Skip to main content

sqlite_pool/
lib.rs

1mod config;
2
3use arc_swap::ArcSwap;
4use std::{
5    fmt,
6    ops::{Deref, DerefMut},
7    sync::{
8        atomic::{AtomicUsize, Ordering},
9        Arc,
10    },
11};
12use tracing::warn;
13
14use deadpool::managed::{self, Object};
15use metrics::counter;
16use rusqlite::{CachedStatement, InterruptHandle, Params, Transaction};
17use tokio::time::{sleep, Duration};
18use tokio_util::sync::{CancellationToken, DropGuard};
19
20pub use deadpool::managed::reexports::*;
21pub use rusqlite;
22
23pub type Pool<T> = deadpool::managed::Pool<Manager<T>>;
24pub type RusqlitePool = Pool<rusqlite::Connection>;
25pub type CreatePoolError = deadpool::managed::CreatePoolError<ConfigError>;
26pub type PoolBuilder<T> = deadpool::managed::PoolBuilder<Manager<T>, Object<Manager<T>>>;
27pub type PoolError = deadpool::managed::PoolError<rusqlite::Error>;
28
29pub type Hook<T> = deadpool::managed::Hook<Manager<T>>;
30pub type HookError = deadpool::managed::HookError<rusqlite::Error>;
31
32pub type Connection<T> = deadpool::managed::Object<Manager<T>>;
33pub type RusqliteConnection = Connection<rusqlite::Connection>;
34
35#[inline]
36pub fn noop_transform(conn: rusqlite::Connection) -> rusqlite::Result<rusqlite::Connection> {
37    Ok(conn)
38}
39
40pub use self::config::{Config, ConfigError};
41
42pub type TransformFn<T> = dyn Fn(rusqlite::Connection) -> Result<T, rusqlite::Error> + Send + Sync;
43
44/// [`Manager`] for creating and recycling SQLite [`Connection`]s.
45///
46/// [`Manager`]: managed::Manager
47pub struct Manager<T> {
48    config: Config,
49    recycle_count: AtomicUsize,
50    transform: Box<TransformFn<T>>,
51}
52
53impl<T> Manager<T> {
54    /// Creates a new [`Manager`] using the given [`Config`] backed by the
55    /// specified [`Runtime`].
56    #[must_use]
57    pub fn from_config(
58        config: &Config,
59        transform: impl Fn(rusqlite::Connection) -> Result<T, rusqlite::Error> + Send + Sync + 'static,
60    ) -> Self {
61        Self {
62            config: config.clone(),
63            recycle_count: AtomicUsize::new(0),
64            transform: Box::new(transform),
65        }
66    }
67}
68
69impl<T> fmt::Debug for Manager<T> {
70    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
71        f.debug_struct("Manager")
72            .field("config", &self.config)
73            .field("recycle_count", &self.recycle_count)
74            .finish()
75    }
76}
77
78pub trait SqliteConn: Send {
79    fn conn(&self) -> &rusqlite::Connection;
80}
81
82impl SqliteConn for rusqlite::Connection {
83    fn conn(&self) -> &rusqlite::Connection {
84        self
85    }
86}
87
88impl<T> managed::Manager for Manager<T>
89where
90    T: SqliteConn,
91{
92    type Type = T;
93    type Error = rusqlite::Error;
94
95    async fn create(&self) -> Result<Self::Type, Self::Error> {
96        (self.transform)(rusqlite::Connection::open_with_flags(
97            &self.config.path,
98            self.config.open_flags,
99        )?)
100    }
101
102    async fn recycle(
103        &self,
104        _conn: &mut Self::Type,
105        _: &Metrics,
106    ) -> managed::RecycleResult<Self::Error> {
107        let _ = self.recycle_count.fetch_add(1, Ordering::Relaxed);
108        Ok(())
109    }
110}
111
112#[derive(Clone)]
113pub struct InterruptHandler {
114    interrupt_hdl: Arc<InterruptHandle>,
115    current_sql: Arc<ArcSwap<Option<String>>>,
116    timeout: Option<Duration>,
117    source: &'static str,
118}
119
120impl InterruptHandler {
121    pub fn new(
122        interrupt_hdl: Arc<InterruptHandle>,
123        current_sql: Arc<ArcSwap<Option<String>>>,
124        timeout: Option<Duration>,
125        source: &'static str,
126    ) -> Self {
127        Self {
128            interrupt_hdl,
129            current_sql,
130            timeout,
131            source,
132        }
133    }
134
135    fn timeout_guard(&self) -> DropGuard {
136        let cancel_token = CancellationToken::new();
137
138        if let Some(timeout) = self.timeout {
139            let cloned_token = cancel_token.clone();
140            let interrupt_hdl = self.interrupt_hdl.clone();
141            let current_sql = self.current_sql.clone();
142            let source = self.source;
143            tokio::spawn(async move {
144                tokio::select! {
145                    _ = cloned_token.cancelled() => {}
146                    _ = sleep(timeout) => {
147                        warn!("sql call took more than {timeout:?}, interrupting.. {:?}", current_sql);
148                        interrupt_hdl.interrupt();
149                        counter!("corro.sqlite.interrupt", "source" => source, "reason" => "timeout").increment(1);
150                    }
151                }
152            });
153        }
154
155        cancel_token.drop_guard()
156    }
157}
158
159pub struct InterruptibleTransaction<T> {
160    conn: T,
161    int_hdlr: InterruptHandler,
162    current_sql: Arc<ArcSwap<Option<String>>>,
163}
164
165impl<T> InterruptibleTransaction<T>
166where
167    T: Deref<Target = rusqlite::Connection>,
168{
169    pub fn new(conn: T, timeout: Option<Duration>, source: &'static str) -> Self {
170        let interrupt_hdl = Arc::new(conn.get_interrupt_handle());
171        let query_store: Arc<ArcSwap<Option<String>>> = Arc::new(ArcSwap::new(Arc::new(None)));
172        let int_hdlr = InterruptHandler::new(interrupt_hdl, query_store.clone(), timeout, source);
173        Self {
174            conn,
175            int_hdlr,
176            current_sql: query_store,
177        }
178    }
179
180    pub fn new_with_hdlr(
181        conn: T,
182        query_store: Arc<ArcSwap<Option<String>>>,
183        int_hdlr: InterruptHandler,
184    ) -> Self {
185        Self {
186            conn,
187            int_hdlr,
188            current_sql: query_store,
189        }
190    }
191
192    pub fn execute(
193        &self,
194        sql: &str,
195        params: &[&dyn rusqlite::ToSql],
196    ) -> Result<usize, rusqlite::Error> {
197        let _guard = self.int_hdlr.timeout_guard();
198        self.current_sql.store(Arc::new(Some(sql.to_string())));
199        self.conn.execute(sql, params)
200    }
201
202    pub fn prepare(
203        &self,
204        sql: &str,
205    ) -> Result<InterruptibleStatement<Statement<'_>>, rusqlite::Error> {
206        let stmt = self.conn.prepare(sql)?;
207        self.current_sql.store(Arc::new(Some(sql.to_string())));
208        Ok(InterruptibleStatement::new(
209            Statement(stmt),
210            self.int_hdlr.clone(),
211        ))
212    }
213
214    pub fn prepare_cached(
215        &self,
216        sql: &str,
217    ) -> Result<InterruptibleStatement<CachedStatement<'_>>, rusqlite::Error> {
218        let stmt = self.conn.prepare_cached(sql)?;
219        self.current_sql.store(Arc::new(Some(sql.to_string())));
220        Ok(InterruptibleStatement::new(stmt, self.int_hdlr.clone()))
221    }
222
223    pub fn execute_batch(&self, sql: &str) -> Result<(), rusqlite::Error> {
224        let _guard = self.int_hdlr.timeout_guard();
225        self.current_sql.store(Arc::new(Some(sql.to_string())));
226        self.conn.execute_batch(sql)
227    }
228}
229
230impl<T> InterruptibleTransaction<T>
231where
232    T: Deref<Target = rusqlite::Connection> + Committable,
233{
234    pub fn commit(self) -> Result<(), rusqlite::Error> {
235        let _guard = self.int_hdlr.timeout_guard();
236        self.conn.commit()
237    }
238
239    pub fn savepoint(
240        &mut self,
241    ) -> Result<InterruptibleTransaction<rusqlite::Savepoint<'_>>, rusqlite::Error> {
242        let sp = self.conn.savepoint()?;
243        Ok(InterruptibleTransaction::new_with_hdlr(
244            sp,
245            self.current_sql.clone(),
246            self.int_hdlr.clone(),
247        ))
248    }
249}
250
251impl<T> Deref for InterruptibleTransaction<T>
252where
253    T: Deref<Target = rusqlite::Connection>,
254{
255    type Target = rusqlite::Connection;
256
257    fn deref(&self) -> &Self::Target {
258        &self.conn
259    }
260}
261
262impl<T> DerefMut for InterruptibleTransaction<T>
263where
264    T: DerefMut<Target = rusqlite::Connection>,
265{
266    fn deref_mut(&mut self) -> &mut Self::Target {
267        &mut self.conn
268    }
269}
270
271pub struct InterruptibleStatement<T> {
272    stmt: T,
273    int_hdlr: InterruptHandler,
274}
275
276impl<'conn, 'a, T> InterruptibleStatement<T>
277where
278    T: Deref<Target = rusqlite::Statement<'conn>> + DerefMut<Target = rusqlite::Statement<'conn>>,
279{
280    pub fn new(stmt: T, int_hdlr: InterruptHandler) -> Self {
281        Self { stmt, int_hdlr }
282    }
283
284    pub fn execute<P: Params>(&mut self, params: P) -> Result<usize, rusqlite::Error> {
285        let _guard = self.int_hdlr.timeout_guard();
286        self.stmt.execute(params)
287    }
288
289    pub fn query<'rows, P: Params>(
290        &'a mut self,
291        params: P,
292    ) -> Result<InterruptibleRows<'rows>, rusqlite::Error>
293    where
294        'conn: 'rows,
295        'a: 'rows,
296    {
297        let _guard = self.int_hdlr.timeout_guard();
298        let rows = self.stmt.query(params)?;
299        Ok(InterruptibleRows::new(rows, self.int_hdlr.clone()))
300    }
301
302    pub fn query_map<P: Params, S, F>(
303        &'a mut self,
304        params: P,
305        f: F,
306    ) -> rusqlite::Result<InterruptibleMappedRows<'a, F>>
307    where
308        F: FnMut(&rusqlite::Row<'_>) -> rusqlite::Result<S>,
309        'conn: 'a,
310    {
311        let _guard = self.int_hdlr.timeout_guard();
312        let mapped_rows = self.stmt.query_map(params, f)?;
313        Ok(InterruptibleMappedRows::new(
314            mapped_rows,
315            self.int_hdlr.clone(),
316        ))
317    }
318}
319
320impl<'conn, T: Deref<Target = rusqlite::Statement<'conn>>> Deref for InterruptibleStatement<T> {
321    type Target = rusqlite::Statement<'conn>;
322
323    fn deref(&self) -> &Self::Target {
324        &self.stmt
325    }
326}
327
328impl<'conn, T: DerefMut<Target = rusqlite::Statement<'conn>>> DerefMut
329    for InterruptibleStatement<T>
330{
331    fn deref_mut(&mut self) -> &mut Self::Target {
332        &mut self.stmt
333    }
334}
335
336pub trait Committable {
337    fn commit(self) -> Result<(), rusqlite::Error>;
338    fn savepoint(&mut self) -> Result<rusqlite::Savepoint<'_>, rusqlite::Error>;
339}
340
341impl Committable for Transaction<'_> {
342    fn commit(self) -> Result<(), rusqlite::Error> {
343        self.commit()
344    }
345
346    fn savepoint(&mut self) -> Result<rusqlite::Savepoint<'_>, rusqlite::Error> {
347        self.savepoint()
348    }
349}
350
351impl Committable for rusqlite::Savepoint<'_> {
352    fn commit(self) -> Result<(), rusqlite::Error> {
353        self.commit()
354    }
355
356    fn savepoint(&mut self) -> Result<rusqlite::Savepoint<'_>, rusqlite::Error> {
357        self.savepoint()
358    }
359}
360
361pub struct Statement<'conn>(pub rusqlite::Statement<'conn>);
362
363impl<'conn> Deref for Statement<'conn> {
364    type Target = rusqlite::Statement<'conn>;
365
366    fn deref(&self) -> &Self::Target {
367        &self.0
368    }
369}
370
371impl<'conn> DerefMut for Statement<'conn> {
372    fn deref_mut(&mut self) -> &mut Self::Target {
373        &mut self.0
374    }
375}
376
377pub struct InterruptibleMappedRows<'a, F> {
378    rows: rusqlite::MappedRows<'a, F>,
379    int_hdlr: InterruptHandler,
380}
381
382impl<'a, F> InterruptibleMappedRows<'a, F> {
383    pub fn new(rows: rusqlite::MappedRows<'a, F>, int_hdlr: InterruptHandler) -> Self {
384        Self { rows, int_hdlr }
385    }
386}
387
388impl<'a, F, T> Iterator for InterruptibleMappedRows<'a, F>
389where
390    F: FnMut(&rusqlite::Row<'_>) -> rusqlite::Result<T>,
391{
392    type Item = rusqlite::Result<T>;
393
394    fn next(&mut self) -> Option<Self::Item> {
395        let _guard = self.int_hdlr.timeout_guard();
396        self.rows.next()
397    }
398}
399
400pub struct InterruptibleRows<'stmt> {
401    rows: rusqlite::Rows<'stmt>,
402    int_hdlr: InterruptHandler,
403}
404
405impl<'stmt> InterruptibleRows<'stmt> {
406    pub fn new(rows: rusqlite::Rows<'stmt>, int_hdlr: InterruptHandler) -> Self {
407        Self { rows, int_hdlr }
408    }
409}
410
411impl<'stmt> InterruptibleRows<'stmt> {
412    #[allow(clippy::should_implement_trait)]
413    pub fn next(&mut self) -> Result<Option<&rusqlite::Row<'stmt>>, rusqlite::Error> {
414        let _guard = self.int_hdlr.timeout_guard();
415        self.rows.next()
416    }
417}
418
419#[cfg(test)]
420mod tests {}