Skip to main content

sql_middleware/postgres/
transaction.rs

1use std::ops::DerefMut;
2
3use tokio_postgres::{Client, Statement, Transaction as PgTransaction};
4
5use crate::adapters::params::convert_params;
6use crate::middleware::{ConversionMode, ResultSet, RowValues, SqlMiddlewareDbError};
7use crate::tx_outcome::TxOutcome;
8
9use super::{Params, build_result_set};
10use crate::postgres::query::{build_result_set_from_rows, convert_affected_rows};
11
12/// Lightweight transaction wrapper for Postgres.
13pub struct Tx<'a> {
14    tx: PgTransaction<'a>,
15}
16
17/// Prepared statement wrapper for Postgres.
18pub struct Prepared {
19    stmt: Statement,
20}
21
22/// Begin a new transaction on the provided Postgres connection.
23///
24/// # Errors
25/// Returns an error if creating the transaction fails.
26pub async fn begin_transaction<C>(conn: &mut C) -> Result<Tx<'_>, SqlMiddlewareDbError>
27where
28    C: DerefMut<Target = Client>,
29{
30    let tx = conn.deref_mut().transaction().await?;
31    Ok(Tx { tx })
32}
33
34impl Tx<'_> {
35    /// Prepare a SQL statement tied to this transaction.
36    ///
37    /// # Errors
38    /// Returns an error if the prepare call fails.
39    pub async fn prepare(&self, sql: &str) -> Result<Prepared, SqlMiddlewareDbError> {
40        let stmt = self.tx.prepare(sql).await?;
41        Ok(Prepared { stmt })
42    }
43
44    /// Execute a parameterized DML statement and return the affected row count.
45    ///
46    /// # Errors
47    /// Returns an error if parameter conversion, execution, or row-count conversion fails.
48    pub async fn execute_prepared(
49        &self,
50        prepared: &Prepared,
51        params: &[RowValues],
52    ) -> Result<usize, SqlMiddlewareDbError> {
53        let converted = convert_params::<Params>(params, ConversionMode::Execute)?;
54
55        let rows = self.tx.execute(&prepared.stmt, converted.as_refs()).await?;
56
57        convert_affected_rows(rows, "Invalid rows affected count")
58    }
59
60    /// Execute a parameterized DML statement without preparing and return affected rows.
61    ///
62    /// # Errors
63    /// Returns an error if parameter conversion, execution, or row-count conversion fails.
64    pub async fn execute_dml(
65        &self,
66        query: &str,
67        params: &[RowValues],
68    ) -> Result<usize, SqlMiddlewareDbError> {
69        let converted = convert_params::<Params>(params, ConversionMode::Execute)?;
70        let rows = self.tx.execute(query, converted.as_refs()).await?;
71        convert_affected_rows(rows, "Invalid rows affected count")
72    }
73
74    /// Execute a parameterized SELECT and return a `ResultSet`.
75    ///
76    /// # Errors
77    /// Returns an error if parameter conversion, execution, or result building fails.
78    pub async fn query_prepared(
79        &self,
80        prepared: &Prepared,
81        params: &[RowValues],
82    ) -> Result<ResultSet, SqlMiddlewareDbError> {
83        let converted = convert_params::<Params>(params, ConversionMode::Query)?;
84        build_result_set(&prepared.stmt, converted.as_refs(), &self.tx).await
85    }
86
87    /// Execute a parameterized SELECT without preparing and return a `ResultSet`.
88    ///
89    /// # Errors
90    /// Returns an error if parameter conversion or query execution fails.
91    pub async fn query(
92        &self,
93        query: &str,
94        params: &[RowValues],
95    ) -> Result<ResultSet, SqlMiddlewareDbError> {
96        let converted = convert_params::<Params>(params, ConversionMode::Query)?;
97        let rows = self.tx.query(query, converted.as_refs()).await?;
98        build_result_set_from_rows(&rows)
99    }
100
101    /// Execute a batch of SQL statements inside the transaction.
102    ///
103    /// # Errors
104    /// Returns an error if execution fails.
105    pub async fn execute_batch(&self, sql: &str) -> Result<(), SqlMiddlewareDbError> {
106        self.tx.batch_execute(sql).await?;
107        Ok(())
108    }
109
110    /// Commit the transaction.
111    ///
112    /// # Errors
113    /// Returns an error if commit fails.
114    pub async fn commit(self) -> Result<TxOutcome, SqlMiddlewareDbError> {
115        self.tx.commit().await?;
116        Ok(TxOutcome::without_restored_connection())
117    }
118
119    /// Roll back the transaction.
120    ///
121    /// # Errors
122    /// Returns an error if rollback fails.
123    pub async fn rollback(self) -> Result<TxOutcome, SqlMiddlewareDbError> {
124        self.tx.rollback().await?;
125        Ok(TxOutcome::without_restored_connection())
126    }
127}