sqlx_utils/traits/repository/transactions/mod.rs
1//! Transaction related traits.
2//!
3//! There are many issues with lifetimes when using [`Transaction`] from [`sqlx`], this is due to
4//! the implementation in [`sqlx`] and from my testing cant be fixed without the inner traits being
5//! implemented using [`async_trait`], if this changes in the future then I will fix this
6
7use crate::{
8 mod_def,
9 traits::{Model, Repository},
10 types::Database,
11};
12use futures::future::try_join_all;
13use sqlx::{Error, Transaction};
14use std::future::Future;
15use std::sync::Arc;
16
17mod_def! {
18 !export
19 pub(crate) mod insert_tx;
20 pub(crate) mod update_tx;
21 pub(crate) mod delete_tx;
22 pub(crate) mod save_tx;
23}
24
25/// Extension trait for Repository to work with transactions
26///
27/// This trait adds transactions capabilities to any repository that implements
28/// the [`Repository`] trait. It provides several methods for executing operations
29/// within database transactions, with different strategies for concurrency and
30/// error handling.
31///
32/// The trait is automatically implemented for any type that implements [`Repository<M>`],
33/// making transactions capabilities available to all repositories without additional code.
34pub trait TransactionRepository<M>: Repository<M>
35where
36 M: Model,
37{
38 /// Executes a callback within a transactions, handling the transactions lifecycle automatically.
39 ///
40 /// This method:
41 /// 1. Begins a transactions from the repository's connection pool
42 /// 2. Passes the transactions to the callback function
43 /// 3. Waits for the callback to complete and return both a result and the transactions
44 /// 4. Commits the transactions if the result is `Ok`, or rolls it back if it's `Err`
45 /// 5. Returns the final result
46 ///
47 /// # Type Parameters
48 ///
49 /// * `F`: The type of the callback function [^func]
50 /// * `Fut`: The future type returned by the callback
51 /// * `R`: The result type
52 /// * `E`: The error type, which must be convertible from [`Error`]
53 ///
54 /// # Parameters
55 ///
56 /// * `callback`: A function that accepts a [`Transaction`] and returns a future
57 ///
58 /// # Returns
59 ///
60 /// A future that resolves to `Result<R, E>`.
61 ///
62 /// # Example
63 ///
64 /// ```no_compile
65 /// let result = repo.with_transaction(|mut tx| async move {
66 /// let model = Model::new();
67 /// let res = repo.save_with_executor(&mut tx, model).await;
68 /// (res, tx)
69 /// }).await;
70 /// ```
71 ///
72 /// [^func]: The function signature of an action must be `async fn action<'b>(tx: Transaction<'b, Database>) -> (Result<T, E>, Transaction<'b, Database>)`
73 /// Take note of the lifetimes as you might run into errors related to lifetimes if they are not specified due to invariance. The future must also be [`Send`]
74 fn with_transaction<'a, 'b, F, Fut, R, E>(
75 &'a self,
76 callback: F,
77 ) -> impl Future<Output = Result<R, E>> + Send + 'a
78 where
79 F: FnOnce(Transaction<'b, Database>) -> Fut + Send + 'a,
80 Fut: Future<Output = (Result<R, E>, Transaction<'b, Database>)> + Send,
81 R: Send + 'a,
82 E: From<Error> + Send,
83 {
84 async move {
85 let transaction = self.pool().begin().await.map_err(E::from)?;
86
87 let (ret, tx) = callback(transaction).await;
88
89 match ret {
90 Ok(val) => {
91 tx.commit().await.map_err(E::from)?;
92 Ok(val)
93 }
94 Err(err) => {
95 tx.rollback().await.map_err(E::from)?;
96 Err(err)
97 }
98 }
99 }
100 }
101
102 /// Executes multiple operations sequentially in a transactions, stopping at the first error.
103 ///
104 /// This method provides an optimized approach for cases where you want to stop processing
105 /// as soon as any action fails, immediately rolling back the transactions.
106 ///
107 /// # Type Parameters
108 ///
109 /// * `I`: The iterator type
110 /// * `F`: The action function type [^func]
111 /// * `Fut`: The future type returned by each action
112 /// * `R`: The result type
113 /// * `E`: The error type, which must be convertible from [`Error`]
114 ///
115 /// # Parameters
116 ///
117 /// * `actions`: An iterator of functions that will be executed in the transactions
118 ///
119 /// # Returns
120 ///
121 /// A future that resolves to:
122 /// * `Ok(Vec<R>)`: A vector of results if all actions succeeded
123 /// * `Err(E)`: The first error encountered
124 ///
125 /// # Implementation Details
126 ///
127 /// 1. Begins a transactions from the repository's connection pool
128 /// 2. Executes each action sequentially, collecting results
129 /// 3. If any action fails, rolls back the transactions and returns the error
130 /// 4. If all actions succeed, commits the transactions and returns the results
131 ///
132 /// Due to complex lifetime bounds in underlying types we must take ownership and then return it
133 /// back.
134 ///
135 /// # Examples
136 ///
137 /// ## Basic
138 ///
139 /// ```no_compile
140 /// let results = repo.transaction_sequential([
141 /// |tx| async move { repo.save_with_executor(tx, model1).await },
142 /// |tx| async move { repo.save_with_executor(tx, model2).await }
143 /// ]).await;
144 /// ```
145 ///
146 /// ## Complete
147 ///
148 /// ```rust,should_panic
149 /// use sqlx::Transaction;
150 /// use sqlx_utils::prelude::*;
151 /// #
152 /// # repository! {
153 /// # !zst
154 /// # UserRepo<User>;
155 /// # }
156 ///
157 /// struct User {
158 /// id: String,
159 /// name: String
160 /// }
161 ///
162 /// impl Model for User {
163 /// type Id = String;
164 ///
165 /// fn get_id(&self) -> Option<Self::Id> {
166 /// Some(self.id.to_owned())
167 /// }
168 /// }
169 ///
170 /// #[derive(Debug)]
171 /// struct Error { // Any error type that implements `From<sqlx::Error>` is allowed
172 /// kind: Box<dyn std::error::Error + Send>
173 /// }
174 ///
175 /// impl From<sqlx::Error> for Error {
176 /// fn from(value: sqlx::Error) -> Self {
177 /// Self {
178 /// kind: Box::new(value)
179 /// }
180 /// }
181 /// }
182 ///
183 /// async fn action<'b>(tx: Transaction<'b, Database>) -> (Result<User, Error>, Transaction<'b, Database>) {
184 /// unimplemented!()
185 /// }
186 ///
187 /// # #[tokio::main]
188 /// # async fn main() {
189 /// USER_REPO.transaction_sequential([action, action, action]).await.unwrap();
190 /// # }
191 /// ```
192 ///
193 /// [^func]: The function signature of an action must be `async fn action<'b>(tx: Transaction<'b, Database>) -> (Result<T, E>, Transaction<'b, Database>)`
194 /// Take note of the lifetimes as you might run into errors related to lifetimes if they are not specified due to invariance. The future must also be [`Send`]
195 fn transaction_sequential<'a, 'b, I, F, Fut, R, E>(
196 &'a self,
197 actions: I,
198 ) -> impl Future<Output = Result<Vec<R>, E>> + Send + 'a
199 where
200 I: IntoIterator<Item = F> + Send + 'a,
201 I::IntoIter: Send + 'a,
202 F: FnOnce(Transaction<'b, Database>) -> Fut + Send + 'a,
203 Fut: Future<Output = (Result<R, E>, Transaction<'b, Database>)> + Send,
204 R: Send + 'a,
205 E: From<Error> + Send + 'a,
206 {
207 async move {
208 let mut tx = self.pool().begin().await.map_err(E::from)?;
209 let mut results = Vec::new();
210
211 for action in actions {
212 let (result, new_tx) = action(tx).await;
213 tx = new_tx; // Get back ownership
214
215 match result {
216 Ok(value) => results.push(value),
217 Err(e) => {
218 let _ = tx.rollback().await;
219 return Err(e);
220 }
221 }
222 }
223
224 tx.commit().await.map_err(E::from)?;
225 Ok(results)
226 }
227 }
228
229 /// Executes multiple operations concurrently in a transactions.
230 ///
231 /// This method allows for concurrent execution of actions within a transactions,
232 /// which can significantly improve performance for I/O-bound operations.
233 /// Note that this only works when the actions don't have data dependencies.
234 ///
235 /// # Type Parameters
236 ///
237 /// * `I`: The iterator type
238 /// * `F`: The action function type [^func] [^mutex]
239 /// * `Fut`: The future type returned by each action
240 /// * `R`: The result type
241 /// * `E`: The error type, which must be convertible from [`Error`]
242 ///
243 /// # Parameters
244 ///
245 /// * `actions`: An iterator of functions that will be executed concurrently in the transactions
246 ///
247 /// # Returns
248 ///
249 /// A future that resolves to:
250 /// * `Ok(Vec<R>)`: A vector of results if all actions succeeded
251 /// * `Err(E)`: The first error encountered
252 ///
253 /// # Implementation Details
254 ///
255 /// 1. Begins a transactions from the repository's connection pool
256 /// 2. Wraps the transactions in an [`Arc<Mutex<_>>`] to safely share it between concurrent operations [^mutex]
257 /// 3. Creates futures for all actions but doesn't execute them yet
258 /// 4. Executes all futures concurrently using [`try_join_all`]
259 /// 5. If all operations succeed, commits the transactions and returns the results
260 /// 6. If any operation fails, rolls back the transactions and returns the error
261 ///
262 /// # Notes
263 ///
264 /// - Uses [`parking_lot::Mutex`] for better performance than `std::sync::Mutex`
265 /// - Requires the transactions to be safely shared between multiple futures
266 ///
267 /// # Example
268 ///
269 /// ## Basic
270 ///
271 /// ```no_compile
272 /// let results = repo.transaction_concurrent([
273 /// |tx_arc| async move {
274 /// let mut tx = tx_arc.lock();
275 /// repo.save_with_executor(&mut *tx, model1).await
276 /// },
277 /// |tx_arc| async move {
278 /// let mut tx = tx_arc.lock();
279 /// repo.save_with_executor(&mut *tx, model2).await
280 /// }
281 /// ]).await;
282 /// ```
283 ///
284 /// ## Complete
285 ///
286 /// ```rust,should_panic
287 /// use sqlx::Transaction;
288 /// use std::sync::Arc;
289 /// use sqlx_utils::prelude::*;
290 /// #
291 /// # repository! {
292 /// # !zst
293 /// # UserRepo<User>;
294 /// # }
295 ///
296 /// struct User {
297 /// id: String,
298 /// name: String
299 /// }
300 ///
301 /// impl Model for User {
302 /// type Id = String;
303 ///
304 /// fn get_id(&self) -> Option<Self::Id> {
305 /// Some(self.id.to_owned())
306 /// }
307 /// }
308 ///
309 /// #[derive(Debug)]
310 /// struct Error { // Any error type that implements `From<sqlx::Error>` is allowed
311 /// kind: Box<dyn std::error::Error + Send>
312 /// }
313 ///
314 /// impl From<sqlx::Error> for Error {
315 /// fn from(value: sqlx::Error) -> Self {
316 /// Self {
317 /// kind: Box::new(value)
318 /// }
319 /// }
320 /// }
321 ///
322 /// async fn action<'b>(tx: Arc<parking_lot::Mutex<Transaction<'b, Database>>>) -> Result<User, Error> {
323 /// unimplemented!()
324 /// }
325 ///
326 /// # #[tokio::main]
327 /// # async fn main() {
328 /// USER_REPO.transaction_concurrent([action, action, action]).await.unwrap();
329 /// # }
330 /// ```
331 ///
332 /// [^func]: The function signature of an action must be `async fn action<'b>(tx: Arc<parking_lot::Mutex<Transaction<'b, Database>>>) -> Result<T, E>`
333 /// Take note of the lifetimes as you might run into errors related to lifetimes if they are not specified due to invariance. The future must also be [`Send`]
334 ///
335 /// [^mutex]: It is up to you to ensure we don't get deadlocks, the function itself will not lock the mutex,
336 /// it will however attempt to get the inner value of the [`Arc`] after all actions has completed where it also consumes the mutex.
337 /// This makes it in theory impossible to get a deadlock in this method, however deadlocks can occur between different actions.
338 fn transaction_concurrent<'a, 'b, I, F, Fut, R, E>(
339 &'a self,
340 actions: I,
341 ) -> impl Future<Output = Result<Vec<R>, E>> + Send + 'a
342 where
343 I: IntoIterator<Item = F> + Send + 'a,
344 I::IntoIter: Send + 'a,
345 F: FnOnce(Arc<parking_lot::Mutex<Transaction<'b, Database>>>) -> Fut + Send + 'a,
346 Fut: Future<Output = Result<R, E>> + Send + 'a,
347 R: Send + 'a,
348 E: From<Error> + Send + 'a,
349 {
350 async move {
351 let tx = self.pool().begin().await.map_err(E::from)?;
352 let tx = Arc::new(parking_lot::Mutex::new(tx));
353
354 // Create futures but don't await them yet
355 let futures: Vec<_> = actions
356 .into_iter()
357 .map(|action_fn| action_fn(tx.clone()))
358 .collect();
359
360 // Execute all futures concurrently
361 let results = try_join_all(futures).await;
362
363 match results {
364 Ok(values) => {
365 let tx = match Arc::into_inner(tx) {
366 Some(mutex) => mutex.into_inner(),
367 None => return Err(E::from(Error::PoolClosed)),
368 };
369
370 tx.commit().await.map_err(E::from)?;
371 Ok(values)
372 }
373 Err(e) => {
374 let tx = match Arc::into_inner(tx) {
375 Some(mutex) => mutex.into_inner(),
376 None => return Err(E::from(Error::PoolClosed)),
377 };
378
379 tx.rollback().await.map_err(E::from)?;
380 Err(e)
381 }
382 }
383 }
384 }
385
386 /// Executes multiple operations and collects all results, committing only if all succeed.
387 ///
388 /// This method runs all actions sequentially, collecting results (both successes and failures).
389 /// The transactions is committed only if all actions succeed; otherwise, it's rolled back.
390 ///
391 /// # Type Parameters
392 ///
393 /// * `I`: The iterator type
394 /// * `F`: The action function type [^func]
395 /// * `Fut`: The future type returned by each action
396 /// * `R`: The result type
397 /// * `E`: The error type, which must be convertible from [`Error`]
398 ///
399 /// # Parameters
400 ///
401 /// * `actions`: An iterator of functions that will be executed in the transactions
402 ///
403 /// # Returns
404 ///
405 /// A future that resolves to:
406 /// * `Ok(Vec<R>)`: A vector of results if all operations succeeded
407 /// * `Err(Vec<E>)`: A vector of all errors if any operation failed
408 ///
409 /// # Implementation Details
410 ///
411 /// 1. Begins a transactions from the repository's connection pool
412 /// 2. Executes each action sequentially, collecting all results and errors
413 /// 3. If any errors occurred, rolls back the transactions and returns all errors
414 /// 4. If all operations succeeded, commits the transactions and returns the results
415 ///
416 /// # Example
417 ///
418 /// ## Basic
419 ///
420 /// ```no_compile
421 /// match repo.try_transaction([
422 /// |tx| async move { repo.save_with_executor(tx, model1).await },
423 /// |tx| async move { repo.save_with_executor(tx, model2).await }
424 /// ]).await {
425 /// Ok(results) => println!("All operations succeeded"),
426 /// Err(errors) => println!("Some operations failed: {:?}", errors)
427 /// }
428 /// ```
429 ///
430 /// ## Complete
431 ///
432 /// ```rust,should_panic
433 /// use sqlx::Transaction;
434 /// use sqlx_utils::prelude::*;
435 /// #
436 /// # repository! {
437 /// # !zst
438 /// # UserRepo<User>;
439 /// # }
440 ///
441 /// struct User {
442 /// id: String,
443 /// name: String
444 /// }
445 ///
446 /// impl Model for User {
447 /// type Id = String;
448 ///
449 /// fn get_id(&self) -> Option<Self::Id> {
450 /// Some(self.id.to_owned())
451 /// }
452 /// }
453 ///
454 /// #[derive(Debug)]
455 /// struct Error { // Any error type that implements `From<sqlx::Error>` is allowed
456 /// kind: Box<dyn std::error::Error + Send>
457 /// }
458 ///
459 /// impl From<sqlx::Error> for Error {
460 /// fn from(value: sqlx::Error) -> Self {
461 /// Self {
462 /// kind: Box::new(value)
463 /// }
464 /// }
465 /// }
466 ///
467 /// async fn action<'b>(tx: Transaction<'b, Database>) -> (Result<User, Error>, Transaction<'b, Database>) {
468 /// unimplemented!()
469 /// }
470 ///
471 /// # #[tokio::main]
472 /// # async fn main() {
473 /// USER_REPO.try_transaction([action, action, action]).await.unwrap();
474 /// # }
475 /// ```
476 ///
477 /// [^func]: The function signature of an action must be `async fn action<'b>(tx: Transaction<'b, Database>) -> (Result<T, E>, Transaction<'b, Database>)`
478 /// Take note of the lifetimes as you might run into errors related to lifetimes if they are not specified due to invariance. The future must also be [`Send`]
479 fn try_transaction<'a, 'b, I, F, Fut, R, E>(
480 &'a self,
481 actions: I,
482 ) -> impl Future<Output = Result<Vec<R>, Vec<E>>> + Send + 'a
483 where
484 I: IntoIterator<Item = F> + Send + 'a,
485 I::IntoIter: Send + 'a,
486 F: FnOnce(Transaction<'b, Database>) -> Fut + Send + 'a,
487 Fut: Future<Output = (Result<R, E>, Transaction<'b, Database>)> + Send,
488 R: Send + 'a,
489 E: From<Error> + Send + 'a,
490 {
491 async move {
492 let mut tx = self.pool().begin().await.map_err(|e| vec![E::from(e)])?;
493 let mut results = Vec::new();
494 let mut errors = Vec::new();
495
496 for action in actions {
497 let (result, new_tx) = action(tx).await;
498 tx = new_tx;
499
500 match result {
501 Ok(result) => results.push(result),
502 Err(e) => errors.push(e),
503 }
504 }
505
506 if errors.is_empty() {
507 tx.commit().await.map_err(|e| vec![E::from(e)])?;
508 Ok(results)
509 } else {
510 let _ = tx.rollback().await;
511 Err(errors)
512 }
513 }
514 }
515}
516
517impl<T, M> TransactionRepository<M> for T
518where
519 T: Repository<M>,
520 M: Model,
521{
522}