Skip to main content

sql_middleware/postgres/
transaction.rs

1use std::ops::DerefMut;
2
3use tokio_postgres::{Client, Statement, Transaction as PgTransaction};
4
5use crate::adapters::params::convert_params;
6use crate::middleware::{ConversionMode, CustomDbRow, ResultSet, RowValues, SqlMiddlewareDbError};
7use crate::tx_outcome::TxOutcome;
8
9use super::{Params, build_result_set};
10use crate::postgres::query::{build_result_set_from_rows, convert_affected_rows};
11
12/// Lightweight transaction wrapper for Postgres.
13pub struct Tx<'a> {
14    tx: PgTransaction<'a>,
15}
16
17/// Prepared statement wrapper for Postgres.
18pub struct Prepared {
19    stmt: Statement,
20}
21
22/// Begin a new transaction on the provided Postgres connection.
23///
24/// # Errors
25/// Returns an error if creating the transaction fails.
26pub async fn begin_transaction<C>(conn: &mut C) -> Result<Tx<'_>, SqlMiddlewareDbError>
27where
28    C: DerefMut<Target = Client>,
29{
30    let tx = conn.deref_mut().transaction().await?;
31    Ok(Tx { tx })
32}
33
34impl<'conn> Tx<'conn> {
35    /// Prepare a SQL statement tied to this transaction.
36    ///
37    /// # Errors
38    /// Returns an error if the prepare call fails.
39    pub async fn prepare(&self, sql: &str) -> Result<Prepared, SqlMiddlewareDbError> {
40        let stmt = self.tx.prepare(sql).await?;
41        Ok(Prepared { stmt })
42    }
43
44    /// Start configuring a prepared SELECT execution.
45    #[must_use]
46    pub fn select<'tx, 'prepared>(
47        &'tx self,
48        prepared: &'prepared Prepared,
49    ) -> PreparedSelect<'tx, 'prepared, 'static, 'conn> {
50        PreparedSelect {
51            tx: self,
52            prepared,
53            params: &[],
54        }
55    }
56
57    /// Start configuring a prepared DML execution.
58    #[must_use]
59    pub fn execute<'tx, 'prepared>(
60        &'tx self,
61        prepared: &'prepared Prepared,
62    ) -> PreparedExecute<'tx, 'prepared, 'static, 'conn> {
63        PreparedExecute {
64            tx: self,
65            prepared,
66            params: &[],
67        }
68    }
69
70    /// Execute a parameterized DML statement and return the affected row count.
71    ///
72    /// # Errors
73    /// Returns an error if parameter conversion, execution, or row-count conversion fails.
74    pub(crate) async fn execute_prepared(
75        &self,
76        prepared: &Prepared,
77        params: &[RowValues],
78    ) -> Result<usize, SqlMiddlewareDbError> {
79        let converted = convert_params::<Params>(params, ConversionMode::Execute)?;
80
81        let rows = self.tx.execute(&prepared.stmt, converted.as_refs()).await?;
82
83        convert_affected_rows(rows, "Invalid rows affected count")
84    }
85
86    /// Execute a parameterized DML statement without preparing and return affected rows.
87    ///
88    /// # Errors
89    /// Returns an error if parameter conversion, execution, or row-count conversion fails.
90    pub async fn execute_dml(
91        &self,
92        query: &str,
93        params: &[RowValues],
94    ) -> Result<usize, SqlMiddlewareDbError> {
95        let converted = convert_params::<Params>(params, ConversionMode::Execute)?;
96        let rows = self.tx.execute(query, converted.as_refs()).await?;
97        convert_affected_rows(rows, "Invalid rows affected count")
98    }
99
100    /// Execute a parameterized SELECT and return a `ResultSet`.
101    ///
102    /// # Errors
103    /// Returns an error if parameter conversion, execution, or result building fails.
104    pub(crate) async fn query_prepared(
105        &self,
106        prepared: &Prepared,
107        params: &[RowValues],
108    ) -> Result<ResultSet, SqlMiddlewareDbError> {
109        let converted = convert_params::<Params>(params, ConversionMode::Query)?;
110        build_result_set(&prepared.stmt, converted.as_refs(), &self.tx).await
111    }
112
113    /// Execute a prepared SELECT and return the first row, if present.
114    ///
115    /// # Errors
116    /// Returns an error if parameter conversion, execution, or result building fails.
117    pub(crate) async fn query_prepared_optional(
118        &self,
119        prepared: &Prepared,
120        params: &[RowValues],
121    ) -> Result<Option<CustomDbRow>, SqlMiddlewareDbError> {
122        self.query_prepared(prepared, params)
123            .await
124            .map(ResultSet::into_optional)
125    }
126
127    /// Execute a prepared SELECT and return the first row.
128    ///
129    /// # Errors
130    /// Returns an error if execution fails or no row is returned.
131    pub(crate) async fn query_prepared_one(
132        &self,
133        prepared: &Prepared,
134        params: &[RowValues],
135    ) -> Result<CustomDbRow, SqlMiddlewareDbError> {
136        self.query_prepared(prepared, params).await?.into_one()
137    }
138
139    /// Execute a prepared SELECT and map the first native Postgres row.
140    ///
141    /// Use this for hot paths that only need one row and can decode directly from
142    /// `tokio_postgres::Row`, avoiding `ResultSet` materialisation.
143    ///
144    /// # Errors
145    /// Returns an error if execution fails, no row is returned, or the mapper fails.
146    pub(crate) async fn query_prepared_map_one<T, F>(
147        &self,
148        prepared: &Prepared,
149        params: &[RowValues],
150        mapper: F,
151    ) -> Result<T, SqlMiddlewareDbError>
152    where
153        F: FnOnce(&tokio_postgres::Row) -> Result<T, SqlMiddlewareDbError>,
154    {
155        self.query_prepared_map_optional(prepared, params, mapper)
156            .await?
157            .ok_or_else(|| SqlMiddlewareDbError::ExecutionError("query returned no rows".into()))
158    }
159
160    /// Execute a prepared SELECT and map the first native Postgres row, returning `None` if no row
161    /// exists.
162    ///
163    /// # Errors
164    /// Returns an error if execution or the mapper fails.
165    pub(crate) async fn query_prepared_map_optional<T, F>(
166        &self,
167        prepared: &Prepared,
168        params: &[RowValues],
169        mapper: F,
170    ) -> Result<Option<T>, SqlMiddlewareDbError>
171    where
172        F: FnOnce(&tokio_postgres::Row) -> Result<T, SqlMiddlewareDbError>,
173    {
174        let converted = convert_params::<Params>(params, ConversionMode::Query)?;
175        let row = self
176            .tx
177            .query_opt(&prepared.stmt, converted.as_refs())
178            .await?;
179        row.as_ref().map(mapper).transpose()
180    }
181
182    /// Execute a parameterized SELECT without preparing and return a `ResultSet`.
183    ///
184    /// # Errors
185    /// Returns an error if parameter conversion or query execution fails.
186    pub async fn query(
187        &self,
188        query: &str,
189        params: &[RowValues],
190    ) -> Result<ResultSet, SqlMiddlewareDbError> {
191        let converted = convert_params::<Params>(params, ConversionMode::Query)?;
192        let rows = self.tx.query(query, converted.as_refs()).await?;
193        build_result_set_from_rows(&rows)
194    }
195
196    /// Execute a batch of SQL statements inside the transaction.
197    ///
198    /// # Errors
199    /// Returns an error if execution fails.
200    pub async fn execute_batch(&self, sql: &str) -> Result<(), SqlMiddlewareDbError> {
201        self.tx.batch_execute(sql).await?;
202        Ok(())
203    }
204
205    /// Commit the transaction.
206    ///
207    /// # Errors
208    /// Returns an error if commit fails.
209    pub async fn commit(self) -> Result<TxOutcome, SqlMiddlewareDbError> {
210        self.tx.commit().await?;
211        Ok(TxOutcome::without_restored_connection())
212    }
213
214    /// Roll back the transaction.
215    ///
216    /// # Errors
217    /// Returns an error if rollback fails.
218    pub async fn rollback(self) -> Result<TxOutcome, SqlMiddlewareDbError> {
219        self.tx.rollback().await?;
220        Ok(TxOutcome::without_restored_connection())
221    }
222}
223
224/// Builder for executing a prepared Postgres DML statement inside a transaction.
225pub struct PreparedExecute<'tx, 'prepared, 'params, 'conn> {
226    tx: &'tx Tx<'conn>,
227    prepared: &'prepared Prepared,
228    params: &'params [RowValues],
229}
230
231impl<'tx, 'prepared, 'params, 'conn> PreparedExecute<'tx, 'prepared, 'params, 'conn> {
232    /// Use middleware `RowValues` parameters.
233    #[must_use]
234    pub fn params<'next>(
235        self,
236        params: &'next [RowValues],
237    ) -> PreparedExecute<'tx, 'prepared, 'next, 'conn> {
238        PreparedExecute {
239            tx: self.tx,
240            prepared: self.prepared,
241            params,
242        }
243    }
244
245    /// Execute the DML statement and return affected rows.
246    ///
247    /// # Errors
248    /// Returns an error if parameter conversion, execution, or row-count conversion fails.
249    pub async fn run(self) -> Result<usize, SqlMiddlewareDbError> {
250        self.tx.execute_prepared(self.prepared, self.params).await
251    }
252}
253
254/// Builder for executing a prepared Postgres SELECT inside a transaction.
255pub struct PreparedSelect<'tx, 'prepared, 'params, 'conn> {
256    tx: &'tx Tx<'conn>,
257    prepared: &'prepared Prepared,
258    params: &'params [RowValues],
259}
260
261impl<'tx, 'prepared, 'params, 'conn> PreparedSelect<'tx, 'prepared, 'params, 'conn> {
262    /// Use middleware `RowValues` parameters.
263    #[must_use]
264    pub fn params<'next>(
265        self,
266        params: &'next [RowValues],
267    ) -> PreparedSelect<'tx, 'prepared, 'next, 'conn> {
268        PreparedSelect {
269            tx: self.tx,
270            prepared: self.prepared,
271            params,
272        }
273    }
274
275    /// Execute and return all rows as a `ResultSet`.
276    ///
277    /// # Errors
278    /// Returns an error if parameter conversion, execution, or result building fails.
279    pub async fn all(self) -> Result<ResultSet, SqlMiddlewareDbError> {
280        self.tx.query_prepared(self.prepared, self.params).await
281    }
282
283    /// Execute and return the first row, if present.
284    ///
285    /// # Errors
286    /// Returns an error if parameter conversion, execution, or result building fails.
287    pub async fn optional(self) -> Result<Option<CustomDbRow>, SqlMiddlewareDbError> {
288        self.tx
289            .query_prepared_optional(self.prepared, self.params)
290            .await
291    }
292
293    /// Execute and return exactly one row.
294    ///
295    /// # Errors
296    /// Returns an error if execution fails or no row is returned.
297    pub async fn one(self) -> Result<CustomDbRow, SqlMiddlewareDbError> {
298        self.tx.query_prepared_one(self.prepared, self.params).await
299    }
300
301    /// Execute and map exactly one native Postgres row.
302    ///
303    /// # Errors
304    /// Returns an error if execution fails, no row is returned, or the mapper fails.
305    pub async fn map_one<T, F>(self, mapper: F) -> Result<T, SqlMiddlewareDbError>
306    where
307        F: FnOnce(&tokio_postgres::Row) -> Result<T, SqlMiddlewareDbError>,
308    {
309        self.tx
310            .query_prepared_map_one(self.prepared, self.params, mapper)
311            .await
312    }
313
314    /// Execute and map the first native Postgres row, if present.
315    ///
316    /// # Errors
317    /// Returns an error if execution or the mapper fails.
318    pub async fn map_optional<T, F>(self, mapper: F) -> Result<Option<T>, SqlMiddlewareDbError>
319    where
320        F: FnOnce(&tokio_postgres::Row) -> Result<T, SqlMiddlewareDbError>,
321    {
322        self.tx
323            .query_prepared_map_optional(self.prepared, self.params, mapper)
324            .await
325    }
326}