1use crate::error::DbError;
2use crate::mysql::condition::{Condition, SqlValue};
3use crate::mysql::field::FieldType;
4use sqlx::Transaction as SqlxTransaction;
5use std::collections::HashMap;
6
7pub struct Transaction {
9 tx: Option<SqlxTransaction<'static, sqlx::MySql>>,
10 enable_logging: bool,
11}
12
13impl Transaction {
14 pub(crate) fn new(tx: SqlxTransaction<'static, sqlx::MySql>, enable_logging: bool) -> Self {
16 Self {
17 tx: Some(tx),
18 enable_logging,
19 }
20 }
21
22 pub async fn commit(mut self) -> Result<(), DbError> {
24 if self.enable_logging {
25 log::debug!("提交事务");
26 }
27
28 if let Some(tx) = self.tx.take() {
29 tx.commit().await?;
30 }
31
32 Ok(())
33 }
34
35 pub async fn rollback(mut self) -> Result<(), DbError> {
37 if self.enable_logging {
38 log::debug!("回滚事务");
39 }
40
41 if let Some(tx) = self.tx.take() {
42 tx.rollback().await?;
43 }
44
45 Ok(())
46 }
47
48 pub async fn execute(&mut self, sql: &str) -> Result<u64, DbError> {
50 if self.enable_logging {
51 log::debug!("事务中执行: {}", sql);
52 }
53
54 if let Some(tx) = &mut self.tx {
55 let result = sqlx::query(sql).execute(&mut **tx).await?;
56 Ok(result.rows_affected())
57 } else {
58 Err(DbError::TransactionError("事务已提交或回滚".to_string()))
59 }
60 }
61
62 pub async fn execute_with_params(
88 &mut self,
89 sql: &str,
90 params: Vec<serde_json::Value>,
91 ) -> Result<u64, DbError> {
92 if self.enable_logging {
93 log::debug!("事务中执行参数化语句: {}, 参数数量: {}", sql, params.len());
94 }
95
96 if let Some(tx) = &mut self.tx {
97 let mut query = sqlx::query(sql);
99 for param in ¶ms {
100 query = bind_json_param_tx(query, param);
101 }
102 let result = query.execute(&mut **tx).await?;
103 Ok(result.rows_affected())
104 } else {
105 Err(DbError::TransactionError("事务已提交或回滚".to_string()))
106 }
107 }
108
109 pub async fn query_with_params<T>(
135 &mut self,
136 sql: &str,
137 params: Vec<serde_json::Value>,
138 ) -> Result<Vec<T>, DbError>
139 where
140 T: for<'r> sqlx::FromRow<'r, sqlx::mysql::MySqlRow> + Send + Unpin,
141 {
142 if self.enable_logging {
143 log::debug!("事务中执行参数化查询: {}, 参数数量: {}", sql, params.len());
144 }
145
146 if let Some(tx) = &mut self.tx {
147 let mut query = sqlx::query_as::<_, T>(sql);
149 for param in ¶ms {
150 query = bind_json_param_as_tx(query, param);
151 }
152 let rows = query.fetch_all(&mut **tx).await?;
153 Ok(rows)
154 } else {
155 Err(DbError::TransactionError("事务已提交或回滚".to_string()))
156 }
157 }
158
159 pub fn table(&mut self, table_name: &str) -> TransactionQueryBuilder<'_> {
193 TransactionQueryBuilder::new(self, table_name)
194 }
195}
196
197pub struct TransactionQueryBuilder<'a> {
201 tx: &'a mut Transaction,
202 table: String,
203 conditions: Vec<Condition>,
204 field_types: HashMap<String, FieldType>,
205}
206
207impl<'a> TransactionQueryBuilder<'a> {
208 fn new(tx: &'a mut Transaction, table_name: &str) -> Self {
210 Self {
211 tx,
212 table: table_name.to_string(),
213 conditions: Vec::new(),
214 field_types: HashMap::new(),
215 }
216 }
217
218 pub fn json(mut self, field: &str) -> Self {
220 self.field_types.insert(field.to_string(), FieldType::Json);
221 self
222 }
223
224 pub fn datetime(mut self, field: &str) -> Self {
226 self.field_types
227 .insert(field.to_string(), FieldType::DateTime);
228 self
229 }
230
231 pub fn timestamp(mut self, field: &str) -> Self {
233 self.field_types
234 .insert(field.to_string(), FieldType::Timestamp);
235 self
236 }
237
238 pub fn decimal(mut self, field: &str) -> Self {
240 self.field_types
241 .insert(field.to_string(), FieldType::Decimal);
242 self
243 }
244
245 pub fn blob(mut self, field: &str) -> Self {
247 self.field_types.insert(field.to_string(), FieldType::Blob);
248 self
249 }
250
251 pub fn text(mut self, field: &str) -> Self {
253 self.field_types.insert(field.to_string(), FieldType::Text);
254 self
255 }
256
257 pub fn where_and<V>(mut self, field: &str, op: &str, value: V) -> Self
259 where
260 V: Into<SqlValue>,
261 {
262 let sql_value = value.into();
263 let condition = match op {
264 "=" => Condition::Eq(field.to_string(), sql_value),
265 "!=" => Condition::Ne(field.to_string(), sql_value),
266 ">" => Condition::Gt(field.to_string(), sql_value),
267 "<" => Condition::Lt(field.to_string(), sql_value),
268 ">=" => Condition::Gte(field.to_string(), sql_value),
269 "<=" => Condition::Lte(field.to_string(), sql_value),
270 "like" | "LIKE" => {
271 if let SqlValue::String(s) = sql_value {
272 Condition::Like(field.to_string(), s)
273 } else {
274 Condition::Like(field.to_string(), format!("{:?}", sql_value))
275 }
276 }
277 _ => panic!("不支持的操作符: {}", op),
278 };
279
280 self.conditions.push(condition);
281 self
282 }
283
284 pub async fn insert<T>(self, data: &T) -> Result<u64, DbError>
298 where
299 T: serde::Serialize,
300 {
301 if self.tx.enable_logging {
303 log::debug!("事务中执行 insert() 操作,表: {}", self.table);
304 }
305
306 let json_data = serde_json::to_value(data)
308 .map_err(|e| DbError::SerializationError(format!("数据序列化失败: {}", e)))?;
309
310 let mut generator = crate::mysql::query_builder::SqlGenerator::new();
312 generator.build_insert(&self.table, &json_data, &self.field_types)?;
313
314 let sql = generator.get_sql();
315 let params = generator.get_params();
316
317 if self.tx.enable_logging {
319 log::debug!("事务中执行 insert() SQL: {}", sql);
320 log::debug!("参数: {:?}", params);
321 }
322
323 let mut query = sqlx::query(sql);
325
326 for param in params {
328 query = bind_execute_param(query, param);
329 }
330
331 if let Some(tx) = &mut self.tx.tx {
333 let result = query.execute(&mut **tx).await?;
334 let last_insert_id = result.last_insert_id();
335
336 if self.tx.enable_logging {
337 log::debug!("事务中 insert() 成功,插入 ID: {}", last_insert_id);
338 }
339
340 Ok(last_insert_id)
341 } else {
342 Err(DbError::TransactionError("事务已提交或回滚".to_string()))
343 }
344 }
345
346 pub async fn update<T>(self, data: &T) -> Result<u64, DbError>
361 where
362 T: serde::Serialize,
363 {
364 if self.tx.enable_logging {
366 log::debug!("事务中执行 update() 操作,表: {}", self.table);
367 }
368
369 if self.conditions.is_empty() {
371 log::warn!("事务中 update() 操作缺少 WHERE 条件,禁止全表更新");
372 return Err(DbError::MissingWhereClause);
373 }
374
375 let json_data = serde_json::to_value(data)
377 .map_err(|e| DbError::SerializationError(format!("数据序列化失败: {}", e)))?;
378
379 let mut generator = crate::mysql::query_builder::SqlGenerator::new();
381 generator.build_update(&self.table, &json_data, &self.field_types, &self.conditions)?;
382
383 let sql = generator.get_sql();
384 let params = generator.get_params();
385
386 if self.tx.enable_logging {
388 log::debug!("事务中执行 update() SQL: {}", sql);
389 log::debug!("参数: {:?}", params);
390 }
391
392 let mut query = sqlx::query(sql);
394
395 for param in params {
397 query = bind_execute_param(query, param);
398 }
399
400 if let Some(tx) = &mut self.tx.tx {
402 let result = query.execute(&mut **tx).await?;
403 let rows_affected = result.rows_affected();
404
405 if self.tx.enable_logging {
406 log::debug!("事务中 update() 成功,影响 {} 行", rows_affected);
407 }
408
409 Ok(rows_affected)
410 } else {
411 Err(DbError::TransactionError("事务已提交或回滚".to_string()))
412 }
413 }
414
415 pub async fn delete(self) -> Result<u64, DbError> {
424 if self.tx.enable_logging {
426 log::debug!("事务中执行 delete() 操作,表: {}", self.table);
427 }
428
429 if self.conditions.is_empty() {
431 log::warn!("事务中 delete() 操作缺少 WHERE 条件,禁止全表删除");
432 return Err(DbError::MissingWhereClause);
433 }
434
435 let mut generator = crate::mysql::query_builder::SqlGenerator::new();
437 generator.build_delete(&self.table, &self.conditions)?;
438
439 let sql = generator.get_sql();
440 let params = generator.get_params();
441
442 if self.tx.enable_logging {
444 log::debug!("事务中执行 delete() SQL: {}", sql);
445 log::debug!("参数: {:?}", params);
446 }
447
448 let mut query = sqlx::query(sql);
450
451 for param in params {
453 query = bind_execute_param(query, param);
454 }
455
456 if let Some(tx) = &mut self.tx.tx {
458 let result = query.execute(&mut **tx).await?;
459 let rows_affected = result.rows_affected();
460
461 if self.tx.enable_logging {
462 log::debug!("事务中 delete() 成功,影响 {} 行", rows_affected);
463 }
464
465 Ok(rows_affected)
466 } else {
467 Err(DbError::TransactionError("事务已提交或回滚".to_string()))
468 }
469 }
470}
471
472fn bind_execute_param<'q>(
481 query: sqlx::query::Query<'q, sqlx::MySql, sqlx::mysql::MySqlArguments>,
482 param: &SqlValue,
483) -> sqlx::query::Query<'q, sqlx::MySql, sqlx::mysql::MySqlArguments> {
484 match param {
485 SqlValue::Null => query.bind(Option::<i32>::None),
486 SqlValue::Bool(b) => query.bind(*b),
487 SqlValue::Int(i) => query.bind(*i),
488 SqlValue::Float(f) => query.bind(*f),
489 SqlValue::String(s) => query.bind(s.clone()),
490 SqlValue::Bytes(b) => query.bind(b.clone()),
491 SqlValue::Json(j) => query.bind(j.to_string()),
492 SqlValue::DateTime(dt) => query.bind(*dt),
493 SqlValue::Timestamp(ts) => query.bind(*ts),
494 }
495}
496
497fn bind_json_param_tx<'q>(
506 query: sqlx::query::Query<'q, sqlx::MySql, sqlx::mysql::MySqlArguments>,
507 param: &serde_json::Value,
508) -> sqlx::query::Query<'q, sqlx::MySql, sqlx::mysql::MySqlArguments> {
509 match param {
510 serde_json::Value::String(s) => query.bind(s.clone()),
512 serde_json::Value::Number(n) => {
514 if let Some(i) = n.as_i64() {
515 query.bind(i)
516 } else if let Some(f) = n.as_f64() {
517 query.bind(f.to_string())
519 } else {
520 query.bind(Option::<String>::None)
521 }
522 }
523 serde_json::Value::Bool(b) => query.bind(*b),
525 serde_json::Value::Null => query.bind(Option::<String>::None),
527 other => query.bind(other.to_string()),
529 }
530}
531
532fn bind_json_param_as_tx<'q, T>(
541 query: sqlx::query::QueryAs<'q, sqlx::MySql, T, sqlx::mysql::MySqlArguments>,
542 param: &serde_json::Value,
543) -> sqlx::query::QueryAs<'q, sqlx::MySql, T, sqlx::mysql::MySqlArguments>
544where
545 T: for<'r> sqlx::FromRow<'r, sqlx::mysql::MySqlRow>,
546{
547 match param {
548 serde_json::Value::String(s) => query.bind(s.clone()),
550 serde_json::Value::Number(n) => {
552 if let Some(i) = n.as_i64() {
553 query.bind(i)
554 } else if let Some(f) = n.as_f64() {
555 query.bind(f.to_string())
557 } else {
558 query.bind(Option::<String>::None)
559 }
560 }
561 serde_json::Value::Bool(b) => query.bind(*b),
563 serde_json::Value::Null => query.bind(Option::<String>::None),
565 other => query.bind(other.to_string()),
567 }
568}