Skip to main content

sqlglot_rust/ast/
types.rs

1use serde::{Deserialize, Serialize};
2
3// ═══════════════════════════════════════════════════════════════════════
4// Identifier quoting style
5// ═══════════════════════════════════════════════════════════════════════
6
7/// How an identifier (column, table, alias) was quoted in the source SQL.
8///
9/// Used to preserve and transform quoting across dialects (e.g. backtick
10/// for MySQL/BigQuery → double-quote for PostgreSQL → bracket for T-SQL).
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
12pub enum QuoteStyle {
13    /// Bare / unquoted identifier
14    #[default]
15    None,
16    /// `"identifier"` — ANSI SQL, PostgreSQL, Oracle, Snowflake, etc.
17    DoubleQuote,
18    /// `` `identifier` `` — MySQL, BigQuery, Hive, Spark, etc.
19    Backtick,
20    /// `[identifier]` — T-SQL / SQL Server
21    Bracket,
22}
23
24impl QuoteStyle {
25    /// Returns the canonical quoting style for the given dialect.
26    #[must_use]
27    pub fn for_dialect(dialect: crate::dialects::Dialect) -> Self {
28        use crate::dialects::Dialect;
29        match dialect {
30            Dialect::Tsql | Dialect::Fabric => QuoteStyle::Bracket,
31            Dialect::Mysql
32            | Dialect::BigQuery
33            | Dialect::Hive
34            | Dialect::Spark
35            | Dialect::Databricks
36            | Dialect::Doris
37            | Dialect::SingleStore
38            | Dialect::StarRocks => QuoteStyle::Backtick,
39            // ANSI, Postgres, Oracle, Snowflake, Presto, Trino, etc.
40            _ => QuoteStyle::DoubleQuote,
41        }
42    }
43
44    /// Returns `true` when the identifier carries explicit quoting.
45    #[must_use]
46    pub fn is_quoted(self) -> bool {
47        !matches!(self, QuoteStyle::None)
48    }
49}
50
51// ═══════════════════════════════════════════════════════════════════════
52// Top-level statement types
53// ═══════════════════════════════════════════════════════════════════════
54
55/// A fully parsed SQL statement.
56///
57/// Corresponds to the top-level node in sqlglot's expression tree.
58#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
59pub enum Statement {
60    Select(SelectStatement),
61    Insert(InsertStatement),
62    Update(UpdateStatement),
63    Delete(DeleteStatement),
64    CreateTable(CreateTableStatement),
65    DropTable(DropTableStatement),
66    /// UNION / INTERSECT / EXCEPT between queries
67    SetOperation(SetOperationStatement),
68    /// ALTER TABLE ...
69    AlterTable(AlterTableStatement),
70    /// CREATE VIEW ...
71    CreateView(CreateViewStatement),
72    /// DROP VIEW ...
73    DropView(DropViewStatement),
74    /// TRUNCATE TABLE ...
75    Truncate(TruncateStatement),
76    /// BEGIN / COMMIT / ROLLBACK
77    Transaction(TransactionStatement),
78    /// EXPLAIN <statement>
79    Explain(ExplainStatement),
80    /// USE database
81    Use(UseStatement),
82    /// Raw / passthrough expression (for expressions that don't fit a specific statement type)
83    Expression(Expr),
84}
85
86// ═══════════════════════════════════════════════════════════════════════
87// SELECT
88// ═══════════════════════════════════════════════════════════════════════
89
90/// A SELECT statement, including CTEs.
91///
92/// Aligned with sqlglot's `Select` expression which wraps `With`, `From`,
93/// `Where`, `Group`, `Having`, `Order`, `Limit`, `Offset`, `Window`.
94#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
95pub struct SelectStatement {
96    /// Common Table Expressions (WITH clause)
97    pub ctes: Vec<Cte>,
98    pub distinct: bool,
99    /// TOP N (TSQL-style)
100    pub top: Option<Box<Expr>>,
101    pub columns: Vec<SelectItem>,
102    pub from: Option<FromClause>,
103    pub joins: Vec<JoinClause>,
104    pub where_clause: Option<Expr>,
105    pub group_by: Vec<Expr>,
106    pub having: Option<Expr>,
107    pub order_by: Vec<OrderByItem>,
108    pub limit: Option<Expr>,
109    pub offset: Option<Expr>,
110    /// Oracle-style FETCH FIRST n ROWS ONLY
111    pub fetch_first: Option<Expr>,
112    /// QUALIFY clause (BigQuery, Snowflake)
113    pub qualify: Option<Expr>,
114    /// Named WINDOW definitions
115    pub window_definitions: Vec<WindowDefinition>,
116}
117
118/// A Common Table Expression: `name [(col1, col2)] AS [NOT] MATERIALIZED (query)`
119#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
120pub struct Cte {
121    pub name: String,
122    pub columns: Vec<String>,
123    pub query: Box<Statement>,
124    pub materialized: Option<bool>,
125    pub recursive: bool,
126}
127
128/// Named WINDOW definition: `window_name AS (window_spec)`
129#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
130pub struct WindowDefinition {
131    pub name: String,
132    pub spec: WindowSpec,
133}
134
135// ═══════════════════════════════════════════════════════════════════════
136// Set operations (UNION, INTERSECT, EXCEPT)
137// ═══════════════════════════════════════════════════════════════════════
138
139/// UNION / INTERSECT / EXCEPT between two or more queries.
140#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
141pub struct SetOperationStatement {
142    pub op: SetOperationType,
143    pub all: bool,
144    pub left: Box<Statement>,
145    pub right: Box<Statement>,
146    pub order_by: Vec<OrderByItem>,
147    pub limit: Option<Expr>,
148    pub offset: Option<Expr>,
149}
150
151#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
152pub enum SetOperationType {
153    Union,
154    Intersect,
155    Except,
156}
157
158// ═══════════════════════════════════════════════════════════════════════
159// SELECT items and FROM
160// ═══════════════════════════════════════════════════════════════════════
161
162/// An item in a SELECT list.
163#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
164pub enum SelectItem {
165    /// `*`
166    Wildcard,
167    /// `table.*`
168    QualifiedWildcard { table: String },
169    /// An expression with optional alias: `expr AS alias`
170    Expr { expr: Expr, alias: Option<String> },
171}
172
173/// A FROM clause, now supporting subqueries and multiple tables.
174#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
175pub struct FromClause {
176    pub source: TableSource,
177}
178
179/// A table source can be a table reference, subquery, or table function.
180#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
181pub enum TableSource {
182    Table(TableRef),
183    Subquery {
184        query: Box<Statement>,
185        alias: Option<String>,
186    },
187    TableFunction {
188        name: String,
189        args: Vec<Expr>,
190        alias: Option<String>,
191    },
192    /// LATERAL subquery or function
193    Lateral {
194        source: Box<TableSource>,
195    },
196    /// UNNEST(array_expr)
197    Unnest {
198        expr: Box<Expr>,
199        alias: Option<String>,
200        with_offset: bool,
201    },
202}
203
204/// A reference to a table.
205#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
206pub struct TableRef {
207    pub catalog: Option<String>,
208    pub schema: Option<String>,
209    pub name: String,
210    pub alias: Option<String>,
211    /// How the table name was quoted in the source SQL.
212    #[serde(default)]
213    pub name_quote_style: QuoteStyle,
214}
215
216/// A JOIN clause.
217#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
218pub struct JoinClause {
219    pub join_type: JoinType,
220    pub table: TableSource,
221    pub on: Option<Expr>,
222    pub using: Vec<String>,
223}
224
225/// The type of JOIN.
226#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
227pub enum JoinType {
228    Inner,
229    Left,
230    Right,
231    Full,
232    Cross,
233    /// NATURAL JOIN
234    Natural,
235    /// LATERAL JOIN
236    Lateral,
237}
238
239/// An ORDER BY item.
240#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
241pub struct OrderByItem {
242    pub expr: Expr,
243    pub ascending: bool,
244    /// NULLS FIRST / NULLS LAST
245    pub nulls_first: Option<bool>,
246}
247
248// ═══════════════════════════════════════════════════════════════════════
249// Expressions (the core of the AST)
250// ═══════════════════════════════════════════════════════════════════════
251
252/// An expression in SQL.
253///
254/// This enum is aligned with sqlglot's Expression class hierarchy.
255/// Key additions over the basic implementation:
256/// - Subquery, Exists, Cast, Extract, Window functions
257/// - TypedString, Interval, Array/Struct constructors
258/// - Postgres-style casting (`::`)
259#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
260pub enum Expr {
261    /// A column reference, possibly qualified: `[catalog.][schema.]table.column`
262    Column {
263        table: Option<String>,
264        name: String,
265        /// How the column name was quoted in the source SQL.
266        #[serde(default)]
267        quote_style: QuoteStyle,
268        /// How the table qualifier was quoted, if present.
269        #[serde(default)]
270        table_quote_style: QuoteStyle,
271    },
272    /// A numeric literal.
273    Number(String),
274    /// A string literal.
275    StringLiteral(String),
276    /// A boolean literal.
277    Boolean(bool),
278    /// NULL literal.
279    Null,
280    /// A binary operation: `left op right`
281    BinaryOp {
282        left: Box<Expr>,
283        op: BinaryOperator,
284        right: Box<Expr>,
285    },
286    /// A unary operation: `op expr`
287    UnaryOp { op: UnaryOperator, expr: Box<Expr> },
288    /// A function call: `name(args...)` with optional DISTINCT, ORDER BY, etc.
289    Function {
290        name: String,
291        args: Vec<Expr>,
292        distinct: bool,
293        /// FILTER (WHERE expr) clause on aggregate
294        filter: Option<Box<Expr>>,
295        /// OVER window specification for window functions
296        over: Option<WindowSpec>,
297    },
298    /// `expr BETWEEN low AND high`
299    Between {
300        expr: Box<Expr>,
301        low: Box<Expr>,
302        high: Box<Expr>,
303        negated: bool,
304    },
305    /// `expr IN (list...)` or `expr IN (subquery)`
306    InList {
307        expr: Box<Expr>,
308        list: Vec<Expr>,
309        negated: bool,
310    },
311    /// `expr IN (SELECT ...)`
312    InSubquery {
313        expr: Box<Expr>,
314        subquery: Box<Statement>,
315        negated: bool,
316    },
317    /// `expr op ANY(subexpr)` — PostgreSQL array/subquery comparison
318    AnyOp {
319        expr: Box<Expr>,
320        op: BinaryOperator,
321        right: Box<Expr>,
322    },
323    /// `expr op ALL(subexpr)` — PostgreSQL array/subquery comparison
324    AllOp {
325        expr: Box<Expr>,
326        op: BinaryOperator,
327        right: Box<Expr>,
328    },
329    /// `expr IS [NOT] NULL`
330    IsNull { expr: Box<Expr>, negated: bool },
331    /// `expr IS [NOT] TRUE` / `expr IS [NOT] FALSE`
332    IsBool {
333        expr: Box<Expr>,
334        value: bool,
335        negated: bool,
336    },
337    /// `expr [NOT] LIKE pattern [ESCAPE escape_char]`
338    Like {
339        expr: Box<Expr>,
340        pattern: Box<Expr>,
341        negated: bool,
342        escape: Option<Box<Expr>>,
343    },
344    /// `expr [NOT] ILIKE pattern [ESCAPE escape_char]` (case-insensitive LIKE)
345    ILike {
346        expr: Box<Expr>,
347        pattern: Box<Expr>,
348        negated: bool,
349        escape: Option<Box<Expr>>,
350    },
351    /// `CASE [operand] WHEN ... THEN ... ELSE ... END`
352    Case {
353        operand: Option<Box<Expr>>,
354        when_clauses: Vec<(Expr, Expr)>,
355        else_clause: Option<Box<Expr>>,
356    },
357    /// A parenthesized sub-expression.
358    Nested(Box<Expr>),
359    /// A wildcard `*` used in contexts like `COUNT(*)`.
360    Wildcard,
361    /// A scalar subquery: `(SELECT ...)`
362    Subquery(Box<Statement>),
363    /// `EXISTS (SELECT ...)`
364    Exists {
365        subquery: Box<Statement>,
366        negated: bool,
367    },
368    /// `CAST(expr AS type)` or `expr::type` (PostgreSQL)
369    Cast {
370        expr: Box<Expr>,
371        data_type: DataType,
372    },
373    /// `TRY_CAST(expr AS type)`
374    TryCast {
375        expr: Box<Expr>,
376        data_type: DataType,
377    },
378    /// `EXTRACT(field FROM expr)`
379    Extract {
380        field: DateTimeField,
381        expr: Box<Expr>,
382    },
383    /// `INTERVAL 'value' unit`
384    Interval {
385        value: Box<Expr>,
386        unit: Option<DateTimeField>,
387    },
388    /// Array literal: `ARRAY[1, 2, 3]` or `[1, 2, 3]`
389    ArrayLiteral(Vec<Expr>),
390    /// Struct literal / row constructor: `(1, 'a', true)`
391    Tuple(Vec<Expr>),
392    /// `COALESCE(a, b, c)`
393    Coalesce(Vec<Expr>),
394    /// `IF(condition, true_val, false_val)` (MySQL, BigQuery)
395    If {
396        condition: Box<Expr>,
397        true_val: Box<Expr>,
398        false_val: Option<Box<Expr>>,
399    },
400    /// `NULLIF(a, b)`
401    NullIf { expr: Box<Expr>, r#else: Box<Expr> },
402    /// `expr COLLATE collation`
403    Collate { expr: Box<Expr>, collation: String },
404    /// Parameter / placeholder: `$1`, `?`, `:name`
405    Parameter(String),
406    /// A type expression used in DDL contexts or CAST
407    TypeExpr(DataType),
408    /// `table.*` in expression context
409    QualifiedWildcard { table: String },
410    /// Star expression `*`
411    Star,
412    /// Alias expression: `expr AS name`
413    Alias { expr: Box<Expr>, name: String },
414    /// Array access: `expr[index]`
415    ArrayIndex { expr: Box<Expr>, index: Box<Expr> },
416    /// JSON access: `expr->key` or `expr->>key`
417    JsonAccess {
418        expr: Box<Expr>,
419        path: Box<Expr>,
420        /// false = ->, true = ->>
421        as_text: bool,
422    },
423    /// Lambda expression: `x -> x + 1`
424    Lambda {
425        params: Vec<String>,
426        body: Box<Expr>,
427    },
428    /// `DEFAULT` keyword in INSERT/UPDATE contexts
429    Default,
430    /// A typed function expression with semantic awareness.
431    /// Enables per-function, per-dialect code generation and transpilation.
432    TypedFunction {
433        func: TypedFunction,
434        /// FILTER (WHERE expr) clause on aggregate
435        filter: Option<Box<Expr>>,
436        /// OVER window specification for window functions
437        over: Option<WindowSpec>,
438    },
439}
440
441// ═══════════════════════════════════════════════════════════════════════
442// Window specification
443// ═══════════════════════════════════════════════════════════════════════
444
445/// Window specification for window functions: OVER (PARTITION BY ... ORDER BY ... frame)
446#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
447pub struct WindowSpec {
448    /// Reference to a named window
449    pub window_ref: Option<String>,
450    pub partition_by: Vec<Expr>,
451    pub order_by: Vec<OrderByItem>,
452    pub frame: Option<WindowFrame>,
453}
454
455#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
456pub struct WindowFrame {
457    pub kind: WindowFrameKind,
458    pub start: WindowFrameBound,
459    pub end: Option<WindowFrameBound>,
460}
461
462#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
463pub enum WindowFrameKind {
464    Rows,
465    Range,
466    Groups,
467}
468
469#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
470pub enum WindowFrameBound {
471    CurrentRow,
472    Preceding(Option<Box<Expr>>), // None = UNBOUNDED PRECEDING
473    Following(Option<Box<Expr>>), // None = UNBOUNDED FOLLOWING
474}
475
476// ═══════════════════════════════════════════════════════════════════════
477// Date/time fields (for EXTRACT, INTERVAL)
478// ═══════════════════════════════════════════════════════════════════════
479
480#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
481pub enum DateTimeField {
482    Year,
483    Quarter,
484    Month,
485    Week,
486    Day,
487    DayOfWeek,
488    DayOfYear,
489    Hour,
490    Minute,
491    Second,
492    Millisecond,
493    Microsecond,
494    Nanosecond,
495    Epoch,
496    Timezone,
497    TimezoneHour,
498    TimezoneMinute,
499}
500
501// ═══════════════════════════════════════════════════════════════════════
502// Trim type
503// ═══════════════════════════════════════════════════════════════════════
504
505/// The type of TRIM operation.
506#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
507pub enum TrimType {
508    Leading,
509    Trailing,
510    Both,
511}
512
513// ═══════════════════════════════════════════════════════════════════════
514// Typed function expressions
515// ═══════════════════════════════════════════════════════════════════════
516
517/// Typed function variants enabling per-function transpilation rules,
518/// function signature validation, and dialect-specific code generation.
519///
520/// Each variant carries semantically typed arguments rather than a generic
521/// `Vec<Expr>`, allowing the generator to emit dialect-specific SQL.
522#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
523pub enum TypedFunction {
524    // ── Date/Time ──────────────────────────────────────────────────────
525    /// `DATE_ADD(expr, interval)` — add an interval to a date/timestamp
526    DateAdd {
527        expr: Box<Expr>,
528        interval: Box<Expr>,
529        unit: Option<DateTimeField>,
530    },
531    /// `DATE_DIFF(start, end)` — difference between two dates
532    DateDiff {
533        start: Box<Expr>,
534        end: Box<Expr>,
535        unit: Option<DateTimeField>,
536    },
537    /// `DATE_TRUNC(unit, expr)` — truncate to the given precision
538    DateTrunc {
539        unit: DateTimeField,
540        expr: Box<Expr>,
541    },
542    /// `DATE_SUB(expr, interval)` — subtract an interval from a date
543    DateSub {
544        expr: Box<Expr>,
545        interval: Box<Expr>,
546        unit: Option<DateTimeField>,
547    },
548    /// `CURRENT_DATE`
549    CurrentDate,
550    /// `CURRENT_TIMESTAMP` / `NOW()` / `GETDATE()`
551    CurrentTimestamp,
552    /// `STR_TO_TIME(expr, format)` / `TO_TIMESTAMP` / `PARSE_DATETIME`
553    StrToTime { expr: Box<Expr>, format: Box<Expr> },
554    /// `TIME_TO_STR(expr, format)` / `DATE_FORMAT` / `FORMAT_DATETIME`
555    TimeToStr { expr: Box<Expr>, format: Box<Expr> },
556    /// `TS_OR_DS_TO_DATE(expr)` — convert timestamp or date-string to date
557    TsOrDsToDate { expr: Box<Expr> },
558    /// `YEAR(expr)` — extract year from a date/timestamp
559    Year { expr: Box<Expr> },
560    /// `MONTH(expr)` — extract month from a date/timestamp
561    Month { expr: Box<Expr> },
562    /// `DAY(expr)` — extract day from a date/timestamp
563    Day { expr: Box<Expr> },
564
565    // ── String ─────────────────────────────────────────────────────────
566    /// `TRIM([LEADING|TRAILING|BOTH] [chars FROM] expr)`
567    Trim {
568        expr: Box<Expr>,
569        trim_type: TrimType,
570        trim_chars: Option<Box<Expr>>,
571    },
572    /// `SUBSTRING(expr, start [, length])` / `SUBSTR`
573    Substring {
574        expr: Box<Expr>,
575        start: Box<Expr>,
576        length: Option<Box<Expr>>,
577    },
578    /// `UPPER(expr)` / `UCASE`
579    Upper { expr: Box<Expr> },
580    /// `LOWER(expr)` / `LCASE`
581    Lower { expr: Box<Expr> },
582    /// `REGEXP_LIKE(expr, pattern [, flags])` / `~` (Postgres)
583    RegexpLike {
584        expr: Box<Expr>,
585        pattern: Box<Expr>,
586        flags: Option<Box<Expr>>,
587    },
588    /// `REGEXP_EXTRACT(expr, pattern [, group_index])`
589    RegexpExtract {
590        expr: Box<Expr>,
591        pattern: Box<Expr>,
592        group_index: Option<Box<Expr>>,
593    },
594    /// `REGEXP_REPLACE(expr, pattern, replacement [, flags])`
595    RegexpReplace {
596        expr: Box<Expr>,
597        pattern: Box<Expr>,
598        replacement: Box<Expr>,
599        flags: Option<Box<Expr>>,
600    },
601    /// `CONCAT_WS(separator, expr, ...)`
602    ConcatWs {
603        separator: Box<Expr>,
604        exprs: Vec<Expr>,
605    },
606    /// `SPLIT(expr, delimiter)` / `STRING_SPLIT`
607    Split {
608        expr: Box<Expr>,
609        delimiter: Box<Expr>,
610    },
611    /// `INITCAP(expr)` — capitalize first letter of each word
612    Initcap { expr: Box<Expr> },
613    /// `LENGTH(expr)` / `LEN`
614    Length { expr: Box<Expr> },
615    /// `REPLACE(expr, from, to)`
616    Replace {
617        expr: Box<Expr>,
618        from: Box<Expr>,
619        to: Box<Expr>,
620    },
621    /// `REVERSE(expr)`
622    Reverse { expr: Box<Expr> },
623    /// `LEFT(expr, n)`
624    Left { expr: Box<Expr>, n: Box<Expr> },
625    /// `RIGHT(expr, n)`
626    Right { expr: Box<Expr>, n: Box<Expr> },
627    /// `LPAD(expr, length [, pad])`
628    Lpad {
629        expr: Box<Expr>,
630        length: Box<Expr>,
631        pad: Option<Box<Expr>>,
632    },
633    /// `RPAD(expr, length [, pad])`
634    Rpad {
635        expr: Box<Expr>,
636        length: Box<Expr>,
637        pad: Option<Box<Expr>>,
638    },
639
640    // ── Aggregate ──────────────────────────────────────────────────────
641    /// `COUNT(expr)` or `COUNT(DISTINCT expr)` or `COUNT(*)`
642    Count { expr: Box<Expr>, distinct: bool },
643    /// `SUM([DISTINCT] expr)`
644    Sum { expr: Box<Expr>, distinct: bool },
645    /// `AVG([DISTINCT] expr)`
646    Avg { expr: Box<Expr>, distinct: bool },
647    /// `MIN(expr)`
648    Min { expr: Box<Expr> },
649    /// `MAX(expr)`
650    Max { expr: Box<Expr> },
651    /// `ARRAY_AGG([DISTINCT] expr)` / `LIST` / `COLLECT_LIST`
652    ArrayAgg { expr: Box<Expr>, distinct: bool },
653    /// `APPROX_DISTINCT(expr)` / `APPROX_COUNT_DISTINCT`
654    ApproxDistinct { expr: Box<Expr> },
655    /// `VARIANCE(expr)` / `VAR_SAMP`
656    Variance { expr: Box<Expr> },
657    /// `STDDEV(expr)` / `STDDEV_SAMP`
658    Stddev { expr: Box<Expr> },
659
660    // ── Array ──────────────────────────────────────────────────────────
661    /// `ARRAY_CONCAT(arr1, arr2)` / `ARRAY_CAT`
662    ArrayConcat { arrays: Vec<Expr> },
663    /// `ARRAY_CONTAINS(array, element)` / `ARRAY_POSITION`
664    ArrayContains {
665        array: Box<Expr>,
666        element: Box<Expr>,
667    },
668    /// `ARRAY_SIZE(expr)` / `ARRAY_LENGTH` / `CARDINALITY`
669    ArraySize { expr: Box<Expr> },
670    /// `EXPLODE(expr)` — Hive/Spark array expansion
671    Explode { expr: Box<Expr> },
672    /// `GENERATE_SERIES(start, stop [, step])`
673    GenerateSeries {
674        start: Box<Expr>,
675        stop: Box<Expr>,
676        step: Option<Box<Expr>>,
677    },
678    /// `FLATTEN(expr)` — flatten nested arrays
679    Flatten { expr: Box<Expr> },
680
681    // ── JSON ───────────────────────────────────────────────────────────
682    /// `JSON_EXTRACT(expr, path)` / `JSON_VALUE` / `->` (Postgres)
683    JSONExtract { expr: Box<Expr>, path: Box<Expr> },
684    /// `JSON_EXTRACT_SCALAR(expr, path)` / `->>`
685    JSONExtractScalar { expr: Box<Expr>, path: Box<Expr> },
686    /// `PARSE_JSON(expr)` / `JSON_PARSE`
687    ParseJSON { expr: Box<Expr> },
688    /// `JSON_FORMAT(expr)` / `TO_JSON`
689    JSONFormat { expr: Box<Expr> },
690
691    // ── Window ─────────────────────────────────────────────────────────
692    /// `ROW_NUMBER()`
693    RowNumber,
694    /// `RANK()`
695    Rank,
696    /// `DENSE_RANK()`
697    DenseRank,
698    /// `NTILE(n)`
699    NTile { n: Box<Expr> },
700    /// `LEAD(expr [, offset [, default]])`
701    Lead {
702        expr: Box<Expr>,
703        offset: Option<Box<Expr>>,
704        default: Option<Box<Expr>>,
705    },
706    /// `LAG(expr [, offset [, default]])`
707    Lag {
708        expr: Box<Expr>,
709        offset: Option<Box<Expr>>,
710        default: Option<Box<Expr>>,
711    },
712    /// `FIRST_VALUE(expr)`
713    FirstValue { expr: Box<Expr> },
714    /// `LAST_VALUE(expr)`
715    LastValue { expr: Box<Expr> },
716
717    // ── Math ───────────────────────────────────────────────────────────
718    /// `ABS(expr)`
719    Abs { expr: Box<Expr> },
720    /// `CEIL(expr)` / `CEILING`
721    Ceil { expr: Box<Expr> },
722    /// `FLOOR(expr)`
723    Floor { expr: Box<Expr> },
724    /// `ROUND(expr [, decimals])`
725    Round {
726        expr: Box<Expr>,
727        decimals: Option<Box<Expr>>,
728    },
729    /// `LOG(expr [, base])` — semantics vary by dialect
730    Log {
731        expr: Box<Expr>,
732        base: Option<Box<Expr>>,
733    },
734    /// `LN(expr)` — natural logarithm
735    Ln { expr: Box<Expr> },
736    /// `POW(base, exponent)` / `POWER`
737    Pow {
738        base: Box<Expr>,
739        exponent: Box<Expr>,
740    },
741    /// `SQRT(expr)`
742    Sqrt { expr: Box<Expr> },
743    /// `GREATEST(expr, ...)`
744    Greatest { exprs: Vec<Expr> },
745    /// `LEAST(expr, ...)`
746    Least { exprs: Vec<Expr> },
747    /// `MOD(a, b)` — modulo function
748    Mod { left: Box<Expr>, right: Box<Expr> },
749
750    // ── Conversion ─────────────────────────────────────────────────────
751    /// `HEX(expr)` / `TO_HEX`
752    Hex { expr: Box<Expr> },
753    /// `UNHEX(expr)` / `FROM_HEX`
754    Unhex { expr: Box<Expr> },
755    /// `MD5(expr)`
756    Md5 { expr: Box<Expr> },
757    /// `SHA(expr)` / `SHA1`
758    Sha { expr: Box<Expr> },
759    /// `SHA2(expr, bit_length)` — SHA-256/SHA-512
760    Sha2 {
761        expr: Box<Expr>,
762        bit_length: Box<Expr>,
763    },
764}
765
766impl TypedFunction {
767    /// Walk child expressions, calling `visitor` on each.
768    pub fn walk_children<F>(&self, visitor: &mut F)
769    where
770        F: FnMut(&Expr) -> bool,
771    {
772        match self {
773            // Date/Time
774            TypedFunction::DateAdd { expr, interval, .. }
775            | TypedFunction::DateSub { expr, interval, .. } => {
776                expr.walk(visitor);
777                interval.walk(visitor);
778            }
779            TypedFunction::DateDiff { start, end, .. } => {
780                start.walk(visitor);
781                end.walk(visitor);
782            }
783            TypedFunction::DateTrunc { expr, .. } => expr.walk(visitor),
784            TypedFunction::CurrentDate | TypedFunction::CurrentTimestamp => {}
785            TypedFunction::StrToTime { expr, format }
786            | TypedFunction::TimeToStr { expr, format } => {
787                expr.walk(visitor);
788                format.walk(visitor);
789            }
790            TypedFunction::TsOrDsToDate { expr }
791            | TypedFunction::Year { expr }
792            | TypedFunction::Month { expr }
793            | TypedFunction::Day { expr } => expr.walk(visitor),
794
795            // String
796            TypedFunction::Trim {
797                expr, trim_chars, ..
798            } => {
799                expr.walk(visitor);
800                if let Some(c) = trim_chars {
801                    c.walk(visitor);
802                }
803            }
804            TypedFunction::Substring {
805                expr,
806                start,
807                length,
808            } => {
809                expr.walk(visitor);
810                start.walk(visitor);
811                if let Some(l) = length {
812                    l.walk(visitor);
813                }
814            }
815            TypedFunction::Upper { expr }
816            | TypedFunction::Lower { expr }
817            | TypedFunction::Initcap { expr }
818            | TypedFunction::Length { expr }
819            | TypedFunction::Reverse { expr } => expr.walk(visitor),
820            TypedFunction::RegexpLike {
821                expr,
822                pattern,
823                flags,
824            } => {
825                expr.walk(visitor);
826                pattern.walk(visitor);
827                if let Some(f) = flags {
828                    f.walk(visitor);
829                }
830            }
831            TypedFunction::RegexpExtract {
832                expr,
833                pattern,
834                group_index,
835            } => {
836                expr.walk(visitor);
837                pattern.walk(visitor);
838                if let Some(g) = group_index {
839                    g.walk(visitor);
840                }
841            }
842            TypedFunction::RegexpReplace {
843                expr,
844                pattern,
845                replacement,
846                flags,
847            } => {
848                expr.walk(visitor);
849                pattern.walk(visitor);
850                replacement.walk(visitor);
851                if let Some(f) = flags {
852                    f.walk(visitor);
853                }
854            }
855            TypedFunction::ConcatWs { separator, exprs } => {
856                separator.walk(visitor);
857                for e in exprs {
858                    e.walk(visitor);
859                }
860            }
861            TypedFunction::Split { expr, delimiter } => {
862                expr.walk(visitor);
863                delimiter.walk(visitor);
864            }
865            TypedFunction::Replace { expr, from, to } => {
866                expr.walk(visitor);
867                from.walk(visitor);
868                to.walk(visitor);
869            }
870            TypedFunction::Left { expr, n } | TypedFunction::Right { expr, n } => {
871                expr.walk(visitor);
872                n.walk(visitor);
873            }
874            TypedFunction::Lpad { expr, length, pad }
875            | TypedFunction::Rpad { expr, length, pad } => {
876                expr.walk(visitor);
877                length.walk(visitor);
878                if let Some(p) = pad {
879                    p.walk(visitor);
880                }
881            }
882
883            // Aggregate
884            TypedFunction::Count { expr, .. }
885            | TypedFunction::Sum { expr, .. }
886            | TypedFunction::Avg { expr, .. }
887            | TypedFunction::Min { expr }
888            | TypedFunction::Max { expr }
889            | TypedFunction::ArrayAgg { expr, .. }
890            | TypedFunction::ApproxDistinct { expr }
891            | TypedFunction::Variance { expr }
892            | TypedFunction::Stddev { expr } => expr.walk(visitor),
893
894            // Array
895            TypedFunction::ArrayConcat { arrays } => {
896                for a in arrays {
897                    a.walk(visitor);
898                }
899            }
900            TypedFunction::ArrayContains { array, element } => {
901                array.walk(visitor);
902                element.walk(visitor);
903            }
904            TypedFunction::ArraySize { expr }
905            | TypedFunction::Explode { expr }
906            | TypedFunction::Flatten { expr } => expr.walk(visitor),
907            TypedFunction::GenerateSeries { start, stop, step } => {
908                start.walk(visitor);
909                stop.walk(visitor);
910                if let Some(s) = step {
911                    s.walk(visitor);
912                }
913            }
914
915            // JSON
916            TypedFunction::JSONExtract { expr, path }
917            | TypedFunction::JSONExtractScalar { expr, path } => {
918                expr.walk(visitor);
919                path.walk(visitor);
920            }
921            TypedFunction::ParseJSON { expr } | TypedFunction::JSONFormat { expr } => {
922                expr.walk(visitor)
923            }
924
925            // Window
926            TypedFunction::RowNumber | TypedFunction::Rank | TypedFunction::DenseRank => {}
927            TypedFunction::NTile { n } => n.walk(visitor),
928            TypedFunction::Lead {
929                expr,
930                offset,
931                default,
932            }
933            | TypedFunction::Lag {
934                expr,
935                offset,
936                default,
937            } => {
938                expr.walk(visitor);
939                if let Some(o) = offset {
940                    o.walk(visitor);
941                }
942                if let Some(d) = default {
943                    d.walk(visitor);
944                }
945            }
946            TypedFunction::FirstValue { expr } | TypedFunction::LastValue { expr } => {
947                expr.walk(visitor)
948            }
949
950            // Math
951            TypedFunction::Abs { expr }
952            | TypedFunction::Ceil { expr }
953            | TypedFunction::Floor { expr }
954            | TypedFunction::Ln { expr }
955            | TypedFunction::Sqrt { expr } => expr.walk(visitor),
956            TypedFunction::Round { expr, decimals } => {
957                expr.walk(visitor);
958                if let Some(d) = decimals {
959                    d.walk(visitor);
960                }
961            }
962            TypedFunction::Log { expr, base } => {
963                expr.walk(visitor);
964                if let Some(b) = base {
965                    b.walk(visitor);
966                }
967            }
968            TypedFunction::Pow { base, exponent } => {
969                base.walk(visitor);
970                exponent.walk(visitor);
971            }
972            TypedFunction::Greatest { exprs } | TypedFunction::Least { exprs } => {
973                for e in exprs {
974                    e.walk(visitor);
975                }
976            }
977            TypedFunction::Mod { left, right } => {
978                left.walk(visitor);
979                right.walk(visitor);
980            }
981
982            // Conversion
983            TypedFunction::Hex { expr }
984            | TypedFunction::Unhex { expr }
985            | TypedFunction::Md5 { expr }
986            | TypedFunction::Sha { expr } => expr.walk(visitor),
987            TypedFunction::Sha2 { expr, bit_length } => {
988                expr.walk(visitor);
989                bit_length.walk(visitor);
990            }
991        }
992    }
993
994    /// Transform child expressions, returning a new `TypedFunction`.
995    #[must_use]
996    pub fn transform_children<F>(self, func: &F) -> TypedFunction
997    where
998        F: Fn(Expr) -> Expr,
999    {
1000        match self {
1001            // Date/Time
1002            TypedFunction::DateAdd {
1003                expr,
1004                interval,
1005                unit,
1006            } => TypedFunction::DateAdd {
1007                expr: Box::new(expr.transform(func)),
1008                interval: Box::new(interval.transform(func)),
1009                unit,
1010            },
1011            TypedFunction::DateDiff { start, end, unit } => TypedFunction::DateDiff {
1012                start: Box::new(start.transform(func)),
1013                end: Box::new(end.transform(func)),
1014                unit,
1015            },
1016            TypedFunction::DateTrunc { unit, expr } => TypedFunction::DateTrunc {
1017                unit,
1018                expr: Box::new(expr.transform(func)),
1019            },
1020            TypedFunction::DateSub {
1021                expr,
1022                interval,
1023                unit,
1024            } => TypedFunction::DateSub {
1025                expr: Box::new(expr.transform(func)),
1026                interval: Box::new(interval.transform(func)),
1027                unit,
1028            },
1029            TypedFunction::CurrentDate => TypedFunction::CurrentDate,
1030            TypedFunction::CurrentTimestamp => TypedFunction::CurrentTimestamp,
1031            TypedFunction::StrToTime { expr, format } => TypedFunction::StrToTime {
1032                expr: Box::new(expr.transform(func)),
1033                format: Box::new(format.transform(func)),
1034            },
1035            TypedFunction::TimeToStr { expr, format } => TypedFunction::TimeToStr {
1036                expr: Box::new(expr.transform(func)),
1037                format: Box::new(format.transform(func)),
1038            },
1039            TypedFunction::TsOrDsToDate { expr } => TypedFunction::TsOrDsToDate {
1040                expr: Box::new(expr.transform(func)),
1041            },
1042            TypedFunction::Year { expr } => TypedFunction::Year {
1043                expr: Box::new(expr.transform(func)),
1044            },
1045            TypedFunction::Month { expr } => TypedFunction::Month {
1046                expr: Box::new(expr.transform(func)),
1047            },
1048            TypedFunction::Day { expr } => TypedFunction::Day {
1049                expr: Box::new(expr.transform(func)),
1050            },
1051
1052            // String
1053            TypedFunction::Trim {
1054                expr,
1055                trim_type,
1056                trim_chars,
1057            } => TypedFunction::Trim {
1058                expr: Box::new(expr.transform(func)),
1059                trim_type,
1060                trim_chars: trim_chars.map(|c| Box::new(c.transform(func))),
1061            },
1062            TypedFunction::Substring {
1063                expr,
1064                start,
1065                length,
1066            } => TypedFunction::Substring {
1067                expr: Box::new(expr.transform(func)),
1068                start: Box::new(start.transform(func)),
1069                length: length.map(|l| Box::new(l.transform(func))),
1070            },
1071            TypedFunction::Upper { expr } => TypedFunction::Upper {
1072                expr: Box::new(expr.transform(func)),
1073            },
1074            TypedFunction::Lower { expr } => TypedFunction::Lower {
1075                expr: Box::new(expr.transform(func)),
1076            },
1077            TypedFunction::RegexpLike {
1078                expr,
1079                pattern,
1080                flags,
1081            } => TypedFunction::RegexpLike {
1082                expr: Box::new(expr.transform(func)),
1083                pattern: Box::new(pattern.transform(func)),
1084                flags: flags.map(|f| Box::new(f.transform(func))),
1085            },
1086            TypedFunction::RegexpExtract {
1087                expr,
1088                pattern,
1089                group_index,
1090            } => TypedFunction::RegexpExtract {
1091                expr: Box::new(expr.transform(func)),
1092                pattern: Box::new(pattern.transform(func)),
1093                group_index: group_index.map(|g| Box::new(g.transform(func))),
1094            },
1095            TypedFunction::RegexpReplace {
1096                expr,
1097                pattern,
1098                replacement,
1099                flags,
1100            } => TypedFunction::RegexpReplace {
1101                expr: Box::new(expr.transform(func)),
1102                pattern: Box::new(pattern.transform(func)),
1103                replacement: Box::new(replacement.transform(func)),
1104                flags: flags.map(|f| Box::new(f.transform(func))),
1105            },
1106            TypedFunction::ConcatWs { separator, exprs } => TypedFunction::ConcatWs {
1107                separator: Box::new(separator.transform(func)),
1108                exprs: exprs.into_iter().map(|e| e.transform(func)).collect(),
1109            },
1110            TypedFunction::Split { expr, delimiter } => TypedFunction::Split {
1111                expr: Box::new(expr.transform(func)),
1112                delimiter: Box::new(delimiter.transform(func)),
1113            },
1114            TypedFunction::Initcap { expr } => TypedFunction::Initcap {
1115                expr: Box::new(expr.transform(func)),
1116            },
1117            TypedFunction::Length { expr } => TypedFunction::Length {
1118                expr: Box::new(expr.transform(func)),
1119            },
1120            TypedFunction::Replace { expr, from, to } => TypedFunction::Replace {
1121                expr: Box::new(expr.transform(func)),
1122                from: Box::new(from.transform(func)),
1123                to: Box::new(to.transform(func)),
1124            },
1125            TypedFunction::Reverse { expr } => TypedFunction::Reverse {
1126                expr: Box::new(expr.transform(func)),
1127            },
1128            TypedFunction::Left { expr, n } => TypedFunction::Left {
1129                expr: Box::new(expr.transform(func)),
1130                n: Box::new(n.transform(func)),
1131            },
1132            TypedFunction::Right { expr, n } => TypedFunction::Right {
1133                expr: Box::new(expr.transform(func)),
1134                n: Box::new(n.transform(func)),
1135            },
1136            TypedFunction::Lpad { expr, length, pad } => TypedFunction::Lpad {
1137                expr: Box::new(expr.transform(func)),
1138                length: Box::new(length.transform(func)),
1139                pad: pad.map(|p| Box::new(p.transform(func))),
1140            },
1141            TypedFunction::Rpad { expr, length, pad } => TypedFunction::Rpad {
1142                expr: Box::new(expr.transform(func)),
1143                length: Box::new(length.transform(func)),
1144                pad: pad.map(|p| Box::new(p.transform(func))),
1145            },
1146
1147            // Aggregate
1148            TypedFunction::Count { expr, distinct } => TypedFunction::Count {
1149                expr: Box::new(expr.transform(func)),
1150                distinct,
1151            },
1152            TypedFunction::Sum { expr, distinct } => TypedFunction::Sum {
1153                expr: Box::new(expr.transform(func)),
1154                distinct,
1155            },
1156            TypedFunction::Avg { expr, distinct } => TypedFunction::Avg {
1157                expr: Box::new(expr.transform(func)),
1158                distinct,
1159            },
1160            TypedFunction::Min { expr } => TypedFunction::Min {
1161                expr: Box::new(expr.transform(func)),
1162            },
1163            TypedFunction::Max { expr } => TypedFunction::Max {
1164                expr: Box::new(expr.transform(func)),
1165            },
1166            TypedFunction::ArrayAgg { expr, distinct } => TypedFunction::ArrayAgg {
1167                expr: Box::new(expr.transform(func)),
1168                distinct,
1169            },
1170            TypedFunction::ApproxDistinct { expr } => TypedFunction::ApproxDistinct {
1171                expr: Box::new(expr.transform(func)),
1172            },
1173            TypedFunction::Variance { expr } => TypedFunction::Variance {
1174                expr: Box::new(expr.transform(func)),
1175            },
1176            TypedFunction::Stddev { expr } => TypedFunction::Stddev {
1177                expr: Box::new(expr.transform(func)),
1178            },
1179
1180            // Array
1181            TypedFunction::ArrayConcat { arrays } => TypedFunction::ArrayConcat {
1182                arrays: arrays.into_iter().map(|a| a.transform(func)).collect(),
1183            },
1184            TypedFunction::ArrayContains { array, element } => TypedFunction::ArrayContains {
1185                array: Box::new(array.transform(func)),
1186                element: Box::new(element.transform(func)),
1187            },
1188            TypedFunction::ArraySize { expr } => TypedFunction::ArraySize {
1189                expr: Box::new(expr.transform(func)),
1190            },
1191            TypedFunction::Explode { expr } => TypedFunction::Explode {
1192                expr: Box::new(expr.transform(func)),
1193            },
1194            TypedFunction::GenerateSeries { start, stop, step } => TypedFunction::GenerateSeries {
1195                start: Box::new(start.transform(func)),
1196                stop: Box::new(stop.transform(func)),
1197                step: step.map(|s| Box::new(s.transform(func))),
1198            },
1199            TypedFunction::Flatten { expr } => TypedFunction::Flatten {
1200                expr: Box::new(expr.transform(func)),
1201            },
1202
1203            // JSON
1204            TypedFunction::JSONExtract { expr, path } => TypedFunction::JSONExtract {
1205                expr: Box::new(expr.transform(func)),
1206                path: Box::new(path.transform(func)),
1207            },
1208            TypedFunction::JSONExtractScalar { expr, path } => TypedFunction::JSONExtractScalar {
1209                expr: Box::new(expr.transform(func)),
1210                path: Box::new(path.transform(func)),
1211            },
1212            TypedFunction::ParseJSON { expr } => TypedFunction::ParseJSON {
1213                expr: Box::new(expr.transform(func)),
1214            },
1215            TypedFunction::JSONFormat { expr } => TypedFunction::JSONFormat {
1216                expr: Box::new(expr.transform(func)),
1217            },
1218
1219            // Window
1220            TypedFunction::RowNumber => TypedFunction::RowNumber,
1221            TypedFunction::Rank => TypedFunction::Rank,
1222            TypedFunction::DenseRank => TypedFunction::DenseRank,
1223            TypedFunction::NTile { n } => TypedFunction::NTile {
1224                n: Box::new(n.transform(func)),
1225            },
1226            TypedFunction::Lead {
1227                expr,
1228                offset,
1229                default,
1230            } => TypedFunction::Lead {
1231                expr: Box::new(expr.transform(func)),
1232                offset: offset.map(|o| Box::new(o.transform(func))),
1233                default: default.map(|d| Box::new(d.transform(func))),
1234            },
1235            TypedFunction::Lag {
1236                expr,
1237                offset,
1238                default,
1239            } => TypedFunction::Lag {
1240                expr: Box::new(expr.transform(func)),
1241                offset: offset.map(|o| Box::new(o.transform(func))),
1242                default: default.map(|d| Box::new(d.transform(func))),
1243            },
1244            TypedFunction::FirstValue { expr } => TypedFunction::FirstValue {
1245                expr: Box::new(expr.transform(func)),
1246            },
1247            TypedFunction::LastValue { expr } => TypedFunction::LastValue {
1248                expr: Box::new(expr.transform(func)),
1249            },
1250
1251            // Math
1252            TypedFunction::Abs { expr } => TypedFunction::Abs {
1253                expr: Box::new(expr.transform(func)),
1254            },
1255            TypedFunction::Ceil { expr } => TypedFunction::Ceil {
1256                expr: Box::new(expr.transform(func)),
1257            },
1258            TypedFunction::Floor { expr } => TypedFunction::Floor {
1259                expr: Box::new(expr.transform(func)),
1260            },
1261            TypedFunction::Round { expr, decimals } => TypedFunction::Round {
1262                expr: Box::new(expr.transform(func)),
1263                decimals: decimals.map(|d| Box::new(d.transform(func))),
1264            },
1265            TypedFunction::Log { expr, base } => TypedFunction::Log {
1266                expr: Box::new(expr.transform(func)),
1267                base: base.map(|b| Box::new(b.transform(func))),
1268            },
1269            TypedFunction::Ln { expr } => TypedFunction::Ln {
1270                expr: Box::new(expr.transform(func)),
1271            },
1272            TypedFunction::Pow { base, exponent } => TypedFunction::Pow {
1273                base: Box::new(base.transform(func)),
1274                exponent: Box::new(exponent.transform(func)),
1275            },
1276            TypedFunction::Sqrt { expr } => TypedFunction::Sqrt {
1277                expr: Box::new(expr.transform(func)),
1278            },
1279            TypedFunction::Greatest { exprs } => TypedFunction::Greatest {
1280                exprs: exprs.into_iter().map(|e| e.transform(func)).collect(),
1281            },
1282            TypedFunction::Least { exprs } => TypedFunction::Least {
1283                exprs: exprs.into_iter().map(|e| e.transform(func)).collect(),
1284            },
1285            TypedFunction::Mod { left, right } => TypedFunction::Mod {
1286                left: Box::new(left.transform(func)),
1287                right: Box::new(right.transform(func)),
1288            },
1289
1290            // Conversion
1291            TypedFunction::Hex { expr } => TypedFunction::Hex {
1292                expr: Box::new(expr.transform(func)),
1293            },
1294            TypedFunction::Unhex { expr } => TypedFunction::Unhex {
1295                expr: Box::new(expr.transform(func)),
1296            },
1297            TypedFunction::Md5 { expr } => TypedFunction::Md5 {
1298                expr: Box::new(expr.transform(func)),
1299            },
1300            TypedFunction::Sha { expr } => TypedFunction::Sha {
1301                expr: Box::new(expr.transform(func)),
1302            },
1303            TypedFunction::Sha2 { expr, bit_length } => TypedFunction::Sha2 {
1304                expr: Box::new(expr.transform(func)),
1305                bit_length: Box::new(bit_length.transform(func)),
1306            },
1307        }
1308    }
1309}
1310
1311// ═══════════════════════════════════════════════════════════════════════
1312// Operators
1313// ═══════════════════════════════════════════════════════════════════════
1314
1315/// Binary operators.
1316#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
1317pub enum BinaryOperator {
1318    Plus,
1319    Minus,
1320    Multiply,
1321    Divide,
1322    Modulo,
1323    Eq,
1324    Neq,
1325    Lt,
1326    Gt,
1327    LtEq,
1328    GtEq,
1329    And,
1330    Or,
1331    Xor,
1332    Concat,
1333    BitwiseAnd,
1334    BitwiseOr,
1335    BitwiseXor,
1336    ShiftLeft,
1337    ShiftRight,
1338    /// `->` JSON access operator
1339    Arrow,
1340    /// `->>` JSON text access
1341    DoubleArrow,
1342}
1343
1344/// Unary operators.
1345#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
1346pub enum UnaryOperator {
1347    Not,
1348    Minus,
1349    Plus,
1350    BitwiseNot,
1351}
1352
1353// ═══════════════════════════════════════════════════════════════════════
1354// DML statements
1355// ═══════════════════════════════════════════════════════════════════════
1356
1357/// An INSERT statement, now supporting INSERT ... SELECT and ON CONFLICT.
1358#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1359pub struct InsertStatement {
1360    pub table: TableRef,
1361    pub columns: Vec<String>,
1362    pub source: InsertSource,
1363    /// ON CONFLICT / ON DUPLICATE KEY
1364    pub on_conflict: Option<OnConflict>,
1365    /// RETURNING clause
1366    pub returning: Vec<SelectItem>,
1367}
1368
1369#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1370pub enum InsertSource {
1371    Values(Vec<Vec<Expr>>),
1372    Query(Box<Statement>),
1373    Default,
1374}
1375
1376#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1377pub struct OnConflict {
1378    pub columns: Vec<String>,
1379    pub action: ConflictAction,
1380}
1381
1382#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1383pub enum ConflictAction {
1384    DoNothing,
1385    DoUpdate(Vec<(String, Expr)>),
1386}
1387
1388/// An UPDATE statement.
1389#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1390pub struct UpdateStatement {
1391    pub table: TableRef,
1392    pub assignments: Vec<(String, Expr)>,
1393    pub from: Option<FromClause>,
1394    pub where_clause: Option<Expr>,
1395    pub returning: Vec<SelectItem>,
1396}
1397
1398/// A DELETE statement.
1399#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1400pub struct DeleteStatement {
1401    pub table: TableRef,
1402    pub using: Option<FromClause>,
1403    pub where_clause: Option<Expr>,
1404    pub returning: Vec<SelectItem>,
1405}
1406
1407// ═══════════════════════════════════════════════════════════════════════
1408// DDL statements
1409// ═══════════════════════════════════════════════════════════════════════
1410
1411/// A CREATE TABLE statement.
1412#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1413pub struct CreateTableStatement {
1414    pub if_not_exists: bool,
1415    pub temporary: bool,
1416    pub table: TableRef,
1417    pub columns: Vec<ColumnDef>,
1418    pub constraints: Vec<TableConstraint>,
1419    /// CREATE TABLE ... AS SELECT ...
1420    pub as_select: Option<Box<Statement>>,
1421}
1422
1423/// Table-level constraints.
1424#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1425pub enum TableConstraint {
1426    PrimaryKey {
1427        name: Option<String>,
1428        columns: Vec<String>,
1429    },
1430    Unique {
1431        name: Option<String>,
1432        columns: Vec<String>,
1433    },
1434    ForeignKey {
1435        name: Option<String>,
1436        columns: Vec<String>,
1437        ref_table: TableRef,
1438        ref_columns: Vec<String>,
1439        on_delete: Option<ReferentialAction>,
1440        on_update: Option<ReferentialAction>,
1441    },
1442    Check {
1443        name: Option<String>,
1444        expr: Expr,
1445    },
1446}
1447
1448#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
1449pub enum ReferentialAction {
1450    Cascade,
1451    Restrict,
1452    NoAction,
1453    SetNull,
1454    SetDefault,
1455}
1456
1457/// A column definition in CREATE TABLE.
1458#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1459pub struct ColumnDef {
1460    pub name: String,
1461    pub data_type: DataType,
1462    pub nullable: Option<bool>,
1463    pub default: Option<Expr>,
1464    pub primary_key: bool,
1465    pub unique: bool,
1466    pub auto_increment: bool,
1467    pub collation: Option<String>,
1468    pub comment: Option<String>,
1469}
1470
1471/// ALTER TABLE statement.
1472#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1473pub struct AlterTableStatement {
1474    pub table: TableRef,
1475    pub actions: Vec<AlterTableAction>,
1476}
1477
1478#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1479pub enum AlterTableAction {
1480    AddColumn(ColumnDef),
1481    DropColumn { name: String, if_exists: bool },
1482    RenameColumn { old_name: String, new_name: String },
1483    AlterColumnType { name: String, data_type: DataType },
1484    AddConstraint(TableConstraint),
1485    DropConstraint { name: String },
1486    RenameTable { new_name: String },
1487}
1488
1489/// CREATE VIEW statement.
1490#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1491pub struct CreateViewStatement {
1492    pub name: TableRef,
1493    pub columns: Vec<String>,
1494    pub query: Box<Statement>,
1495    pub or_replace: bool,
1496    pub materialized: bool,
1497    pub if_not_exists: bool,
1498}
1499
1500/// DROP VIEW statement.
1501#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1502pub struct DropViewStatement {
1503    pub name: TableRef,
1504    pub if_exists: bool,
1505    pub materialized: bool,
1506}
1507
1508/// TRUNCATE TABLE statement.
1509#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1510pub struct TruncateStatement {
1511    pub table: TableRef,
1512}
1513
1514/// Transaction control statements.
1515#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1516pub enum TransactionStatement {
1517    Begin,
1518    Commit,
1519    Rollback,
1520    Savepoint(String),
1521    ReleaseSavepoint(String),
1522    RollbackTo(String),
1523}
1524
1525/// EXPLAIN statement.
1526#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1527pub struct ExplainStatement {
1528    pub analyze: bool,
1529    pub statement: Box<Statement>,
1530}
1531
1532/// USE database statement.
1533#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1534pub struct UseStatement {
1535    pub name: String,
1536}
1537
1538/// A DROP TABLE statement.
1539#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1540pub struct DropTableStatement {
1541    pub if_exists: bool,
1542    pub table: TableRef,
1543    pub cascade: bool,
1544}
1545
1546// ═══════════════════════════════════════════════════════════════════════
1547// Data types
1548// ═══════════════════════════════════════════════════════════════════════
1549
1550/// SQL data types. Significantly expanded to match sqlglot's DataType.Type enum.
1551#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
1552pub enum DataType {
1553    // Numeric
1554    TinyInt,
1555    SmallInt,
1556    Int,
1557    BigInt,
1558    Float,
1559    Double,
1560    Decimal {
1561        precision: Option<u32>,
1562        scale: Option<u32>,
1563    },
1564    Numeric {
1565        precision: Option<u32>,
1566        scale: Option<u32>,
1567    },
1568    Real,
1569
1570    // String
1571    Varchar(Option<u32>),
1572    Char(Option<u32>),
1573    Text,
1574    String,
1575    Binary(Option<u32>),
1576    Varbinary(Option<u32>),
1577
1578    // Boolean
1579    Boolean,
1580
1581    // Date/Time
1582    Date,
1583    Time {
1584        precision: Option<u32>,
1585    },
1586    Timestamp {
1587        precision: Option<u32>,
1588        with_tz: bool,
1589    },
1590    Interval,
1591    DateTime,
1592
1593    // Binary
1594    Blob,
1595    Bytea,
1596    Bytes,
1597
1598    // JSON
1599    Json,
1600    Jsonb,
1601
1602    // UUID
1603    Uuid,
1604
1605    // Complex types
1606    Array(Option<Box<DataType>>),
1607    Map {
1608        key: Box<DataType>,
1609        value: Box<DataType>,
1610    },
1611    Struct(Vec<(String, DataType)>),
1612    Tuple(Vec<DataType>),
1613
1614    // Special
1615    Null,
1616    Unknown(String),
1617    Variant,
1618    Object,
1619    Xml,
1620    Inet,
1621    Cidr,
1622    Macaddr,
1623    Bit(Option<u32>),
1624    Money,
1625    Serial,
1626    BigSerial,
1627    SmallSerial,
1628    Regclass,
1629    Regtype,
1630    Hstore,
1631    Geography,
1632    Geometry,
1633    Super,
1634}
1635
1636// ═══════════════════════════════════════════════════════════════════════
1637// Expression tree traversal helpers
1638// ═══════════════════════════════════════════════════════════════════════
1639
1640impl Expr {
1641    /// Recursively walk this expression tree, calling `visitor` on each node.
1642    /// If `visitor` returns `false`, children of that node are not visited.
1643    pub fn walk<F>(&self, visitor: &mut F)
1644    where
1645        F: FnMut(&Expr) -> bool,
1646    {
1647        if !visitor(self) {
1648            return;
1649        }
1650        match self {
1651            Expr::BinaryOp { left, right, .. } => {
1652                left.walk(visitor);
1653                right.walk(visitor);
1654            }
1655            Expr::UnaryOp { expr, .. } => expr.walk(visitor),
1656            Expr::Function { args, filter, .. } => {
1657                for arg in args {
1658                    arg.walk(visitor);
1659                }
1660                if let Some(f) = filter {
1661                    f.walk(visitor);
1662                }
1663            }
1664            Expr::Between {
1665                expr, low, high, ..
1666            } => {
1667                expr.walk(visitor);
1668                low.walk(visitor);
1669                high.walk(visitor);
1670            }
1671            Expr::InList { expr, list, .. } => {
1672                expr.walk(visitor);
1673                for item in list {
1674                    item.walk(visitor);
1675                }
1676            }
1677            Expr::InSubquery { expr, .. } => {
1678                expr.walk(visitor);
1679            }
1680            Expr::IsNull { expr, .. } => expr.walk(visitor),
1681            Expr::IsBool { expr, .. } => expr.walk(visitor),
1682            Expr::AnyOp { expr, right, .. } | Expr::AllOp { expr, right, .. } => {
1683                expr.walk(visitor);
1684                right.walk(visitor);
1685            }
1686            Expr::Like { expr, pattern, .. } | Expr::ILike { expr, pattern, .. } => {
1687                expr.walk(visitor);
1688                pattern.walk(visitor);
1689            }
1690            Expr::Case {
1691                operand,
1692                when_clauses,
1693                else_clause,
1694            } => {
1695                if let Some(op) = operand {
1696                    op.walk(visitor);
1697                }
1698                for (cond, result) in when_clauses {
1699                    cond.walk(visitor);
1700                    result.walk(visitor);
1701                }
1702                if let Some(el) = else_clause {
1703                    el.walk(visitor);
1704                }
1705            }
1706            Expr::Nested(inner) => inner.walk(visitor),
1707            Expr::Cast { expr, .. } | Expr::TryCast { expr, .. } => expr.walk(visitor),
1708            Expr::Extract { expr, .. } => expr.walk(visitor),
1709            Expr::Interval { value, .. } => value.walk(visitor),
1710            Expr::ArrayLiteral(items) | Expr::Tuple(items) | Expr::Coalesce(items) => {
1711                for item in items {
1712                    item.walk(visitor);
1713                }
1714            }
1715            Expr::If {
1716                condition,
1717                true_val,
1718                false_val,
1719            } => {
1720                condition.walk(visitor);
1721                true_val.walk(visitor);
1722                if let Some(fv) = false_val {
1723                    fv.walk(visitor);
1724                }
1725            }
1726            Expr::NullIf { expr, r#else } => {
1727                expr.walk(visitor);
1728                r#else.walk(visitor);
1729            }
1730            Expr::Collate { expr, .. } => expr.walk(visitor),
1731            Expr::Alias { expr, .. } => expr.walk(visitor),
1732            Expr::ArrayIndex { expr, index } => {
1733                expr.walk(visitor);
1734                index.walk(visitor);
1735            }
1736            Expr::JsonAccess { expr, path, .. } => {
1737                expr.walk(visitor);
1738                path.walk(visitor);
1739            }
1740            Expr::Lambda { body, .. } => body.walk(visitor),
1741            Expr::TypedFunction { func, filter, .. } => {
1742                func.walk_children(visitor);
1743                if let Some(f) = filter {
1744                    f.walk(visitor);
1745                }
1746            }
1747            // Leaf nodes
1748            Expr::Column { .. }
1749            | Expr::Number(_)
1750            | Expr::StringLiteral(_)
1751            | Expr::Boolean(_)
1752            | Expr::Null
1753            | Expr::Wildcard
1754            | Expr::Star
1755            | Expr::Parameter(_)
1756            | Expr::TypeExpr(_)
1757            | Expr::QualifiedWildcard { .. }
1758            | Expr::Default
1759            | Expr::Subquery(_)
1760            | Expr::Exists { .. } => {}
1761        }
1762    }
1763
1764    /// Find the first expression matching the predicate.
1765    #[must_use]
1766    pub fn find<F>(&self, predicate: &F) -> Option<&Expr>
1767    where
1768        F: Fn(&Expr) -> bool,
1769    {
1770        let mut result = None;
1771        self.walk(&mut |expr| {
1772            if result.is_some() {
1773                return false;
1774            }
1775            if predicate(expr) {
1776                result = Some(expr as *const Expr);
1777                false
1778            } else {
1779                true
1780            }
1781        });
1782        // SAFETY: the pointer is valid as long as self is alive
1783        result.map(|p| unsafe { &*p })
1784    }
1785
1786    /// Find all expressions matching the predicate.
1787    #[must_use]
1788    pub fn find_all<F>(&self, predicate: &F) -> Vec<&Expr>
1789    where
1790        F: Fn(&Expr) -> bool,
1791    {
1792        let mut results: Vec<*const Expr> = Vec::new();
1793        self.walk(&mut |expr| {
1794            if predicate(expr) {
1795                results.push(expr as *const Expr);
1796            }
1797            true
1798        });
1799        results.into_iter().map(|p| unsafe { &*p }).collect()
1800    }
1801
1802    /// Transform this expression tree by applying a function to each node.
1803    /// The function can return a new expression to replace the current one.
1804    #[must_use]
1805    pub fn transform<F>(self, func: &F) -> Expr
1806    where
1807        F: Fn(Expr) -> Expr,
1808    {
1809        let transformed = match self {
1810            Expr::BinaryOp { left, op, right } => Expr::BinaryOp {
1811                left: Box::new(left.transform(func)),
1812                op,
1813                right: Box::new(right.transform(func)),
1814            },
1815            Expr::UnaryOp { op, expr } => Expr::UnaryOp {
1816                op,
1817                expr: Box::new(expr.transform(func)),
1818            },
1819            Expr::Function {
1820                name,
1821                args,
1822                distinct,
1823                filter,
1824                over,
1825            } => Expr::Function {
1826                name,
1827                args: args.into_iter().map(|a| a.transform(func)).collect(),
1828                distinct,
1829                filter: filter.map(|f| Box::new(f.transform(func))),
1830                over,
1831            },
1832            Expr::Nested(inner) => Expr::Nested(Box::new(inner.transform(func))),
1833            Expr::Cast { expr, data_type } => Expr::Cast {
1834                expr: Box::new(expr.transform(func)),
1835                data_type,
1836            },
1837            Expr::Between {
1838                expr,
1839                low,
1840                high,
1841                negated,
1842            } => Expr::Between {
1843                expr: Box::new(expr.transform(func)),
1844                low: Box::new(low.transform(func)),
1845                high: Box::new(high.transform(func)),
1846                negated,
1847            },
1848            Expr::Case {
1849                operand,
1850                when_clauses,
1851                else_clause,
1852            } => Expr::Case {
1853                operand: operand.map(|o| Box::new(o.transform(func))),
1854                when_clauses: when_clauses
1855                    .into_iter()
1856                    .map(|(c, r)| (c.transform(func), r.transform(func)))
1857                    .collect(),
1858                else_clause: else_clause.map(|e| Box::new(e.transform(func))),
1859            },
1860            Expr::IsBool {
1861                expr,
1862                value,
1863                negated,
1864            } => Expr::IsBool {
1865                expr: Box::new(expr.transform(func)),
1866                value,
1867                negated,
1868            },
1869            Expr::AnyOp { expr, op, right } => Expr::AnyOp {
1870                expr: Box::new(expr.transform(func)),
1871                op,
1872                right: Box::new(right.transform(func)),
1873            },
1874            Expr::AllOp { expr, op, right } => Expr::AllOp {
1875                expr: Box::new(expr.transform(func)),
1876                op,
1877                right: Box::new(right.transform(func)),
1878            },
1879            Expr::TypedFunction {
1880                func: tf,
1881                filter,
1882                over,
1883            } => Expr::TypedFunction {
1884                func: tf.transform_children(func),
1885                filter: filter.map(|f| Box::new(f.transform(func))),
1886                over,
1887            },
1888            Expr::InList {
1889                expr,
1890                list,
1891                negated,
1892            } => Expr::InList {
1893                expr: Box::new(expr.transform(func)),
1894                list: list.into_iter().map(|e| e.transform(func)).collect(),
1895                negated,
1896            },
1897            Expr::InSubquery {
1898                expr,
1899                subquery,
1900                negated,
1901            } => Expr::InSubquery {
1902                expr: Box::new(expr.transform(func)),
1903                subquery, // Statement — not transformable via Expr func
1904                negated,
1905            },
1906            Expr::IsNull { expr, negated } => Expr::IsNull {
1907                expr: Box::new(expr.transform(func)),
1908                negated,
1909            },
1910            Expr::Like {
1911                expr,
1912                pattern,
1913                negated,
1914                escape,
1915            } => Expr::Like {
1916                expr: Box::new(expr.transform(func)),
1917                pattern: Box::new(pattern.transform(func)),
1918                negated,
1919                escape: escape.map(|e| Box::new(e.transform(func))),
1920            },
1921            Expr::ILike {
1922                expr,
1923                pattern,
1924                negated,
1925                escape,
1926            } => Expr::ILike {
1927                expr: Box::new(expr.transform(func)),
1928                pattern: Box::new(pattern.transform(func)),
1929                negated,
1930                escape: escape.map(|e| Box::new(e.transform(func))),
1931            },
1932            Expr::TryCast { expr, data_type } => Expr::TryCast {
1933                expr: Box::new(expr.transform(func)),
1934                data_type,
1935            },
1936            Expr::Extract { field, expr } => Expr::Extract {
1937                field,
1938                expr: Box::new(expr.transform(func)),
1939            },
1940            Expr::Interval { value, unit } => Expr::Interval {
1941                value: Box::new(value.transform(func)),
1942                unit,
1943            },
1944            Expr::ArrayLiteral(elems) => {
1945                Expr::ArrayLiteral(elems.into_iter().map(|e| e.transform(func)).collect())
1946            }
1947            Expr::Tuple(elems) => {
1948                Expr::Tuple(elems.into_iter().map(|e| e.transform(func)).collect())
1949            }
1950            Expr::Coalesce(elems) => {
1951                Expr::Coalesce(elems.into_iter().map(|e| e.transform(func)).collect())
1952            }
1953            Expr::If {
1954                condition,
1955                true_val,
1956                false_val,
1957            } => Expr::If {
1958                condition: Box::new(condition.transform(func)),
1959                true_val: Box::new(true_val.transform(func)),
1960                false_val: false_val.map(|f| Box::new(f.transform(func))),
1961            },
1962            Expr::NullIf { expr, r#else } => Expr::NullIf {
1963                expr: Box::new(expr.transform(func)),
1964                r#else: Box::new(r#else.transform(func)),
1965            },
1966            Expr::Collate { expr, collation } => Expr::Collate {
1967                expr: Box::new(expr.transform(func)),
1968                collation,
1969            },
1970            Expr::Alias { expr, name } => Expr::Alias {
1971                expr: Box::new(expr.transform(func)),
1972                name,
1973            },
1974            Expr::ArrayIndex { expr, index } => Expr::ArrayIndex {
1975                expr: Box::new(expr.transform(func)),
1976                index: Box::new(index.transform(func)),
1977            },
1978            Expr::JsonAccess {
1979                expr,
1980                path,
1981                as_text,
1982            } => Expr::JsonAccess {
1983                expr: Box::new(expr.transform(func)),
1984                path: Box::new(path.transform(func)),
1985                as_text,
1986            },
1987            Expr::Lambda { params, body } => Expr::Lambda {
1988                params,
1989                body: Box::new(body.transform(func)),
1990            },
1991            other => other,
1992        };
1993        func(transformed)
1994    }
1995
1996    /// Check whether this expression is a column reference.
1997    #[must_use]
1998    pub fn is_column(&self) -> bool {
1999        matches!(self, Expr::Column { .. })
2000    }
2001
2002    /// Check whether this expression is a literal value (number, string, bool, null).
2003    #[must_use]
2004    pub fn is_literal(&self) -> bool {
2005        matches!(
2006            self,
2007            Expr::Number(_) | Expr::StringLiteral(_) | Expr::Boolean(_) | Expr::Null
2008        )
2009    }
2010
2011    /// Get the SQL representation of this expression for display purposes.
2012    /// For full generation, use the Generator.
2013    #[must_use]
2014    pub fn sql(&self) -> String {
2015        use crate::generator::Generator;
2016        Generator::expr_to_sql(self)
2017    }
2018}
2019
2020/// Helper: collect all column references from an expression.
2021#[must_use]
2022pub fn find_columns(expr: &Expr) -> Vec<&Expr> {
2023    expr.find_all(&|e| matches!(e, Expr::Column { .. }))
2024}
2025
2026/// Helper: collect all table references from a statement.
2027#[must_use]
2028pub fn find_tables(statement: &Statement) -> Vec<&TableRef> {
2029    match statement {
2030        Statement::Select(sel) => {
2031            let mut tables = Vec::new();
2032            if let Some(from) = &sel.from {
2033                collect_table_refs_from_source(&from.source, &mut tables);
2034            }
2035            for join in &sel.joins {
2036                collect_table_refs_from_source(&join.table, &mut tables);
2037            }
2038            tables
2039        }
2040        Statement::Insert(ins) => vec![&ins.table],
2041        Statement::Update(upd) => vec![&upd.table],
2042        Statement::Delete(del) => vec![&del.table],
2043        Statement::CreateTable(ct) => vec![&ct.table],
2044        Statement::DropTable(dt) => vec![&dt.table],
2045        _ => vec![],
2046    }
2047}
2048
2049fn collect_table_refs_from_source<'a>(source: &'a TableSource, tables: &mut Vec<&'a TableRef>) {
2050    match source {
2051        TableSource::Table(table_ref) => tables.push(table_ref),
2052        TableSource::Subquery { .. } => {}
2053        TableSource::TableFunction { .. } => {}
2054        TableSource::Lateral { source } => collect_table_refs_from_source(source, tables),
2055        TableSource::Unnest { .. } => {}
2056    }
2057}