1use crate::mysql::condition::{Condition, SqlValue};
2use crate::mysql::field::{FieldType, JoinClause, OrderClause};
3use sqlx::mysql::MySqlPool;
4use std::collections::HashMap;
5
6macro_rules! bind_value_match {
19 ($query:expr, $param:expr) => {
20 match $param {
21 SqlValue::Null => $query.bind(Option::<i32>::None),
23 SqlValue::Bool(b) => $query.bind(*b),
25 SqlValue::Int(i) => $query.bind(*i),
27 SqlValue::Float(f) => $query.bind(*f),
29 SqlValue::String(s) => $query.bind(s.clone()),
31 SqlValue::Bytes(b) => $query.bind(b.clone()),
33 SqlValue::Json(j) => $query.bind(j.to_string()),
35 SqlValue::DateTime(dt) => $query.bind(*dt),
37 SqlValue::Timestamp(ts) => $query.bind(*ts),
39 }
40 };
41}
42
43const INSERT_BATCH_SIZE: usize = 500;
48
49const UPDATE_BATCH_SIZE: usize = 1000;
51
52#[allow(dead_code)]
54pub(crate) struct SqlGenerator {
55 sql: String,
57 params: Vec<SqlValue>,
59}
60
61#[allow(dead_code)]
62impl SqlGenerator {
63 pub(crate) fn new() -> Self {
69 Self {
70 sql: String::with_capacity(256),
72 params: Vec::with_capacity(8),
74 }
75 }
76
77 pub(crate) fn get_sql(&self) -> &str {
79 &self.sql
80 }
81
82 pub(crate) fn get_params(&self) -> &[SqlValue] {
84 &self.params
85 }
86
87 fn append(&mut self, fragment: &str) {
89 self.sql.push_str(fragment);
90 }
91
92 fn add_param(&mut self, param: SqlValue) {
94 self.params.push(param);
95 }
96
97 fn clear(&mut self) {
99 self.sql.clear();
100 self.params.clear();
101 }
102
103 #[cfg(test)]
105 pub(crate) fn clear_for_test(&mut self) {
106 self.clear();
107 }
108
109 #[cfg(test)]
111 pub(crate) fn sql_capacity(&self) -> usize {
112 self.sql.capacity()
113 }
114
115 #[cfg(test)]
117 pub(crate) fn params_capacity(&self) -> usize {
118 self.params.capacity()
119 }
120
121 fn build_select(&mut self, builder: &QueryBuilder) -> Result<(), crate::error::DbError> {
130 self.clear();
132
133 self.append("SELECT ");
135
136 if builder.distinct {
138 self.append("DISTINCT ");
139 }
140
141 if builder.fields.is_empty() {
143 self.append("*");
144 } else {
145 self.append(&builder.fields.join(", "));
146 }
147
148 self.append(" FROM ");
150 self.append(&builder.table);
151
152 if !builder.joins.is_empty() {
154 self.build_joins(&builder.joins);
155 }
156
157 if !builder.conditions.is_empty() {
159 self.build_where(&builder.conditions)?;
160 }
161
162 if !builder.group_by.is_empty() {
164 self.build_group_by(&builder.group_by);
165 }
166
167 if !builder.having_clause.is_empty() {
169 if builder.group_by.is_empty() {
170 return Err(crate::error::DbError::MissingGroupByClause);
171 }
172 self.build_having(&builder.having_clause)?;
173 }
174
175 if !builder.order_by.is_empty() {
177 self.build_order_by(&builder.order_by);
178 }
179
180 if let Some(limit) = builder.limit {
182 self.append(&format!(" LIMIT {}", limit));
183 }
184
185 if let Some(offset) = builder.offset {
187 self.append(&format!(" OFFSET {}", offset));
188 }
189
190 Ok(())
191 }
192
193 fn build_where(&mut self, conditions: &[Condition]) -> Result<(), crate::error::DbError> {
202 if conditions.is_empty() {
203 return Ok(());
204 }
205
206 self.append(" WHERE ");
207
208 if conditions.len() == 1 {
210 let sql = crate::mysql::condition::condition_to_sql(&conditions[0], &mut self.params);
211 self.append(&sql);
212 } else {
213 let combined = Condition::And(conditions.to_vec());
215 let sql = crate::mysql::condition::condition_to_sql(&combined, &mut self.params);
216 self.append(&sql);
217 }
218
219 Ok(())
220 }
221
222 fn build_joins(&mut self, joins: &[JoinClause]) {
227 use crate::mysql::field::JoinType;
228
229 for join in joins {
230 let join_type_str = match join.join_type {
231 JoinType::Inner => " INNER JOIN ",
232 JoinType::Left => " LEFT JOIN ",
233 JoinType::Right => " RIGHT JOIN ",
234 };
235
236 self.append(join_type_str);
237 self.append(&join.table);
238 self.append(" ON ");
239 self.append(&join.on);
240 }
241 }
242
243 fn build_order_by(&mut self, orders: &[OrderClause]) {
248 if orders.is_empty() {
249 return;
250 }
251
252 self.append(" ORDER BY ");
253
254 let order_parts: Vec<String> = orders
255 .iter()
256 .map(|order| {
257 let direction = if order.asc { "ASC" } else { "DESC" };
258 format!("{} {}", order.field, direction)
259 })
260 .collect();
261
262 self.append(&order_parts.join(", "));
263 }
264
265 fn build_group_by(&mut self, groups: &[String]) {
270 if groups.is_empty() {
271 return;
272 }
273
274 self.append(" GROUP BY ");
275 self.append(&groups.join(", "));
276 }
277
278 fn build_having(&mut self, conditions: &[Condition]) -> Result<(), crate::error::DbError> {
280 self.append(" HAVING ");
281 if conditions.len() == 1 {
282 let sql = crate::mysql::condition::condition_to_sql(&conditions[0], &mut self.params);
283 self.append(&sql);
284 } else {
285 let parts: Vec<String> = conditions
286 .iter()
287 .map(|c| crate::mysql::condition::condition_to_sql(c, &mut self.params))
288 .collect();
289 self.append(&parts.join(" AND "));
290 }
291 Ok(())
292 }
293
294 pub(crate) fn build_insert(
305 &mut self,
306 table: &str,
307 data: &serde_json::Value,
308 field_types: &HashMap<String, FieldType>,
309 ) -> Result<(), crate::error::DbError> {
310 self.clear();
312
313 let obj = data.as_object().ok_or_else(|| {
315 crate::error::DbError::SerializationError("插入数据必须是 JSON 对象".to_string())
316 })?;
317
318 if obj.is_empty() {
319 return Err(crate::error::DbError::SerializationError(
320 "插入数据不能为空".to_string(),
321 ));
322 }
323
324 let mut fields = Vec::new();
326 let mut placeholders = Vec::new();
327
328 for (key, value) in obj.iter() {
329 fields.push(key.clone());
330 placeholders.push("?".to_string());
331
332 let sql_value = self.json_value_to_sql_value(value, field_types.get(key))?;
334 self.add_param(sql_value);
335 }
336
337 self.append("INSERT INTO ");
339 self.append(table);
340 self.append(" (");
341 self.append(&fields.join(", "));
342 self.append(") VALUES (");
343 self.append(&placeholders.join(", "));
344 self.append(")");
345
346 Ok(())
347 }
348
349 pub(crate) fn build_insert_batch(
360 &mut self,
361 table: &str,
362 data_list: &[serde_json::Value],
363 field_types: &HashMap<String, FieldType>,
364 ) -> Result<(), crate::error::DbError> {
365 self.clear();
367
368 if data_list.is_empty() {
369 return Err(crate::error::DbError::SerializationError(
370 "批量插入数据不能为空".to_string(),
371 ));
372 }
373
374 let first_obj = data_list[0].as_object().ok_or_else(|| {
376 crate::error::DbError::SerializationError("插入数据必须是 JSON 对象".to_string())
377 })?;
378
379 if first_obj.is_empty() {
380 return Err(crate::error::DbError::SerializationError(
381 "插入数据不能为空".to_string(),
382 ));
383 }
384
385 let fields: Vec<String> = first_obj.keys().cloned().collect();
387
388 self.append("INSERT INTO ");
390 self.append(table);
391 self.append(" (");
392 self.append(&fields.join(", "));
393 self.append(") VALUES ");
394
395 for (record_idx, data) in data_list.iter().enumerate() {
397 if record_idx > 0 {
399 self.sql.push_str(", ");
400 }
401
402 let obj = data.as_object().ok_or_else(|| {
403 crate::error::DbError::SerializationError("插入数据必须是 JSON 对象".to_string())
404 })?;
405
406 self.sql.push('(');
408
409 for (field_idx, field) in fields.iter().enumerate() {
411 if field_idx > 0 {
413 self.sql.push_str(", ");
414 }
415 self.sql.push('?');
416
417 let value = obj.get(field).unwrap_or(&serde_json::Value::Null);
419
420 let sql_value = self.json_value_to_sql_value(value, field_types.get(field))?;
422 self.add_param(sql_value);
423 }
424
425 self.sql.push(')');
427 }
428
429 Ok(())
430 }
431
432 pub(crate) fn build_update(
444 &mut self,
445 table: &str,
446 data: &serde_json::Value,
447 field_types: &HashMap<String, FieldType>,
448 conditions: &[Condition],
449 ) -> Result<(), crate::error::DbError> {
450 self.clear();
452
453 if conditions.is_empty() {
455 return Err(crate::error::DbError::MissingWhereClause);
456 }
457
458 let obj = data.as_object().ok_or_else(|| {
460 crate::error::DbError::SerializationError("更新数据必须是 JSON 对象".to_string())
461 })?;
462
463 if obj.is_empty() {
464 return Err(crate::error::DbError::SerializationError(
465 "更新数据不能为空".to_string(),
466 ));
467 }
468
469 self.append("UPDATE ");
471 self.append(table);
472 self.append(" SET ");
473
474 let mut set_clauses = Vec::new();
476
477 for (key, value) in obj.iter() {
478 set_clauses.push(format!("{} = ?", key));
479
480 let sql_value = self.json_value_to_sql_value(value, field_types.get(key))?;
482 self.add_param(sql_value);
483 }
484
485 self.append(&set_clauses.join(", "));
486
487 self.build_where(conditions)?;
489
490 Ok(())
491 }
492
493 pub(crate) fn build_delete(
503 &mut self,
504 table: &str,
505 conditions: &[Condition],
506 ) -> Result<(), crate::error::DbError> {
507 self.clear();
509
510 if conditions.is_empty() {
512 return Err(crate::error::DbError::MissingWhereClause);
513 }
514
515 self.append("DELETE FROM ");
517 self.append(table);
518
519 self.build_where(conditions)?;
521
522 Ok(())
523 }
524
525 pub(crate) fn build_update_batch(
530 &mut self,
531 table: &str,
532 records: &[serde_json::Value],
533 id_field: &str,
534 field_types: &std::collections::HashMap<String, FieldType>,
535 ) -> Result<(), crate::error::DbError> {
536 self.clear();
537
538 if records.is_empty() {
539 return Err(crate::error::DbError::SerializationError(
540 "批量更新数据不能为空".to_string(),
541 ));
542 }
543
544 let first = records[0].as_object().ok_or_else(|| {
545 crate::error::DbError::SerializationError("更新数据必须是 JSON 对象".to_string())
546 })?;
547
548 let update_fields: Vec<String> = first
550 .keys()
551 .filter(|k| k.as_str() != id_field)
552 .cloned()
553 .collect();
554
555 if update_fields.is_empty() {
556 return Err(crate::error::DbError::SerializationError(
557 "没有可更新的字段".to_string(),
558 ));
559 }
560
561 self.sql.push_str("UPDATE ");
563 self.sql.push_str(table);
564 self.sql.push_str(" SET ");
565
566 for (field_idx, field) in update_fields.iter().enumerate() {
569 if field_idx > 0 {
571 self.sql.push_str(", ");
572 }
573
574 self.sql.push_str(field);
576 self.sql.push_str(" = CASE ");
577
578 for record in records {
580 let id_val = record.get(id_field).unwrap_or(&serde_json::Value::Null);
581 let field_val = record.get(field.as_str()).unwrap_or(&serde_json::Value::Null);
582
583 let id_sql_val = self.json_value_to_sql_value(id_val, field_types.get(id_field))?;
585 let field_sql_val =
586 self.json_value_to_sql_value(field_val, field_types.get(field.as_str()))?;
587
588 self.sql.push_str("WHEN ");
590 self.sql.push_str(id_field);
591 self.sql.push_str("=? THEN ? ");
592
593 self.params.push(id_sql_val);
595 self.params.push(field_sql_val);
596 }
597
598 self.sql.push_str("END");
600 }
601
602 self.sql.push_str(" WHERE ");
604 self.sql.push_str(id_field);
605 self.sql.push_str(" IN (");
606
607 for (idx, record) in records.iter().enumerate() {
609 if idx > 0 {
610 self.sql.push(',');
611 }
612 self.sql.push('?');
613
614 let id_val = record.get(id_field).unwrap_or(&serde_json::Value::Null);
616 let id_sql_val = self.json_value_to_sql_value(id_val, field_types.get(id_field))?;
617 self.params.push(id_sql_val);
618 }
619
620 self.sql.push(')');
621
622 Ok(())
623 }
624
625 pub(crate) fn build_upsert(
627 &mut self,
628 table: &str,
629 data: &serde_json::Value,
630 field_types: &std::collections::HashMap<String, FieldType>,
631 ) -> Result<(), crate::error::DbError> {
632 self.clear();
633
634 let obj = data.as_object().ok_or_else(|| {
635 crate::error::DbError::SerializationError("插入数据必须是 JSON 对象".to_string())
636 })?;
637
638 if obj.is_empty() {
639 return Err(crate::error::DbError::SerializationError(
640 "插入数据不能为空".to_string(),
641 ));
642 }
643
644 let fields: Vec<String> = obj.keys().cloned().collect();
645 let placeholders: Vec<&str> = fields.iter().map(|_| "?").collect();
646
647 self.append(&format!(
648 "INSERT INTO {} ({}) VALUES ({})",
649 table,
650 fields.join(", "),
651 placeholders.join(", ")
652 ));
653
654 for field in &fields {
655 let val = obj.get(field.as_str()).unwrap_or(&serde_json::Value::Null);
656 self.add_param(self.json_value_to_sql_value(val, field_types.get(field.as_str()))?);
657 }
658
659 let update_parts: Vec<String> = fields
660 .iter()
661 .map(|f| format!("{}=VALUES({})", f, f))
662 .collect();
663
664 self.append(&format!(
665 " ON DUPLICATE KEY UPDATE {}",
666 update_parts.join(", ")
667 ));
668
669 Ok(())
670 }
671
672 fn json_value_to_sql_value(
681 &self,
682 value: &serde_json::Value,
683 field_type: Option<&FieldType>,
684 ) -> Result<SqlValue, crate::error::DbError> {
685 use serde_json::Value;
686
687 if let Some(ft) = field_type {
689 match ft {
690 FieldType::Json => {
691 return Ok(SqlValue::Json(value.clone()));
693 }
694 FieldType::DateTime => {
695 if let Some(s) = value.as_str() {
697 let dt = chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S")
698 .map_err(|e| {
699 crate::error::DbError::TypeConversionError(format!(
700 "无法解析 DATETIME 字符串: {}",
701 e
702 ))
703 })?;
704 return Ok(SqlValue::DateTime(dt));
705 }
706 }
707 FieldType::Timestamp => {
708 if let Some(i) = value.as_i64() {
710 return Ok(SqlValue::Timestamp(i));
711 }
712 }
713 FieldType::Decimal => {
714 if let Some(f) = value.as_f64() {
716 return Ok(SqlValue::Float(f));
717 } else if let Some(i) = value.as_i64() {
718 return Ok(SqlValue::Float(i as f64));
719 }
720 }
721 FieldType::Blob => {
722 if let Some(s) = value.as_str() {
724 use base64::Engine;
726 if let Ok(bytes) = base64::engine::general_purpose::STANDARD.decode(s) {
727 return Ok(SqlValue::Bytes(bytes));
728 }
729 return Ok(SqlValue::Bytes(s.as_bytes().to_vec()));
731 }
732 }
733 FieldType::Text => {
734 if let Some(s) = value.as_str() {
736 return Ok(SqlValue::String(s.to_string()));
737 }
738 }
739 FieldType::Standard => {
740 }
742 }
743 }
744
745 match value {
747 Value::Null => Ok(SqlValue::Null),
748 Value::Bool(b) => Ok(SqlValue::Bool(*b)),
749 Value::Number(n) => {
750 if let Some(i) = n.as_i64() {
751 Ok(SqlValue::Int(i))
752 } else if let Some(f) = n.as_f64() {
753 Ok(SqlValue::Float(f))
754 } else {
755 Err(crate::error::DbError::TypeConversionError(
756 "无法转换数字类型".to_string(),
757 ))
758 }
759 }
760 Value::String(s) => Ok(SqlValue::String(s.clone())),
761 Value::Array(_) | Value::Object(_) => {
762 Ok(SqlValue::Json(value.clone()))
764 }
765 }
766 }
767}
768
769pub struct QueryBuilder<'a> {
771 #[allow(dead_code)]
772 pool: &'a MySqlPool,
773 table: String,
774 fields: Vec<String>,
775 #[allow(dead_code)]
776 conditions: Vec<Condition>,
777 #[allow(dead_code)]
778 joins: Vec<JoinClause>,
779 #[allow(dead_code)]
780 order_by: Vec<OrderClause>,
781 #[allow(dead_code)]
782 group_by: Vec<String>,
783 #[allow(dead_code)]
784 having_clause: Vec<Condition>,
785 limit: Option<u64>,
786 offset: Option<u64>,
787 distinct: bool,
788 field_types: HashMap<String, FieldType>,
789 #[allow(dead_code)]
790 enable_logging: bool,
791}
792
793impl<'a> QueryBuilder<'a> {
794 pub(crate) fn new(pool: &'a MySqlPool, table_name: &str, enable_logging: bool) -> Self {
796 Self {
797 pool,
798 table: table_name.to_string(),
799 fields: Vec::new(),
800 conditions: Vec::new(),
801 joins: Vec::new(),
802 order_by: Vec::new(),
803 group_by: Vec::new(),
804 having_clause: Vec::new(),
805 limit: None,
806 offset: None,
807 distinct: false,
808 field_types: HashMap::new(),
809 enable_logging,
810 }
811 }
812
813 pub fn field(mut self, field: &str) -> Self {
815 self.fields.push(field.to_string());
816 self
817 }
818
819 pub fn fields(mut self, fields: &[&str]) -> Self {
821 for field in fields {
822 self.fields.push(field.to_string());
823 }
824 self
825 }
826
827 pub fn json(mut self, field: &str) -> Self {
829 self.field_types.insert(field.to_string(), FieldType::Json);
830 self
831 }
832
833 pub fn datetime(mut self, field: &str) -> Self {
835 self.field_types
836 .insert(field.to_string(), FieldType::DateTime);
837 self
838 }
839
840 pub fn timestamp(mut self, field: &str) -> Self {
842 self.field_types
843 .insert(field.to_string(), FieldType::Timestamp);
844 self
845 }
846
847 pub fn decimal(mut self, field: &str) -> Self {
849 self.field_types
850 .insert(field.to_string(), FieldType::Decimal);
851 self
852 }
853
854 pub fn blob(mut self, field: &str) -> Self {
856 self.field_types.insert(field.to_string(), FieldType::Blob);
857 self
858 }
859
860 pub fn text(mut self, field: &str) -> Self {
862 self.field_types.insert(field.to_string(), FieldType::Text);
863 self
864 }
865
866 pub fn distinct(mut self) -> Self {
868 self.distinct = true;
869 self
870 }
871
872 pub fn where_and<V>(mut self, field: &str, op: &str, value: V) -> Result<Self, crate::error::DbError>
886 where
887 V: Into<crate::mysql::condition::SqlValue>,
888 {
889 use crate::mysql::condition::{Condition, SqlValue};
890
891 let sql_value = value.into();
892 let condition = match op {
893 "=" => Condition::Eq(field.to_string(), sql_value),
894 "!=" => Condition::Ne(field.to_string(), sql_value),
895 ">" => Condition::Gt(field.to_string(), sql_value),
896 "<" => Condition::Lt(field.to_string(), sql_value),
897 ">=" => Condition::Gte(field.to_string(), sql_value),
898 "<=" => Condition::Lte(field.to_string(), sql_value),
899 "like" | "LIKE" => {
900 if let SqlValue::String(s) = sql_value {
901 Condition::Like(field.to_string(), s)
902 } else {
903 Condition::Like(field.to_string(), format!("{:?}", sql_value))
905 }
906 }
907 _ => return Err(crate::error::DbError::UnsupportedOperator(op.to_string())),
909 };
910
911 self.conditions.push(condition);
912 Ok(self)
913 }
914
915 pub fn where_and_unchecked<V>(self, field: &str, op: &str, value: V) -> Self
925 where
926 V: Into<crate::mysql::condition::SqlValue>,
927 {
928 self.where_and(field, op, value)
930 .unwrap_or_else(|e| panic!("{}", e))
931 }
932
933 pub fn where_or<V>(mut self, field: &str, op: &str, value: V) -> Result<Self, crate::error::DbError>
947 where
948 V: Into<crate::mysql::condition::SqlValue>,
949 {
950 use crate::mysql::condition::{Condition, SqlValue};
951
952 let sql_value = value.into();
953 let condition = match op {
954 "=" => Condition::Eq(field.to_string(), sql_value),
955 "!=" => Condition::Ne(field.to_string(), sql_value),
956 ">" => Condition::Gt(field.to_string(), sql_value),
957 "<" => Condition::Lt(field.to_string(), sql_value),
958 ">=" => Condition::Gte(field.to_string(), sql_value),
959 "<=" => Condition::Lte(field.to_string(), sql_value),
960 "like" | "LIKE" => {
961 if let SqlValue::String(s) = sql_value {
962 Condition::Like(field.to_string(), s)
963 } else {
964 Condition::Like(field.to_string(), format!("{:?}", sql_value))
965 }
966 }
967 _ => return Err(crate::error::DbError::UnsupportedOperator(op.to_string())),
969 };
970
971 if !self.conditions.is_empty() {
973 let existing = std::mem::take(&mut self.conditions);
974 self.conditions.push(Condition::Or(vec![
975 if existing.len() == 1 {
976 existing.into_iter().next().unwrap()
977 } else {
978 Condition::And(existing)
979 },
980 condition,
981 ]));
982 } else {
983 self.conditions.push(condition);
984 }
985
986 Ok(self)
987 }
988
989 pub fn where_or_unchecked<V>(self, field: &str, op: &str, value: V) -> Self
999 where
1000 V: Into<crate::mysql::condition::SqlValue>,
1001 {
1002 self.where_or(field, op, value)
1004 .unwrap_or_else(|e| panic!("{}", e))
1005 }
1006
1007 pub fn where_in<V>(mut self, field: &str, values: Vec<V>) -> Self
1009 where
1010 V: Into<crate::mysql::condition::SqlValue>,
1011 {
1012 use crate::mysql::condition::Condition;
1013
1014 let sql_values: Vec<_> = values.into_iter().map(|v| v.into()).collect();
1015 self.conditions
1016 .push(Condition::In(field.to_string(), sql_values));
1017 self
1018 }
1019
1020 pub fn where_between<V>(mut self, field: &str, start: V, end: V) -> Self
1022 where
1023 V: Into<crate::mysql::condition::SqlValue>,
1024 {
1025 use crate::mysql::condition::Condition;
1026
1027 self.conditions.push(Condition::Between(
1028 field.to_string(),
1029 start.into(),
1030 end.into(),
1031 ));
1032 self
1033 }
1034
1035 pub fn where_null(mut self, field: &str) -> Self {
1056 self.conditions.push(Condition::IsNull(field.to_string()));
1057 self
1058 }
1059
1060 pub fn where_not_null(mut self, field: &str) -> Self {
1067 self.conditions
1068 .push(Condition::IsNotNull(field.to_string()));
1069 self
1070 }
1071
1072 pub fn having_cond<V>(mut self, field: &str, op: &str, value: V) -> Result<Self, crate::error::DbError>
1100 where
1101 V: Into<crate::mysql::condition::SqlValue>,
1102 {
1103 let sql_value = value.into();
1104 let condition = match op {
1105 "=" => Condition::Eq(field.to_string(), sql_value),
1106 "!=" => Condition::Ne(field.to_string(), sql_value),
1107 ">" => Condition::Gt(field.to_string(), sql_value),
1108 "<" => Condition::Lt(field.to_string(), sql_value),
1109 ">=" => Condition::Gte(field.to_string(), sql_value),
1110 "<=" => Condition::Lte(field.to_string(), sql_value),
1111 _ => return Err(crate::error::DbError::UnsupportedOperator(op.to_string())),
1113 };
1114 self.having_clause.push(condition);
1115 Ok(self)
1116 }
1117
1118 pub fn having_cond_unchecked<V>(self, field: &str, op: &str, value: V) -> Self
1128 where
1129 V: Into<crate::mysql::condition::SqlValue>,
1130 {
1131 self.having_cond(field, op, value)
1133 .unwrap_or_else(|e| panic!("{}", e))
1134 }
1135
1136 pub fn join(mut self, table: &str, on: &str) -> Self {
1138 use crate::mysql::field::{JoinClause, JoinType};
1139
1140 self.joins.push(JoinClause {
1141 join_type: JoinType::Inner,
1142 table: table.to_string(),
1143 on: on.to_string(),
1144 });
1145 self
1146 }
1147
1148 pub fn left_join(mut self, table: &str, on: &str) -> Self {
1150 use crate::mysql::field::{JoinClause, JoinType};
1151
1152 self.joins.push(JoinClause {
1153 join_type: JoinType::Left,
1154 table: table.to_string(),
1155 on: on.to_string(),
1156 });
1157 self
1158 }
1159
1160 pub fn right_join(mut self, table: &str, on: &str) -> Self {
1162 use crate::mysql::field::{JoinClause, JoinType};
1163
1164 self.joins.push(JoinClause {
1165 join_type: JoinType::Right,
1166 table: table.to_string(),
1167 on: on.to_string(),
1168 });
1169 self
1170 }
1171
1172 pub fn order(mut self, field: &str, asc: bool) -> Self {
1174 use crate::mysql::field::OrderClause;
1175
1176 self.order_by.push(OrderClause {
1177 field: field.to_string(),
1178 asc,
1179 });
1180 self
1181 }
1182
1183 pub fn group(mut self, field: &str) -> Self {
1185 self.group_by.push(field.to_string());
1186 self
1187 }
1188
1189 pub fn limit(mut self, limit: u64) -> Self {
1191 self.limit = Some(limit);
1192 self
1193 }
1194
1195 pub fn offset(mut self, offset: u64) -> Self {
1197 self.offset = Some(offset);
1198 self
1199 }
1200
1201 pub fn to_sql(&self) -> String {
1206 let mut generator = SqlGenerator::new();
1207
1208 match generator.build_select(self) {
1210 Ok(_) => generator.get_sql().to_string(),
1211 Err(_) => {
1212 let fields_str = if self.fields.is_empty() {
1214 "*".to_string()
1215 } else {
1216 self.fields.join(", ")
1217 };
1218
1219 let distinct_str = if self.distinct { "DISTINCT " } else { "" };
1220
1221 format!("SELECT {}{} FROM {}", distinct_str, fields_str, self.table)
1222 }
1223 }
1224 }
1225
1226 pub async fn find<T>(mut self) -> Result<Option<T>, crate::error::DbError>
1264 where
1265 T: for<'r> sqlx::FromRow<'r, sqlx::mysql::MySqlRow> + Send + Unpin,
1266 {
1267 self.limit = Some(1);
1269
1270 let mut generator = SqlGenerator::new();
1272 generator.build_select(&self)?;
1273
1274 let sql = generator.get_sql();
1275 let params = generator.get_params();
1276
1277 if self.enable_logging {
1279 log::debug!("执行 find() 查询: {}", sql);
1280 log::debug!("参数: {:?}", params);
1281 }
1282
1283 let mut query = sqlx::query_as::<_, T>(sql);
1285
1286 for param in params {
1288 query = bind_param(query, param);
1289 }
1290
1291 let result = query.fetch_optional(self.pool).await;
1293
1294 match result {
1295 Ok(row) => {
1296 if self.enable_logging {
1297 if row.is_some() {
1298 log::debug!("find() 查询成功,返回 1 条记录");
1299 } else {
1300 log::debug!("find() 查询成功,未找到匹配记录");
1301 }
1302 }
1303 Ok(row)
1304 }
1305 Err(e) => {
1306 log::error!("find() 查询失败: {}", e);
1307 Err(crate::error::DbError::from(e))
1308 }
1309 }
1310 }
1311
1312 pub async fn select<T>(self) -> Result<Vec<T>, crate::error::DbError>
1350 where
1351 T: for<'r> sqlx::FromRow<'r, sqlx::mysql::MySqlRow> + Send + Unpin,
1352 {
1353 let mut generator = SqlGenerator::new();
1355 generator.build_select(&self)?;
1356
1357 let sql = generator.get_sql();
1358 let params = generator.get_params();
1359
1360 if self.enable_logging {
1362 log::debug!("执行 select() 查询: {}", sql);
1363 log::debug!("参数: {:?}", params);
1364 }
1365
1366 let mut query = sqlx::query_as::<_, T>(sql);
1368
1369 for param in params {
1371 query = bind_param(query, param);
1372 }
1373
1374 let result = query.fetch_all(self.pool).await;
1376
1377 match result {
1378 Ok(rows) => {
1379 if self.enable_logging {
1380 log::debug!("select() 查询成功,返回 {} 条记录", rows.len());
1381 }
1382 Ok(rows)
1383 }
1384 Err(e) => {
1385 log::error!("select() 查询失败: {}", e);
1386 Err(crate::error::DbError::from(e))
1387 }
1388 }
1389 }
1390
1391 pub async fn value<T>(mut self, field: &str) -> Result<Option<T>, crate::error::DbError>
1435 where
1436 T: for<'r> sqlx::Decode<'r, sqlx::MySql> + sqlx::Type<sqlx::MySql> + Send + Unpin,
1437 {
1438 self.fields.clear();
1440 self.fields.push(field.to_string());
1441
1442 self.limit = Some(1);
1444
1445 let mut generator = SqlGenerator::new();
1447 generator.build_select(&self)?;
1448
1449 let sql = generator.get_sql();
1450 let params = generator.get_params();
1451
1452 if self.enable_logging {
1454 log::debug!("执行 value() 查询: {}", sql);
1455 log::debug!("参数: {:?}", params);
1456 }
1457
1458 let mut query = sqlx::query_scalar::<_, T>(sql);
1460
1461 for param in params {
1463 query = bind_scalar_param(query, param);
1464 }
1465
1466 let result = query.fetch_optional(self.pool).await;
1468
1469 match result {
1470 Ok(value) => {
1471 if self.enable_logging {
1472 if value.is_some() {
1473 log::debug!("value() 查询成功,返回字段值");
1474 } else {
1475 log::debug!("value() 查询成功,未找到匹配记录");
1476 }
1477 }
1478 Ok(value)
1479 }
1480 Err(e) => {
1481 log::error!("value() 查询失败: {}", e);
1482 Err(crate::error::DbError::from(e))
1483 }
1484 }
1485 }
1486
1487 pub async fn count(self) -> Result<i64, crate::error::DbError> {
1518 if self.enable_logging {
1520 log::debug!("执行 count() 查询");
1521 }
1522
1523 let result = self.value::<i64>("COUNT(*)").await?;
1525
1526 Ok(result.unwrap_or(0))
1528 }
1529
1530 pub async fn sum(self, field: &str) -> Result<Option<f64>, crate::error::DbError> {
1574 if self.enable_logging {
1576 log::debug!("执行 sum() 查询,字段: {}", field);
1577 }
1578
1579 let sum_expr = format!("CAST(SUM({}) AS DOUBLE)", field);
1582
1583 let mut builder = self;
1585 builder.fields.clear();
1586 builder.fields.push(sum_expr.clone());
1587
1588 builder.limit = Some(1);
1590
1591 let mut generator = SqlGenerator::new();
1593 generator.build_select(&builder)?;
1594
1595 let sql = generator.get_sql();
1596 let params = generator.get_params();
1597
1598 if builder.enable_logging {
1600 log::debug!("执行 sum() 查询: {}", sql);
1601 log::debug!("参数: {:?}", params);
1602 }
1603
1604 let mut query = sqlx::query_scalar::<_, Option<f64>>(sql);
1606
1607 for param in params {
1609 query = bind_scalar_param_option(query, param);
1610 }
1611
1612 let result = query.fetch_optional(builder.pool).await;
1614
1615 match result {
1616 Ok(Some(value)) => {
1617 if builder.enable_logging {
1619 if value.is_some() {
1620 log::debug!("sum() 查询成功,返回总和");
1621 } else {
1622 log::debug!("sum() 查询成功,返回 None(没有匹配记录或所有值为 NULL)");
1623 }
1624 }
1625 Ok(value)
1626 }
1627 Ok(None) => {
1628 if builder.enable_logging {
1630 log::debug!("sum() 查询成功,未找到匹配记录");
1631 }
1632 Ok(None)
1633 }
1634 Err(e) => {
1635 log::error!("sum() 查询失败: {}", e);
1636 Err(crate::error::DbError::from(e))
1637 }
1638 }
1639 }
1640
1641 pub async fn avg(self, field: &str) -> Result<Option<f64>, crate::error::DbError> {
1689 if self.enable_logging {
1691 log::debug!("执行 avg() 查询,字段: {}", field);
1692 }
1693
1694 let avg_expr = format!("CAST(AVG({}) AS DOUBLE)", field);
1697
1698 let mut builder = self;
1700 builder.fields.clear();
1701 builder.fields.push(avg_expr.clone());
1702
1703 builder.limit = Some(1);
1705
1706 let mut generator = SqlGenerator::new();
1708 generator.build_select(&builder)?;
1709
1710 let sql = generator.get_sql();
1711 let params = generator.get_params();
1712
1713 if builder.enable_logging {
1715 log::debug!("执行 avg() 查询: {}", sql);
1716 log::debug!("参数: {:?}", params);
1717 }
1718
1719 let mut query = sqlx::query_scalar::<_, Option<f64>>(sql);
1721
1722 for param in params {
1724 query = bind_scalar_param_option(query, param);
1725 }
1726
1727 let result = query.fetch_optional(builder.pool).await;
1729
1730 match result {
1731 Ok(Some(value)) => {
1732 if builder.enable_logging {
1734 if value.is_some() {
1735 log::debug!("avg() 查询成功,返回平均值");
1736 } else {
1737 log::debug!("avg() 查询成功,返回 None(没有匹配记录或所有值为 NULL)");
1738 }
1739 }
1740 Ok(value)
1741 }
1742 Ok(None) => {
1743 if builder.enable_logging {
1745 log::debug!("avg() 查询成功,未找到匹配记录");
1746 }
1747 Ok(None)
1748 }
1749 Err(e) => {
1750 log::error!("avg() 查询失败: {}", e);
1751 Err(crate::error::DbError::from(e))
1752 }
1753 }
1754 }
1755
1756 pub async fn min<T>(self, field: &str) -> Result<Option<T>, crate::error::DbError>
1819 where
1820 T: for<'r> sqlx::Decode<'r, sqlx::MySql> + sqlx::Type<sqlx::MySql> + Send + Unpin,
1821 {
1822 if self.enable_logging {
1824 log::debug!("执行 min() 查询,字段: {}", field);
1825 }
1826
1827 let min_expr = format!("MIN({})", field);
1829
1830 let mut builder = self;
1832 builder.fields.clear();
1833 builder.fields.push(min_expr.clone());
1834
1835 builder.limit = Some(1);
1837
1838 let mut generator = SqlGenerator::new();
1840 generator.build_select(&builder)?;
1841
1842 let sql = generator.get_sql();
1843 let params = generator.get_params();
1844
1845 if builder.enable_logging {
1847 log::debug!("执行 min() 查询: {}", sql);
1848 log::debug!("参数: {:?}", params);
1849 }
1850
1851 let mut query = sqlx::query_scalar::<_, Option<T>>(sql);
1853
1854 for param in params {
1856 query = bind_scalar_param_option(query, param);
1857 }
1858
1859 let result = query.fetch_optional(builder.pool).await;
1861
1862 match result {
1863 Ok(Some(value)) => {
1864 if builder.enable_logging {
1866 if value.is_some() {
1867 log::debug!("min() 查询成功,返回最小值");
1868 } else {
1869 log::debug!("min() 查询成功,返回 None(没有匹配记录或所有值为 NULL)");
1870 }
1871 }
1872 Ok(value)
1873 }
1874 Ok(None) => {
1875 if builder.enable_logging {
1877 log::debug!("min() 查询成功,未找到匹配记录");
1878 }
1879 Ok(None)
1880 }
1881 Err(e) => {
1882 log::error!("min() 查询失败: {}", e);
1883 Err(crate::error::DbError::from(e))
1884 }
1885 }
1886 }
1887
1888 pub async fn max<T>(self, field: &str) -> Result<Option<T>, crate::error::DbError>
1951 where
1952 T: for<'r> sqlx::Decode<'r, sqlx::MySql> + sqlx::Type<sqlx::MySql> + Send + Unpin,
1953 {
1954 if self.enable_logging {
1956 log::debug!("执行 max() 查询,字段: {}", field);
1957 }
1958
1959 let max_expr = format!("MAX({})", field);
1961
1962 let mut builder = self;
1964 builder.fields.clear();
1965 builder.fields.push(max_expr.clone());
1966
1967 builder.limit = Some(1);
1969
1970 let mut generator = SqlGenerator::new();
1972 generator.build_select(&builder)?;
1973
1974 let sql = generator.get_sql();
1975 let params = generator.get_params();
1976
1977 if builder.enable_logging {
1979 log::debug!("执行 max() 查询: {}", sql);
1980 log::debug!("参数: {:?}", params);
1981 }
1982
1983 let mut query = sqlx::query_scalar::<_, Option<T>>(sql);
1985
1986 for param in params {
1988 query = bind_scalar_param_option(query, param);
1989 }
1990
1991 let result = query.fetch_optional(builder.pool).await;
1993
1994 match result {
1995 Ok(Some(value)) => {
1996 if builder.enable_logging {
1998 if value.is_some() {
1999 log::debug!("max() 查询成功,返回最大值");
2000 } else {
2001 log::debug!("max() 查询成功,返回 None(没有匹配记录或所有值为 NULL)");
2002 }
2003 }
2004 Ok(value)
2005 }
2006 Ok(None) => {
2007 if builder.enable_logging {
2009 log::debug!("max() 查询成功,未找到匹配记录");
2010 }
2011 Ok(None)
2012 }
2013 Err(e) => {
2014 log::error!("max() 查询失败: {}", e);
2015 Err(crate::error::DbError::from(e))
2016 }
2017 }
2018 }
2019
2020 pub async fn insert<T>(self, data: &T) -> Result<u64, crate::error::DbError>
2073 where
2074 T: serde::Serialize,
2075 {
2076 if self.enable_logging {
2078 log::debug!("执行 insert() 操作,表: {}", self.table);
2079 }
2080
2081 let json_data = serde_json::to_value(data).map_err(|e| {
2083 crate::error::DbError::SerializationError(format!("数据序列化失败: {}", e))
2084 })?;
2085
2086 let mut generator = SqlGenerator::new();
2088 generator.build_insert(&self.table, &json_data, &self.field_types)?;
2089
2090 let sql = generator.get_sql();
2091 let params = generator.get_params();
2092
2093 if self.enable_logging {
2095 log::debug!("执行 insert() SQL: {}", sql);
2096 log::debug!("参数: {:?}", params);
2097 }
2098
2099 let mut query = sqlx::query(sql);
2101
2102 for param in params {
2104 query = bind_execute_param(query, param);
2105 }
2106
2107 let result = query.execute(self.pool).await;
2109
2110 match result {
2111 Ok(query_result) => {
2112 let last_insert_id = query_result.last_insert_id();
2113 if self.enable_logging {
2114 log::debug!("insert() 成功,插入 ID: {}", last_insert_id);
2115 }
2116 Ok(last_insert_id)
2117 }
2118 Err(e) => {
2119 log::error!("insert() 失败: {}", e);
2120 Err(crate::error::DbError::from(e))
2121 }
2122 }
2123 }
2124
2125 pub async fn insert_batch<T>(self, data: &[T]) -> Result<u64, crate::error::DbError>
2193 where
2194 T: serde::Serialize,
2195 {
2196 if self.enable_logging {
2198 log::debug!(
2199 "执行 insert_batch() 操作,表: {},记录数: {}",
2200 self.table,
2201 data.len()
2202 );
2203 }
2204
2205 if data.is_empty() {
2207 return Err(crate::error::DbError::SerializationError(
2208 "批量插入数据不能为空".to_string(),
2209 ));
2210 }
2211
2212 if data.len() <= INSERT_BATCH_SIZE {
2214 return self.insert_chunk(data).await;
2215 }
2216
2217 let mut total_affected = 0u64;
2219
2220 for (batch_index, chunk) in data.chunks(INSERT_BATCH_SIZE).enumerate() {
2222 if self.enable_logging {
2223 log::debug!(
2224 "执行第 {} 批插入,本批记录数: {}",
2225 batch_index + 1,
2226 chunk.len()
2227 );
2228 }
2229
2230 let chunk_builder = QueryBuilder {
2233 pool: self.pool,
2234 table: self.table.clone(),
2235 fields: self.fields.clone(),
2236 conditions: self.conditions.clone(),
2237 joins: self.joins.clone(),
2238 order_by: self.order_by.clone(),
2239 group_by: self.group_by.clone(),
2240 having_clause: self.having_clause.clone(),
2241 limit: self.limit,
2242 offset: self.offset,
2243 distinct: self.distinct,
2244 field_types: self.field_types.clone(),
2245 enable_logging: self.enable_logging,
2246 };
2247
2248 let affected = chunk_builder.insert_chunk(chunk).await?;
2249 total_affected += affected;
2250
2251 if self.enable_logging {
2252 log::debug!("第 {} 批插入成功,影响 {} 行", batch_index + 1, affected);
2253 }
2254 }
2255
2256 if self.enable_logging {
2257 log::debug!("insert_batch() 全部完成,总共影响 {} 行", total_affected);
2258 }
2259
2260 Ok(total_affected)
2261 }
2262
2263 pub async fn insert_batch_with_size<T>(
2303 self,
2304 data: &[T],
2305 batch_size: usize,
2306 ) -> Result<u64, crate::error::DbError>
2307 where
2308 T: serde::Serialize,
2309 {
2310 if batch_size == 0 {
2312 return Err(crate::error::DbError::SerializationError(
2313 "batch_size 不能为 0".to_string(),
2314 ));
2315 }
2316
2317 if self.enable_logging {
2319 log::debug!(
2320 "执行 insert_batch_with_size() 操作,表: {},记录数: {},批次大小: {}",
2321 self.table,
2322 data.len(),
2323 batch_size
2324 );
2325 }
2326
2327 if data.is_empty() {
2329 return Err(crate::error::DbError::SerializationError(
2330 "批量插入数据不能为空".to_string(),
2331 ));
2332 }
2333
2334 let mut total_affected = 0u64;
2336
2337 for (batch_index, chunk) in data.chunks(batch_size).enumerate() {
2338 if self.enable_logging {
2339 log::debug!(
2340 "执行第 {} 批插入,本批记录数: {}",
2341 batch_index + 1,
2342 chunk.len()
2343 );
2344 }
2345
2346 let chunk_builder = QueryBuilder {
2348 pool: self.pool,
2349 table: self.table.clone(),
2350 fields: self.fields.clone(),
2351 conditions: self.conditions.clone(),
2352 joins: self.joins.clone(),
2353 order_by: self.order_by.clone(),
2354 group_by: self.group_by.clone(),
2355 having_clause: self.having_clause.clone(),
2356 limit: self.limit,
2357 offset: self.offset,
2358 distinct: self.distinct,
2359 field_types: self.field_types.clone(),
2360 enable_logging: self.enable_logging,
2361 };
2362
2363 let affected = chunk_builder.insert_chunk(chunk).await?;
2365 total_affected += affected;
2366
2367 if self.enable_logging {
2368 log::debug!("第 {} 批插入成功,影响 {} 行", batch_index + 1, affected);
2369 }
2370 }
2371
2372 if self.enable_logging {
2373 log::debug!(
2374 "insert_batch_with_size() 全部完成,总共影响 {} 行",
2375 total_affected
2376 );
2377 }
2378
2379 Ok(total_affected)
2380 }
2381
2382 async fn insert_chunk<T>(&self, data: &[T]) -> Result<u64, crate::error::DbError>
2397 where
2398 T: serde::Serialize,
2399 {
2400 let json_data_list: Result<Vec<_>, _> = data
2402 .iter()
2403 .map(|item| {
2404 serde_json::to_value(item).map_err(|e| {
2405 crate::error::DbError::SerializationError(format!("数据序列化失败: {}", e))
2406 })
2407 })
2408 .collect();
2409
2410 let json_data_list = json_data_list?;
2411
2412 let mut generator = SqlGenerator::new();
2414 generator.build_insert_batch(&self.table, &json_data_list, &self.field_types)?;
2415
2416 let sql = generator.get_sql();
2417 let params = generator.get_params();
2418
2419 if self.enable_logging {
2421 log::debug!("执行 insert_chunk() SQL: {}", sql);
2422 log::debug!("参数数量: {}", params.len());
2423 }
2424
2425 let mut query = sqlx::query(sql);
2427
2428 for param in params {
2430 query = bind_execute_param(query, param);
2431 }
2432
2433 let result = query.execute(self.pool).await;
2435
2436 match result {
2437 Ok(query_result) => {
2438 let rows_affected = query_result.rows_affected();
2439 if self.enable_logging {
2440 log::debug!("insert_chunk() 成功,影响 {} 行", rows_affected);
2441 }
2442 Ok(rows_affected)
2443 }
2444 Err(e) => {
2445 log::error!("insert_chunk() 失败: {}", e);
2446 Err(crate::error::DbError::from(e))
2447 }
2448 }
2449 }
2450
2451 pub async fn update<T>(self, data: &T) -> Result<u64, crate::error::DbError>
2490 where
2491 T: serde::Serialize,
2492 {
2493 if self.enable_logging {
2495 log::debug!("执行 update() 操作,表: {}", self.table);
2496 }
2497
2498 if self.conditions.is_empty() {
2500 log::warn!("update() 操作缺少 WHERE 条件,禁止全表更新");
2501 return Err(crate::error::DbError::MissingWhereClause);
2502 }
2503
2504 let json_data = serde_json::to_value(data).map_err(|e| {
2506 crate::error::DbError::SerializationError(format!("数据序列化失败: {}", e))
2507 })?;
2508
2509 let mut generator = SqlGenerator::new();
2511 generator.build_update(&self.table, &json_data, &self.field_types, &self.conditions)?;
2512
2513 let sql = generator.get_sql();
2514 let params = generator.get_params();
2515
2516 if self.enable_logging {
2518 log::debug!("执行 update() SQL: {}", sql);
2519 log::debug!("参数: {:?}", params);
2520 }
2521
2522 let mut query = sqlx::query(sql);
2524
2525 for param in params {
2527 query = bind_execute_param(query, param);
2528 }
2529
2530 let result = query.execute(self.pool).await;
2532
2533 match result {
2534 Ok(query_result) => {
2535 let rows_affected = query_result.rows_affected();
2536 if self.enable_logging {
2537 log::debug!("update() 成功,影响 {} 行", rows_affected);
2538 }
2539 Ok(rows_affected)
2540 }
2541 Err(e) => {
2542 log::error!("update() 失败: {}", e);
2543 Err(crate::error::DbError::from(e))
2544 }
2545 }
2546 }
2547
2548 pub async fn delete(self) -> Result<u64, crate::error::DbError> {
2575 if self.enable_logging {
2577 log::debug!("执行 delete() 操作,表: {}", self.table);
2578 }
2579
2580 if self.conditions.is_empty() {
2582 log::warn!("delete() 操作缺少 WHERE 条件,禁止全表删除");
2583 return Err(crate::error::DbError::MissingWhereClause);
2584 }
2585
2586 let mut generator = SqlGenerator::new();
2588 generator.build_delete(&self.table, &self.conditions)?;
2589
2590 let sql = generator.get_sql();
2591 let params = generator.get_params();
2592
2593 if self.enable_logging {
2595 log::debug!("执行 delete() SQL: {}", sql);
2596 log::debug!("参数: {:?}", params);
2597 }
2598
2599 let mut query = sqlx::query(sql);
2601
2602 for param in params {
2604 query = bind_execute_param(query, param);
2605 }
2606
2607 let result = query.execute(self.pool).await;
2609
2610 match result {
2611 Ok(query_result) => {
2612 let rows_affected = query_result.rows_affected();
2613 if self.enable_logging {
2614 log::debug!("delete() 成功,影响 {} 行", rows_affected);
2615 }
2616 Ok(rows_affected)
2617 }
2618 Err(e) => {
2619 log::error!("delete() 失败: {}", e);
2620 Err(crate::error::DbError::from(e))
2621 }
2622 }
2623 }
2624
2625 pub async fn update_batch<T>(
2655 self,
2656 records: &[T],
2657 where_field: &str,
2658 ) -> Result<u64, crate::error::DbError>
2659 where
2660 T: serde::Serialize,
2661 {
2662 if records.is_empty() {
2663 return Err(crate::error::DbError::SerializationError(
2664 "批量更新数据不能为空".to_string(),
2665 ));
2666 }
2667
2668 let json_records: Vec<serde_json::Value> = records
2669 .iter()
2670 .map(|r| {
2671 serde_json::to_value(r).map_err(|e| {
2672 crate::error::DbError::SerializationError(format!("数据序列化失败: {}", e))
2673 })
2674 })
2675 .collect::<Result<_, _>>()?;
2676
2677 let mut tx = self
2678 .pool
2679 .begin()
2680 .await
2681 .map_err(crate::error::DbError::from)?;
2682 let mut total = 0u64;
2683
2684 for chunk in json_records.chunks(UPDATE_BATCH_SIZE) {
2685 let mut generator = SqlGenerator::new();
2686 generator.build_update_batch(&self.table, chunk, where_field, &self.field_types)?;
2687
2688 let sql = generator.get_sql();
2689 let params = generator.get_params();
2690
2691 let mut query = sqlx::query(sql);
2692 for param in params {
2693 query = bind_execute_param(query, param);
2694 }
2695
2696 let result = query
2697 .execute(&mut *tx)
2698 .await
2699 .map_err(crate::error::DbError::from)?;
2700 total += result.rows_affected();
2701 }
2702
2703 tx.commit().await.map_err(crate::error::DbError::from)?;
2704 Ok(total)
2705 }
2706
2707 pub async fn upsert<T>(self, data: &T) -> Result<u64, crate::error::DbError>
2732 where
2733 T: serde::Serialize,
2734 {
2735 if self.enable_logging {
2736 log::debug!("执行 upsert() 操作,表: {}", self.table);
2737 }
2738
2739 let json_data = serde_json::to_value(data).map_err(|e| {
2740 crate::error::DbError::SerializationError(format!("数据序列化失败: {}", e))
2741 })?;
2742
2743 let mut generator = SqlGenerator::new();
2744 generator.build_upsert(&self.table, &json_data, &self.field_types)?;
2745
2746 let sql = generator.get_sql();
2747 let params = generator.get_params();
2748
2749 if self.enable_logging {
2750 log::debug!("执行 upsert() SQL: {}", sql);
2751 }
2752
2753 let mut query = sqlx::query(sql);
2754 for param in params {
2755 query = bind_execute_param(query, param);
2756 }
2757
2758 let result = query
2759 .execute(self.pool)
2760 .await
2761 .map_err(crate::error::DbError::from)?;
2762 let rows = result.rows_affected();
2763
2764 if self.enable_logging {
2765 log::debug!("upsert() 完成,rows_affected: {}", rows);
2766 }
2767
2768 Ok(rows)
2769 }
2770}
2771
2772fn bind_execute_param<'q>(
2781 query: sqlx::query::Query<'q, sqlx::MySql, sqlx::mysql::MySqlArguments>,
2782 param: &SqlValue,
2783) -> sqlx::query::Query<'q, sqlx::MySql, sqlx::mysql::MySqlArguments> {
2784 bind_value_match!(query, param)
2786}
2787
2788fn bind_param<'q, T>(
2797 query: sqlx::query::QueryAs<'q, sqlx::MySql, T, sqlx::mysql::MySqlArguments>,
2798 param: &SqlValue,
2799) -> sqlx::query::QueryAs<'q, sqlx::MySql, T, sqlx::mysql::MySqlArguments>
2800where
2801 T: for<'r> sqlx::FromRow<'r, sqlx::mysql::MySqlRow> + Send + Unpin,
2802{
2803 bind_value_match!(query, param)
2805}
2806
2807fn bind_scalar_param<'q, T>(
2816 query: sqlx::query::QueryScalar<'q, sqlx::MySql, T, sqlx::mysql::MySqlArguments>,
2817 param: &SqlValue,
2818) -> sqlx::query::QueryScalar<'q, sqlx::MySql, T, sqlx::mysql::MySqlArguments>
2819where
2820 T: for<'r> sqlx::Decode<'r, sqlx::MySql> + sqlx::Type<sqlx::MySql> + Send + Unpin,
2821{
2822 bind_value_match!(query, param)
2824}
2825
2826fn bind_scalar_param_option<'q, T>(
2835 query: sqlx::query::QueryScalar<'q, sqlx::MySql, Option<T>, sqlx::mysql::MySqlArguments>,
2836 param: &SqlValue,
2837) -> sqlx::query::QueryScalar<'q, sqlx::MySql, Option<T>, sqlx::mysql::MySqlArguments>
2838where
2839 T: for<'r> sqlx::Decode<'r, sqlx::MySql> + sqlx::Type<sqlx::MySql> + Send + Unpin,
2840{
2841 bind_value_match!(query, param)
2843}
2844
2845#[cfg(test)]
2846mod tests {
2847 use super::*;
2848 use sqlx::mysql::MySqlPoolOptions;
2849
2850 async fn create_test_pool() -> MySqlPool {
2852 MySqlPoolOptions::new()
2853 .max_connections(1)
2854 .connect("mysql://root:111111@localhost:3306/test")
2855 .await
2856 .expect("无法连接到测试数据库")
2857 }
2858
2859 #[tokio::test]
2860 async fn test_table_name_in_sql() {
2861 let pool = create_test_pool().await;
2862 let builder = QueryBuilder::new(&pool, "users", false);
2863 let sql = builder.to_sql();
2864 assert!(sql.contains("FROM users"));
2865 }
2866
2867 #[test]
2869 fn test_sql_generator_new() {
2870 let generator = SqlGenerator::new();
2871 assert_eq!(generator.get_sql(), "");
2872 assert_eq!(generator.get_params().len(), 0);
2873 }
2874
2875 #[test]
2876 fn test_sql_generator_append() {
2877 let mut generator = SqlGenerator::new();
2878 generator.append("SELECT * FROM users");
2879 assert_eq!(generator.get_sql(), "SELECT * FROM users");
2880 }
2881
2882 #[test]
2883 fn test_sql_generator_add_param() {
2884 let mut generator = SqlGenerator::new();
2885 generator.add_param(SqlValue::Int(42));
2886 generator.add_param(SqlValue::String("test".to_string()));
2887 assert_eq!(generator.get_params().len(), 2);
2888 }
2889
2890 #[test]
2891 fn test_sql_generator_clear() {
2892 let mut generator = SqlGenerator::new();
2893 generator.append("SELECT * FROM users");
2894 generator.add_param(SqlValue::Int(1));
2895
2896 generator.clear();
2897
2898 assert_eq!(generator.get_sql(), "");
2899 assert_eq!(generator.get_params().len(), 0);
2900 }
2901
2902 #[test]
2903 fn test_sql_generator_multiple_operations() {
2904 let mut generator = SqlGenerator::new();
2905
2906 generator.append("SELECT * FROM users WHERE id = ?");
2907 generator.add_param(SqlValue::Int(1));
2908 generator.append(" AND name = ?");
2909 generator.add_param(SqlValue::String("test".to_string()));
2910
2911 assert_eq!(
2912 generator.get_sql(),
2913 "SELECT * FROM users WHERE id = ? AND name = ?"
2914 );
2915 assert_eq!(generator.get_params().len(), 2);
2916 }
2917
2918 #[tokio::test]
2919 async fn test_field_selection() {
2920 let pool = create_test_pool().await;
2921 let builder = QueryBuilder::new(&pool, "users", false)
2922 .field("id")
2923 .field("name");
2924 let sql = builder.to_sql();
2925 assert!(sql.contains("id, name"));
2926 }
2927
2928 #[tokio::test]
2929 async fn test_fields_selection() {
2930 let pool = create_test_pool().await;
2931 let builder = QueryBuilder::new(&pool, "users", false).fields(&["id", "name", "email"]);
2932 let sql = builder.to_sql();
2933 assert!(sql.contains("id, name, email"));
2934 }
2935
2936 #[tokio::test]
2937 async fn test_distinct() {
2938 let pool = create_test_pool().await;
2939 let builder = QueryBuilder::new(&pool, "users", false)
2940 .field("name")
2941 .distinct();
2942 let sql = builder.to_sql();
2943 assert!(sql.contains("SELECT DISTINCT"));
2944 }
2945
2946 #[tokio::test]
2947 async fn test_field_type_marking() {
2948 let pool = create_test_pool().await;
2949 let builder = QueryBuilder::new(&pool, "users", false)
2950 .json("data")
2951 .datetime("created_at")
2952 .timestamp("updated_at")
2953 .decimal("price")
2954 .blob("content")
2955 .text("description");
2956
2957 assert_eq!(builder.field_types.get("data"), Some(&FieldType::Json));
2958 assert_eq!(
2959 builder.field_types.get("created_at"),
2960 Some(&FieldType::DateTime)
2961 );
2962 assert_eq!(
2963 builder.field_types.get("updated_at"),
2964 Some(&FieldType::Timestamp)
2965 );
2966 assert_eq!(builder.field_types.get("price"), Some(&FieldType::Decimal));
2967 assert_eq!(builder.field_types.get("content"), Some(&FieldType::Blob));
2968 assert_eq!(
2969 builder.field_types.get("description"),
2970 Some(&FieldType::Text)
2971 );
2972 }
2973
2974 #[tokio::test]
2975 async fn test_where_and() {
2976 let pool = create_test_pool().await;
2977 let builder = QueryBuilder::new(&pool, "users", false)
2978 .where_and_unchecked("name", "=", "test")
2979 .where_and_unchecked("age", ">", 18);
2980
2981 assert_eq!(builder.conditions.len(), 2);
2982 }
2983
2984 #[tokio::test]
2985 async fn test_where_or() {
2986 let pool = create_test_pool().await;
2987 let builder = QueryBuilder::new(&pool, "users", false)
2988 .where_or_unchecked("status", "=", 1)
2989 .where_or_unchecked("status", "=", 2);
2990
2991 assert_eq!(builder.conditions.len(), 1);
2993 }
2994
2995 #[tokio::test]
2996 async fn test_where_in() {
2997 let pool = create_test_pool().await;
2998 let builder = QueryBuilder::new(&pool, "users", false).where_in("id", vec![1, 2, 3]);
2999
3000 assert_eq!(builder.conditions.len(), 1);
3001 }
3002
3003 #[tokio::test]
3004 async fn test_where_between() {
3005 let pool = create_test_pool().await;
3006 let builder = QueryBuilder::new(&pool, "users", false).where_between("age", 18, 65);
3007
3008 assert_eq!(builder.conditions.len(), 1);
3009 }
3010
3011 #[tokio::test]
3012 async fn test_join() {
3013 let pool = create_test_pool().await;
3014 let builder =
3015 QueryBuilder::new(&pool, "users", false).join("orders", "users.id = orders.user_id");
3016
3017 assert_eq!(builder.joins.len(), 1);
3018 }
3019
3020 #[tokio::test]
3021 async fn test_left_join() {
3022 let pool = create_test_pool().await;
3023 let builder = QueryBuilder::new(&pool, "users", false)
3024 .left_join("orders", "users.id = orders.user_id");
3025
3026 assert_eq!(builder.joins.len(), 1);
3027 }
3028
3029 #[tokio::test]
3030 async fn test_right_join() {
3031 let pool = create_test_pool().await;
3032 let builder = QueryBuilder::new(&pool, "users", false)
3033 .right_join("orders", "users.id = orders.user_id");
3034
3035 assert_eq!(builder.joins.len(), 1);
3036 }
3037
3038 #[tokio::test]
3039 async fn test_order() {
3040 let pool = create_test_pool().await;
3041 let builder = QueryBuilder::new(&pool, "users", false)
3042 .order("name", true)
3043 .order("age", false);
3044
3045 assert_eq!(builder.order_by.len(), 2);
3046 }
3047
3048 #[tokio::test]
3049 async fn test_group() {
3050 let pool = create_test_pool().await;
3051 let builder = QueryBuilder::new(&pool, "users", false)
3052 .group("status")
3053 .group("role");
3054
3055 assert_eq!(builder.group_by.len(), 2);
3056 }
3057
3058 #[tokio::test]
3060 async fn test_select_with_where() {
3061 let pool = create_test_pool().await;
3062 let builder = QueryBuilder::new(&pool, "users", false)
3063 .field("id")
3064 .field("name")
3065 .where_and_unchecked("status", "=", 1);
3066
3067 let sql = builder.to_sql();
3068 assert!(sql.contains("SELECT id, name FROM users"));
3069 assert!(sql.contains("WHERE"));
3070 }
3071
3072 #[tokio::test]
3073 async fn test_select_with_join() {
3074 let pool = create_test_pool().await;
3075 let builder = QueryBuilder::new(&pool, "users", false)
3076 .field("users.id")
3077 .field("orders.total")
3078 .join("orders", "users.id = orders.user_id");
3079
3080 let sql = builder.to_sql();
3081 assert!(sql.contains("SELECT users.id, orders.total FROM users"));
3082 assert!(sql.contains("INNER JOIN orders ON users.id = orders.user_id"));
3083 }
3084
3085 #[tokio::test]
3086 async fn test_select_with_order_by() {
3087 let pool = create_test_pool().await;
3088 let builder = QueryBuilder::new(&pool, "users", false)
3089 .field("name")
3090 .order("name", true)
3091 .order("age", false);
3092
3093 let sql = builder.to_sql();
3094 assert!(sql.contains("ORDER BY name ASC, age DESC"));
3095 }
3096
3097 #[tokio::test]
3098 async fn test_select_with_group_by() {
3099 let pool = create_test_pool().await;
3100 let builder = QueryBuilder::new(&pool, "users", false)
3101 .field("status")
3102 .group("status");
3103
3104 let sql = builder.to_sql();
3105 assert!(sql.contains("GROUP BY status"));
3106 }
3107
3108 #[tokio::test]
3109 async fn test_select_with_limit_offset() {
3110 let pool = create_test_pool().await;
3111 let builder = QueryBuilder::new(&pool, "users", false)
3112 .field("id")
3113 .limit(10)
3114 .offset(20);
3115
3116 let sql = builder.to_sql();
3117 assert!(sql.contains("LIMIT 10"));
3118 assert!(sql.contains("OFFSET 20"));
3119 }
3120
3121 #[tokio::test]
3122 async fn test_select_complex_query() {
3123 let pool = create_test_pool().await;
3124 let builder = QueryBuilder::new(&pool, "users", false)
3125 .field("users.id")
3126 .field("users.name")
3127 .field("orders.total")
3128 .distinct()
3129 .join("orders", "users.id = orders.user_id")
3130 .where_and_unchecked("users.status", "=", 1)
3131 .where_and_unchecked("orders.total", ">", 100)
3132 .group("users.id")
3133 .order("orders.total", false)
3134 .limit(50);
3135
3136 let sql = builder.to_sql();
3137 assert!(sql.contains("SELECT DISTINCT"));
3138 assert!(sql.contains("users.id, users.name, orders.total"));
3139 assert!(sql.contains("FROM users"));
3140 assert!(sql.contains("INNER JOIN orders ON users.id = orders.user_id"));
3141 assert!(sql.contains("WHERE"));
3142 assert!(sql.contains("GROUP BY users.id"));
3143 assert!(sql.contains("ORDER BY orders.total DESC"));
3144 assert!(sql.contains("LIMIT 50"));
3145 }
3146
3147 #[tokio::test]
3148 async fn test_select_with_multiple_joins() {
3149 let pool = create_test_pool().await;
3150 let builder = QueryBuilder::new(&pool, "users", false)
3151 .field("users.name")
3152 .field("orders.total")
3153 .field("products.name")
3154 .join("orders", "users.id = orders.user_id")
3155 .left_join("products", "orders.product_id = products.id");
3156
3157 let sql = builder.to_sql();
3158 assert!(sql.contains("INNER JOIN orders ON users.id = orders.user_id"));
3159 assert!(sql.contains("LEFT JOIN products ON orders.product_id = products.id"));
3160 }
3161
3162 #[tokio::test]
3163 async fn test_select_with_in_condition() {
3164 let pool = create_test_pool().await;
3165 let builder = QueryBuilder::new(&pool, "users", false)
3166 .field("name")
3167 .where_in("id", vec![1, 2, 3, 4, 5]);
3168
3169 let sql = builder.to_sql();
3170 assert!(sql.contains("WHERE"));
3171 assert!(sql.contains("IN"));
3172 }
3173
3174 #[tokio::test]
3175 async fn test_select_with_between_condition() {
3176 let pool = create_test_pool().await;
3177 let builder = QueryBuilder::new(&pool, "users", false)
3178 .field("name")
3179 .where_between("age", 18, 65);
3180
3181 let sql = builder.to_sql();
3182 assert!(sql.contains("WHERE"));
3183 assert!(sql.contains("BETWEEN"));
3184 }
3185
3186 #[test]
3187 fn test_where_null_generates_is_null_sql() {
3188 let pool_storage = std::mem::MaybeUninit::<MySqlPool>::uninit();
3189 let pool: &MySqlPool = unsafe { &*pool_storage.as_ptr() };
3190 let builder = QueryBuilder::new(pool, "users", false).where_null("deleted_at");
3191 let sql = builder.to_sql();
3192 assert!(sql.contains("deleted_at IS NULL"));
3193 }
3194
3195 #[test]
3196 fn test_where_not_null_generates_is_not_null_sql() {
3197 let pool_storage = std::mem::MaybeUninit::<MySqlPool>::uninit();
3198 let pool: &MySqlPool = unsafe { &*pool_storage.as_ptr() };
3199 let builder = QueryBuilder::new(pool, "users", false).where_not_null("email");
3200 let sql = builder.to_sql();
3201 assert!(sql.contains("email IS NOT NULL"));
3202 }
3203
3204 #[test]
3205 fn test_is_null_with_and_condition() {
3206 let pool_storage = std::mem::MaybeUninit::<MySqlPool>::uninit();
3207 let pool: &MySqlPool = unsafe { &*pool_storage.as_ptr() };
3208 let builder = QueryBuilder::new(pool, "users", false)
3209 .where_and_unchecked("status", "=", 1i64)
3210 .where_null("deleted_at");
3211 let sql = builder.to_sql();
3212 assert!(sql.contains("status = ?"));
3213 assert!(sql.contains("deleted_at IS NULL"));
3214 }
3215
3216 #[test]
3217 fn test_having_clause_sql_generation() {
3218 let pool_storage = std::mem::MaybeUninit::<MySqlPool>::uninit();
3219 let pool: &MySqlPool = unsafe { &*pool_storage.as_ptr() };
3220 let builder = QueryBuilder::new(pool, "orders", false)
3221 .field("user_id")
3222 .field("COUNT(*) as cnt")
3223 .group("user_id")
3224 .having_cond_unchecked("cnt", ">", 5i64);
3225 let sql = builder.to_sql();
3226 assert!(sql.contains("HAVING"));
3227 assert!(sql.contains("cnt > ?"));
3228 }
3229
3230 #[test]
3231 fn test_having_without_group_returns_error() {
3232 let pool_storage = std::mem::MaybeUninit::<MySqlPool>::uninit();
3233 let pool: &MySqlPool = unsafe { &*pool_storage.as_ptr() };
3234 let builder = QueryBuilder::new(pool, "orders", false).having_cond_unchecked("cnt", ">", 5i64);
3235 let mut generator = SqlGenerator::new();
3236 let result = generator.build_select(&builder);
3237 assert!(result.is_err());
3238 assert!(matches!(
3239 result.unwrap_err(),
3240 crate::DbError::MissingGroupByClause
3241 ));
3242 }
3243
3244 #[test]
3245 fn test_having_clause_order() {
3246 let pool_storage = std::mem::MaybeUninit::<MySqlPool>::uninit();
3247 let pool: &MySqlPool = unsafe { &*pool_storage.as_ptr() };
3248 let builder = QueryBuilder::new(pool, "orders", false)
3249 .group("user_id")
3250 .having_cond_unchecked("cnt", ">", 5i64)
3251 .order("cnt", false);
3252 let sql = builder.to_sql();
3253 let group_pos = sql.find("GROUP BY").unwrap();
3254 let having_pos = sql.find("HAVING").unwrap();
3255 let order_pos = sql.find("ORDER BY").unwrap();
3256 assert!(group_pos < having_pos);
3257 assert!(having_pos < order_pos);
3258 }
3259
3260 #[test]
3261 fn test_update_batch_case_when_sql() {
3262 let records = vec![
3263 serde_json::json!({"id": 1, "name": "Alice", "age": 25}),
3264 serde_json::json!({"id": 2, "name": "Bob", "age": 30}),
3265 ];
3266 let mut generator = SqlGenerator::new();
3267 generator
3268 .build_update_batch("users", &records, "id", &std::collections::HashMap::new())
3269 .unwrap();
3270 let sql = generator.get_sql();
3271 assert!(sql.starts_with("UPDATE users SET "));
3272 assert!(sql.contains("CASE WHEN id=? THEN ?"));
3273 assert!(sql.contains("WHERE id IN ("));
3274 }
3275
3276 #[test]
3277 fn test_update_batch_empty_returns_error() {
3278 let records: Vec<serde_json::Value> = vec![];
3279 let mut generator = SqlGenerator::new();
3280 let result = generator.build_update_batch(
3281 "users",
3282 &records,
3283 "id",
3284 &std::collections::HashMap::new(),
3285 );
3286 assert!(result.is_err());
3287 }
3288
3289 #[test]
3290 fn test_upsert_sql_generation() {
3291 let data = serde_json::json!({"id": 1, "name": "Alice", "email": "a@b.com"});
3292 let mut generator = SqlGenerator::new();
3293 generator
3294 .build_upsert("users", &data, &std::collections::HashMap::new())
3295 .unwrap();
3296 let sql = generator.get_sql();
3297 assert!(sql.starts_with("INSERT INTO users"));
3298 assert!(sql.contains("ON DUPLICATE KEY UPDATE"));
3299 assert!(sql.contains("name=VALUES(name)"));
3300 }
3301
3302 #[test]
3303 fn test_upsert_empty_data_returns_error() {
3304 let data = serde_json::json!({});
3305 let mut generator = SqlGenerator::new();
3306 let result = generator.build_upsert("users", &data, &std::collections::HashMap::new());
3307 assert!(result.is_err());
3308 }
3309
3310 #[tokio::test]
3312 async fn test_sql_generator_build_select_basic() {
3313 let pool = create_test_pool().await;
3314 let builder = QueryBuilder::new(&pool, "users", false)
3315 .field("id")
3316 .field("name");
3317
3318 let mut generator = SqlGenerator::new();
3319 let result = generator.build_select(&builder);
3320
3321 assert!(result.is_ok());
3322 assert_eq!(generator.get_sql(), "SELECT id, name FROM users");
3323 }
3324
3325 #[tokio::test]
3326 async fn test_sql_generator_build_select_with_distinct() {
3327 let pool = create_test_pool().await;
3328 let builder = QueryBuilder::new(&pool, "users", false)
3329 .field("name")
3330 .distinct();
3331
3332 let mut generator = SqlGenerator::new();
3333 let result = generator.build_select(&builder);
3334
3335 assert!(result.is_ok());
3336 assert_eq!(generator.get_sql(), "SELECT DISTINCT name FROM users");
3337 }
3338
3339 #[tokio::test]
3340 async fn test_sql_generator_build_select_all_fields() {
3341 let pool = create_test_pool().await;
3342 let builder = QueryBuilder::new(&pool, "users", false);
3343
3344 let mut generator = SqlGenerator::new();
3345 let result = generator.build_select(&builder);
3346
3347 assert!(result.is_ok());
3348 assert_eq!(generator.get_sql(), "SELECT * FROM users");
3349 }
3350
3351 #[tokio::test]
3353 async fn test_sql_generator_build_where() {
3354 let pool = create_test_pool().await;
3355 let builder = QueryBuilder::new(&pool, "users", false)
3356 .where_and_unchecked("status", "=", 1)
3357 .where_and_unchecked("age", ">", 18);
3358
3359 let mut generator = SqlGenerator::new();
3360 let result = generator.build_select(&builder);
3361
3362 assert!(result.is_ok());
3363 let sql = generator.get_sql();
3364 assert!(sql.contains("WHERE"));
3365 assert!(sql.contains("status"));
3366 assert!(sql.contains("age"));
3367 }
3368
3369 #[tokio::test]
3371 async fn test_sql_generator_build_joins() {
3372 let pool = create_test_pool().await;
3373 let builder = QueryBuilder::new(&pool, "users", false)
3374 .join("orders", "users.id = orders.user_id")
3375 .left_join("profiles", "users.id = profiles.user_id");
3376
3377 let mut generator = SqlGenerator::new();
3378 let result = generator.build_select(&builder);
3379
3380 assert!(result.is_ok());
3381 let sql = generator.get_sql();
3382 assert!(sql.contains("INNER JOIN orders ON users.id = orders.user_id"));
3383 assert!(sql.contains("LEFT JOIN profiles ON users.id = profiles.user_id"));
3384 }
3385
3386 #[tokio::test]
3388 async fn test_sql_generator_build_order_by() {
3389 let pool = create_test_pool().await;
3390 let builder = QueryBuilder::new(&pool, "users", false)
3391 .order("name", true)
3392 .order("created_at", false);
3393
3394 let mut generator = SqlGenerator::new();
3395 let result = generator.build_select(&builder);
3396
3397 assert!(result.is_ok());
3398 let sql = generator.get_sql();
3399 assert!(sql.contains("ORDER BY name ASC, created_at DESC"));
3400 }
3401
3402 #[tokio::test]
3404 async fn test_sql_generator_build_group_by() {
3405 let pool = create_test_pool().await;
3406 let builder = QueryBuilder::new(&pool, "users", false)
3407 .group("status")
3408 .group("role");
3409
3410 let mut generator = SqlGenerator::new();
3411 let result = generator.build_select(&builder);
3412
3413 assert!(result.is_ok());
3414 let sql = generator.get_sql();
3415 assert!(sql.contains("GROUP BY status, role"));
3416 }
3417
3418 #[tokio::test]
3420 async fn test_sql_generator_build_limit_offset() {
3421 let pool = create_test_pool().await;
3422 let builder = QueryBuilder::new(&pool, "users", false)
3423 .limit(10)
3424 .offset(20);
3425
3426 let mut generator = SqlGenerator::new();
3427 let result = generator.build_select(&builder);
3428
3429 assert!(result.is_ok());
3430 let sql = generator.get_sql();
3431 assert!(sql.contains("LIMIT 10"));
3432 assert!(sql.contains("OFFSET 20"));
3433 }
3434
3435 #[tokio::test]
3437 async fn test_sql_generator_complex_query() {
3438 let pool = create_test_pool().await;
3439 let builder = QueryBuilder::new(&pool, "users", false)
3440 .field("users.id")
3441 .field("users.name")
3442 .field("COUNT(orders.id) as order_count")
3443 .distinct()
3444 .join("orders", "users.id = orders.user_id")
3445 .where_and_unchecked("users.status", "=", 1)
3446 .where_and_unchecked("orders.total", ">", 100)
3447 .group("users.id")
3448 .group("users.name")
3449 .order("order_count", false)
3450 .limit(20)
3451 .offset(10);
3452
3453 let mut generator = SqlGenerator::new();
3454 let result = generator.build_select(&builder);
3455
3456 assert!(result.is_ok());
3457 let sql = generator.get_sql();
3458
3459 assert!(sql.starts_with("SELECT DISTINCT"));
3461 assert!(sql.contains("users.id, users.name, COUNT(orders.id) as order_count"));
3462 assert!(sql.contains("FROM users"));
3463 assert!(sql.contains("INNER JOIN orders ON users.id = orders.user_id"));
3464 assert!(sql.contains("WHERE"));
3465 assert!(sql.contains("GROUP BY users.id, users.name"));
3466 assert!(sql.contains("ORDER BY order_count DESC"));
3467 assert!(sql.contains("LIMIT 20"));
3468 assert!(sql.contains("OFFSET 10"));
3469 }
3470
3471 #[tokio::test]
3473 async fn test_find_adds_limit_one() {
3474 let pool = create_test_pool().await;
3475 let builder = QueryBuilder::new(&pool, "users", false)
3476 .field("id")
3477 .field("name")
3478 .where_and_unchecked("id", "=", 1);
3479
3480 assert_eq!(builder.limit, None);
3482
3483 let builder_with_limit = QueryBuilder::new(&pool, "users", false)
3485 .field("id")
3486 .field("name")
3487 .where_and_unchecked("id", "=", 1)
3488 .limit(1);
3489
3490 let sql = builder_with_limit.to_sql();
3491 assert!(sql.contains("LIMIT 1"), "find() 应该自动添加 LIMIT 1");
3492 }
3493
3494 #[test]
3496 fn test_sql_generator_build_insert_basic() {
3497 let mut generator = SqlGenerator::new();
3498 let data = serde_json::json!({
3499 "name": "张三",
3500 "age": 25,
3501 "email": "zhangsan@example.com"
3502 });
3503 let field_types = HashMap::new();
3504
3505 let result = generator.build_insert("users", &data, &field_types);
3506 assert!(result.is_ok());
3507
3508 let sql = generator.get_sql();
3509 assert!(sql.starts_with("INSERT INTO users"));
3510 assert!(sql.contains("name"));
3511 assert!(sql.contains("age"));
3512 assert!(sql.contains("email"));
3513 assert!(sql.contains("VALUES"));
3514 assert_eq!(generator.get_params().len(), 3);
3515 }
3516
3517 #[test]
3518 fn test_sql_generator_build_insert_with_json_field() {
3519 let mut generator = SqlGenerator::new();
3520 let data = serde_json::json!({
3521 "name": "测试用户",
3522 "data": {"role": "admin", "permissions": ["read", "write"]}
3523 });
3524
3525 let mut field_types = HashMap::new();
3526 field_types.insert("data".to_string(), FieldType::Json);
3527
3528 let result = generator.build_insert("users", &data, &field_types);
3529 assert!(result.is_ok());
3530
3531 let sql = generator.get_sql();
3532 assert!(sql.contains("INSERT INTO users"));
3533 assert!(sql.contains("name"));
3534 assert!(sql.contains("data"));
3535 assert_eq!(generator.get_params().len(), 2);
3536
3537 let params = generator.get_params();
3539 let has_json = params.iter().any(|p| matches!(p, SqlValue::Json(_)));
3540 assert!(has_json, "应该包含 JSON 类型的参数");
3541 }
3542
3543 #[test]
3544 fn test_sql_generator_build_insert_empty_data() {
3545 let mut generator = SqlGenerator::new();
3546 let data = serde_json::json!({});
3547 let field_types = HashMap::new();
3548
3549 let result = generator.build_insert("users", &data, &field_types);
3550 assert!(result.is_err());
3551 assert!(matches!(
3552 result.unwrap_err(),
3553 crate::error::DbError::SerializationError(_)
3554 ));
3555 }
3556
3557 #[test]
3558 fn test_sql_generator_build_insert_not_object() {
3559 let mut generator = SqlGenerator::new();
3560 let data = serde_json::json!([1, 2, 3]); let field_types = HashMap::new();
3562
3563 let result = generator.build_insert("users", &data, &field_types);
3564 assert!(result.is_err());
3565 assert!(matches!(
3566 result.unwrap_err(),
3567 crate::error::DbError::SerializationError(_)
3568 ));
3569 }
3570
3571 #[tokio::test]
3575 async fn test_avg_sql_generation() {
3576 let pool = create_test_pool().await;
3577
3578 let mut test_builder = QueryBuilder::new(&pool, "products", false);
3580 test_builder.fields.clear();
3581 test_builder
3582 .fields
3583 .push("CAST(AVG(price) AS DOUBLE)".to_string());
3584 test_builder.limit = Some(1);
3585
3586 let sql = test_builder.to_sql();
3587 assert!(sql.contains("SELECT CAST(AVG(price) AS DOUBLE)"));
3588 assert!(sql.contains("FROM products"));
3589 assert!(sql.contains("LIMIT 1"));
3590 }
3591
3592 #[tokio::test]
3594 async fn test_avg_with_where_sql() {
3595 let pool = create_test_pool().await;
3596
3597 let mut test_builder =
3599 QueryBuilder::new(&pool, "products", false).where_and_unchecked("status", "=", 1);
3600 test_builder.fields.clear();
3601 test_builder
3602 .fields
3603 .push("CAST(AVG(price) AS DOUBLE)".to_string());
3604 test_builder.limit = Some(1);
3605
3606 let sql = test_builder.to_sql();
3607 assert!(sql.contains("SELECT CAST(AVG(price) AS DOUBLE)"));
3608 assert!(sql.contains("FROM products"));
3609 assert!(sql.contains("WHERE"));
3610 assert!(sql.contains("status"));
3611 }
3612
3613 #[tokio::test]
3615 async fn test_min_sql_generation() {
3616 let pool = create_test_pool().await;
3617
3618 let mut test_builder = QueryBuilder::new(&pool, "products", false);
3620 test_builder.fields.clear();
3621 test_builder.fields.push("MIN(price)".to_string());
3622 test_builder.limit = Some(1);
3623
3624 let sql = test_builder.to_sql();
3625 assert!(sql.contains("SELECT MIN(price)"));
3626 assert!(sql.contains("FROM products"));
3627 assert!(sql.contains("LIMIT 1"));
3628 }
3629
3630 #[tokio::test]
3632 async fn test_max_sql_generation() {
3633 let pool = create_test_pool().await;
3634
3635 let mut test_builder = QueryBuilder::new(&pool, "products", false);
3637 test_builder.fields.clear();
3638 test_builder.fields.push("MAX(price)".to_string());
3639 test_builder.limit = Some(1);
3640
3641 let sql = test_builder.to_sql();
3642 assert!(sql.contains("SELECT MAX(price)"));
3643 assert!(sql.contains("FROM products"));
3644 assert!(sql.contains("LIMIT 1"));
3645 }
3646
3647 #[tokio::test]
3649 async fn test_min_max_different_types() {
3650 let pool = create_test_pool().await;
3651
3652 let mut builder_int = QueryBuilder::new(&pool, "products", false);
3654 builder_int.fields.clear();
3655 builder_int.fields.push("MIN(stock)".to_string());
3656 let sql_int = builder_int.to_sql();
3657 assert!(sql_int.contains("MIN(stock)"));
3658
3659 let mut builder_float = QueryBuilder::new(&pool, "products", false);
3661 builder_float.fields.clear();
3662 builder_float.fields.push("MAX(price)".to_string());
3663 let sql_float = builder_float.to_sql();
3664 assert!(sql_float.contains("MAX(price)"));
3665
3666 let mut builder_string = QueryBuilder::new(&pool, "users", false);
3668 builder_string.fields.clear();
3669 builder_string.fields.push("MIN(name)".to_string());
3670 let sql_string = builder_string.to_sql();
3671 assert!(sql_string.contains("MIN(name)"));
3672
3673 let mut builder_datetime = QueryBuilder::new(&pool, "users", false);
3675 builder_datetime.fields.clear();
3676 builder_datetime.fields.push("MAX(created_at)".to_string());
3677 let sql_datetime = builder_datetime.to_sql();
3678 assert!(sql_datetime.contains("MAX(created_at)"));
3679 }
3680
3681 #[tokio::test]
3683 async fn test_aggregates_with_group_by_sql() {
3684 let pool = create_test_pool().await;
3685
3686 let mut test_builder = QueryBuilder::new(&pool, "orders", false).group("user_id");
3688 test_builder.fields.clear();
3689 test_builder.fields.push("user_id".to_string());
3690 test_builder
3691 .fields
3692 .push("CAST(AVG(amount) AS DOUBLE) as avg_amount".to_string());
3693
3694 let sql = test_builder.to_sql();
3695 assert!(sql.contains("SELECT user_id, CAST(AVG(amount) AS DOUBLE) as avg_amount"));
3696 assert!(sql.contains("FROM orders"));
3697 assert!(sql.contains("GROUP BY user_id"));
3698 }
3699
3700 #[tokio::test]
3702 async fn test_multiple_aggregates_sql() {
3703 let pool = create_test_pool().await;
3704
3705 let mut test_builder =
3707 QueryBuilder::new(&pool, "orders", false).where_and_unchecked("status", "=", "completed");
3708 test_builder.fields.clear();
3709 test_builder
3710 .fields
3711 .push("CAST(AVG(amount) AS DOUBLE) as avg_amount".to_string());
3712 test_builder
3713 .fields
3714 .push("CAST(MIN(amount) AS DOUBLE) as min_amount".to_string());
3715 test_builder
3716 .fields
3717 .push("CAST(MAX(amount) AS DOUBLE) as max_amount".to_string());
3718 test_builder
3719 .fields
3720 .push("COUNT(*) as order_count".to_string());
3721
3722 let sql = test_builder.to_sql();
3723 assert!(sql.contains("CAST(AVG(amount) AS DOUBLE) as avg_amount"));
3724 assert!(sql.contains("CAST(MIN(amount) AS DOUBLE) as min_amount"));
3725 assert!(sql.contains("CAST(MAX(amount) AS DOUBLE) as max_amount"));
3726 assert!(sql.contains("COUNT(*) as order_count"));
3727 assert!(sql.contains("WHERE"));
3728 assert!(sql.contains("status"));
3729 }
3730
3731 #[tokio::test]
3733 async fn test_sql_clause_order_with_aggregates() {
3734 let pool = create_test_pool().await;
3735
3736 let mut test_builder = QueryBuilder::new(&pool, "orders", false)
3738 .where_and_unchecked("status", "=", "completed")
3739 .group("user_id")
3740 .order("total_amount", false);
3741 test_builder.fields.clear();
3742 test_builder.fields.push("user_id".to_string());
3743 test_builder
3744 .fields
3745 .push("SUM(amount) as total_amount".to_string());
3746
3747 let sql = test_builder.to_sql();
3748
3749 let where_pos = sql.find("WHERE").expect("应该包含 WHERE");
3751 let group_pos = sql.find("GROUP BY").expect("应该包含 GROUP BY");
3752 let order_pos = sql.find("ORDER BY").expect("应该包含 ORDER BY");
3753
3754 assert!(where_pos < group_pos, "WHERE 应该在 GROUP BY 之前");
3755 assert!(group_pos < order_pos, "GROUP BY 应该在 ORDER BY 之前");
3756 }
3757
3758 #[tokio::test]
3760 async fn test_aggregates_empty_result_sql() {
3761 let pool = create_test_pool().await;
3762
3763 let mut test_builder = QueryBuilder::new(&pool, "products", false).where_and_unchecked("id", "=", -1); test_builder.fields.clear();
3766 test_builder
3767 .fields
3768 .push("CAST(AVG(price) AS DOUBLE)".to_string());
3769 test_builder.limit = Some(1);
3770
3771 let sql = test_builder.to_sql();
3772 assert!(sql.contains("SELECT CAST(AVG(price) AS DOUBLE)"));
3773 assert!(sql.contains("WHERE"));
3774 }
3776
3777 #[tokio::test]
3779 async fn test_aggregates_with_special_field_names() {
3780 let pool = create_test_pool().await;
3781
3782 let mut test_builder = QueryBuilder::new(&pool, "products", false);
3784 test_builder.fields.clear();
3785 test_builder.fields.push("AVG(unit_price)".to_string());
3786
3787 let sql = test_builder.to_sql();
3788 assert!(sql.contains("AVG(unit_price)"));
3789
3790 let mut test_builder2 = QueryBuilder::new(&pool, "products", false);
3792 test_builder2.fields.clear();
3793 test_builder2.fields.push("MAX(`order`)".to_string());
3794
3795 let sql2 = test_builder2.to_sql();
3796 assert!(sql2.contains("MAX(`order`)"));
3797 }
3798
3799 #[tokio::test]
3801 async fn test_aggregates_with_distinct() {
3802 let pool = create_test_pool().await;
3803
3804 let mut test_builder = QueryBuilder::new(&pool, "orders", false);
3806 test_builder.fields.clear();
3807 test_builder
3808 .fields
3809 .push("COUNT(DISTINCT user_id) as unique_users".to_string());
3810
3811 let sql = test_builder.to_sql();
3812 assert!(sql.contains("COUNT(DISTINCT user_id)"));
3813 }
3814
3815 #[tokio::test]
3817 async fn test_aggregates_with_join() {
3818 let pool = create_test_pool().await;
3819
3820 let mut test_builder = QueryBuilder::new(&pool, "users", false)
3822 .join("orders", "users.id = orders.user_id")
3823 .group("users.id");
3824 test_builder.fields.clear();
3825 test_builder.fields.push("users.id".to_string());
3826 test_builder.fields.push("users.name".to_string());
3827 test_builder
3828 .fields
3829 .push("COUNT(orders.id) as order_count".to_string());
3830 test_builder
3831 .fields
3832 .push("SUM(orders.amount) as total_amount".to_string());
3833
3834 let sql = test_builder.to_sql();
3835 assert!(sql.contains("INNER JOIN orders ON users.id = orders.user_id"));
3836 assert!(sql.contains("COUNT(orders.id) as order_count"));
3837 assert!(sql.contains("SUM(orders.amount) as total_amount"));
3838 assert!(sql.contains("GROUP BY users.id"));
3839 }
3840
3841 #[tokio::test]
3843 async fn test_aggregates_sql_injection_prevention() {
3844 let pool = create_test_pool().await;
3845
3846 let builder = QueryBuilder::new(&pool, "products", false).where_and_unchecked(
3848 "category",
3849 "=",
3850 "'; DROP TABLE products; --",
3851 );
3852
3853 let mut generator = SqlGenerator::new();
3855 let result = generator.build_select(&builder);
3856 assert!(result.is_ok());
3857
3858 let sql = generator.get_sql();
3859 let params = generator.get_params();
3860
3861 assert!(sql.contains("?"));
3863 assert!(!sql.contains("DROP TABLE"));
3864
3865 assert_eq!(params.len(), 1);
3867 }
3868}
3869
3870#[cfg(test)]
3871mod property_tests {
3872 use super::*;
3873 use proptest::prelude::*;
3874 use sqlx::mysql::MySqlPoolOptions;
3875
3876 fn table_name_strategy() -> impl Strategy<Value = String> {
3878 "[a-z][a-z0-9_]{0,30}"
3879 }
3880
3881 fn field_name_strategy() -> impl Strategy<Value = String> {
3883 "[a-z][a-z0-9_]{0,30}"
3884 }
3885
3886 fn create_test_pool_sync() -> MySqlPool {
3888 tokio::runtime::Runtime::new().unwrap().block_on(async {
3889 MySqlPoolOptions::new()
3890 .max_connections(1)
3891 .connect("mysql://root:111111@localhost:3306/test")
3892 .await
3893 .expect("无法连接到测试数据库")
3894 })
3895 }
3896
3897 proptest! {
3900 #![proptest_config(ProptestConfig::with_cases(100))]
3901
3902 #[test]
3903 fn prop_table_name_in_sql(table_name in table_name_strategy()) {
3904 let pool = create_test_pool_sync();
3905 let builder = QueryBuilder::new(&pool, &table_name, false);
3906 let sql = builder.to_sql();
3907
3908 let expected = format!("FROM {}", table_name);
3910 prop_assert!(sql.contains(&expected));
3911 }
3912 }
3913
3914 proptest! {
3917 #![proptest_config(ProptestConfig::with_cases(100))]
3918
3919 #[test]
3920 fn prop_table_name_override(
3921 table_name1 in table_name_strategy(),
3922 table_name2 in table_name_strategy()
3923 ) {
3924 prop_assume!(table_name1 != table_name2);
3925
3926 let pool = create_test_pool_sync();
3927 let builder1 = QueryBuilder::new(&pool, &table_name1, false);
3929 let sql1 = builder1.to_sql();
3930 let expected1 = format!("FROM {}", table_name1);
3931 prop_assert!(sql1.contains(&expected1));
3932
3933 let builder2 = QueryBuilder::new(&pool, &table_name2, false);
3935 let sql2 = builder2.to_sql();
3936 let expected2 = format!("FROM {}", table_name2);
3937 prop_assert!(sql2.contains(&expected2));
3938
3939 let pattern1 = format!("FROM {} ", table_name1);
3942 let pattern1_alt = format!("FROM {}\n", table_name1);
3943 prop_assert!(!sql2.contains(&pattern1) && !sql2.contains(&pattern1_alt));
3944 }
3945 }
3946
3947 proptest! {
3950 #![proptest_config(ProptestConfig::with_cases(100))]
3951
3952 #[test]
3953 fn prop_field_selection(
3954 table_name in table_name_strategy(),
3955 fields in prop::collection::vec(field_name_strategy(), 1..10)
3956 ) {
3957 let pool = create_test_pool_sync();
3958 let mut builder = QueryBuilder::new(&pool, &table_name, false);
3959
3960 for field in &fields {
3962 builder = builder.field(field);
3963 }
3964
3965 let sql = builder.to_sql();
3966
3967 for field in &fields {
3969 prop_assert!(sql.contains(field));
3970 }
3971 }
3972 }
3973
3974 proptest! {
3977 #![proptest_config(ProptestConfig::with_cases(100))]
3978
3979 #[test]
3980 fn prop_distinct_keyword(
3981 table_name in table_name_strategy(),
3982 field in field_name_strategy()
3983 ) {
3984 let pool = create_test_pool_sync();
3985 let builder = QueryBuilder::new(&pool, &table_name, false)
3986 .field(&field)
3987 .distinct();
3988
3989 let sql = builder.to_sql();
3990
3991 prop_assert!(sql.contains("SELECT DISTINCT"));
3993 }
3994 }
3995
3996 proptest! {
3999 #![proptest_config(ProptestConfig::with_cases(100))]
4000
4001 #[test]
4002 fn prop_special_field_type_marking(
4003 table_name in table_name_strategy(),
4004 json_field in field_name_strategy(),
4005 datetime_field in field_name_strategy(),
4006 timestamp_field in field_name_strategy(),
4007 decimal_field in field_name_strategy(),
4008 blob_field in field_name_strategy(),
4009 text_field in field_name_strategy()
4010 ) {
4011 prop_assume!(json_field != datetime_field);
4013 prop_assume!(json_field != timestamp_field);
4014 prop_assume!(json_field != decimal_field);
4015 prop_assume!(json_field != blob_field);
4016 prop_assume!(json_field != text_field);
4017 prop_assume!(datetime_field != timestamp_field);
4018 prop_assume!(datetime_field != decimal_field);
4019 prop_assume!(datetime_field != blob_field);
4020 prop_assume!(datetime_field != text_field);
4021 prop_assume!(timestamp_field != decimal_field);
4022 prop_assume!(timestamp_field != blob_field);
4023 prop_assume!(timestamp_field != text_field);
4024 prop_assume!(decimal_field != blob_field);
4025 prop_assume!(decimal_field != text_field);
4026 prop_assume!(blob_field != text_field);
4027
4028 let pool = create_test_pool_sync();
4029 let builder = QueryBuilder::new(&pool, &table_name, false)
4030 .json(&json_field)
4031 .datetime(&datetime_field)
4032 .timestamp(×tamp_field)
4033 .decimal(&decimal_field)
4034 .blob(&blob_field)
4035 .text(&text_field);
4036
4037 prop_assert_eq!(builder.field_types.get(&json_field), Some(&FieldType::Json));
4039 prop_assert_eq!(builder.field_types.get(&datetime_field), Some(&FieldType::DateTime));
4040 prop_assert_eq!(builder.field_types.get(×tamp_field), Some(&FieldType::Timestamp));
4041 prop_assert_eq!(builder.field_types.get(&decimal_field), Some(&FieldType::Decimal));
4042 prop_assert_eq!(builder.field_types.get(&blob_field), Some(&FieldType::Blob));
4043 prop_assert_eq!(builder.field_types.get(&text_field), Some(&FieldType::Text));
4044 }
4045 }
4046
4047 proptest! {
4050 #![proptest_config(ProptestConfig::with_cases(100))]
4051
4052 #[test]
4053 fn prop_where_and_condition_added(
4054 table_name in table_name_strategy(),
4055 field in field_name_strategy(),
4056 value in any::<i32>()
4057 ) {
4058 let pool = create_test_pool_sync();
4059 let builder = QueryBuilder::new(&pool, &table_name, false)
4060 .where_and_unchecked(&field, "=", value);
4061
4062 prop_assert_eq!(builder.conditions.len(), 1);
4064 }
4065
4066 #[test]
4067 fn prop_where_or_condition_added(
4068 table_name in table_name_strategy(),
4069 field in field_name_strategy(),
4070 value1 in any::<i32>(),
4071 value2 in any::<i32>()
4072 ) {
4073 let pool = create_test_pool_sync();
4074 let builder = QueryBuilder::new(&pool, &table_name, false)
4075 .where_or_unchecked(&field, "=", value1)
4076 .where_or_unchecked(&field, "=", value2);
4077
4078 prop_assert_eq!(builder.conditions.len(), 1);
4080 }
4081 }
4082
4083 proptest! {
4086 #![proptest_config(ProptestConfig::with_cases(100))]
4087
4088 #[test]
4089 fn prop_in_operator_array_support(
4090 table_name in table_name_strategy(),
4091 field in field_name_strategy(),
4092 values in prop::collection::vec(any::<i32>(), 1..10)
4093 ) {
4094 let pool = create_test_pool_sync();
4095 let builder = QueryBuilder::new(&pool, &table_name, false)
4096 .where_in(&field, values);
4097
4098 prop_assert_eq!(builder.conditions.len(), 1);
4100 }
4101 }
4102
4103 proptest! {
4106 #![proptest_config(ProptestConfig::with_cases(100))]
4107
4108 #[test]
4109 fn prop_between_operator_boundary_support(
4110 table_name in table_name_strategy(),
4111 field in field_name_strategy(),
4112 start in any::<i32>(),
4113 end in any::<i32>()
4114 ) {
4115 let pool = create_test_pool_sync();
4116 let builder = QueryBuilder::new(&pool, &table_name, false)
4117 .where_between(&field, start, end);
4118
4119 prop_assert_eq!(builder.conditions.len(), 1);
4121 }
4122 }
4123
4124 proptest! {
4127 #![proptest_config(ProptestConfig::with_cases(100))]
4128
4129 #[test]
4130 fn prop_multiple_and_conditions(
4131 table_name in table_name_strategy(),
4132 field in field_name_strategy(),
4133 values in prop::collection::vec(any::<i32>(), 2..5)
4134 ) {
4135 let pool = create_test_pool_sync();
4136 let mut builder = QueryBuilder::new(&pool, &table_name, false);
4137
4138 for value in &values {
4140 builder = builder.where_and_unchecked(&field, "=", *value);
4141 }
4142
4143 prop_assert_eq!(builder.conditions.len(), values.len());
4145 }
4146 }
4147
4148 proptest! {
4151 #![proptest_config(ProptestConfig::with_cases(100))]
4152
4153 #[test]
4154 fn prop_join_clause_generation(
4155 table_name in table_name_strategy(),
4156 join_table in table_name_strategy(),
4157 on_condition in "[a-z][a-z0-9_]{0,20}\\.[a-z][a-z0-9_]{0,20} = [a-z][a-z0-9_]{0,20}\\.[a-z][a-z0-9_]{0,20}"
4158 ) {
4159 let pool = create_test_pool_sync();
4160
4161 let builder_inner = QueryBuilder::new(&pool, &table_name, false)
4163 .join(&join_table, &on_condition);
4164 prop_assert_eq!(builder_inner.joins.len(), 1);
4165
4166 let builder_left = QueryBuilder::new(&pool, &table_name, false)
4168 .left_join(&join_table, &on_condition);
4169 prop_assert_eq!(builder_left.joins.len(), 1);
4170
4171 let builder_right = QueryBuilder::new(&pool, &table_name, false)
4173 .right_join(&join_table, &on_condition);
4174 prop_assert_eq!(builder_right.joins.len(), 1);
4175 }
4176 }
4177
4178 proptest! {
4181 #![proptest_config(ProptestConfig::with_cases(100))]
4182
4183 #[test]
4184 fn prop_multiple_join_support(
4185 table_name in table_name_strategy(),
4186 join_tables in prop::collection::vec(table_name_strategy(), 1..5)
4187 ) {
4188 let pool = create_test_pool_sync();
4189 let mut builder = QueryBuilder::new(&pool, &table_name, false);
4190
4191 for join_table in &join_tables {
4193 let on_condition = format!("{}.id = {}.id", table_name, join_table);
4194 builder = builder.join(join_table, &on_condition);
4195 }
4196
4197 prop_assert_eq!(builder.joins.len(), join_tables.len());
4199 }
4200 }
4201
4202 proptest! {
4205 #![proptest_config(ProptestConfig::with_cases(100))]
4206
4207 #[test]
4208 fn prop_table_alias_support(
4209 base_table in table_name_strategy(),
4210 join_table in table_name_strategy(),
4211 base_alias in "[a-z][a-z0-9]{0,5}",
4212 join_alias in "[a-z][a-z0-9]{0,5}"
4213 ) {
4214 prop_assume!(base_table != join_table);
4215 prop_assume!(base_alias != join_alias);
4216
4217 let pool = create_test_pool_sync();
4218
4219 let base_table_with_alias = format!("{} AS {}", base_table, base_alias);
4221 let join_table_with_alias = format!("{} AS {}", join_table, join_alias);
4222
4223 let on_condition = format!("{}.id = {}.id", base_alias, join_alias);
4225
4226 let builder = QueryBuilder::new(&pool, &base_table_with_alias, false)
4228 .field(&format!("{}.id", base_alias))
4229 .field(&format!("{}.name", base_alias))
4230 .join(&join_table_with_alias, &on_condition);
4231
4232 let sql = builder.to_sql();
4233
4234 prop_assert!(sql.contains(&format!("FROM {}", base_table_with_alias)),
4236 "SQL 应该包含带别名的主表: FROM {}", base_table_with_alias);
4237
4238 prop_assert!(sql.contains(&join_table_with_alias),
4240 "SQL 应该包含带别名的 JOIN 表: {}", join_table_with_alias);
4241
4242 prop_assert!(sql.contains(&on_condition),
4244 "SQL 应该包含使用别名的 ON 条件: {}", on_condition);
4245
4246 prop_assert!(sql.contains(&format!("{}.id", base_alias)),
4248 "SQL 应该包含使用别名的字段: {}.id", base_alias);
4249 prop_assert!(sql.contains(&format!("{}.name", base_alias)),
4250 "SQL 应该包含使用别名的字段: {}.name", base_alias);
4251 }
4252 }
4253
4254 proptest! {
4257 #![proptest_config(ProptestConfig::with_cases(100))]
4258
4259 #[test]
4260 fn prop_order_by_clause_generation(
4261 table_name in table_name_strategy(),
4262 field in field_name_strategy(),
4263 asc in any::<bool>()
4264 ) {
4265 let pool = create_test_pool_sync();
4266 let builder = QueryBuilder::new(&pool, &table_name, false)
4267 .order(&field, asc);
4268
4269 prop_assert_eq!(builder.order_by.len(), 1);
4271 prop_assert_eq!(&builder.order_by[0].field, &field);
4272 prop_assert_eq!(builder.order_by[0].asc, asc);
4273 }
4274 }
4275
4276 proptest! {
4279 #![proptest_config(ProptestConfig::with_cases(100))]
4280
4281 #[test]
4282 fn prop_multiple_order_by_support(
4283 table_name in table_name_strategy(),
4284 fields in prop::collection::vec(field_name_strategy(), 1..5)
4285 ) {
4286 let pool = create_test_pool_sync();
4287 let mut builder = QueryBuilder::new(&pool, &table_name, false);
4288
4289 for field in &fields {
4291 builder = builder.order(field, true);
4292 }
4293
4294 prop_assert_eq!(builder.order_by.len(), fields.len());
4296 }
4297 }
4298
4299 proptest! {
4302 #![proptest_config(ProptestConfig::with_cases(100))]
4303
4304 #[test]
4305 fn prop_group_by_clause_generation(
4306 table_name in table_name_strategy(),
4307 field in field_name_strategy()
4308 ) {
4309 let pool = create_test_pool_sync();
4310 let builder = QueryBuilder::new(&pool, &table_name, false)
4311 .group(&field);
4312
4313 prop_assert_eq!(builder.group_by.len(), 1);
4315 prop_assert_eq!(&builder.group_by[0], &field);
4316 }
4317 }
4318
4319 proptest! {
4322 #![proptest_config(ProptestConfig::with_cases(100))]
4323
4324 #[test]
4325 fn prop_multiple_group_by_support(
4326 table_name in table_name_strategy(),
4327 fields in prop::collection::vec(field_name_strategy(), 1..5)
4328 ) {
4329 let pool = create_test_pool_sync();
4330 let mut builder = QueryBuilder::new(&pool, &table_name, false);
4331
4332 for field in &fields {
4334 builder = builder.group(field);
4335 }
4336
4337 prop_assert_eq!(builder.group_by.len(), fields.len());
4339 }
4340 }
4341
4342 proptest! {
4345 #![proptest_config(ProptestConfig::with_cases(100))]
4346
4347 #[test]
4348 fn prop_to_sql_returns_valid_sql(
4349 table_name in table_name_strategy(),
4350 fields in prop::collection::vec(field_name_strategy(), 0..5),
4351 use_distinct in any::<bool>(),
4352 limit_opt in prop::option::of(1u64..100),
4353 offset_opt in prop::option::of(0u64..100)
4354 ) {
4355 let pool = create_test_pool_sync();
4356 let mut builder = QueryBuilder::new(&pool, &table_name, false);
4357
4358 for field in &fields {
4360 builder = builder.field(field);
4361 }
4362
4363 if use_distinct {
4365 builder = builder.distinct();
4366 }
4367
4368 if let Some(limit) = limit_opt {
4370 builder = builder.limit(limit);
4371 }
4372
4373 if let Some(offset) = offset_opt {
4375 builder = builder.offset(offset);
4376 }
4377
4378 let sql = builder.to_sql();
4380
4381 prop_assert!(!sql.is_empty(), "SQL 字符串不应为空");
4383
4384 prop_assert!(sql.contains("SELECT"), "SQL 应包含 SELECT 关键字");
4386 prop_assert!(sql.contains("FROM"), "SQL 应包含 FROM 关键字");
4387
4388 prop_assert!(sql.contains(&table_name), "SQL 应包含表名");
4390
4391 if use_distinct {
4393 prop_assert!(sql.contains("DISTINCT"), "SQL 应包含 DISTINCT 关键字");
4394 }
4395
4396 if let Some(limit) = limit_opt {
4398 prop_assert!(sql.contains("LIMIT"), "SQL 应包含 LIMIT 关键字");
4399 prop_assert!(sql.contains(&limit.to_string()), "SQL 应包含 LIMIT 值");
4400 }
4401
4402 if let Some(offset) = offset_opt {
4404 prop_assert!(sql.contains("OFFSET"), "SQL 应包含 OFFSET 关键字");
4405 prop_assert!(sql.contains(&offset.to_string()), "SQL 应包含 OFFSET 值");
4406 }
4407
4408 if !fields.is_empty() {
4410 for field in &fields {
4411 prop_assert!(sql.contains(field), "SQL 应包含字段 {}", field);
4412 }
4413 } else {
4414 prop_assert!(sql.contains("*"), "SQL 应包含 * 表示所有字段");
4416 }
4417 }
4418
4419 #[test]
4420 fn prop_to_sql_with_conditions(
4421 table_name in table_name_strategy(),
4422 field in field_name_strategy(),
4423 value in any::<i32>()
4424 ) {
4425 let pool = create_test_pool_sync();
4426 let builder = QueryBuilder::new(&pool, &table_name, false)
4427 .where_and_unchecked(&field, "=", value);
4428
4429 let sql = builder.to_sql();
4430
4431 prop_assert!(!sql.is_empty());
4433 prop_assert!(sql.contains("SELECT"));
4434 prop_assert!(sql.contains("FROM"));
4435 prop_assert!(sql.contains(&table_name));
4436
4437 prop_assert!(sql.contains("WHERE"), "SQL 应包含 WHERE 关键字");
4439 }
4440
4441 #[test]
4442 fn prop_to_sql_with_joins(
4443 table_name in table_name_strategy(),
4444 join_table in table_name_strategy(),
4445 on_field1 in field_name_strategy(),
4446 on_field2 in field_name_strategy()
4447 ) {
4448 let pool = create_test_pool_sync();
4449 let on_condition = format!("{}.{} = {}.{}", table_name, on_field1, join_table, on_field2);
4450 let builder = QueryBuilder::new(&pool, &table_name, false)
4451 .join(&join_table, &on_condition);
4452
4453 let sql = builder.to_sql();
4454
4455 prop_assert!(!sql.is_empty());
4457 prop_assert!(sql.contains("SELECT"));
4458 prop_assert!(sql.contains("FROM"));
4459
4460 prop_assert!(sql.contains("JOIN"), "SQL 应包含 JOIN 关键字");
4462 prop_assert!(sql.contains(&join_table), "SQL 应包含连接的表名");
4463 }
4464
4465 #[test]
4466 fn prop_to_sql_with_order_and_group(
4467 table_name in table_name_strategy(),
4468 order_field in field_name_strategy(),
4469 group_field in field_name_strategy(),
4470 asc in any::<bool>()
4471 ) {
4472 let pool = create_test_pool_sync();
4473 let builder = QueryBuilder::new(&pool, &table_name, false)
4474 .order(&order_field, asc)
4475 .group(&group_field);
4476
4477 let sql = builder.to_sql();
4478
4479 prop_assert!(!sql.is_empty());
4481 prop_assert!(sql.contains("SELECT"));
4482 prop_assert!(sql.contains("FROM"));
4483
4484 prop_assert!(sql.contains("ORDER BY"), "SQL 应包含 ORDER BY 关键字");
4486 prop_assert!(sql.contains("GROUP BY"), "SQL 应包含 GROUP BY 关键字");
4487 prop_assert!(sql.contains(&order_field), "SQL 应包含排序字段");
4488 prop_assert!(sql.contains(&group_field), "SQL 应包含分组字段");
4489 }
4490
4491 #[test]
4492 fn prop_to_sql_complex_query(
4493 table_name in table_name_strategy(),
4494 fields in prop::collection::vec(field_name_strategy(), 1..3),
4495 join_table in table_name_strategy(),
4496 where_field in field_name_strategy(),
4497 order_field in field_name_strategy(),
4498 group_field in field_name_strategy()
4499 ) {
4500 let pool = create_test_pool_sync();
4501 let mut builder = QueryBuilder::new(&pool, &table_name, false);
4502
4503 for field in &fields {
4505 builder = builder.field(field);
4506 }
4507
4508 let on_condition = format!("{}.id = {}.id", table_name, join_table);
4510 builder = builder.join(&join_table, &on_condition);
4511
4512 builder = builder.where_and_unchecked(&where_field, "=", 1);
4514
4515 builder = builder.order(&order_field, true);
4517
4518 builder = builder.group(&group_field);
4520
4521 builder = builder.limit(10);
4523
4524 let sql = builder.to_sql();
4525
4526 prop_assert!(!sql.is_empty());
4528 prop_assert!(sql.contains("SELECT"));
4529 prop_assert!(sql.contains("FROM"));
4530 prop_assert!(sql.contains(&table_name));
4531 prop_assert!(sql.contains("JOIN"));
4532 prop_assert!(sql.contains("WHERE"));
4533 prop_assert!(sql.contains("ORDER BY"));
4534 prop_assert!(sql.contains("GROUP BY"));
4535 prop_assert!(sql.contains("LIMIT"));
4536
4537 let select_pos = sql.find("SELECT").unwrap();
4539 let from_pos = sql.find("FROM").unwrap();
4540 let join_pos = sql.find("JOIN").unwrap();
4541 let where_pos = sql.find("WHERE").unwrap();
4542 let group_pos = sql.find("GROUP BY").unwrap();
4543 let order_pos = sql.find("ORDER BY").unwrap();
4544 let limit_pos = sql.find("LIMIT").unwrap();
4545
4546 prop_assert!(select_pos < from_pos, "SELECT 应在 FROM 之前");
4548 prop_assert!(from_pos < join_pos, "FROM 应在 JOIN 之前");
4549 prop_assert!(join_pos < where_pos, "JOIN 应在 WHERE 之前");
4550 prop_assert!(where_pos < group_pos, "WHERE 应在 GROUP BY 之前");
4551 prop_assert!(group_pos < order_pos, "GROUP BY 应在 ORDER BY 之前");
4552 prop_assert!(order_pos < limit_pos, "ORDER BY 应在 LIMIT 之前");
4553 }
4554 }
4555
4556 proptest! {
4559 #![proptest_config(ProptestConfig::with_cases(100))]
4560
4561 #[test]
4562 fn prop_sql_injection_prevention_single_quote(
4563 table_name in table_name_strategy(),
4564 field in field_name_strategy(),
4565 malicious_input in ".*'.*"
4566 ) {
4567 let pool = create_test_pool_sync();
4568 let builder = QueryBuilder::new(&pool, &table_name, false)
4569 .where_and_unchecked(&field, "=", malicious_input.as_str());
4570
4571 let sql = builder.to_sql();
4572
4573 prop_assert!(sql.contains("?"), "SQL 应该使用参数化查询(? 占位符)");
4576
4577 let where_clause = sql.split("WHERE").nth(1).unwrap_or("");
4580 prop_assert!(!where_clause.contains(&malicious_input),
4581 "WHERE 子句不应该直接包含用户输入的恶意字符串");
4582 }
4583
4584 #[test]
4585 fn prop_sql_injection_prevention_semicolon(
4586 table_name in table_name_strategy(),
4587 field in field_name_strategy(),
4588 malicious_input in ".*;.*"
4589 ) {
4590 let pool = create_test_pool_sync();
4591 let builder = QueryBuilder::new(&pool, &table_name, false)
4592 .where_and_unchecked(&field, "=", malicious_input.as_str());
4593
4594 let sql = builder.to_sql();
4595
4596 prop_assert!(sql.contains("?"), "SQL 应该使用参数化查询");
4598
4599 let where_clause = sql.split("WHERE").nth(1).unwrap_or("");
4601 prop_assert!(!where_clause.contains(&malicious_input),
4602 "WHERE 子句不应该直接包含用户输入的恶意字符串");
4603 }
4604
4605 #[test]
4606 fn prop_sql_injection_prevention_comment(
4607 table_name in table_name_strategy(),
4608 field in field_name_strategy(),
4609 malicious_input in ".*--.*"
4610 ) {
4611 let pool = create_test_pool_sync();
4612 let builder = QueryBuilder::new(&pool, &table_name, false)
4613 .where_and_unchecked(&field, "=", malicious_input.as_str());
4614
4615 let sql = builder.to_sql();
4616
4617 prop_assert!(sql.contains("?"), "SQL 应该使用参数化查询");
4619
4620 let where_clause = sql.split("WHERE").nth(1).unwrap_or("");
4622 prop_assert!(!where_clause.contains(&malicious_input),
4623 "WHERE 子句不应该直接包含用户输入的恶意字符串");
4624 }
4625
4626 #[test]
4627 fn prop_sql_injection_prevention_drop_table(
4628 table_name in table_name_strategy(),
4629 field in field_name_strategy()
4630 ) {
4631 let pool = create_test_pool_sync();
4632 let malicious_input = "'; DROP TABLE users; --";
4633 let builder = QueryBuilder::new(&pool, &table_name, false)
4634 .where_and_unchecked(&field, "=", malicious_input);
4635
4636 let sql = builder.to_sql();
4637
4638 prop_assert!(sql.contains("?"), "SQL 应该使用参数化查询");
4640
4641 prop_assert!(!sql.to_uppercase().contains("DROP TABLE"),
4643 "SQL 不应该包含 DROP TABLE 语句");
4644
4645 let where_clause = sql.split("WHERE").nth(1).unwrap_or("");
4647 prop_assert!(!where_clause.contains(malicious_input),
4648 "WHERE 子句不应该直接包含用户输入的恶意字符串");
4649 }
4650
4651 #[test]
4652 fn prop_sql_injection_prevention_union_select(
4653 table_name in table_name_strategy(),
4654 field in field_name_strategy()
4655 ) {
4656 let pool = create_test_pool_sync();
4657 let malicious_input = "' UNION SELECT * FROM passwords --";
4658 let builder = QueryBuilder::new(&pool, &table_name, false)
4659 .where_and_unchecked(&field, "=", malicious_input);
4660
4661 let sql = builder.to_sql();
4662
4663 prop_assert!(sql.contains("?"), "SQL 应该使用参数化查询");
4665
4666 let sql_upper = sql.to_uppercase();
4668 let union_count = sql_upper.matches("UNION").count();
4669 prop_assert_eq!(union_count, 0, "SQL 不应该包含 UNION 注入");
4670
4671 let where_clause = sql.split("WHERE").nth(1).unwrap_or("");
4673 prop_assert!(!where_clause.contains(malicious_input),
4674 "WHERE 子句不应该直接包含用户输入的恶意字符串");
4675 }
4676
4677 #[test]
4678 fn prop_sql_injection_prevention_or_always_true(
4679 table_name in table_name_strategy(),
4680 field in field_name_strategy()
4681 ) {
4682 let pool = create_test_pool_sync();
4683 let malicious_input = "' OR '1'='1";
4684 let builder = QueryBuilder::new(&pool, &table_name, false)
4685 .where_and_unchecked(&field, "=", malicious_input);
4686
4687 let sql = builder.to_sql();
4688
4689 prop_assert!(sql.contains("?"), "SQL 应该使用参数化查询");
4691
4692 let where_clause = sql.split("WHERE").nth(1).unwrap_or("");
4694 prop_assert!(!where_clause.contains(malicious_input),
4695 "WHERE 子句不应该直接包含用户输入的恶意字符串");
4696
4697 let or_count = where_clause.matches(" OR ").count();
4700 prop_assert_eq!(or_count, 0, "不应该因为用户输入而产生 OR 条件");
4702 }
4703
4704 #[test]
4705 fn prop_sql_injection_prevention_multiple_special_chars(
4706 table_name in table_name_strategy(),
4707 field in field_name_strategy(),
4708 malicious_input in "[a-z0-9]*[';\"\\-][a-z0-9]*[';\"\\-][a-z0-9]*"
4709 ) {
4710 let pool = create_test_pool_sync();
4711 let builder = QueryBuilder::new(&pool, &table_name, false)
4712 .where_and_unchecked(&field, "=", malicious_input.as_str());
4713
4714 let sql = builder.to_sql();
4715
4716 prop_assert!(sql.contains("?"), "SQL 应该使用参数化查询");
4718
4719 let where_clause = sql.split("WHERE").nth(1).unwrap_or("");
4721 prop_assert!(!where_clause.contains(&malicious_input),
4722 "WHERE 子句不应该直接包含用户输入的恶意字符串");
4723 }
4724
4725 #[test]
4726 fn prop_sql_injection_prevention_in_operator(
4727 table_name in table_name_strategy(),
4728 field in field_name_strategy(),
4729 malicious_values in prop::collection::vec(".*[';].*", 1..5)
4730 ) {
4731 let pool = create_test_pool_sync();
4732 let builder = QueryBuilder::new(&pool, &table_name, false)
4733 .where_in(&field, malicious_values.clone());
4734
4735 let sql = builder.to_sql();
4736
4737 prop_assert!(sql.contains("IN"), "SQL 应该包含 IN 操作符");
4739 prop_assert!(sql.contains("?"), "SQL 应该使用参数化查询");
4740
4741 let placeholder_count = sql.matches("?").count();
4743 prop_assert!(placeholder_count >= malicious_values.len(),
4744 "每个 IN 值都应该有对应的参数占位符");
4745
4746 for malicious_value in &malicious_values {
4748 let where_clause = sql.split("WHERE").nth(1).unwrap_or("");
4749 prop_assert!(!where_clause.contains(malicious_value),
4750 "WHERE 子句不应该直接包含用户输入的恶意字符串");
4751 }
4752 }
4753
4754 #[test]
4755 fn prop_sql_injection_prevention_like_operator(
4756 table_name in table_name_strategy(),
4757 field in field_name_strategy(),
4758 malicious_pattern in ".*[';].*"
4759 ) {
4760 let pool = create_test_pool_sync();
4761 let builder = QueryBuilder::new(&pool, &table_name, false)
4762 .where_and_unchecked(&field, "like", malicious_pattern.as_str());
4763
4764 let sql = builder.to_sql();
4765
4766 prop_assert!(sql.contains("LIKE"), "SQL 应该包含 LIKE 操作符");
4768 prop_assert!(sql.contains("?"), "SQL 应该使用参数化查询");
4769
4770 let where_clause = sql.split("WHERE").nth(1).unwrap_or("");
4772 prop_assert!(!where_clause.contains(&malicious_pattern),
4773 "WHERE 子句不应该直接包含用户输入的恶意字符串");
4774 }
4775
4776 #[test]
4777 fn prop_sql_injection_prevention_between_operator(
4778 table_name in table_name_strategy(),
4779 field in field_name_strategy(),
4780 malicious_start in ".*[';].*",
4781 malicious_end in ".*[';].*"
4782 ) {
4783 let pool = create_test_pool_sync();
4784 let builder = QueryBuilder::new(&pool, &table_name, false)
4785 .where_between(&field, malicious_start.as_str(), malicious_end.as_str());
4786
4787 let sql = builder.to_sql();
4788
4789 prop_assert!(sql.contains("BETWEEN"), "SQL 应该包含 BETWEEN 操作符");
4791 prop_assert!(sql.contains("?"), "SQL 应该使用参数化查询");
4792
4793 let where_clause = sql.split("WHERE").nth(1).unwrap_or("");
4795 let placeholder_count = where_clause.matches("?").count();
4796 prop_assert!(placeholder_count >= 2, "BETWEEN 应该有两个参数占位符");
4797
4798 prop_assert!(!where_clause.contains(&malicious_start),
4800 "WHERE 子句不应该直接包含用户输入的恶意字符串");
4801 prop_assert!(!where_clause.contains(&malicious_end),
4802 "WHERE 子句不应该直接包含用户输入的恶意字符串");
4803 }
4804 }
4805
4806 proptest! {
4809 #![proptest_config(ProptestConfig::with_cases(100))]
4810
4811 #[test]
4812 fn prop_find_adds_limit_one(
4813 table_name in table_name_strategy(),
4814 field in field_name_strategy(),
4815 value in any::<i32>()
4816 ) {
4817 let pool = create_test_pool_sync();
4818
4819 let builder = QueryBuilder::new(&pool, &table_name, false)
4821 .field(&field)
4822 .where_and_unchecked(&field, "=", value)
4823 .limit(1); let sql = builder.to_sql();
4826
4827 prop_assert!(sql.contains("LIMIT 1"),
4829 "find() 方法应该自动添加 LIMIT 1 到查询中");
4830 }
4831 }
4832
4833 proptest! {
4841 #![proptest_config(ProptestConfig::with_cases(100))]
4842
4843 #[test]
4844 fn prop_count_aggregation_function(
4845 table_name in table_name_strategy()
4846 ) {
4847 let pool = create_test_pool_sync();
4848
4849 let builder = QueryBuilder::new(&pool, &table_name, false)
4852 .field("COUNT(*)");
4853
4854 let sql = builder.to_sql();
4855
4856 prop_assert!(
4858 sql.contains("COUNT(*)") || sql.contains("COUNT("),
4859 "count() 方法应该生成包含 COUNT(*) 或 COUNT(field) 的 SQL 语句,实际 SQL: {}",
4860 sql
4861 );
4862
4863 prop_assert!(
4865 sql.to_uppercase().contains("SELECT"),
4866 "count() 方法应该生成 SELECT 语句,实际 SQL: {}",
4867 sql
4868 );
4869
4870 prop_assert!(
4872 sql.contains(&format!("FROM {}", table_name)),
4873 "count() 方法应该包含正确的表名,实际 SQL: {}",
4874 sql
4875 );
4876 }
4877 }
4878
4879 proptest! {
4887 #![proptest_config(ProptestConfig::with_cases(100))]
4888
4889 #[test]
4890 fn prop_count_with_where_condition(
4891 table_name in table_name_strategy(),
4892 field_name in field_name_strategy(),
4893 field_value in 1i32..1000i32,
4894 ) {
4895 let pool = create_test_pool_sync();
4896
4897 let builder = QueryBuilder::new(&pool, &table_name, false)
4899 .where_and_unchecked(&field_name, "=", field_value)
4900 .field("COUNT(*)");
4901
4902 let sql = builder.to_sql();
4903
4904 prop_assert!(
4906 sql.contains("COUNT(*)"),
4907 "带条件的 count() 查询应该包含 COUNT(*),实际 SQL: {}",
4908 sql
4909 );
4910
4911 prop_assert!(
4913 sql.to_uppercase().contains("WHERE"),
4914 "带条件的 count() 查询应该包含 WHERE 子句,实际 SQL: {}",
4915 sql
4916 );
4917
4918 prop_assert!(
4920 sql.contains(&format!("FROM {}", table_name)),
4921 "count() 方法应该包含正确的表名,实际 SQL: {}",
4922 sql
4923 );
4924 }
4925 }
4926
4927 proptest! {
4935 #![proptest_config(ProptestConfig::with_cases(100))]
4936
4937 #[test]
4938 fn prop_count_specific_field(
4939 table_name in table_name_strategy(),
4940 field_name in field_name_strategy(),
4941 ) {
4942 let pool = create_test_pool_sync();
4943
4944 let count_expr = format!("COUNT({})", field_name);
4946 let builder = QueryBuilder::new(&pool, &table_name, false)
4947 .field(&count_expr);
4948
4949 let sql = builder.to_sql();
4950
4951 prop_assert!(
4953 sql.contains(&count_expr),
4954 "COUNT 特定字段应该包含 COUNT(field_name),实际 SQL: {}",
4955 sql
4956 );
4957
4958 prop_assert!(
4960 sql.to_uppercase().contains("SELECT"),
4961 "COUNT 查询应该是 SELECT 语句,实际 SQL: {}",
4962 sql
4963 );
4964
4965 prop_assert!(
4967 sql.contains(&format!("FROM {}", table_name)),
4968 "COUNT 查询应该包含正确的表名,实际 SQL: {}",
4969 sql
4970 );
4971 }
4972 }
4973
4974 proptest! {
4982 #![proptest_config(ProptestConfig::with_cases(100))]
4983
4984 #[test]
4985 fn prop_sum_aggregation_function(
4986 table_name in table_name_strategy(),
4987 field in field_name_strategy()
4988 ) {
4989 let pool = create_test_pool_sync();
4990
4991 let sum_expr = format!("CAST(SUM({}) AS DOUBLE)", field);
4994 let builder = QueryBuilder::new(&pool, &table_name, false)
4995 .field(&sum_expr);
4996
4997 let sql = builder.to_sql();
4998
4999 prop_assert!(
5001 sql.contains("SUM("),
5002 "sum() 方法应该生成包含 SUM(field) 的 SQL 语句,实际 SQL: {}",
5003 sql
5004 );
5005
5006 prop_assert!(
5008 sql.contains(&field),
5009 "sum() 方法生成的 SQL 应该包含指定的字段名 {},实际 SQL: {}",
5010 field,
5011 sql
5012 );
5013
5014 prop_assert!(
5016 sql.to_uppercase().contains("SELECT"),
5017 "sum() 方法应该生成 SELECT 语句,实际 SQL: {}",
5018 sql
5019 );
5020
5021 prop_assert!(
5023 sql.contains(&format!("FROM {}", table_name)),
5024 "sum() 方法应该包含正确的表名,实际 SQL: {}",
5025 sql
5026 );
5027
5028 prop_assert!(
5030 sql.to_uppercase().contains("CAST"),
5031 "sum() 方法应该使用 CAST 转换结果为 DOUBLE,实际 SQL: {}",
5032 sql
5033 );
5034 }
5035 }
5036
5037 proptest! {
5045 #![proptest_config(ProptestConfig::with_cases(100))]
5046
5047 #[test]
5048 fn prop_sum_with_where_condition(
5049 table_name in table_name_strategy(),
5050 sum_field in field_name_strategy(),
5051 where_field in field_name_strategy(),
5052 where_value in 1i32..1000i32,
5053 ) {
5054 prop_assume!(sum_field != where_field);
5056
5057 let pool = create_test_pool_sync();
5058
5059 let sum_expr = format!("CAST(SUM({}) AS DOUBLE)", sum_field);
5061 let builder = QueryBuilder::new(&pool, &table_name, false)
5062 .where_and_unchecked(&where_field, "=", where_value)
5063 .field(&sum_expr);
5064
5065 let sql = builder.to_sql();
5066
5067 prop_assert!(
5069 sql.contains("SUM("),
5070 "带条件的 sum() 查询应该包含 SUM(field),实际 SQL: {}",
5071 sql
5072 );
5073
5074 prop_assert!(
5076 sql.contains(&sum_field),
5077 "sum() 方法应该包含求和字段名 {},实际 SQL: {}",
5078 sum_field,
5079 sql
5080 );
5081
5082 prop_assert!(
5084 sql.to_uppercase().contains("WHERE"),
5085 "带条件的 sum() 查询应该包含 WHERE 子句,实际 SQL: {}",
5086 sql
5087 );
5088
5089 prop_assert!(
5091 sql.contains(&format!("FROM {}", table_name)),
5092 "sum() 方法应该包含正确的表名,实际 SQL: {}",
5093 sql
5094 );
5095 }
5096 }
5097
5098 proptest! {
5106 #![proptest_config(ProptestConfig::with_cases(100))]
5107
5108 #[test]
5109 fn prop_sum_with_multiple_conditions(
5110 table_name in table_name_strategy(),
5111 sum_field in field_name_strategy(),
5112 where_field1 in field_name_strategy(),
5113 where_field2 in field_name_strategy(),
5114 value1 in 1i32..1000i32,
5115 value2 in 1i32..1000i32,
5116 ) {
5117 prop_assume!(sum_field != where_field1);
5119 prop_assume!(sum_field != where_field2);
5120 prop_assume!(where_field1 != where_field2);
5121
5122 let pool = create_test_pool_sync();
5123
5124 let sum_expr = format!("CAST(SUM({}) AS DOUBLE)", sum_field);
5126 let builder = QueryBuilder::new(&pool, &table_name, false)
5127 .where_and_unchecked(&where_field1, "=", value1)
5128 .where_and_unchecked(&where_field2, ">", value2)
5129 .field(&sum_expr);
5130
5131 let sql = builder.to_sql();
5132
5133 prop_assert!(
5135 sql.contains("SUM("),
5136 "多条件 sum() 查询应该包含 SUM(field),实际 SQL: {}",
5137 sql
5138 );
5139
5140 prop_assert!(
5142 sql.contains(&sum_field),
5143 "sum() 方法应该包含求和字段名 {},实际 SQL: {}",
5144 sum_field,
5145 sql
5146 );
5147
5148 prop_assert!(
5150 sql.to_uppercase().contains("WHERE"),
5151 "多条件查询应该包含 WHERE 子句,实际 SQL: {}",
5152 sql
5153 );
5154
5155 prop_assert!(
5157 sql.to_uppercase().contains(" AND "),
5158 "多个 where_and 条件应该用 AND 连接,实际 SQL: {}",
5159 sql
5160 );
5161 }
5162 }
5163
5164 proptest! {
5170 #![proptest_config(ProptestConfig::with_cases(200))]
5171
5172 #[test]
5173 fn prop_unsupported_operator_returns_error(
5174 op in "[a-zA-Z0-9!@#$%^&*]{1,10}"
5177 ) {
5178 let supported = ["=", "!=", ">", "<", ">=", "<=", "like", "LIKE"];
5180 prop_assume!(!supported.contains(&op.as_str()));
5181
5182 let pool = create_test_pool_sync();
5183 let builder = QueryBuilder::new(&pool, "users", false);
5184
5185 let result = builder.where_and("field", &op, 1i64);
5187
5188 prop_assert!(
5189 matches!(result, Err(crate::DbError::UnsupportedOperator(_))),
5190 "不支持的操作符 '{}' 应该返回 Err(DbError::UnsupportedOperator),实际结果: {:?}",
5191 op,
5192 result.map(|_| "Ok")
5193 );
5194 }
5195 }
5196
5197 #[test]
5200 fn test_where_and_supported_operators_eq() {
5201 let pool_storage = std::mem::MaybeUninit::<MySqlPool>::uninit();
5203 let pool: &MySqlPool = unsafe { &*pool_storage.as_ptr() };
5204 let result = QueryBuilder::new(pool, "users", false)
5205 .where_and("age", "=", 18i64);
5206 assert!(result.is_ok(), "操作符 '=' 应该返回 Ok");
5207 let builder = result.unwrap();
5208 assert_eq!(builder.conditions.len(), 1);
5209 }
5210
5211 #[test]
5212 fn test_where_and_supported_operators_ne() {
5213 let pool_storage = std::mem::MaybeUninit::<MySqlPool>::uninit();
5215 let pool: &MySqlPool = unsafe { &*pool_storage.as_ptr() };
5216 let result = QueryBuilder::new(pool, "users", false)
5217 .where_and("status", "!=", 0i64);
5218 assert!(result.is_ok(), "操作符 '!=' 应该返回 Ok");
5219 }
5220
5221 #[test]
5222 fn test_where_and_supported_operators_gt() {
5223 let pool_storage = std::mem::MaybeUninit::<MySqlPool>::uninit();
5225 let pool: &MySqlPool = unsafe { &*pool_storage.as_ptr() };
5226 let result = QueryBuilder::new(pool, "users", false)
5227 .where_and("age", ">", 18i64);
5228 assert!(result.is_ok(), "操作符 '>' 应该返回 Ok");
5229 }
5230
5231 #[test]
5232 fn test_where_and_supported_operators_lt() {
5233 let pool_storage = std::mem::MaybeUninit::<MySqlPool>::uninit();
5235 let pool: &MySqlPool = unsafe { &*pool_storage.as_ptr() };
5236 let result = QueryBuilder::new(pool, "users", false)
5237 .where_and("age", "<", 65i64);
5238 assert!(result.is_ok(), "操作符 '<' 应该返回 Ok");
5239 }
5240
5241 #[test]
5242 fn test_where_and_supported_operators_gte() {
5243 let pool_storage = std::mem::MaybeUninit::<MySqlPool>::uninit();
5245 let pool: &MySqlPool = unsafe { &*pool_storage.as_ptr() };
5246 let result = QueryBuilder::new(pool, "users", false)
5247 .where_and("score", ">=", 60i64);
5248 assert!(result.is_ok(), "操作符 '>=' 应该返回 Ok");
5249 }
5250
5251 #[test]
5252 fn test_where_and_supported_operators_lte() {
5253 let pool_storage = std::mem::MaybeUninit::<MySqlPool>::uninit();
5255 let pool: &MySqlPool = unsafe { &*pool_storage.as_ptr() };
5256 let result = QueryBuilder::new(pool, "users", false)
5257 .where_and("score", "<=", 100i64);
5258 assert!(result.is_ok(), "操作符 '<=' 应该返回 Ok");
5259 }
5260
5261 #[test]
5262 fn test_where_and_supported_operators_like_lowercase() {
5263 let pool_storage = std::mem::MaybeUninit::<MySqlPool>::uninit();
5265 let pool: &MySqlPool = unsafe { &*pool_storage.as_ptr() };
5266 let result = QueryBuilder::new(pool, "users", false)
5267 .where_and("name", "like", "%test%");
5268 assert!(result.is_ok(), "操作符 'like' 应该返回 Ok");
5269 }
5270
5271 #[test]
5272 fn test_where_and_supported_operators_like_uppercase() {
5273 let pool_storage = std::mem::MaybeUninit::<MySqlPool>::uninit();
5275 let pool: &MySqlPool = unsafe { &*pool_storage.as_ptr() };
5276 let result = QueryBuilder::new(pool, "users", false)
5277 .where_and("name", "LIKE", "%test%");
5278 assert!(result.is_ok(), "操作符 'LIKE' 应该返回 Ok");
5279 }
5280
5281 #[test]
5282 fn test_where_and_unsupported_operator_returns_error() {
5283 let pool_storage = std::mem::MaybeUninit::<MySqlPool>::uninit();
5285 let pool: &MySqlPool = unsafe { &*pool_storage.as_ptr() };
5286 let result = QueryBuilder::new(pool, "users", false)
5287 .where_and("age", "BETWEEN", 18i64);
5288 assert!(
5289 matches!(result, Err(crate::DbError::UnsupportedOperator(_))),
5290 "不支持的操作符 'BETWEEN' 应该返回 Err(DbError::UnsupportedOperator)"
5291 );
5292 }
5293
5294 #[test]
5295 fn test_where_and_unsupported_operator_error_message() {
5296 let pool_storage = std::mem::MaybeUninit::<MySqlPool>::uninit();
5298 let pool: &MySqlPool = unsafe { &*pool_storage.as_ptr() };
5299 let result = QueryBuilder::new(pool, "users", false)
5300 .where_and("age", "XOR", 1i64);
5301 match result {
5302 Err(crate::DbError::UnsupportedOperator(op)) => {
5303 assert_eq!(op, "XOR", "错误消息应该包含操作符名称");
5304 }
5305 _ => panic!("应该返回 UnsupportedOperator 错误"),
5306 }
5307 }
5308
5309 #[test]
5310 fn test_where_or_unsupported_operator_returns_error() {
5311 let pool_storage = std::mem::MaybeUninit::<MySqlPool>::uninit();
5313 let pool: &MySqlPool = unsafe { &*pool_storage.as_ptr() };
5314 let result = QueryBuilder::new(pool, "users", false)
5315 .where_or("age", "IN", 18i64);
5316 assert!(
5317 matches!(result, Err(crate::DbError::UnsupportedOperator(_))),
5318 "where_or 遇到不支持的操作符 'IN' 应该返回 Err(DbError::UnsupportedOperator)"
5319 );
5320 }
5321
5322 #[test]
5323 fn test_where_or_supported_operators_work() {
5324 let pool_storage = std::mem::MaybeUninit::<MySqlPool>::uninit();
5326 let pool: &MySqlPool = unsafe { &*pool_storage.as_ptr() };
5327 let result = QueryBuilder::new(pool, "users", false)
5328 .where_or("status", "=", 1i64);
5329 assert!(result.is_ok(), "where_or 操作符 '=' 应该返回 Ok");
5330 let builder = result.unwrap();
5331 assert_eq!(builder.conditions.len(), 1);
5332 }
5333
5334 #[test]
5335 fn test_having_cond_unsupported_operator_returns_error() {
5336 let pool_storage = std::mem::MaybeUninit::<MySqlPool>::uninit();
5338 let pool: &MySqlPool = unsafe { &*pool_storage.as_ptr() };
5339 let result = QueryBuilder::new(pool, "orders", false)
5340 .having_cond("cnt", "LIKE", 5i64);
5341 assert!(
5342 matches!(result, Err(crate::DbError::UnsupportedOperator(_))),
5343 "having_cond 遇到不支持的操作符 'LIKE' 应该返回 Err(DbError::UnsupportedOperator)"
5344 );
5345 }
5346
5347 #[test]
5348 fn test_having_cond_supported_operators_work() {
5349 let pool_storage = std::mem::MaybeUninit::<MySqlPool>::uninit();
5351 let pool: &MySqlPool = unsafe { &*pool_storage.as_ptr() };
5352 let result = QueryBuilder::new(pool, "orders", false)
5353 .having_cond("cnt", ">", 5i64);
5354 assert!(result.is_ok(), "having_cond 操作符 '>' 应该返回 Ok");
5355 let builder = result.unwrap();
5356 assert_eq!(builder.having_clause.len(), 1);
5357 }
5358
5359 #[test]
5360 fn test_where_and_unchecked_works_with_valid_operator() {
5361 let pool_storage = std::mem::MaybeUninit::<MySqlPool>::uninit();
5363 let pool: &MySqlPool = unsafe { &*pool_storage.as_ptr() };
5364 let builder = QueryBuilder::new(pool, "users", false)
5365 .where_and_unchecked("age", "=", 18i64);
5366 assert_eq!(builder.conditions.len(), 1);
5367 }
5368
5369 #[test]
5370 fn test_where_or_unchecked_works_with_valid_operator() {
5371 let pool_storage = std::mem::MaybeUninit::<MySqlPool>::uninit();
5373 let pool: &MySqlPool = unsafe { &*pool_storage.as_ptr() };
5374 let builder = QueryBuilder::new(pool, "users", false)
5375 .where_or_unchecked("status", "=", 1i64)
5376 .where_or_unchecked("status", "=", 2i64);
5377 assert_eq!(builder.conditions.len(), 1);
5379 }
5380
5381 #[test]
5382 fn test_having_cond_unchecked_works_with_valid_operator() {
5383 let pool_storage = std::mem::MaybeUninit::<MySqlPool>::uninit();
5385 let pool: &MySqlPool = unsafe { &*pool_storage.as_ptr() };
5386 let builder = QueryBuilder::new(pool, "orders", false)
5387 .having_cond_unchecked("cnt", ">", 5i64);
5388 assert_eq!(builder.having_clause.len(), 1);
5389 }
5390
5391 #[test]
5392 fn test_where_and_chaining_with_result() {
5393 let pool_storage = std::mem::MaybeUninit::<MySqlPool>::uninit();
5395 let pool: &MySqlPool = unsafe { &*pool_storage.as_ptr() };
5396 let result = QueryBuilder::new(pool, "users", false)
5397 .where_and("age", ">", 18i64)
5398 .and_then(|b| b.where_and("status", "=", 1i64));
5399 assert!(result.is_ok(), "链式 where_and 调用应该成功");
5400 let builder = result.unwrap();
5401 assert_eq!(builder.conditions.len(), 2);
5402 }
5403
5404 proptest! {
5412 #![proptest_config(ProptestConfig::with_cases(500))]
5413
5414 #[test]
5415 fn prop_batch_size_chunk_count(
5416 n in 0usize..=1000,
5418 b in 1usize..=200
5420 ) {
5421 let data: Vec<u32> = (0..n as u32).collect();
5423
5424 let actual_chunk_count = if data.is_empty() {
5426 0
5427 } else {
5428 data.chunks(b).count()
5429 };
5430
5431 let expected_chunk_count = if n == 0 {
5433 0
5434 } else {
5435 n.div_ceil(b) };
5437
5438 prop_assert_eq!(
5440 actual_chunk_count,
5441 expected_chunk_count,
5442 "n={} 条记录,批次大小 b={},实际分批数 {} 应等于 ceil(n/b)={}",
5443 n, b, actual_chunk_count, expected_chunk_count
5444 );
5445
5446 for chunk in data.chunks(b) {
5448 prop_assert!(
5449 chunk.len() <= b,
5450 "每个分批的大小 {} 不应超过批次大小 {}",
5451 chunk.len(), b
5452 );
5453 }
5454
5455 let total_records: usize = data.chunks(b).map(|c| c.len()).sum();
5457 prop_assert_eq!(
5458 total_records,
5459 n,
5460 "所有分批的记录总数 {} 应等于原始记录数 {}",
5461 total_records, n
5462 );
5463 }
5464 }
5465
5466 #[test]
5468 fn test_insert_batch_with_size_zero_batch_size_returns_error() {
5469 let rt = tokio::runtime::Runtime::new().unwrap();
5473 rt.block_on(async {
5474 let pool_storage = std::mem::MaybeUninit::<MySqlPool>::uninit();
5478 let pool: &MySqlPool = unsafe { &*pool_storage.as_ptr() };
5479 let builder = QueryBuilder::new(pool, "users", false);
5480
5481 let data = vec![serde_json::json!({"name": "张三"})];
5483
5484 let result = builder.insert_batch_with_size(&data, 0).await;
5486
5487 assert!(
5489 matches!(result, Err(crate::DbError::SerializationError(_))),
5490 "batch_size 为 0 应返回 SerializationError,实际结果: {:?}",
5491 result.map(|_| "Ok")
5492 );
5493
5494 if let Err(crate::DbError::SerializationError(msg)) = result {
5496 assert!(
5497 msg.contains("batch_size") || msg.contains("0"),
5498 "错误消息应提及 batch_size 不能为 0,实际消息: {}",
5499 msg
5500 );
5501 }
5502 });
5503 }
5504
5505 #[test]
5507 fn test_batch_chunk_logic_boundary_cases() {
5508 let data: Vec<u32> = (0..10).collect();
5512 let chunks: Vec<_> = data.chunks(10).collect();
5513 assert_eq!(chunks.len(), 1, "数据量等于批次大小时应只有 1 个分批");
5514 assert_eq!(chunks[0].len(), 10);
5515
5516 let data: Vec<u32> = (0..5).collect();
5518 let chunks: Vec<_> = data.chunks(10).collect();
5519 assert_eq!(chunks.len(), 1, "数据量小于批次大小时应只有 1 个分批");
5520 assert_eq!(chunks[0].len(), 5);
5521
5522 let data: Vec<u32> = (0..20).collect();
5524 let chunks: Vec<_> = data.chunks(5).collect();
5525 assert_eq!(chunks.len(), 4, "20 条记录按批次大小 5 分批应得到 4 个分批");
5526 for chunk in &chunks {
5527 assert_eq!(chunk.len(), 5);
5528 }
5529
5530 let data: Vec<u32> = (0..11).collect();
5532 let chunks: Vec<_> = data.chunks(5).collect();
5533 assert_eq!(chunks.len(), 3, "11 条记录按批次大小 5 分批应得到 3 个分批");
5534 assert_eq!(chunks[0].len(), 5);
5535 assert_eq!(chunks[1].len(), 5);
5536 assert_eq!(chunks[2].len(), 1, "最后一批应只有 1 条记录");
5537
5538 let data: Vec<u32> = (0..5).collect();
5540 let chunks: Vec<_> = data.chunks(1).collect();
5541 assert_eq!(chunks.len(), 5, "批次大小为 1 时,分批数应等于记录数");
5542 }
5543}