Skip to main content

sqlglot_rust/ast/
types.rs

1use serde::{Deserialize, Serialize};
2
3// ═══════════════════════════════════════════════════════════════════════
4// Identifier quoting style
5// ═══════════════════════════════════════════════════════════════════════
6
7/// How an identifier (column, table, alias) was quoted in the source SQL.
8///
9/// Used to preserve and transform quoting across dialects (e.g. backtick
10/// for MySQL/BigQuery → double-quote for PostgreSQL → bracket for T-SQL).
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
12pub enum QuoteStyle {
13    /// Bare / unquoted identifier
14    #[default]
15    None,
16    /// `"identifier"` — ANSI SQL, PostgreSQL, Oracle, Snowflake, etc.
17    DoubleQuote,
18    /// `` `identifier` `` — MySQL, BigQuery, Hive, Spark, etc.
19    Backtick,
20    /// `[identifier]` — T-SQL / SQL Server
21    Bracket,
22}
23
24impl QuoteStyle {
25    /// Returns the canonical quoting style for the given dialect.
26    #[must_use]
27    pub fn for_dialect(dialect: crate::dialects::Dialect) -> Self {
28        use crate::dialects::Dialect;
29        match dialect {
30            Dialect::Tsql | Dialect::Fabric => QuoteStyle::Bracket,
31            Dialect::Mysql | Dialect::BigQuery | Dialect::Hive | Dialect::Spark
32            | Dialect::Databricks | Dialect::Doris | Dialect::SingleStore
33            | Dialect::StarRocks => QuoteStyle::Backtick,
34            // ANSI, Postgres, Oracle, Snowflake, Presto, Trino, etc.
35            _ => QuoteStyle::DoubleQuote,
36        }
37    }
38
39    /// Returns `true` when the identifier carries explicit quoting.
40    #[must_use]
41    pub fn is_quoted(self) -> bool {
42        !matches!(self, QuoteStyle::None)
43    }
44}
45
46// ═══════════════════════════════════════════════════════════════════════
47// Top-level statement types
48// ═══════════════════════════════════════════════════════════════════════
49
50/// A fully parsed SQL statement.
51///
52/// Corresponds to the top-level node in sqlglot's expression tree.
53#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
54pub enum Statement {
55    Select(SelectStatement),
56    Insert(InsertStatement),
57    Update(UpdateStatement),
58    Delete(DeleteStatement),
59    CreateTable(CreateTableStatement),
60    DropTable(DropTableStatement),
61    /// UNION / INTERSECT / EXCEPT between queries
62    SetOperation(SetOperationStatement),
63    /// ALTER TABLE ...
64    AlterTable(AlterTableStatement),
65    /// CREATE VIEW ...
66    CreateView(CreateViewStatement),
67    /// DROP VIEW ...
68    DropView(DropViewStatement),
69    /// TRUNCATE TABLE ...
70    Truncate(TruncateStatement),
71    /// BEGIN / COMMIT / ROLLBACK
72    Transaction(TransactionStatement),
73    /// EXPLAIN <statement>
74    Explain(ExplainStatement),
75    /// USE database
76    Use(UseStatement),
77    /// Raw / passthrough expression (for expressions that don't fit a specific statement type)
78    Expression(Expr),
79}
80
81// ═══════════════════════════════════════════════════════════════════════
82// SELECT
83// ═══════════════════════════════════════════════════════════════════════
84
85/// A SELECT statement, including CTEs.
86///
87/// Aligned with sqlglot's `Select` expression which wraps `With`, `From`,
88/// `Where`, `Group`, `Having`, `Order`, `Limit`, `Offset`, `Window`.
89#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
90pub struct SelectStatement {
91    /// Common Table Expressions (WITH clause)
92    pub ctes: Vec<Cte>,
93    pub distinct: bool,
94    /// TOP N (TSQL-style)
95    pub top: Option<Box<Expr>>,
96    pub columns: Vec<SelectItem>,
97    pub from: Option<FromClause>,
98    pub joins: Vec<JoinClause>,
99    pub where_clause: Option<Expr>,
100    pub group_by: Vec<Expr>,
101    pub having: Option<Expr>,
102    pub order_by: Vec<OrderByItem>,
103    pub limit: Option<Expr>,
104    pub offset: Option<Expr>,
105    /// Oracle-style FETCH FIRST n ROWS ONLY
106    pub fetch_first: Option<Expr>,
107    /// QUALIFY clause (BigQuery, Snowflake)
108    pub qualify: Option<Expr>,
109    /// Named WINDOW definitions
110    pub window_definitions: Vec<WindowDefinition>,
111}
112
113/// A Common Table Expression: `name [(col1, col2)] AS [NOT] MATERIALIZED (query)`
114#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
115pub struct Cte {
116    pub name: String,
117    pub columns: Vec<String>,
118    pub query: Box<Statement>,
119    pub materialized: Option<bool>,
120    pub recursive: bool,
121}
122
123/// Named WINDOW definition: `window_name AS (window_spec)`
124#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
125pub struct WindowDefinition {
126    pub name: String,
127    pub spec: WindowSpec,
128}
129
130// ═══════════════════════════════════════════════════════════════════════
131// Set operations (UNION, INTERSECT, EXCEPT)
132// ═══════════════════════════════════════════════════════════════════════
133
134/// UNION / INTERSECT / EXCEPT between two or more queries.
135#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
136pub struct SetOperationStatement {
137    pub op: SetOperationType,
138    pub all: bool,
139    pub left: Box<Statement>,
140    pub right: Box<Statement>,
141    pub order_by: Vec<OrderByItem>,
142    pub limit: Option<Expr>,
143    pub offset: Option<Expr>,
144}
145
146#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
147pub enum SetOperationType {
148    Union,
149    Intersect,
150    Except,
151}
152
153// ═══════════════════════════════════════════════════════════════════════
154// SELECT items and FROM
155// ═══════════════════════════════════════════════════════════════════════
156
157/// An item in a SELECT list.
158#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
159pub enum SelectItem {
160    /// `*`
161    Wildcard,
162    /// `table.*`
163    QualifiedWildcard { table: String },
164    /// An expression with optional alias: `expr AS alias`
165    Expr { expr: Expr, alias: Option<String> },
166}
167
168/// A FROM clause, now supporting subqueries and multiple tables.
169#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
170pub struct FromClause {
171    pub source: TableSource,
172}
173
174/// A table source can be a table reference, subquery, or table function.
175#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
176pub enum TableSource {
177    Table(TableRef),
178    Subquery {
179        query: Box<Statement>,
180        alias: Option<String>,
181    },
182    TableFunction {
183        name: String,
184        args: Vec<Expr>,
185        alias: Option<String>,
186    },
187    /// LATERAL subquery or function
188    Lateral {
189        source: Box<TableSource>,
190    },
191    /// UNNEST(array_expr)
192    Unnest {
193        expr: Box<Expr>,
194        alias: Option<String>,
195        with_offset: bool,
196    },
197}
198
199/// A reference to a table.
200#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
201pub struct TableRef {
202    pub catalog: Option<String>,
203    pub schema: Option<String>,
204    pub name: String,
205    pub alias: Option<String>,
206    /// How the table name was quoted in the source SQL.
207    #[serde(default)]
208    pub name_quote_style: QuoteStyle,
209}
210
211/// A JOIN clause.
212#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
213pub struct JoinClause {
214    pub join_type: JoinType,
215    pub table: TableSource,
216    pub on: Option<Expr>,
217    pub using: Vec<String>,
218}
219
220/// The type of JOIN.
221#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
222pub enum JoinType {
223    Inner,
224    Left,
225    Right,
226    Full,
227    Cross,
228    /// NATURAL JOIN
229    Natural,
230    /// LATERAL JOIN
231    Lateral,
232}
233
234/// An ORDER BY item.
235#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
236pub struct OrderByItem {
237    pub expr: Expr,
238    pub ascending: bool,
239    /// NULLS FIRST / NULLS LAST
240    pub nulls_first: Option<bool>,
241}
242
243// ═══════════════════════════════════════════════════════════════════════
244// Expressions (the core of the AST)
245// ═══════════════════════════════════════════════════════════════════════
246
247/// An expression in SQL.
248///
249/// This enum is aligned with sqlglot's Expression class hierarchy.
250/// Key additions over the basic implementation:
251/// - Subquery, Exists, Cast, Extract, Window functions
252/// - TypedString, Interval, Array/Struct constructors
253/// - Postgres-style casting (`::`)
254#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
255pub enum Expr {
256    /// A column reference, possibly qualified: `[catalog.][schema.]table.column`
257    Column {
258        table: Option<String>,
259        name: String,
260        /// How the column name was quoted in the source SQL.
261        #[serde(default)]
262        quote_style: QuoteStyle,
263        /// How the table qualifier was quoted, if present.
264        #[serde(default)]
265        table_quote_style: QuoteStyle,
266    },
267    /// A numeric literal.
268    Number(String),
269    /// A string literal.
270    StringLiteral(String),
271    /// A boolean literal.
272    Boolean(bool),
273    /// NULL literal.
274    Null,
275    /// A binary operation: `left op right`
276    BinaryOp {
277        left: Box<Expr>,
278        op: BinaryOperator,
279        right: Box<Expr>,
280    },
281    /// A unary operation: `op expr`
282    UnaryOp {
283        op: UnaryOperator,
284        expr: Box<Expr>,
285    },
286    /// A function call: `name(args...)` with optional DISTINCT, ORDER BY, etc.
287    Function {
288        name: String,
289        args: Vec<Expr>,
290        distinct: bool,
291        /// FILTER (WHERE expr) clause on aggregate
292        filter: Option<Box<Expr>>,
293        /// OVER window specification for window functions
294        over: Option<WindowSpec>,
295    },
296    /// `expr BETWEEN low AND high`
297    Between {
298        expr: Box<Expr>,
299        low: Box<Expr>,
300        high: Box<Expr>,
301        negated: bool,
302    },
303    /// `expr IN (list...)` or `expr IN (subquery)`
304    InList {
305        expr: Box<Expr>,
306        list: Vec<Expr>,
307        negated: bool,
308    },
309    /// `expr IN (SELECT ...)`
310    InSubquery {
311        expr: Box<Expr>,
312        subquery: Box<Statement>,
313        negated: bool,
314    },
315    /// `expr IS [NOT] NULL`
316    IsNull {
317        expr: Box<Expr>,
318        negated: bool,
319    },
320    /// `expr IS [NOT] TRUE` / `expr IS [NOT] FALSE`
321    IsBool {
322        expr: Box<Expr>,
323        value: bool,
324        negated: bool,
325    },
326    /// `expr [NOT] LIKE pattern [ESCAPE escape_char]`
327    Like {
328        expr: Box<Expr>,
329        pattern: Box<Expr>,
330        negated: bool,
331        escape: Option<Box<Expr>>,
332    },
333    /// `expr [NOT] ILIKE pattern [ESCAPE escape_char]` (case-insensitive LIKE)
334    ILike {
335        expr: Box<Expr>,
336        pattern: Box<Expr>,
337        negated: bool,
338        escape: Option<Box<Expr>>,
339    },
340    /// `CASE [operand] WHEN ... THEN ... ELSE ... END`
341    Case {
342        operand: Option<Box<Expr>>,
343        when_clauses: Vec<(Expr, Expr)>,
344        else_clause: Option<Box<Expr>>,
345    },
346    /// A parenthesized sub-expression.
347    Nested(Box<Expr>),
348    /// A wildcard `*` used in contexts like `COUNT(*)`.
349    Wildcard,
350    /// A scalar subquery: `(SELECT ...)`
351    Subquery(Box<Statement>),
352    /// `EXISTS (SELECT ...)`
353    Exists {
354        subquery: Box<Statement>,
355        negated: bool,
356    },
357    /// `CAST(expr AS type)` or `expr::type` (PostgreSQL)
358    Cast {
359        expr: Box<Expr>,
360        data_type: DataType,
361    },
362    /// `TRY_CAST(expr AS type)`
363    TryCast {
364        expr: Box<Expr>,
365        data_type: DataType,
366    },
367    /// `EXTRACT(field FROM expr)`
368    Extract {
369        field: DateTimeField,
370        expr: Box<Expr>,
371    },
372    /// `INTERVAL 'value' unit`
373    Interval {
374        value: Box<Expr>,
375        unit: Option<DateTimeField>,
376    },
377    /// Array literal: `ARRAY[1, 2, 3]` or `[1, 2, 3]`
378    ArrayLiteral(Vec<Expr>),
379    /// Struct literal / row constructor: `(1, 'a', true)`
380    Tuple(Vec<Expr>),
381    /// `COALESCE(a, b, c)`
382    Coalesce(Vec<Expr>),
383    /// `IF(condition, true_val, false_val)` (MySQL, BigQuery)
384    If {
385        condition: Box<Expr>,
386        true_val: Box<Expr>,
387        false_val: Option<Box<Expr>>,
388    },
389    /// `NULLIF(a, b)`
390    NullIf {
391        expr: Box<Expr>,
392        r#else: Box<Expr>,
393    },
394    /// `expr COLLATE collation`
395    Collate {
396        expr: Box<Expr>,
397        collation: String,
398    },
399    /// Parameter / placeholder: `$1`, `?`, `:name`
400    Parameter(String),
401    /// A type expression used in DDL contexts or CAST
402    TypeExpr(DataType),
403    /// `table.*` in expression context
404    QualifiedWildcard { table: String },
405    /// Star expression `*`
406    Star,
407    /// Alias expression: `expr AS name`
408    Alias {
409        expr: Box<Expr>,
410        name: String,
411    },
412    /// Array access: `expr[index]`
413    ArrayIndex {
414        expr: Box<Expr>,
415        index: Box<Expr>,
416    },
417    /// JSON access: `expr->key` or `expr->>key`
418    JsonAccess {
419        expr: Box<Expr>,
420        path: Box<Expr>,
421        /// false = ->, true = ->>
422        as_text: bool,
423    },
424    /// Lambda expression: `x -> x + 1`
425    Lambda {
426        params: Vec<String>,
427        body: Box<Expr>,
428    },
429    /// `DEFAULT` keyword in INSERT/UPDATE contexts
430    Default,
431}
432
433// ═══════════════════════════════════════════════════════════════════════
434// Window specification
435// ═══════════════════════════════════════════════════════════════════════
436
437/// Window specification for window functions: OVER (PARTITION BY ... ORDER BY ... frame)
438#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
439pub struct WindowSpec {
440    /// Reference to a named window
441    pub window_ref: Option<String>,
442    pub partition_by: Vec<Expr>,
443    pub order_by: Vec<OrderByItem>,
444    pub frame: Option<WindowFrame>,
445}
446
447#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
448pub struct WindowFrame {
449    pub kind: WindowFrameKind,
450    pub start: WindowFrameBound,
451    pub end: Option<WindowFrameBound>,
452}
453
454#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
455pub enum WindowFrameKind {
456    Rows,
457    Range,
458    Groups,
459}
460
461#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
462pub enum WindowFrameBound {
463    CurrentRow,
464    Preceding(Option<Box<Expr>>),  // None = UNBOUNDED PRECEDING
465    Following(Option<Box<Expr>>),  // None = UNBOUNDED FOLLOWING
466}
467
468// ═══════════════════════════════════════════════════════════════════════
469// Date/time fields (for EXTRACT, INTERVAL)
470// ═══════════════════════════════════════════════════════════════════════
471
472#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
473pub enum DateTimeField {
474    Year,
475    Quarter,
476    Month,
477    Week,
478    Day,
479    DayOfWeek,
480    DayOfYear,
481    Hour,
482    Minute,
483    Second,
484    Millisecond,
485    Microsecond,
486    Nanosecond,
487    Epoch,
488    Timezone,
489    TimezoneHour,
490    TimezoneMinute,
491}
492
493// ═══════════════════════════════════════════════════════════════════════
494// Operators
495// ═══════════════════════════════════════════════════════════════════════
496
497/// Binary operators.
498#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
499pub enum BinaryOperator {
500    Plus,
501    Minus,
502    Multiply,
503    Divide,
504    Modulo,
505    Eq,
506    Neq,
507    Lt,
508    Gt,
509    LtEq,
510    GtEq,
511    And,
512    Or,
513    Xor,
514    Concat,
515    BitwiseAnd,
516    BitwiseOr,
517    BitwiseXor,
518    ShiftLeft,
519    ShiftRight,
520    /// `->` JSON access operator
521    Arrow,
522    /// `->>` JSON text access
523    DoubleArrow,
524}
525
526/// Unary operators.
527#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
528pub enum UnaryOperator {
529    Not,
530    Minus,
531    Plus,
532    BitwiseNot,
533}
534
535// ═══════════════════════════════════════════════════════════════════════
536// DML statements
537// ═══════════════════════════════════════════════════════════════════════
538
539/// An INSERT statement, now supporting INSERT ... SELECT and ON CONFLICT.
540#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
541pub struct InsertStatement {
542    pub table: TableRef,
543    pub columns: Vec<String>,
544    pub source: InsertSource,
545    /// ON CONFLICT / ON DUPLICATE KEY
546    pub on_conflict: Option<OnConflict>,
547    /// RETURNING clause
548    pub returning: Vec<SelectItem>,
549}
550
551#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
552pub enum InsertSource {
553    Values(Vec<Vec<Expr>>),
554    Query(Box<Statement>),
555    Default,
556}
557
558#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
559pub struct OnConflict {
560    pub columns: Vec<String>,
561    pub action: ConflictAction,
562}
563
564#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
565pub enum ConflictAction {
566    DoNothing,
567    DoUpdate(Vec<(String, Expr)>),
568}
569
570/// An UPDATE statement.
571#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
572pub struct UpdateStatement {
573    pub table: TableRef,
574    pub assignments: Vec<(String, Expr)>,
575    pub from: Option<FromClause>,
576    pub where_clause: Option<Expr>,
577    pub returning: Vec<SelectItem>,
578}
579
580/// A DELETE statement.
581#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
582pub struct DeleteStatement {
583    pub table: TableRef,
584    pub using: Option<FromClause>,
585    pub where_clause: Option<Expr>,
586    pub returning: Vec<SelectItem>,
587}
588
589// ═══════════════════════════════════════════════════════════════════════
590// DDL statements
591// ═══════════════════════════════════════════════════════════════════════
592
593/// A CREATE TABLE statement.
594#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
595pub struct CreateTableStatement {
596    pub if_not_exists: bool,
597    pub temporary: bool,
598    pub table: TableRef,
599    pub columns: Vec<ColumnDef>,
600    pub constraints: Vec<TableConstraint>,
601    /// CREATE TABLE ... AS SELECT ...
602    pub as_select: Option<Box<Statement>>,
603}
604
605/// Table-level constraints.
606#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
607pub enum TableConstraint {
608    PrimaryKey {
609        name: Option<String>,
610        columns: Vec<String>,
611    },
612    Unique {
613        name: Option<String>,
614        columns: Vec<String>,
615    },
616    ForeignKey {
617        name: Option<String>,
618        columns: Vec<String>,
619        ref_table: TableRef,
620        ref_columns: Vec<String>,
621        on_delete: Option<ReferentialAction>,
622        on_update: Option<ReferentialAction>,
623    },
624    Check {
625        name: Option<String>,
626        expr: Expr,
627    },
628}
629
630#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
631pub enum ReferentialAction {
632    Cascade,
633    Restrict,
634    NoAction,
635    SetNull,
636    SetDefault,
637}
638
639/// A column definition in CREATE TABLE.
640#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
641pub struct ColumnDef {
642    pub name: String,
643    pub data_type: DataType,
644    pub nullable: Option<bool>,
645    pub default: Option<Expr>,
646    pub primary_key: bool,
647    pub unique: bool,
648    pub auto_increment: bool,
649    pub collation: Option<String>,
650    pub comment: Option<String>,
651}
652
653/// ALTER TABLE statement.
654#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
655pub struct AlterTableStatement {
656    pub table: TableRef,
657    pub actions: Vec<AlterTableAction>,
658}
659
660#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
661pub enum AlterTableAction {
662    AddColumn(ColumnDef),
663    DropColumn { name: String, if_exists: bool },
664    RenameColumn { old_name: String, new_name: String },
665    AlterColumnType { name: String, data_type: DataType },
666    AddConstraint(TableConstraint),
667    DropConstraint { name: String },
668    RenameTable { new_name: String },
669}
670
671/// CREATE VIEW statement.
672#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
673pub struct CreateViewStatement {
674    pub name: TableRef,
675    pub columns: Vec<String>,
676    pub query: Box<Statement>,
677    pub or_replace: bool,
678    pub materialized: bool,
679    pub if_not_exists: bool,
680}
681
682/// DROP VIEW statement.
683#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
684pub struct DropViewStatement {
685    pub name: TableRef,
686    pub if_exists: bool,
687    pub materialized: bool,
688}
689
690/// TRUNCATE TABLE statement.
691#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
692pub struct TruncateStatement {
693    pub table: TableRef,
694}
695
696/// Transaction control statements.
697#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
698pub enum TransactionStatement {
699    Begin,
700    Commit,
701    Rollback,
702    Savepoint(String),
703    ReleaseSavepoint(String),
704    RollbackTo(String),
705}
706
707/// EXPLAIN statement.
708#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
709pub struct ExplainStatement {
710    pub analyze: bool,
711    pub statement: Box<Statement>,
712}
713
714/// USE database statement.
715#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
716pub struct UseStatement {
717    pub name: String,
718}
719
720/// A DROP TABLE statement.
721#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
722pub struct DropTableStatement {
723    pub if_exists: bool,
724    pub table: TableRef,
725    pub cascade: bool,
726}
727
728// ═══════════════════════════════════════════════════════════════════════
729// Data types
730// ═══════════════════════════════════════════════════════════════════════
731
732/// SQL data types. Significantly expanded to match sqlglot's DataType.Type enum.
733#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
734pub enum DataType {
735    // Numeric
736    TinyInt,
737    SmallInt,
738    Int,
739    BigInt,
740    Float,
741    Double,
742    Decimal { precision: Option<u32>, scale: Option<u32> },
743    Numeric { precision: Option<u32>, scale: Option<u32> },
744    Real,
745
746    // String
747    Varchar(Option<u32>),
748    Char(Option<u32>),
749    Text,
750    String,
751    Binary(Option<u32>),
752    Varbinary(Option<u32>),
753
754    // Boolean
755    Boolean,
756
757    // Date/Time
758    Date,
759    Time { precision: Option<u32> },
760    Timestamp { precision: Option<u32>, with_tz: bool },
761    Interval,
762    DateTime,
763
764    // Binary
765    Blob,
766    Bytea,
767    Bytes,
768
769    // JSON
770    Json,
771    Jsonb,
772
773    // UUID
774    Uuid,
775
776    // Complex types
777    Array(Option<Box<DataType>>),
778    Map { key: Box<DataType>, value: Box<DataType> },
779    Struct(Vec<(String, DataType)>),
780    Tuple(Vec<DataType>),
781
782    // Special
783    Null,
784    Unknown(String),
785    Variant,
786    Object,
787    Xml,
788    Inet,
789    Cidr,
790    Macaddr,
791    Bit(Option<u32>),
792    Money,
793    Serial,
794    BigSerial,
795    SmallSerial,
796    Regclass,
797    Regtype,
798    Hstore,
799    Geography,
800    Geometry,
801    Super,
802}
803
804// ═══════════════════════════════════════════════════════════════════════
805// Expression tree traversal helpers
806// ═══════════════════════════════════════════════════════════════════════
807
808impl Expr {
809    /// Recursively walk this expression tree, calling `visitor` on each node.
810    /// If `visitor` returns `false`, children of that node are not visited.
811    pub fn walk<F>(&self, visitor: &mut F)
812    where
813        F: FnMut(&Expr) -> bool,
814    {
815        if !visitor(self) {
816            return;
817        }
818        match self {
819            Expr::BinaryOp { left, right, .. } => {
820                left.walk(visitor);
821                right.walk(visitor);
822            }
823            Expr::UnaryOp { expr, .. } => expr.walk(visitor),
824            Expr::Function { args, filter, .. } => {
825                for arg in args {
826                    arg.walk(visitor);
827                }
828                if let Some(f) = filter {
829                    f.walk(visitor);
830                }
831            }
832            Expr::Between { expr, low, high, .. } => {
833                expr.walk(visitor);
834                low.walk(visitor);
835                high.walk(visitor);
836            }
837            Expr::InList { expr, list, .. } => {
838                expr.walk(visitor);
839                for item in list {
840                    item.walk(visitor);
841                }
842            }
843            Expr::InSubquery { expr, .. } => {
844                expr.walk(visitor);
845            }
846            Expr::IsNull { expr, .. } => expr.walk(visitor),
847            Expr::IsBool { expr, .. } => expr.walk(visitor),
848            Expr::Like { expr, pattern, .. } | Expr::ILike { expr, pattern, .. } => {
849                expr.walk(visitor);
850                pattern.walk(visitor);
851            }
852            Expr::Case {
853                operand,
854                when_clauses,
855                else_clause,
856            } => {
857                if let Some(op) = operand {
858                    op.walk(visitor);
859                }
860                for (cond, result) in when_clauses {
861                    cond.walk(visitor);
862                    result.walk(visitor);
863                }
864                if let Some(el) = else_clause {
865                    el.walk(visitor);
866                }
867            }
868            Expr::Nested(inner) => inner.walk(visitor),
869            Expr::Cast { expr, .. } | Expr::TryCast { expr, .. } => expr.walk(visitor),
870            Expr::Extract { expr, .. } => expr.walk(visitor),
871            Expr::Interval { value, .. } => value.walk(visitor),
872            Expr::ArrayLiteral(items) | Expr::Tuple(items) | Expr::Coalesce(items) => {
873                for item in items {
874                    item.walk(visitor);
875                }
876            }
877            Expr::If { condition, true_val, false_val } => {
878                condition.walk(visitor);
879                true_val.walk(visitor);
880                if let Some(fv) = false_val {
881                    fv.walk(visitor);
882                }
883            }
884            Expr::NullIf { expr, r#else } => {
885                expr.walk(visitor);
886                r#else.walk(visitor);
887            }
888            Expr::Collate { expr, .. } => expr.walk(visitor),
889            Expr::Alias { expr, .. } => expr.walk(visitor),
890            Expr::ArrayIndex { expr, index } => {
891                expr.walk(visitor);
892                index.walk(visitor);
893            }
894            Expr::JsonAccess { expr, path, .. } => {
895                expr.walk(visitor);
896                path.walk(visitor);
897            }
898            Expr::Lambda { body, .. } => body.walk(visitor),
899            // Leaf nodes
900            Expr::Column { .. }
901            | Expr::Number(_)
902            | Expr::StringLiteral(_)
903            | Expr::Boolean(_)
904            | Expr::Null
905            | Expr::Wildcard
906            | Expr::Star
907            | Expr::Parameter(_)
908            | Expr::TypeExpr(_)
909            | Expr::QualifiedWildcard { .. }
910            | Expr::Default
911            | Expr::Subquery(_)
912            | Expr::Exists { .. } => {}
913        }
914    }
915
916    /// Find the first expression matching the predicate.
917    #[must_use]
918    pub fn find<F>(&self, predicate: &F) -> Option<&Expr>
919    where
920        F: Fn(&Expr) -> bool,
921    {
922        let mut result = None;
923        self.walk(&mut |expr| {
924            if result.is_some() {
925                return false;
926            }
927            if predicate(expr) {
928                result = Some(expr as *const Expr);
929                false
930            } else {
931                true
932            }
933        });
934        // SAFETY: the pointer is valid as long as self is alive
935        result.map(|p| unsafe { &*p })
936    }
937
938    /// Find all expressions matching the predicate.
939    #[must_use]
940    pub fn find_all<F>(&self, predicate: &F) -> Vec<&Expr>
941    where
942        F: Fn(&Expr) -> bool,
943    {
944        let mut results: Vec<*const Expr> = Vec::new();
945        self.walk(&mut |expr| {
946            if predicate(expr) {
947                results.push(expr as *const Expr);
948            }
949            true
950        });
951        results.into_iter().map(|p| unsafe { &*p }).collect()
952    }
953
954    /// Transform this expression tree by applying a function to each node.
955    /// The function can return a new expression to replace the current one.
956    #[must_use]
957    pub fn transform<F>(self, func: &F) -> Expr
958    where
959        F: Fn(Expr) -> Expr,
960    {
961        let transformed = match self {
962            Expr::BinaryOp { left, op, right } => Expr::BinaryOp {
963                left: Box::new(left.transform(func)),
964                op,
965                right: Box::new(right.transform(func)),
966            },
967            Expr::UnaryOp { op, expr } => Expr::UnaryOp {
968                op,
969                expr: Box::new(expr.transform(func)),
970            },
971            Expr::Function { name, args, distinct, filter, over } => Expr::Function {
972                name,
973                args: args.into_iter().map(|a| a.transform(func)).collect(),
974                distinct,
975                filter: filter.map(|f| Box::new(f.transform(func))),
976                over,
977            },
978            Expr::Nested(inner) => Expr::Nested(Box::new(inner.transform(func))),
979            Expr::Cast { expr, data_type } => Expr::Cast {
980                expr: Box::new(expr.transform(func)),
981                data_type,
982            },
983            Expr::Between { expr, low, high, negated } => Expr::Between {
984                expr: Box::new(expr.transform(func)),
985                low: Box::new(low.transform(func)),
986                high: Box::new(high.transform(func)),
987                negated,
988            },
989            Expr::Case { operand, when_clauses, else_clause } => Expr::Case {
990                operand: operand.map(|o| Box::new(o.transform(func))),
991                when_clauses: when_clauses
992                    .into_iter()
993                    .map(|(c, r)| (c.transform(func), r.transform(func)))
994                    .collect(),
995                else_clause: else_clause.map(|e| Box::new(e.transform(func))),
996            },
997            Expr::IsBool { expr, value, negated } => Expr::IsBool {
998                expr: Box::new(expr.transform(func)),
999                value,
1000                negated,
1001            },
1002            other => other,
1003        };
1004        func(transformed)
1005    }
1006
1007    /// Check whether this expression is a column reference.
1008    #[must_use]
1009    pub fn is_column(&self) -> bool {
1010        matches!(self, Expr::Column { .. })
1011    }
1012
1013    /// Check whether this expression is a literal value (number, string, bool, null).
1014    #[must_use]
1015    pub fn is_literal(&self) -> bool {
1016        matches!(
1017            self,
1018            Expr::Number(_) | Expr::StringLiteral(_) | Expr::Boolean(_) | Expr::Null
1019        )
1020    }
1021
1022    /// Get the SQL representation of this expression for display purposes.
1023    /// For full generation, use the Generator.
1024    #[must_use]
1025    pub fn sql(&self) -> String {
1026        use crate::generator::Generator;
1027        Generator::expr_to_sql(self)
1028    }
1029}
1030
1031/// Helper: collect all column references from an expression.
1032#[must_use]
1033pub fn find_columns(expr: &Expr) -> Vec<&Expr> {
1034    expr.find_all(&|e| matches!(e, Expr::Column { .. }))
1035}
1036
1037/// Helper: collect all table references from a statement.
1038#[must_use]
1039pub fn find_tables(statement: &Statement) -> Vec<&TableRef> {
1040    match statement {
1041        Statement::Select(sel) => {
1042            let mut tables = Vec::new();
1043            if let Some(from) = &sel.from {
1044                collect_table_refs_from_source(&from.source, &mut tables);
1045            }
1046            for join in &sel.joins {
1047                collect_table_refs_from_source(&join.table, &mut tables);
1048            }
1049            tables
1050        }
1051        Statement::Insert(ins) => vec![&ins.table],
1052        Statement::Update(upd) => vec![&upd.table],
1053        Statement::Delete(del) => vec![&del.table],
1054        Statement::CreateTable(ct) => vec![&ct.table],
1055        Statement::DropTable(dt) => vec![&dt.table],
1056        _ => vec![],
1057    }
1058}
1059
1060fn collect_table_refs_from_source<'a>(source: &'a TableSource, tables: &mut Vec<&'a TableRef>) {
1061    match source {
1062        TableSource::Table(table_ref) => tables.push(table_ref),
1063        TableSource::Subquery { .. } => {}
1064        TableSource::TableFunction { .. } => {}
1065        TableSource::Lateral { source } => collect_table_refs_from_source(source, tables),
1066        TableSource::Unnest { .. } => {}
1067    }
1068}