sqlite_watcher/
statement.rs

1use crate::connection::{SqlExecutor, SqlExecutorAsync, SqlExecutorMut};
2use std::future::Future;
3use tracing::{Instrument, Span, error};
4
5pub(super) trait Sealed {}
6
7/// Basic abstraction that defers the execution of a SQL statement in order to reduce the duplication
8/// of sync and async code. Basic composability and chaining are also included.
9///
10/// The `Send` requirement is in theory not required for sync implementations, but this is not
11/// intended to be used outside of the scope of this crate.
12#[allow(private_bounds)]
13pub trait Statement: Send + Sealed {
14    /// Output of this statement.
15    type Output: Send;
16
17    /// Execute the statement and return the result.
18    ///
19    /// # Errors
20    ///
21    /// If the statement fails, return error.
22    fn execute<S: SqlExecutor>(self, connection: &S) -> Result<Self::Output, S::Error>;
23
24    /// Execute the statement and return the result.
25    ///
26    /// # Errors
27    ///
28    /// If the statement fails, return error.
29    fn execute_mut<S: SqlExecutorMut>(self, connection: &mut S) -> Result<Self::Output, S::Error>;
30
31    /// Execute the statement and return the result.
32    ///
33    /// # Errors
34    ///
35    /// If the statement fails, return error.
36    fn execute_async<S: SqlExecutorAsync>(
37        self,
38        connection: &mut S,
39    ) -> impl Future<Output = Result<Self::Output, S::Error>> + Send;
40
41    /// If this statement succeeds, then execute the next `statement`.
42    fn then<Q: Statement>(self, statement: Q) -> Then<Self, Q>
43    where
44        Self: Sized,
45    {
46        Then {
47            a: self,
48            b: statement,
49        }
50    }
51
52    /// if the current statement succeeds, then execute the next `statement` with output of the
53    /// current [`Statement`].
54    fn pipe<Q: StatementWithInput<Input = Self::Output> + Send>(self, statement: Q) -> Pipe<Self, Q>
55    where
56        Self: Sized,
57    {
58        Pipe {
59            a: self,
60            b: statement,
61        }
62    }
63
64    /// Instrument the current statement with the given [`Span`]
65    fn spanned(self, span: Span) -> TracedStatement<Self>
66    where
67        Self: Sized,
68    {
69        TracedStatement::new(self, span)
70    }
71
72    /// Instrument the current statement with currently active [`Span`]
73    fn spanned_in_current(self) -> TracedStatement<Self>
74    where
75        Self: Sized,
76    {
77        TracedStatement::current(self)
78    }
79}
80
81/// Similar to [`Statement`] but accepts an input parameter.
82///
83/// This statement is intended to be used with [`Statement::pipe`].
84pub trait StatementWithInput: Send {
85    /// Input for the statement.
86    type Input: Send;
87    /// Output of the statement.
88    type Output: Send;
89
90    /// Execute the statement with the given `input` and return the result.
91    ///
92    /// # Errors
93    ///
94    /// If the statement fails, return error.
95    fn execute<S: SqlExecutor>(
96        self,
97        connection: &S,
98        input: Self::Input,
99    ) -> Result<Self::Output, S::Error>;
100
101    /// Execute the statement with the given `input` and return the result.
102    ///
103    /// # Errors
104    ///
105    /// If the statement fails, return error.
106    fn execute_mut<S: SqlExecutorMut>(
107        self,
108        connection: &mut S,
109        input: Self::Input,
110    ) -> Result<Self::Output, S::Error>;
111
112    /// Execute the statement with the given `input` and return the result.
113    ///
114    /// # Errors
115    ///
116    /// If the statement fails, return error.
117    fn execute_async<S: SqlExecutorAsync>(
118        self,
119        connection: &mut S,
120        input: Self::Input,
121    ) -> impl Future<Output = Result<Self::Output, S::Error>> + Send;
122}
123
124/// Link two [`Statement`]s.
125///
126/// Statement `B` is only executed if `A` fails.
127pub struct Then<A: Statement, B: Statement> {
128    a: A,
129    b: B,
130}
131impl<A: Statement, B: Statement> Sealed for Then<A, B> {}
132
133impl<A: Statement + Send, B: Statement + Send> Statement for Then<A, B> {
134    type Output = B::Output;
135
136    fn execute<S: SqlExecutor>(self, connection: &S) -> Result<Self::Output, S::Error> {
137        self.a.execute(connection)?;
138        self.b.execute(connection)
139    }
140
141    fn execute_mut<S: SqlExecutorMut>(self, connection: &mut S) -> Result<Self::Output, S::Error> {
142        self.a.execute_mut(connection)?;
143        self.b.execute_mut(connection)
144    }
145
146    async fn execute_async<S: SqlExecutorAsync>(
147        self,
148        connection: &mut S,
149    ) -> Result<Self::Output, S::Error> {
150        self.a.execute_async(connection).await?;
151        self.b.execute_async(connection).await
152    }
153}
154
155/// Link two [`Statement`]s and use the output of `A` as the input of `B`.
156pub struct Pipe<A: Statement + Send, B: StatementWithInput<Input = A::Output> + Send> {
157    a: A,
158    b: B,
159}
160
161impl<A: Statement, B: StatementWithInput<Input = A::Output>> Sealed for Pipe<A, B> {}
162impl<A: Statement + Send, B: StatementWithInput<Input = A::Output> + Send> Statement
163    for Pipe<A, B>
164{
165    type Output = B::Output;
166
167    fn execute<S: SqlExecutor>(self, connection: &S) -> Result<Self::Output, S::Error> {
168        let output = self.a.execute(connection)?;
169        self.b.execute(connection, output)
170    }
171
172    fn execute_mut<S: SqlExecutorMut>(self, connection: &mut S) -> Result<Self::Output, S::Error> {
173        let output = self.a.execute_mut(connection)?;
174        self.b.execute_mut(connection, output)
175    }
176
177    async fn execute_async<S: SqlExecutorAsync>(
178        self,
179        connection: &mut S,
180    ) -> Result<Self::Output, S::Error> {
181        let output = self.a.execute_async(connection).await?;
182        self.b.execute_async(connection, output).await
183    }
184}
185
186/// Execute an SQL statement which does not return any value.
187pub(super) struct SqlExecuteStatement<T: AsRef<str>> {
188    query: T,
189}
190
191impl<T: AsRef<str> + Send> SqlExecuteStatement<T> {
192    pub fn new(query: T) -> Self {
193        Self { query }
194    }
195}
196
197impl<T: AsRef<str> + Send> Sealed for SqlExecuteStatement<T> {}
198
199impl<T: AsRef<str> + Send> Statement for SqlExecuteStatement<T> {
200    type Output = ();
201
202    fn execute<S: SqlExecutor>(self, connection: &S) -> Result<Self::Output, S::Error> {
203        connection.sql_execute(self.query.as_ref())
204    }
205
206    fn execute_mut<S: SqlExecutorMut>(self, connection: &mut S) -> Result<Self::Output, S::Error> {
207        connection.sql_execute(self.query.as_ref())
208    }
209
210    async fn execute_async<S: SqlExecutorAsync>(
211        self,
212        connection: &mut S,
213    ) -> Result<Self::Output, S::Error> {
214        connection.sql_execute(self.query.as_ref()).await
215    }
216}
217
218/// Controls the transaction behavior.
219enum TransactionMode {
220    /// Locks only temporary tables. Note that this not guaranteed if the transaction touches
221    /// other tables that are not temporary
222    Temporary,
223    /// Locks the full database.
224    Full,
225}
226
227/// Execute an SQL Transaction.
228pub(super) struct SqlTransactionStatement<Q: Statement> {
229    statement: Q,
230    mode: TransactionMode,
231}
232
233impl<Q: Statement> SqlTransactionStatement<Q> {
234    /// Create new transaction that only affects temporary tables.
235    pub fn temporary(statement: Q) -> Self {
236        Self {
237            statement,
238            mode: TransactionMode::Temporary,
239        }
240    }
241    /// Create a new transaction that affects all tables.
242    #[allow(dead_code)]
243    pub fn full(statement: Q) -> Self {
244        Self {
245            statement,
246            mode: TransactionMode::Full,
247        }
248    }
249
250    fn begin_statement(&self) -> &'static str {
251        match self.mode {
252            TransactionMode::Temporary => BEGIN_TRANSACTION_STATEMENT,
253            TransactionMode::Full => BEGIN_TRANSACTION_IMMEDIATE_STATEMENT,
254        }
255    }
256}
257
258impl<Q: Statement<Output = ()>> Sealed for SqlTransactionStatement<Q> {}
259
260impl<Q: Statement<Output = ()>> Statement for SqlTransactionStatement<Q> {
261    type Output = ();
262
263    fn execute<S: SqlExecutor>(self, connection: &S) -> Result<Self::Output, S::Error> {
264        connection
265            .sql_execute(self.begin_statement())
266            .inspect_err(|e| error!("Failed to start transaction: {e}"))?;
267        if let Err(e) = self.statement.execute(connection) {
268            error!("Statement failed to execute: {e}");
269            if let Err(e) = connection.sql_execute(ROLLBACK_TRANSACTION_STATEMENT) {
270                error!("Failed to rollback transaction: {e}");
271            }
272            return Err(e);
273        }
274        connection
275            .sql_execute(COMMIT_TRANSACTION_STATEMENT)
276            .inspect_err(|e| error!("Failed to commit transaction: {e}"))?;
277        Ok(())
278    }
279
280    fn execute_mut<S: SqlExecutorMut>(self, connection: &mut S) -> Result<Self::Output, S::Error> {
281        connection
282            .sql_execute(self.begin_statement())
283            .inspect_err(|e| error!("Failed to start transaction: {e}"))?;
284        if let Err(e) = self.statement.execute_mut(connection) {
285            error!("Statement failed to execute: {e}");
286            if let Err(e) = connection.sql_execute(ROLLBACK_TRANSACTION_STATEMENT) {
287                error!("Failed to rollback transaction: {e}");
288            }
289            return Err(e);
290        }
291        connection
292            .sql_execute(COMMIT_TRANSACTION_STATEMENT)
293            .inspect_err(|e| error!("Failed to commit transaction: {e}"))?;
294        Ok(())
295    }
296    async fn execute_async<S: SqlExecutorAsync>(
297        self,
298        connection: &mut S,
299    ) -> Result<Self::Output, S::Error> {
300        connection
301            .sql_execute(self.begin_statement())
302            .await
303            .inspect_err(|e| error!("Failed to start transaction: {e}"))?;
304        if let Err(e) = self.statement.execute_async(connection).await {
305            error!("Statement failed to execute: {e}");
306            if let Err(e) = connection.sql_execute(ROLLBACK_TRANSACTION_STATEMENT).await {
307                error!("Failed to rollback transaction: {e}");
308            }
309            return Err(e);
310        }
311        connection
312            .sql_execute(COMMIT_TRANSACTION_STATEMENT)
313            .await
314            .inspect_err(|e| error!("Failed to commit transaction: {e}"))?;
315        Ok(())
316    }
317}
318
319/// Execute a collections of [`Statement`].
320///
321/// Execution will halt on the first failed statement.
322pub(super) struct BatchQuery<Q: Statement>(Vec<Q>);
323
324impl<Q: Statement> BatchQuery<Q> {
325    pub fn new(v: impl IntoIterator<Item = Q>) -> Self {
326        Self(Vec::from_iter(v))
327    }
328
329    pub fn push(&mut self, q: Q) {
330        self.0.push(q);
331    }
332
333    pub fn extend<I: IntoIterator<Item = Q>>(&mut self, iter: I) {
334        self.0.extend(iter);
335    }
336}
337
338impl<Q: Statement<Output = ()>> Sealed for BatchQuery<Q> {}
339
340impl<Q: Statement<Output = ()>> Statement for BatchQuery<Q> {
341    type Output = ();
342
343    fn execute<S: SqlExecutor>(self, connection: &S) -> Result<Self::Output, S::Error> {
344        for q in self.0 {
345            q.execute(connection)?;
346        }
347        Ok(())
348    }
349
350    fn execute_mut<S: SqlExecutorMut>(self, connection: &mut S) -> Result<Self::Output, S::Error> {
351        for q in self.0 {
352            q.execute_mut(connection)?;
353        }
354        Ok(())
355    }
356    async fn execute_async<S: SqlExecutorAsync>(
357        self,
358        connection: &mut S,
359    ) -> Result<Self::Output, S::Error> {
360        for q in self.0 {
361            q.execute_async(connection).await?;
362        }
363        Ok(())
364    }
365}
366
367impl<Q: Statement> Sealed for Option<Q> {}
368
369impl<Q: Statement> Statement for Option<Q> {
370    type Output = Option<Q::Output>;
371
372    fn execute<S: SqlExecutor>(self, connection: &S) -> Result<Self::Output, S::Error> {
373        Ok(match self {
374            Some(q) => Some(q.execute(connection)?),
375            None => None,
376        })
377    }
378
379    fn execute_mut<S: SqlExecutorMut>(self, connection: &mut S) -> Result<Self::Output, S::Error> {
380        Ok(match self {
381            Some(q) => Some(q.execute_mut(connection)?),
382            None => None,
383        })
384    }
385
386    async fn execute_async<S: SqlExecutorAsync>(
387        self,
388        connection: &mut S,
389    ) -> Result<Self::Output, S::Error> {
390        Ok(match self {
391            Some(q) => Some(q.execute_async(connection).await?),
392            None => None,
393        })
394    }
395}
396
397pub struct TracedStatement<Q: Statement> {
398    statement: Q,
399    span: Span,
400}
401
402impl<Q: Statement> TracedStatement<Q> {
403    /// Create a new traced `span` for the `statement`.
404    pub fn new(statement: Q, span: Span) -> Self {
405        Self { statement, span }
406    }
407
408    /// Create a new traced span for the `statement` using the current active tracing span.
409    pub fn current(statement: Q) -> Self {
410        Self::new(statement, Span::current())
411    }
412}
413
414impl<Q: Statement> Sealed for TracedStatement<Q> {}
415
416impl<Q: Statement> Statement for TracedStatement<Q> {
417    type Output = Q::Output;
418    fn execute<S: SqlExecutor>(self, connection: &S) -> Result<Self::Output, S::Error> {
419        let _span = self.span.entered();
420        self.statement.execute(connection)
421    }
422
423    fn execute_mut<S: SqlExecutorMut>(self, connection: &mut S) -> Result<Self::Output, S::Error> {
424        let _span = self.span.entered();
425        self.statement.execute_mut(connection)
426    }
427
428    async fn execute_async<S: SqlExecutorAsync>(
429        self,
430        connection: &mut S,
431    ) -> Result<Self::Output, S::Error> {
432        self.statement
433            .execute_async(connection)
434            .instrument(self.span)
435            .await
436    }
437}
438
439const BEGIN_TRANSACTION_STATEMENT: &str = "BEGIN";
440const BEGIN_TRANSACTION_IMMEDIATE_STATEMENT: &str = "BEGIN IMMEDIATE";
441const COMMIT_TRANSACTION_STATEMENT: &str = "COMMIT";
442const ROLLBACK_TRANSACTION_STATEMENT: &str = "ROLLBACK";