1use super::{
2 DatabaseDriver, EncodeColumn, executor::Executor, mutation::MutationExt, query::QueryExt,
3 schema::Schema,
4};
5use std::fmt::Display;
6use zino_core::{
7 BoxFuture, Map,
8 error::Error,
9 extension::JsonValueExt,
10 model::{Mutation, Query},
11};
12
13#[cfg(feature = "orm-sqlx")]
14use sqlx::Acquire;
15
16pub trait Transaction<K, Tx>: Schema<PrimaryKey = K>
58where
59 K: Default + Display + PartialEq,
60{
61 async fn transaction<F, T>(tx: F) -> Result<T, Error>
65 where
66 F: for<'t> FnOnce(&'t mut Tx) -> BoxFuture<'t, Result<T, Error>>;
67
68 async fn transactional_execute(queries: &[&str], params: Option<&Map>) -> Result<u64, Error>;
72
73 async fn transactional_insert<M: Schema>(self, models: Vec<M>) -> Result<u64, Error>;
75
76 async fn transactional_update<M: Schema>(
78 queries: (&Query, &Query),
79 mutations: (&mut Mutation, &mut Mutation),
80 ) -> Result<u64, Error>;
81
82 async fn transactional_delete<M: Schema>(queries: (&Query, &Query)) -> Result<u64, Error>;
84}
85
86#[cfg(feature = "orm-sqlx")]
87impl<'c, M, K> Transaction<K, sqlx::Transaction<'c, DatabaseDriver>> for M
88where
89 M: Schema<PrimaryKey = K>,
90 K: Default + Display + PartialEq,
91{
92 async fn transaction<F, T>(tx: F) -> Result<T, Error>
93 where
94 F: for<'t> FnOnce(
95 &'t mut sqlx::Transaction<'c, DatabaseDriver>,
96 ) -> BoxFuture<'t, Result<T, Error>>,
97 {
98 let mut transaction = Self::acquire_writer().await?.pool().begin().await?;
99 let data = tx(&mut transaction).await?;
100 transaction.commit().await?;
101 Ok(data)
102 }
103
104 async fn transactional_execute(queries: &[&str], params: Option<&Map>) -> Result<u64, Error> {
105 let mut transaction = Self::acquire_writer().await?.pool().begin().await?;
106 let connection = transaction.acquire().await?;
107
108 let mut total_rows = 0;
109 for query in queries {
110 let (sql, values) = Query::prepare_query(query, params);
111 let mut ctx = Self::before_scan(&sql).await?;
112 ctx.set_query(sql);
113
114 let mut arguments = values
115 .iter()
116 .map(|v| v.to_string_unquoted())
117 .collect::<Vec<_>>();
118 let rows_affected = connection
119 .execute_with(ctx.query(), &arguments)
120 .await?
121 .rows_affected();
122 total_rows += rows_affected;
123 ctx.append_arguments(&mut arguments);
124 ctx.set_query_result(rows_affected, true);
125 Self::after_scan(&ctx).await?;
126 }
127 transaction.commit().await?;
128 Ok(total_rows)
129 }
130
131 async fn transactional_insert<S: Schema>(mut self, associations: Vec<S>) -> Result<u64, Error> {
132 let mut transaction = Self::acquire_writer().await?.pool().begin().await?;
133 let connection = transaction.acquire().await?;
134
135 let model_data = self.before_insert().await?;
137 let table_name = if let Some(table) = self.before_prepare().await? {
138 Query::escape_table_name(&table)
139 } else {
140 Query::escape_table_name(Self::table_name())
141 };
142 let map = self.into_map();
143 let columns = Self::columns();
144
145 let mut fields = Vec::with_capacity(columns.len());
146 let values = columns
147 .iter()
148 .filter_map(|col| {
149 if col.auto_increment() {
150 None
151 } else {
152 let name = col.name();
153 fields.push(name);
154 Some(col.encode_value(map.get(name)))
155 }
156 })
157 .collect::<Vec<_>>()
158 .join(", ");
159 let fields = fields.join(", ");
160 let sql = format!("INSERT INTO {table_name} ({fields}) VALUES ({values});");
161 let mut ctx = Self::before_scan(&sql).await?;
162 ctx.set_query(sql);
163
164 let mut total_rows = 0;
165 let query_result = connection.execute(ctx.query()).await?;
166 let (last_insert_id, rows_affected) = Query::parse_query_result(query_result);
167 let success = rows_affected == 1;
168 if let Some(last_insert_id) = last_insert_id {
169 ctx.set_last_insert_id(last_insert_id);
170 }
171 total_rows += rows_affected;
172 ctx.set_query_result(rows_affected, success);
173 Self::after_scan(&ctx).await?;
174 Self::after_insert(&ctx, model_data).await?;
175
176 let columns = S::columns();
178 let mut values = Vec::with_capacity(associations.len());
179 for mut association in associations.into_iter() {
180 let _association_data = association.before_insert().await?;
181 let map = association.into_map();
182 let entries = columns
183 .iter()
184 .map(|col| col.encode_value(map.get(col.name())))
185 .collect::<Vec<_>>()
186 .join(", ");
187 values.push(format!("({entries})"));
188 }
189
190 let table_name = Query::escape_table_name(S::table_name());
191 let fields = S::fields().join(", ");
192 let values = values.join(", ");
193 let sql = format!("INSERT INTO {table_name} ({fields}) VALUES {values};");
194 let mut ctx = S::before_scan(&sql).await?;
195 ctx.set_query(sql);
196
197 let rows_affected = connection.execute(ctx.query()).await?.rows_affected();
198 total_rows += rows_affected;
199 ctx.set_query_result(rows_affected, true);
200 S::after_scan(&ctx).await?;
201
202 transaction.commit().await?;
204 Ok(total_rows)
205 }
206
207 async fn transactional_update<S: Schema>(
208 queries: (&Query, &Query),
209 mutations: (&mut Mutation, &mut Mutation),
210 ) -> Result<u64, Error> {
211 let mut transaction = Self::acquire_writer().await?.pool().begin().await?;
212 let connection = transaction.acquire().await?;
213
214 let query = queries.0;
215 let mutation = mutations.0;
216 Self::before_mutation(query, mutation).await?;
217
218 let table_name = query.format_table_name::<Self>();
219 let filters = query.format_filters::<Self>();
220 let updates = mutation.format_updates::<Self>();
221 let sql = format!("UPDATE {table_name} SET {updates} {filters};");
222 let mut ctx = Self::before_scan(&sql).await?;
223 ctx.set_query(sql);
224
225 let mut total_rows = 0;
226 let rows_affected = connection.execute(ctx.query()).await?.rows_affected();
227 total_rows += rows_affected;
228 ctx.set_query_result(rows_affected, true);
229 Self::after_scan(&ctx).await?;
230 Self::after_mutation(&ctx).await?;
231
232 let query = queries.1;
233 let mutation = mutations.1;
234 S::before_mutation(query, mutation).await?;
235
236 let table_name = query.format_table_name::<S>();
237 let filters = query.format_filters::<S>();
238 let updates = mutation.format_updates::<S>();
239 let sql = format!("UPDATE {table_name} SET {updates} {filters};");
240 let mut ctx = S::before_scan(&sql).await?;
241 ctx.set_query(sql);
242
243 let rows_affected = connection.execute(ctx.query()).await?.rows_affected();
244 total_rows += rows_affected;
245 ctx.set_query_result(rows_affected, true);
246 S::after_scan(&ctx).await?;
247 S::after_mutation(&ctx).await?;
248
249 transaction.commit().await?;
251 Ok(total_rows)
252 }
253
254 async fn transactional_delete<S: Schema>(queries: (&Query, &Query)) -> Result<u64, Error> {
255 let mut transaction = Self::acquire_writer().await?.pool().begin().await?;
256 let connection = transaction.acquire().await?;
257
258 let query = queries.0;
259 Self::before_query(query).await?;
260
261 let table_name = query.format_table_name::<Self>();
262 let filters = query.format_filters::<Self>();
263 let sql = format!("DELETE FROM {table_name} {filters};");
264 let mut ctx = Self::before_scan(&sql).await?;
265 ctx.set_query(sql);
266
267 let mut total_rows = 0;
268 let rows_affected = connection.execute(ctx.query()).await?.rows_affected();
269 total_rows += rows_affected;
270 ctx.set_query_result(rows_affected, true);
271 Self::after_scan(&ctx).await?;
272 Self::after_query(&ctx).await?;
273
274 let query = queries.1;
275 S::before_query(query).await?;
276
277 let table_name = query.format_table_name::<S>();
278 let filters = query.format_filters::<S>();
279 let sql = format!("DELETE FROM {table_name} {filters};");
280 let mut ctx = S::before_scan(&sql).await?;
281 ctx.set_query(sql);
282
283 let rows_affected = connection.execute(ctx.query()).await?.rows_affected();
284 total_rows += rows_affected;
285 ctx.set_query_result(rows_affected, true);
286 S::after_scan(&ctx).await?;
287 S::after_query(&ctx).await?;
288
289 transaction.commit().await?;
291 Ok(total_rows)
292 }
293}