sql_cli/sql/parser/
ast.rs

1//! Abstract Syntax Tree (AST) definitions for SQL queries
2//!
3//! This module contains all the data structures that represent
4//! the parsed SQL query structure.
5
6// ===== Expression Types =====
7
8/// Quote style for identifiers (column names, table names, etc.)
9#[derive(Debug, Clone, PartialEq, Eq, Hash)]
10pub enum QuoteStyle {
11    /// No quotes needed (valid unquoted identifier)
12    None,
13    /// Double quotes: "Customer Id"
14    DoubleQuotes,
15    /// SQL Server style brackets: [Customer Id]
16    Brackets,
17}
18
19/// Column reference with optional quoting information
20#[derive(Debug, Clone, PartialEq, Eq, Hash)]
21pub struct ColumnRef {
22    pub name: String,
23    pub quote_style: QuoteStyle,
24}
25
26impl ColumnRef {
27    /// Create an unquoted column reference
28    pub fn unquoted(name: String) -> Self {
29        Self {
30            name,
31            quote_style: QuoteStyle::None,
32        }
33    }
34
35    /// Create a double-quoted column reference
36    pub fn quoted(name: String) -> Self {
37        Self {
38            name,
39            quote_style: QuoteStyle::DoubleQuotes,
40        }
41    }
42
43    /// Create a bracket-quoted column reference
44    pub fn bracketed(name: String) -> Self {
45        Self {
46            name,
47            quote_style: QuoteStyle::Brackets,
48        }
49    }
50
51    /// Format the column reference with appropriate quoting
52    pub fn to_sql(&self) -> String {
53        match self.quote_style {
54            QuoteStyle::None => self.name.clone(),
55            QuoteStyle::DoubleQuotes => format!("\"{}\"", self.name),
56            QuoteStyle::Brackets => format!("[{}]", self.name),
57        }
58    }
59}
60
61impl PartialEq<str> for ColumnRef {
62    fn eq(&self, other: &str) -> bool {
63        self.name == other
64    }
65}
66
67impl PartialEq<&str> for ColumnRef {
68    fn eq(&self, other: &&str) -> bool {
69        self.name == *other
70    }
71}
72
73impl std::fmt::Display for ColumnRef {
74    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75        write!(f, "{}", self.to_sql())
76    }
77}
78
79#[derive(Debug, Clone)]
80pub enum SqlExpression {
81    Column(ColumnRef),
82    StringLiteral(String),
83    NumberLiteral(String),
84    BooleanLiteral(bool),
85    Null, // NULL literal
86    DateTimeConstructor {
87        year: i32,
88        month: u32,
89        day: u32,
90        hour: Option<u32>,
91        minute: Option<u32>,
92        second: Option<u32>,
93    },
94    DateTimeToday {
95        hour: Option<u32>,
96        minute: Option<u32>,
97        second: Option<u32>,
98    },
99    MethodCall {
100        object: String,
101        method: String,
102        args: Vec<SqlExpression>,
103    },
104    ChainedMethodCall {
105        base: Box<SqlExpression>,
106        method: String,
107        args: Vec<SqlExpression>,
108    },
109    FunctionCall {
110        name: String,
111        args: Vec<SqlExpression>,
112        distinct: bool, // For COUNT(DISTINCT col), SUM(DISTINCT col), etc.
113    },
114    WindowFunction {
115        name: String,
116        args: Vec<SqlExpression>,
117        window_spec: WindowSpec,
118    },
119    BinaryOp {
120        left: Box<SqlExpression>,
121        op: String,
122        right: Box<SqlExpression>,
123    },
124    InList {
125        expr: Box<SqlExpression>,
126        values: Vec<SqlExpression>,
127    },
128    NotInList {
129        expr: Box<SqlExpression>,
130        values: Vec<SqlExpression>,
131    },
132    Between {
133        expr: Box<SqlExpression>,
134        lower: Box<SqlExpression>,
135        upper: Box<SqlExpression>,
136    },
137    Not {
138        expr: Box<SqlExpression>,
139    },
140    CaseExpression {
141        when_branches: Vec<WhenBranch>,
142        else_branch: Option<Box<SqlExpression>>,
143    },
144    SimpleCaseExpression {
145        expr: Box<SqlExpression>,
146        when_branches: Vec<SimpleWhenBranch>,
147        else_branch: Option<Box<SqlExpression>>,
148    },
149    /// Scalar subquery that returns a single value
150    /// Used in expressions like: WHERE col = (SELECT MAX(id) FROM table)
151    ScalarSubquery {
152        query: Box<SelectStatement>,
153    },
154    /// IN subquery that returns multiple values  
155    /// Used in expressions like: WHERE col IN (SELECT id FROM table WHERE ...)
156    InSubquery {
157        expr: Box<SqlExpression>,
158        subquery: Box<SelectStatement>,
159    },
160    /// NOT IN subquery
161    /// Used in expressions like: WHERE col NOT IN (SELECT id FROM table WHERE ...)
162    NotInSubquery {
163        expr: Box<SqlExpression>,
164        subquery: Box<SelectStatement>,
165    },
166}
167
168#[derive(Debug, Clone)]
169pub struct WhenBranch {
170    pub condition: Box<SqlExpression>,
171    pub result: Box<SqlExpression>,
172}
173
174#[derive(Debug, Clone)]
175pub struct SimpleWhenBranch {
176    pub value: Box<SqlExpression>,
177    pub result: Box<SqlExpression>,
178}
179
180// ===== WHERE Clause Types =====
181
182#[derive(Debug, Clone)]
183pub struct WhereClause {
184    pub conditions: Vec<Condition>,
185}
186
187#[derive(Debug, Clone)]
188pub struct Condition {
189    pub expr: SqlExpression,
190    pub connector: Option<LogicalOp>, // AND/OR connecting to next condition
191}
192
193#[derive(Debug, Clone)]
194pub enum LogicalOp {
195    And,
196    Or,
197}
198
199// ===== ORDER BY Types =====
200
201#[derive(Debug, Clone, PartialEq)]
202pub enum SortDirection {
203    Asc,
204    Desc,
205}
206
207#[derive(Debug, Clone)]
208pub struct OrderByColumn {
209    pub column: String,
210    pub direction: SortDirection,
211}
212
213// ===== Window Function Types =====
214
215/// Window frame bounds
216#[derive(Debug, Clone, PartialEq)]
217pub enum FrameBound {
218    UnboundedPreceding,
219    CurrentRow,
220    Preceding(i64),
221    Following(i64),
222    UnboundedFollowing,
223}
224
225/// Window frame unit (ROWS or RANGE)
226#[derive(Debug, Clone, PartialEq)]
227pub enum FrameUnit {
228    Rows,
229    Range,
230}
231
232/// Window frame specification
233#[derive(Debug, Clone)]
234pub struct WindowFrame {
235    pub unit: FrameUnit,
236    pub start: FrameBound,
237    pub end: Option<FrameBound>, // None means CURRENT ROW
238}
239
240#[derive(Debug, Clone)]
241pub struct WindowSpec {
242    pub partition_by: Vec<String>,
243    pub order_by: Vec<OrderByColumn>,
244    pub frame: Option<WindowFrame>, // Optional window frame
245}
246
247// ===== SELECT Statement Types =====
248
249/// Represents a SELECT item - either a simple column or a computed expression with alias
250#[derive(Debug, Clone)]
251pub enum SelectItem {
252    /// Simple column reference: "`column_name`"
253    Column(ColumnRef),
254    /// Computed expression with alias: "expr AS alias"
255    Expression { expr: SqlExpression, alias: String },
256    /// Star selector: "*"
257    Star,
258}
259
260#[derive(Debug, Clone)]
261pub struct SelectStatement {
262    pub distinct: bool,                // SELECT DISTINCT flag
263    pub columns: Vec<String>,          // Keep for backward compatibility, will be deprecated
264    pub select_items: Vec<SelectItem>, // New field for computed expressions
265    pub from_table: Option<String>,
266    pub from_subquery: Option<Box<SelectStatement>>, // Subquery in FROM clause
267    pub from_function: Option<TableFunction>,        // Table function like RANGE() in FROM clause
268    pub from_alias: Option<String>,                  // Alias for subquery (AS name)
269    pub joins: Vec<JoinClause>,                      // JOIN clauses
270    pub where_clause: Option<WhereClause>,
271    pub order_by: Option<Vec<OrderByColumn>>,
272    pub group_by: Option<Vec<SqlExpression>>, // Changed from Vec<String> to support expressions
273    pub having: Option<SqlExpression>,        // HAVING clause for post-aggregation filtering
274    pub limit: Option<usize>,
275    pub offset: Option<usize>,
276    pub ctes: Vec<CTE>, // Common Table Expressions (WITH clause)
277}
278
279// ===== Table and Join Types =====
280
281/// Table function that generates virtual tables
282#[derive(Debug, Clone)]
283pub enum TableFunction {
284    Range {
285        start: SqlExpression,
286        end: SqlExpression,
287        step: Option<SqlExpression>,
288    },
289    Split {
290        text: SqlExpression,
291        delimiter: Option<SqlExpression>,
292    },
293    Generator {
294        name: String,
295        args: Vec<SqlExpression>,
296    },
297}
298
299/// Common Table Expression (CTE) structure
300#[derive(Debug, Clone)]
301pub struct CTE {
302    pub name: String,
303    pub column_list: Option<Vec<String>>, // Optional column list: WITH t(col1, col2) AS ...
304    pub cte_type: CTEType,
305}
306
307/// Type of CTE - standard SQL or WEB fetch
308#[derive(Debug, Clone)]
309pub enum CTEType {
310    Standard(SelectStatement),
311    Web(WebCTESpec),
312}
313
314/// Specification for WEB CTEs
315#[derive(Debug, Clone)]
316pub struct WebCTESpec {
317    pub url: String,
318    pub format: Option<DataFormat>,     // CSV, JSON, or auto-detect
319    pub headers: Vec<(String, String)>, // HTTP headers
320    pub cache_seconds: Option<u64>,     // Cache duration
321}
322
323/// Data format for WEB CTEs
324#[derive(Debug, Clone)]
325pub enum DataFormat {
326    CSV,
327    JSON,
328    Auto, // Auto-detect from Content-Type or extension
329}
330
331/// Table source - either a file/table name or a derived table (subquery/CTE)
332#[derive(Debug, Clone)]
333pub enum TableSource {
334    Table(String), // Regular table from CSV/JSON
335    DerivedTable {
336        // Both CTE and subquery
337        query: Box<SelectStatement>,
338        alias: String, // Required alias for subqueries
339    },
340}
341
342/// Join type enumeration
343#[derive(Debug, Clone, PartialEq)]
344pub enum JoinType {
345    Inner,
346    Left,
347    Right,
348    Full,
349    Cross,
350}
351
352/// Join operator for join conditions
353#[derive(Debug, Clone, PartialEq)]
354pub enum JoinOperator {
355    Equal,
356    NotEqual,
357    LessThan,
358    GreaterThan,
359    LessThanOrEqual,
360    GreaterThanOrEqual,
361}
362
363/// Join condition - initially just column equality
364#[derive(Debug, Clone)]
365pub struct JoinCondition {
366    pub left_column: String, // Column from left table (can include table prefix)
367    pub operator: JoinOperator, // Join operator (initially just Equal)
368    pub right_column: String, // Column from right table (can include table prefix)
369}
370
371/// Join clause structure
372#[derive(Debug, Clone)]
373pub struct JoinClause {
374    pub join_type: JoinType,
375    pub table: TableSource,       // The table being joined
376    pub alias: Option<String>,    // Optional alias for the joined table
377    pub condition: JoinCondition, // ON condition
378}