Skip to main content

sqlglot_rust/ast/
types.rs

1use serde::{Deserialize, Serialize};
2
3// ═══════════════════════════════════════════════════════════════════════
4// Identifier quoting style
5// ═══════════════════════════════════════════════════════════════════════
6
7/// How an identifier (column, table, alias) was quoted in the source SQL.
8///
9/// Used to preserve and transform quoting across dialects (e.g. backtick
10/// for MySQL/BigQuery → double-quote for PostgreSQL → bracket for T-SQL).
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
12pub enum QuoteStyle {
13    /// Bare / unquoted identifier
14    #[default]
15    None,
16    /// `"identifier"` — ANSI SQL, PostgreSQL, Oracle, Snowflake, etc.
17    DoubleQuote,
18    /// `` `identifier` `` — MySQL, BigQuery, Hive, Spark, etc.
19    Backtick,
20    /// `[identifier]` — T-SQL / SQL Server
21    Bracket,
22}
23
24impl QuoteStyle {
25    /// Returns the canonical quoting style for the given dialect.
26    #[must_use]
27    pub fn for_dialect(dialect: crate::dialects::Dialect) -> Self {
28        use crate::dialects::Dialect;
29        match dialect {
30            Dialect::Tsql | Dialect::Fabric => QuoteStyle::Bracket,
31            Dialect::Mysql
32            | Dialect::BigQuery
33            | Dialect::Hive
34            | Dialect::Spark
35            | Dialect::Databricks
36            | Dialect::Doris
37            | Dialect::SingleStore
38            | Dialect::StarRocks => QuoteStyle::Backtick,
39            // ANSI, Postgres, Oracle, Snowflake, Presto, Trino, etc.
40            _ => QuoteStyle::DoubleQuote,
41        }
42    }
43
44    /// Returns `true` when the identifier carries explicit quoting.
45    #[must_use]
46    pub fn is_quoted(self) -> bool {
47        !matches!(self, QuoteStyle::None)
48    }
49}
50
51// ═══════════════════════════════════════════════════════════════════════
52// Top-level statement types
53// ═══════════════════════════════════════════════════════════════════════
54
55/// A fully parsed SQL statement.
56///
57/// Corresponds to the top-level node in sqlglot's expression tree.
58#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
59pub enum Statement {
60    Select(SelectStatement),
61    Insert(InsertStatement),
62    Update(UpdateStatement),
63    Delete(DeleteStatement),
64    CreateTable(CreateTableStatement),
65    DropTable(DropTableStatement),
66    /// UNION / INTERSECT / EXCEPT between queries
67    SetOperation(SetOperationStatement),
68    /// ALTER TABLE ...
69    AlterTable(AlterTableStatement),
70    /// CREATE VIEW ...
71    CreateView(CreateViewStatement),
72    /// DROP VIEW ...
73    DropView(DropViewStatement),
74    /// TRUNCATE TABLE ...
75    Truncate(TruncateStatement),
76    /// BEGIN / COMMIT / ROLLBACK
77    Transaction(TransactionStatement),
78    /// EXPLAIN <statement>
79    Explain(ExplainStatement),
80    /// USE database
81    Use(UseStatement),
82    /// Raw / passthrough expression (for expressions that don't fit a specific statement type)
83    Expression(Expr),
84}
85
86// ═══════════════════════════════════════════════════════════════════════
87// SELECT
88// ═══════════════════════════════════════════════════════════════════════
89
90/// A SELECT statement, including CTEs.
91///
92/// Aligned with sqlglot's `Select` expression which wraps `With`, `From`,
93/// `Where`, `Group`, `Having`, `Order`, `Limit`, `Offset`, `Window`.
94#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
95pub struct SelectStatement {
96    /// Common Table Expressions (WITH clause)
97    pub ctes: Vec<Cte>,
98    pub distinct: bool,
99    /// TOP N (TSQL-style)
100    pub top: Option<Box<Expr>>,
101    pub columns: Vec<SelectItem>,
102    pub from: Option<FromClause>,
103    pub joins: Vec<JoinClause>,
104    pub where_clause: Option<Expr>,
105    pub group_by: Vec<Expr>,
106    pub having: Option<Expr>,
107    pub order_by: Vec<OrderByItem>,
108    pub limit: Option<Expr>,
109    pub offset: Option<Expr>,
110    /// Oracle-style FETCH FIRST n ROWS ONLY
111    pub fetch_first: Option<Expr>,
112    /// QUALIFY clause (BigQuery, Snowflake)
113    pub qualify: Option<Expr>,
114    /// Named WINDOW definitions
115    pub window_definitions: Vec<WindowDefinition>,
116}
117
118/// A Common Table Expression: `name [(col1, col2)] AS [NOT] MATERIALIZED (query)`
119#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
120pub struct Cte {
121    pub name: String,
122    pub columns: Vec<String>,
123    pub query: Box<Statement>,
124    pub materialized: Option<bool>,
125    pub recursive: bool,
126}
127
128/// Named WINDOW definition: `window_name AS (window_spec)`
129#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
130pub struct WindowDefinition {
131    pub name: String,
132    pub spec: WindowSpec,
133}
134
135// ═══════════════════════════════════════════════════════════════════════
136// Set operations (UNION, INTERSECT, EXCEPT)
137// ═══════════════════════════════════════════════════════════════════════
138
139/// UNION / INTERSECT / EXCEPT between two or more queries.
140#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
141pub struct SetOperationStatement {
142    pub op: SetOperationType,
143    pub all: bool,
144    pub left: Box<Statement>,
145    pub right: Box<Statement>,
146    pub order_by: Vec<OrderByItem>,
147    pub limit: Option<Expr>,
148    pub offset: Option<Expr>,
149}
150
151#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
152pub enum SetOperationType {
153    Union,
154    Intersect,
155    Except,
156}
157
158// ═══════════════════════════════════════════════════════════════════════
159// SELECT items and FROM
160// ═══════════════════════════════════════════════════════════════════════
161
162/// An item in a SELECT list.
163#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
164pub enum SelectItem {
165    /// `*`
166    Wildcard,
167    /// `table.*`
168    QualifiedWildcard { table: String },
169    /// An expression with optional alias: `expr AS alias`
170    Expr { expr: Expr, alias: Option<String> },
171}
172
173/// A FROM clause, now supporting subqueries and multiple tables.
174#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
175pub struct FromClause {
176    pub source: TableSource,
177}
178
179/// A table source can be a table reference, subquery, or table function.
180#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
181pub enum TableSource {
182    Table(TableRef),
183    Subquery {
184        query: Box<Statement>,
185        alias: Option<String>,
186    },
187    TableFunction {
188        name: String,
189        args: Vec<Expr>,
190        alias: Option<String>,
191    },
192    /// LATERAL subquery or function
193    Lateral {
194        source: Box<TableSource>,
195    },
196    /// UNNEST(array_expr)
197    Unnest {
198        expr: Box<Expr>,
199        alias: Option<String>,
200        with_offset: bool,
201    },
202}
203
204/// A reference to a table.
205#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
206pub struct TableRef {
207    pub catalog: Option<String>,
208    pub schema: Option<String>,
209    pub name: String,
210    pub alias: Option<String>,
211    /// How the table name was quoted in the source SQL.
212    #[serde(default)]
213    pub name_quote_style: QuoteStyle,
214}
215
216/// A JOIN clause.
217#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
218pub struct JoinClause {
219    pub join_type: JoinType,
220    pub table: TableSource,
221    pub on: Option<Expr>,
222    pub using: Vec<String>,
223}
224
225/// The type of JOIN.
226#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
227pub enum JoinType {
228    Inner,
229    Left,
230    Right,
231    Full,
232    Cross,
233    /// NATURAL JOIN
234    Natural,
235    /// LATERAL JOIN
236    Lateral,
237}
238
239/// An ORDER BY item.
240#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
241pub struct OrderByItem {
242    pub expr: Expr,
243    pub ascending: bool,
244    /// NULLS FIRST / NULLS LAST
245    pub nulls_first: Option<bool>,
246}
247
248// ═══════════════════════════════════════════════════════════════════════
249// Expressions (the core of the AST)
250// ═══════════════════════════════════════════════════════════════════════
251
252/// An expression in SQL.
253///
254/// This enum is aligned with sqlglot's Expression class hierarchy.
255/// Key additions over the basic implementation:
256/// - Subquery, Exists, Cast, Extract, Window functions
257/// - TypedString, Interval, Array/Struct constructors
258/// - Postgres-style casting (`::`)
259#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
260pub enum Expr {
261    /// A column reference, possibly qualified: `[catalog.][schema.]table.column`
262    Column {
263        table: Option<String>,
264        name: String,
265        /// How the column name was quoted in the source SQL.
266        #[serde(default)]
267        quote_style: QuoteStyle,
268        /// How the table qualifier was quoted, if present.
269        #[serde(default)]
270        table_quote_style: QuoteStyle,
271    },
272    /// A numeric literal.
273    Number(String),
274    /// A string literal.
275    StringLiteral(String),
276    /// A boolean literal.
277    Boolean(bool),
278    /// NULL literal.
279    Null,
280    /// A binary operation: `left op right`
281    BinaryOp {
282        left: Box<Expr>,
283        op: BinaryOperator,
284        right: Box<Expr>,
285    },
286    /// A unary operation: `op expr`
287    UnaryOp { op: UnaryOperator, expr: Box<Expr> },
288    /// A function call: `name(args...)` with optional DISTINCT, ORDER BY, etc.
289    Function {
290        name: String,
291        args: Vec<Expr>,
292        distinct: bool,
293        /// FILTER (WHERE expr) clause on aggregate
294        filter: Option<Box<Expr>>,
295        /// OVER window specification for window functions
296        over: Option<WindowSpec>,
297    },
298    /// `expr BETWEEN low AND high`
299    Between {
300        expr: Box<Expr>,
301        low: Box<Expr>,
302        high: Box<Expr>,
303        negated: bool,
304    },
305    /// `expr IN (list...)` or `expr IN (subquery)`
306    InList {
307        expr: Box<Expr>,
308        list: Vec<Expr>,
309        negated: bool,
310    },
311    /// `expr IN (SELECT ...)`
312    InSubquery {
313        expr: Box<Expr>,
314        subquery: Box<Statement>,
315        negated: bool,
316    },
317    /// `expr op ANY(subexpr)` — PostgreSQL array/subquery comparison
318    AnyOp {
319        expr: Box<Expr>,
320        op: BinaryOperator,
321        right: Box<Expr>,
322    },
323    /// `expr op ALL(subexpr)` — PostgreSQL array/subquery comparison
324    AllOp {
325        expr: Box<Expr>,
326        op: BinaryOperator,
327        right: Box<Expr>,
328    },
329    /// `expr IS [NOT] NULL`
330    IsNull { expr: Box<Expr>, negated: bool },
331    /// `expr IS [NOT] TRUE` / `expr IS [NOT] FALSE`
332    IsBool {
333        expr: Box<Expr>,
334        value: bool,
335        negated: bool,
336    },
337    /// `expr [NOT] LIKE pattern [ESCAPE escape_char]`
338    Like {
339        expr: Box<Expr>,
340        pattern: Box<Expr>,
341        negated: bool,
342        escape: Option<Box<Expr>>,
343    },
344    /// `expr [NOT] ILIKE pattern [ESCAPE escape_char]` (case-insensitive LIKE)
345    ILike {
346        expr: Box<Expr>,
347        pattern: Box<Expr>,
348        negated: bool,
349        escape: Option<Box<Expr>>,
350    },
351    /// `CASE [operand] WHEN ... THEN ... ELSE ... END`
352    Case {
353        operand: Option<Box<Expr>>,
354        when_clauses: Vec<(Expr, Expr)>,
355        else_clause: Option<Box<Expr>>,
356    },
357    /// A parenthesized sub-expression.
358    Nested(Box<Expr>),
359    /// A wildcard `*` used in contexts like `COUNT(*)`.
360    Wildcard,
361    /// A scalar subquery: `(SELECT ...)`
362    Subquery(Box<Statement>),
363    /// `EXISTS (SELECT ...)`
364    Exists {
365        subquery: Box<Statement>,
366        negated: bool,
367    },
368    /// `CAST(expr AS type)` or `expr::type` (PostgreSQL)
369    Cast {
370        expr: Box<Expr>,
371        data_type: DataType,
372    },
373    /// `TRY_CAST(expr AS type)`
374    TryCast {
375        expr: Box<Expr>,
376        data_type: DataType,
377    },
378    /// `EXTRACT(field FROM expr)`
379    Extract {
380        field: DateTimeField,
381        expr: Box<Expr>,
382    },
383    /// `INTERVAL 'value' unit`
384    Interval {
385        value: Box<Expr>,
386        unit: Option<DateTimeField>,
387    },
388    /// Array literal: `ARRAY[1, 2, 3]` or `[1, 2, 3]`
389    ArrayLiteral(Vec<Expr>),
390    /// Struct literal / row constructor: `(1, 'a', true)`
391    Tuple(Vec<Expr>),
392    /// `COALESCE(a, b, c)`
393    Coalesce(Vec<Expr>),
394    /// `IF(condition, true_val, false_val)` (MySQL, BigQuery)
395    If {
396        condition: Box<Expr>,
397        true_val: Box<Expr>,
398        false_val: Option<Box<Expr>>,
399    },
400    /// `NULLIF(a, b)`
401    NullIf { expr: Box<Expr>, r#else: Box<Expr> },
402    /// `expr COLLATE collation`
403    Collate { expr: Box<Expr>, collation: String },
404    /// Parameter / placeholder: `$1`, `?`, `:name`
405    Parameter(String),
406    /// A type expression used in DDL contexts or CAST
407    TypeExpr(DataType),
408    /// `table.*` in expression context
409    QualifiedWildcard { table: String },
410    /// Star expression `*`
411    Star,
412    /// Alias expression: `expr AS name`
413    Alias { expr: Box<Expr>, name: String },
414    /// Array access: `expr[index]`
415    ArrayIndex { expr: Box<Expr>, index: Box<Expr> },
416    /// JSON access: `expr->key` or `expr->>key`
417    JsonAccess {
418        expr: Box<Expr>,
419        path: Box<Expr>,
420        /// false = ->, true = ->>
421        as_text: bool,
422    },
423    /// Lambda expression: `x -> x + 1`
424    Lambda {
425        params: Vec<String>,
426        body: Box<Expr>,
427    },
428    /// `DEFAULT` keyword in INSERT/UPDATE contexts
429    Default,
430}
431
432// ═══════════════════════════════════════════════════════════════════════
433// Window specification
434// ═══════════════════════════════════════════════════════════════════════
435
436/// Window specification for window functions: OVER (PARTITION BY ... ORDER BY ... frame)
437#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
438pub struct WindowSpec {
439    /// Reference to a named window
440    pub window_ref: Option<String>,
441    pub partition_by: Vec<Expr>,
442    pub order_by: Vec<OrderByItem>,
443    pub frame: Option<WindowFrame>,
444}
445
446#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
447pub struct WindowFrame {
448    pub kind: WindowFrameKind,
449    pub start: WindowFrameBound,
450    pub end: Option<WindowFrameBound>,
451}
452
453#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
454pub enum WindowFrameKind {
455    Rows,
456    Range,
457    Groups,
458}
459
460#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
461pub enum WindowFrameBound {
462    CurrentRow,
463    Preceding(Option<Box<Expr>>), // None = UNBOUNDED PRECEDING
464    Following(Option<Box<Expr>>), // None = UNBOUNDED FOLLOWING
465}
466
467// ═══════════════════════════════════════════════════════════════════════
468// Date/time fields (for EXTRACT, INTERVAL)
469// ═══════════════════════════════════════════════════════════════════════
470
471#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
472pub enum DateTimeField {
473    Year,
474    Quarter,
475    Month,
476    Week,
477    Day,
478    DayOfWeek,
479    DayOfYear,
480    Hour,
481    Minute,
482    Second,
483    Millisecond,
484    Microsecond,
485    Nanosecond,
486    Epoch,
487    Timezone,
488    TimezoneHour,
489    TimezoneMinute,
490}
491
492// ═══════════════════════════════════════════════════════════════════════
493// Operators
494// ═══════════════════════════════════════════════════════════════════════
495
496/// Binary operators.
497#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
498pub enum BinaryOperator {
499    Plus,
500    Minus,
501    Multiply,
502    Divide,
503    Modulo,
504    Eq,
505    Neq,
506    Lt,
507    Gt,
508    LtEq,
509    GtEq,
510    And,
511    Or,
512    Xor,
513    Concat,
514    BitwiseAnd,
515    BitwiseOr,
516    BitwiseXor,
517    ShiftLeft,
518    ShiftRight,
519    /// `->` JSON access operator
520    Arrow,
521    /// `->>` JSON text access
522    DoubleArrow,
523}
524
525/// Unary operators.
526#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
527pub enum UnaryOperator {
528    Not,
529    Minus,
530    Plus,
531    BitwiseNot,
532}
533
534// ═══════════════════════════════════════════════════════════════════════
535// DML statements
536// ═══════════════════════════════════════════════════════════════════════
537
538/// An INSERT statement, now supporting INSERT ... SELECT and ON CONFLICT.
539#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
540pub struct InsertStatement {
541    pub table: TableRef,
542    pub columns: Vec<String>,
543    pub source: InsertSource,
544    /// ON CONFLICT / ON DUPLICATE KEY
545    pub on_conflict: Option<OnConflict>,
546    /// RETURNING clause
547    pub returning: Vec<SelectItem>,
548}
549
550#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
551pub enum InsertSource {
552    Values(Vec<Vec<Expr>>),
553    Query(Box<Statement>),
554    Default,
555}
556
557#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
558pub struct OnConflict {
559    pub columns: Vec<String>,
560    pub action: ConflictAction,
561}
562
563#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
564pub enum ConflictAction {
565    DoNothing,
566    DoUpdate(Vec<(String, Expr)>),
567}
568
569/// An UPDATE statement.
570#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
571pub struct UpdateStatement {
572    pub table: TableRef,
573    pub assignments: Vec<(String, Expr)>,
574    pub from: Option<FromClause>,
575    pub where_clause: Option<Expr>,
576    pub returning: Vec<SelectItem>,
577}
578
579/// A DELETE statement.
580#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
581pub struct DeleteStatement {
582    pub table: TableRef,
583    pub using: Option<FromClause>,
584    pub where_clause: Option<Expr>,
585    pub returning: Vec<SelectItem>,
586}
587
588// ═══════════════════════════════════════════════════════════════════════
589// DDL statements
590// ═══════════════════════════════════════════════════════════════════════
591
592/// A CREATE TABLE statement.
593#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
594pub struct CreateTableStatement {
595    pub if_not_exists: bool,
596    pub temporary: bool,
597    pub table: TableRef,
598    pub columns: Vec<ColumnDef>,
599    pub constraints: Vec<TableConstraint>,
600    /// CREATE TABLE ... AS SELECT ...
601    pub as_select: Option<Box<Statement>>,
602}
603
604/// Table-level constraints.
605#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
606pub enum TableConstraint {
607    PrimaryKey {
608        name: Option<String>,
609        columns: Vec<String>,
610    },
611    Unique {
612        name: Option<String>,
613        columns: Vec<String>,
614    },
615    ForeignKey {
616        name: Option<String>,
617        columns: Vec<String>,
618        ref_table: TableRef,
619        ref_columns: Vec<String>,
620        on_delete: Option<ReferentialAction>,
621        on_update: Option<ReferentialAction>,
622    },
623    Check {
624        name: Option<String>,
625        expr: Expr,
626    },
627}
628
629#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
630pub enum ReferentialAction {
631    Cascade,
632    Restrict,
633    NoAction,
634    SetNull,
635    SetDefault,
636}
637
638/// A column definition in CREATE TABLE.
639#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
640pub struct ColumnDef {
641    pub name: String,
642    pub data_type: DataType,
643    pub nullable: Option<bool>,
644    pub default: Option<Expr>,
645    pub primary_key: bool,
646    pub unique: bool,
647    pub auto_increment: bool,
648    pub collation: Option<String>,
649    pub comment: Option<String>,
650}
651
652/// ALTER TABLE statement.
653#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
654pub struct AlterTableStatement {
655    pub table: TableRef,
656    pub actions: Vec<AlterTableAction>,
657}
658
659#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
660pub enum AlterTableAction {
661    AddColumn(ColumnDef),
662    DropColumn { name: String, if_exists: bool },
663    RenameColumn { old_name: String, new_name: String },
664    AlterColumnType { name: String, data_type: DataType },
665    AddConstraint(TableConstraint),
666    DropConstraint { name: String },
667    RenameTable { new_name: String },
668}
669
670/// CREATE VIEW statement.
671#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
672pub struct CreateViewStatement {
673    pub name: TableRef,
674    pub columns: Vec<String>,
675    pub query: Box<Statement>,
676    pub or_replace: bool,
677    pub materialized: bool,
678    pub if_not_exists: bool,
679}
680
681/// DROP VIEW statement.
682#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
683pub struct DropViewStatement {
684    pub name: TableRef,
685    pub if_exists: bool,
686    pub materialized: bool,
687}
688
689/// TRUNCATE TABLE statement.
690#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
691pub struct TruncateStatement {
692    pub table: TableRef,
693}
694
695/// Transaction control statements.
696#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
697pub enum TransactionStatement {
698    Begin,
699    Commit,
700    Rollback,
701    Savepoint(String),
702    ReleaseSavepoint(String),
703    RollbackTo(String),
704}
705
706/// EXPLAIN statement.
707#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
708pub struct ExplainStatement {
709    pub analyze: bool,
710    pub statement: Box<Statement>,
711}
712
713/// USE database statement.
714#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
715pub struct UseStatement {
716    pub name: String,
717}
718
719/// A DROP TABLE statement.
720#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
721pub struct DropTableStatement {
722    pub if_exists: bool,
723    pub table: TableRef,
724    pub cascade: bool,
725}
726
727// ═══════════════════════════════════════════════════════════════════════
728// Data types
729// ═══════════════════════════════════════════════════════════════════════
730
731/// SQL data types. Significantly expanded to match sqlglot's DataType.Type enum.
732#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
733pub enum DataType {
734    // Numeric
735    TinyInt,
736    SmallInt,
737    Int,
738    BigInt,
739    Float,
740    Double,
741    Decimal {
742        precision: Option<u32>,
743        scale: Option<u32>,
744    },
745    Numeric {
746        precision: Option<u32>,
747        scale: Option<u32>,
748    },
749    Real,
750
751    // String
752    Varchar(Option<u32>),
753    Char(Option<u32>),
754    Text,
755    String,
756    Binary(Option<u32>),
757    Varbinary(Option<u32>),
758
759    // Boolean
760    Boolean,
761
762    // Date/Time
763    Date,
764    Time {
765        precision: Option<u32>,
766    },
767    Timestamp {
768        precision: Option<u32>,
769        with_tz: bool,
770    },
771    Interval,
772    DateTime,
773
774    // Binary
775    Blob,
776    Bytea,
777    Bytes,
778
779    // JSON
780    Json,
781    Jsonb,
782
783    // UUID
784    Uuid,
785
786    // Complex types
787    Array(Option<Box<DataType>>),
788    Map {
789        key: Box<DataType>,
790        value: Box<DataType>,
791    },
792    Struct(Vec<(String, DataType)>),
793    Tuple(Vec<DataType>),
794
795    // Special
796    Null,
797    Unknown(String),
798    Variant,
799    Object,
800    Xml,
801    Inet,
802    Cidr,
803    Macaddr,
804    Bit(Option<u32>),
805    Money,
806    Serial,
807    BigSerial,
808    SmallSerial,
809    Regclass,
810    Regtype,
811    Hstore,
812    Geography,
813    Geometry,
814    Super,
815}
816
817// ═══════════════════════════════════════════════════════════════════════
818// Expression tree traversal helpers
819// ═══════════════════════════════════════════════════════════════════════
820
821impl Expr {
822    /// Recursively walk this expression tree, calling `visitor` on each node.
823    /// If `visitor` returns `false`, children of that node are not visited.
824    pub fn walk<F>(&self, visitor: &mut F)
825    where
826        F: FnMut(&Expr) -> bool,
827    {
828        if !visitor(self) {
829            return;
830        }
831        match self {
832            Expr::BinaryOp { left, right, .. } => {
833                left.walk(visitor);
834                right.walk(visitor);
835            }
836            Expr::UnaryOp { expr, .. } => expr.walk(visitor),
837            Expr::Function { args, filter, .. } => {
838                for arg in args {
839                    arg.walk(visitor);
840                }
841                if let Some(f) = filter {
842                    f.walk(visitor);
843                }
844            }
845            Expr::Between {
846                expr, low, high, ..
847            } => {
848                expr.walk(visitor);
849                low.walk(visitor);
850                high.walk(visitor);
851            }
852            Expr::InList { expr, list, .. } => {
853                expr.walk(visitor);
854                for item in list {
855                    item.walk(visitor);
856                }
857            }
858            Expr::InSubquery { expr, .. } => {
859                expr.walk(visitor);
860            }
861            Expr::IsNull { expr, .. } => expr.walk(visitor),
862            Expr::IsBool { expr, .. } => expr.walk(visitor),
863            Expr::AnyOp { expr, right, .. } | Expr::AllOp { expr, right, .. } => {
864                expr.walk(visitor);
865                right.walk(visitor);
866            }
867            Expr::Like { expr, pattern, .. } | Expr::ILike { expr, pattern, .. } => {
868                expr.walk(visitor);
869                pattern.walk(visitor);
870            }
871            Expr::Case {
872                operand,
873                when_clauses,
874                else_clause,
875            } => {
876                if let Some(op) = operand {
877                    op.walk(visitor);
878                }
879                for (cond, result) in when_clauses {
880                    cond.walk(visitor);
881                    result.walk(visitor);
882                }
883                if let Some(el) = else_clause {
884                    el.walk(visitor);
885                }
886            }
887            Expr::Nested(inner) => inner.walk(visitor),
888            Expr::Cast { expr, .. } | Expr::TryCast { expr, .. } => expr.walk(visitor),
889            Expr::Extract { expr, .. } => expr.walk(visitor),
890            Expr::Interval { value, .. } => value.walk(visitor),
891            Expr::ArrayLiteral(items) | Expr::Tuple(items) | Expr::Coalesce(items) => {
892                for item in items {
893                    item.walk(visitor);
894                }
895            }
896            Expr::If {
897                condition,
898                true_val,
899                false_val,
900            } => {
901                condition.walk(visitor);
902                true_val.walk(visitor);
903                if let Some(fv) = false_val {
904                    fv.walk(visitor);
905                }
906            }
907            Expr::NullIf { expr, r#else } => {
908                expr.walk(visitor);
909                r#else.walk(visitor);
910            }
911            Expr::Collate { expr, .. } => expr.walk(visitor),
912            Expr::Alias { expr, .. } => expr.walk(visitor),
913            Expr::ArrayIndex { expr, index } => {
914                expr.walk(visitor);
915                index.walk(visitor);
916            }
917            Expr::JsonAccess { expr, path, .. } => {
918                expr.walk(visitor);
919                path.walk(visitor);
920            }
921            Expr::Lambda { body, .. } => body.walk(visitor),
922            // Leaf nodes
923            Expr::Column { .. }
924            | Expr::Number(_)
925            | Expr::StringLiteral(_)
926            | Expr::Boolean(_)
927            | Expr::Null
928            | Expr::Wildcard
929            | Expr::Star
930            | Expr::Parameter(_)
931            | Expr::TypeExpr(_)
932            | Expr::QualifiedWildcard { .. }
933            | Expr::Default
934            | Expr::Subquery(_)
935            | Expr::Exists { .. } => {}
936        }
937    }
938
939    /// Find the first expression matching the predicate.
940    #[must_use]
941    pub fn find<F>(&self, predicate: &F) -> Option<&Expr>
942    where
943        F: Fn(&Expr) -> bool,
944    {
945        let mut result = None;
946        self.walk(&mut |expr| {
947            if result.is_some() {
948                return false;
949            }
950            if predicate(expr) {
951                result = Some(expr as *const Expr);
952                false
953            } else {
954                true
955            }
956        });
957        // SAFETY: the pointer is valid as long as self is alive
958        result.map(|p| unsafe { &*p })
959    }
960
961    /// Find all expressions matching the predicate.
962    #[must_use]
963    pub fn find_all<F>(&self, predicate: &F) -> Vec<&Expr>
964    where
965        F: Fn(&Expr) -> bool,
966    {
967        let mut results: Vec<*const Expr> = Vec::new();
968        self.walk(&mut |expr| {
969            if predicate(expr) {
970                results.push(expr as *const Expr);
971            }
972            true
973        });
974        results.into_iter().map(|p| unsafe { &*p }).collect()
975    }
976
977    /// Transform this expression tree by applying a function to each node.
978    /// The function can return a new expression to replace the current one.
979    #[must_use]
980    pub fn transform<F>(self, func: &F) -> Expr
981    where
982        F: Fn(Expr) -> Expr,
983    {
984        let transformed = match self {
985            Expr::BinaryOp { left, op, right } => Expr::BinaryOp {
986                left: Box::new(left.transform(func)),
987                op,
988                right: Box::new(right.transform(func)),
989            },
990            Expr::UnaryOp { op, expr } => Expr::UnaryOp {
991                op,
992                expr: Box::new(expr.transform(func)),
993            },
994            Expr::Function {
995                name,
996                args,
997                distinct,
998                filter,
999                over,
1000            } => Expr::Function {
1001                name,
1002                args: args.into_iter().map(|a| a.transform(func)).collect(),
1003                distinct,
1004                filter: filter.map(|f| Box::new(f.transform(func))),
1005                over,
1006            },
1007            Expr::Nested(inner) => Expr::Nested(Box::new(inner.transform(func))),
1008            Expr::Cast { expr, data_type } => Expr::Cast {
1009                expr: Box::new(expr.transform(func)),
1010                data_type,
1011            },
1012            Expr::Between {
1013                expr,
1014                low,
1015                high,
1016                negated,
1017            } => Expr::Between {
1018                expr: Box::new(expr.transform(func)),
1019                low: Box::new(low.transform(func)),
1020                high: Box::new(high.transform(func)),
1021                negated,
1022            },
1023            Expr::Case {
1024                operand,
1025                when_clauses,
1026                else_clause,
1027            } => Expr::Case {
1028                operand: operand.map(|o| Box::new(o.transform(func))),
1029                when_clauses: when_clauses
1030                    .into_iter()
1031                    .map(|(c, r)| (c.transform(func), r.transform(func)))
1032                    .collect(),
1033                else_clause: else_clause.map(|e| Box::new(e.transform(func))),
1034            },
1035            Expr::IsBool {
1036                expr,
1037                value,
1038                negated,
1039            } => Expr::IsBool {
1040                expr: Box::new(expr.transform(func)),
1041                value,
1042                negated,
1043            },
1044            Expr::AnyOp { expr, op, right } => Expr::AnyOp {
1045                expr: Box::new(expr.transform(func)),
1046                op,
1047                right: Box::new(right.transform(func)),
1048            },
1049            Expr::AllOp { expr, op, right } => Expr::AllOp {
1050                expr: Box::new(expr.transform(func)),
1051                op,
1052                right: Box::new(right.transform(func)),
1053            },
1054            other => other,
1055        };
1056        func(transformed)
1057    }
1058
1059    /// Check whether this expression is a column reference.
1060    #[must_use]
1061    pub fn is_column(&self) -> bool {
1062        matches!(self, Expr::Column { .. })
1063    }
1064
1065    /// Check whether this expression is a literal value (number, string, bool, null).
1066    #[must_use]
1067    pub fn is_literal(&self) -> bool {
1068        matches!(
1069            self,
1070            Expr::Number(_) | Expr::StringLiteral(_) | Expr::Boolean(_) | Expr::Null
1071        )
1072    }
1073
1074    /// Get the SQL representation of this expression for display purposes.
1075    /// For full generation, use the Generator.
1076    #[must_use]
1077    pub fn sql(&self) -> String {
1078        use crate::generator::Generator;
1079        Generator::expr_to_sql(self)
1080    }
1081}
1082
1083/// Helper: collect all column references from an expression.
1084#[must_use]
1085pub fn find_columns(expr: &Expr) -> Vec<&Expr> {
1086    expr.find_all(&|e| matches!(e, Expr::Column { .. }))
1087}
1088
1089/// Helper: collect all table references from a statement.
1090#[must_use]
1091pub fn find_tables(statement: &Statement) -> Vec<&TableRef> {
1092    match statement {
1093        Statement::Select(sel) => {
1094            let mut tables = Vec::new();
1095            if let Some(from) = &sel.from {
1096                collect_table_refs_from_source(&from.source, &mut tables);
1097            }
1098            for join in &sel.joins {
1099                collect_table_refs_from_source(&join.table, &mut tables);
1100            }
1101            tables
1102        }
1103        Statement::Insert(ins) => vec![&ins.table],
1104        Statement::Update(upd) => vec![&upd.table],
1105        Statement::Delete(del) => vec![&del.table],
1106        Statement::CreateTable(ct) => vec![&ct.table],
1107        Statement::DropTable(dt) => vec![&dt.table],
1108        _ => vec![],
1109    }
1110}
1111
1112fn collect_table_refs_from_source<'a>(source: &'a TableSource, tables: &mut Vec<&'a TableRef>) {
1113    match source {
1114        TableSource::Table(table_ref) => tables.push(table_ref),
1115        TableSource::Subquery { .. } => {}
1116        TableSource::TableFunction { .. } => {}
1117        TableSource::Lateral { source } => collect_table_refs_from_source(source, tables),
1118        TableSource::Unnest { .. } => {}
1119    }
1120}