sqlx_utils/traits/repository/transactions/
mod.rs

1//! Transaction related traits.
2//!
3//! There are many issues with lifetimes when using [`Transaction`] from [`sqlx`], this is due to
4//! the implementation in [`sqlx`] and from my testing cant be fixed without the inner traits being
5//! implemented using [`async_trait`], if this changes in the future then I will fix this
6
7use crate::{
8    mod_def,
9    traits::{Model, Repository},
10    types::Database,
11};
12use futures::future::try_join_all;
13use sqlx::{Error, Transaction};
14use std::future::Future;
15use std::sync::Arc;
16
17mod_def! {
18    !export
19    pub(crate) mod insert_tx;
20    pub(crate) mod update_tx;
21    pub(crate) mod delete_tx;
22    pub(crate) mod save_tx;
23}
24
25/// Extension trait for Repository to work with transactions
26///
27/// This trait adds transactions capabilities to any repository that implements
28/// the [`Repository`] trait. It provides several methods for executing operations
29/// within database transactions, with different strategies for concurrency and
30/// error handling.
31///
32/// The trait is automatically implemented for any type that implements [`Repository<M>`],
33/// making transactions capabilities available to all repositories without additional code.
34pub trait TransactionRepository<M>: Repository<M>
35where
36    M: Model,
37{
38    /// Executes a callback within a transactions, handling the transactions lifecycle automatically.
39    ///
40    /// This method:
41    /// 1. Begins a transactions from the repository's connection pool
42    /// 2. Passes the transactions to the callback function
43    /// 3. Waits for the callback to complete and return both a result and the transactions
44    /// 4. Commits the transactions if the result is `Ok`, or rolls it back if it's `Err`
45    /// 5. Returns the final result
46    ///
47    /// # Type Parameters
48    ///
49    /// * `F`: The type of the callback function [^func]
50    /// * `Fut`: The future type returned by the callback
51    /// * `R`: The result type
52    /// * `E`: The error type, which must be convertible from [`Error`]
53    ///
54    /// # Parameters
55    ///
56    /// * `callback`: A function that accepts a [`Transaction`] and returns a future
57    ///
58    /// # Returns
59    ///
60    /// A future that resolves to `Result<R, E>`.
61    ///
62    /// # Example
63    ///
64    /// ```no_compile
65    /// let result = repo.with_transaction(|mut tx| async move {
66    ///     let model = Model::new();
67    ///     let res = repo.save_with_executor(&mut tx, model).await;
68    ///     (res, tx)
69    /// }).await;
70    /// ```
71    ///
72    /// [^func]: The function signature of an action must be `async fn action<'b>(tx: Transaction<'b, Database>) -> (Result<T, E>, Transaction<'b, Database>)`
73    ///    Take note of the lifetimes as you might run into errors related to lifetimes if they are not specified due to invariance. The future must also be [`Send`]
74    fn with_transaction<'a, 'b, F, Fut, R, E>(
75        &'a self,
76        callback: F,
77    ) -> impl Future<Output = Result<R, E>> + Send + 'a
78    where
79        F: FnOnce(Transaction<'b, Database>) -> Fut + Send + 'a,
80        Fut: Future<Output = (Result<R, E>, Transaction<'b, Database>)> + Send,
81        R: Send + 'a,
82        E: From<Error> + Send,
83    {
84        async move {
85            let transaction = self.pool().begin().await.map_err(E::from)?;
86
87            let (ret, tx) = callback(transaction).await;
88
89            match ret {
90                Ok(val) => {
91                    tx.commit().await.map_err(E::from)?;
92                    Ok(val)
93                }
94                Err(err) => {
95                    tx.rollback().await.map_err(E::from)?;
96                    Err(err)
97                }
98            }
99        }
100    }
101
102    /// Executes multiple operations sequentially in a transactions, stopping at the first error.
103    ///
104    /// This method provides an optimized approach for cases where you want to stop processing
105    /// as soon as any action fails, immediately rolling back the transactions.
106    ///
107    /// # Type Parameters
108    ///
109    /// * `I`: The iterator type
110    /// * `F`: The action function type [^func]
111    /// * `Fut`: The future type returned by each action
112    /// * `R`: The result type
113    /// * `E`: The error type, which must be convertible from [`Error`]
114    ///
115    /// # Parameters
116    ///
117    /// * `actions`: An iterator of functions that will be executed in the transactions
118    ///
119    /// # Returns
120    ///
121    /// A future that resolves to:
122    /// * `Ok(Vec<R>)`: A vector of results if all actions succeeded
123    /// * `Err(E)`: The first error encountered
124    ///
125    /// # Implementation Details
126    ///
127    /// 1. Begins a transactions from the repository's connection pool
128    /// 2. Executes each action sequentially, collecting results
129    /// 3. If any action fails, rolls back the transactions and returns the error
130    /// 4. If all actions succeed, commits the transactions and returns the results
131    ///
132    /// Due to complex lifetime bounds in underlying types we must take ownership and then return it
133    /// back.
134    ///
135    /// # Examples
136    ///
137    /// ## Basic
138    ///
139    /// ```no_compile
140    /// let results = repo.transaction_sequential([
141    ///     |tx| async move { repo.save_with_executor(tx, model1).await },
142    ///     |tx| async move { repo.save_with_executor(tx, model2).await }
143    /// ]).await;
144    /// ```
145    ///
146    /// ## Complete
147    ///
148    /// ```rust,should_panic
149    /// use sqlx::Transaction;
150    /// use sqlx_utils::prelude::*;
151    /// #
152    /// # repository! {
153    /// #     !zst
154    /// #     UserRepo<User>;
155    /// # }
156    ///
157    /// struct User {
158    ///     id: String,
159    ///     name: String
160    /// }
161    ///
162    /// impl Model for User {
163    ///     type Id = String;
164    ///
165    ///     fn get_id(&self) -> Option<Self::Id> {
166    ///         Some(self.id.to_owned())
167    ///     }
168    /// }
169    ///
170    /// #[derive(Debug)]
171    /// struct Error { // Any error type that implements `From<sqlx::Error>` is allowed
172    ///     kind: Box<dyn std::error::Error + Send>
173    /// }
174    ///
175    /// impl From<sqlx::Error> for Error {
176    ///     fn from(value: sqlx::Error) -> Self {
177    ///         Self {
178    ///             kind: Box::new(value)
179    ///         }
180    ///     }
181    /// }
182    ///
183    /// async fn action<'b>(tx: Transaction<'b, Database>) -> (Result<User, Error>, Transaction<'b, Database>) {
184    ///      unimplemented!()
185    ///  }
186    ///
187    /// # #[tokio::main]
188    /// # async fn main() {
189    /// USER_REPO.transaction_sequential([action, action, action]).await.unwrap();
190    /// # }
191    /// ```
192    ///
193    /// [^func]: The function signature of an action must be `async fn action<'b>(tx: Transaction<'b, Database>) -> (Result<T, E>, Transaction<'b, Database>)`
194    ///    Take note of the lifetimes as you might run into errors related to lifetimes if they are not specified due to invariance. The future must also be [`Send`]
195    fn transaction_sequential<'a, 'b, I, F, Fut, R, E>(
196        &'a self,
197        actions: I,
198    ) -> impl Future<Output = Result<Vec<R>, E>> + Send + 'a
199    where
200        I: IntoIterator<Item = F> + Send + 'a,
201        I::IntoIter: Send + 'a,
202        F: FnOnce(Transaction<'b, Database>) -> Fut + Send + 'a,
203        Fut: Future<Output = (Result<R, E>, Transaction<'b, Database>)> + Send,
204        R: Send + 'a,
205        E: From<Error> + Send + 'a,
206    {
207        async move {
208            let mut tx = self.pool().begin().await.map_err(E::from)?;
209            let mut results = Vec::new();
210
211            for action in actions {
212                let (result, new_tx) = action(tx).await;
213                tx = new_tx; // Get back ownership
214
215                match result {
216                    Ok(value) => results.push(value),
217                    Err(e) => {
218                        let _ = tx.rollback().await;
219                        return Err(e);
220                    }
221                }
222            }
223
224            tx.commit().await.map_err(E::from)?;
225            Ok(results)
226        }
227    }
228
229    /// Executes multiple operations concurrently in a transactions.
230    ///
231    /// This method allows for concurrent execution of actions within a transactions,
232    /// which can significantly improve performance for I/O-bound operations.
233    /// Note that this only works when the actions don't have data dependencies.
234    ///
235    /// # Type Parameters
236    ///
237    /// * `I`: The iterator type
238    /// * `F`: The action function type [^func] [^mutex]
239    /// * `Fut`: The future type returned by each action
240    /// * `R`: The result type
241    /// * `E`: The error type, which must be convertible from [`Error`]
242    ///
243    /// # Parameters
244    ///
245    /// * `actions`: An iterator of functions that will be executed concurrently in the transactions
246    ///
247    /// # Returns
248    ///
249    /// A future that resolves to:
250    /// * `Ok(Vec<R>)`: A vector of results if all actions succeeded
251    /// * `Err(E)`: The first error encountered
252    ///
253    /// # Implementation Details
254    ///
255    /// 1. Begins a transactions from the repository's connection pool
256    /// 2. Wraps the transactions in an [`Arc<Mutex<_>>`] to safely share it between concurrent operations [^mutex]
257    /// 3. Creates futures for all actions but doesn't execute them yet
258    /// 4. Executes all futures concurrently using [`try_join_all`]
259    /// 5. If all operations succeed, commits the transactions and returns the results
260    /// 6. If any operation fails, rolls back the transactions and returns the error
261    ///
262    /// # Notes
263    ///
264    /// - Uses [`parking_lot::Mutex`] for better performance than `std::sync::Mutex`
265    /// - Requires the transactions to be safely shared between multiple futures
266    ///
267    /// # Example
268    ///
269    /// ## Basic
270    ///
271    /// ```no_compile
272    /// let results = repo.transaction_concurrent([
273    ///     |tx_arc| async move {
274    ///         let mut tx = tx_arc.lock();
275    ///         repo.save_with_executor(&mut *tx, model1).await
276    ///     },
277    ///     |tx_arc| async move {
278    ///         let mut tx = tx_arc.lock();
279    ///         repo.save_with_executor(&mut *tx, model2).await
280    ///     }
281    /// ]).await;
282    /// ```
283    ///
284    /// ## Complete
285    ///
286    /// ```rust,should_panic
287    /// use sqlx::Transaction;
288    /// use std::sync::Arc;
289    /// use sqlx_utils::prelude::*;
290    /// #
291    /// # repository! {
292    /// #     !zst
293    /// #     UserRepo<User>;
294    /// # }
295    ///
296    /// struct User {
297    ///     id: String,
298    ///     name: String
299    /// }
300    ///
301    /// impl Model for User {
302    ///     type Id = String;
303    ///
304    ///     fn get_id(&self) -> Option<Self::Id> {
305    ///         Some(self.id.to_owned())
306    ///     }
307    /// }
308    ///
309    /// #[derive(Debug)]
310    /// struct Error { // Any error type that implements `From<sqlx::Error>` is allowed
311    ///     kind: Box<dyn std::error::Error + Send>
312    /// }
313    ///
314    /// impl From<sqlx::Error> for Error {
315    ///     fn from(value: sqlx::Error) -> Self {
316    ///         Self {
317    ///             kind: Box::new(value)
318    ///         }
319    ///     }
320    /// }
321    ///
322    /// async fn action<'b>(tx: Arc<parking_lot::Mutex<Transaction<'b, Database>>>) -> Result<User, Error> {
323    ///     unimplemented!()
324    ///  }
325    ///
326    /// # #[tokio::main]
327    /// # async fn main() {
328    /// USER_REPO.transaction_concurrent([action, action, action]).await.unwrap();
329    /// # }
330    /// ```
331    ///
332    /// [^func]: The function signature of an action must be `async fn action<'b>(tx: Arc<parking_lot::Mutex<Transaction<'b, Database>>>) -> Result<T, E>`
333    ///    Take note of the lifetimes as you might run into errors related to lifetimes if they are not specified due to invariance. The future must also be [`Send`]
334    ///
335    /// [^mutex]: It is up to you to ensure we don't get deadlocks, the function itself will not lock the mutex,
336    ///    it will however attempt to get the inner value of the [`Arc`] after all actions has completed where it also consumes the mutex.
337    ///    This makes it in theory impossible to get a deadlock in this method, however deadlocks can occur between different actions.
338    fn transaction_concurrent<'a, 'b, I, F, Fut, R, E>(
339        &'a self,
340        actions: I,
341    ) -> impl Future<Output = Result<Vec<R>, E>> + Send + 'a
342    where
343        I: IntoIterator<Item = F> + Send + 'a,
344        I::IntoIter: Send + 'a,
345        F: FnOnce(Arc<parking_lot::Mutex<Transaction<'b, Database>>>) -> Fut + Send + 'a,
346        Fut: Future<Output = Result<R, E>> + Send + 'a,
347        R: Send + 'a,
348        E: From<Error> + Send + 'a,
349    {
350        async move {
351            let tx = self.pool().begin().await.map_err(E::from)?;
352            let tx = Arc::new(parking_lot::Mutex::new(tx));
353
354            // Create futures but don't await them yet
355            let futures: Vec<_> = actions
356                .into_iter()
357                .map(|action_fn| action_fn(tx.clone()))
358                .collect();
359
360            // Execute all futures concurrently
361            let results = try_join_all(futures).await;
362
363            match results {
364                Ok(values) => {
365                    let tx = match Arc::into_inner(tx) {
366                        Some(mutex) => mutex.into_inner(),
367                        None => return Err(E::from(Error::PoolClosed)),
368                    };
369
370                    tx.commit().await.map_err(E::from)?;
371                    Ok(values)
372                }
373                Err(e) => {
374                    let tx = match Arc::into_inner(tx) {
375                        Some(mutex) => mutex.into_inner(),
376                        None => return Err(E::from(Error::PoolClosed)),
377                    };
378
379                    tx.rollback().await.map_err(E::from)?;
380                    Err(e)
381                }
382            }
383        }
384    }
385
386    /// Executes multiple operations and collects all results, committing only if all succeed.
387    ///
388    /// This method runs all actions sequentially, collecting results (both successes and failures).
389    /// The transactions is committed only if all actions succeed; otherwise, it's rolled back.
390    ///
391    /// # Type Parameters
392    ///
393    /// * `I`: The iterator type
394    /// * `F`: The action function type [^func]
395    /// * `Fut`: The future type returned by each action
396    /// * `R`: The result type
397    /// * `E`: The error type, which must be convertible from [`Error`]
398    ///
399    /// # Parameters
400    ///
401    /// * `actions`: An iterator of functions that will be executed in the transactions
402    ///
403    /// # Returns
404    ///
405    /// A future that resolves to:
406    /// * `Ok(Vec<R>)`: A vector of results if all operations succeeded
407    /// * `Err(Vec<E>)`: A vector of all errors if any operation failed
408    ///
409    /// # Implementation Details
410    ///
411    /// 1. Begins a transactions from the repository's connection pool
412    /// 2. Executes each action sequentially, collecting all results and errors
413    /// 3. If any errors occurred, rolls back the transactions and returns all errors
414    /// 4. If all operations succeeded, commits the transactions and returns the results
415    ///
416    /// # Example
417    ///
418    /// ## Basic
419    ///
420    /// ```no_compile
421    /// match repo.try_transaction([
422    ///     |tx| async move { repo.save_with_executor(tx, model1).await },
423    ///     |tx| async move { repo.save_with_executor(tx, model2).await }
424    /// ]).await {
425    ///     Ok(results) => println!("All operations succeeded"),
426    ///     Err(errors) => println!("Some operations failed: {:?}", errors)
427    /// }
428    /// ```
429    ///
430    /// ## Complete
431    ///
432    /// ```rust,should_panic
433    /// use sqlx::Transaction;
434    /// use sqlx_utils::prelude::*;
435    /// #
436    /// # repository! {
437    /// #     !zst
438    /// #     UserRepo<User>;
439    /// # }
440    ///
441    /// struct User {
442    ///     id: String,
443    ///     name: String
444    /// }
445    ///
446    /// impl Model for User {
447    ///     type Id = String;
448    ///
449    ///     fn get_id(&self) -> Option<Self::Id> {
450    ///         Some(self.id.to_owned())
451    ///     }
452    /// }
453    ///
454    /// #[derive(Debug)]
455    /// struct Error { // Any error type that implements `From<sqlx::Error>` is allowed
456    ///     kind: Box<dyn std::error::Error + Send>
457    /// }
458    ///
459    /// impl From<sqlx::Error> for Error {
460    ///     fn from(value: sqlx::Error) -> Self {
461    ///         Self {
462    ///             kind: Box::new(value)
463    ///         }
464    ///     }
465    /// }
466    ///
467    /// async fn action<'b>(tx: Transaction<'b, Database>) -> (Result<User, Error>, Transaction<'b, Database>) {
468    ///      unimplemented!()
469    ///  }
470    ///
471    /// # #[tokio::main]
472    /// # async fn main() {
473    /// USER_REPO.try_transaction([action, action, action]).await.unwrap();
474    /// # }
475    /// ```
476    ///
477    /// [^func]: The function signature of an action must be `async fn action<'b>(tx: Transaction<'b, Database>) -> (Result<T, E>, Transaction<'b, Database>)`
478    ///    Take note of the lifetimes as you might run into errors related to lifetimes if they are not specified due to invariance. The future must also be [`Send`]
479    fn try_transaction<'a, 'b, I, F, Fut, R, E>(
480        &'a self,
481        actions: I,
482    ) -> impl Future<Output = Result<Vec<R>, Vec<E>>> + Send + 'a
483    where
484        I: IntoIterator<Item = F> + Send + 'a,
485        I::IntoIter: Send + 'a,
486        F: FnOnce(Transaction<'b, Database>) -> Fut + Send + 'a,
487        Fut: Future<Output = (Result<R, E>, Transaction<'b, Database>)> + Send,
488        R: Send + 'a,
489        E: From<Error> + Send + 'a,
490    {
491        async move {
492            let mut tx = self.pool().begin().await.map_err(|e| vec![E::from(e)])?;
493            let mut results = Vec::new();
494            let mut errors = Vec::new();
495
496            for action in actions {
497                let (result, new_tx) = action(tx).await;
498                tx = new_tx;
499
500                match result {
501                    Ok(result) => results.push(result),
502                    Err(e) => errors.push(e),
503                }
504            }
505
506            if errors.is_empty() {
507                tx.commit().await.map_err(|e| vec![E::from(e)])?;
508                Ok(results)
509            } else {
510                let _ = tx.rollback().await;
511                Err(errors)
512            }
513        }
514    }
515}
516
517impl<T, M> TransactionRepository<M> for T
518where
519    T: Repository<M>,
520    M: Model,
521{
522}