Skip to main content

sqlglot_rust/optimizer/
annotate_types.rs

1//! Type annotation pass for SQL expressions.
2//!
3//! Infers and propagates SQL data types across AST nodes using schema metadata.
4//! Inspired by Python sqlglot's `annotate_types` optimizer pass.
5//!
6//! # Overview
7//!
8//! The pass walks the AST bottom-up, resolving types for:
9//! - **Literals**: `42` → `Int`, `'hello'` → `Varchar`, `TRUE` → `Boolean`
10//! - **Column references**: looked up from the provided [`Schema`]
11//! - **Binary operators**: result type from operand coercion (e.g. `INT + FLOAT → FLOAT`)
12//! - **CAST / TRY_CAST**: the target data type
13//! - **Functions**: return type based on function signature and argument types
14//! - **CASE**: common type across all THEN / ELSE branches
15//! - **Aggregates**: `COUNT → BigInt`, `SUM` depends on input, etc.
16//! - **Subqueries**: type of the single output column
17//!
18//! # Example
19//!
20//! ```rust
21//! use sqlglot_rust::optimizer::annotate_types::annotate_types;
22//! use sqlglot_rust::schema::{MappingSchema, Schema};
23//! use sqlglot_rust::ast::DataType;
24//! use sqlglot_rust::{parse, Dialect};
25//!
26//! let mut schema = MappingSchema::new(Dialect::Ansi);
27//! schema.add_table(&["t"], vec![
28//!     ("id".to_string(), DataType::Int),
29//!     ("name".to_string(), DataType::Varchar(Some(255))),
30//! ]).unwrap();
31//!
32//! let stmt = parse("SELECT id, name FROM t WHERE id > 1", Dialect::Ansi).unwrap();
33//! let annotations = annotate_types(&stmt, &schema);
34//! // annotations now contains inferred types for every expression node
35//! ```
36
37use std::collections::HashMap;
38
39use crate::ast::*;
40use crate::schema::Schema;
41
42// ═══════════════════════════════════════════════════════════════════════
43// TypeAnnotations — the result of type inference
44// ═══════════════════════════════════════════════════════════════════════
45
46/// Stores inferred [`DataType`] annotations for expression nodes in an AST.
47///
48/// Annotations are keyed by raw pointer identity, so this structure is valid
49/// only as long as the underlying AST is not moved, cloned, or dropped.
50/// Intended for single-pass analysis over a borrowed AST.
51pub struct TypeAnnotations {
52    types: HashMap<*const Expr, DataType>,
53}
54
55// Raw pointers are not Send/Sync by default, but our usage is safe because
56// the pointers are derived from shared references with a known lifetime.
57unsafe impl Send for TypeAnnotations {}
58unsafe impl Sync for TypeAnnotations {}
59
60impl TypeAnnotations {
61    fn new() -> Self {
62        Self {
63            types: HashMap::new(),
64        }
65    }
66
67    fn set(&mut self, expr: &Expr, dt: DataType) {
68        self.types.insert(expr as *const Expr, dt);
69    }
70
71    /// Retrieve the inferred type of an expression, if annotated.
72    #[must_use]
73    pub fn get_type(&self, expr: &Expr) -> Option<&DataType> {
74        self.types.get(&(expr as *const Expr))
75    }
76
77    /// Number of annotated nodes.
78    #[must_use]
79    pub fn len(&self) -> usize {
80        self.types.len()
81    }
82
83    /// Returns `true` if no annotations were recorded.
84    #[must_use]
85    pub fn is_empty(&self) -> bool {
86        self.types.is_empty()
87    }
88}
89
90// ═══════════════════════════════════════════════════════════════════════
91// Public entry point
92// ═══════════════════════════════════════════════════════════════════════
93
94/// Annotate all expression nodes in a statement with inferred SQL types.
95///
96/// Walks the AST bottom-up, resolving types from literals, schema column
97/// lookups, operator/function signatures, and type coercion rules.
98///
99/// The returned [`TypeAnnotations`] is valid only while the borrowed `stmt`
100/// is alive and unmodified.
101#[must_use]
102pub fn annotate_types<S: Schema>(stmt: &Statement, schema: &S) -> TypeAnnotations {
103    let mut ann = TypeAnnotations::new();
104    let mut ctx = AnnotationContext::new(schema);
105    annotate_statement(stmt, &mut ctx, &mut ann);
106    ann
107}
108
109// ═══════════════════════════════════════════════════════════════════════
110// Internal context
111// ═══════════════════════════════════════════════════════════════════════
112
113/// Carries schema reference and table alias mappings through the walk.
114struct AnnotationContext<'s, S: Schema> {
115    schema: &'s S,
116    /// Maps table alias or name → table path for column type lookups.
117    table_aliases: HashMap<String, Vec<String>>,
118}
119
120impl<'s, S: Schema> AnnotationContext<'s, S> {
121    fn new(schema: &'s S) -> Self {
122        Self {
123            schema,
124            table_aliases: HashMap::new(),
125        }
126    }
127
128    /// Register a table (by ref) so that columns can be looked up by alias.
129    fn register_table(&mut self, table_ref: &TableRef) {
130        let path = vec![table_ref.name.clone()];
131        let alias = table_ref
132            .alias
133            .as_deref()
134            .unwrap_or(&table_ref.name)
135            .to_string();
136        self.table_aliases.insert(alias, path);
137    }
138
139    /// Look up the type of a column, resolving through table aliases.
140    fn resolve_column_type(&self, table: Option<&str>, column: &str) -> Option<DataType> {
141        if let Some(tbl) = table {
142            // Qualified column — look up via alias map
143            if let Some(path) = self.table_aliases.get(tbl) {
144                let path_refs: Vec<&str> = path.iter().map(String::as_str).collect();
145                return self.schema.get_column_type(&path_refs, column).ok();
146            }
147            // Try the table name directly
148            return self.schema.get_column_type(&[tbl], column).ok();
149        }
150        // Unqualified — search all registered tables
151        for path in self.table_aliases.values() {
152            let path_refs: Vec<&str> = path.iter().map(String::as_str).collect();
153            if let Ok(dt) = self.schema.get_column_type(&path_refs, column) {
154                return Some(dt);
155            }
156        }
157        None
158    }
159}
160
161// ═══════════════════════════════════════════════════════════════════════
162// Statement-level annotation
163// ═══════════════════════════════════════════════════════════════════════
164
165fn annotate_statement<S: Schema>(
166    stmt: &Statement,
167    ctx: &mut AnnotationContext<S>,
168    ann: &mut TypeAnnotations,
169) {
170    match stmt {
171        Statement::Select(sel) => annotate_select(sel, ctx, ann),
172        Statement::SetOperation(set_op) => {
173            annotate_statement(&set_op.left, ctx, ann);
174            annotate_statement(&set_op.right, ctx, ann);
175        }
176        Statement::Insert(ins) => {
177            if let InsertSource::Query(q) = &ins.source {
178                annotate_statement(q, ctx, ann);
179            }
180            for row in match &ins.source {
181                InsertSource::Values(rows) => rows.as_slice(),
182                _ => &[],
183            } {
184                for expr in row {
185                    annotate_expr(expr, ctx, ann);
186                }
187            }
188        }
189        Statement::Update(upd) => {
190            for (_, expr) in &upd.assignments {
191                annotate_expr(expr, ctx, ann);
192            }
193            if let Some(wh) = &upd.where_clause {
194                annotate_expr(wh, ctx, ann);
195            }
196        }
197        Statement::Delete(del) => {
198            if let Some(wh) = &del.where_clause {
199                annotate_expr(wh, ctx, ann);
200            }
201        }
202        Statement::Expression(expr) => {
203            annotate_expr(expr, ctx, ann);
204        }
205        Statement::Explain(expl) => {
206            annotate_statement(&expl.statement, ctx, ann);
207        }
208        // DDL / transaction / other statements — no expression types to annotate
209        _ => {}
210    }
211}
212
213fn annotate_select<S: Schema>(
214    sel: &SelectStatement,
215    ctx: &mut AnnotationContext<S>,
216    ann: &mut TypeAnnotations,
217) {
218    // 1. Register CTEs
219    for cte in &sel.ctes {
220        annotate_statement(&cte.query, ctx, ann);
221    }
222
223    // 2. Register FROM sources
224    if let Some(from) = &sel.from {
225        register_table_source(&from.source, ctx);
226    }
227    for join in &sel.joins {
228        register_table_source(&join.table, ctx);
229    }
230
231    // 3. Annotate WHERE clause
232    if let Some(wh) = &sel.where_clause {
233        annotate_expr(wh, ctx, ann);
234    }
235
236    // 4. Annotate SELECT columns
237    for item in &sel.columns {
238        if let SelectItem::Expr { expr, .. } = item {
239            annotate_expr(expr, ctx, ann);
240        }
241    }
242
243    // 5. Annotate GROUP BY
244    for expr in &sel.group_by {
245        annotate_expr(expr, ctx, ann);
246    }
247
248    // 6. Annotate HAVING
249    if let Some(having) = &sel.having {
250        annotate_expr(having, ctx, ann);
251    }
252
253    // 7. Annotate ORDER BY
254    for ob in &sel.order_by {
255        annotate_expr(&ob.expr, ctx, ann);
256    }
257
258    // 8. Annotate LIMIT / OFFSET
259    if let Some(limit) = &sel.limit {
260        annotate_expr(limit, ctx, ann);
261    }
262    if let Some(offset) = &sel.offset {
263        annotate_expr(offset, ctx, ann);
264    }
265    if let Some(fetch) = &sel.fetch_first {
266        annotate_expr(fetch, ctx, ann);
267    }
268
269    // 9. Annotate QUALIFY
270    if let Some(qualify) = &sel.qualify {
271        annotate_expr(qualify, ctx, ann);
272    }
273
274    // 10. Annotate JOIN ON conditions
275    for join in &sel.joins {
276        if let Some(on) = &join.on {
277            annotate_expr(on, ctx, ann);
278        }
279    }
280}
281
282fn register_table_source<S: Schema>(source: &TableSource, ctx: &mut AnnotationContext<S>) {
283    match source {
284        TableSource::Table(tref) => ctx.register_table(tref),
285        TableSource::Subquery { alias, .. } => {
286            // Subqueries as sources don't have schema entries to register.
287            // Their output column types would come from recursive annotation.
288            let _ = alias;
289        }
290        TableSource::TableFunction { alias, .. } => {
291            let _ = alias;
292        }
293        TableSource::Lateral { source } => register_table_source(source, ctx),
294        TableSource::Pivot { source, .. } | TableSource::Unpivot { source, .. } => {
295            register_table_source(source, ctx);
296        }
297        TableSource::Unnest { .. } => {}
298    }
299}
300
301// ═══════════════════════════════════════════════════════════════════════
302// Expression-level annotation (bottom-up)
303// ═══════════════════════════════════════════════════════════════════════
304
305fn annotate_expr<S: Schema>(expr: &Expr, ctx: &AnnotationContext<S>, ann: &mut TypeAnnotations) {
306    // First annotate children, then determine this node's type.
307    annotate_children(expr, ctx, ann);
308
309    let dt = infer_type(expr, ctx, ann);
310    if let Some(t) = dt {
311        ann.set(expr, t);
312    }
313}
314
315/// Recursively annotate child expressions before the parent.
316fn annotate_children<S: Schema>(
317    expr: &Expr,
318    ctx: &AnnotationContext<S>,
319    ann: &mut TypeAnnotations,
320) {
321    match expr {
322        Expr::BinaryOp { left, right, .. } => {
323            annotate_expr(left, ctx, ann);
324            annotate_expr(right, ctx, ann);
325        }
326        Expr::UnaryOp { expr: inner, .. } => annotate_expr(inner, ctx, ann),
327        Expr::Function { args, filter, .. } => {
328            for arg in args {
329                annotate_expr(arg, ctx, ann);
330            }
331            if let Some(f) = filter {
332                annotate_expr(f, ctx, ann);
333            }
334        }
335        Expr::Between {
336            expr: e, low, high, ..
337        } => {
338            annotate_expr(e, ctx, ann);
339            annotate_expr(low, ctx, ann);
340            annotate_expr(high, ctx, ann);
341        }
342        Expr::InList { expr: e, list, .. } => {
343            annotate_expr(e, ctx, ann);
344            for item in list {
345                annotate_expr(item, ctx, ann);
346            }
347        }
348        Expr::InSubquery {
349            expr: e, subquery, ..
350        } => {
351            annotate_expr(e, ctx, ann);
352            let mut sub_ctx = AnnotationContext::new(ctx.schema);
353            annotate_statement(subquery, &mut sub_ctx, ann);
354        }
355        Expr::IsNull { expr: e, .. } | Expr::IsBool { expr: e, .. } => {
356            annotate_expr(e, ctx, ann);
357        }
358        Expr::Like {
359            expr: e,
360            pattern,
361            escape,
362            ..
363        }
364        | Expr::ILike {
365            expr: e,
366            pattern,
367            escape,
368            ..
369        } => {
370            annotate_expr(e, ctx, ann);
371            annotate_expr(pattern, ctx, ann);
372            if let Some(esc) = escape {
373                annotate_expr(esc, ctx, ann);
374            }
375        }
376        Expr::Case {
377            operand,
378            when_clauses,
379            else_clause,
380        } => {
381            if let Some(op) = operand {
382                annotate_expr(op, ctx, ann);
383            }
384            for (cond, result) in when_clauses {
385                annotate_expr(cond, ctx, ann);
386                annotate_expr(result, ctx, ann);
387            }
388            if let Some(el) = else_clause {
389                annotate_expr(el, ctx, ann);
390            }
391        }
392        Expr::Nested(inner) => annotate_expr(inner, ctx, ann),
393        Expr::Cast { expr: e, .. } | Expr::TryCast { expr: e, .. } => {
394            annotate_expr(e, ctx, ann);
395        }
396        Expr::Extract { expr: e, .. } => annotate_expr(e, ctx, ann),
397        Expr::Interval { value, .. } => annotate_expr(value, ctx, ann),
398        Expr::ArrayLiteral(items) | Expr::Tuple(items) | Expr::Coalesce(items) => {
399            for item in items {
400                annotate_expr(item, ctx, ann);
401            }
402        }
403        Expr::If {
404            condition,
405            true_val,
406            false_val,
407        } => {
408            annotate_expr(condition, ctx, ann);
409            annotate_expr(true_val, ctx, ann);
410            if let Some(fv) = false_val {
411                annotate_expr(fv, ctx, ann);
412            }
413        }
414        Expr::NullIf { expr: e, r#else } => {
415            annotate_expr(e, ctx, ann);
416            annotate_expr(r#else, ctx, ann);
417        }
418        Expr::Collate { expr: e, .. } => annotate_expr(e, ctx, ann),
419        Expr::Alias { expr: e, .. } => annotate_expr(e, ctx, ann),
420        Expr::ArrayIndex { expr: e, index } => {
421            annotate_expr(e, ctx, ann);
422            annotate_expr(index, ctx, ann);
423        }
424        Expr::JsonAccess { expr: e, path, .. } => {
425            annotate_expr(e, ctx, ann);
426            annotate_expr(path, ctx, ann);
427        }
428        Expr::Lambda { body, .. } => annotate_expr(body, ctx, ann),
429        Expr::AnyOp { expr: e, right, .. } | Expr::AllOp { expr: e, right, .. } => {
430            annotate_expr(e, ctx, ann);
431            annotate_expr(right, ctx, ann);
432        }
433        Expr::Subquery(sub) => {
434            let mut sub_ctx = AnnotationContext::new(ctx.schema);
435            annotate_statement(sub, &mut sub_ctx, ann);
436        }
437        Expr::Exists { subquery, .. } => {
438            let mut sub_ctx = AnnotationContext::new(ctx.schema);
439            annotate_statement(subquery, &mut sub_ctx, ann);
440        }
441        Expr::TypedFunction { func, filter, .. } => {
442            annotate_typed_function_children(func, ctx, ann);
443            if let Some(f) = filter {
444                annotate_expr(f, ctx, ann);
445            }
446        }
447        Expr::Cube { exprs } | Expr::Rollup { exprs } | Expr::GroupingSets { sets: exprs } => {
448            for item in exprs {
449                annotate_expr(item, ctx, ann);
450            }
451        }
452        // Leaf nodes — no children to annotate
453        Expr::Column { .. }
454        | Expr::Number(_)
455        | Expr::StringLiteral(_)
456        | Expr::Boolean(_)
457        | Expr::Null
458        | Expr::Wildcard
459        | Expr::Star
460        | Expr::Parameter(_)
461        | Expr::TypeExpr(_)
462        | Expr::QualifiedWildcard { .. }
463        | Expr::Default
464        | Expr::Commented { .. } => {}
465    }
466}
467
468/// Annotate children of a TypedFunction.
469fn annotate_typed_function_children<S: Schema>(
470    func: &TypedFunction,
471    ctx: &AnnotationContext<S>,
472    ann: &mut TypeAnnotations,
473) {
474    // Use walk_children to visit all child expressions and annotate each
475    func.walk_children(&mut |child| {
476        annotate_expr(child, ctx, ann);
477        true
478    });
479}
480
481// ═══════════════════════════════════════════════════════════════════════
482// Type inference for a single expression node
483// ═══════════════════════════════════════════════════════════════════════
484
485fn infer_type<S: Schema>(
486    expr: &Expr,
487    ctx: &AnnotationContext<S>,
488    ann: &TypeAnnotations,
489) -> Option<DataType> {
490    match expr {
491        // ── Literals ───────────────────────────────────────────────────
492        Expr::Number(s) => Some(infer_number_type(s)),
493        Expr::StringLiteral(_) => Some(DataType::Varchar(None)),
494        Expr::Boolean(_) => Some(DataType::Boolean),
495        Expr::Null => Some(DataType::Null),
496
497        // ── Column reference ──────────────────────────────────────────
498        Expr::Column { table, name, .. } => ctx.resolve_column_type(table.as_deref(), name),
499
500        // ── Binary operators ──────────────────────────────────────────
501        Expr::BinaryOp { left, op, right } => {
502            infer_binary_op_type(op, ann.get_type(left), ann.get_type(right))
503        }
504
505        // ── Unary operators ───────────────────────────────────────────
506        Expr::UnaryOp { op, expr: inner } => match op {
507            UnaryOperator::Not => Some(DataType::Boolean),
508            UnaryOperator::Minus | UnaryOperator::Plus => ann.get_type(inner).cloned(),
509            UnaryOperator::BitwiseNot => ann.get_type(inner).cloned(),
510        },
511
512        // ── CAST / TRY_CAST ──────────────────────────────────────────
513        Expr::Cast { data_type, .. } | Expr::TryCast { data_type, .. } => Some(data_type.clone()),
514
515        // ── CASE expression ──────────────────────────────────────────
516        Expr::Case {
517            when_clauses,
518            else_clause,
519            ..
520        } => {
521            let mut result_types: Vec<&DataType> = Vec::new();
522            for (_, result) in when_clauses {
523                if let Some(t) = ann.get_type(result) {
524                    result_types.push(t);
525                }
526            }
527            if let Some(el) = else_clause {
528                if let Some(t) = ann.get_type(el.as_ref()) {
529                    result_types.push(t);
530                }
531            }
532            common_type(&result_types)
533        }
534
535        // ── IF expression ────────────────────────────────────────────
536        Expr::If {
537            true_val,
538            false_val,
539            ..
540        } => {
541            let mut types = Vec::new();
542            if let Some(t) = ann.get_type(true_val) {
543                types.push(t);
544            }
545            if let Some(fv) = false_val {
546                if let Some(t) = ann.get_type(fv.as_ref()) {
547                    types.push(t);
548                }
549            }
550            common_type(&types)
551        }
552
553        // ── COALESCE ─────────────────────────────────────────────────
554        Expr::Coalesce(items) => {
555            let types: Vec<&DataType> = items.iter().filter_map(|e| ann.get_type(e)).collect();
556            common_type(&types)
557        }
558
559        // ── NULLIF ───────────────────────────────────────────────────
560        Expr::NullIf { expr: e, .. } => ann.get_type(e.as_ref()).cloned(),
561
562        // ── Generic function ─────────────────────────────────────────
563        Expr::Function { name, args, .. } => infer_generic_function_type(name, args, ctx, ann),
564
565        // ── Typed functions ──────────────────────────────────────────
566        Expr::TypedFunction { func, .. } => infer_typed_function_type(func, ann),
567
568        // ── Subquery (scalar) ────────────────────────────────────────
569        Expr::Subquery(sub) => infer_subquery_type(sub, ann),
570
571        // ── EXISTS → Boolean ─────────────────────────────────────────
572        Expr::Exists { .. } => Some(DataType::Boolean),
573
574        // ── Boolean predicates ───────────────────────────────────────
575        Expr::Between { .. }
576        | Expr::InList { .. }
577        | Expr::InSubquery { .. }
578        | Expr::IsNull { .. }
579        | Expr::IsBool { .. }
580        | Expr::Like { .. }
581        | Expr::ILike { .. }
582        | Expr::AnyOp { .. }
583        | Expr::AllOp { .. } => Some(DataType::Boolean),
584
585        // ── EXTRACT → numeric ────────────────────────────────────────
586        Expr::Extract { .. } => Some(DataType::Int),
587
588        // ── INTERVAL → Interval ──────────────────────────────────────
589        Expr::Interval { .. } => Some(DataType::Interval),
590
591        // ── Array literal ────────────────────────────────────────────
592        Expr::ArrayLiteral(items) => {
593            let elem_types: Vec<&DataType> = items.iter().filter_map(|e| ann.get_type(e)).collect();
594            let elem = common_type(&elem_types);
595            Some(DataType::Array(elem.map(Box::new)))
596        }
597
598        // ── Tuple ────────────────────────────────────────────────────
599        Expr::Tuple(items) => {
600            let types: Vec<DataType> = items
601                .iter()
602                .map(|e| ann.get_type(e).cloned().unwrap_or(DataType::Null))
603                .collect();
604            Some(DataType::Tuple(types))
605        }
606
607        // ── Array index → element type ───────────────────────────────
608        Expr::ArrayIndex { expr: e, .. } => match ann.get_type(e.as_ref()) {
609            Some(DataType::Array(Some(elem))) => Some(elem.as_ref().clone()),
610            _ => None,
611        },
612
613        // ── JSON access ──────────────────────────────────────────────
614        Expr::JsonAccess { as_text, .. } => {
615            if *as_text {
616                Some(DataType::Text)
617            } else {
618                Some(DataType::Json)
619            }
620        }
621
622        // ── Nested / Alias — pass through ────────────────────────────
623        Expr::Nested(inner) => ann.get_type(inner.as_ref()).cloned(),
624        Expr::Alias { expr: e, .. } => ann.get_type(e.as_ref()).cloned(),
625
626        // ── Collate → Varchar ────────────────────────────────────────
627        Expr::Collate { .. } => Some(DataType::Varchar(None)),
628
629        // ── TypeExpr ─────────────────────────────────────────────────
630        Expr::TypeExpr(dt) => Some(dt.clone()),
631
632        // ── Others — no type ─────────────────────────────────────────
633        Expr::Wildcard
634        | Expr::Star
635        | Expr::QualifiedWildcard { .. }
636        | Expr::Parameter(_)
637        | Expr::Lambda { .. }
638        | Expr::Default
639        | Expr::Cube { .. }
640        | Expr::Rollup { .. }
641        | Expr::GroupingSets { .. }
642        | Expr::Commented { .. } => None,
643    }
644}
645
646// ═══════════════════════════════════════════════════════════════════════
647// Number type inference
648// ═══════════════════════════════════════════════════════════════════════
649
650fn infer_number_type(s: &str) -> DataType {
651    if s.contains('.') || s.contains('e') || s.contains('E') {
652        DataType::Double
653    } else if let Ok(v) = s.parse::<i64>() {
654        if v >= i32::MIN as i64 && v <= i32::MAX as i64 {
655            DataType::Int
656        } else {
657            DataType::BigInt
658        }
659    } else {
660        // Very large numbers or special formats
661        DataType::BigInt
662    }
663}
664
665// ═══════════════════════════════════════════════════════════════════════
666// Binary operator type inference
667// ═══════════════════════════════════════════════════════════════════════
668
669fn infer_binary_op_type(
670    op: &BinaryOperator,
671    left: Option<&DataType>,
672    right: Option<&DataType>,
673) -> Option<DataType> {
674    use BinaryOperator::*;
675    match op {
676        // Comparison operators → Boolean
677        Eq | Neq | Lt | Gt | LtEq | GtEq => Some(DataType::Boolean),
678
679        // Logical operators → Boolean
680        And | Or | Xor => Some(DataType::Boolean),
681
682        // String concatenation → Varchar
683        Concat => Some(DataType::Varchar(None)),
684
685        // Arithmetic → coerce operand types
686        Plus | Minus | Multiply | Divide | Modulo => match (left, right) {
687            (Some(l), Some(r)) => Some(coerce_numeric(l, r)),
688            (Some(l), None) => Some(l.clone()),
689            (None, Some(r)) => Some(r.clone()),
690            (None, None) => None,
691        },
692
693        // Bitwise → integer type
694        BitwiseAnd | BitwiseOr | BitwiseXor | ShiftLeft | ShiftRight => match (left, right) {
695            (Some(l), Some(r)) => Some(coerce_numeric(l, r)),
696            (Some(l), None) => Some(l.clone()),
697            (None, Some(r)) => Some(r.clone()),
698            (None, None) => Some(DataType::Int),
699        },
700
701        // JSON operators
702        Arrow => Some(DataType::Json),
703        DoubleArrow => Some(DataType::Text),
704    }
705}
706
707// ═══════════════════════════════════════════════════════════════════════
708// Generic (untyped) function return type inference
709// ═══════════════════════════════════════════════════════════════════════
710
711fn infer_generic_function_type<S: Schema>(
712    name: &str,
713    args: &[Expr],
714    ctx: &AnnotationContext<S>,
715    ann: &TypeAnnotations,
716) -> Option<DataType> {
717    let upper = name.to_uppercase();
718    match upper.as_str() {
719        // Aggregate functions
720        "COUNT" | "COUNT_BIG" => Some(DataType::BigInt),
721        "SUM" => args
722            .first()
723            .and_then(|a| ann.get_type(a))
724            .map(|t| coerce_sum_type(t)),
725        "AVG" => Some(DataType::Double),
726        "MIN" | "MAX" => args.first().and_then(|a| ann.get_type(a)).cloned(),
727        "VARIANCE" | "VAR_SAMP" | "VAR_POP" | "STDDEV" | "STDDEV_SAMP" | "STDDEV_POP" => {
728            Some(DataType::Double)
729        }
730        "APPROX_COUNT_DISTINCT" | "APPROX_DISTINCT" => Some(DataType::BigInt),
731
732        // String functions
733        "CONCAT" | "UPPER" | "LOWER" | "TRIM" | "LTRIM" | "RTRIM" | "LPAD" | "RPAD" | "REPLACE"
734        | "REVERSE" | "SUBSTRING" | "SUBSTR" | "LEFT" | "RIGHT" | "INITCAP" | "REPEAT"
735        | "TRANSLATE" | "FORMAT" | "CONCAT_WS" | "SPACE" | "REPLICATE" => {
736            Some(DataType::Varchar(None))
737        }
738        "LENGTH" | "LEN" | "CHAR_LENGTH" | "CHARACTER_LENGTH" | "OCTET_LENGTH" | "BIT_LENGTH" => {
739            Some(DataType::Int)
740        }
741        "POSITION" | "STRPOS" | "LOCATE" | "INSTR" | "CHARINDEX" => Some(DataType::Int),
742        "ASCII" => Some(DataType::Int),
743        "CHR" | "CHAR" => Some(DataType::Varchar(Some(1))),
744
745        // Math functions
746        "ABS" | "CEIL" | "CEILING" | "FLOOR" => args.first().and_then(|a| ann.get_type(a)).cloned(),
747        "ROUND" | "TRUNCATE" | "TRUNC" => args.first().and_then(|a| ann.get_type(a)).cloned(),
748        "SQRT" | "LN" | "LOG" | "LOG2" | "LOG10" | "EXP" | "POWER" | "POW" | "ACOS" | "ASIN"
749        | "ATAN" | "ATAN2" | "COS" | "SIN" | "TAN" | "COT" | "DEGREES" | "RADIANS" | "PI"
750        | "SIGN" => Some(DataType::Double),
751        "MOD" => {
752            match (
753                args.first().and_then(|a| ann.get_type(a)),
754                args.get(1).and_then(|a| ann.get_type(a)),
755            ) {
756                (Some(l), Some(r)) => Some(coerce_numeric(l, r)),
757                (Some(l), _) => Some(l.clone()),
758                (_, Some(r)) => Some(r.clone()),
759                _ => Some(DataType::Int),
760            }
761        }
762        "GREATEST" | "LEAST" => {
763            let types: Vec<&DataType> = args.iter().filter_map(|a| ann.get_type(a)).collect();
764            common_type(&types)
765        }
766        "RANDOM" | "RAND" => Some(DataType::Double),
767
768        // Date/Time functions
769        "CURRENT_DATE" | "CURDATE" | "TODAY" => Some(DataType::Date),
770        "CURRENT_TIMESTAMP" | "NOW" | "GETDATE" | "SYSDATE" | "SYSTIMESTAMP" | "LOCALTIMESTAMP" => {
771            Some(DataType::Timestamp {
772                precision: None,
773                with_tz: false,
774            })
775        }
776        "CURRENT_TIME" | "CURTIME" => Some(DataType::Time { precision: None }),
777        "DATE" | "TO_DATE" | "DATE_TRUNC" | "DATE_ADD" | "DATE_SUB" | "DATEADD" | "DATESUB"
778        | "ADDDATE" | "SUBDATE" => Some(DataType::Date),
779        "TIMESTAMP" | "TO_TIMESTAMP" => Some(DataType::Timestamp {
780            precision: None,
781            with_tz: false,
782        }),
783        "YEAR" | "MONTH" | "DAY" | "DAYOFWEEK" | "DAYOFYEAR" | "HOUR" | "MINUTE" | "SECOND"
784        | "QUARTER" | "WEEK" | "EXTRACT" | "DATEDIFF" | "TIMESTAMPDIFF" | "MONTHS_BETWEEN" => {
785            Some(DataType::Int)
786        }
787
788        // Type conversion
789        "CAST" | "TRY_CAST" | "SAFE_CAST" | "CONVERT" => None, // handled by Expr::Cast
790
791        // Boolean functions
792        "COALESCE" => {
793            let types: Vec<&DataType> = args.iter().filter_map(|a| ann.get_type(a)).collect();
794            common_type(&types)
795        }
796        "NULLIF" => args.first().and_then(|a| ann.get_type(a)).cloned(),
797        "IF" | "IIF" => {
798            // IF(cond, true_val, false_val) — type from true_val
799            args.get(1).and_then(|a| ann.get_type(a)).cloned()
800        }
801        "IFNULL" | "NVL" | "ISNULL" => {
802            let types: Vec<&DataType> = args.iter().filter_map(|a| ann.get_type(a)).collect();
803            common_type(&types)
804        }
805
806        // JSON functions
807        "JSON_EXTRACT" | "JSON_QUERY" | "GET_JSON_OBJECT" => Some(DataType::Json),
808        "JSON_EXTRACT_SCALAR" | "JSON_VALUE" | "JSON_EXTRACT_PATH_TEXT" => {
809            Some(DataType::Varchar(None))
810        }
811        "TO_JSON" | "JSON_OBJECT" | "JSON_ARRAY" | "JSON_BUILD_OBJECT" | "JSON_BUILD_ARRAY" => {
812            Some(DataType::Json)
813        }
814        "PARSE_JSON" | "JSON_PARSE" | "JSON" => Some(DataType::Json),
815
816        // Array functions
817        "ARRAY_AGG" | "COLLECT_LIST" | "COLLECT_SET" => {
818            let elem = args.first().and_then(|a| ann.get_type(a)).cloned();
819            Some(DataType::Array(elem.map(Box::new)))
820        }
821        "ARRAY_LENGTH" | "ARRAY_SIZE" | "CARDINALITY" => Some(DataType::Int),
822        "ARRAY" | "ARRAY_CONSTRUCT" => {
823            let types: Vec<&DataType> = args.iter().filter_map(|a| ann.get_type(a)).collect();
824            let elem = common_type(&types);
825            Some(DataType::Array(elem.map(Box::new)))
826        }
827        "ARRAY_CONTAINS" | "ARRAY_POSITION" => Some(DataType::Boolean),
828
829        // Window ranking
830        "ROW_NUMBER" | "RANK" | "DENSE_RANK" | "NTILE" | "CUME_DIST" | "PERCENT_RANK" => {
831            Some(DataType::BigInt)
832        }
833
834        // Hash / crypto
835        "MD5" | "SHA1" | "SHA" | "SHA2" | "SHA256" | "SHA512" => Some(DataType::Varchar(None)),
836        "HEX" | "TO_HEX" => Some(DataType::Varchar(None)),
837        "UNHEX" | "FROM_HEX" => Some(DataType::Varbinary(None)),
838        "CRC32" | "HASH" => Some(DataType::BigInt),
839
840        // Type checking
841        "TYPEOF" | "TYPE_OF" => Some(DataType::Varchar(None)),
842
843        // UDFs — check schema
844        _ => ctx.schema.get_udf_type(&upper).cloned(),
845    }
846}
847
848// ═══════════════════════════════════════════════════════════════════════
849// TypedFunction return type inference
850// ═══════════════════════════════════════════════════════════════════════
851
852fn infer_typed_function_type(func: &TypedFunction, ann: &TypeAnnotations) -> Option<DataType> {
853    match func {
854        // ── Date/Time → Date or Timestamp ────────────────────────────
855        TypedFunction::DateAdd { .. }
856        | TypedFunction::DateSub { .. }
857        | TypedFunction::DateTrunc { .. }
858        | TypedFunction::TsOrDsToDate { .. } => Some(DataType::Date),
859        TypedFunction::DateDiff { .. } => Some(DataType::Int),
860        TypedFunction::CurrentDate => Some(DataType::Date),
861        TypedFunction::CurrentTimestamp => Some(DataType::Timestamp {
862            precision: None,
863            with_tz: false,
864        }),
865        TypedFunction::StrToTime { .. } => Some(DataType::Timestamp {
866            precision: None,
867            with_tz: false,
868        }),
869        TypedFunction::TimeToStr { .. } => Some(DataType::Varchar(None)),
870        TypedFunction::Year { .. } | TypedFunction::Month { .. } | TypedFunction::Day { .. } => {
871            Some(DataType::Int)
872        }
873
874        // ── String → Varchar ─────────────────────────────────────────
875        TypedFunction::Trim { .. }
876        | TypedFunction::Substring { .. }
877        | TypedFunction::Upper { .. }
878        | TypedFunction::Lower { .. }
879        | TypedFunction::Initcap { .. }
880        | TypedFunction::Replace { .. }
881        | TypedFunction::Reverse { .. }
882        | TypedFunction::Left { .. }
883        | TypedFunction::Right { .. }
884        | TypedFunction::Lpad { .. }
885        | TypedFunction::Rpad { .. }
886        | TypedFunction::ConcatWs { .. } => Some(DataType::Varchar(None)),
887        TypedFunction::Length { .. } => Some(DataType::Int),
888        TypedFunction::RegexpLike { .. } => Some(DataType::Boolean),
889        TypedFunction::RegexpExtract { .. } => Some(DataType::Varchar(None)),
890        TypedFunction::RegexpReplace { .. } => Some(DataType::Varchar(None)),
891        TypedFunction::Split { .. } => {
892            Some(DataType::Array(Some(Box::new(DataType::Varchar(None)))))
893        }
894
895        // ── Aggregates ───────────────────────────────────────────────
896        TypedFunction::Count { .. } => Some(DataType::BigInt),
897        TypedFunction::Sum { expr, .. } => ann.get_type(expr.as_ref()).map(|t| coerce_sum_type(t)),
898        TypedFunction::Avg { .. } => Some(DataType::Double),
899        TypedFunction::Min { expr } | TypedFunction::Max { expr } => {
900            ann.get_type(expr.as_ref()).cloned()
901        }
902        TypedFunction::ArrayAgg { expr, .. } => {
903            let elem = ann.get_type(expr.as_ref()).cloned();
904            Some(DataType::Array(elem.map(Box::new)))
905        }
906        TypedFunction::ApproxDistinct { .. } => Some(DataType::BigInt),
907        TypedFunction::Variance { .. } | TypedFunction::Stddev { .. } => Some(DataType::Double),
908
909        // ── Array ────────────────────────────────────────────────────
910        TypedFunction::ArrayConcat { arrays } => {
911            // Type is the same as input arrays
912            arrays.first().and_then(|a| ann.get_type(a)).cloned()
913        }
914        TypedFunction::ArrayContains { .. } => Some(DataType::Boolean),
915        TypedFunction::ArraySize { .. } => Some(DataType::Int),
916        TypedFunction::Explode { expr } => {
917            // Unwrap array element type
918            match ann.get_type(expr.as_ref()) {
919                Some(DataType::Array(Some(elem))) => Some(elem.as_ref().clone()),
920                _ => None,
921            }
922        }
923        TypedFunction::GenerateSeries { .. } => Some(DataType::Int),
924        TypedFunction::Flatten { expr } => ann.get_type(expr.as_ref()).cloned(),
925
926        // ── JSON ─────────────────────────────────────────────────────
927        TypedFunction::JSONExtract { .. } => Some(DataType::Json),
928        TypedFunction::JSONExtractScalar { .. } => Some(DataType::Varchar(None)),
929        TypedFunction::ParseJSON { .. } | TypedFunction::JSONFormat { .. } => Some(DataType::Json),
930
931        // ── Window ───────────────────────────────────────────────────
932        TypedFunction::RowNumber | TypedFunction::Rank | TypedFunction::DenseRank => {
933            Some(DataType::BigInt)
934        }
935        TypedFunction::NTile { .. } => Some(DataType::BigInt),
936        TypedFunction::Lead { expr, .. }
937        | TypedFunction::Lag { expr, .. }
938        | TypedFunction::FirstValue { expr }
939        | TypedFunction::LastValue { expr } => ann.get_type(expr.as_ref()).cloned(),
940
941        // ── Math ─────────────────────────────────────────────────────
942        TypedFunction::Abs { expr }
943        | TypedFunction::Ceil { expr }
944        | TypedFunction::Floor { expr } => ann.get_type(expr.as_ref()).cloned(),
945        TypedFunction::Round { expr, .. } => ann.get_type(expr.as_ref()).cloned(),
946        TypedFunction::Log { .. }
947        | TypedFunction::Ln { .. }
948        | TypedFunction::Pow { .. }
949        | TypedFunction::Sqrt { .. } => Some(DataType::Double),
950        TypedFunction::Greatest { exprs } | TypedFunction::Least { exprs } => {
951            let types: Vec<&DataType> = exprs.iter().filter_map(|e| ann.get_type(e)).collect();
952            common_type(&types)
953        }
954        TypedFunction::Mod { left, right } => {
955            match (ann.get_type(left.as_ref()), ann.get_type(right.as_ref())) {
956                (Some(l), Some(r)) => Some(coerce_numeric(l, r)),
957                (Some(l), _) => Some(l.clone()),
958                (_, Some(r)) => Some(r.clone()),
959                _ => Some(DataType::Int),
960            }
961        }
962
963        // ── Conversion ───────────────────────────────────────────────
964        TypedFunction::Hex { .. } | TypedFunction::Md5 { .. } | TypedFunction::Sha { .. } => {
965            Some(DataType::Varchar(None))
966        }
967        TypedFunction::Sha2 { .. } => Some(DataType::Varchar(None)),
968        TypedFunction::Unhex { .. } => Some(DataType::Varbinary(None)),
969    }
970}
971
972// ═══════════════════════════════════════════════════════════════════════
973// Subquery type inference
974// ═══════════════════════════════════════════════════════════════════════
975
976fn infer_subquery_type(sub: &Statement, ann: &TypeAnnotations) -> Option<DataType> {
977    // The type of a scalar subquery is the type of its single output column
978    if let Statement::Select(sel) = sub {
979        if let Some(SelectItem::Expr { expr, .. }) = sel.columns.first() {
980            return ann.get_type(expr).cloned();
981        }
982    }
983    None
984}
985
986// ═══════════════════════════════════════════════════════════════════════
987// Type coercion helpers
988// ═══════════════════════════════════════════════════════════════════════
989
990/// Numeric type widening precedence (higher = wider).
991fn numeric_precedence(dt: &DataType) -> u8 {
992    match dt {
993        DataType::Boolean => 1,
994        DataType::TinyInt => 2,
995        DataType::SmallInt => 3,
996        DataType::Int | DataType::Serial => 4,
997        DataType::BigInt | DataType::BigSerial => 5,
998        DataType::Real | DataType::Float => 6,
999        DataType::Double => 7,
1000        DataType::Decimal { .. } | DataType::Numeric { .. } => 8,
1001        _ => 0,
1002    }
1003}
1004
1005/// Coerce two numeric types to their common wider type.
1006fn coerce_numeric(left: &DataType, right: &DataType) -> DataType {
1007    let lp = numeric_precedence(left);
1008    let rp = numeric_precedence(right);
1009    if lp == 0 && rp == 0 {
1010        // Neither is numeric — fall back to left
1011        return left.clone();
1012    }
1013    if lp >= rp {
1014        left.clone()
1015    } else {
1016        right.clone()
1017    }
1018}
1019
1020/// Determine the return type of SUM based on input type.
1021fn coerce_sum_type(input: &DataType) -> DataType {
1022    match input {
1023        DataType::TinyInt | DataType::SmallInt | DataType::Int | DataType::BigInt => {
1024            DataType::BigInt
1025        }
1026        DataType::Float | DataType::Real => DataType::Double,
1027        DataType::Double => DataType::Double,
1028        DataType::Decimal { precision, scale } => DataType::Decimal {
1029            precision: *precision,
1030            scale: *scale,
1031        },
1032        DataType::Numeric { precision, scale } => DataType::Numeric {
1033            precision: *precision,
1034            scale: *scale,
1035        },
1036        _ => DataType::BigInt,
1037    }
1038}
1039
1040/// Find the common (widest) type among a set of types.
1041fn common_type(types: &[&DataType]) -> Option<DataType> {
1042    if types.is_empty() {
1043        return None;
1044    }
1045    let mut result = types[0];
1046    for t in &types[1..] {
1047        // Skip NULL — it doesn't contribute to the common type
1048        if **t == DataType::Null {
1049            continue;
1050        }
1051        if *result == DataType::Null {
1052            result = t;
1053            continue;
1054        }
1055        // If both are numeric, pick the wider one
1056        let lp = numeric_precedence(result);
1057        let rp = numeric_precedence(t);
1058        if lp > 0 && rp > 0 {
1059            if rp > lp {
1060                result = t;
1061            }
1062            continue;
1063        }
1064        // If both are string-like, prefer VARCHAR
1065        if is_string_type(result) && is_string_type(t) {
1066            result = if matches!(result, DataType::Text) || matches!(t, DataType::Text) {
1067                if matches!(result, DataType::Text) {
1068                    result
1069                } else {
1070                    t
1071                }
1072            } else {
1073                result // keep first
1074            };
1075            continue;
1076        }
1077        // Otherwise keep the first non-null type
1078    }
1079    Some(result.clone())
1080}
1081
1082fn is_string_type(dt: &DataType) -> bool {
1083    matches!(
1084        dt,
1085        DataType::Varchar(_) | DataType::Char(_) | DataType::Text | DataType::String
1086    )
1087}
1088
1089// ═══════════════════════════════════════════════════════════════════════
1090// Tests
1091// ═══════════════════════════════════════════════════════════════════════
1092
1093#[cfg(test)]
1094mod tests {
1095    use super::*;
1096    use crate::dialects::Dialect;
1097    use crate::parser::Parser;
1098    use crate::schema::{MappingSchema, Schema};
1099
1100    fn setup_schema() -> MappingSchema {
1101        let mut schema = MappingSchema::new(Dialect::Ansi);
1102        schema
1103            .add_table(
1104                &["users"],
1105                vec![
1106                    ("id".to_string(), DataType::Int),
1107                    ("name".to_string(), DataType::Varchar(Some(255))),
1108                    ("age".to_string(), DataType::Int),
1109                    ("salary".to_string(), DataType::Double),
1110                    ("active".to_string(), DataType::Boolean),
1111                    (
1112                        "created_at".to_string(),
1113                        DataType::Timestamp {
1114                            precision: None,
1115                            with_tz: false,
1116                        },
1117                    ),
1118                ],
1119            )
1120            .unwrap();
1121        schema
1122            .add_table(
1123                &["orders"],
1124                vec![
1125                    ("id".to_string(), DataType::Int),
1126                    ("user_id".to_string(), DataType::Int),
1127                    (
1128                        "amount".to_string(),
1129                        DataType::Decimal {
1130                            precision: Some(10),
1131                            scale: Some(2),
1132                        },
1133                    ),
1134                    ("status".to_string(), DataType::Varchar(Some(50))),
1135                ],
1136            )
1137            .unwrap();
1138        schema
1139    }
1140
1141    fn parse_and_annotate(sql: &str, schema: &MappingSchema) -> (Statement, TypeAnnotations) {
1142        let stmt = Parser::new(sql).unwrap().parse_statement().unwrap();
1143        let ann = annotate_types(&stmt, schema);
1144        (stmt, ann)
1145    }
1146
1147    /// Helper: get the type of the first SELECT column
1148    fn first_col_type(stmt: &Statement, ann: &TypeAnnotations) -> Option<DataType> {
1149        if let Statement::Select(sel) = stmt {
1150            if let Some(SelectItem::Expr { expr, .. }) = sel.columns.first() {
1151                return ann.get_type(expr).cloned();
1152            }
1153        }
1154        None
1155    }
1156
1157    // ── Literal type inference ────────────────────────────────────────
1158
1159    #[test]
1160    fn test_number_literal_int() {
1161        let schema = setup_schema();
1162        let (stmt, ann) = parse_and_annotate("SELECT 42", &schema);
1163        assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Int));
1164    }
1165
1166    #[test]
1167    fn test_number_literal_big_int() {
1168        let schema = setup_schema();
1169        let (stmt, ann) = parse_and_annotate("SELECT 9999999999", &schema);
1170        assert_eq!(first_col_type(&stmt, &ann), Some(DataType::BigInt));
1171    }
1172
1173    #[test]
1174    fn test_number_literal_double() {
1175        let schema = setup_schema();
1176        let (stmt, ann) = parse_and_annotate("SELECT 3.14", &schema);
1177        assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Double));
1178    }
1179
1180    #[test]
1181    fn test_string_literal() {
1182        let schema = setup_schema();
1183        let (stmt, ann) = parse_and_annotate("SELECT 'hello'", &schema);
1184        assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Varchar(None)));
1185    }
1186
1187    #[test]
1188    fn test_boolean_literal() {
1189        let schema = setup_schema();
1190        let (stmt, ann) = parse_and_annotate("SELECT TRUE", &schema);
1191        assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Boolean));
1192    }
1193
1194    #[test]
1195    fn test_null_literal() {
1196        let schema = setup_schema();
1197        let (stmt, ann) = parse_and_annotate("SELECT NULL", &schema);
1198        assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Null));
1199    }
1200
1201    // ── Column reference type lookup ─────────────────────────────────
1202
1203    #[test]
1204    fn test_column_type_from_schema() {
1205        let schema = setup_schema();
1206        let (stmt, ann) = parse_and_annotate("SELECT id FROM users", &schema);
1207        assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Int));
1208    }
1209
1210    #[test]
1211    fn test_qualified_column_type() {
1212        let schema = setup_schema();
1213        let (stmt, ann) = parse_and_annotate("SELECT users.name FROM users", &schema);
1214        assert_eq!(
1215            first_col_type(&stmt, &ann),
1216            Some(DataType::Varchar(Some(255)))
1217        );
1218    }
1219
1220    #[test]
1221    fn test_aliased_table_column_type() {
1222        let schema = setup_schema();
1223        let (stmt, ann) = parse_and_annotate("SELECT u.salary FROM users AS u", &schema);
1224        assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Double));
1225    }
1226
1227    // ── Binary operator type inference ───────────────────────────────
1228
1229    #[test]
1230    fn test_int_plus_int() {
1231        let schema = setup_schema();
1232        let (stmt, ann) = parse_and_annotate("SELECT id + age FROM users", &schema);
1233        assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Int));
1234    }
1235
1236    #[test]
1237    fn test_int_plus_double() {
1238        let schema = setup_schema();
1239        let (stmt, ann) = parse_and_annotate("SELECT id + salary FROM users", &schema);
1240        assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Double));
1241    }
1242
1243    #[test]
1244    fn test_comparison_returns_boolean() {
1245        let schema = setup_schema();
1246        let (stmt, ann) = parse_and_annotate("SELECT id > 5 FROM users", &schema);
1247        assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Boolean));
1248    }
1249
1250    #[test]
1251    fn test_and_returns_boolean() {
1252        let schema = setup_schema();
1253        let (stmt, ann) = parse_and_annotate("SELECT id > 5 AND age < 30 FROM users", &schema);
1254        assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Boolean));
1255    }
1256
1257    // ── CAST type inference ──────────────────────────────────────────
1258
1259    #[test]
1260    fn test_cast_type() {
1261        let schema = setup_schema();
1262        let (stmt, ann) = parse_and_annotate("SELECT CAST(id AS BIGINT) FROM users", &schema);
1263        assert_eq!(first_col_type(&stmt, &ann), Some(DataType::BigInt));
1264    }
1265
1266    #[test]
1267    fn test_cast_to_varchar() {
1268        let schema = setup_schema();
1269        let (stmt, ann) = parse_and_annotate("SELECT CAST(id AS VARCHAR) FROM users", &schema);
1270        assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Varchar(None)));
1271    }
1272
1273    // ── CASE expression ──────────────────────────────────────────────
1274
1275    #[test]
1276    fn test_case_expression_type() {
1277        let schema = setup_schema();
1278        let (stmt, ann) = parse_and_annotate(
1279            "SELECT CASE WHEN id > 1 THEN salary ELSE 0.0 END FROM users",
1280            &schema,
1281        );
1282        let t = first_col_type(&stmt, &ann);
1283        assert!(
1284            matches!(t, Some(DataType::Double)),
1285            "Expected Double, got {t:?}"
1286        );
1287    }
1288
1289    // ── Function return types ────────────────────────────────────────
1290
1291    #[test]
1292    fn test_count_returns_bigint() {
1293        let schema = setup_schema();
1294        let (stmt, ann) = parse_and_annotate("SELECT COUNT(*) FROM users", &schema);
1295        assert_eq!(first_col_type(&stmt, &ann), Some(DataType::BigInt));
1296    }
1297
1298    #[test]
1299    fn test_sum_returns_bigint_for_int() {
1300        let schema = setup_schema();
1301        let (stmt, ann) = parse_and_annotate("SELECT SUM(id) FROM users", &schema);
1302        assert_eq!(first_col_type(&stmt, &ann), Some(DataType::BigInt));
1303    }
1304
1305    #[test]
1306    fn test_avg_returns_double() {
1307        let schema = setup_schema();
1308        let (stmt, ann) = parse_and_annotate("SELECT AVG(age) FROM users", &schema);
1309        assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Double));
1310    }
1311
1312    #[test]
1313    fn test_min_preserves_type() {
1314        let schema = setup_schema();
1315        let (stmt, ann) = parse_and_annotate("SELECT MIN(salary) FROM users", &schema);
1316        assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Double));
1317    }
1318
1319    #[test]
1320    fn test_upper_returns_varchar() {
1321        let schema = setup_schema();
1322        let (stmt, ann) = parse_and_annotate("SELECT UPPER(name) FROM users", &schema);
1323        assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Varchar(None)));
1324    }
1325
1326    #[test]
1327    fn test_length_returns_int() {
1328        let schema = setup_schema();
1329        let (stmt, ann) = parse_and_annotate("SELECT LENGTH(name) FROM users", &schema);
1330        assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Int));
1331    }
1332
1333    // ── Predicate types ──────────────────────────────────────────────
1334
1335    #[test]
1336    fn test_between_returns_boolean() {
1337        let schema = setup_schema();
1338        let (stmt, ann) = parse_and_annotate("SELECT age BETWEEN 18 AND 65 FROM users", &schema);
1339        assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Boolean));
1340    }
1341
1342    #[test]
1343    fn test_in_list_returns_boolean() {
1344        let schema = setup_schema();
1345        let (stmt, ann) = parse_and_annotate("SELECT id IN (1, 2, 3) FROM users", &schema);
1346        assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Boolean));
1347    }
1348
1349    #[test]
1350    fn test_is_null_returns_boolean() {
1351        let schema = setup_schema();
1352        let (stmt, ann) = parse_and_annotate("SELECT name IS NULL FROM users", &schema);
1353        assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Boolean));
1354    }
1355
1356    #[test]
1357    fn test_like_returns_boolean() {
1358        let schema = setup_schema();
1359        let (stmt, ann) = parse_and_annotate("SELECT name LIKE '%test%' FROM users", &schema);
1360        assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Boolean));
1361    }
1362
1363    // ── Exists ───────────────────────────────────────────────────────
1364
1365    #[test]
1366    fn test_exists_returns_boolean() {
1367        let schema = setup_schema();
1368        let (stmt, ann) =
1369            parse_and_annotate("SELECT EXISTS (SELECT 1 FROM orders) FROM users", &schema);
1370        assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Boolean));
1371    }
1372
1373    // ── Nested expressions ───────────────────────────────────────────
1374
1375    #[test]
1376    fn test_nested_expression_propagation() {
1377        let schema = setup_schema();
1378        let (stmt, ann) = parse_and_annotate("SELECT (id + age) * salary FROM users", &schema);
1379        let t = first_col_type(&stmt, &ann);
1380        // INT + INT = INT, INT * DOUBLE = DOUBLE
1381        assert!(
1382            matches!(t, Some(DataType::Double)),
1383            "Expected Double, got {t:?}"
1384        );
1385    }
1386
1387    // ── EXTRACT ──────────────────────────────────────────────────────
1388
1389    #[test]
1390    fn test_extract_returns_int() {
1391        let schema = setup_schema();
1392        let (stmt, ann) =
1393            parse_and_annotate("SELECT EXTRACT(YEAR FROM created_at) FROM users", &schema);
1394        assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Int));
1395    }
1396
1397    // ── Multiple columns ─────────────────────────────────────────────
1398
1399    #[test]
1400    fn test_multiple_columns_annotated() {
1401        let schema = setup_schema();
1402        let (stmt, ann) = parse_and_annotate("SELECT id, name, salary FROM users", &schema);
1403        if let Statement::Select(sel) = &stmt {
1404            assert_eq!(sel.columns.len(), 3);
1405            // id → Int
1406            if let SelectItem::Expr { expr, .. } = &sel.columns[0] {
1407                assert_eq!(ann.get_type(expr), Some(&DataType::Int));
1408            }
1409            // name → Varchar(255)
1410            if let SelectItem::Expr { expr, .. } = &sel.columns[1] {
1411                assert_eq!(ann.get_type(expr), Some(&DataType::Varchar(Some(255))));
1412            }
1413            // salary → Double
1414            if let SelectItem::Expr { expr, .. } = &sel.columns[2] {
1415                assert_eq!(ann.get_type(expr), Some(&DataType::Double));
1416            }
1417        }
1418    }
1419
1420    // ── WHERE clause annotation ──────────────────────────────────────
1421
1422    #[test]
1423    fn test_where_clause_annotated() {
1424        let schema = setup_schema();
1425        // Don't move stmt after annotation — raw pointers for inline fields
1426        // (like where_clause: Option<Expr>) are invalidated on move.
1427        let stmt = Parser::new("SELECT id FROM users WHERE age > 21")
1428            .unwrap()
1429            .parse_statement()
1430            .unwrap();
1431        let ann = annotate_types(&stmt, &schema);
1432        if let Statement::Select(sel) = &stmt {
1433            if let Some(wh) = &sel.where_clause {
1434                assert_eq!(ann.get_type(wh), Some(&DataType::Boolean));
1435            }
1436        }
1437    }
1438
1439    // ── Coercion rules ──────────────────────────────────────────────
1440
1441    #[test]
1442    fn test_int_and_bigint_coercion() {
1443        assert_eq!(
1444            coerce_numeric(&DataType::Int, &DataType::BigInt),
1445            DataType::BigInt
1446        );
1447    }
1448
1449    #[test]
1450    fn test_float_and_double_coercion() {
1451        assert_eq!(
1452            coerce_numeric(&DataType::Float, &DataType::Double),
1453            DataType::Double
1454        );
1455    }
1456
1457    #[test]
1458    fn test_int_and_double_coercion() {
1459        assert_eq!(
1460            coerce_numeric(&DataType::Int, &DataType::Double),
1461            DataType::Double
1462        );
1463    }
1464
1465    // ── Common type ─────────────────────────────────────────────────
1466
1467    #[test]
1468    fn test_common_type_nulls_skipped() {
1469        let types = vec![&DataType::Null, &DataType::Int, &DataType::Null];
1470        assert_eq!(common_type(&types), Some(DataType::Int));
1471    }
1472
1473    #[test]
1474    fn test_common_type_numeric_widening() {
1475        let types = vec![&DataType::Int, &DataType::Double, &DataType::Float];
1476        assert_eq!(common_type(&types), Some(DataType::Double));
1477    }
1478
1479    #[test]
1480    fn test_common_type_empty() {
1481        let types: Vec<&DataType> = vec![];
1482        assert_eq!(common_type(&types), None);
1483    }
1484
1485    // ── UDF type support ─────────────────────────────────────────────
1486
1487    #[test]
1488    fn test_udf_return_type() {
1489        let mut schema = setup_schema();
1490        schema.add_udf("my_func", DataType::Varchar(None));
1491        let (stmt, ann) = parse_and_annotate("SELECT my_func(id) FROM users", &schema);
1492        assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Varchar(None)));
1493    }
1494
1495    // ── Annotation count ─────────────────────────────────────────────
1496
1497    #[test]
1498    fn test_annotations_not_empty() {
1499        let schema = setup_schema();
1500        let (_, ann) = parse_and_annotate("SELECT id, name FROM users WHERE age > 21", &schema);
1501        assert!(!ann.is_empty());
1502        // Should have at least the SELECT columns and WHERE predicate
1503        assert!(ann.len() >= 3);
1504    }
1505
1506    // ── SUM of DECIMAL preserves precision ───────────────────────────
1507
1508    #[test]
1509    fn test_sum_decimal_preserves_type() {
1510        let schema = setup_schema();
1511        let (stmt, ann) = parse_and_annotate("SELECT SUM(amount) FROM orders", &schema);
1512        assert_eq!(
1513            first_col_type(&stmt, &ann),
1514            Some(DataType::Decimal {
1515                precision: Some(10),
1516                scale: Some(2)
1517            })
1518        );
1519    }
1520}