Skip to main content

systemprompt_database/services/
transaction.rs

1//! Generic transaction wrappers that work directly with [`PgPool`] /
2//! [`PgDbPool`] without going through the dyn-safe trait.
3
4use crate::error::RepositoryError;
5use crate::repository::PgDbPool;
6use crate::resilience::classify::Outcome;
7use crate::resilience::config::RetryConfig;
8use crate::resilience::retry::retry_async;
9use sqlx::{PgPool, Postgres, Transaction};
10use std::future::Future;
11use std::pin::Pin;
12use std::time::Duration;
13
14pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
15
16pub async fn with_transaction<F, T, E>(pool: &PgDbPool, f: F) -> Result<T, E>
17where
18    F: for<'c> FnOnce(&'c mut Transaction<'_, Postgres>) -> BoxFuture<'c, Result<T, E>>,
19    E: From<sqlx::Error>,
20{
21    let mut tx = pool.begin().await?;
22    let result = f(&mut tx).await?;
23    tx.commit().await?;
24    Ok(result)
25}
26
27pub async fn with_transaction_raw<F, T, E>(pool: &PgPool, f: F) -> Result<T, E>
28where
29    F: for<'c> FnOnce(&'c mut Transaction<'_, Postgres>) -> BoxFuture<'c, Result<T, E>>,
30    E: From<sqlx::Error>,
31{
32    let mut tx = pool.begin().await?;
33    let result = f(&mut tx).await?;
34    tx.commit().await?;
35    Ok(result)
36}
37
38pub async fn with_transaction_retry<F, T>(
39    pool: &PgDbPool,
40    max_retries: u32,
41    f: F,
42) -> Result<T, RepositoryError>
43where
44    T: Send,
45    F: for<'c> Fn(&'c mut Transaction<'_, Postgres>) -> BoxFuture<'c, Result<T, RepositoryError>>
46        + Send
47        + Sync,
48{
49    let cfg = RetryConfig {
50        max_attempts: max_retries.saturating_add(1),
51        base_delay: Duration::from_millis(20),
52        max_delay: Duration::from_millis(640),
53        jitter: false,
54    };
55    let classify = |err: &RepositoryError| {
56        if is_retriable_error(err) {
57            Outcome::Transient { retry_after: None }
58        } else {
59            Outcome::Permanent
60        }
61    };
62    let attempt = || async {
63        let mut tx = pool.begin().await?;
64        match f(&mut tx).await {
65            Ok(result) => {
66                tx.commit().await?;
67                Ok(result)
68            },
69            Err(e) => {
70                if let Err(rollback_err) = tx.rollback().await {
71                    tracing::error!(error = %rollback_err, "Transaction rollback failed");
72                }
73                Err(e)
74            },
75        }
76    };
77    retry_async(&cfg, "transaction", classify, attempt).await
78}
79
80fn is_retriable_error(error: &RepositoryError) -> bool {
81    match error {
82        RepositoryError::Database(sqlx_error) => {
83            sqlx_error.as_database_error().is_some_and(|db_error| {
84                let code = db_error.code().map(|c| c.to_string());
85                matches!(code.as_deref(), Some("40001" | "40P01"))
86            })
87        },
88        RepositoryError::NotFound(_)
89        | RepositoryError::Constraint(_)
90        | RepositoryError::Serialization(_)
91        | RepositoryError::InvalidArgument(_)
92        | RepositoryError::InvalidState(_)
93        | RepositoryError::Internal(_) => false,
94    }
95}