tx2_query/
builder.rs

1use crate::error::{QueryError, Result};
2use serde_json::Value;
3use std::fmt;
4
5/// Comparison operators for WHERE clauses
6#[derive(Debug, Clone, PartialEq, Eq)]
7pub enum ComparisonOp {
8    Eq,
9    Ne,
10    Lt,
11    Le,
12    Gt,
13    Ge,
14    Like,
15    NotLike,
16    In,
17    NotIn,
18    IsNull,
19    IsNotNull,
20}
21
22impl fmt::Display for ComparisonOp {
23    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
24        match self {
25            ComparisonOp::Eq => write!(f, "="),
26            ComparisonOp::Ne => write!(f, "!="),
27            ComparisonOp::Lt => write!(f, "<"),
28            ComparisonOp::Le => write!(f, "<="),
29            ComparisonOp::Gt => write!(f, ">"),
30            ComparisonOp::Ge => write!(f, ">="),
31            ComparisonOp::Like => write!(f, "LIKE"),
32            ComparisonOp::NotLike => write!(f, "NOT LIKE"),
33            ComparisonOp::In => write!(f, "IN"),
34            ComparisonOp::NotIn => write!(f, "NOT IN"),
35            ComparisonOp::IsNull => write!(f, "IS NULL"),
36            ComparisonOp::IsNotNull => write!(f, "IS NOT NULL"),
37        }
38    }
39}
40
41/// Logical operators for combining conditions
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43pub enum LogicalOp {
44    And,
45    Or,
46}
47
48impl fmt::Display for LogicalOp {
49    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50        match self {
51            LogicalOp::And => write!(f, "AND"),
52            LogicalOp::Or => write!(f, "OR"),
53        }
54    }
55}
56
57/// Sort direction
58#[derive(Debug, Clone, Copy, PartialEq, Eq)]
59pub enum SortDirection {
60    Asc,
61    Desc,
62}
63
64impl fmt::Display for SortDirection {
65    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66        match self {
67            SortDirection::Asc => write!(f, "ASC"),
68            SortDirection::Desc => write!(f, "DESC"),
69        }
70    }
71}
72
73/// Join type
74#[derive(Debug, Clone, Copy, PartialEq, Eq)]
75pub enum JoinType {
76    Inner,
77    Left,
78    Right,
79    Full,
80}
81
82impl fmt::Display for JoinType {
83    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
84        match self {
85            JoinType::Inner => write!(f, "INNER JOIN"),
86            JoinType::Left => write!(f, "LEFT JOIN"),
87            JoinType::Right => write!(f, "RIGHT JOIN"),
88            JoinType::Full => write!(f, "FULL JOIN"),
89        }
90    }
91}
92
93/// Aggregate function
94#[derive(Debug, Clone, PartialEq, Eq)]
95pub enum AggregateFunc {
96    Count,
97    CountDistinct,
98    Sum,
99    Avg,
100    Min,
101    Max,
102}
103
104impl AggregateFunc {
105    pub fn to_sql(&self, column: &str) -> String {
106        match self {
107            AggregateFunc::Count => "COUNT(*)".to_string(),
108            AggregateFunc::CountDistinct => format!("COUNT(DISTINCT {})", column),
109            AggregateFunc::Sum => format!("SUM({})", column),
110            AggregateFunc::Avg => format!("AVG({})", column),
111            AggregateFunc::Min => format!("MIN({})", column),
112            AggregateFunc::Max => format!("MAX({})", column),
113        }
114    }
115}
116
117/// WHERE condition
118#[derive(Debug, Clone)]
119pub enum Condition {
120    Simple {
121        column: String,
122        op: ComparisonOp,
123        value: Option<Value>,
124    },
125    Compound {
126        conditions: Vec<Condition>,
127        op: LogicalOp,
128    },
129    Raw(String),
130}
131
132impl Condition {
133    pub fn to_sql(&self) -> String {
134        match self {
135            Condition::Simple { column, op, value } => {
136                if matches!(op, ComparisonOp::IsNull | ComparisonOp::IsNotNull) {
137                    format!("{} {}", column, op)
138                } else if matches!(op, ComparisonOp::In | ComparisonOp::NotIn) {
139                    if let Some(Value::Array(arr)) = value {
140                        let values = arr
141                            .iter()
142                            .map(format_value)
143                            .collect::<Vec<_>>()
144                            .join(", ");
145                        format!("{} {} ({})", column, op, values)
146                    } else {
147                        format!("{} {} ()", column, op)
148                    }
149                } else {
150                    let val = value
151                        .as_ref()
152                        .map(format_value)
153                        .unwrap_or_else(|| "NULL".to_string());
154                    format!("{} {} {}", column, op, val)
155                }
156            }
157            Condition::Compound { conditions, op } => {
158                if conditions.is_empty() {
159                    "TRUE".to_string()
160                } else {
161                    let parts = conditions
162                        .iter()
163                        .map(|c| c.to_sql())
164                        .collect::<Vec<_>>()
165                        .join(&format!(" {} ", op));
166                    format!("({})", parts)
167                }
168            }
169            Condition::Raw(sql) => sql.clone(),
170        }
171    }
172}
173
174/// Join clause
175#[derive(Debug, Clone)]
176pub struct Join {
177    pub join_type: JoinType,
178    pub table: String,
179    pub on_condition: Condition,
180}
181
182impl Join {
183    pub fn to_sql(&self) -> String {
184        format!("{} {} ON {}", self.join_type, self.table, self.on_condition.to_sql())
185    }
186}
187
188/// ORDER BY clause
189#[derive(Debug, Clone)]
190pub struct OrderBy {
191    pub column: String,
192    pub direction: SortDirection,
193}
194
195impl OrderBy {
196    pub fn to_sql(&self) -> String {
197        format!("{} {}", self.column, self.direction)
198    }
199}
200
201/// SELECT query builder
202#[derive(Debug, Clone)]
203pub struct SelectBuilder {
204    table: String,
205    columns: Vec<String>,
206    joins: Vec<Join>,
207    where_clause: Option<Condition>,
208    group_by: Vec<String>,
209    having: Option<Condition>,
210    order_by: Vec<OrderBy>,
211    limit: Option<usize>,
212    offset: Option<usize>,
213    distinct: bool,
214}
215
216impl SelectBuilder {
217    /// Create a new SELECT query builder
218    pub fn new(table: impl Into<String>) -> Self {
219        Self {
220            table: table.into(),
221            columns: vec!["*".to_string()],
222            joins: Vec::new(),
223            where_clause: None,
224            group_by: Vec::new(),
225            having: None,
226            order_by: Vec::new(),
227            limit: None,
228            offset: None,
229            distinct: false,
230        }
231    }
232
233    /// Select specific columns
234    pub fn select(mut self, columns: Vec<impl Into<String>>) -> Self {
235        self.columns = columns.into_iter().map(|c| c.into()).collect();
236        self
237    }
238
239    /// Select all columns
240    pub fn select_all(mut self) -> Self {
241        self.columns = vec!["*".to_string()];
242        self
243    }
244
245    /// Add a column to select
246    pub fn add_column(mut self, column: impl Into<String>) -> Self {
247        if self.columns == vec!["*".to_string()] {
248            self.columns.clear();
249        }
250        self.columns.push(column.into());
251        self
252    }
253
254    /// Add an aggregate function
255    pub fn aggregate(mut self, func: AggregateFunc, column: impl Into<String>, alias: Option<impl Into<String>>) -> Self {
256        if self.columns == vec!["*".to_string()] {
257            self.columns.clear();
258        }
259        let col_str = func.to_sql(&column.into());
260        if let Some(alias) = alias {
261            self.columns.push(format!("{} AS {}", col_str, alias.into()));
262        } else {
263            self.columns.push(col_str);
264        }
265        self
266    }
267
268    /// Use DISTINCT
269    pub fn distinct(mut self) -> Self {
270        self.distinct = true;
271        self
272    }
273
274    /// Add a WHERE condition
275    pub fn where_clause(mut self, condition: Condition) -> Self {
276        self.where_clause = Some(condition);
277        self
278    }
279
280    /// Add an AND condition to existing WHERE
281    pub fn and_where(mut self, condition: Condition) -> Self {
282        if let Some(existing) = self.where_clause {
283            self.where_clause = Some(Condition::Compound {
284                conditions: vec![existing, condition],
285                op: LogicalOp::And,
286            });
287        } else {
288            self.where_clause = Some(condition);
289        }
290        self
291    }
292
293    /// Add an OR condition to existing WHERE
294    pub fn or_where(mut self, condition: Condition) -> Self {
295        if let Some(existing) = self.where_clause {
296            self.where_clause = Some(Condition::Compound {
297                conditions: vec![existing, condition],
298                op: LogicalOp::Or,
299            });
300        } else {
301            self.where_clause = Some(condition);
302        }
303        self
304    }
305
306    /// Add a simple WHERE condition (column = value)
307    pub fn where_eq(self, column: impl Into<String>, value: Value) -> Self {
308        self.and_where(Condition::Simple {
309            column: column.into(),
310            op: ComparisonOp::Eq,
311            value: Some(value),
312        })
313    }
314
315    /// Add a WHERE column > value condition
316    pub fn where_gt(self, column: impl Into<String>, value: Value) -> Self {
317        self.and_where(Condition::Simple {
318            column: column.into(),
319            op: ComparisonOp::Gt,
320            value: Some(value),
321        })
322    }
323
324    /// Add a WHERE column < value condition
325    pub fn where_lt(self, column: impl Into<String>, value: Value) -> Self {
326        self.and_where(Condition::Simple {
327            column: column.into(),
328            op: ComparisonOp::Lt,
329            value: Some(value),
330        })
331    }
332
333    /// Add a WHERE column IN (...) condition
334    pub fn where_in(self, column: impl Into<String>, values: Vec<Value>) -> Self {
335        self.and_where(Condition::Simple {
336            column: column.into(),
337            op: ComparisonOp::In,
338            value: Some(Value::Array(values)),
339        })
340    }
341
342    /// Add a WHERE column IS NULL condition
343    pub fn where_null(self, column: impl Into<String>) -> Self {
344        self.and_where(Condition::Simple {
345            column: column.into(),
346            op: ComparisonOp::IsNull,
347            value: None,
348        })
349    }
350
351    /// Add a WHERE column LIKE pattern condition
352    pub fn where_like(self, column: impl Into<String>, pattern: impl Into<String>) -> Self {
353        self.and_where(Condition::Simple {
354            column: column.into(),
355            op: ComparisonOp::Like,
356            value: Some(Value::String(pattern.into())),
357        })
358    }
359
360    /// Add a JOIN clause
361    pub fn join(mut self, join_type: JoinType, table: impl Into<String>, on: Condition) -> Self {
362        self.joins.push(Join {
363            join_type,
364            table: table.into(),
365            on_condition: on,
366        });
367        self
368    }
369
370    /// Add an INNER JOIN
371    pub fn inner_join(self, table: impl Into<String>, on: Condition) -> Self {
372        self.join(JoinType::Inner, table, on)
373    }
374
375    /// Add a LEFT JOIN
376    pub fn left_join(self, table: impl Into<String>, on: Condition) -> Self {
377        self.join(JoinType::Left, table, on)
378    }
379
380    /// Add GROUP BY columns
381    pub fn group_by(mut self, columns: Vec<impl Into<String>>) -> Self {
382        self.group_by = columns.into_iter().map(|c| c.into()).collect();
383        self
384    }
385
386    /// Add a HAVING condition
387    pub fn having(mut self, condition: Condition) -> Self {
388        self.having = Some(condition);
389        self
390    }
391
392    /// Add ORDER BY
393    pub fn order_by(mut self, column: impl Into<String>, direction: SortDirection) -> Self {
394        self.order_by.push(OrderBy {
395            column: column.into(),
396            direction,
397        });
398        self
399    }
400
401    /// Add ascending ORDER BY
402    pub fn order_asc(self, column: impl Into<String>) -> Self {
403        self.order_by(column, SortDirection::Asc)
404    }
405
406    /// Add descending ORDER BY
407    pub fn order_desc(self, column: impl Into<String>) -> Self {
408        self.order_by(column, SortDirection::Desc)
409    }
410
411    /// Set LIMIT
412    pub fn limit(mut self, limit: usize) -> Self {
413        self.limit = Some(limit);
414        self
415    }
416
417    /// Set OFFSET
418    pub fn offset(mut self, offset: usize) -> Self {
419        self.offset = Some(offset);
420        self
421    }
422
423    /// Build the SQL query string
424    pub fn build(self) -> Result<String> {
425        let mut sql = String::from("SELECT ");
426
427        if self.distinct {
428            sql.push_str("DISTINCT ");
429        }
430
431        sql.push_str(&self.columns.join(", "));
432        sql.push_str(&format!(" FROM {}", self.table));
433
434        for join in &self.joins {
435            sql.push_str(" ");
436            sql.push_str(&join.to_sql());
437        }
438
439        if let Some(where_clause) = &self.where_clause {
440            sql.push_str(" WHERE ");
441            sql.push_str(&where_clause.to_sql());
442        }
443
444        if !self.group_by.is_empty() {
445            sql.push_str(" GROUP BY ");
446            sql.push_str(&self.group_by.join(", "));
447        }
448
449        if let Some(having) = &self.having {
450            sql.push_str(" HAVING ");
451            sql.push_str(&having.to_sql());
452        }
453
454        if !self.order_by.is_empty() {
455            sql.push_str(" ORDER BY ");
456            sql.push_str(
457                &self
458                    .order_by
459                    .iter()
460                    .map(|o| o.to_sql())
461                    .collect::<Vec<_>>()
462                    .join(", "),
463            );
464        }
465
466        if let Some(limit) = self.limit {
467            sql.push_str(&format!(" LIMIT {}", limit));
468        }
469
470        if let Some(offset) = self.offset {
471            sql.push_str(&format!(" OFFSET {}", offset));
472        }
473
474        Ok(sql)
475    }
476
477    /// Build and return the SQL query string (convenience method)
478    pub fn to_sql(self) -> Result<String> {
479        self.build()
480    }
481}
482
483/// UPDATE query builder
484#[derive(Debug, Clone)]
485pub struct UpdateBuilder {
486    table: String,
487    set_values: Vec<(String, Value)>,
488    where_clause: Option<Condition>,
489}
490
491impl UpdateBuilder {
492    /// Create a new UPDATE query builder
493    pub fn new(table: impl Into<String>) -> Self {
494        Self {
495            table: table.into(),
496            set_values: Vec::new(),
497            where_clause: None,
498        }
499    }
500
501    /// Set a column value
502    pub fn set(mut self, column: impl Into<String>, value: Value) -> Self {
503        self.set_values.push((column.into(), value));
504        self
505    }
506
507    /// Set multiple column values
508    pub fn set_many(mut self, values: Vec<(impl Into<String>, Value)>) -> Self {
509        for (col, val) in values {
510            self.set_values.push((col.into(), val));
511        }
512        self
513    }
514
515    /// Add WHERE condition
516    pub fn where_clause(mut self, condition: Condition) -> Self {
517        self.where_clause = Some(condition);
518        self
519    }
520
521    /// Add simple WHERE condition (column = value)
522    pub fn where_eq(mut self, column: impl Into<String>, value: Value) -> Self {
523        let condition = Condition::Simple {
524            column: column.into(),
525            op: ComparisonOp::Eq,
526            value: Some(value),
527        };
528        if let Some(existing) = self.where_clause {
529            self.where_clause = Some(Condition::Compound {
530                conditions: vec![existing, condition],
531                op: LogicalOp::And,
532            });
533        } else {
534            self.where_clause = Some(condition);
535        }
536        self
537    }
538
539    /// Build the SQL query string
540    pub fn build(self) -> Result<String> {
541        if self.set_values.is_empty() {
542            return Err(QueryError::Query("UPDATE must have at least one SET value".to_string()));
543        }
544
545        let mut sql = format!("UPDATE {} SET ", self.table);
546
547        let set_clauses: Vec<String> = self
548            .set_values
549            .iter()
550            .map(|(col, val)| format!("{} = {}", col, format_value(val)))
551            .collect();
552
553        sql.push_str(&set_clauses.join(", "));
554
555        if let Some(where_clause) = &self.where_clause {
556            sql.push_str(" WHERE ");
557            sql.push_str(&where_clause.to_sql());
558        }
559
560        Ok(sql)
561    }
562
563    /// Build and return the SQL query string (convenience method)
564    pub fn to_sql(self) -> Result<String> {
565        self.build()
566    }
567}
568
569/// DELETE query builder
570#[derive(Debug, Clone)]
571pub struct DeleteBuilder {
572    table: String,
573    where_clause: Option<Condition>,
574}
575
576impl DeleteBuilder {
577    /// Create a new DELETE query builder
578    pub fn new(table: impl Into<String>) -> Self {
579        Self {
580            table: table.into(),
581            where_clause: None,
582        }
583    }
584
585    /// Add WHERE condition
586    pub fn where_clause(mut self, condition: Condition) -> Self {
587        self.where_clause = Some(condition);
588        self
589    }
590
591    /// Add simple WHERE condition (column = value)
592    pub fn where_eq(mut self, column: impl Into<String>, value: Value) -> Self {
593        let condition = Condition::Simple {
594            column: column.into(),
595            op: ComparisonOp::Eq,
596            value: Some(value),
597        };
598        if let Some(existing) = self.where_clause {
599            self.where_clause = Some(Condition::Compound {
600                conditions: vec![existing, condition],
601                op: LogicalOp::And,
602            });
603        } else {
604            self.where_clause = Some(condition);
605        }
606        self
607    }
608
609    /// Build the SQL query string
610    pub fn build(self) -> Result<String> {
611        let mut sql = format!("DELETE FROM {}", self.table);
612
613        if let Some(where_clause) = &self.where_clause {
614            sql.push_str(" WHERE ");
615            sql.push_str(&where_clause.to_sql());
616        }
617
618        Ok(sql)
619    }
620
621    /// Build and return the SQL query string (convenience method)
622    pub fn to_sql(self) -> Result<String> {
623        self.build()
624    }
625}
626
627/// Format a JSON value for SQL
628fn format_value(value: &Value) -> String {
629    match value {
630        Value::Null => "NULL".to_string(),
631        Value::Bool(b) => b.to_string().to_uppercase(),
632        Value::Number(n) => n.to_string(),
633        Value::String(s) => format!("'{}'", s.replace('\'', "''")),
634        Value::Array(_) | Value::Object(_) => {
635            format!("'{}'", serde_json::to_string(value).unwrap_or_default().replace('\'', "''"))
636        }
637    }
638}
639
640#[cfg(test)]
641mod tests {
642    use super::*;
643
644    #[test]
645    fn test_select_basic() {
646        let query = SelectBuilder::new("users")
647            .select_all()
648            .build()
649            .unwrap();
650
651        assert_eq!(query, "SELECT * FROM users");
652    }
653
654    #[test]
655    fn test_select_columns() {
656        let query = SelectBuilder::new("users")
657            .select(vec!["id", "name", "email"])
658            .build()
659            .unwrap();
660
661        assert_eq!(query, "SELECT id, name, email FROM users");
662    }
663
664    #[test]
665    fn test_select_where() {
666        let query = SelectBuilder::new("users")
667            .select_all()
668            .where_eq("id", Value::Number(1.into()))
669            .build()
670            .unwrap();
671
672        assert_eq!(query, "SELECT * FROM users WHERE id = 1");
673    }
674
675    #[test]
676    fn test_select_where_multiple() {
677        let query = SelectBuilder::new("users")
678            .select_all()
679            .where_eq("age", Value::Number(25.into()))
680            .where_gt("score", Value::Number(100.into()))
681            .build()
682            .unwrap();
683
684        assert_eq!(query, "SELECT * FROM users WHERE (age = 25 AND score > 100)");
685    }
686
687    #[test]
688    fn test_select_where_in() {
689        let query = SelectBuilder::new("users")
690            .select_all()
691            .where_in("id", vec![Value::Number(1.into()), Value::Number(2.into())])
692            .build()
693            .unwrap();
694
695        assert_eq!(query, "SELECT * FROM users WHERE id IN (1, 2)");
696    }
697
698    #[test]
699    fn test_select_join() {
700        let query = SelectBuilder::new("users")
701            .select(vec!["users.name", "posts.title"])
702            .inner_join(
703                "posts",
704                Condition::Raw("users.id = posts.user_id".to_string()),
705            )
706            .build()
707            .unwrap();
708
709        assert_eq!(
710            query,
711            "SELECT users.name, posts.title FROM users INNER JOIN posts ON users.id = posts.user_id"
712        );
713    }
714
715    #[test]
716    fn test_select_order_limit() {
717        let query = SelectBuilder::new("users")
718            .select_all()
719            .order_desc("created_at")
720            .limit(10)
721            .build()
722            .unwrap();
723
724        assert_eq!(query, "SELECT * FROM users ORDER BY created_at DESC LIMIT 10");
725    }
726
727    #[test]
728    fn test_select_group_by() {
729        let query = SelectBuilder::new("orders")
730            .add_column("user_id")
731            .aggregate(AggregateFunc::Count, "*", Some("order_count"))
732            .group_by(vec!["user_id"])
733            .build()
734            .unwrap();
735
736        assert_eq!(query, "SELECT user_id, COUNT(*) AS order_count FROM orders GROUP BY user_id");
737    }
738
739    #[test]
740    fn test_select_distinct() {
741        let query = SelectBuilder::new("users")
742            .select(vec!["country"])
743            .distinct()
744            .build()
745            .unwrap();
746
747        assert_eq!(query, "SELECT DISTINCT country FROM users");
748    }
749
750    #[test]
751    fn test_update_basic() {
752        let query = UpdateBuilder::new("users")
753            .set("name", Value::String("Alice".to_string()))
754            .where_eq("id", Value::Number(1.into()))
755            .build()
756            .unwrap();
757
758        assert_eq!(query, "UPDATE users SET name = 'Alice' WHERE id = 1");
759    }
760
761    #[test]
762    fn test_update_multiple() {
763        let query = UpdateBuilder::new("users")
764            .set("name", Value::String("Alice".to_string()))
765            .set("age", Value::Number(30.into()))
766            .where_eq("id", Value::Number(1.into()))
767            .build()
768            .unwrap();
769
770        assert_eq!(query, "UPDATE users SET name = 'Alice', age = 30 WHERE id = 1");
771    }
772
773    #[test]
774    fn test_delete_basic() {
775        let query = DeleteBuilder::new("users")
776            .where_eq("id", Value::Number(1.into()))
777            .build()
778            .unwrap();
779
780        assert_eq!(query, "DELETE FROM users WHERE id = 1");
781    }
782
783    #[test]
784    fn test_condition_is_null() {
785        let condition = Condition::Simple {
786            column: "deleted_at".to_string(),
787            op: ComparisonOp::IsNull,
788            value: None,
789        };
790
791        assert_eq!(condition.to_sql(), "deleted_at IS NULL");
792    }
793
794    #[test]
795    fn test_condition_like() {
796        let condition = Condition::Simple {
797            column: "name".to_string(),
798            op: ComparisonOp::Like,
799            value: Some(Value::String("%Alice%".to_string())),
800        };
801
802        assert_eq!(condition.to_sql(), "name LIKE '%Alice%'");
803    }
804
805    #[test]
806    fn test_format_value_string_escaping() {
807        let value = Value::String("O'Reilly".to_string());
808        assert_eq!(format_value(&value), "'O''Reilly'");
809    }
810}