sql_cli/sql/parser/
ast.rs

1//! Abstract Syntax Tree (AST) definitions for SQL queries
2//!
3//! This module contains all the data structures that represent
4//! the parsed SQL query structure.
5
6// ===== Comment Types =====
7
8/// Represents a SQL comment (line or block)
9#[derive(Debug, Clone, PartialEq)]
10pub struct Comment {
11    /// The comment text (without delimiters like -- or /* */)
12    pub text: String,
13    /// True for line comments (--), false for block comments (/* */)
14    pub is_line_comment: bool,
15}
16
17impl Comment {
18    /// Create a new line comment
19    pub fn line(text: String) -> Self {
20        Self {
21            text,
22            is_line_comment: true,
23        }
24    }
25
26    /// Create a new block comment
27    pub fn block(text: String) -> Self {
28        Self {
29            text,
30            is_line_comment: false,
31        }
32    }
33}
34
35// ===== Expression Types =====
36
37/// Quote style for identifiers (column names, table names, etc.)
38#[derive(Debug, Clone, PartialEq, Eq, Hash)]
39pub enum QuoteStyle {
40    /// No quotes needed (valid unquoted identifier)
41    None,
42    /// Double quotes: "Customer Id"
43    DoubleQuotes,
44    /// SQL Server style brackets: [Customer Id]
45    Brackets,
46}
47
48/// Column reference with optional quoting information and table prefix
49#[derive(Debug, Clone, PartialEq, Eq, Hash)]
50pub struct ColumnRef {
51    pub name: String,
52    pub quote_style: QuoteStyle,
53    /// Optional table/alias prefix (e.g., "messages" in "messages.field_name")
54    pub table_prefix: Option<String>,
55}
56
57impl ColumnRef {
58    /// Create an unquoted column reference
59    pub fn unquoted(name: String) -> Self {
60        Self {
61            name,
62            quote_style: QuoteStyle::None,
63            table_prefix: None,
64        }
65    }
66
67    /// Create a double-quoted column reference
68    pub fn quoted(name: String) -> Self {
69        Self {
70            name,
71            quote_style: QuoteStyle::DoubleQuotes,
72            table_prefix: None,
73        }
74    }
75
76    /// Create a qualified column reference (table.column)
77    pub fn qualified(table: String, name: String) -> Self {
78        Self {
79            name,
80            quote_style: QuoteStyle::None,
81            table_prefix: Some(table),
82        }
83    }
84
85    /// Get the full qualified string representation
86    pub fn to_qualified_string(&self) -> String {
87        match &self.table_prefix {
88            Some(table) => format!("{}.{}", table, self.name),
89            None => self.name.clone(),
90        }
91    }
92
93    /// Create a bracket-quoted column reference
94    pub fn bracketed(name: String) -> Self {
95        Self {
96            name,
97            quote_style: QuoteStyle::Brackets,
98            table_prefix: None,
99        }
100    }
101
102    /// Format the column reference with appropriate quoting
103    pub fn to_sql(&self) -> String {
104        let column_part = match self.quote_style {
105            QuoteStyle::None => self.name.clone(),
106            QuoteStyle::DoubleQuotes => format!("\"{}\"", self.name),
107            QuoteStyle::Brackets => format!("[{}]", self.name),
108        };
109
110        match &self.table_prefix {
111            Some(table) => format!("{}.{}", table, column_part),
112            None => column_part,
113        }
114    }
115}
116
117impl PartialEq<str> for ColumnRef {
118    fn eq(&self, other: &str) -> bool {
119        self.name == other
120    }
121}
122
123impl PartialEq<&str> for ColumnRef {
124    fn eq(&self, other: &&str) -> bool {
125        self.name == *other
126    }
127}
128
129impl std::fmt::Display for ColumnRef {
130    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
131        write!(f, "{}", self.to_sql())
132    }
133}
134
135#[derive(Debug, Clone)]
136pub enum SqlExpression {
137    Column(ColumnRef),
138    StringLiteral(String),
139    NumberLiteral(String),
140    BooleanLiteral(bool),
141    Null, // NULL literal
142    DateTimeConstructor {
143        year: i32,
144        month: u32,
145        day: u32,
146        hour: Option<u32>,
147        minute: Option<u32>,
148        second: Option<u32>,
149    },
150    DateTimeToday {
151        hour: Option<u32>,
152        minute: Option<u32>,
153        second: Option<u32>,
154    },
155    MethodCall {
156        object: String,
157        method: String,
158        args: Vec<SqlExpression>,
159    },
160    ChainedMethodCall {
161        base: Box<SqlExpression>,
162        method: String,
163        args: Vec<SqlExpression>,
164    },
165    FunctionCall {
166        name: String,
167        args: Vec<SqlExpression>,
168        distinct: bool, // For COUNT(DISTINCT col), SUM(DISTINCT col), etc.
169    },
170    WindowFunction {
171        name: String,
172        args: Vec<SqlExpression>,
173        window_spec: WindowSpec,
174    },
175    BinaryOp {
176        left: Box<SqlExpression>,
177        op: String,
178        right: Box<SqlExpression>,
179    },
180    InList {
181        expr: Box<SqlExpression>,
182        values: Vec<SqlExpression>,
183    },
184    NotInList {
185        expr: Box<SqlExpression>,
186        values: Vec<SqlExpression>,
187    },
188    Between {
189        expr: Box<SqlExpression>,
190        lower: Box<SqlExpression>,
191        upper: Box<SqlExpression>,
192    },
193    Not {
194        expr: Box<SqlExpression>,
195    },
196    CaseExpression {
197        when_branches: Vec<WhenBranch>,
198        else_branch: Option<Box<SqlExpression>>,
199    },
200    SimpleCaseExpression {
201        expr: Box<SqlExpression>,
202        when_branches: Vec<SimpleWhenBranch>,
203        else_branch: Option<Box<SqlExpression>>,
204    },
205    /// Scalar subquery that returns a single value
206    /// Used in expressions like: WHERE col = (SELECT MAX(id) FROM table)
207    ScalarSubquery {
208        query: Box<SelectStatement>,
209    },
210    /// IN subquery that returns multiple values
211    /// Used in expressions like: WHERE col IN (SELECT id FROM table WHERE ...)
212    InSubquery {
213        expr: Box<SqlExpression>,
214        subquery: Box<SelectStatement>,
215    },
216    /// UNNEST - Row expansion function that splits delimited strings
217    /// Used like: SELECT UNNEST(accounts, '|') AS account FROM fix_trades
218    /// Causes row multiplication - one input row becomes N output rows
219    Unnest {
220        column: Box<SqlExpression>,
221        delimiter: String,
222    },
223    /// NOT IN subquery
224    /// Used in expressions like: WHERE col NOT IN (SELECT id FROM table WHERE ...)
225    NotInSubquery {
226        expr: Box<SqlExpression>,
227        subquery: Box<SelectStatement>,
228    },
229}
230
231#[derive(Debug, Clone)]
232pub struct WhenBranch {
233    pub condition: Box<SqlExpression>,
234    pub result: Box<SqlExpression>,
235}
236
237#[derive(Debug, Clone)]
238pub struct SimpleWhenBranch {
239    pub value: Box<SqlExpression>,
240    pub result: Box<SqlExpression>,
241}
242
243// ===== WHERE Clause Types =====
244
245#[derive(Debug, Clone)]
246pub struct WhereClause {
247    pub conditions: Vec<Condition>,
248}
249
250#[derive(Debug, Clone)]
251pub struct Condition {
252    pub expr: SqlExpression,
253    pub connector: Option<LogicalOp>, // AND/OR connecting to next condition
254}
255
256#[derive(Debug, Clone)]
257pub enum LogicalOp {
258    And,
259    Or,
260}
261
262// ===== ORDER BY Types =====
263
264#[derive(Debug, Clone, PartialEq)]
265pub enum SortDirection {
266    Asc,
267    Desc,
268}
269
270impl SortDirection {
271    pub fn as_u8(&self) -> u8 {
272        match self {
273            SortDirection::Asc => 0,
274            SortDirection::Desc => 1,
275        }
276    }
277}
278
279/// Legacy structure - kept for backward compatibility
280/// New code should use OrderByItem
281#[derive(Debug, Clone)]
282pub struct OrderByColumn {
283    pub column: String,
284    pub direction: SortDirection,
285}
286
287/// Modern ORDER BY item that supports expressions
288#[derive(Debug, Clone)]
289pub struct OrderByItem {
290    pub expr: SqlExpression,
291    pub direction: SortDirection,
292}
293
294impl OrderByItem {
295    /// Create from a simple column name (for backward compatibility)
296    pub fn from_column_name(name: String, direction: SortDirection) -> Self {
297        Self {
298            expr: SqlExpression::Column(ColumnRef {
299                name,
300                quote_style: QuoteStyle::None,
301                table_prefix: None,
302            }),
303            direction,
304        }
305    }
306
307    /// Create from an expression
308    pub fn from_expression(expr: SqlExpression, direction: SortDirection) -> Self {
309        Self { expr, direction }
310    }
311}
312
313// ===== Window Function Types =====
314
315/// Window frame bounds
316#[derive(Debug, Clone, PartialEq)]
317pub enum FrameBound {
318    UnboundedPreceding,
319    CurrentRow,
320    Preceding(i64),
321    Following(i64),
322    UnboundedFollowing,
323}
324
325/// Window frame unit (ROWS or RANGE)
326#[derive(Debug, Clone, Copy, PartialEq)]
327pub enum FrameUnit {
328    Rows,
329    Range,
330}
331
332impl FrameUnit {
333    pub fn as_u8(&self) -> u8 {
334        match self {
335            FrameUnit::Rows => 0,
336            FrameUnit::Range => 1,
337        }
338    }
339}
340
341/// Window frame specification
342#[derive(Debug, Clone)]
343pub struct WindowFrame {
344    pub unit: FrameUnit,
345    pub start: FrameBound,
346    pub end: Option<FrameBound>, // None means CURRENT ROW
347}
348
349#[derive(Debug, Clone)]
350pub struct WindowSpec {
351    pub partition_by: Vec<String>,
352    pub order_by: Vec<OrderByItem>,
353    pub frame: Option<WindowFrame>, // Optional window frame
354}
355
356impl WindowSpec {
357    /// Compute a fast hash for cache key purposes
358    /// Much faster than format!("{:?}", spec) used previously
359    pub fn compute_hash(&self) -> u64 {
360        use std::collections::hash_map::DefaultHasher;
361        use std::hash::{Hash, Hasher};
362
363        let mut hasher = DefaultHasher::new();
364
365        // Hash partition_by columns
366        for col in &self.partition_by {
367            col.hash(&mut hasher);
368        }
369
370        // Hash order_by items (just the column names for simplicity)
371        for item in &self.order_by {
372            // For ORDER BY, we typically just have column references
373            // Hash a string representation for simplicity
374            format!("{:?}", item.expr).hash(&mut hasher);
375            item.direction.as_u8().hash(&mut hasher);
376        }
377
378        // Hash frame specification
379        if let Some(ref frame) = self.frame {
380            frame.unit.as_u8().hash(&mut hasher);
381            format!("{:?}", frame.start).hash(&mut hasher);
382            if let Some(ref end) = frame.end {
383                format!("{:?}", end).hash(&mut hasher);
384            }
385        }
386
387        hasher.finish()
388    }
389}
390
391// ===== SELECT Statement Types =====
392
393/// Set operation type for combining SELECT statements
394#[derive(Debug, Clone, PartialEq)]
395pub enum SetOperation {
396    /// UNION ALL - combines results without deduplication
397    UnionAll,
398    /// UNION - combines results with deduplication (not yet implemented)
399    Union,
400    /// INTERSECT - returns common rows (not yet implemented)
401    Intersect,
402    /// EXCEPT - returns rows from left not in right (not yet implemented)
403    Except,
404}
405
406/// Represents a SELECT item - either a simple column or a computed expression with alias
407#[derive(Debug, Clone)]
408pub enum SelectItem {
409    /// Simple column reference: "`column_name`"
410    Column {
411        column: ColumnRef,
412        leading_comments: Vec<Comment>,
413        trailing_comment: Option<Comment>,
414    },
415    /// Computed expression with alias: "expr AS alias"
416    Expression {
417        expr: SqlExpression,
418        alias: String,
419        leading_comments: Vec<Comment>,
420        trailing_comment: Option<Comment>,
421    },
422    /// Star selector: "*" or "table.*"
423    Star {
424        table_prefix: Option<String>, // e.g., Some("p") for "p.*"
425        leading_comments: Vec<Comment>,
426        trailing_comment: Option<Comment>,
427    },
428    /// Star with EXCLUDE: "* EXCLUDE (col1, col2)"
429    StarExclude {
430        table_prefix: Option<String>,
431        excluded_columns: Vec<String>,
432        leading_comments: Vec<Comment>,
433        trailing_comment: Option<Comment>,
434    },
435}
436
437#[derive(Debug, Clone)]
438pub struct SelectStatement {
439    pub distinct: bool,                // SELECT DISTINCT flag
440    pub columns: Vec<String>,          // Keep for backward compatibility, will be deprecated
441    pub select_items: Vec<SelectItem>, // New field for computed expressions
442    pub from_table: Option<String>,
443    pub from_subquery: Option<Box<SelectStatement>>, // Subquery in FROM clause
444    pub from_function: Option<TableFunction>,        // Table function like RANGE() in FROM clause
445    pub from_alias: Option<String>,                  // Alias for subquery (AS name)
446    pub joins: Vec<JoinClause>,                      // JOIN clauses
447    pub where_clause: Option<WhereClause>,
448    pub order_by: Option<Vec<OrderByItem>>, // Supports expressions: columns, aggregates, CASE, etc.
449    pub group_by: Option<Vec<SqlExpression>>, // Changed from Vec<String> to support expressions
450    pub having: Option<SqlExpression>,      // HAVING clause for post-aggregation filtering
451    pub qualify: Option<SqlExpression>, // QUALIFY clause for window function filtering (Snowflake-style)
452    pub limit: Option<usize>,
453    pub offset: Option<usize>,
454    pub ctes: Vec<CTE>,                // Common Table Expressions (WITH clause)
455    pub into_table: Option<IntoTable>, // INTO clause for temporary tables
456    pub set_operations: Vec<(SetOperation, Box<SelectStatement>)>, // UNION/INTERSECT/EXCEPT operations
457
458    // Comment preservation
459    pub leading_comments: Vec<Comment>, // Comments before the SELECT keyword
460    pub trailing_comment: Option<Comment>, // Trailing comment at end of statement
461}
462
463impl Default for SelectStatement {
464    fn default() -> Self {
465        SelectStatement {
466            distinct: false,
467            columns: Vec::new(),
468            select_items: Vec::new(),
469            from_table: None,
470            from_subquery: None,
471            from_function: None,
472            from_alias: None,
473            joins: Vec::new(),
474            where_clause: None,
475            order_by: None,
476            group_by: None,
477            having: None,
478            qualify: None,
479            limit: None,
480            offset: None,
481            ctes: Vec::new(),
482            into_table: None,
483            set_operations: Vec::new(),
484            leading_comments: Vec::new(),
485            trailing_comment: None,
486        }
487    }
488}
489
490/// INTO clause for creating temporary tables
491#[derive(Debug, Clone, PartialEq)]
492pub struct IntoTable {
493    /// Name of the temporary table (must start with #)
494    pub name: String,
495}
496
497// ===== Table and Join Types =====
498
499/// Table function that generates virtual tables
500#[derive(Debug, Clone)]
501pub enum TableFunction {
502    Generator {
503        name: String,
504        args: Vec<SqlExpression>,
505    },
506}
507
508/// Common Table Expression (CTE) structure
509#[derive(Debug, Clone)]
510pub struct CTE {
511    pub name: String,
512    pub column_list: Option<Vec<String>>, // Optional column list: WITH t(col1, col2) AS ...
513    pub cte_type: CTEType,
514}
515
516/// Type of CTE - standard SQL or WEB fetch
517#[derive(Debug, Clone)]
518pub enum CTEType {
519    Standard(SelectStatement),
520    Web(WebCTESpec),
521}
522
523/// Specification for WEB CTEs
524#[derive(Debug, Clone)]
525pub struct WebCTESpec {
526    pub url: String,
527    pub format: Option<DataFormat>,        // CSV, JSON, or auto-detect
528    pub headers: Vec<(String, String)>,    // HTTP headers
529    pub cache_seconds: Option<u64>,        // Cache duration
530    pub method: Option<HttpMethod>,        // HTTP method (GET, POST, etc.)
531    pub body: Option<String>,              // Request body for POST/PUT
532    pub json_path: Option<String>, // JSON path to extract (e.g., "Result" for {Result: [...]})
533    pub form_files: Vec<(String, String)>, // Multipart form files: (field_name, file_path)
534    pub form_fields: Vec<(String, String)>, // Multipart form fields: (field_name, value)
535    pub template_vars: Vec<TemplateVar>, // Template variables for injection from temp tables
536}
537
538/// Template variable for injecting temp table data into WEB CTEs
539#[derive(Debug, Clone)]
540pub struct TemplateVar {
541    pub placeholder: String,    // e.g., "${#instruments}"
542    pub table_name: String,     // e.g., "#instruments"
543    pub column: Option<String>, // e.g., Some("symbol") for ${#instruments.symbol}
544    pub index: Option<usize>,   // e.g., Some(0) for ${#instruments[0]}
545}
546
547/// HTTP methods for WEB CTEs
548#[derive(Debug, Clone)]
549pub enum HttpMethod {
550    GET,
551    POST,
552    PUT,
553    DELETE,
554    PATCH,
555}
556
557/// Data format for WEB CTEs
558#[derive(Debug, Clone)]
559pub enum DataFormat {
560    CSV,
561    JSON,
562    Auto, // Auto-detect from Content-Type or extension
563}
564
565/// Table source - either a file/table name or a derived table (subquery/CTE)
566#[derive(Debug, Clone)]
567pub enum TableSource {
568    Table(String), // Regular table from CSV/JSON
569    DerivedTable {
570        // Both CTE and subquery
571        query: Box<SelectStatement>,
572        alias: String, // Required alias for subqueries
573    },
574}
575
576/// Join type enumeration
577#[derive(Debug, Clone, PartialEq)]
578pub enum JoinType {
579    Inner,
580    Left,
581    Right,
582    Full,
583    Cross,
584}
585
586/// Join operator for join conditions
587#[derive(Debug, Clone, PartialEq)]
588pub enum JoinOperator {
589    Equal,
590    NotEqual,
591    LessThan,
592    GreaterThan,
593    LessThanOrEqual,
594    GreaterThanOrEqual,
595}
596
597/// Single join condition
598#[derive(Debug, Clone)]
599pub struct SingleJoinCondition {
600    pub left_expr: SqlExpression, // Expression from left table (can be column, function call, etc.)
601    pub operator: JoinOperator,   // Join operator
602    pub right_expr: SqlExpression, // Expression from right table (can be column, function call, etc.)
603}
604
605/// Join condition - can be multiple conditions connected by AND
606#[derive(Debug, Clone)]
607pub struct JoinCondition {
608    pub conditions: Vec<SingleJoinCondition>, // Multiple conditions connected by AND
609}
610
611/// Join clause structure
612#[derive(Debug, Clone)]
613pub struct JoinClause {
614    pub join_type: JoinType,
615    pub table: TableSource,       // The table being joined
616    pub alias: Option<String>,    // Optional alias for the joined table
617    pub condition: JoinCondition, // ON condition(s)
618}