sql_middleware/postgres/
transaction.rs

1use std::ops::DerefMut;
2
3use tokio_postgres::{Client, Statement, Transaction as PgTransaction};
4
5use crate::middleware::{
6    ConversionMode, ParamConverter, ResultSet, RowValues, SqlMiddlewareDbError,
7};
8use crate::tx_outcome::TxOutcome;
9
10use super::{Params, build_result_set};
11
12/// Lightweight transaction wrapper for Postgres.
13pub struct Tx<'a> {
14    tx: PgTransaction<'a>,
15}
16
17/// Prepared statement wrapper for Postgres.
18pub struct Prepared {
19    stmt: Statement,
20}
21
22/// Begin a new transaction on the provided Postgres connection.
23///
24/// # Errors
25/// Returns an error if creating the transaction fails.
26pub async fn begin_transaction<C>(conn: &mut C) -> Result<Tx<'_>, SqlMiddlewareDbError>
27where
28    C: DerefMut<Target = Client>,
29{
30    let tx = conn.deref_mut().transaction().await?;
31    Ok(Tx { tx })
32}
33
34impl Tx<'_> {
35    /// Prepare a SQL statement tied to this transaction.
36    ///
37    /// # Errors
38    /// Returns an error if the prepare call fails.
39    pub async fn prepare(&self, sql: &str) -> Result<Prepared, SqlMiddlewareDbError> {
40        let stmt = self.tx.prepare(sql).await?;
41        Ok(Prepared { stmt })
42    }
43
44    /// Execute a parameterized DML statement and return the affected row count.
45    ///
46    /// # Errors
47    /// Returns an error if parameter conversion, execution, or row-count conversion fails.
48    pub async fn execute_prepared(
49        &self,
50        prepared: &Prepared,
51        params: &[RowValues],
52    ) -> Result<usize, SqlMiddlewareDbError> {
53        let converted =
54            <Params as ParamConverter>::convert_sql_params(params, ConversionMode::Execute)?;
55
56        let rows = self.tx.execute(&prepared.stmt, converted.as_refs()).await?;
57
58        usize::try_from(rows).map_err(|e| {
59            SqlMiddlewareDbError::ExecutionError(format!("Invalid rows affected count: {e}"))
60        })
61    }
62
63    /// Execute a parameterized SELECT and return a `ResultSet`.
64    ///
65    /// # Errors
66    /// Returns an error if parameter conversion, execution, or result building fails.
67    pub async fn query_prepared(
68        &self,
69        prepared: &Prepared,
70        params: &[RowValues],
71    ) -> Result<ResultSet, SqlMiddlewareDbError> {
72        let converted =
73            <Params as ParamConverter>::convert_sql_params(params, ConversionMode::Query)?;
74        build_result_set(&prepared.stmt, converted.as_refs(), &self.tx).await
75    }
76
77    /// Execute a batch of SQL statements inside the transaction.
78    ///
79    /// # Errors
80    /// Returns an error if execution fails.
81    pub async fn execute_batch(&self, sql: &str) -> Result<(), SqlMiddlewareDbError> {
82        self.tx.batch_execute(sql).await?;
83        Ok(())
84    }
85
86    /// Commit the transaction.
87    ///
88    /// # Errors
89    /// Returns an error if commit fails.
90    pub async fn commit(self) -> Result<TxOutcome, SqlMiddlewareDbError> {
91        self.tx.commit().await?;
92        Ok(TxOutcome::without_restored_connection())
93    }
94
95    /// Roll back the transaction.
96    ///
97    /// # Errors
98    /// Returns an error if rollback fails.
99    pub async fn rollback(self) -> Result<TxOutcome, SqlMiddlewareDbError> {
100        self.tx.rollback().await?;
101        Ok(TxOutcome::without_restored_connection())
102    }
103}