zino_orm/
transaction.rs

1use super::{
2    DatabaseDriver, EncodeColumn, executor::Executor, mutation::MutationExt, query::QueryExt,
3    schema::Schema,
4};
5use std::fmt::Display;
6use zino_core::{
7    BoxFuture, Map,
8    error::Error,
9    extension::JsonValueExt,
10    model::{Mutation, Query},
11};
12
13#[cfg(feature = "orm-sqlx")]
14use sqlx::Acquire;
15
16/// An in-progress database transaction.
17///
18/// # Examples
19///
20/// ```rust,ignore
21/// use crate::model::{Account, AccountColumn, Order, Stock, StockColumn};
22/// use zino_orm::{Executor, MutationBuilder, QueryBuilder, Schema, Transaction};
23///
24/// let user_id = "0193d8e6-2970-7b52-bc06-80a981212aa9";
25/// let product_id = "0193c06d-bee6-7070-a5e7-9659161bddb5";
26///
27/// let order = Order::from_customer(user_id, product_id);
28/// let quantity = order.quantity();
29/// let total_price = order.total_price();
30/// let order_ctx = order.prepare_insert().await?;
31///
32/// let stock_query = QueryBuilder::new()
33///     .and_eq(StockColumn::ProductId, product_id)
34///     .and_ge(StockColumn::Quantity, quantity)
35///     .build();
36/// let mut stock_mutation = MutationBuilder::<Stock>::new()
37///     .inc(StockColumn::Quantity, -quantity)
38///     .build();
39/// let stock_ctx = Stock::prepare_update_one(&stock_query, &mut stock_mutation).await?;
40///
41/// let account_query = QueryBuilder::new()
42///     .and_eq(AccountColumn::UserId, user_id)
43///     .and_ge(AccountColumn::Balance, total_price)
44///     .build();
45/// let mut account_mutation = MutationBuilder::<Account>::new()
46///     .inc(AccountColumn::Balance, -total_price)
47///     .build();
48/// let account_ctx = Account::prepare_update_one(&account_query, &mut account_mutation).await?;
49///
50/// Order::transaction(move |tx| Box::pin(async move {
51///     tx.execute(order_ctx.query()).await?;
52///     tx.execute(stock_ctx.query()).await?;
53///     tx.execute(account_ctx.query()).await?;
54///     Ok(())
55/// })).await?;
56/// ```
57pub trait Transaction<K, Tx>: Schema<PrimaryKey = K>
58where
59    K: Default + Display + PartialEq,
60{
61    /// Executes the specific operations inside of a transaction.
62    /// If the operations return an error, the transaction will be rolled back;
63    /// if not, the transaction will be committed.
64    async fn transaction<F, T>(tx: F) -> Result<T, Error>
65    where
66        F: for<'t> FnOnce(&'t mut Tx) -> BoxFuture<'t, Result<T, Error>>;
67
68    /// Executes the queries sequentially inside of a transaction.
69    /// If it returns an error, the transaction will be rolled back;
70    /// if not, the transaction will be committed.
71    async fn transactional_execute(queries: &[&str], params: Option<&Map>) -> Result<u64, Error>;
72
73    /// Inserts the model and its associations inside of a transaction.
74    async fn transactional_insert<M: Schema>(self, models: Vec<M>) -> Result<u64, Error>;
75
76    /// Updates the models inside of a transaction.
77    async fn transactional_update<M: Schema>(
78        queries: (&Query, &Query),
79        mutations: (&mut Mutation, &mut Mutation),
80    ) -> Result<u64, Error>;
81
82    /// Deletes the models inside of a transaction.
83    async fn transactional_delete<M: Schema>(queries: (&Query, &Query)) -> Result<u64, Error>;
84}
85
86#[cfg(feature = "orm-sqlx")]
87impl<'c, M, K> Transaction<K, sqlx::Transaction<'c, DatabaseDriver>> for M
88where
89    M: Schema<PrimaryKey = K>,
90    K: Default + Display + PartialEq,
91{
92    async fn transaction<F, T>(tx: F) -> Result<T, Error>
93    where
94        F: for<'t> FnOnce(
95            &'t mut sqlx::Transaction<'c, DatabaseDriver>,
96        ) -> BoxFuture<'t, Result<T, Error>>,
97    {
98        let mut transaction = Self::acquire_writer().await?.pool().begin().await?;
99        let data = tx(&mut transaction).await?;
100        transaction.commit().await?;
101        Ok(data)
102    }
103
104    async fn transactional_execute(queries: &[&str], params: Option<&Map>) -> Result<u64, Error> {
105        let mut transaction = Self::acquire_writer().await?.pool().begin().await?;
106        let connection = transaction.acquire().await?;
107
108        let mut total_rows = 0;
109        for query in queries {
110            let (sql, values) = Query::prepare_query(query, params);
111            let mut ctx = Self::before_scan(&sql).await?;
112            ctx.set_query(sql);
113
114            let mut arguments = values
115                .iter()
116                .map(|v| v.to_string_unquoted())
117                .collect::<Vec<_>>();
118            let rows_affected = connection
119                .execute_with(ctx.query(), &arguments)
120                .await?
121                .rows_affected();
122            total_rows += rows_affected;
123            ctx.append_arguments(&mut arguments);
124            ctx.set_query_result(rows_affected, true);
125            Self::after_scan(&ctx).await?;
126        }
127        transaction.commit().await?;
128        Ok(total_rows)
129    }
130
131    async fn transactional_insert<S: Schema>(mut self, associations: Vec<S>) -> Result<u64, Error> {
132        let mut transaction = Self::acquire_writer().await?.pool().begin().await?;
133        let connection = transaction.acquire().await?;
134
135        // Inserts the model
136        let model_data = self.before_insert().await?;
137        let table_name = if let Some(table) = self.before_prepare().await? {
138            Query::escape_table_name(&table)
139        } else {
140            Query::escape_table_name(Self::table_name())
141        };
142        let map = self.into_map();
143        let columns = Self::columns();
144
145        let mut fields = Vec::with_capacity(columns.len());
146        let values = columns
147            .iter()
148            .filter_map(|col| {
149                if col.auto_increment() {
150                    None
151                } else {
152                    let name = col.name();
153                    fields.push(name);
154                    Some(col.encode_value(map.get(name)))
155                }
156            })
157            .collect::<Vec<_>>()
158            .join(", ");
159        let fields = fields.join(", ");
160        let sql = format!("INSERT INTO {table_name} ({fields}) VALUES ({values});");
161        let mut ctx = Self::before_scan(&sql).await?;
162        ctx.set_query(sql);
163
164        let mut total_rows = 0;
165        let query_result = connection.execute(ctx.query()).await?;
166        let (last_insert_id, rows_affected) = Query::parse_query_result(query_result);
167        let success = rows_affected == 1;
168        if let Some(last_insert_id) = last_insert_id {
169            ctx.set_last_insert_id(last_insert_id);
170        }
171        total_rows += rows_affected;
172        ctx.set_query_result(rows_affected, success);
173        Self::after_scan(&ctx).await?;
174        Self::after_insert(&ctx, model_data).await?;
175
176        // Inserts associations
177        let columns = S::columns();
178        let mut values = Vec::with_capacity(associations.len());
179        for mut association in associations.into_iter() {
180            let _association_data = association.before_insert().await?;
181            let map = association.into_map();
182            let entries = columns
183                .iter()
184                .map(|col| col.encode_value(map.get(col.name())))
185                .collect::<Vec<_>>()
186                .join(", ");
187            values.push(format!("({entries})"));
188        }
189
190        let table_name = Query::escape_table_name(S::table_name());
191        let fields = S::fields().join(", ");
192        let values = values.join(", ");
193        let sql = format!("INSERT INTO {table_name} ({fields}) VALUES {values};");
194        let mut ctx = S::before_scan(&sql).await?;
195        ctx.set_query(sql);
196
197        let rows_affected = connection.execute(ctx.query()).await?.rows_affected();
198        total_rows += rows_affected;
199        ctx.set_query_result(rows_affected, true);
200        S::after_scan(&ctx).await?;
201
202        // Commits the transaction
203        transaction.commit().await?;
204        Ok(total_rows)
205    }
206
207    async fn transactional_update<S: Schema>(
208        queries: (&Query, &Query),
209        mutations: (&mut Mutation, &mut Mutation),
210    ) -> Result<u64, Error> {
211        let mut transaction = Self::acquire_writer().await?.pool().begin().await?;
212        let connection = transaction.acquire().await?;
213
214        let query = queries.0;
215        let mutation = mutations.0;
216        Self::before_mutation(query, mutation).await?;
217
218        let table_name = query.format_table_name::<Self>();
219        let filters = query.format_filters::<Self>();
220        let updates = mutation.format_updates::<Self>();
221        let sql = format!("UPDATE {table_name} SET {updates} {filters};");
222        let mut ctx = Self::before_scan(&sql).await?;
223        ctx.set_query(sql);
224
225        let mut total_rows = 0;
226        let rows_affected = connection.execute(ctx.query()).await?.rows_affected();
227        total_rows += rows_affected;
228        ctx.set_query_result(rows_affected, true);
229        Self::after_scan(&ctx).await?;
230        Self::after_mutation(&ctx).await?;
231
232        let query = queries.1;
233        let mutation = mutations.1;
234        S::before_mutation(query, mutation).await?;
235
236        let table_name = query.format_table_name::<S>();
237        let filters = query.format_filters::<S>();
238        let updates = mutation.format_updates::<S>();
239        let sql = format!("UPDATE {table_name} SET {updates} {filters};");
240        let mut ctx = S::before_scan(&sql).await?;
241        ctx.set_query(sql);
242
243        let rows_affected = connection.execute(ctx.query()).await?.rows_affected();
244        total_rows += rows_affected;
245        ctx.set_query_result(rows_affected, true);
246        S::after_scan(&ctx).await?;
247        S::after_mutation(&ctx).await?;
248
249        // Commits the transaction
250        transaction.commit().await?;
251        Ok(total_rows)
252    }
253
254    async fn transactional_delete<S: Schema>(queries: (&Query, &Query)) -> Result<u64, Error> {
255        let mut transaction = Self::acquire_writer().await?.pool().begin().await?;
256        let connection = transaction.acquire().await?;
257
258        let query = queries.0;
259        Self::before_query(query).await?;
260
261        let table_name = query.format_table_name::<Self>();
262        let filters = query.format_filters::<Self>();
263        let sql = format!("DELETE FROM {table_name} {filters};");
264        let mut ctx = Self::before_scan(&sql).await?;
265        ctx.set_query(sql);
266
267        let mut total_rows = 0;
268        let rows_affected = connection.execute(ctx.query()).await?.rows_affected();
269        total_rows += rows_affected;
270        ctx.set_query_result(rows_affected, true);
271        Self::after_scan(&ctx).await?;
272        Self::after_query(&ctx).await?;
273
274        let query = queries.1;
275        S::before_query(query).await?;
276
277        let table_name = query.format_table_name::<S>();
278        let filters = query.format_filters::<S>();
279        let sql = format!("DELETE FROM {table_name} {filters};");
280        let mut ctx = S::before_scan(&sql).await?;
281        ctx.set_query(sql);
282
283        let rows_affected = connection.execute(ctx.query()).await?.rows_affected();
284        total_rows += rows_affected;
285        ctx.set_query_result(rows_affected, true);
286        S::after_scan(&ctx).await?;
287        S::after_query(&ctx).await?;
288
289        // Commits the transaction
290        transaction.commit().await?;
291        Ok(total_rows)
292    }
293}