1use crate::mysql::condition::{Condition, SqlValue};
2use crate::mysql::field::{FieldType, JoinClause, OrderClause};
3use sqlx::mysql::MySqlPool;
4use std::collections::HashMap;
5
6#[allow(dead_code)]
8pub(crate) struct SqlGenerator {
9 sql: String,
11 params: Vec<SqlValue>,
13}
14
15#[allow(dead_code)]
16impl SqlGenerator {
17 pub(crate) fn new() -> Self {
19 Self {
20 sql: String::new(),
21 params: Vec::new(),
22 }
23 }
24
25 pub(crate) fn get_sql(&self) -> &str {
27 &self.sql
28 }
29
30 pub(crate) fn get_params(&self) -> &[SqlValue] {
32 &self.params
33 }
34
35 fn append(&mut self, fragment: &str) {
37 self.sql.push_str(fragment);
38 }
39
40 fn add_param(&mut self, param: SqlValue) {
42 self.params.push(param);
43 }
44
45 fn clear(&mut self) {
47 self.sql.clear();
48 self.params.clear();
49 }
50
51 fn build_select(&mut self, builder: &QueryBuilder) -> Result<(), crate::error::DbError> {
60 self.clear();
62
63 self.append("SELECT ");
65
66 if builder.distinct {
68 self.append("DISTINCT ");
69 }
70
71 if builder.fields.is_empty() {
73 self.append("*");
74 } else {
75 self.append(&builder.fields.join(", "));
76 }
77
78 self.append(" FROM ");
80 self.append(&builder.table);
81
82 if !builder.joins.is_empty() {
84 self.build_joins(&builder.joins);
85 }
86
87 if !builder.conditions.is_empty() {
89 self.build_where(&builder.conditions)?;
90 }
91
92 if !builder.group_by.is_empty() {
94 self.build_group_by(&builder.group_by);
95 }
96
97 if !builder.order_by.is_empty() {
99 self.build_order_by(&builder.order_by);
100 }
101
102 if let Some(limit) = builder.limit {
104 self.append(&format!(" LIMIT {}", limit));
105 }
106
107 if let Some(offset) = builder.offset {
109 self.append(&format!(" OFFSET {}", offset));
110 }
111
112 Ok(())
113 }
114
115 fn build_where(&mut self, conditions: &[Condition]) -> Result<(), crate::error::DbError> {
124 if conditions.is_empty() {
125 return Ok(());
126 }
127
128 self.append(" WHERE ");
129
130 if conditions.len() == 1 {
132 let sql = crate::mysql::condition::condition_to_sql(&conditions[0], &mut self.params);
133 self.append(&sql);
134 } else {
135 let combined = Condition::And(conditions.to_vec());
137 let sql = crate::mysql::condition::condition_to_sql(&combined, &mut self.params);
138 self.append(&sql);
139 }
140
141 Ok(())
142 }
143
144 fn build_joins(&mut self, joins: &[JoinClause]) {
149 use crate::mysql::field::JoinType;
150
151 for join in joins {
152 let join_type_str = match join.join_type {
153 JoinType::Inner => " INNER JOIN ",
154 JoinType::Left => " LEFT JOIN ",
155 JoinType::Right => " RIGHT JOIN ",
156 };
157
158 self.append(join_type_str);
159 self.append(&join.table);
160 self.append(" ON ");
161 self.append(&join.on);
162 }
163 }
164
165 fn build_order_by(&mut self, orders: &[OrderClause]) {
170 if orders.is_empty() {
171 return;
172 }
173
174 self.append(" ORDER BY ");
175
176 let order_parts: Vec<String> = orders
177 .iter()
178 .map(|order| {
179 let direction = if order.asc { "ASC" } else { "DESC" };
180 format!("{} {}", order.field, direction)
181 })
182 .collect();
183
184 self.append(&order_parts.join(", "));
185 }
186
187 fn build_group_by(&mut self, groups: &[String]) {
192 if groups.is_empty() {
193 return;
194 }
195
196 self.append(" GROUP BY ");
197 self.append(&groups.join(", "));
198 }
199
200 pub(crate) fn build_insert(
211 &mut self,
212 table: &str,
213 data: &serde_json::Value,
214 field_types: &HashMap<String, FieldType>,
215 ) -> Result<(), crate::error::DbError> {
216 self.clear();
218
219 let obj = data.as_object().ok_or_else(|| {
221 crate::error::DbError::SerializationError("插入数据必须是 JSON 对象".to_string())
222 })?;
223
224 if obj.is_empty() {
225 return Err(crate::error::DbError::SerializationError(
226 "插入数据不能为空".to_string(),
227 ));
228 }
229
230 let mut fields = Vec::new();
232 let mut placeholders = Vec::new();
233
234 for (key, value) in obj.iter() {
235 fields.push(key.clone());
236 placeholders.push("?".to_string());
237
238 let sql_value = self.json_value_to_sql_value(value, field_types.get(key))?;
240 self.add_param(sql_value);
241 }
242
243 self.append("INSERT INTO ");
245 self.append(table);
246 self.append(" (");
247 self.append(&fields.join(", "));
248 self.append(") VALUES (");
249 self.append(&placeholders.join(", "));
250 self.append(")");
251
252 Ok(())
253 }
254
255 pub(crate) fn build_insert_batch(
266 &mut self,
267 table: &str,
268 data_list: &[serde_json::Value],
269 field_types: &HashMap<String, FieldType>,
270 ) -> Result<(), crate::error::DbError> {
271 self.clear();
273
274 if data_list.is_empty() {
275 return Err(crate::error::DbError::SerializationError(
276 "批量插入数据不能为空".to_string(),
277 ));
278 }
279
280 let first_obj = data_list[0].as_object().ok_or_else(|| {
282 crate::error::DbError::SerializationError("插入数据必须是 JSON 对象".to_string())
283 })?;
284
285 if first_obj.is_empty() {
286 return Err(crate::error::DbError::SerializationError(
287 "插入数据不能为空".to_string(),
288 ));
289 }
290
291 let fields: Vec<String> = first_obj.keys().cloned().collect();
293
294 self.append("INSERT INTO ");
296 self.append(table);
297 self.append(" (");
298 self.append(&fields.join(", "));
299 self.append(") VALUES ");
300
301 let mut value_clauses = Vec::new();
303
304 for data in data_list {
305 let obj = data.as_object().ok_or_else(|| {
306 crate::error::DbError::SerializationError("插入数据必须是 JSON 对象".to_string())
307 })?;
308
309 let mut placeholders = Vec::new();
311
312 for field in &fields {
313 placeholders.push("?".to_string());
314
315 let value = obj.get(field).unwrap_or(&serde_json::Value::Null);
317
318 let sql_value = self.json_value_to_sql_value(value, field_types.get(field))?;
320 self.add_param(sql_value);
321 }
322
323 value_clauses.push(format!("({})", placeholders.join(", ")));
324 }
325
326 self.append(&value_clauses.join(", "));
328
329 Ok(())
330 }
331
332 pub(crate) fn build_update(
344 &mut self,
345 table: &str,
346 data: &serde_json::Value,
347 field_types: &HashMap<String, FieldType>,
348 conditions: &[Condition],
349 ) -> Result<(), crate::error::DbError> {
350 self.clear();
352
353 if conditions.is_empty() {
355 return Err(crate::error::DbError::MissingWhereClause);
356 }
357
358 let obj = data.as_object().ok_or_else(|| {
360 crate::error::DbError::SerializationError("更新数据必须是 JSON 对象".to_string())
361 })?;
362
363 if obj.is_empty() {
364 return Err(crate::error::DbError::SerializationError(
365 "更新数据不能为空".to_string(),
366 ));
367 }
368
369 self.append("UPDATE ");
371 self.append(table);
372 self.append(" SET ");
373
374 let mut set_clauses = Vec::new();
376
377 for (key, value) in obj.iter() {
378 set_clauses.push(format!("{} = ?", key));
379
380 let sql_value = self.json_value_to_sql_value(value, field_types.get(key))?;
382 self.add_param(sql_value);
383 }
384
385 self.append(&set_clauses.join(", "));
386
387 self.build_where(conditions)?;
389
390 Ok(())
391 }
392
393 pub(crate) fn build_delete(
403 &mut self,
404 table: &str,
405 conditions: &[Condition],
406 ) -> Result<(), crate::error::DbError> {
407 self.clear();
409
410 if conditions.is_empty() {
412 return Err(crate::error::DbError::MissingWhereClause);
413 }
414
415 self.append("DELETE FROM ");
417 self.append(table);
418
419 self.build_where(conditions)?;
421
422 Ok(())
423 }
424
425 fn json_value_to_sql_value(
435 &self,
436 value: &serde_json::Value,
437 field_type: Option<&FieldType>,
438 ) -> Result<SqlValue, crate::error::DbError> {
439 use serde_json::Value;
440
441 if let Some(ft) = field_type {
443 match ft {
444 FieldType::Json => {
445 return Ok(SqlValue::Json(value.clone()));
447 }
448 FieldType::DateTime => {
449 if let Some(s) = value.as_str() {
451 let dt = chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S")
452 .map_err(|e| {
453 crate::error::DbError::TypeConversionError(format!(
454 "无法解析 DATETIME 字符串: {}",
455 e
456 ))
457 })?;
458 return Ok(SqlValue::DateTime(dt));
459 }
460 }
461 FieldType::Timestamp => {
462 if let Some(i) = value.as_i64() {
464 return Ok(SqlValue::Timestamp(i));
465 }
466 }
467 FieldType::Decimal => {
468 if let Some(f) = value.as_f64() {
470 return Ok(SqlValue::Float(f));
471 } else if let Some(i) = value.as_i64() {
472 return Ok(SqlValue::Float(i as f64));
473 }
474 }
475 FieldType::Blob => {
476 if let Some(s) = value.as_str() {
478 use base64::Engine;
480 if let Ok(bytes) = base64::engine::general_purpose::STANDARD.decode(s) {
481 return Ok(SqlValue::Bytes(bytes));
482 }
483 return Ok(SqlValue::Bytes(s.as_bytes().to_vec()));
485 }
486 }
487 FieldType::Text => {
488 if let Some(s) = value.as_str() {
490 return Ok(SqlValue::String(s.to_string()));
491 }
492 }
493 FieldType::Standard => {
494 }
496 }
497 }
498
499 match value {
501 Value::Null => Ok(SqlValue::Null),
502 Value::Bool(b) => Ok(SqlValue::Bool(*b)),
503 Value::Number(n) => {
504 if let Some(i) = n.as_i64() {
505 Ok(SqlValue::Int(i))
506 } else if let Some(f) = n.as_f64() {
507 Ok(SqlValue::Float(f))
508 } else {
509 Err(crate::error::DbError::TypeConversionError(
510 "无法转换数字类型".to_string(),
511 ))
512 }
513 }
514 Value::String(s) => Ok(SqlValue::String(s.clone())),
515 Value::Array(_) | Value::Object(_) => {
516 Ok(SqlValue::Json(value.clone()))
518 }
519 }
520 }
521}
522
523pub struct QueryBuilder<'a> {
525 #[allow(dead_code)]
526 pool: &'a MySqlPool,
527 table: String,
528 fields: Vec<String>,
529 #[allow(dead_code)]
530 conditions: Vec<Condition>,
531 #[allow(dead_code)]
532 joins: Vec<JoinClause>,
533 #[allow(dead_code)]
534 order_by: Vec<OrderClause>,
535 #[allow(dead_code)]
536 group_by: Vec<String>,
537 limit: Option<u64>,
538 offset: Option<u64>,
539 distinct: bool,
540 field_types: HashMap<String, FieldType>,
541 #[allow(dead_code)]
542 enable_logging: bool,
543}
544
545impl<'a> QueryBuilder<'a> {
546 pub(crate) fn new(pool: &'a MySqlPool, table_name: &str, enable_logging: bool) -> Self {
548 Self {
549 pool,
550 table: table_name.to_string(),
551 fields: Vec::new(),
552 conditions: Vec::new(),
553 joins: Vec::new(),
554 order_by: Vec::new(),
555 group_by: Vec::new(),
556 limit: None,
557 offset: None,
558 distinct: false,
559 field_types: HashMap::new(),
560 enable_logging,
561 }
562 }
563
564 pub fn field(mut self, field: &str) -> Self {
566 self.fields.push(field.to_string());
567 self
568 }
569
570 pub fn fields(mut self, fields: &[&str]) -> Self {
572 for field in fields {
573 self.fields.push(field.to_string());
574 }
575 self
576 }
577
578 pub fn json(mut self, field: &str) -> Self {
580 self.field_types.insert(field.to_string(), FieldType::Json);
581 self
582 }
583
584 pub fn datetime(mut self, field: &str) -> Self {
586 self.field_types
587 .insert(field.to_string(), FieldType::DateTime);
588 self
589 }
590
591 pub fn timestamp(mut self, field: &str) -> Self {
593 self.field_types
594 .insert(field.to_string(), FieldType::Timestamp);
595 self
596 }
597
598 pub fn decimal(mut self, field: &str) -> Self {
600 self.field_types
601 .insert(field.to_string(), FieldType::Decimal);
602 self
603 }
604
605 pub fn blob(mut self, field: &str) -> Self {
607 self.field_types.insert(field.to_string(), FieldType::Blob);
608 self
609 }
610
611 pub fn text(mut self, field: &str) -> Self {
613 self.field_types.insert(field.to_string(), FieldType::Text);
614 self
615 }
616
617 pub fn distinct(mut self) -> Self {
619 self.distinct = true;
620 self
621 }
622
623 pub fn where_and<V>(mut self, field: &str, op: &str, value: V) -> Self
625 where
626 V: Into<crate::mysql::condition::SqlValue>,
627 {
628 use crate::mysql::condition::{Condition, SqlValue};
629
630 let sql_value = value.into();
631 let condition = match op {
632 "=" => Condition::Eq(field.to_string(), sql_value),
633 "!=" => Condition::Ne(field.to_string(), sql_value),
634 ">" => Condition::Gt(field.to_string(), sql_value),
635 "<" => Condition::Lt(field.to_string(), sql_value),
636 ">=" => Condition::Gte(field.to_string(), sql_value),
637 "<=" => Condition::Lte(field.to_string(), sql_value),
638 "like" | "LIKE" => {
639 if let SqlValue::String(s) = sql_value {
640 Condition::Like(field.to_string(), s)
641 } else {
642 Condition::Like(field.to_string(), format!("{:?}", sql_value))
644 }
645 }
646 _ => panic!("不支持的操作符: {}", op),
647 };
648
649 self.conditions.push(condition);
650 self
651 }
652
653 pub fn where_or<V>(mut self, field: &str, op: &str, value: V) -> Self
655 where
656 V: Into<crate::mysql::condition::SqlValue>,
657 {
658 use crate::mysql::condition::{Condition, SqlValue};
659
660 let sql_value = value.into();
661 let condition = match op {
662 "=" => Condition::Eq(field.to_string(), sql_value),
663 "!=" => Condition::Ne(field.to_string(), sql_value),
664 ">" => Condition::Gt(field.to_string(), sql_value),
665 "<" => Condition::Lt(field.to_string(), sql_value),
666 ">=" => Condition::Gte(field.to_string(), sql_value),
667 "<=" => Condition::Lte(field.to_string(), sql_value),
668 "like" | "LIKE" => {
669 if let SqlValue::String(s) = sql_value {
670 Condition::Like(field.to_string(), s)
671 } else {
672 Condition::Like(field.to_string(), format!("{:?}", sql_value))
673 }
674 }
675 _ => panic!("不支持的操作符: {}", op),
676 };
677
678 if !self.conditions.is_empty() {
680 let existing = std::mem::take(&mut self.conditions);
681 self.conditions.push(Condition::Or(vec![
682 if existing.len() == 1 {
683 existing.into_iter().next().unwrap()
684 } else {
685 Condition::And(existing)
686 },
687 condition,
688 ]));
689 } else {
690 self.conditions.push(condition);
691 }
692
693 self
694 }
695
696 pub fn where_in<V>(mut self, field: &str, values: Vec<V>) -> Self
698 where
699 V: Into<crate::mysql::condition::SqlValue>,
700 {
701 use crate::mysql::condition::Condition;
702
703 let sql_values: Vec<_> = values.into_iter().map(|v| v.into()).collect();
704 self.conditions
705 .push(Condition::In(field.to_string(), sql_values));
706 self
707 }
708
709 pub fn where_between<V>(mut self, field: &str, start: V, end: V) -> Self
711 where
712 V: Into<crate::mysql::condition::SqlValue>,
713 {
714 use crate::mysql::condition::Condition;
715
716 self.conditions.push(Condition::Between(
717 field.to_string(),
718 start.into(),
719 end.into(),
720 ));
721 self
722 }
723
724 pub fn join(mut self, table: &str, on: &str) -> Self {
726 use crate::mysql::field::{JoinClause, JoinType};
727
728 self.joins.push(JoinClause {
729 join_type: JoinType::Inner,
730 table: table.to_string(),
731 on: on.to_string(),
732 });
733 self
734 }
735
736 pub fn left_join(mut self, table: &str, on: &str) -> Self {
738 use crate::mysql::field::{JoinClause, JoinType};
739
740 self.joins.push(JoinClause {
741 join_type: JoinType::Left,
742 table: table.to_string(),
743 on: on.to_string(),
744 });
745 self
746 }
747
748 pub fn right_join(mut self, table: &str, on: &str) -> Self {
750 use crate::mysql::field::{JoinClause, JoinType};
751
752 self.joins.push(JoinClause {
753 join_type: JoinType::Right,
754 table: table.to_string(),
755 on: on.to_string(),
756 });
757 self
758 }
759
760 pub fn order(mut self, field: &str, asc: bool) -> Self {
762 use crate::mysql::field::OrderClause;
763
764 self.order_by.push(OrderClause {
765 field: field.to_string(),
766 asc,
767 });
768 self
769 }
770
771 pub fn group(mut self, field: &str) -> Self {
773 self.group_by.push(field.to_string());
774 self
775 }
776
777 pub fn limit(mut self, limit: u64) -> Self {
779 self.limit = Some(limit);
780 self
781 }
782
783 pub fn offset(mut self, offset: u64) -> Self {
785 self.offset = Some(offset);
786 self
787 }
788
789 pub fn to_sql(&self) -> String {
794 let mut generator = SqlGenerator::new();
795
796 match generator.build_select(self) {
798 Ok(_) => generator.get_sql().to_string(),
799 Err(_) => {
800 let fields_str = if self.fields.is_empty() {
802 "*".to_string()
803 } else {
804 self.fields.join(", ")
805 };
806
807 let distinct_str = if self.distinct { "DISTINCT " } else { "" };
808
809 format!("SELECT {}{} FROM {}", distinct_str, fields_str, self.table)
810 }
811 }
812 }
813
814 pub async fn find<T>(mut self) -> Result<Option<T>, crate::error::DbError>
852 where
853 T: for<'r> sqlx::FromRow<'r, sqlx::mysql::MySqlRow> + Send + Unpin,
854 {
855 self.limit = Some(1);
857
858 let mut generator = SqlGenerator::new();
860 generator.build_select(&self)?;
861
862 let sql = generator.get_sql();
863 let params = generator.get_params();
864
865 if self.enable_logging {
867 log::debug!("执行 find() 查询: {}", sql);
868 log::debug!("参数: {:?}", params);
869 }
870
871 let mut query = sqlx::query_as::<_, T>(sql);
873
874 for param in params {
876 query = bind_param(query, param);
877 }
878
879 let result = query.fetch_optional(self.pool).await;
881
882 match result {
883 Ok(row) => {
884 if self.enable_logging {
885 if row.is_some() {
886 log::debug!("find() 查询成功,返回 1 条记录");
887 } else {
888 log::debug!("find() 查询成功,未找到匹配记录");
889 }
890 }
891 Ok(row)
892 }
893 Err(e) => {
894 log::error!("find() 查询失败: {}", e);
895 Err(crate::error::DbError::from(e))
896 }
897 }
898 }
899
900 pub async fn select<T>(self) -> Result<Vec<T>, crate::error::DbError>
938 where
939 T: for<'r> sqlx::FromRow<'r, sqlx::mysql::MySqlRow> + Send + Unpin,
940 {
941 let mut generator = SqlGenerator::new();
943 generator.build_select(&self)?;
944
945 let sql = generator.get_sql();
946 let params = generator.get_params();
947
948 if self.enable_logging {
950 log::debug!("执行 select() 查询: {}", sql);
951 log::debug!("参数: {:?}", params);
952 }
953
954 let mut query = sqlx::query_as::<_, T>(sql);
956
957 for param in params {
959 query = bind_param(query, param);
960 }
961
962 let result = query.fetch_all(self.pool).await;
964
965 match result {
966 Ok(rows) => {
967 if self.enable_logging {
968 log::debug!("select() 查询成功,返回 {} 条记录", rows.len());
969 }
970 Ok(rows)
971 }
972 Err(e) => {
973 log::error!("select() 查询失败: {}", e);
974 Err(crate::error::DbError::from(e))
975 }
976 }
977 }
978
979 pub async fn value<T>(mut self, field: &str) -> Result<Option<T>, crate::error::DbError>
1023 where
1024 T: for<'r> sqlx::Decode<'r, sqlx::MySql> + sqlx::Type<sqlx::MySql> + Send + Unpin,
1025 {
1026 self.fields.clear();
1028 self.fields.push(field.to_string());
1029
1030 self.limit = Some(1);
1032
1033 let mut generator = SqlGenerator::new();
1035 generator.build_select(&self)?;
1036
1037 let sql = generator.get_sql();
1038 let params = generator.get_params();
1039
1040 if self.enable_logging {
1042 log::debug!("执行 value() 查询: {}", sql);
1043 log::debug!("参数: {:?}", params);
1044 }
1045
1046 let mut query = sqlx::query_scalar::<_, T>(sql);
1048
1049 for param in params {
1051 query = bind_scalar_param(query, param);
1052 }
1053
1054 let result = query.fetch_optional(self.pool).await;
1056
1057 match result {
1058 Ok(value) => {
1059 if self.enable_logging {
1060 if value.is_some() {
1061 log::debug!("value() 查询成功,返回字段值");
1062 } else {
1063 log::debug!("value() 查询成功,未找到匹配记录");
1064 }
1065 }
1066 Ok(value)
1067 }
1068 Err(e) => {
1069 log::error!("value() 查询失败: {}", e);
1070 Err(crate::error::DbError::from(e))
1071 }
1072 }
1073 }
1074
1075 pub async fn count(self) -> Result<i64, crate::error::DbError> {
1106 if self.enable_logging {
1108 log::debug!("执行 count() 查询");
1109 }
1110
1111 let result = self.value::<i64>("COUNT(*)").await?;
1113
1114 Ok(result.unwrap_or(0))
1116 }
1117
1118 pub async fn sum(self, field: &str) -> Result<Option<f64>, crate::error::DbError> {
1162 if self.enable_logging {
1164 log::debug!("执行 sum() 查询,字段: {}", field);
1165 }
1166
1167 let sum_expr = format!("CAST(SUM({}) AS DOUBLE)", field);
1170
1171 let mut builder = self;
1173 builder.fields.clear();
1174 builder.fields.push(sum_expr.clone());
1175
1176 builder.limit = Some(1);
1178
1179 let mut generator = SqlGenerator::new();
1181 generator.build_select(&builder)?;
1182
1183 let sql = generator.get_sql();
1184 let params = generator.get_params();
1185
1186 if builder.enable_logging {
1188 log::debug!("执行 sum() 查询: {}", sql);
1189 log::debug!("参数: {:?}", params);
1190 }
1191
1192 let mut query = sqlx::query_scalar::<_, Option<f64>>(sql);
1194
1195 for param in params {
1197 query = bind_scalar_param_option(query, param);
1198 }
1199
1200 let result = query.fetch_optional(builder.pool).await;
1202
1203 match result {
1204 Ok(Some(value)) => {
1205 if builder.enable_logging {
1207 if value.is_some() {
1208 log::debug!("sum() 查询成功,返回总和");
1209 } else {
1210 log::debug!("sum() 查询成功,返回 None(没有匹配记录或所有值为 NULL)");
1211 }
1212 }
1213 Ok(value)
1214 }
1215 Ok(None) => {
1216 if builder.enable_logging {
1218 log::debug!("sum() 查询成功,未找到匹配记录");
1219 }
1220 Ok(None)
1221 }
1222 Err(e) => {
1223 log::error!("sum() 查询失败: {}", e);
1224 Err(crate::error::DbError::from(e))
1225 }
1226 }
1227 }
1228
1229 pub async fn insert<T>(self, data: &T) -> Result<u64, crate::error::DbError>
1282 where
1283 T: serde::Serialize,
1284 {
1285 if self.enable_logging {
1287 log::debug!("执行 insert() 操作,表: {}", self.table);
1288 }
1289
1290 let json_data = serde_json::to_value(data).map_err(|e| {
1292 crate::error::DbError::SerializationError(format!("数据序列化失败: {}", e))
1293 })?;
1294
1295 let mut generator = SqlGenerator::new();
1297 generator.build_insert(&self.table, &json_data, &self.field_types)?;
1298
1299 let sql = generator.get_sql();
1300 let params = generator.get_params();
1301
1302 if self.enable_logging {
1304 log::debug!("执行 insert() SQL: {}", sql);
1305 log::debug!("参数: {:?}", params);
1306 }
1307
1308 let mut query = sqlx::query(sql);
1310
1311 for param in params {
1313 query = bind_execute_param(query, param);
1314 }
1315
1316 let result = query.execute(self.pool).await;
1318
1319 match result {
1320 Ok(query_result) => {
1321 let last_insert_id = query_result.last_insert_id();
1322 if self.enable_logging {
1323 log::debug!("insert() 成功,插入 ID: {}", last_insert_id);
1324 }
1325 Ok(last_insert_id)
1326 }
1327 Err(e) => {
1328 log::error!("insert() 失败: {}", e);
1329 Err(crate::error::DbError::from(e))
1330 }
1331 }
1332 }
1333
1334 pub async fn insert_batch<T>(self, data: &[T]) -> Result<u64, crate::error::DbError>
1401 where
1402 T: serde::Serialize,
1403 {
1404 if self.enable_logging {
1406 log::debug!(
1407 "执行 insert_batch() 操作,表: {},记录数: {}",
1408 self.table,
1409 data.len()
1410 );
1411 }
1412
1413 if data.is_empty() {
1415 return Err(crate::error::DbError::SerializationError(
1416 "批量插入数据不能为空".to_string(),
1417 ));
1418 }
1419
1420 let json_data_list: Result<Vec<_>, _> = data
1422 .iter()
1423 .map(|item| {
1424 serde_json::to_value(item).map_err(|e| {
1425 crate::error::DbError::SerializationError(format!("数据序列化失败: {}", e))
1426 })
1427 })
1428 .collect();
1429
1430 let json_data_list = json_data_list?;
1431
1432 let mut generator = SqlGenerator::new();
1434 generator.build_insert_batch(&self.table, &json_data_list, &self.field_types)?;
1435
1436 let sql = generator.get_sql();
1437 let params = generator.get_params();
1438
1439 if self.enable_logging {
1441 log::debug!("执行 insert_batch() SQL: {}", sql);
1442 log::debug!("参数数量: {}", params.len());
1443 }
1444
1445 let mut query = sqlx::query(sql);
1447
1448 for param in params {
1450 query = bind_execute_param(query, param);
1451 }
1452
1453 let result = query.execute(self.pool).await;
1455
1456 match result {
1457 Ok(query_result) => {
1458 let rows_affected = query_result.rows_affected();
1459 if self.enable_logging {
1460 log::debug!("insert_batch() 成功,影响 {} 行", rows_affected);
1461 }
1462 Ok(rows_affected)
1463 }
1464 Err(e) => {
1465 log::error!("insert_batch() 失败: {}", e);
1466 Err(crate::error::DbError::from(e))
1467 }
1468 }
1469 }
1470
1471 pub async fn update<T>(self, data: &T) -> Result<u64, crate::error::DbError>
1510 where
1511 T: serde::Serialize,
1512 {
1513 if self.enable_logging {
1515 log::debug!("执行 update() 操作,表: {}", self.table);
1516 }
1517
1518 if self.conditions.is_empty() {
1520 log::warn!("update() 操作缺少 WHERE 条件,禁止全表更新");
1521 return Err(crate::error::DbError::MissingWhereClause);
1522 }
1523
1524 let json_data = serde_json::to_value(data).map_err(|e| {
1526 crate::error::DbError::SerializationError(format!("数据序列化失败: {}", e))
1527 })?;
1528
1529 let mut generator = SqlGenerator::new();
1531 generator.build_update(&self.table, &json_data, &self.field_types, &self.conditions)?;
1532
1533 let sql = generator.get_sql();
1534 let params = generator.get_params();
1535
1536 if self.enable_logging {
1538 log::debug!("执行 update() SQL: {}", sql);
1539 log::debug!("参数: {:?}", params);
1540 }
1541
1542 let mut query = sqlx::query(sql);
1544
1545 for param in params {
1547 query = bind_execute_param(query, param);
1548 }
1549
1550 let result = query.execute(self.pool).await;
1552
1553 match result {
1554 Ok(query_result) => {
1555 let rows_affected = query_result.rows_affected();
1556 if self.enable_logging {
1557 log::debug!("update() 成功,影响 {} 行", rows_affected);
1558 }
1559 Ok(rows_affected)
1560 }
1561 Err(e) => {
1562 log::error!("update() 失败: {}", e);
1563 Err(crate::error::DbError::from(e))
1564 }
1565 }
1566 }
1567
1568 pub async fn delete(self) -> Result<u64, crate::error::DbError> {
1595 if self.enable_logging {
1597 log::debug!("执行 delete() 操作,表: {}", self.table);
1598 }
1599
1600 if self.conditions.is_empty() {
1602 log::warn!("delete() 操作缺少 WHERE 条件,禁止全表删除");
1603 return Err(crate::error::DbError::MissingWhereClause);
1604 }
1605
1606 let mut generator = SqlGenerator::new();
1608 generator.build_delete(&self.table, &self.conditions)?;
1609
1610 let sql = generator.get_sql();
1611 let params = generator.get_params();
1612
1613 if self.enable_logging {
1615 log::debug!("执行 delete() SQL: {}", sql);
1616 log::debug!("参数: {:?}", params);
1617 }
1618
1619 let mut query = sqlx::query(sql);
1621
1622 for param in params {
1624 query = bind_execute_param(query, param);
1625 }
1626
1627 let result = query.execute(self.pool).await;
1629
1630 match result {
1631 Ok(query_result) => {
1632 let rows_affected = query_result.rows_affected();
1633 if self.enable_logging {
1634 log::debug!("delete() 成功,影响 {} 行", rows_affected);
1635 }
1636 Ok(rows_affected)
1637 }
1638 Err(e) => {
1639 log::error!("delete() 失败: {}", e);
1640 Err(crate::error::DbError::from(e))
1641 }
1642 }
1643 }
1644}
1645
1646fn bind_execute_param<'q>(
1655 query: sqlx::query::Query<'q, sqlx::MySql, sqlx::mysql::MySqlArguments>,
1656 param: &SqlValue,
1657) -> sqlx::query::Query<'q, sqlx::MySql, sqlx::mysql::MySqlArguments> {
1658 match param {
1659 SqlValue::Null => query.bind(Option::<i32>::None),
1660 SqlValue::Bool(b) => query.bind(*b),
1661 SqlValue::Int(i) => query.bind(*i),
1662 SqlValue::Float(f) => query.bind(*f),
1663 SqlValue::String(s) => query.bind(s.clone()),
1664 SqlValue::Bytes(b) => query.bind(b.clone()),
1665 SqlValue::Json(j) => query.bind(j.to_string()),
1666 SqlValue::DateTime(dt) => query.bind(*dt),
1667 SqlValue::Timestamp(ts) => query.bind(*ts),
1668 }
1669}
1670
1671fn bind_param<'q, T>(
1680 query: sqlx::query::QueryAs<'q, sqlx::MySql, T, sqlx::mysql::MySqlArguments>,
1681 param: &SqlValue,
1682) -> sqlx::query::QueryAs<'q, sqlx::MySql, T, sqlx::mysql::MySqlArguments>
1683where
1684 T: for<'r> sqlx::FromRow<'r, sqlx::mysql::MySqlRow> + Send + Unpin,
1685{
1686 match param {
1687 SqlValue::Null => query.bind(Option::<i32>::None),
1688 SqlValue::Bool(b) => query.bind(*b),
1689 SqlValue::Int(i) => query.bind(*i),
1690 SqlValue::Float(f) => query.bind(*f),
1691 SqlValue::String(s) => query.bind(s.clone()),
1692 SqlValue::Bytes(b) => query.bind(b.clone()),
1693 SqlValue::Json(j) => query.bind(j.to_string()),
1694 SqlValue::DateTime(dt) => query.bind(*dt),
1695 SqlValue::Timestamp(ts) => query.bind(*ts),
1696 }
1697}
1698
1699fn bind_scalar_param<'q, T>(
1708 query: sqlx::query::QueryScalar<'q, sqlx::MySql, T, sqlx::mysql::MySqlArguments>,
1709 param: &SqlValue,
1710) -> sqlx::query::QueryScalar<'q, sqlx::MySql, T, sqlx::mysql::MySqlArguments>
1711where
1712 T: for<'r> sqlx::Decode<'r, sqlx::MySql> + sqlx::Type<sqlx::MySql> + Send + Unpin,
1713{
1714 match param {
1715 SqlValue::Null => query.bind(Option::<i32>::None),
1716 SqlValue::Bool(b) => query.bind(*b),
1717 SqlValue::Int(i) => query.bind(*i),
1718 SqlValue::Float(f) => query.bind(*f),
1719 SqlValue::String(s) => query.bind(s.clone()),
1720 SqlValue::Bytes(b) => query.bind(b.clone()),
1721 SqlValue::Json(j) => query.bind(j.to_string()),
1722 SqlValue::DateTime(dt) => query.bind(*dt),
1723 SqlValue::Timestamp(ts) => query.bind(*ts),
1724 }
1725}
1726
1727fn bind_scalar_param_option<'q, T>(
1736 query: sqlx::query::QueryScalar<'q, sqlx::MySql, Option<T>, sqlx::mysql::MySqlArguments>,
1737 param: &SqlValue,
1738) -> sqlx::query::QueryScalar<'q, sqlx::MySql, Option<T>, sqlx::mysql::MySqlArguments>
1739where
1740 T: for<'r> sqlx::Decode<'r, sqlx::MySql> + sqlx::Type<sqlx::MySql> + Send + Unpin,
1741{
1742 match param {
1743 SqlValue::Null => query.bind(Option::<i32>::None),
1744 SqlValue::Bool(b) => query.bind(*b),
1745 SqlValue::Int(i) => query.bind(*i),
1746 SqlValue::Float(f) => query.bind(*f),
1747 SqlValue::String(s) => query.bind(s.clone()),
1748 SqlValue::Bytes(b) => query.bind(b.clone()),
1749 SqlValue::Json(j) => query.bind(j.to_string()),
1750 SqlValue::DateTime(dt) => query.bind(*dt),
1751 SqlValue::Timestamp(ts) => query.bind(*ts),
1752 }
1753}
1754
1755#[cfg(test)]
1756mod tests {
1757 use super::*;
1758 use sqlx::mysql::MySqlPoolOptions;
1759
1760 async fn create_test_pool() -> MySqlPool {
1762 MySqlPoolOptions::new()
1763 .max_connections(1)
1764 .connect("mysql://root:111111@localhost:3306/test")
1765 .await
1766 .expect("无法连接到测试数据库")
1767 }
1768
1769 #[tokio::test]
1770 async fn test_table_name_in_sql() {
1771 let pool = create_test_pool().await;
1772 let builder = QueryBuilder::new(&pool, "users", false);
1773 let sql = builder.to_sql();
1774 assert!(sql.contains("FROM users"));
1775 }
1776
1777 #[test]
1779 fn test_sql_generator_new() {
1780 let generator = SqlGenerator::new();
1781 assert_eq!(generator.get_sql(), "");
1782 assert_eq!(generator.get_params().len(), 0);
1783 }
1784
1785 #[test]
1786 fn test_sql_generator_append() {
1787 let mut generator = SqlGenerator::new();
1788 generator.append("SELECT * FROM users");
1789 assert_eq!(generator.get_sql(), "SELECT * FROM users");
1790 }
1791
1792 #[test]
1793 fn test_sql_generator_add_param() {
1794 let mut generator = SqlGenerator::new();
1795 generator.add_param(SqlValue::Int(42));
1796 generator.add_param(SqlValue::String("test".to_string()));
1797 assert_eq!(generator.get_params().len(), 2);
1798 }
1799
1800 #[test]
1801 fn test_sql_generator_clear() {
1802 let mut generator = SqlGenerator::new();
1803 generator.append("SELECT * FROM users");
1804 generator.add_param(SqlValue::Int(1));
1805
1806 generator.clear();
1807
1808 assert_eq!(generator.get_sql(), "");
1809 assert_eq!(generator.get_params().len(), 0);
1810 }
1811
1812 #[test]
1813 fn test_sql_generator_multiple_operations() {
1814 let mut generator = SqlGenerator::new();
1815
1816 generator.append("SELECT * FROM users WHERE id = ?");
1817 generator.add_param(SqlValue::Int(1));
1818 generator.append(" AND name = ?");
1819 generator.add_param(SqlValue::String("test".to_string()));
1820
1821 assert_eq!(
1822 generator.get_sql(),
1823 "SELECT * FROM users WHERE id = ? AND name = ?"
1824 );
1825 assert_eq!(generator.get_params().len(), 2);
1826 }
1827
1828 #[tokio::test]
1829 async fn test_field_selection() {
1830 let pool = create_test_pool().await;
1831 let builder = QueryBuilder::new(&pool, "users", false)
1832 .field("id")
1833 .field("name");
1834 let sql = builder.to_sql();
1835 assert!(sql.contains("id, name"));
1836 }
1837
1838 #[tokio::test]
1839 async fn test_fields_selection() {
1840 let pool = create_test_pool().await;
1841 let builder = QueryBuilder::new(&pool, "users", false).fields(&["id", "name", "email"]);
1842 let sql = builder.to_sql();
1843 assert!(sql.contains("id, name, email"));
1844 }
1845
1846 #[tokio::test]
1847 async fn test_distinct() {
1848 let pool = create_test_pool().await;
1849 let builder = QueryBuilder::new(&pool, "users", false)
1850 .field("name")
1851 .distinct();
1852 let sql = builder.to_sql();
1853 assert!(sql.contains("SELECT DISTINCT"));
1854 }
1855
1856 #[tokio::test]
1857 async fn test_field_type_marking() {
1858 let pool = create_test_pool().await;
1859 let builder = QueryBuilder::new(&pool, "users", false)
1860 .json("data")
1861 .datetime("created_at")
1862 .timestamp("updated_at")
1863 .decimal("price")
1864 .blob("content")
1865 .text("description");
1866
1867 assert_eq!(builder.field_types.get("data"), Some(&FieldType::Json));
1868 assert_eq!(
1869 builder.field_types.get("created_at"),
1870 Some(&FieldType::DateTime)
1871 );
1872 assert_eq!(
1873 builder.field_types.get("updated_at"),
1874 Some(&FieldType::Timestamp)
1875 );
1876 assert_eq!(builder.field_types.get("price"), Some(&FieldType::Decimal));
1877 assert_eq!(builder.field_types.get("content"), Some(&FieldType::Blob));
1878 assert_eq!(
1879 builder.field_types.get("description"),
1880 Some(&FieldType::Text)
1881 );
1882 }
1883
1884 #[tokio::test]
1885 async fn test_where_and() {
1886 let pool = create_test_pool().await;
1887 let builder = QueryBuilder::new(&pool, "users", false)
1888 .where_and("name", "=", "test")
1889 .where_and("age", ">", 18);
1890
1891 assert_eq!(builder.conditions.len(), 2);
1892 }
1893
1894 #[tokio::test]
1895 async fn test_where_or() {
1896 let pool = create_test_pool().await;
1897 let builder = QueryBuilder::new(&pool, "users", false)
1898 .where_or("status", "=", 1)
1899 .where_or("status", "=", 2);
1900
1901 assert_eq!(builder.conditions.len(), 1);
1903 }
1904
1905 #[tokio::test]
1906 async fn test_where_in() {
1907 let pool = create_test_pool().await;
1908 let builder = QueryBuilder::new(&pool, "users", false).where_in("id", vec![1, 2, 3]);
1909
1910 assert_eq!(builder.conditions.len(), 1);
1911 }
1912
1913 #[tokio::test]
1914 async fn test_where_between() {
1915 let pool = create_test_pool().await;
1916 let builder = QueryBuilder::new(&pool, "users", false).where_between("age", 18, 65);
1917
1918 assert_eq!(builder.conditions.len(), 1);
1919 }
1920
1921 #[tokio::test]
1922 async fn test_join() {
1923 let pool = create_test_pool().await;
1924 let builder =
1925 QueryBuilder::new(&pool, "users", false).join("orders", "users.id = orders.user_id");
1926
1927 assert_eq!(builder.joins.len(), 1);
1928 }
1929
1930 #[tokio::test]
1931 async fn test_left_join() {
1932 let pool = create_test_pool().await;
1933 let builder = QueryBuilder::new(&pool, "users", false)
1934 .left_join("orders", "users.id = orders.user_id");
1935
1936 assert_eq!(builder.joins.len(), 1);
1937 }
1938
1939 #[tokio::test]
1940 async fn test_right_join() {
1941 let pool = create_test_pool().await;
1942 let builder = QueryBuilder::new(&pool, "users", false)
1943 .right_join("orders", "users.id = orders.user_id");
1944
1945 assert_eq!(builder.joins.len(), 1);
1946 }
1947
1948 #[tokio::test]
1949 async fn test_order() {
1950 let pool = create_test_pool().await;
1951 let builder = QueryBuilder::new(&pool, "users", false)
1952 .order("name", true)
1953 .order("age", false);
1954
1955 assert_eq!(builder.order_by.len(), 2);
1956 }
1957
1958 #[tokio::test]
1959 async fn test_group() {
1960 let pool = create_test_pool().await;
1961 let builder = QueryBuilder::new(&pool, "users", false)
1962 .group("status")
1963 .group("role");
1964
1965 assert_eq!(builder.group_by.len(), 2);
1966 }
1967
1968 #[tokio::test]
1970 async fn test_select_with_where() {
1971 let pool = create_test_pool().await;
1972 let builder = QueryBuilder::new(&pool, "users", false)
1973 .field("id")
1974 .field("name")
1975 .where_and("status", "=", 1);
1976
1977 let sql = builder.to_sql();
1978 assert!(sql.contains("SELECT id, name FROM users"));
1979 assert!(sql.contains("WHERE"));
1980 }
1981
1982 #[tokio::test]
1983 async fn test_select_with_join() {
1984 let pool = create_test_pool().await;
1985 let builder = QueryBuilder::new(&pool, "users", false)
1986 .field("users.id")
1987 .field("orders.total")
1988 .join("orders", "users.id = orders.user_id");
1989
1990 let sql = builder.to_sql();
1991 assert!(sql.contains("SELECT users.id, orders.total FROM users"));
1992 assert!(sql.contains("INNER JOIN orders ON users.id = orders.user_id"));
1993 }
1994
1995 #[tokio::test]
1996 async fn test_select_with_order_by() {
1997 let pool = create_test_pool().await;
1998 let builder = QueryBuilder::new(&pool, "users", false)
1999 .field("name")
2000 .order("name", true)
2001 .order("age", false);
2002
2003 let sql = builder.to_sql();
2004 assert!(sql.contains("ORDER BY name ASC, age DESC"));
2005 }
2006
2007 #[tokio::test]
2008 async fn test_select_with_group_by() {
2009 let pool = create_test_pool().await;
2010 let builder = QueryBuilder::new(&pool, "users", false)
2011 .field("status")
2012 .group("status");
2013
2014 let sql = builder.to_sql();
2015 assert!(sql.contains("GROUP BY status"));
2016 }
2017
2018 #[tokio::test]
2019 async fn test_select_with_limit_offset() {
2020 let pool = create_test_pool().await;
2021 let builder = QueryBuilder::new(&pool, "users", false)
2022 .field("id")
2023 .limit(10)
2024 .offset(20);
2025
2026 let sql = builder.to_sql();
2027 assert!(sql.contains("LIMIT 10"));
2028 assert!(sql.contains("OFFSET 20"));
2029 }
2030
2031 #[tokio::test]
2032 async fn test_select_complex_query() {
2033 let pool = create_test_pool().await;
2034 let builder = QueryBuilder::new(&pool, "users", false)
2035 .field("users.id")
2036 .field("users.name")
2037 .field("orders.total")
2038 .distinct()
2039 .join("orders", "users.id = orders.user_id")
2040 .where_and("users.status", "=", 1)
2041 .where_and("orders.total", ">", 100)
2042 .group("users.id")
2043 .order("orders.total", false)
2044 .limit(50);
2045
2046 let sql = builder.to_sql();
2047 assert!(sql.contains("SELECT DISTINCT"));
2048 assert!(sql.contains("users.id, users.name, orders.total"));
2049 assert!(sql.contains("FROM users"));
2050 assert!(sql.contains("INNER JOIN orders ON users.id = orders.user_id"));
2051 assert!(sql.contains("WHERE"));
2052 assert!(sql.contains("GROUP BY users.id"));
2053 assert!(sql.contains("ORDER BY orders.total DESC"));
2054 assert!(sql.contains("LIMIT 50"));
2055 }
2056
2057 #[tokio::test]
2058 async fn test_select_with_multiple_joins() {
2059 let pool = create_test_pool().await;
2060 let builder = QueryBuilder::new(&pool, "users", false)
2061 .field("users.name")
2062 .field("orders.total")
2063 .field("products.name")
2064 .join("orders", "users.id = orders.user_id")
2065 .left_join("products", "orders.product_id = products.id");
2066
2067 let sql = builder.to_sql();
2068 assert!(sql.contains("INNER JOIN orders ON users.id = orders.user_id"));
2069 assert!(sql.contains("LEFT JOIN products ON orders.product_id = products.id"));
2070 }
2071
2072 #[tokio::test]
2073 async fn test_select_with_in_condition() {
2074 let pool = create_test_pool().await;
2075 let builder = QueryBuilder::new(&pool, "users", false)
2076 .field("name")
2077 .where_in("id", vec![1, 2, 3, 4, 5]);
2078
2079 let sql = builder.to_sql();
2080 assert!(sql.contains("WHERE"));
2081 assert!(sql.contains("IN"));
2082 }
2083
2084 #[tokio::test]
2085 async fn test_select_with_between_condition() {
2086 let pool = create_test_pool().await;
2087 let builder = QueryBuilder::new(&pool, "users", false)
2088 .field("name")
2089 .where_between("age", 18, 65);
2090
2091 let sql = builder.to_sql();
2092 assert!(sql.contains("WHERE"));
2093 assert!(sql.contains("BETWEEN"));
2094 }
2095
2096 #[tokio::test]
2098 async fn test_sql_generator_build_select_basic() {
2099 let pool = create_test_pool().await;
2100 let builder = QueryBuilder::new(&pool, "users", false)
2101 .field("id")
2102 .field("name");
2103
2104 let mut generator = SqlGenerator::new();
2105 let result = generator.build_select(&builder);
2106
2107 assert!(result.is_ok());
2108 assert_eq!(generator.get_sql(), "SELECT id, name FROM users");
2109 }
2110
2111 #[tokio::test]
2112 async fn test_sql_generator_build_select_with_distinct() {
2113 let pool = create_test_pool().await;
2114 let builder = QueryBuilder::new(&pool, "users", false)
2115 .field("name")
2116 .distinct();
2117
2118 let mut generator = SqlGenerator::new();
2119 let result = generator.build_select(&builder);
2120
2121 assert!(result.is_ok());
2122 assert_eq!(generator.get_sql(), "SELECT DISTINCT name FROM users");
2123 }
2124
2125 #[tokio::test]
2126 async fn test_sql_generator_build_select_all_fields() {
2127 let pool = create_test_pool().await;
2128 let builder = QueryBuilder::new(&pool, "users", false);
2129
2130 let mut generator = SqlGenerator::new();
2131 let result = generator.build_select(&builder);
2132
2133 assert!(result.is_ok());
2134 assert_eq!(generator.get_sql(), "SELECT * FROM users");
2135 }
2136
2137 #[tokio::test]
2139 async fn test_sql_generator_build_where() {
2140 let pool = create_test_pool().await;
2141 let builder = QueryBuilder::new(&pool, "users", false)
2142 .where_and("status", "=", 1)
2143 .where_and("age", ">", 18);
2144
2145 let mut generator = SqlGenerator::new();
2146 let result = generator.build_select(&builder);
2147
2148 assert!(result.is_ok());
2149 let sql = generator.get_sql();
2150 assert!(sql.contains("WHERE"));
2151 assert!(sql.contains("status"));
2152 assert!(sql.contains("age"));
2153 }
2154
2155 #[tokio::test]
2157 async fn test_sql_generator_build_joins() {
2158 let pool = create_test_pool().await;
2159 let builder = QueryBuilder::new(&pool, "users", false)
2160 .join("orders", "users.id = orders.user_id")
2161 .left_join("profiles", "users.id = profiles.user_id");
2162
2163 let mut generator = SqlGenerator::new();
2164 let result = generator.build_select(&builder);
2165
2166 assert!(result.is_ok());
2167 let sql = generator.get_sql();
2168 assert!(sql.contains("INNER JOIN orders ON users.id = orders.user_id"));
2169 assert!(sql.contains("LEFT JOIN profiles ON users.id = profiles.user_id"));
2170 }
2171
2172 #[tokio::test]
2174 async fn test_sql_generator_build_order_by() {
2175 let pool = create_test_pool().await;
2176 let builder = QueryBuilder::new(&pool, "users", false)
2177 .order("name", true)
2178 .order("created_at", false);
2179
2180 let mut generator = SqlGenerator::new();
2181 let result = generator.build_select(&builder);
2182
2183 assert!(result.is_ok());
2184 let sql = generator.get_sql();
2185 assert!(sql.contains("ORDER BY name ASC, created_at DESC"));
2186 }
2187
2188 #[tokio::test]
2190 async fn test_sql_generator_build_group_by() {
2191 let pool = create_test_pool().await;
2192 let builder = QueryBuilder::new(&pool, "users", false)
2193 .group("status")
2194 .group("role");
2195
2196 let mut generator = SqlGenerator::new();
2197 let result = generator.build_select(&builder);
2198
2199 assert!(result.is_ok());
2200 let sql = generator.get_sql();
2201 assert!(sql.contains("GROUP BY status, role"));
2202 }
2203
2204 #[tokio::test]
2206 async fn test_sql_generator_build_limit_offset() {
2207 let pool = create_test_pool().await;
2208 let builder = QueryBuilder::new(&pool, "users", false)
2209 .limit(10)
2210 .offset(20);
2211
2212 let mut generator = SqlGenerator::new();
2213 let result = generator.build_select(&builder);
2214
2215 assert!(result.is_ok());
2216 let sql = generator.get_sql();
2217 assert!(sql.contains("LIMIT 10"));
2218 assert!(sql.contains("OFFSET 20"));
2219 }
2220
2221 #[tokio::test]
2223 async fn test_sql_generator_complex_query() {
2224 let pool = create_test_pool().await;
2225 let builder = QueryBuilder::new(&pool, "users", false)
2226 .field("users.id")
2227 .field("users.name")
2228 .field("COUNT(orders.id) as order_count")
2229 .distinct()
2230 .join("orders", "users.id = orders.user_id")
2231 .where_and("users.status", "=", 1)
2232 .where_and("orders.total", ">", 100)
2233 .group("users.id")
2234 .group("users.name")
2235 .order("order_count", false)
2236 .limit(20)
2237 .offset(10);
2238
2239 let mut generator = SqlGenerator::new();
2240 let result = generator.build_select(&builder);
2241
2242 assert!(result.is_ok());
2243 let sql = generator.get_sql();
2244
2245 assert!(sql.starts_with("SELECT DISTINCT"));
2247 assert!(sql.contains("users.id, users.name, COUNT(orders.id) as order_count"));
2248 assert!(sql.contains("FROM users"));
2249 assert!(sql.contains("INNER JOIN orders ON users.id = orders.user_id"));
2250 assert!(sql.contains("WHERE"));
2251 assert!(sql.contains("GROUP BY users.id, users.name"));
2252 assert!(sql.contains("ORDER BY order_count DESC"));
2253 assert!(sql.contains("LIMIT 20"));
2254 assert!(sql.contains("OFFSET 10"));
2255 }
2256
2257 #[tokio::test]
2259 async fn test_find_adds_limit_one() {
2260 let pool = create_test_pool().await;
2261 let builder = QueryBuilder::new(&pool, "users", false)
2262 .field("id")
2263 .field("name")
2264 .where_and("id", "=", 1);
2265
2266 assert_eq!(builder.limit, None);
2268
2269 let builder_with_limit = QueryBuilder::new(&pool, "users", false)
2271 .field("id")
2272 .field("name")
2273 .where_and("id", "=", 1)
2274 .limit(1);
2275
2276 let sql = builder_with_limit.to_sql();
2277 assert!(sql.contains("LIMIT 1"), "find() 应该自动添加 LIMIT 1");
2278 }
2279
2280 #[test]
2282 fn test_sql_generator_build_insert_basic() {
2283 let mut generator = SqlGenerator::new();
2284 let data = serde_json::json!({
2285 "name": "张三",
2286 "age": 25,
2287 "email": "zhangsan@example.com"
2288 });
2289 let field_types = HashMap::new();
2290
2291 let result = generator.build_insert("users", &data, &field_types);
2292 assert!(result.is_ok());
2293
2294 let sql = generator.get_sql();
2295 assert!(sql.starts_with("INSERT INTO users"));
2296 assert!(sql.contains("name"));
2297 assert!(sql.contains("age"));
2298 assert!(sql.contains("email"));
2299 assert!(sql.contains("VALUES"));
2300 assert_eq!(generator.get_params().len(), 3);
2301 }
2302
2303 #[test]
2304 fn test_sql_generator_build_insert_with_json_field() {
2305 let mut generator = SqlGenerator::new();
2306 let data = serde_json::json!({
2307 "name": "测试用户",
2308 "data": {"role": "admin", "permissions": ["read", "write"]}
2309 });
2310
2311 let mut field_types = HashMap::new();
2312 field_types.insert("data".to_string(), FieldType::Json);
2313
2314 let result = generator.build_insert("users", &data, &field_types);
2315 assert!(result.is_ok());
2316
2317 let sql = generator.get_sql();
2318 assert!(sql.contains("INSERT INTO users"));
2319 assert!(sql.contains("name"));
2320 assert!(sql.contains("data"));
2321 assert_eq!(generator.get_params().len(), 2);
2322
2323 let params = generator.get_params();
2325 let has_json = params.iter().any(|p| matches!(p, SqlValue::Json(_)));
2326 assert!(has_json, "应该包含 JSON 类型的参数");
2327 }
2328
2329 #[test]
2330 fn test_sql_generator_build_insert_empty_data() {
2331 let mut generator = SqlGenerator::new();
2332 let data = serde_json::json!({});
2333 let field_types = HashMap::new();
2334
2335 let result = generator.build_insert("users", &data, &field_types);
2336 assert!(result.is_err());
2337 assert!(matches!(
2338 result.unwrap_err(),
2339 crate::error::DbError::SerializationError(_)
2340 ));
2341 }
2342
2343 #[test]
2344 fn test_sql_generator_build_insert_not_object() {
2345 let mut generator = SqlGenerator::new();
2346 let data = serde_json::json!([1, 2, 3]); let field_types = HashMap::new();
2348
2349 let result = generator.build_insert("users", &data, &field_types);
2350 assert!(result.is_err());
2351 assert!(matches!(
2352 result.unwrap_err(),
2353 crate::error::DbError::SerializationError(_)
2354 ));
2355 }
2356}
2357
2358#[cfg(test)]
2359mod property_tests {
2360 use super::*;
2361 use proptest::prelude::*;
2362 use sqlx::mysql::MySqlPoolOptions;
2363
2364 fn table_name_strategy() -> impl Strategy<Value = String> {
2366 "[a-z][a-z0-9_]{0,30}"
2367 }
2368
2369 fn field_name_strategy() -> impl Strategy<Value = String> {
2371 "[a-z][a-z0-9_]{0,30}"
2372 }
2373
2374 fn create_test_pool_sync() -> MySqlPool {
2376 tokio::runtime::Runtime::new().unwrap().block_on(async {
2377 MySqlPoolOptions::new()
2378 .max_connections(1)
2379 .connect("mysql://root:111111@localhost:3306/test")
2380 .await
2381 .expect("无法连接到测试数据库")
2382 })
2383 }
2384
2385 proptest! {
2388 #![proptest_config(ProptestConfig::with_cases(100))]
2389
2390 #[test]
2391 fn prop_table_name_in_sql(table_name in table_name_strategy()) {
2392 let pool = create_test_pool_sync();
2393 let builder = QueryBuilder::new(&pool, &table_name, false);
2394 let sql = builder.to_sql();
2395
2396 let expected = format!("FROM {}", table_name);
2398 prop_assert!(sql.contains(&expected));
2399 }
2400 }
2401
2402 proptest! {
2405 #![proptest_config(ProptestConfig::with_cases(100))]
2406
2407 #[test]
2408 fn prop_table_name_override(
2409 table_name1 in table_name_strategy(),
2410 table_name2 in table_name_strategy()
2411 ) {
2412 prop_assume!(table_name1 != table_name2);
2413
2414 let pool = create_test_pool_sync();
2415 let builder1 = QueryBuilder::new(&pool, &table_name1, false);
2417 let sql1 = builder1.to_sql();
2418 let expected1 = format!("FROM {}", table_name1);
2419 prop_assert!(sql1.contains(&expected1));
2420
2421 let builder2 = QueryBuilder::new(&pool, &table_name2, false);
2423 let sql2 = builder2.to_sql();
2424 let expected2 = format!("FROM {}", table_name2);
2425 prop_assert!(sql2.contains(&expected2));
2426
2427 let pattern1 = format!("FROM {} ", table_name1);
2430 let pattern1_alt = format!("FROM {}\n", table_name1);
2431 prop_assert!(!sql2.contains(&pattern1) && !sql2.contains(&pattern1_alt));
2432 }
2433 }
2434
2435 proptest! {
2438 #![proptest_config(ProptestConfig::with_cases(100))]
2439
2440 #[test]
2441 fn prop_field_selection(
2442 table_name in table_name_strategy(),
2443 fields in prop::collection::vec(field_name_strategy(), 1..10)
2444 ) {
2445 let pool = create_test_pool_sync();
2446 let mut builder = QueryBuilder::new(&pool, &table_name, false);
2447
2448 for field in &fields {
2450 builder = builder.field(field);
2451 }
2452
2453 let sql = builder.to_sql();
2454
2455 for field in &fields {
2457 prop_assert!(sql.contains(field));
2458 }
2459 }
2460 }
2461
2462 proptest! {
2465 #![proptest_config(ProptestConfig::with_cases(100))]
2466
2467 #[test]
2468 fn prop_distinct_keyword(
2469 table_name in table_name_strategy(),
2470 field in field_name_strategy()
2471 ) {
2472 let pool = create_test_pool_sync();
2473 let builder = QueryBuilder::new(&pool, &table_name, false)
2474 .field(&field)
2475 .distinct();
2476
2477 let sql = builder.to_sql();
2478
2479 prop_assert!(sql.contains("SELECT DISTINCT"));
2481 }
2482 }
2483
2484 proptest! {
2487 #![proptest_config(ProptestConfig::with_cases(100))]
2488
2489 #[test]
2490 fn prop_special_field_type_marking(
2491 table_name in table_name_strategy(),
2492 json_field in field_name_strategy(),
2493 datetime_field in field_name_strategy(),
2494 timestamp_field in field_name_strategy(),
2495 decimal_field in field_name_strategy(),
2496 blob_field in field_name_strategy(),
2497 text_field in field_name_strategy()
2498 ) {
2499 prop_assume!(json_field != datetime_field);
2501 prop_assume!(json_field != timestamp_field);
2502 prop_assume!(json_field != decimal_field);
2503 prop_assume!(json_field != blob_field);
2504 prop_assume!(json_field != text_field);
2505 prop_assume!(datetime_field != timestamp_field);
2506 prop_assume!(datetime_field != decimal_field);
2507 prop_assume!(datetime_field != blob_field);
2508 prop_assume!(datetime_field != text_field);
2509 prop_assume!(timestamp_field != decimal_field);
2510 prop_assume!(timestamp_field != blob_field);
2511 prop_assume!(timestamp_field != text_field);
2512 prop_assume!(decimal_field != blob_field);
2513 prop_assume!(decimal_field != text_field);
2514 prop_assume!(blob_field != text_field);
2515
2516 let pool = create_test_pool_sync();
2517 let builder = QueryBuilder::new(&pool, &table_name, false)
2518 .json(&json_field)
2519 .datetime(&datetime_field)
2520 .timestamp(×tamp_field)
2521 .decimal(&decimal_field)
2522 .blob(&blob_field)
2523 .text(&text_field);
2524
2525 prop_assert_eq!(builder.field_types.get(&json_field), Some(&FieldType::Json));
2527 prop_assert_eq!(builder.field_types.get(&datetime_field), Some(&FieldType::DateTime));
2528 prop_assert_eq!(builder.field_types.get(×tamp_field), Some(&FieldType::Timestamp));
2529 prop_assert_eq!(builder.field_types.get(&decimal_field), Some(&FieldType::Decimal));
2530 prop_assert_eq!(builder.field_types.get(&blob_field), Some(&FieldType::Blob));
2531 prop_assert_eq!(builder.field_types.get(&text_field), Some(&FieldType::Text));
2532 }
2533 }
2534
2535 proptest! {
2538 #![proptest_config(ProptestConfig::with_cases(100))]
2539
2540 #[test]
2541 fn prop_where_and_condition_added(
2542 table_name in table_name_strategy(),
2543 field in field_name_strategy(),
2544 value in any::<i32>()
2545 ) {
2546 let pool = create_test_pool_sync();
2547 let builder = QueryBuilder::new(&pool, &table_name, false)
2548 .where_and(&field, "=", value);
2549
2550 prop_assert_eq!(builder.conditions.len(), 1);
2552 }
2553
2554 #[test]
2555 fn prop_where_or_condition_added(
2556 table_name in table_name_strategy(),
2557 field in field_name_strategy(),
2558 value1 in any::<i32>(),
2559 value2 in any::<i32>()
2560 ) {
2561 let pool = create_test_pool_sync();
2562 let builder = QueryBuilder::new(&pool, &table_name, false)
2563 .where_or(&field, "=", value1)
2564 .where_or(&field, "=", value2);
2565
2566 prop_assert_eq!(builder.conditions.len(), 1);
2568 }
2569 }
2570
2571 proptest! {
2574 #![proptest_config(ProptestConfig::with_cases(100))]
2575
2576 #[test]
2577 fn prop_in_operator_array_support(
2578 table_name in table_name_strategy(),
2579 field in field_name_strategy(),
2580 values in prop::collection::vec(any::<i32>(), 1..10)
2581 ) {
2582 let pool = create_test_pool_sync();
2583 let builder = QueryBuilder::new(&pool, &table_name, false)
2584 .where_in(&field, values);
2585
2586 prop_assert_eq!(builder.conditions.len(), 1);
2588 }
2589 }
2590
2591 proptest! {
2594 #![proptest_config(ProptestConfig::with_cases(100))]
2595
2596 #[test]
2597 fn prop_between_operator_boundary_support(
2598 table_name in table_name_strategy(),
2599 field in field_name_strategy(),
2600 start in any::<i32>(),
2601 end in any::<i32>()
2602 ) {
2603 let pool = create_test_pool_sync();
2604 let builder = QueryBuilder::new(&pool, &table_name, false)
2605 .where_between(&field, start, end);
2606
2607 prop_assert_eq!(builder.conditions.len(), 1);
2609 }
2610 }
2611
2612 proptest! {
2615 #![proptest_config(ProptestConfig::with_cases(100))]
2616
2617 #[test]
2618 fn prop_multiple_and_conditions(
2619 table_name in table_name_strategy(),
2620 field in field_name_strategy(),
2621 values in prop::collection::vec(any::<i32>(), 2..5)
2622 ) {
2623 let pool = create_test_pool_sync();
2624 let mut builder = QueryBuilder::new(&pool, &table_name, false);
2625
2626 for value in &values {
2628 builder = builder.where_and(&field, "=", *value);
2629 }
2630
2631 prop_assert_eq!(builder.conditions.len(), values.len());
2633 }
2634 }
2635
2636 proptest! {
2639 #![proptest_config(ProptestConfig::with_cases(100))]
2640
2641 #[test]
2642 fn prop_join_clause_generation(
2643 table_name in table_name_strategy(),
2644 join_table in table_name_strategy(),
2645 on_condition in "[a-z][a-z0-9_]{0,20}\\.[a-z][a-z0-9_]{0,20} = [a-z][a-z0-9_]{0,20}\\.[a-z][a-z0-9_]{0,20}"
2646 ) {
2647 let pool = create_test_pool_sync();
2648
2649 let builder_inner = QueryBuilder::new(&pool, &table_name, false)
2651 .join(&join_table, &on_condition);
2652 prop_assert_eq!(builder_inner.joins.len(), 1);
2653
2654 let builder_left = QueryBuilder::new(&pool, &table_name, false)
2656 .left_join(&join_table, &on_condition);
2657 prop_assert_eq!(builder_left.joins.len(), 1);
2658
2659 let builder_right = QueryBuilder::new(&pool, &table_name, false)
2661 .right_join(&join_table, &on_condition);
2662 prop_assert_eq!(builder_right.joins.len(), 1);
2663 }
2664 }
2665
2666 proptest! {
2669 #![proptest_config(ProptestConfig::with_cases(100))]
2670
2671 #[test]
2672 fn prop_multiple_join_support(
2673 table_name in table_name_strategy(),
2674 join_tables in prop::collection::vec(table_name_strategy(), 1..5)
2675 ) {
2676 let pool = create_test_pool_sync();
2677 let mut builder = QueryBuilder::new(&pool, &table_name, false);
2678
2679 for join_table in &join_tables {
2681 let on_condition = format!("{}.id = {}.id", table_name, join_table);
2682 builder = builder.join(join_table, &on_condition);
2683 }
2684
2685 prop_assert_eq!(builder.joins.len(), join_tables.len());
2687 }
2688 }
2689
2690 proptest! {
2693 #![proptest_config(ProptestConfig::with_cases(100))]
2694
2695 #[test]
2696 fn prop_table_alias_support(
2697 base_table in table_name_strategy(),
2698 join_table in table_name_strategy(),
2699 base_alias in "[a-z][a-z0-9]{0,5}",
2700 join_alias in "[a-z][a-z0-9]{0,5}"
2701 ) {
2702 prop_assume!(base_table != join_table);
2703 prop_assume!(base_alias != join_alias);
2704
2705 let pool = create_test_pool_sync();
2706
2707 let base_table_with_alias = format!("{} AS {}", base_table, base_alias);
2709 let join_table_with_alias = format!("{} AS {}", join_table, join_alias);
2710
2711 let on_condition = format!("{}.id = {}.id", base_alias, join_alias);
2713
2714 let builder = QueryBuilder::new(&pool, &base_table_with_alias, false)
2716 .field(&format!("{}.id", base_alias))
2717 .field(&format!("{}.name", base_alias))
2718 .join(&join_table_with_alias, &on_condition);
2719
2720 let sql = builder.to_sql();
2721
2722 prop_assert!(sql.contains(&format!("FROM {}", base_table_with_alias)),
2724 "SQL 应该包含带别名的主表: FROM {}", base_table_with_alias);
2725
2726 prop_assert!(sql.contains(&join_table_with_alias),
2728 "SQL 应该包含带别名的 JOIN 表: {}", join_table_with_alias);
2729
2730 prop_assert!(sql.contains(&on_condition),
2732 "SQL 应该包含使用别名的 ON 条件: {}", on_condition);
2733
2734 prop_assert!(sql.contains(&format!("{}.id", base_alias)),
2736 "SQL 应该包含使用别名的字段: {}.id", base_alias);
2737 prop_assert!(sql.contains(&format!("{}.name", base_alias)),
2738 "SQL 应该包含使用别名的字段: {}.name", base_alias);
2739 }
2740 }
2741
2742 proptest! {
2745 #![proptest_config(ProptestConfig::with_cases(100))]
2746
2747 #[test]
2748 fn prop_order_by_clause_generation(
2749 table_name in table_name_strategy(),
2750 field in field_name_strategy(),
2751 asc in any::<bool>()
2752 ) {
2753 let pool = create_test_pool_sync();
2754 let builder = QueryBuilder::new(&pool, &table_name, false)
2755 .order(&field, asc);
2756
2757 prop_assert_eq!(builder.order_by.len(), 1);
2759 prop_assert_eq!(&builder.order_by[0].field, &field);
2760 prop_assert_eq!(builder.order_by[0].asc, asc);
2761 }
2762 }
2763
2764 proptest! {
2767 #![proptest_config(ProptestConfig::with_cases(100))]
2768
2769 #[test]
2770 fn prop_multiple_order_by_support(
2771 table_name in table_name_strategy(),
2772 fields in prop::collection::vec(field_name_strategy(), 1..5)
2773 ) {
2774 let pool = create_test_pool_sync();
2775 let mut builder = QueryBuilder::new(&pool, &table_name, false);
2776
2777 for field in &fields {
2779 builder = builder.order(field, true);
2780 }
2781
2782 prop_assert_eq!(builder.order_by.len(), fields.len());
2784 }
2785 }
2786
2787 proptest! {
2790 #![proptest_config(ProptestConfig::with_cases(100))]
2791
2792 #[test]
2793 fn prop_group_by_clause_generation(
2794 table_name in table_name_strategy(),
2795 field in field_name_strategy()
2796 ) {
2797 let pool = create_test_pool_sync();
2798 let builder = QueryBuilder::new(&pool, &table_name, false)
2799 .group(&field);
2800
2801 prop_assert_eq!(builder.group_by.len(), 1);
2803 prop_assert_eq!(&builder.group_by[0], &field);
2804 }
2805 }
2806
2807 proptest! {
2810 #![proptest_config(ProptestConfig::with_cases(100))]
2811
2812 #[test]
2813 fn prop_multiple_group_by_support(
2814 table_name in table_name_strategy(),
2815 fields in prop::collection::vec(field_name_strategy(), 1..5)
2816 ) {
2817 let pool = create_test_pool_sync();
2818 let mut builder = QueryBuilder::new(&pool, &table_name, false);
2819
2820 for field in &fields {
2822 builder = builder.group(field);
2823 }
2824
2825 prop_assert_eq!(builder.group_by.len(), fields.len());
2827 }
2828 }
2829
2830 proptest! {
2833 #![proptest_config(ProptestConfig::with_cases(100))]
2834
2835 #[test]
2836 fn prop_to_sql_returns_valid_sql(
2837 table_name in table_name_strategy(),
2838 fields in prop::collection::vec(field_name_strategy(), 0..5),
2839 use_distinct in any::<bool>(),
2840 limit_opt in prop::option::of(1u64..100),
2841 offset_opt in prop::option::of(0u64..100)
2842 ) {
2843 let pool = create_test_pool_sync();
2844 let mut builder = QueryBuilder::new(&pool, &table_name, false);
2845
2846 for field in &fields {
2848 builder = builder.field(field);
2849 }
2850
2851 if use_distinct {
2853 builder = builder.distinct();
2854 }
2855
2856 if let Some(limit) = limit_opt {
2858 builder = builder.limit(limit);
2859 }
2860
2861 if let Some(offset) = offset_opt {
2863 builder = builder.offset(offset);
2864 }
2865
2866 let sql = builder.to_sql();
2868
2869 prop_assert!(!sql.is_empty(), "SQL 字符串不应为空");
2871
2872 prop_assert!(sql.contains("SELECT"), "SQL 应包含 SELECT 关键字");
2874 prop_assert!(sql.contains("FROM"), "SQL 应包含 FROM 关键字");
2875
2876 prop_assert!(sql.contains(&table_name), "SQL 应包含表名");
2878
2879 if use_distinct {
2881 prop_assert!(sql.contains("DISTINCT"), "SQL 应包含 DISTINCT 关键字");
2882 }
2883
2884 if let Some(limit) = limit_opt {
2886 prop_assert!(sql.contains("LIMIT"), "SQL 应包含 LIMIT 关键字");
2887 prop_assert!(sql.contains(&limit.to_string()), "SQL 应包含 LIMIT 值");
2888 }
2889
2890 if let Some(offset) = offset_opt {
2892 prop_assert!(sql.contains("OFFSET"), "SQL 应包含 OFFSET 关键字");
2893 prop_assert!(sql.contains(&offset.to_string()), "SQL 应包含 OFFSET 值");
2894 }
2895
2896 if !fields.is_empty() {
2898 for field in &fields {
2899 prop_assert!(sql.contains(field), "SQL 应包含字段 {}", field);
2900 }
2901 } else {
2902 prop_assert!(sql.contains("*"), "SQL 应包含 * 表示所有字段");
2904 }
2905 }
2906
2907 #[test]
2908 fn prop_to_sql_with_conditions(
2909 table_name in table_name_strategy(),
2910 field in field_name_strategy(),
2911 value in any::<i32>()
2912 ) {
2913 let pool = create_test_pool_sync();
2914 let builder = QueryBuilder::new(&pool, &table_name, false)
2915 .where_and(&field, "=", value);
2916
2917 let sql = builder.to_sql();
2918
2919 prop_assert!(!sql.is_empty());
2921 prop_assert!(sql.contains("SELECT"));
2922 prop_assert!(sql.contains("FROM"));
2923 prop_assert!(sql.contains(&table_name));
2924
2925 prop_assert!(sql.contains("WHERE"), "SQL 应包含 WHERE 关键字");
2927 }
2928
2929 #[test]
2930 fn prop_to_sql_with_joins(
2931 table_name in table_name_strategy(),
2932 join_table in table_name_strategy(),
2933 on_field1 in field_name_strategy(),
2934 on_field2 in field_name_strategy()
2935 ) {
2936 let pool = create_test_pool_sync();
2937 let on_condition = format!("{}.{} = {}.{}", table_name, on_field1, join_table, on_field2);
2938 let builder = QueryBuilder::new(&pool, &table_name, false)
2939 .join(&join_table, &on_condition);
2940
2941 let sql = builder.to_sql();
2942
2943 prop_assert!(!sql.is_empty());
2945 prop_assert!(sql.contains("SELECT"));
2946 prop_assert!(sql.contains("FROM"));
2947
2948 prop_assert!(sql.contains("JOIN"), "SQL 应包含 JOIN 关键字");
2950 prop_assert!(sql.contains(&join_table), "SQL 应包含连接的表名");
2951 }
2952
2953 #[test]
2954 fn prop_to_sql_with_order_and_group(
2955 table_name in table_name_strategy(),
2956 order_field in field_name_strategy(),
2957 group_field in field_name_strategy(),
2958 asc in any::<bool>()
2959 ) {
2960 let pool = create_test_pool_sync();
2961 let builder = QueryBuilder::new(&pool, &table_name, false)
2962 .order(&order_field, asc)
2963 .group(&group_field);
2964
2965 let sql = builder.to_sql();
2966
2967 prop_assert!(!sql.is_empty());
2969 prop_assert!(sql.contains("SELECT"));
2970 prop_assert!(sql.contains("FROM"));
2971
2972 prop_assert!(sql.contains("ORDER BY"), "SQL 应包含 ORDER BY 关键字");
2974 prop_assert!(sql.contains("GROUP BY"), "SQL 应包含 GROUP BY 关键字");
2975 prop_assert!(sql.contains(&order_field), "SQL 应包含排序字段");
2976 prop_assert!(sql.contains(&group_field), "SQL 应包含分组字段");
2977 }
2978
2979 #[test]
2980 fn prop_to_sql_complex_query(
2981 table_name in table_name_strategy(),
2982 fields in prop::collection::vec(field_name_strategy(), 1..3),
2983 join_table in table_name_strategy(),
2984 where_field in field_name_strategy(),
2985 order_field in field_name_strategy(),
2986 group_field in field_name_strategy()
2987 ) {
2988 let pool = create_test_pool_sync();
2989 let mut builder = QueryBuilder::new(&pool, &table_name, false);
2990
2991 for field in &fields {
2993 builder = builder.field(field);
2994 }
2995
2996 let on_condition = format!("{}.id = {}.id", table_name, join_table);
2998 builder = builder.join(&join_table, &on_condition);
2999
3000 builder = builder.where_and(&where_field, "=", 1);
3002
3003 builder = builder.order(&order_field, true);
3005
3006 builder = builder.group(&group_field);
3008
3009 builder = builder.limit(10);
3011
3012 let sql = builder.to_sql();
3013
3014 prop_assert!(!sql.is_empty());
3016 prop_assert!(sql.contains("SELECT"));
3017 prop_assert!(sql.contains("FROM"));
3018 prop_assert!(sql.contains(&table_name));
3019 prop_assert!(sql.contains("JOIN"));
3020 prop_assert!(sql.contains("WHERE"));
3021 prop_assert!(sql.contains("ORDER BY"));
3022 prop_assert!(sql.contains("GROUP BY"));
3023 prop_assert!(sql.contains("LIMIT"));
3024
3025 let select_pos = sql.find("SELECT").unwrap();
3027 let from_pos = sql.find("FROM").unwrap();
3028 let join_pos = sql.find("JOIN").unwrap();
3029 let where_pos = sql.find("WHERE").unwrap();
3030 let group_pos = sql.find("GROUP BY").unwrap();
3031 let order_pos = sql.find("ORDER BY").unwrap();
3032 let limit_pos = sql.find("LIMIT").unwrap();
3033
3034 prop_assert!(select_pos < from_pos, "SELECT 应在 FROM 之前");
3036 prop_assert!(from_pos < join_pos, "FROM 应在 JOIN 之前");
3037 prop_assert!(join_pos < where_pos, "JOIN 应在 WHERE 之前");
3038 prop_assert!(where_pos < group_pos, "WHERE 应在 GROUP BY 之前");
3039 prop_assert!(group_pos < order_pos, "GROUP BY 应在 ORDER BY 之前");
3040 prop_assert!(order_pos < limit_pos, "ORDER BY 应在 LIMIT 之前");
3041 }
3042 }
3043
3044 proptest! {
3047 #![proptest_config(ProptestConfig::with_cases(100))]
3048
3049 #[test]
3050 fn prop_sql_injection_prevention_single_quote(
3051 table_name in table_name_strategy(),
3052 field in field_name_strategy(),
3053 malicious_input in ".*'.*"
3054 ) {
3055 let pool = create_test_pool_sync();
3056 let builder = QueryBuilder::new(&pool, &table_name, false)
3057 .where_and(&field, "=", malicious_input.as_str());
3058
3059 let sql = builder.to_sql();
3060
3061 prop_assert!(sql.contains("?"), "SQL 应该使用参数化查询(? 占位符)");
3064
3065 let where_clause = sql.split("WHERE").nth(1).unwrap_or("");
3068 prop_assert!(!where_clause.contains(&malicious_input),
3069 "WHERE 子句不应该直接包含用户输入的恶意字符串");
3070 }
3071
3072 #[test]
3073 fn prop_sql_injection_prevention_semicolon(
3074 table_name in table_name_strategy(),
3075 field in field_name_strategy(),
3076 malicious_input in ".*;.*"
3077 ) {
3078 let pool = create_test_pool_sync();
3079 let builder = QueryBuilder::new(&pool, &table_name, false)
3080 .where_and(&field, "=", malicious_input.as_str());
3081
3082 let sql = builder.to_sql();
3083
3084 prop_assert!(sql.contains("?"), "SQL 应该使用参数化查询");
3086
3087 let where_clause = sql.split("WHERE").nth(1).unwrap_or("");
3089 prop_assert!(!where_clause.contains(&malicious_input),
3090 "WHERE 子句不应该直接包含用户输入的恶意字符串");
3091 }
3092
3093 #[test]
3094 fn prop_sql_injection_prevention_comment(
3095 table_name in table_name_strategy(),
3096 field in field_name_strategy(),
3097 malicious_input in ".*--.*"
3098 ) {
3099 let pool = create_test_pool_sync();
3100 let builder = QueryBuilder::new(&pool, &table_name, false)
3101 .where_and(&field, "=", malicious_input.as_str());
3102
3103 let sql = builder.to_sql();
3104
3105 prop_assert!(sql.contains("?"), "SQL 应该使用参数化查询");
3107
3108 let where_clause = sql.split("WHERE").nth(1).unwrap_or("");
3110 prop_assert!(!where_clause.contains(&malicious_input),
3111 "WHERE 子句不应该直接包含用户输入的恶意字符串");
3112 }
3113
3114 #[test]
3115 fn prop_sql_injection_prevention_drop_table(
3116 table_name in table_name_strategy(),
3117 field in field_name_strategy()
3118 ) {
3119 let pool = create_test_pool_sync();
3120 let malicious_input = "'; DROP TABLE users; --";
3121 let builder = QueryBuilder::new(&pool, &table_name, false)
3122 .where_and(&field, "=", malicious_input);
3123
3124 let sql = builder.to_sql();
3125
3126 prop_assert!(sql.contains("?"), "SQL 应该使用参数化查询");
3128
3129 prop_assert!(!sql.to_uppercase().contains("DROP TABLE"),
3131 "SQL 不应该包含 DROP TABLE 语句");
3132
3133 let where_clause = sql.split("WHERE").nth(1).unwrap_or("");
3135 prop_assert!(!where_clause.contains(malicious_input),
3136 "WHERE 子句不应该直接包含用户输入的恶意字符串");
3137 }
3138
3139 #[test]
3140 fn prop_sql_injection_prevention_union_select(
3141 table_name in table_name_strategy(),
3142 field in field_name_strategy()
3143 ) {
3144 let pool = create_test_pool_sync();
3145 let malicious_input = "' UNION SELECT * FROM passwords --";
3146 let builder = QueryBuilder::new(&pool, &table_name, false)
3147 .where_and(&field, "=", malicious_input);
3148
3149 let sql = builder.to_sql();
3150
3151 prop_assert!(sql.contains("?"), "SQL 应该使用参数化查询");
3153
3154 let sql_upper = sql.to_uppercase();
3156 let union_count = sql_upper.matches("UNION").count();
3157 prop_assert_eq!(union_count, 0, "SQL 不应该包含 UNION 注入");
3158
3159 let where_clause = sql.split("WHERE").nth(1).unwrap_or("");
3161 prop_assert!(!where_clause.contains(malicious_input),
3162 "WHERE 子句不应该直接包含用户输入的恶意字符串");
3163 }
3164
3165 #[test]
3166 fn prop_sql_injection_prevention_or_always_true(
3167 table_name in table_name_strategy(),
3168 field in field_name_strategy()
3169 ) {
3170 let pool = create_test_pool_sync();
3171 let malicious_input = "' OR '1'='1";
3172 let builder = QueryBuilder::new(&pool, &table_name, false)
3173 .where_and(&field, "=", malicious_input);
3174
3175 let sql = builder.to_sql();
3176
3177 prop_assert!(sql.contains("?"), "SQL 应该使用参数化查询");
3179
3180 let where_clause = sql.split("WHERE").nth(1).unwrap_or("");
3182 prop_assert!(!where_clause.contains(malicious_input),
3183 "WHERE 子句不应该直接包含用户输入的恶意字符串");
3184
3185 let or_count = where_clause.matches(" OR ").count();
3188 prop_assert_eq!(or_count, 0, "不应该因为用户输入而产生 OR 条件");
3190 }
3191
3192 #[test]
3193 fn prop_sql_injection_prevention_multiple_special_chars(
3194 table_name in table_name_strategy(),
3195 field in field_name_strategy(),
3196 malicious_input in "[a-z0-9]*[';\"\\-][a-z0-9]*[';\"\\-][a-z0-9]*"
3197 ) {
3198 let pool = create_test_pool_sync();
3199 let builder = QueryBuilder::new(&pool, &table_name, false)
3200 .where_and(&field, "=", malicious_input.as_str());
3201
3202 let sql = builder.to_sql();
3203
3204 prop_assert!(sql.contains("?"), "SQL 应该使用参数化查询");
3206
3207 let where_clause = sql.split("WHERE").nth(1).unwrap_or("");
3209 prop_assert!(!where_clause.contains(&malicious_input),
3210 "WHERE 子句不应该直接包含用户输入的恶意字符串");
3211 }
3212
3213 #[test]
3214 fn prop_sql_injection_prevention_in_operator(
3215 table_name in table_name_strategy(),
3216 field in field_name_strategy(),
3217 malicious_values in prop::collection::vec(".*[';].*", 1..5)
3218 ) {
3219 let pool = create_test_pool_sync();
3220 let builder = QueryBuilder::new(&pool, &table_name, false)
3221 .where_in(&field, malicious_values.clone());
3222
3223 let sql = builder.to_sql();
3224
3225 prop_assert!(sql.contains("IN"), "SQL 应该包含 IN 操作符");
3227 prop_assert!(sql.contains("?"), "SQL 应该使用参数化查询");
3228
3229 let placeholder_count = sql.matches("?").count();
3231 prop_assert!(placeholder_count >= malicious_values.len(),
3232 "每个 IN 值都应该有对应的参数占位符");
3233
3234 for malicious_value in &malicious_values {
3236 let where_clause = sql.split("WHERE").nth(1).unwrap_or("");
3237 prop_assert!(!where_clause.contains(malicious_value),
3238 "WHERE 子句不应该直接包含用户输入的恶意字符串");
3239 }
3240 }
3241
3242 #[test]
3243 fn prop_sql_injection_prevention_like_operator(
3244 table_name in table_name_strategy(),
3245 field in field_name_strategy(),
3246 malicious_pattern in ".*[';].*"
3247 ) {
3248 let pool = create_test_pool_sync();
3249 let builder = QueryBuilder::new(&pool, &table_name, false)
3250 .where_and(&field, "like", malicious_pattern.as_str());
3251
3252 let sql = builder.to_sql();
3253
3254 prop_assert!(sql.contains("LIKE"), "SQL 应该包含 LIKE 操作符");
3256 prop_assert!(sql.contains("?"), "SQL 应该使用参数化查询");
3257
3258 let where_clause = sql.split("WHERE").nth(1).unwrap_or("");
3260 prop_assert!(!where_clause.contains(&malicious_pattern),
3261 "WHERE 子句不应该直接包含用户输入的恶意字符串");
3262 }
3263
3264 #[test]
3265 fn prop_sql_injection_prevention_between_operator(
3266 table_name in table_name_strategy(),
3267 field in field_name_strategy(),
3268 malicious_start in ".*[';].*",
3269 malicious_end in ".*[';].*"
3270 ) {
3271 let pool = create_test_pool_sync();
3272 let builder = QueryBuilder::new(&pool, &table_name, false)
3273 .where_between(&field, malicious_start.as_str(), malicious_end.as_str());
3274
3275 let sql = builder.to_sql();
3276
3277 prop_assert!(sql.contains("BETWEEN"), "SQL 应该包含 BETWEEN 操作符");
3279 prop_assert!(sql.contains("?"), "SQL 应该使用参数化查询");
3280
3281 let where_clause = sql.split("WHERE").nth(1).unwrap_or("");
3283 let placeholder_count = where_clause.matches("?").count();
3284 prop_assert!(placeholder_count >= 2, "BETWEEN 应该有两个参数占位符");
3285
3286 prop_assert!(!where_clause.contains(&malicious_start),
3288 "WHERE 子句不应该直接包含用户输入的恶意字符串");
3289 prop_assert!(!where_clause.contains(&malicious_end),
3290 "WHERE 子句不应该直接包含用户输入的恶意字符串");
3291 }
3292 }
3293
3294 proptest! {
3297 #![proptest_config(ProptestConfig::with_cases(100))]
3298
3299 #[test]
3300 fn prop_find_adds_limit_one(
3301 table_name in table_name_strategy(),
3302 field in field_name_strategy(),
3303 value in any::<i32>()
3304 ) {
3305 let pool = create_test_pool_sync();
3306
3307 let builder = QueryBuilder::new(&pool, &table_name, false)
3309 .field(&field)
3310 .where_and(&field, "=", value)
3311 .limit(1); let sql = builder.to_sql();
3314
3315 prop_assert!(sql.contains("LIMIT 1"),
3317 "find() 方法应该自动添加 LIMIT 1 到查询中");
3318 }
3319 }
3320
3321 proptest! {
3329 #![proptest_config(ProptestConfig::with_cases(100))]
3330
3331 #[test]
3332 fn prop_count_aggregation_function(
3333 table_name in table_name_strategy()
3334 ) {
3335 let pool = create_test_pool_sync();
3336
3337 let builder = QueryBuilder::new(&pool, &table_name, false)
3340 .field("COUNT(*)");
3341
3342 let sql = builder.to_sql();
3343
3344 prop_assert!(
3346 sql.contains("COUNT(*)") || sql.contains("COUNT("),
3347 "count() 方法应该生成包含 COUNT(*) 或 COUNT(field) 的 SQL 语句,实际 SQL: {}",
3348 sql
3349 );
3350
3351 prop_assert!(
3353 sql.to_uppercase().contains("SELECT"),
3354 "count() 方法应该生成 SELECT 语句,实际 SQL: {}",
3355 sql
3356 );
3357
3358 prop_assert!(
3360 sql.contains(&format!("FROM {}", table_name)),
3361 "count() 方法应该包含正确的表名,实际 SQL: {}",
3362 sql
3363 );
3364 }
3365 }
3366
3367 proptest! {
3375 #![proptest_config(ProptestConfig::with_cases(100))]
3376
3377 #[test]
3378 fn prop_count_with_where_condition(
3379 table_name in table_name_strategy(),
3380 field_name in field_name_strategy(),
3381 field_value in 1i32..1000i32,
3382 ) {
3383 let pool = create_test_pool_sync();
3384
3385 let builder = QueryBuilder::new(&pool, &table_name, false)
3387 .where_and(&field_name, "=", field_value)
3388 .field("COUNT(*)");
3389
3390 let sql = builder.to_sql();
3391
3392 prop_assert!(
3394 sql.contains("COUNT(*)"),
3395 "带条件的 count() 查询应该包含 COUNT(*),实际 SQL: {}",
3396 sql
3397 );
3398
3399 prop_assert!(
3401 sql.to_uppercase().contains("WHERE"),
3402 "带条件的 count() 查询应该包含 WHERE 子句,实际 SQL: {}",
3403 sql
3404 );
3405
3406 prop_assert!(
3408 sql.contains(&format!("FROM {}", table_name)),
3409 "count() 方法应该包含正确的表名,实际 SQL: {}",
3410 sql
3411 );
3412 }
3413 }
3414
3415 proptest! {
3423 #![proptest_config(ProptestConfig::with_cases(100))]
3424
3425 #[test]
3426 fn prop_count_specific_field(
3427 table_name in table_name_strategy(),
3428 field_name in field_name_strategy(),
3429 ) {
3430 let pool = create_test_pool_sync();
3431
3432 let count_expr = format!("COUNT({})", field_name);
3434 let builder = QueryBuilder::new(&pool, &table_name, false)
3435 .field(&count_expr);
3436
3437 let sql = builder.to_sql();
3438
3439 prop_assert!(
3441 sql.contains(&count_expr),
3442 "COUNT 特定字段应该包含 COUNT(field_name),实际 SQL: {}",
3443 sql
3444 );
3445
3446 prop_assert!(
3448 sql.to_uppercase().contains("SELECT"),
3449 "COUNT 查询应该是 SELECT 语句,实际 SQL: {}",
3450 sql
3451 );
3452
3453 prop_assert!(
3455 sql.contains(&format!("FROM {}", table_name)),
3456 "COUNT 查询应该包含正确的表名,实际 SQL: {}",
3457 sql
3458 );
3459 }
3460 }
3461
3462 proptest! {
3470 #![proptest_config(ProptestConfig::with_cases(100))]
3471
3472 #[test]
3473 fn prop_sum_aggregation_function(
3474 table_name in table_name_strategy(),
3475 field in field_name_strategy()
3476 ) {
3477 let pool = create_test_pool_sync();
3478
3479 let sum_expr = format!("CAST(SUM({}) AS DOUBLE)", field);
3482 let builder = QueryBuilder::new(&pool, &table_name, false)
3483 .field(&sum_expr);
3484
3485 let sql = builder.to_sql();
3486
3487 prop_assert!(
3489 sql.contains("SUM("),
3490 "sum() 方法应该生成包含 SUM(field) 的 SQL 语句,实际 SQL: {}",
3491 sql
3492 );
3493
3494 prop_assert!(
3496 sql.contains(&field),
3497 "sum() 方法生成的 SQL 应该包含指定的字段名 {},实际 SQL: {}",
3498 field,
3499 sql
3500 );
3501
3502 prop_assert!(
3504 sql.to_uppercase().contains("SELECT"),
3505 "sum() 方法应该生成 SELECT 语句,实际 SQL: {}",
3506 sql
3507 );
3508
3509 prop_assert!(
3511 sql.contains(&format!("FROM {}", table_name)),
3512 "sum() 方法应该包含正确的表名,实际 SQL: {}",
3513 sql
3514 );
3515
3516 prop_assert!(
3518 sql.to_uppercase().contains("CAST"),
3519 "sum() 方法应该使用 CAST 转换结果为 DOUBLE,实际 SQL: {}",
3520 sql
3521 );
3522 }
3523 }
3524
3525 proptest! {
3533 #![proptest_config(ProptestConfig::with_cases(100))]
3534
3535 #[test]
3536 fn prop_sum_with_where_condition(
3537 table_name in table_name_strategy(),
3538 sum_field in field_name_strategy(),
3539 where_field in field_name_strategy(),
3540 where_value in 1i32..1000i32,
3541 ) {
3542 prop_assume!(sum_field != where_field);
3544
3545 let pool = create_test_pool_sync();
3546
3547 let sum_expr = format!("CAST(SUM({}) AS DOUBLE)", sum_field);
3549 let builder = QueryBuilder::new(&pool, &table_name, false)
3550 .where_and(&where_field, "=", where_value)
3551 .field(&sum_expr);
3552
3553 let sql = builder.to_sql();
3554
3555 prop_assert!(
3557 sql.contains("SUM("),
3558 "带条件的 sum() 查询应该包含 SUM(field),实际 SQL: {}",
3559 sql
3560 );
3561
3562 prop_assert!(
3564 sql.contains(&sum_field),
3565 "sum() 方法应该包含求和字段名 {},实际 SQL: {}",
3566 sum_field,
3567 sql
3568 );
3569
3570 prop_assert!(
3572 sql.to_uppercase().contains("WHERE"),
3573 "带条件的 sum() 查询应该包含 WHERE 子句,实际 SQL: {}",
3574 sql
3575 );
3576
3577 prop_assert!(
3579 sql.contains(&format!("FROM {}", table_name)),
3580 "sum() 方法应该包含正确的表名,实际 SQL: {}",
3581 sql
3582 );
3583 }
3584 }
3585
3586 proptest! {
3594 #![proptest_config(ProptestConfig::with_cases(100))]
3595
3596 #[test]
3597 fn prop_sum_with_multiple_conditions(
3598 table_name in table_name_strategy(),
3599 sum_field in field_name_strategy(),
3600 where_field1 in field_name_strategy(),
3601 where_field2 in field_name_strategy(),
3602 value1 in 1i32..1000i32,
3603 value2 in 1i32..1000i32,
3604 ) {
3605 prop_assume!(sum_field != where_field1);
3607 prop_assume!(sum_field != where_field2);
3608 prop_assume!(where_field1 != where_field2);
3609
3610 let pool = create_test_pool_sync();
3611
3612 let sum_expr = format!("CAST(SUM({}) AS DOUBLE)", sum_field);
3614 let builder = QueryBuilder::new(&pool, &table_name, false)
3615 .where_and(&where_field1, "=", value1)
3616 .where_and(&where_field2, ">", value2)
3617 .field(&sum_expr);
3618
3619 let sql = builder.to_sql();
3620
3621 prop_assert!(
3623 sql.contains("SUM("),
3624 "多条件 sum() 查询应该包含 SUM(field),实际 SQL: {}",
3625 sql
3626 );
3627
3628 prop_assert!(
3630 sql.contains(&sum_field),
3631 "sum() 方法应该包含求和字段名 {},实际 SQL: {}",
3632 sum_field,
3633 sql
3634 );
3635
3636 prop_assert!(
3638 sql.to_uppercase().contains("WHERE"),
3639 "多条件查询应该包含 WHERE 子句,实际 SQL: {}",
3640 sql
3641 );
3642
3643 prop_assert!(
3645 sql.to_uppercase().contains(" AND "),
3646 "多个 where_and 条件应该用 AND 连接,实际 SQL: {}",
3647 sql
3648 );
3649 }
3650 }
3651}