Skip to main content

yang_db/mysql/
transaction.rs

1use crate::error::DbError;
2use crate::mysql::condition::{Condition, SqlValue};
3use crate::mysql::field::FieldType;
4use sqlx::Transaction as SqlxTransaction;
5use std::collections::HashMap;
6
7/// 数据库事务
8pub struct Transaction {
9    tx: Option<SqlxTransaction<'static, sqlx::MySql>>,
10    enable_logging: bool,
11}
12
13impl Transaction {
14    /// 创建新的事务实例
15    pub(crate) fn new(tx: SqlxTransaction<'static, sqlx::MySql>, enable_logging: bool) -> Self {
16        Self {
17            tx: Some(tx),
18            enable_logging,
19        }
20    }
21
22    /// 提交事务
23    pub async fn commit(mut self) -> Result<(), DbError> {
24        if self.enable_logging {
25            log::debug!("提交事务");
26        }
27
28        if let Some(tx) = self.tx.take() {
29            tx.commit().await?;
30        }
31
32        Ok(())
33    }
34
35    /// 回滚事务
36    pub async fn rollback(mut self) -> Result<(), DbError> {
37        if self.enable_logging {
38            log::debug!("回滚事务");
39        }
40
41        if let Some(tx) = self.tx.take() {
42            tx.rollback().await?;
43        }
44
45        Ok(())
46    }
47
48    /// 执行原生 SQL
49    pub async fn execute(&mut self, sql: &str) -> Result<u64, DbError> {
50        if self.enable_logging {
51            log::debug!("事务中执行: {}", sql);
52        }
53
54        if let Some(tx) = &mut self.tx {
55            let result = sqlx::query(sql).execute(&mut **tx).await?;
56            Ok(result.rows_affected())
57        } else {
58            Err(DbError::TransactionError("事务已提交或回滚".to_string()))
59        }
60    }
61
62    /// 执行带参数的原生 SQL(参数化查询,防止 SQL 注入)
63    ///
64    /// # 参数
65    /// - sql: SQL 语句,使用 `?` 作为参数占位符
66    /// - params: 参数列表,使用 `serde_json::Value` 类型
67    ///
68    /// # 返回
69    /// - Ok(u64): 受影响的行数
70    /// - Err(DbError): 执行失败错误
71    ///
72    /// # 示例
73    ///
74    /// ```no_run
75    /// use yang_db::Database;
76    /// use serde_json::json;
77    ///
78    /// # async fn example() -> Result<(), yang_db::DbError> {
79    /// let db = Database::connect("mysql://root:password@localhost/test").await?;
80    /// let mut tx = db.transaction().await?;
81    /// let params = vec![json!("张三"), json!("张三@example.com")];
82    /// tx.execute_with_params("INSERT INTO users (name, email) VALUES (?, ?)", params).await?;
83    /// tx.commit().await?;
84    /// # Ok(())
85    /// # }
86    /// ```
87    pub async fn execute_with_params(
88        &mut self,
89        sql: &str,
90        params: Vec<serde_json::Value>,
91    ) -> Result<u64, DbError> {
92        if self.enable_logging {
93            log::debug!("事务中执行参数化语句: {}, 参数数量: {}", sql, params.len());
94        }
95
96        if let Some(tx) = &mut self.tx {
97            // 构建查询并逐一绑定参数
98            let mut query = sqlx::query(sql);
99            for param in &params {
100                query = bind_json_param_tx(query, param);
101            }
102            let result = query.execute(&mut **tx).await?;
103            Ok(result.rows_affected())
104        } else {
105            Err(DbError::TransactionError("事务已提交或回滚".to_string()))
106        }
107    }
108
109    /// 执行带参数的原生 SELECT 查询(参数化查询,防止 SQL 注入)
110    ///
111    /// # 参数
112    /// - sql: SQL 查询语句,使用 `?` 作为参数占位符
113    /// - params: 参数列表,使用 `serde_json::Value` 类型
114    ///
115    /// # 返回
116    /// - Ok(Vec<T>): 查询结果列表
117    /// - Err(DbError): 查询失败错误
118    ///
119    /// # 示例
120    ///
121    /// ```no_run
122    /// use yang_db::Database;
123    /// use serde_json::json;
124    ///
125    /// # async fn example() -> Result<(), yang_db::DbError> {
126    /// let db = Database::connect("mysql://root:password@localhost/test").await?;
127    /// let mut tx = db.transaction().await?;
128    /// let params = vec![json!(1i64)];
129    /// // let users: Vec<User> = tx.query_with_params("SELECT * FROM users WHERE id = ?", params).await?;
130    /// tx.commit().await?;
131    /// # Ok(())
132    /// # }
133    /// ```
134    pub async fn query_with_params<T>(
135        &mut self,
136        sql: &str,
137        params: Vec<serde_json::Value>,
138    ) -> Result<Vec<T>, DbError>
139    where
140        T: for<'r> sqlx::FromRow<'r, sqlx::mysql::MySqlRow> + Send + Unpin,
141    {
142        if self.enable_logging {
143            log::debug!("事务中执行参数化查询: {}, 参数数量: {}", sql, params.len());
144        }
145
146        if let Some(tx) = &mut self.tx {
147            // 构建查询并逐一绑定参数
148            let mut query = sqlx::query_as::<_, T>(sql);
149            for param in &params {
150                query = bind_json_param_as_tx(query, param);
151            }
152            let rows = query.fetch_all(&mut **tx).await?;
153            Ok(rows)
154        } else {
155            Err(DbError::TransactionError("事务已提交或回滚".to_string()))
156        }
157    }
158
159    /// 选择表,返回事务中的查询构建器
160    ///
161    /// # 参数
162    /// - table_name: 表名
163    ///
164    /// # 返回
165    /// - TransactionQueryBuilder: 事务查询构建器
166    ///
167    /// # 示例
168    /// ```no_run
169    /// use yang_db::Database;
170    /// use serde_json::json;
171    ///
172    /// # async fn example() -> Result<(), yang_db::DbError> {
173    /// let db = Database::connect("mysql://root:password@localhost/test").await?;
174    /// let mut tx = db.transaction().await?;
175    ///
176    /// // 在事务中插入数据
177    /// let user_data = json!({"name": "张三", "email": "zhangsan@example.com"});
178    /// let user_id = tx.table("users").insert(&user_data).await?;
179    ///
180    /// // 在事务中更新数据
181    /// let update_data = json!({"status": 1});
182    /// tx.table("users")
183    ///     .where_and("id", "=", user_id)
184    ///     .update(&update_data)
185    ///     .await?;
186    ///
187    /// // 提交事务
188    /// tx.commit().await?;
189    /// # Ok(())
190    /// # }
191    /// ```
192    pub fn table(&mut self, table_name: &str) -> TransactionQueryBuilder<'_> {
193        TransactionQueryBuilder::new(self, table_name)
194    }
195}
196
197/// 事务查询构建器
198///
199/// 用于在事务上下文中构建和执行查询
200pub struct TransactionQueryBuilder<'a> {
201    tx: &'a mut Transaction,
202    table: String,
203    conditions: Vec<Condition>,
204    field_types: HashMap<String, FieldType>,
205}
206
207impl<'a> TransactionQueryBuilder<'a> {
208    /// 创建新的事务查询构建器
209    fn new(tx: &'a mut Transaction, table_name: &str) -> Self {
210        Self {
211            tx,
212            table: table_name.to_string(),
213            conditions: Vec::new(),
214            field_types: HashMap::new(),
215        }
216    }
217
218    /// 标记字段为 JSON 类型
219    pub fn json(mut self, field: &str) -> Self {
220        self.field_types.insert(field.to_string(), FieldType::Json);
221        self
222    }
223
224    /// 标记字段为 DATETIME 类型
225    pub fn datetime(mut self, field: &str) -> Self {
226        self.field_types
227            .insert(field.to_string(), FieldType::DateTime);
228        self
229    }
230
231    /// 标记字段为 TIMESTAMP 类型
232    pub fn timestamp(mut self, field: &str) -> Self {
233        self.field_types
234            .insert(field.to_string(), FieldType::Timestamp);
235        self
236    }
237
238    /// 标记字段为 DECIMAL 类型
239    pub fn decimal(mut self, field: &str) -> Self {
240        self.field_types
241            .insert(field.to_string(), FieldType::Decimal);
242        self
243    }
244
245    /// 标记字段为 BLOB 类型
246    pub fn blob(mut self, field: &str) -> Self {
247        self.field_types.insert(field.to_string(), FieldType::Blob);
248        self
249    }
250
251    /// 标记字段为 TEXT 类型
252    pub fn text(mut self, field: &str) -> Self {
253        self.field_types.insert(field.to_string(), FieldType::Text);
254        self
255    }
256
257    /// 添加 AND 条件
258    pub fn where_and<V>(mut self, field: &str, op: &str, value: V) -> Self
259    where
260        V: Into<SqlValue>,
261    {
262        let sql_value = value.into();
263        let condition = match op {
264            "=" => Condition::Eq(field.to_string(), sql_value),
265            "!=" => Condition::Ne(field.to_string(), sql_value),
266            ">" => Condition::Gt(field.to_string(), sql_value),
267            "<" => Condition::Lt(field.to_string(), sql_value),
268            ">=" => Condition::Gte(field.to_string(), sql_value),
269            "<=" => Condition::Lte(field.to_string(), sql_value),
270            "like" | "LIKE" => {
271                if let SqlValue::String(s) = sql_value {
272                    Condition::Like(field.to_string(), s)
273                } else {
274                    Condition::Like(field.to_string(), format!("{:?}", sql_value))
275                }
276            }
277            _ => panic!("不支持的操作符: {}", op),
278        };
279
280        self.conditions.push(condition);
281        self
282    }
283
284    /// 插入数据
285    ///
286    /// 在事务中执行 INSERT 操作
287    ///
288    /// # 类型参数
289    /// - T: 数据类型,必须实现 Serialize trait
290    ///
291    /// # 参数
292    /// - data: 要插入的数据
293    ///
294    /// # 返回
295    /// - Ok(u64): 插入成功,返回插入记录的 ID(自增主键)
296    /// - Err(DbError): 插入失败
297    pub async fn insert<T>(self, data: &T) -> Result<u64, DbError>
298    where
299        T: serde::Serialize,
300    {
301        // 记录日志
302        if self.tx.enable_logging {
303            log::debug!("事务中执行 insert() 操作,表: {}", self.table);
304        }
305
306        // 将数据序列化为 JSON
307        let json_data = serde_json::to_value(data)
308            .map_err(|e| DbError::SerializationError(format!("数据序列化失败: {}", e)))?;
309
310        // 生成 INSERT 语句
311        let mut generator = crate::mysql::query_builder::SqlGenerator::new();
312        generator.build_insert(&self.table, &json_data, &self.field_types)?;
313
314        let sql = generator.get_sql();
315        let params = generator.get_params();
316
317        // 记录日志
318        if self.tx.enable_logging {
319            log::debug!("事务中执行 insert() SQL: {}", sql);
320            log::debug!("参数: {:?}", params);
321        }
322
323        // 构建查询
324        let mut query = sqlx::query(sql);
325
326        // 绑定参数
327        for param in params {
328            query = bind_execute_param(query, param);
329        }
330
331        // 执行插入
332        if let Some(tx) = &mut self.tx.tx {
333            let result = query.execute(&mut **tx).await?;
334            let last_insert_id = result.last_insert_id();
335
336            if self.tx.enable_logging {
337                log::debug!("事务中 insert() 成功,插入 ID: {}", last_insert_id);
338            }
339
340            Ok(last_insert_id)
341        } else {
342            Err(DbError::TransactionError("事务已提交或回滚".to_string()))
343        }
344    }
345
346    /// 更新数据
347    ///
348    /// 在事务中执行 UPDATE 操作
349    /// 为了防止误操作,必须提供 WHERE 条件,否则会返回错误
350    ///
351    /// # 类型参数
352    /// - T: 数据类型,必须实现 Serialize trait
353    ///
354    /// # 参数
355    /// - data: 要更新的数据
356    ///
357    /// # 返回
358    /// - Ok(u64): 更新成功,返回受影响的行数
359    /// - Err(DbError): 更新失败
360    pub async fn update<T>(self, data: &T) -> Result<u64, DbError>
361    where
362        T: serde::Serialize,
363    {
364        // 记录日志
365        if self.tx.enable_logging {
366            log::debug!("事务中执行 update() 操作,表: {}", self.table);
367        }
368
369        // 检查是否有 WHERE 条件
370        if self.conditions.is_empty() {
371            log::warn!("事务中 update() 操作缺少 WHERE 条件,禁止全表更新");
372            return Err(DbError::MissingWhereClause);
373        }
374
375        // 将数据序列化为 JSON
376        let json_data = serde_json::to_value(data)
377            .map_err(|e| DbError::SerializationError(format!("数据序列化失败: {}", e)))?;
378
379        // 生成 UPDATE 语句
380        let mut generator = crate::mysql::query_builder::SqlGenerator::new();
381        generator.build_update(&self.table, &json_data, &self.field_types, &self.conditions)?;
382
383        let sql = generator.get_sql();
384        let params = generator.get_params();
385
386        // 记录日志
387        if self.tx.enable_logging {
388            log::debug!("事务中执行 update() SQL: {}", sql);
389            log::debug!("参数: {:?}", params);
390        }
391
392        // 构建查询
393        let mut query = sqlx::query(sql);
394
395        // 绑定参数
396        for param in params {
397            query = bind_execute_param(query, param);
398        }
399
400        // 执行更新
401        if let Some(tx) = &mut self.tx.tx {
402            let result = query.execute(&mut **tx).await?;
403            let rows_affected = result.rows_affected();
404
405            if self.tx.enable_logging {
406                log::debug!("事务中 update() 成功,影响 {} 行", rows_affected);
407            }
408
409            Ok(rows_affected)
410        } else {
411            Err(DbError::TransactionError("事务已提交或回滚".to_string()))
412        }
413    }
414
415    /// 删除数据
416    ///
417    /// 在事务中执行 DELETE 操作
418    /// 为了防止误操作,必须提供 WHERE 条件,否则会返回错误
419    ///
420    /// # 返回
421    /// - Ok(u64): 删除成功,返回受影响的行数
422    /// - Err(DbError): 删除失败
423    pub async fn delete(self) -> Result<u64, DbError> {
424        // 记录日志
425        if self.tx.enable_logging {
426            log::debug!("事务中执行 delete() 操作,表: {}", self.table);
427        }
428
429        // 检查是否有 WHERE 条件
430        if self.conditions.is_empty() {
431            log::warn!("事务中 delete() 操作缺少 WHERE 条件,禁止全表删除");
432            return Err(DbError::MissingWhereClause);
433        }
434
435        // 生成 DELETE 语句
436        let mut generator = crate::mysql::query_builder::SqlGenerator::new();
437        generator.build_delete(&self.table, &self.conditions)?;
438
439        let sql = generator.get_sql();
440        let params = generator.get_params();
441
442        // 记录日志
443        if self.tx.enable_logging {
444            log::debug!("事务中执行 delete() SQL: {}", sql);
445            log::debug!("参数: {:?}", params);
446        }
447
448        // 构建查询
449        let mut query = sqlx::query(sql);
450
451        // 绑定参数
452        for param in params {
453            query = bind_execute_param(query, param);
454        }
455
456        // 执行删除
457        if let Some(tx) = &mut self.tx.tx {
458            let result = query.execute(&mut **tx).await?;
459            let rows_affected = result.rows_affected();
460
461            if self.tx.enable_logging {
462                log::debug!("事务中 delete() 成功,影响 {} 行", rows_affected);
463            }
464
465            Ok(rows_affected)
466        } else {
467            Err(DbError::TransactionError("事务已提交或回滚".to_string()))
468        }
469    }
470}
471
472/// 绑定参数到执行查询(用于事务中的 INSERT/UPDATE/DELETE)
473///
474/// # 参数
475/// - query: sqlx 查询对象
476/// - param: SQL 参数值
477///
478/// # 返回
479/// - 绑定参数后的查询对象
480fn bind_execute_param<'q>(
481    query: sqlx::query::Query<'q, sqlx::MySql, sqlx::mysql::MySqlArguments>,
482    param: &SqlValue,
483) -> sqlx::query::Query<'q, sqlx::MySql, sqlx::mysql::MySqlArguments> {
484    match param {
485        SqlValue::Null => query.bind(Option::<i32>::None),
486        SqlValue::Bool(b) => query.bind(*b),
487        SqlValue::Int(i) => query.bind(*i),
488        SqlValue::Float(f) => query.bind(*f),
489        SqlValue::String(s) => query.bind(s.clone()),
490        SqlValue::Bytes(b) => query.bind(b.clone()),
491        SqlValue::Json(j) => query.bind(j.to_string()),
492        SqlValue::DateTime(dt) => query.bind(*dt),
493        SqlValue::Timestamp(ts) => query.bind(*ts),
494    }
495}
496
497/// 将 `serde_json::Value` 参数绑定到事务执行查询(用于参数化 INSERT/UPDATE/DELETE)
498///
499/// # 参数
500/// - query: sqlx 执行查询对象
501/// - param: JSON 参数值
502///
503/// # 返回
504/// - 绑定参数后的查询对象
505fn bind_json_param_tx<'q>(
506    query: sqlx::query::Query<'q, sqlx::MySql, sqlx::mysql::MySqlArguments>,
507    param: &serde_json::Value,
508) -> sqlx::query::Query<'q, sqlx::MySql, sqlx::mysql::MySqlArguments> {
509    match param {
510        // 字符串类型直接绑定
511        serde_json::Value::String(s) => query.bind(s.clone()),
512        // 数字类型转为 i64 绑定
513        serde_json::Value::Number(n) => {
514            if let Some(i) = n.as_i64() {
515                query.bind(i)
516            } else if let Some(f) = n.as_f64() {
517                // 浮点数转为字符串绑定,避免精度丢失
518                query.bind(f.to_string())
519            } else {
520                query.bind(Option::<String>::None)
521            }
522        }
523        // 布尔类型绑定
524        serde_json::Value::Bool(b) => query.bind(*b),
525        // NULL 类型绑定为 None
526        serde_json::Value::Null => query.bind(Option::<String>::None),
527        // 数组和对象类型序列化为 JSON 字符串绑定
528        other => query.bind(other.to_string()),
529    }
530}
531
532/// 将 `serde_json::Value` 参数绑定到事务 `query_as` 查询(用于参数化 SELECT)
533///
534/// # 参数
535/// - query: sqlx query_as 查询对象
536/// - param: JSON 参数值
537///
538/// # 返回
539/// - 绑定参数后的查询对象
540fn bind_json_param_as_tx<'q, T>(
541    query: sqlx::query::QueryAs<'q, sqlx::MySql, T, sqlx::mysql::MySqlArguments>,
542    param: &serde_json::Value,
543) -> sqlx::query::QueryAs<'q, sqlx::MySql, T, sqlx::mysql::MySqlArguments>
544where
545    T: for<'r> sqlx::FromRow<'r, sqlx::mysql::MySqlRow>,
546{
547    match param {
548        // 字符串类型直接绑定
549        serde_json::Value::String(s) => query.bind(s.clone()),
550        // 数字类型转为 i64 绑定
551        serde_json::Value::Number(n) => {
552            if let Some(i) = n.as_i64() {
553                query.bind(i)
554            } else if let Some(f) = n.as_f64() {
555                // 浮点数转为字符串绑定,避免精度丢失
556                query.bind(f.to_string())
557            } else {
558                query.bind(Option::<String>::None)
559            }
560        }
561        // 布尔类型绑定
562        serde_json::Value::Bool(b) => query.bind(*b),
563        // NULL 类型绑定为 None
564        serde_json::Value::Null => query.bind(Option::<String>::None),
565        // 数组和对象类型序列化为 JSON 字符串绑定
566        other => query.bind(other.to_string()),
567    }
568}