Skip to main content

systemprompt_database/services/postgres/
transaction.rs

1use anyhow::{anyhow, Result};
2use async_trait::async_trait;
3
4use super::conversion::{bind_params, row_to_json};
5use crate::models::{DatabaseTransaction, JsonRow, QuerySelector, ToDbValue};
6
7pub struct PostgresTransaction {
8    tx: Option<sqlx::Transaction<'static, sqlx::Postgres>>,
9}
10
11impl std::fmt::Debug for PostgresTransaction {
12    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
13        f.debug_struct("PostgresTransaction")
14            .field("tx", &self.tx.is_some())
15            .finish()
16    }
17}
18
19impl PostgresTransaction {
20    #[must_use]
21    pub const fn new(tx: sqlx::Transaction<'static, sqlx::Postgres>) -> Self {
22        Self { tx: Some(tx) }
23    }
24}
25
26#[async_trait]
27impl DatabaseTransaction for PostgresTransaction {
28    async fn execute(
29        &mut self,
30        query: &dyn QuerySelector,
31        params: &[&dyn ToDbValue],
32    ) -> Result<u64> {
33        let sql = query.select_query();
34        let tx = self
35            .tx
36            .as_mut()
37            .ok_or_else(|| anyhow!("Transaction already consumed"))?;
38
39        let query_obj = sqlx::query(sql);
40        let query_obj = bind_params(query_obj, params);
41
42        let result = query_obj
43            .execute(&mut **tx)
44            .await
45            .map_err(|e| anyhow!("Query execution failed: {e}"))?;
46
47        Ok(result.rows_affected())
48    }
49
50    async fn fetch_all(
51        &mut self,
52        query: &dyn QuerySelector,
53        params: &[&dyn ToDbValue],
54    ) -> Result<Vec<JsonRow>> {
55        let sql = query.select_query();
56        let tx = self
57            .tx
58            .as_mut()
59            .ok_or_else(|| anyhow!("Transaction already consumed"))?;
60
61        let query_obj = sqlx::query(sql);
62        let query_obj = bind_params(query_obj, params);
63
64        let rows = query_obj
65            .fetch_all(&mut **tx)
66            .await
67            .map_err(|e| anyhow!("Query execution failed: {e}"))?;
68
69        Ok(rows.iter().map(row_to_json).collect())
70    }
71
72    async fn fetch_one(
73        &mut self,
74        query: &dyn QuerySelector,
75        params: &[&dyn ToDbValue],
76    ) -> Result<JsonRow> {
77        let sql = query.select_query();
78        let tx = self
79            .tx
80            .as_mut()
81            .ok_or_else(|| anyhow!("Transaction already consumed"))?;
82
83        let query_obj = sqlx::query(sql);
84        let query_obj = bind_params(query_obj, params);
85
86        let row = query_obj
87            .fetch_one(&mut **tx)
88            .await
89            .map_err(|e| anyhow!("Query execution failed: {e}"))?;
90
91        Ok(row_to_json(&row))
92    }
93
94    async fn fetch_optional(
95        &mut self,
96        query: &dyn QuerySelector,
97        params: &[&dyn ToDbValue],
98    ) -> Result<Option<JsonRow>> {
99        let sql = query.select_query();
100        let tx = self
101            .tx
102            .as_mut()
103            .ok_or_else(|| anyhow!("Transaction already consumed"))?;
104
105        let query_obj = sqlx::query(sql);
106        let query_obj = bind_params(query_obj, params);
107
108        let row = query_obj
109            .fetch_optional(&mut **tx)
110            .await
111            .map_err(|e| anyhow!("Query execution failed: {e}"))?;
112
113        Ok(row.map(|r| row_to_json(&r)))
114    }
115
116    async fn commit(mut self: Box<Self>) -> Result<()> {
117        let tx = self
118            .tx
119            .take()
120            .ok_or_else(|| anyhow!("Transaction already consumed"))?;
121
122        tx.commit()
123            .await
124            .map_err(|e| anyhow!("Transaction commit failed: {e}"))?;
125
126        Ok(())
127    }
128
129    async fn rollback(mut self: Box<Self>) -> Result<()> {
130        let tx = self
131            .tx
132            .take()
133            .ok_or_else(|| anyhow!("Transaction already consumed"))?;
134
135        tx.rollback()
136            .await
137            .map_err(|e| anyhow!("Transaction rollback failed: {e}"))?;
138
139        Ok(())
140    }
141}