Skip to main content

systemprompt_database/services/
transaction.rs

1use crate::error::RepositoryError;
2use crate::repository::PgDbPool;
3use sqlx::{PgPool, Postgres, Transaction};
4use std::future::Future;
5use std::pin::Pin;
6
7pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
8
9pub async fn with_transaction<F, T, E>(pool: &PgDbPool, f: F) -> Result<T, E>
10where
11    F: for<'c> FnOnce(&'c mut Transaction<'_, Postgres>) -> BoxFuture<'c, Result<T, E>>,
12    E: From<sqlx::Error>,
13{
14    let mut tx = pool.begin().await?;
15    let result = f(&mut tx).await?;
16    tx.commit().await?;
17    Ok(result)
18}
19
20pub async fn with_transaction_raw<F, T, E>(pool: &PgPool, f: F) -> Result<T, E>
21where
22    F: for<'c> FnOnce(&'c mut Transaction<'_, Postgres>) -> BoxFuture<'c, Result<T, E>>,
23    E: From<sqlx::Error>,
24{
25    let mut tx = pool.begin().await?;
26    let result = f(&mut tx).await?;
27    tx.commit().await?;
28    Ok(result)
29}
30
31pub async fn with_transaction_retry<F, T>(
32    pool: &PgDbPool,
33    max_retries: u32,
34    f: F,
35) -> Result<T, RepositoryError>
36where
37    F: for<'c> Fn(&'c mut Transaction<'_, Postgres>) -> BoxFuture<'c, Result<T, RepositoryError>>
38        + Send,
39{
40    let mut attempts = 0;
41    let base_delay_ms = 10u64;
42
43    loop {
44        let mut tx = pool.begin().await?;
45
46        match f(&mut tx).await {
47            Ok(result) => {
48                tx.commit().await?;
49                return Ok(result);
50            },
51            Err(e) => {
52                if attempts < max_retries && is_retriable_error(&e) {
53                    if let Err(rollback_err) = tx.rollback().await {
54                        tracing::error!(error = %rollback_err, "Transaction rollback failed during retry");
55                    }
56                    attempts += 1;
57                    let delay_ms = base_delay_ms * (1 << attempts.min(6));
58                    tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
59                    continue;
60                }
61                if let Err(rollback_err) = tx.rollback().await {
62                    tracing::error!(error = %rollback_err, "Transaction rollback failed");
63                }
64                return Err(e);
65            },
66        }
67    }
68}
69
70fn is_retriable_error(error: &RepositoryError) -> bool {
71    match error {
72        RepositoryError::Database(sqlx_error) => {
73            sqlx_error.as_database_error().is_some_and(|db_error| {
74                let code = db_error.code().map(|c| c.to_string());
75                matches!(code.as_deref(), Some("40001" | "40P01"))
76            })
77        },
78        RepositoryError::NotFound(_)
79        | RepositoryError::Constraint(_)
80        | RepositoryError::Serialization(_)
81        | RepositoryError::InvalidArgument(_)
82        | RepositoryError::Internal(_) => false,
83    }
84}