Skip to main content

umbral_core/orm/queryset/
tx.rs

1//! `QuerySetTx` — a QuerySet bound to an open transaction.
2//!
3//! Construction happens in [`super::QuerySet::on_tx`] /
4//! [`super::Manager::on_tx`] using struct-literal syntax against the
5//! `pub(super)` fields. All terminals here mirror their plain-QuerySet
6//! siblings but route their SQL through the borrowed
7//! [`crate::db::Transaction`] so the operations commit or roll back
8//! as a unit with every other operation in the same
9//! `umbral::db::transaction(...)` closure.
10//!
11//! The struct borrows `&mut Transaction` so the borrow checker
12//! enforces that only one `QuerySetTx` uses the transaction at a
13//! time, and that the transaction stays alive for the duration of
14//! each terminal call.
15
16use sea_query::{Expr, Func, PostgresQueryBuilder, SqliteQueryBuilder};
17use sea_query_binder::SqlxBinder;
18
19use crate::orm::{HydrateRelated, Model};
20
21use super::QuerySet;
22use super::errors::GetError;
23use super::write_helpers::{build_insert_one_for, serialize_to_map};
24
25/// A `QuerySet` bound to an open transaction. See module docs for
26/// the construction sites and the borrow-checker contract.
27pub struct QuerySetTx<'tx, T> {
28    pub(super) qs: QuerySet<T>,
29    pub(super) tx: &'tx mut crate::db::Transaction,
30}
31
32impl<'tx, T: Model> QuerySetTx<'tx, T> {
33    // -----------------------------------------------------------------------
34    // Read terminals
35    // -----------------------------------------------------------------------
36
37    /// SELECT all matching rows inside the transaction.
38    pub async fn fetch(self) -> Result<Vec<T>, sqlx::Error>
39    where
40        T: for<'r> sqlx::FromRow<'r, sqlx::sqlite::SqliteRow>
41            + for<'r> sqlx::FromRow<'r, sqlx::postgres::PgRow>
42            + HydrateRelated,
43    {
44        let q = self.qs.build_query_for(self.tx.backend_name());
45        let mut rows = match self.tx.backend_name() {
46            "sqlite" => {
47                let tx = self.tx.as_sqlite_mut().unwrap();
48                let (sql, values) = q.build_sqlx(SqliteQueryBuilder);
49                sqlx::query_as_with::<sqlx::Sqlite, T, _>(&sql, values)
50                    .fetch_all(&mut **tx)
51                    .await?
52            }
53            _ => {
54                let tx = self.tx.as_pg_mut().unwrap();
55                let (sql, values) = q.build_sqlx(PostgresQueryBuilder);
56                sqlx::query_as_with::<sqlx::Postgres, T, _>(&sql, values)
57                    .fetch_all(&mut **tx)
58                    .await?
59            }
60        };
61        // BUG-16 step 2: wire each row's PK into its M2M slots so
62        // junction-table accessors used inside the transaction see
63        // the right parent.
64        for r in &mut rows {
65            r.set_m2m_parent_ids();
66        }
67        Ok(rows)
68    }
69
70    /// SELECT LIMIT 1 and return the first row, if any.
71    pub async fn first(mut self) -> Result<Option<T>, sqlx::Error>
72    where
73        T: for<'r> sqlx::FromRow<'r, sqlx::sqlite::SqliteRow>
74            + for<'r> sqlx::FromRow<'r, sqlx::postgres::PgRow>
75            + HydrateRelated,
76    {
77        self.qs.query.limit(1);
78        let q = self.qs.build_query_for(self.tx.backend_name());
79        let mut row = match self.tx.backend_name() {
80            "sqlite" => {
81                let tx = self.tx.as_sqlite_mut().unwrap();
82                let (sql, values) = q.build_sqlx(SqliteQueryBuilder);
83                sqlx::query_as_with::<sqlx::Sqlite, T, _>(&sql, values)
84                    .fetch_optional(&mut **tx)
85                    .await?
86            }
87            _ => {
88                let tx = self.tx.as_pg_mut().unwrap();
89                let (sql, values) = q.build_sqlx(PostgresQueryBuilder);
90                sqlx::query_as_with::<sqlx::Postgres, T, _>(&sql, values)
91                    .fetch_optional(&mut **tx)
92                    .await?
93            }
94        };
95        if let Some(r) = row.as_mut() {
96            r.set_m2m_parent_ids();
97        }
98        Ok(row)
99    }
100
101    /// SELECT COUNT(*) inside the transaction.
102    pub async fn count(self) -> Result<i64, sqlx::Error> {
103        let backend = self.tx.backend_name();
104        let mut rebuilt = self.qs.build_query_for(backend);
105        rebuilt.clear_selects();
106        // `sea_query::Asterisk` renders the bare SQL `*` token; `Alias::new("*")`
107        // would render `COUNT("*")` — a quoted identifier Postgres reads as a
108        // column named `*`. Matches the non-transactional count path.
109        rebuilt.expr(Func::count(Expr::col(sea_query::Asterisk)));
110        rebuilt.reset_limit();
111        rebuilt.reset_offset();
112        match backend {
113            "sqlite" => {
114                let tx = self.tx.as_sqlite_mut().unwrap();
115                let (sql, values) = rebuilt.build_sqlx(SqliteQueryBuilder);
116                let (n,): (i64,) = sqlx::query_as_with::<sqlx::Sqlite, (i64,), _>(&sql, values)
117                    .fetch_one(&mut **tx)
118                    .await?;
119                Ok(n)
120            }
121            _ => {
122                let tx = self.tx.as_pg_mut().unwrap();
123                let (sql, values) = rebuilt.build_sqlx(PostgresQueryBuilder);
124                let (n,): (i64,) = sqlx::query_as_with::<sqlx::Postgres, (i64,), _>(&sql, values)
125                    .fetch_one(&mut **tx)
126                    .await?;
127                Ok(n)
128            }
129        }
130    }
131
132    /// Return whether any row matches, inside the transaction.
133    pub async fn exists(mut self) -> Result<bool, sqlx::Error>
134    where
135        T: for<'r> sqlx::FromRow<'r, sqlx::sqlite::SqliteRow>
136            + for<'r> sqlx::FromRow<'r, sqlx::postgres::PgRow>,
137    {
138        self.qs.query.limit(1);
139        let backend = self.tx.backend_name();
140        let q = self.qs.build_query_for(backend);
141        let row_opt: Option<T> = match backend {
142            "sqlite" => {
143                let tx = self.tx.as_sqlite_mut().unwrap();
144                let (sql, values) = q.build_sqlx(SqliteQueryBuilder);
145                sqlx::query_as_with::<sqlx::Sqlite, T, _>(&sql, values)
146                    .fetch_optional(&mut **tx)
147                    .await?
148            }
149            _ => {
150                let tx = self.tx.as_pg_mut().unwrap();
151                let (sql, values) = q.build_sqlx(PostgresQueryBuilder);
152                sqlx::query_as_with::<sqlx::Postgres, T, _>(&sql, values)
153                    .fetch_optional(&mut **tx)
154                    .await?
155            }
156        };
157        Ok(row_opt.is_some())
158    }
159
160    /// Exactly-one terminal inside the transaction. See [`super::QuerySet::get`].
161    pub async fn get(mut self) -> Result<T, GetError>
162    where
163        T: for<'r> sqlx::FromRow<'r, sqlx::sqlite::SqliteRow>
164            + for<'r> sqlx::FromRow<'r, sqlx::postgres::PgRow>,
165    {
166        self.qs.query.limit(2);
167        let q = self.qs.build_query_for(self.tx.backend_name());
168        let mut rows: Vec<T> = match self.tx.backend_name() {
169            "sqlite" => {
170                let tx = self.tx.as_sqlite_mut().unwrap();
171                let (sql, values) = q.build_sqlx(SqliteQueryBuilder);
172                sqlx::query_as_with::<sqlx::Sqlite, T, _>(&sql, values)
173                    .fetch_all(&mut **tx)
174                    .await
175                    .map_err(GetError::Sqlx)?
176            }
177            _ => {
178                let tx = self.tx.as_pg_mut().unwrap();
179                let (sql, values) = q.build_sqlx(PostgresQueryBuilder);
180                sqlx::query_as_with::<sqlx::Postgres, T, _>(&sql, values)
181                    .fetch_all(&mut **tx)
182                    .await
183                    .map_err(GetError::Sqlx)?
184            }
185        };
186        match rows.len() {
187            0 => Err(GetError::NotFound),
188            1 => Ok(rows.pop().unwrap()),
189            _ => Err(GetError::MultipleObjectsReturned),
190        }
191    }
192
193    // -----------------------------------------------------------------------
194    // Write terminals
195    // -----------------------------------------------------------------------
196
197    /// DELETE inside the transaction. Returns the number of rows deleted.
198    pub async fn delete(self) -> Result<u64, sqlx::Error> {
199        let stmt = self.qs.build_delete_for(self.tx.backend_name());
200        match self.tx.backend_name() {
201            "sqlite" => {
202                let tx = self.tx.as_sqlite_mut().unwrap();
203                let (sql, values) = stmt.build_sqlx(SqliteQueryBuilder);
204                let result = sqlx::query_with::<sqlx::Sqlite, _>(&sql, values)
205                    .execute(&mut **tx)
206                    .await?;
207                Ok(result.rows_affected())
208            }
209            _ => {
210                let tx = self.tx.as_pg_mut().unwrap();
211                let (sql, values) = stmt.build_sqlx(PostgresQueryBuilder);
212                let result = sqlx::query_with::<sqlx::Postgres, _>(&sql, values)
213                    .execute(&mut **tx)
214                    .await?;
215                Ok(result.rows_affected())
216            }
217        }
218    }
219
220    /// UPDATE inside the transaction. Takes the same `column → JSON value`
221    /// map as [`super::QuerySet::update_values`].
222    pub async fn update_values(
223        self,
224        values: serde_json::Map<String, serde_json::Value>,
225    ) -> Result<u64, crate::orm::write::WriteError> {
226        let stmt = self.qs.build_update_for(self.tx.backend_name(), &values)?;
227        match self.tx.backend_name() {
228            "sqlite" => {
229                let tx = self.tx.as_sqlite_mut().unwrap();
230                let (sql, values) = stmt.build_sqlx(SqliteQueryBuilder);
231                let result = sqlx::query_with::<sqlx::Sqlite, _>(&sql, values)
232                    .execute(&mut **tx)
233                    .await?;
234                Ok(result.rows_affected())
235            }
236            _ => {
237                let tx = self.tx.as_pg_mut().unwrap();
238                let (sql, values) = stmt.build_sqlx(PostgresQueryBuilder);
239                let result = sqlx::query_with::<sqlx::Postgres, _>(&sql, values)
240                    .execute(&mut **tx)
241                    .await?;
242                Ok(result.rows_affected())
243            }
244        }
245    }
246
247    /// INSERT one row and return the populated row, inside the transaction.
248    ///
249    /// This is the `Manager::create_in_tx` equivalent called through the
250    /// QuerySet API: `Post::objects().on_tx(tx).create(instance).await?`.
251    pub async fn create(self, instance: T) -> Result<T, crate::orm::write::WriteError>
252    where
253        T: serde::Serialize
254            + for<'r> sqlx::FromRow<'r, sqlx::sqlite::SqliteRow>
255            + for<'r> sqlx::FromRow<'r, sqlx::postgres::PgRow>
256            + HydrateRelated,
257    {
258        let map = serialize_to_map(&instance)?;
259        let stmt = build_insert_one_for::<T>(self.tx.backend_name(), &map)?;
260        match self.tx.backend_name() {
261            "sqlite" => {
262                let tx = self.tx.as_sqlite_mut().unwrap();
263                let (sql, values) = stmt.build_sqlx(SqliteQueryBuilder);
264                let mut row = sqlx::query_as_with::<sqlx::Sqlite, T, _>(&sql, values)
265                    .fetch_one(&mut **tx)
266                    .await?;
267                row.set_m2m_parent_ids();
268                Ok(row)
269            }
270            _ => {
271                let tx = self.tx.as_pg_mut().unwrap();
272                let (sql, values) = stmt.build_sqlx(PostgresQueryBuilder);
273                let mut row = sqlx::query_as_with::<sqlx::Postgres, T, _>(&sql, values)
274                    .fetch_one(&mut **tx)
275                    .await?;
276                row.set_m2m_parent_ids();
277                Ok(row)
278            }
279        }
280    }
281}