sql_middleware/postgres/
transaction.rs1use std::ops::DerefMut;
2
3use tokio_postgres::{Client, Statement, Transaction as PgTransaction};
4
5use crate::adapters::params::convert_params;
6use crate::middleware::{ConversionMode, ResultSet, RowValues, SqlMiddlewareDbError};
7use crate::tx_outcome::TxOutcome;
8
9use super::{Params, build_result_set};
10use crate::postgres::query::{build_result_set_from_rows, convert_affected_rows};
11
12pub struct Tx<'a> {
14 tx: PgTransaction<'a>,
15}
16
17pub struct Prepared {
19 stmt: Statement,
20}
21
22pub async fn begin_transaction<C>(conn: &mut C) -> Result<Tx<'_>, SqlMiddlewareDbError>
27where
28 C: DerefMut<Target = Client>,
29{
30 let tx = conn.deref_mut().transaction().await?;
31 Ok(Tx { tx })
32}
33
34impl Tx<'_> {
35 pub async fn prepare(&self, sql: &str) -> Result<Prepared, SqlMiddlewareDbError> {
40 let stmt = self.tx.prepare(sql).await?;
41 Ok(Prepared { stmt })
42 }
43
44 pub async fn execute_prepared(
49 &self,
50 prepared: &Prepared,
51 params: &[RowValues],
52 ) -> Result<usize, SqlMiddlewareDbError> {
53 let converted = convert_params::<Params>(params, ConversionMode::Execute)?;
54
55 let rows = self.tx.execute(&prepared.stmt, converted.as_refs()).await?;
56
57 convert_affected_rows(rows, "Invalid rows affected count")
58 }
59
60 pub async fn execute_dml(
65 &self,
66 query: &str,
67 params: &[RowValues],
68 ) -> Result<usize, SqlMiddlewareDbError> {
69 let converted = convert_params::<Params>(params, ConversionMode::Execute)?;
70 let rows = self.tx.execute(query, converted.as_refs()).await?;
71 convert_affected_rows(rows, "Invalid rows affected count")
72 }
73
74 pub async fn query_prepared(
79 &self,
80 prepared: &Prepared,
81 params: &[RowValues],
82 ) -> Result<ResultSet, SqlMiddlewareDbError> {
83 let converted = convert_params::<Params>(params, ConversionMode::Query)?;
84 build_result_set(&prepared.stmt, converted.as_refs(), &self.tx).await
85 }
86
87 pub async fn query(
92 &self,
93 query: &str,
94 params: &[RowValues],
95 ) -> Result<ResultSet, SqlMiddlewareDbError> {
96 let converted = convert_params::<Params>(params, ConversionMode::Query)?;
97 let rows = self.tx.query(query, converted.as_refs()).await?;
98 build_result_set_from_rows(&rows)
99 }
100
101 pub async fn execute_batch(&self, sql: &str) -> Result<(), SqlMiddlewareDbError> {
106 self.tx.batch_execute(sql).await?;
107 Ok(())
108 }
109
110 pub async fn commit(self) -> Result<TxOutcome, SqlMiddlewareDbError> {
115 self.tx.commit().await?;
116 Ok(TxOutcome::without_restored_connection())
117 }
118
119 pub async fn rollback(self) -> Result<TxOutcome, SqlMiddlewareDbError> {
124 self.tx.rollback().await?;
125 Ok(TxOutcome::without_restored_connection())
126 }
127}