Skip to main content

systemprompt_database/services/
transaction.rs

1//! Generic transaction wrappers that work directly with [`PgPool`] /
2//! [`PgDbPool`] without going through the dyn-safe trait.
3
4use crate::error::RepositoryError;
5use crate::repository::PgDbPool;
6use sqlx::{PgPool, Postgres, Transaction};
7use std::future::Future;
8use std::pin::Pin;
9
10pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
11
12pub async fn with_transaction<F, T, E>(pool: &PgDbPool, f: F) -> Result<T, E>
13where
14    F: for<'c> FnOnce(&'c mut Transaction<'_, Postgres>) -> BoxFuture<'c, Result<T, E>>,
15    E: From<sqlx::Error>,
16{
17    let mut tx = pool.begin().await?;
18    let result = f(&mut tx).await?;
19    tx.commit().await?;
20    Ok(result)
21}
22
23pub async fn with_transaction_raw<F, T, E>(pool: &PgPool, f: F) -> Result<T, E>
24where
25    F: for<'c> FnOnce(&'c mut Transaction<'_, Postgres>) -> BoxFuture<'c, Result<T, E>>,
26    E: From<sqlx::Error>,
27{
28    let mut tx = pool.begin().await?;
29    let result = f(&mut tx).await?;
30    tx.commit().await?;
31    Ok(result)
32}
33
34pub async fn with_transaction_retry<F, T>(
35    pool: &PgDbPool,
36    max_retries: u32,
37    f: F,
38) -> Result<T, RepositoryError>
39where
40    F: for<'c> Fn(&'c mut Transaction<'_, Postgres>) -> BoxFuture<'c, Result<T, RepositoryError>>
41        + Send,
42{
43    let mut attempts = 0;
44    let base_delay_ms = 10u64;
45
46    loop {
47        let mut tx = pool.begin().await?;
48
49        match f(&mut tx).await {
50            Ok(result) => {
51                tx.commit().await?;
52                return Ok(result);
53            },
54            Err(e) => {
55                if attempts < max_retries && is_retriable_error(&e) {
56                    if let Err(rollback_err) = tx.rollback().await {
57                        tracing::error!(error = %rollback_err, "Transaction rollback failed during retry");
58                    }
59                    attempts += 1;
60                    let delay_ms = base_delay_ms * (1 << attempts.min(6));
61                    tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
62                    continue;
63                }
64                if let Err(rollback_err) = tx.rollback().await {
65                    tracing::error!(error = %rollback_err, "Transaction rollback failed");
66                }
67                return Err(e);
68            },
69        }
70    }
71}
72
73fn is_retriable_error(error: &RepositoryError) -> bool {
74    match error {
75        RepositoryError::Database(sqlx_error) => {
76            sqlx_error.as_database_error().is_some_and(|db_error| {
77                let code = db_error.code().map(|c| c.to_string());
78                matches!(code.as_deref(), Some("40001" | "40P01"))
79            })
80        },
81        RepositoryError::NotFound(_)
82        | RepositoryError::Constraint(_)
83        | RepositoryError::Serialization(_)
84        | RepositoryError::InvalidArgument(_)
85        | RepositoryError::InvalidState(_)
86        | RepositoryError::Internal(_) => false,
87    }
88}