1use crate::error::{QueryError, Result};
2use serde_json::Value;
3use std::fmt;
4
5#[derive(Debug, Clone, PartialEq, Eq)]
7pub enum ComparisonOp {
8 Eq,
9 Ne,
10 Lt,
11 Le,
12 Gt,
13 Ge,
14 Like,
15 NotLike,
16 In,
17 NotIn,
18 IsNull,
19 IsNotNull,
20}
21
22impl fmt::Display for ComparisonOp {
23 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
24 match self {
25 ComparisonOp::Eq => write!(f, "="),
26 ComparisonOp::Ne => write!(f, "!="),
27 ComparisonOp::Lt => write!(f, "<"),
28 ComparisonOp::Le => write!(f, "<="),
29 ComparisonOp::Gt => write!(f, ">"),
30 ComparisonOp::Ge => write!(f, ">="),
31 ComparisonOp::Like => write!(f, "LIKE"),
32 ComparisonOp::NotLike => write!(f, "NOT LIKE"),
33 ComparisonOp::In => write!(f, "IN"),
34 ComparisonOp::NotIn => write!(f, "NOT IN"),
35 ComparisonOp::IsNull => write!(f, "IS NULL"),
36 ComparisonOp::IsNotNull => write!(f, "IS NOT NULL"),
37 }
38 }
39}
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43pub enum LogicalOp {
44 And,
45 Or,
46}
47
48impl fmt::Display for LogicalOp {
49 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50 match self {
51 LogicalOp::And => write!(f, "AND"),
52 LogicalOp::Or => write!(f, "OR"),
53 }
54 }
55}
56
57#[derive(Debug, Clone, Copy, PartialEq, Eq)]
59pub enum SortDirection {
60 Asc,
61 Desc,
62}
63
64impl fmt::Display for SortDirection {
65 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66 match self {
67 SortDirection::Asc => write!(f, "ASC"),
68 SortDirection::Desc => write!(f, "DESC"),
69 }
70 }
71}
72
73#[derive(Debug, Clone, Copy, PartialEq, Eq)]
75pub enum JoinType {
76 Inner,
77 Left,
78 Right,
79 Full,
80}
81
82impl fmt::Display for JoinType {
83 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
84 match self {
85 JoinType::Inner => write!(f, "INNER JOIN"),
86 JoinType::Left => write!(f, "LEFT JOIN"),
87 JoinType::Right => write!(f, "RIGHT JOIN"),
88 JoinType::Full => write!(f, "FULL JOIN"),
89 }
90 }
91}
92
93#[derive(Debug, Clone, PartialEq, Eq)]
95pub enum AggregateFunc {
96 Count,
97 CountDistinct,
98 Sum,
99 Avg,
100 Min,
101 Max,
102}
103
104impl AggregateFunc {
105 pub fn to_sql(&self, column: &str) -> String {
106 match self {
107 AggregateFunc::Count => "COUNT(*)".to_string(),
108 AggregateFunc::CountDistinct => format!("COUNT(DISTINCT {})", column),
109 AggregateFunc::Sum => format!("SUM({})", column),
110 AggregateFunc::Avg => format!("AVG({})", column),
111 AggregateFunc::Min => format!("MIN({})", column),
112 AggregateFunc::Max => format!("MAX({})", column),
113 }
114 }
115}
116
117#[derive(Debug, Clone)]
119pub enum Condition {
120 Simple {
121 column: String,
122 op: ComparisonOp,
123 value: Option<Value>,
124 },
125 Compound {
126 conditions: Vec<Condition>,
127 op: LogicalOp,
128 },
129 Raw(String),
130}
131
132impl Condition {
133 pub fn to_sql(&self) -> String {
134 match self {
135 Condition::Simple { column, op, value } => {
136 if matches!(op, ComparisonOp::IsNull | ComparisonOp::IsNotNull) {
137 format!("{} {}", column, op)
138 } else if matches!(op, ComparisonOp::In | ComparisonOp::NotIn) {
139 if let Some(Value::Array(arr)) = value {
140 let values = arr
141 .iter()
142 .map(format_value)
143 .collect::<Vec<_>>()
144 .join(", ");
145 format!("{} {} ({})", column, op, values)
146 } else {
147 format!("{} {} ()", column, op)
148 }
149 } else {
150 let val = value
151 .as_ref()
152 .map(format_value)
153 .unwrap_or_else(|| "NULL".to_string());
154 format!("{} {} {}", column, op, val)
155 }
156 }
157 Condition::Compound { conditions, op } => {
158 if conditions.is_empty() {
159 "TRUE".to_string()
160 } else {
161 let parts = conditions
162 .iter()
163 .map(|c| c.to_sql())
164 .collect::<Vec<_>>()
165 .join(&format!(" {} ", op));
166 format!("({})", parts)
167 }
168 }
169 Condition::Raw(sql) => sql.clone(),
170 }
171 }
172}
173
174#[derive(Debug, Clone)]
176pub struct Join {
177 pub join_type: JoinType,
178 pub table: String,
179 pub on_condition: Condition,
180}
181
182impl Join {
183 pub fn to_sql(&self) -> String {
184 format!("{} {} ON {}", self.join_type, self.table, self.on_condition.to_sql())
185 }
186}
187
188#[derive(Debug, Clone)]
190pub struct OrderBy {
191 pub column: String,
192 pub direction: SortDirection,
193}
194
195impl OrderBy {
196 pub fn to_sql(&self) -> String {
197 format!("{} {}", self.column, self.direction)
198 }
199}
200
201#[derive(Debug, Clone)]
203pub struct SelectBuilder {
204 table: String,
205 columns: Vec<String>,
206 joins: Vec<Join>,
207 where_clause: Option<Condition>,
208 group_by: Vec<String>,
209 having: Option<Condition>,
210 order_by: Vec<OrderBy>,
211 limit: Option<usize>,
212 offset: Option<usize>,
213 distinct: bool,
214}
215
216impl SelectBuilder {
217 pub fn new(table: impl Into<String>) -> Self {
219 Self {
220 table: table.into(),
221 columns: vec!["*".to_string()],
222 joins: Vec::new(),
223 where_clause: None,
224 group_by: Vec::new(),
225 having: None,
226 order_by: Vec::new(),
227 limit: None,
228 offset: None,
229 distinct: false,
230 }
231 }
232
233 pub fn select(mut self, columns: Vec<impl Into<String>>) -> Self {
235 self.columns = columns.into_iter().map(|c| c.into()).collect();
236 self
237 }
238
239 pub fn select_all(mut self) -> Self {
241 self.columns = vec!["*".to_string()];
242 self
243 }
244
245 pub fn add_column(mut self, column: impl Into<String>) -> Self {
247 if self.columns == vec!["*".to_string()] {
248 self.columns.clear();
249 }
250 self.columns.push(column.into());
251 self
252 }
253
254 pub fn aggregate(mut self, func: AggregateFunc, column: impl Into<String>, alias: Option<impl Into<String>>) -> Self {
256 if self.columns == vec!["*".to_string()] {
257 self.columns.clear();
258 }
259 let col_str = func.to_sql(&column.into());
260 if let Some(alias) = alias {
261 self.columns.push(format!("{} AS {}", col_str, alias.into()));
262 } else {
263 self.columns.push(col_str);
264 }
265 self
266 }
267
268 pub fn distinct(mut self) -> Self {
270 self.distinct = true;
271 self
272 }
273
274 pub fn where_clause(mut self, condition: Condition) -> Self {
276 self.where_clause = Some(condition);
277 self
278 }
279
280 pub fn and_where(mut self, condition: Condition) -> Self {
282 if let Some(existing) = self.where_clause {
283 self.where_clause = Some(Condition::Compound {
284 conditions: vec![existing, condition],
285 op: LogicalOp::And,
286 });
287 } else {
288 self.where_clause = Some(condition);
289 }
290 self
291 }
292
293 pub fn or_where(mut self, condition: Condition) -> Self {
295 if let Some(existing) = self.where_clause {
296 self.where_clause = Some(Condition::Compound {
297 conditions: vec![existing, condition],
298 op: LogicalOp::Or,
299 });
300 } else {
301 self.where_clause = Some(condition);
302 }
303 self
304 }
305
306 pub fn where_eq(self, column: impl Into<String>, value: Value) -> Self {
308 self.and_where(Condition::Simple {
309 column: column.into(),
310 op: ComparisonOp::Eq,
311 value: Some(value),
312 })
313 }
314
315 pub fn where_gt(self, column: impl Into<String>, value: Value) -> Self {
317 self.and_where(Condition::Simple {
318 column: column.into(),
319 op: ComparisonOp::Gt,
320 value: Some(value),
321 })
322 }
323
324 pub fn where_lt(self, column: impl Into<String>, value: Value) -> Self {
326 self.and_where(Condition::Simple {
327 column: column.into(),
328 op: ComparisonOp::Lt,
329 value: Some(value),
330 })
331 }
332
333 pub fn where_in(self, column: impl Into<String>, values: Vec<Value>) -> Self {
335 self.and_where(Condition::Simple {
336 column: column.into(),
337 op: ComparisonOp::In,
338 value: Some(Value::Array(values)),
339 })
340 }
341
342 pub fn where_null(self, column: impl Into<String>) -> Self {
344 self.and_where(Condition::Simple {
345 column: column.into(),
346 op: ComparisonOp::IsNull,
347 value: None,
348 })
349 }
350
351 pub fn where_like(self, column: impl Into<String>, pattern: impl Into<String>) -> Self {
353 self.and_where(Condition::Simple {
354 column: column.into(),
355 op: ComparisonOp::Like,
356 value: Some(Value::String(pattern.into())),
357 })
358 }
359
360 pub fn join(mut self, join_type: JoinType, table: impl Into<String>, on: Condition) -> Self {
362 self.joins.push(Join {
363 join_type,
364 table: table.into(),
365 on_condition: on,
366 });
367 self
368 }
369
370 pub fn inner_join(self, table: impl Into<String>, on: Condition) -> Self {
372 self.join(JoinType::Inner, table, on)
373 }
374
375 pub fn left_join(self, table: impl Into<String>, on: Condition) -> Self {
377 self.join(JoinType::Left, table, on)
378 }
379
380 pub fn group_by(mut self, columns: Vec<impl Into<String>>) -> Self {
382 self.group_by = columns.into_iter().map(|c| c.into()).collect();
383 self
384 }
385
386 pub fn having(mut self, condition: Condition) -> Self {
388 self.having = Some(condition);
389 self
390 }
391
392 pub fn order_by(mut self, column: impl Into<String>, direction: SortDirection) -> Self {
394 self.order_by.push(OrderBy {
395 column: column.into(),
396 direction,
397 });
398 self
399 }
400
401 pub fn order_asc(self, column: impl Into<String>) -> Self {
403 self.order_by(column, SortDirection::Asc)
404 }
405
406 pub fn order_desc(self, column: impl Into<String>) -> Self {
408 self.order_by(column, SortDirection::Desc)
409 }
410
411 pub fn limit(mut self, limit: usize) -> Self {
413 self.limit = Some(limit);
414 self
415 }
416
417 pub fn offset(mut self, offset: usize) -> Self {
419 self.offset = Some(offset);
420 self
421 }
422
423 pub fn build(self) -> Result<String> {
425 let mut sql = String::from("SELECT ");
426
427 if self.distinct {
428 sql.push_str("DISTINCT ");
429 }
430
431 sql.push_str(&self.columns.join(", "));
432 sql.push_str(&format!(" FROM {}", self.table));
433
434 for join in &self.joins {
435 sql.push_str(" ");
436 sql.push_str(&join.to_sql());
437 }
438
439 if let Some(where_clause) = &self.where_clause {
440 sql.push_str(" WHERE ");
441 sql.push_str(&where_clause.to_sql());
442 }
443
444 if !self.group_by.is_empty() {
445 sql.push_str(" GROUP BY ");
446 sql.push_str(&self.group_by.join(", "));
447 }
448
449 if let Some(having) = &self.having {
450 sql.push_str(" HAVING ");
451 sql.push_str(&having.to_sql());
452 }
453
454 if !self.order_by.is_empty() {
455 sql.push_str(" ORDER BY ");
456 sql.push_str(
457 &self
458 .order_by
459 .iter()
460 .map(|o| o.to_sql())
461 .collect::<Vec<_>>()
462 .join(", "),
463 );
464 }
465
466 if let Some(limit) = self.limit {
467 sql.push_str(&format!(" LIMIT {}", limit));
468 }
469
470 if let Some(offset) = self.offset {
471 sql.push_str(&format!(" OFFSET {}", offset));
472 }
473
474 Ok(sql)
475 }
476
477 pub fn to_sql(self) -> Result<String> {
479 self.build()
480 }
481}
482
483#[derive(Debug, Clone)]
485pub struct UpdateBuilder {
486 table: String,
487 set_values: Vec<(String, Value)>,
488 where_clause: Option<Condition>,
489}
490
491impl UpdateBuilder {
492 pub fn new(table: impl Into<String>) -> Self {
494 Self {
495 table: table.into(),
496 set_values: Vec::new(),
497 where_clause: None,
498 }
499 }
500
501 pub fn set(mut self, column: impl Into<String>, value: Value) -> Self {
503 self.set_values.push((column.into(), value));
504 self
505 }
506
507 pub fn set_many(mut self, values: Vec<(impl Into<String>, Value)>) -> Self {
509 for (col, val) in values {
510 self.set_values.push((col.into(), val));
511 }
512 self
513 }
514
515 pub fn where_clause(mut self, condition: Condition) -> Self {
517 self.where_clause = Some(condition);
518 self
519 }
520
521 pub fn where_eq(mut self, column: impl Into<String>, value: Value) -> Self {
523 let condition = Condition::Simple {
524 column: column.into(),
525 op: ComparisonOp::Eq,
526 value: Some(value),
527 };
528 if let Some(existing) = self.where_clause {
529 self.where_clause = Some(Condition::Compound {
530 conditions: vec![existing, condition],
531 op: LogicalOp::And,
532 });
533 } else {
534 self.where_clause = Some(condition);
535 }
536 self
537 }
538
539 pub fn build(self) -> Result<String> {
541 if self.set_values.is_empty() {
542 return Err(QueryError::Query("UPDATE must have at least one SET value".to_string()));
543 }
544
545 let mut sql = format!("UPDATE {} SET ", self.table);
546
547 let set_clauses: Vec<String> = self
548 .set_values
549 .iter()
550 .map(|(col, val)| format!("{} = {}", col, format_value(val)))
551 .collect();
552
553 sql.push_str(&set_clauses.join(", "));
554
555 if let Some(where_clause) = &self.where_clause {
556 sql.push_str(" WHERE ");
557 sql.push_str(&where_clause.to_sql());
558 }
559
560 Ok(sql)
561 }
562
563 pub fn to_sql(self) -> Result<String> {
565 self.build()
566 }
567}
568
569#[derive(Debug, Clone)]
571pub struct DeleteBuilder {
572 table: String,
573 where_clause: Option<Condition>,
574}
575
576impl DeleteBuilder {
577 pub fn new(table: impl Into<String>) -> Self {
579 Self {
580 table: table.into(),
581 where_clause: None,
582 }
583 }
584
585 pub fn where_clause(mut self, condition: Condition) -> Self {
587 self.where_clause = Some(condition);
588 self
589 }
590
591 pub fn where_eq(mut self, column: impl Into<String>, value: Value) -> Self {
593 let condition = Condition::Simple {
594 column: column.into(),
595 op: ComparisonOp::Eq,
596 value: Some(value),
597 };
598 if let Some(existing) = self.where_clause {
599 self.where_clause = Some(Condition::Compound {
600 conditions: vec![existing, condition],
601 op: LogicalOp::And,
602 });
603 } else {
604 self.where_clause = Some(condition);
605 }
606 self
607 }
608
609 pub fn build(self) -> Result<String> {
611 let mut sql = format!("DELETE FROM {}", self.table);
612
613 if let Some(where_clause) = &self.where_clause {
614 sql.push_str(" WHERE ");
615 sql.push_str(&where_clause.to_sql());
616 }
617
618 Ok(sql)
619 }
620
621 pub fn to_sql(self) -> Result<String> {
623 self.build()
624 }
625}
626
627fn format_value(value: &Value) -> String {
629 match value {
630 Value::Null => "NULL".to_string(),
631 Value::Bool(b) => b.to_string().to_uppercase(),
632 Value::Number(n) => n.to_string(),
633 Value::String(s) => format!("'{}'", s.replace('\'', "''")),
634 Value::Array(_) | Value::Object(_) => {
635 format!("'{}'", serde_json::to_string(value).unwrap_or_default().replace('\'', "''"))
636 }
637 }
638}
639
640#[cfg(test)]
641mod tests {
642 use super::*;
643
644 #[test]
645 fn test_select_basic() {
646 let query = SelectBuilder::new("users")
647 .select_all()
648 .build()
649 .unwrap();
650
651 assert_eq!(query, "SELECT * FROM users");
652 }
653
654 #[test]
655 fn test_select_columns() {
656 let query = SelectBuilder::new("users")
657 .select(vec!["id", "name", "email"])
658 .build()
659 .unwrap();
660
661 assert_eq!(query, "SELECT id, name, email FROM users");
662 }
663
664 #[test]
665 fn test_select_where() {
666 let query = SelectBuilder::new("users")
667 .select_all()
668 .where_eq("id", Value::Number(1.into()))
669 .build()
670 .unwrap();
671
672 assert_eq!(query, "SELECT * FROM users WHERE id = 1");
673 }
674
675 #[test]
676 fn test_select_where_multiple() {
677 let query = SelectBuilder::new("users")
678 .select_all()
679 .where_eq("age", Value::Number(25.into()))
680 .where_gt("score", Value::Number(100.into()))
681 .build()
682 .unwrap();
683
684 assert_eq!(query, "SELECT * FROM users WHERE (age = 25 AND score > 100)");
685 }
686
687 #[test]
688 fn test_select_where_in() {
689 let query = SelectBuilder::new("users")
690 .select_all()
691 .where_in("id", vec![Value::Number(1.into()), Value::Number(2.into())])
692 .build()
693 .unwrap();
694
695 assert_eq!(query, "SELECT * FROM users WHERE id IN (1, 2)");
696 }
697
698 #[test]
699 fn test_select_join() {
700 let query = SelectBuilder::new("users")
701 .select(vec!["users.name", "posts.title"])
702 .inner_join(
703 "posts",
704 Condition::Raw("users.id = posts.user_id".to_string()),
705 )
706 .build()
707 .unwrap();
708
709 assert_eq!(
710 query,
711 "SELECT users.name, posts.title FROM users INNER JOIN posts ON users.id = posts.user_id"
712 );
713 }
714
715 #[test]
716 fn test_select_order_limit() {
717 let query = SelectBuilder::new("users")
718 .select_all()
719 .order_desc("created_at")
720 .limit(10)
721 .build()
722 .unwrap();
723
724 assert_eq!(query, "SELECT * FROM users ORDER BY created_at DESC LIMIT 10");
725 }
726
727 #[test]
728 fn test_select_group_by() {
729 let query = SelectBuilder::new("orders")
730 .add_column("user_id")
731 .aggregate(AggregateFunc::Count, "*", Some("order_count"))
732 .group_by(vec!["user_id"])
733 .build()
734 .unwrap();
735
736 assert_eq!(query, "SELECT user_id, COUNT(*) AS order_count FROM orders GROUP BY user_id");
737 }
738
739 #[test]
740 fn test_select_distinct() {
741 let query = SelectBuilder::new("users")
742 .select(vec!["country"])
743 .distinct()
744 .build()
745 .unwrap();
746
747 assert_eq!(query, "SELECT DISTINCT country FROM users");
748 }
749
750 #[test]
751 fn test_update_basic() {
752 let query = UpdateBuilder::new("users")
753 .set("name", Value::String("Alice".to_string()))
754 .where_eq("id", Value::Number(1.into()))
755 .build()
756 .unwrap();
757
758 assert_eq!(query, "UPDATE users SET name = 'Alice' WHERE id = 1");
759 }
760
761 #[test]
762 fn test_update_multiple() {
763 let query = UpdateBuilder::new("users")
764 .set("name", Value::String("Alice".to_string()))
765 .set("age", Value::Number(30.into()))
766 .where_eq("id", Value::Number(1.into()))
767 .build()
768 .unwrap();
769
770 assert_eq!(query, "UPDATE users SET name = 'Alice', age = 30 WHERE id = 1");
771 }
772
773 #[test]
774 fn test_delete_basic() {
775 let query = DeleteBuilder::new("users")
776 .where_eq("id", Value::Number(1.into()))
777 .build()
778 .unwrap();
779
780 assert_eq!(query, "DELETE FROM users WHERE id = 1");
781 }
782
783 #[test]
784 fn test_condition_is_null() {
785 let condition = Condition::Simple {
786 column: "deleted_at".to_string(),
787 op: ComparisonOp::IsNull,
788 value: None,
789 };
790
791 assert_eq!(condition.to_sql(), "deleted_at IS NULL");
792 }
793
794 #[test]
795 fn test_condition_like() {
796 let condition = Condition::Simple {
797 column: "name".to_string(),
798 op: ComparisonOp::Like,
799 value: Some(Value::String("%Alice%".to_string())),
800 };
801
802 assert_eq!(condition.to_sql(), "name LIKE '%Alice%'");
803 }
804
805 #[test]
806 fn test_format_value_string_escaping() {
807 let value = Value::String("O'Reilly".to_string());
808 assert_eq!(format_value(&value), "'O''Reilly'");
809 }
810}