1mod config;
2
3use arc_swap::ArcSwap;
4use std::{
5 fmt,
6 ops::{Deref, DerefMut},
7 sync::{
8 atomic::{AtomicUsize, Ordering},
9 Arc,
10 },
11};
12use tracing::warn;
13
14use deadpool::managed::{self, Object};
15use metrics::counter;
16use rusqlite::{CachedStatement, InterruptHandle, Params, Transaction};
17use tokio::time::{sleep, Duration};
18use tokio_util::sync::{CancellationToken, DropGuard};
19
20pub use deadpool::managed::reexports::*;
21pub use rusqlite;
22
23pub type Pool<T> = deadpool::managed::Pool<Manager<T>>;
24pub type RusqlitePool = Pool<rusqlite::Connection>;
25pub type CreatePoolError = deadpool::managed::CreatePoolError<ConfigError>;
26pub type PoolBuilder<T> = deadpool::managed::PoolBuilder<Manager<T>, Object<Manager<T>>>;
27pub type PoolError = deadpool::managed::PoolError<rusqlite::Error>;
28
29pub type Hook<T> = deadpool::managed::Hook<Manager<T>>;
30pub type HookError = deadpool::managed::HookError<rusqlite::Error>;
31
32pub type Connection<T> = deadpool::managed::Object<Manager<T>>;
33pub type RusqliteConnection = Connection<rusqlite::Connection>;
34
35#[inline]
36pub fn noop_transform(conn: rusqlite::Connection) -> rusqlite::Result<rusqlite::Connection> {
37 Ok(conn)
38}
39
40pub use self::config::{Config, ConfigError};
41
42pub type TransformFn<T> = dyn Fn(rusqlite::Connection) -> Result<T, rusqlite::Error> + Send + Sync;
43
44pub struct Manager<T> {
48 config: Config,
49 recycle_count: AtomicUsize,
50 transform: Box<TransformFn<T>>,
51}
52
53impl<T> Manager<T> {
54 #[must_use]
57 pub fn from_config(
58 config: &Config,
59 transform: impl Fn(rusqlite::Connection) -> Result<T, rusqlite::Error> + Send + Sync + 'static,
60 ) -> Self {
61 Self {
62 config: config.clone(),
63 recycle_count: AtomicUsize::new(0),
64 transform: Box::new(transform),
65 }
66 }
67}
68
69impl<T> fmt::Debug for Manager<T> {
70 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
71 f.debug_struct("Manager")
72 .field("config", &self.config)
73 .field("recycle_count", &self.recycle_count)
74 .finish()
75 }
76}
77
78pub trait SqliteConn: Send {
79 fn conn(&self) -> &rusqlite::Connection;
80}
81
82impl SqliteConn for rusqlite::Connection {
83 fn conn(&self) -> &rusqlite::Connection {
84 self
85 }
86}
87
88impl<T> managed::Manager for Manager<T>
89where
90 T: SqliteConn,
91{
92 type Type = T;
93 type Error = rusqlite::Error;
94
95 async fn create(&self) -> Result<Self::Type, Self::Error> {
96 (self.transform)(rusqlite::Connection::open_with_flags(
97 &self.config.path,
98 self.config.open_flags,
99 )?)
100 }
101
102 async fn recycle(
103 &self,
104 _conn: &mut Self::Type,
105 _: &Metrics,
106 ) -> managed::RecycleResult<Self::Error> {
107 let _ = self.recycle_count.fetch_add(1, Ordering::Relaxed);
108 Ok(())
109 }
110}
111
112#[derive(Clone)]
113pub struct InterruptHandler {
114 interrupt_hdl: Arc<InterruptHandle>,
115 current_sql: Arc<ArcSwap<Option<String>>>,
116 timeout: Option<Duration>,
117 source: &'static str,
118}
119
120impl InterruptHandler {
121 pub fn new(
122 interrupt_hdl: Arc<InterruptHandle>,
123 current_sql: Arc<ArcSwap<Option<String>>>,
124 timeout: Option<Duration>,
125 source: &'static str,
126 ) -> Self {
127 Self {
128 interrupt_hdl,
129 current_sql,
130 timeout,
131 source,
132 }
133 }
134
135 fn timeout_guard(&self) -> DropGuard {
136 let cancel_token = CancellationToken::new();
137
138 if let Some(timeout) = self.timeout {
139 let cloned_token = cancel_token.clone();
140 let interrupt_hdl = self.interrupt_hdl.clone();
141 let current_sql = self.current_sql.clone();
142 let source = self.source;
143 tokio::spawn(async move {
144 tokio::select! {
145 _ = cloned_token.cancelled() => {}
146 _ = sleep(timeout) => {
147 warn!("sql call took more than {timeout:?}, interrupting.. {:?}", current_sql);
148 interrupt_hdl.interrupt();
149 counter!("corro.sqlite.interrupt", "source" => source, "reason" => "timeout").increment(1);
150 }
151 }
152 });
153 }
154
155 cancel_token.drop_guard()
156 }
157}
158
159pub struct InterruptibleTransaction<T> {
160 conn: T,
161 int_hdlr: InterruptHandler,
162 current_sql: Arc<ArcSwap<Option<String>>>,
163}
164
165impl<T> InterruptibleTransaction<T>
166where
167 T: Deref<Target = rusqlite::Connection>,
168{
169 pub fn new(conn: T, timeout: Option<Duration>, source: &'static str) -> Self {
170 let interrupt_hdl = Arc::new(conn.get_interrupt_handle());
171 let query_store: Arc<ArcSwap<Option<String>>> = Arc::new(ArcSwap::new(Arc::new(None)));
172 let int_hdlr = InterruptHandler::new(interrupt_hdl, query_store.clone(), timeout, source);
173 Self {
174 conn,
175 int_hdlr,
176 current_sql: query_store,
177 }
178 }
179
180 pub fn new_with_hdlr(
181 conn: T,
182 query_store: Arc<ArcSwap<Option<String>>>,
183 int_hdlr: InterruptHandler,
184 ) -> Self {
185 Self {
186 conn,
187 int_hdlr,
188 current_sql: query_store,
189 }
190 }
191
192 pub fn execute(
193 &self,
194 sql: &str,
195 params: &[&dyn rusqlite::ToSql],
196 ) -> Result<usize, rusqlite::Error> {
197 let _guard = self.int_hdlr.timeout_guard();
198 self.current_sql.store(Arc::new(Some(sql.to_string())));
199 self.conn.execute(sql, params)
200 }
201
202 pub fn prepare(
203 &self,
204 sql: &str,
205 ) -> Result<InterruptibleStatement<Statement<'_>>, rusqlite::Error> {
206 let stmt = self.conn.prepare(sql)?;
207 self.current_sql.store(Arc::new(Some(sql.to_string())));
208 Ok(InterruptibleStatement::new(
209 Statement(stmt),
210 self.int_hdlr.clone(),
211 ))
212 }
213
214 pub fn prepare_cached(
215 &self,
216 sql: &str,
217 ) -> Result<InterruptibleStatement<CachedStatement<'_>>, rusqlite::Error> {
218 let stmt = self.conn.prepare_cached(sql)?;
219 self.current_sql.store(Arc::new(Some(sql.to_string())));
220 Ok(InterruptibleStatement::new(stmt, self.int_hdlr.clone()))
221 }
222
223 pub fn execute_batch(&self, sql: &str) -> Result<(), rusqlite::Error> {
224 let _guard = self.int_hdlr.timeout_guard();
225 self.current_sql.store(Arc::new(Some(sql.to_string())));
226 self.conn.execute_batch(sql)
227 }
228}
229
230impl<T> InterruptibleTransaction<T>
231where
232 T: Deref<Target = rusqlite::Connection> + Committable,
233{
234 pub fn commit(self) -> Result<(), rusqlite::Error> {
235 let _guard = self.int_hdlr.timeout_guard();
236 self.conn.commit()
237 }
238
239 pub fn savepoint(
240 &mut self,
241 ) -> Result<InterruptibleTransaction<rusqlite::Savepoint<'_>>, rusqlite::Error> {
242 let sp = self.conn.savepoint()?;
243 Ok(InterruptibleTransaction::new_with_hdlr(
244 sp,
245 self.current_sql.clone(),
246 self.int_hdlr.clone(),
247 ))
248 }
249}
250
251impl<T> Deref for InterruptibleTransaction<T>
252where
253 T: Deref<Target = rusqlite::Connection>,
254{
255 type Target = rusqlite::Connection;
256
257 fn deref(&self) -> &Self::Target {
258 &self.conn
259 }
260}
261
262impl<T> DerefMut for InterruptibleTransaction<T>
263where
264 T: DerefMut<Target = rusqlite::Connection>,
265{
266 fn deref_mut(&mut self) -> &mut Self::Target {
267 &mut self.conn
268 }
269}
270
271pub struct InterruptibleStatement<T> {
272 stmt: T,
273 int_hdlr: InterruptHandler,
274}
275
276impl<'conn, 'a, T> InterruptibleStatement<T>
277where
278 T: Deref<Target = rusqlite::Statement<'conn>> + DerefMut<Target = rusqlite::Statement<'conn>>,
279{
280 pub fn new(stmt: T, int_hdlr: InterruptHandler) -> Self {
281 Self { stmt, int_hdlr }
282 }
283
284 pub fn execute<P: Params>(&mut self, params: P) -> Result<usize, rusqlite::Error> {
285 let _guard = self.int_hdlr.timeout_guard();
286 self.stmt.execute(params)
287 }
288
289 pub fn query<'rows, P: Params>(
290 &'a mut self,
291 params: P,
292 ) -> Result<InterruptibleRows<'rows>, rusqlite::Error>
293 where
294 'conn: 'rows,
295 'a: 'rows,
296 {
297 let _guard = self.int_hdlr.timeout_guard();
298 let rows = self.stmt.query(params)?;
299 Ok(InterruptibleRows::new(rows, self.int_hdlr.clone()))
300 }
301
302 pub fn query_map<P: Params, S, F>(
303 &'a mut self,
304 params: P,
305 f: F,
306 ) -> rusqlite::Result<InterruptibleMappedRows<'a, F>>
307 where
308 F: FnMut(&rusqlite::Row<'_>) -> rusqlite::Result<S>,
309 'conn: 'a,
310 {
311 let _guard = self.int_hdlr.timeout_guard();
312 let mapped_rows = self.stmt.query_map(params, f)?;
313 Ok(InterruptibleMappedRows::new(
314 mapped_rows,
315 self.int_hdlr.clone(),
316 ))
317 }
318}
319
320impl<'conn, T: Deref<Target = rusqlite::Statement<'conn>>> Deref for InterruptibleStatement<T> {
321 type Target = rusqlite::Statement<'conn>;
322
323 fn deref(&self) -> &Self::Target {
324 &self.stmt
325 }
326}
327
328impl<'conn, T: DerefMut<Target = rusqlite::Statement<'conn>>> DerefMut
329 for InterruptibleStatement<T>
330{
331 fn deref_mut(&mut self) -> &mut Self::Target {
332 &mut self.stmt
333 }
334}
335
336pub trait Committable {
337 fn commit(self) -> Result<(), rusqlite::Error>;
338 fn savepoint(&mut self) -> Result<rusqlite::Savepoint<'_>, rusqlite::Error>;
339}
340
341impl Committable for Transaction<'_> {
342 fn commit(self) -> Result<(), rusqlite::Error> {
343 self.commit()
344 }
345
346 fn savepoint(&mut self) -> Result<rusqlite::Savepoint<'_>, rusqlite::Error> {
347 self.savepoint()
348 }
349}
350
351impl Committable for rusqlite::Savepoint<'_> {
352 fn commit(self) -> Result<(), rusqlite::Error> {
353 self.commit()
354 }
355
356 fn savepoint(&mut self) -> Result<rusqlite::Savepoint<'_>, rusqlite::Error> {
357 self.savepoint()
358 }
359}
360
361pub struct Statement<'conn>(pub rusqlite::Statement<'conn>);
362
363impl<'conn> Deref for Statement<'conn> {
364 type Target = rusqlite::Statement<'conn>;
365
366 fn deref(&self) -> &Self::Target {
367 &self.0
368 }
369}
370
371impl<'conn> DerefMut for Statement<'conn> {
372 fn deref_mut(&mut self) -> &mut Self::Target {
373 &mut self.0
374 }
375}
376
377pub struct InterruptibleMappedRows<'a, F> {
378 rows: rusqlite::MappedRows<'a, F>,
379 int_hdlr: InterruptHandler,
380}
381
382impl<'a, F> InterruptibleMappedRows<'a, F> {
383 pub fn new(rows: rusqlite::MappedRows<'a, F>, int_hdlr: InterruptHandler) -> Self {
384 Self { rows, int_hdlr }
385 }
386}
387
388impl<'a, F, T> Iterator for InterruptibleMappedRows<'a, F>
389where
390 F: FnMut(&rusqlite::Row<'_>) -> rusqlite::Result<T>,
391{
392 type Item = rusqlite::Result<T>;
393
394 fn next(&mut self) -> Option<Self::Item> {
395 let _guard = self.int_hdlr.timeout_guard();
396 self.rows.next()
397 }
398}
399
400pub struct InterruptibleRows<'stmt> {
401 rows: rusqlite::Rows<'stmt>,
402 int_hdlr: InterruptHandler,
403}
404
405impl<'stmt> InterruptibleRows<'stmt> {
406 pub fn new(rows: rusqlite::Rows<'stmt>, int_hdlr: InterruptHandler) -> Self {
407 Self { rows, int_hdlr }
408 }
409}
410
411impl<'stmt> InterruptibleRows<'stmt> {
412 #[allow(clippy::should_implement_trait)]
413 pub fn next(&mut self) -> Result<Option<&rusqlite::Row<'stmt>>, rusqlite::Error> {
414 let _guard = self.int_hdlr.timeout_guard();
415 self.rows.next()
416 }
417}
418
419#[cfg(test)]
420mod tests {}