Skip to main content

sqlglot_rust/optimizer/
annotate_types.rs

1//! Type annotation pass for SQL expressions.
2//!
3//! Infers and propagates SQL data types across AST nodes using schema metadata.
4//! Inspired by Python sqlglot's `annotate_types` optimizer pass.
5//!
6//! # Overview
7//!
8//! The pass walks the AST bottom-up, resolving types for:
9//! - **Literals**: `42` → `Int`, `'hello'` → `Varchar`, `TRUE` → `Boolean`
10//! - **Column references**: looked up from the provided [`Schema`]
11//! - **Binary operators**: result type from operand coercion (e.g. `INT + FLOAT → FLOAT`)
12//! - **CAST / TRY_CAST**: the target data type
13//! - **Functions**: return type based on function signature and argument types
14//! - **CASE**: common type across all THEN / ELSE branches
15//! - **Aggregates**: `COUNT → BigInt`, `SUM` depends on input, etc.
16//! - **Subqueries**: type of the single output column
17//!
18//! # Example
19//!
20//! ```rust
21//! use sqlglot_rust::optimizer::annotate_types::annotate_types;
22//! use sqlglot_rust::schema::{MappingSchema, Schema};
23//! use sqlglot_rust::ast::DataType;
24//! use sqlglot_rust::{parse, Dialect};
25//!
26//! let mut schema = MappingSchema::new(Dialect::Ansi);
27//! schema.add_table(&["t"], vec![
28//!     ("id".to_string(), DataType::Int),
29//!     ("name".to_string(), DataType::Varchar(Some(255))),
30//! ]).unwrap();
31//!
32//! let stmt = parse("SELECT id, name FROM t WHERE id > 1", Dialect::Ansi).unwrap();
33//! let annotations = annotate_types(&stmt, &schema);
34//! // annotations now contains inferred types for every expression node
35//! ```
36
37use std::collections::HashMap;
38
39use crate::ast::*;
40use crate::schema::Schema;
41
42// ═══════════════════════════════════════════════════════════════════════
43// TypeAnnotations — the result of type inference
44// ═══════════════════════════════════════════════════════════════════════
45
46/// Stores inferred [`DataType`] annotations for expression nodes in an AST.
47///
48/// Annotations are keyed by raw pointer identity, so this structure is valid
49/// only as long as the underlying AST is not moved, cloned, or dropped.
50/// Intended for single-pass analysis over a borrowed AST.
51pub struct TypeAnnotations {
52    types: HashMap<*const Expr, DataType>,
53}
54
55// Raw pointers are not Send/Sync by default, but our usage is safe because
56// the pointers are derived from shared references with a known lifetime.
57unsafe impl Send for TypeAnnotations {}
58unsafe impl Sync for TypeAnnotations {}
59
60impl TypeAnnotations {
61    fn new() -> Self {
62        Self {
63            types: HashMap::new(),
64        }
65    }
66
67    fn set(&mut self, expr: &Expr, dt: DataType) {
68        self.types.insert(expr as *const Expr, dt);
69    }
70
71    /// Retrieve the inferred type of an expression, if annotated.
72    #[must_use]
73    pub fn get_type(&self, expr: &Expr) -> Option<&DataType> {
74        self.types.get(&(expr as *const Expr))
75    }
76
77    /// Number of annotated nodes.
78    #[must_use]
79    pub fn len(&self) -> usize {
80        self.types.len()
81    }
82
83    /// Returns `true` if no annotations were recorded.
84    #[must_use]
85    pub fn is_empty(&self) -> bool {
86        self.types.is_empty()
87    }
88}
89
90// ═══════════════════════════════════════════════════════════════════════
91// Public entry point
92// ═══════════════════════════════════════════════════════════════════════
93
94/// Annotate all expression nodes in a statement with inferred SQL types.
95///
96/// Walks the AST bottom-up, resolving types from literals, schema column
97/// lookups, operator/function signatures, and type coercion rules.
98///
99/// The returned [`TypeAnnotations`] is valid only while the borrowed `stmt`
100/// is alive and unmodified.
101#[must_use]
102pub fn annotate_types<S: Schema>(stmt: &Statement, schema: &S) -> TypeAnnotations {
103    let mut ann = TypeAnnotations::new();
104    let mut ctx = AnnotationContext::new(schema);
105    annotate_statement(stmt, &mut ctx, &mut ann);
106    ann
107}
108
109// ═══════════════════════════════════════════════════════════════════════
110// Internal context
111// ═══════════════════════════════════════════════════════════════════════
112
113/// Carries schema reference and table alias mappings through the walk.
114struct AnnotationContext<'s, S: Schema> {
115    schema: &'s S,
116    /// Maps table alias or name → table path for column type lookups.
117    table_aliases: HashMap<String, Vec<String>>,
118}
119
120impl<'s, S: Schema> AnnotationContext<'s, S> {
121    fn new(schema: &'s S) -> Self {
122        Self {
123            schema,
124            table_aliases: HashMap::new(),
125        }
126    }
127
128    /// Register a table (by ref) so that columns can be looked up by alias.
129    fn register_table(&mut self, table_ref: &TableRef) {
130        let path = vec![table_ref.name.clone()];
131        let alias = table_ref
132            .alias
133            .as_deref()
134            .unwrap_or(&table_ref.name)
135            .to_string();
136        self.table_aliases.insert(alias, path);
137    }
138
139    /// Look up the type of a column, resolving through table aliases.
140    fn resolve_column_type(&self, table: Option<&str>, column: &str) -> Option<DataType> {
141        if let Some(tbl) = table {
142            // Qualified column — look up via alias map
143            if let Some(path) = self.table_aliases.get(tbl) {
144                let path_refs: Vec<&str> = path.iter().map(String::as_str).collect();
145                return self.schema.get_column_type(&path_refs, column).ok();
146            }
147            // Try the table name directly
148            return self.schema.get_column_type(&[tbl], column).ok();
149        }
150        // Unqualified — search all registered tables
151        for path in self.table_aliases.values() {
152            let path_refs: Vec<&str> = path.iter().map(String::as_str).collect();
153            if let Ok(dt) = self.schema.get_column_type(&path_refs, column) {
154                return Some(dt);
155            }
156        }
157        None
158    }
159}
160
161// ═══════════════════════════════════════════════════════════════════════
162// Statement-level annotation
163// ═══════════════════════════════════════════════════════════════════════
164
165fn annotate_statement<S: Schema>(
166    stmt: &Statement,
167    ctx: &mut AnnotationContext<S>,
168    ann: &mut TypeAnnotations,
169) {
170    match stmt {
171        Statement::Select(sel) => annotate_select(sel, ctx, ann),
172        Statement::SetOperation(set_op) => {
173            annotate_statement(&set_op.left, ctx, ann);
174            annotate_statement(&set_op.right, ctx, ann);
175        }
176        Statement::Insert(ins) => {
177            if let InsertSource::Query(q) = &ins.source {
178                annotate_statement(q, ctx, ann);
179            }
180            for row in match &ins.source {
181                InsertSource::Values(rows) => rows.as_slice(),
182                _ => &[],
183            } {
184                for expr in row {
185                    annotate_expr(expr, ctx, ann);
186                }
187            }
188        }
189        Statement::Update(upd) => {
190            for (_, expr) in &upd.assignments {
191                annotate_expr(expr, ctx, ann);
192            }
193            if let Some(wh) = &upd.where_clause {
194                annotate_expr(wh, ctx, ann);
195            }
196        }
197        Statement::Delete(del) => {
198            if let Some(wh) = &del.where_clause {
199                annotate_expr(wh, ctx, ann);
200            }
201        }
202        Statement::Expression(expr) => {
203            annotate_expr(expr, ctx, ann);
204        }
205        Statement::Explain(expl) => {
206            annotate_statement(&expl.statement, ctx, ann);
207        }
208        // DDL / transaction / other statements — no expression types to annotate
209        _ => {}
210    }
211}
212
213fn annotate_select<S: Schema>(
214    sel: &SelectStatement,
215    ctx: &mut AnnotationContext<S>,
216    ann: &mut TypeAnnotations,
217) {
218    // 1. Register CTEs
219    for cte in &sel.ctes {
220        annotate_statement(&cte.query, ctx, ann);
221    }
222
223    // 2. Register FROM sources
224    if let Some(from) = &sel.from {
225        register_table_source(&from.source, ctx);
226    }
227    for join in &sel.joins {
228        register_table_source(&join.table, ctx);
229    }
230
231    // 3. Annotate WHERE clause
232    if let Some(wh) = &sel.where_clause {
233        annotate_expr(wh, ctx, ann);
234    }
235
236    // 4. Annotate SELECT columns
237    for item in &sel.columns {
238        if let SelectItem::Expr { expr, .. } = item {
239            annotate_expr(expr, ctx, ann);
240        }
241    }
242
243    // 5. Annotate GROUP BY
244    for expr in &sel.group_by {
245        annotate_expr(expr, ctx, ann);
246    }
247
248    // 6. Annotate HAVING
249    if let Some(having) = &sel.having {
250        annotate_expr(having, ctx, ann);
251    }
252
253    // 7. Annotate ORDER BY
254    for ob in &sel.order_by {
255        annotate_expr(&ob.expr, ctx, ann);
256    }
257
258    // 8. Annotate LIMIT / OFFSET
259    if let Some(limit) = &sel.limit {
260        annotate_expr(limit, ctx, ann);
261    }
262    if let Some(offset) = &sel.offset {
263        annotate_expr(offset, ctx, ann);
264    }
265    if let Some(fetch) = &sel.fetch_first {
266        annotate_expr(fetch, ctx, ann);
267    }
268
269    // 9. Annotate QUALIFY
270    if let Some(qualify) = &sel.qualify {
271        annotate_expr(qualify, ctx, ann);
272    }
273
274    // 10. Annotate JOIN ON conditions
275    for join in &sel.joins {
276        if let Some(on) = &join.on {
277            annotate_expr(on, ctx, ann);
278        }
279    }
280}
281
282fn register_table_source<S: Schema>(source: &TableSource, ctx: &mut AnnotationContext<S>) {
283    match source {
284        TableSource::Table(tref) => ctx.register_table(tref),
285        TableSource::Subquery { alias, .. } => {
286            // Subqueries as sources don't have schema entries to register.
287            // Their output column types would come from recursive annotation.
288            let _ = alias;
289        }
290        TableSource::TableFunction { alias, .. } => {
291            let _ = alias;
292        }
293        TableSource::Lateral { source } => register_table_source(source, ctx),
294        TableSource::Unnest { .. } => {}
295    }
296}
297
298// ═══════════════════════════════════════════════════════════════════════
299// Expression-level annotation (bottom-up)
300// ═══════════════════════════════════════════════════════════════════════
301
302fn annotate_expr<S: Schema>(expr: &Expr, ctx: &AnnotationContext<S>, ann: &mut TypeAnnotations) {
303    // First annotate children, then determine this node's type.
304    annotate_children(expr, ctx, ann);
305
306    let dt = infer_type(expr, ctx, ann);
307    if let Some(t) = dt {
308        ann.set(expr, t);
309    }
310}
311
312/// Recursively annotate child expressions before the parent.
313fn annotate_children<S: Schema>(
314    expr: &Expr,
315    ctx: &AnnotationContext<S>,
316    ann: &mut TypeAnnotations,
317) {
318    match expr {
319        Expr::BinaryOp { left, right, .. } => {
320            annotate_expr(left, ctx, ann);
321            annotate_expr(right, ctx, ann);
322        }
323        Expr::UnaryOp { expr: inner, .. } => annotate_expr(inner, ctx, ann),
324        Expr::Function { args, filter, .. } => {
325            for arg in args {
326                annotate_expr(arg, ctx, ann);
327            }
328            if let Some(f) = filter {
329                annotate_expr(f, ctx, ann);
330            }
331        }
332        Expr::Between {
333            expr: e, low, high, ..
334        } => {
335            annotate_expr(e, ctx, ann);
336            annotate_expr(low, ctx, ann);
337            annotate_expr(high, ctx, ann);
338        }
339        Expr::InList { expr: e, list, .. } => {
340            annotate_expr(e, ctx, ann);
341            for item in list {
342                annotate_expr(item, ctx, ann);
343            }
344        }
345        Expr::InSubquery {
346            expr: e, subquery, ..
347        } => {
348            annotate_expr(e, ctx, ann);
349            let mut sub_ctx = AnnotationContext::new(ctx.schema);
350            annotate_statement(subquery, &mut sub_ctx, ann);
351        }
352        Expr::IsNull { expr: e, .. } | Expr::IsBool { expr: e, .. } => {
353            annotate_expr(e, ctx, ann);
354        }
355        Expr::Like {
356            expr: e,
357            pattern,
358            escape,
359            ..
360        }
361        | Expr::ILike {
362            expr: e,
363            pattern,
364            escape,
365            ..
366        } => {
367            annotate_expr(e, ctx, ann);
368            annotate_expr(pattern, ctx, ann);
369            if let Some(esc) = escape {
370                annotate_expr(esc, ctx, ann);
371            }
372        }
373        Expr::Case {
374            operand,
375            when_clauses,
376            else_clause,
377        } => {
378            if let Some(op) = operand {
379                annotate_expr(op, ctx, ann);
380            }
381            for (cond, result) in when_clauses {
382                annotate_expr(cond, ctx, ann);
383                annotate_expr(result, ctx, ann);
384            }
385            if let Some(el) = else_clause {
386                annotate_expr(el, ctx, ann);
387            }
388        }
389        Expr::Nested(inner) => annotate_expr(inner, ctx, ann),
390        Expr::Cast { expr: e, .. } | Expr::TryCast { expr: e, .. } => {
391            annotate_expr(e, ctx, ann);
392        }
393        Expr::Extract { expr: e, .. } => annotate_expr(e, ctx, ann),
394        Expr::Interval { value, .. } => annotate_expr(value, ctx, ann),
395        Expr::ArrayLiteral(items) | Expr::Tuple(items) | Expr::Coalesce(items) => {
396            for item in items {
397                annotate_expr(item, ctx, ann);
398            }
399        }
400        Expr::If {
401            condition,
402            true_val,
403            false_val,
404        } => {
405            annotate_expr(condition, ctx, ann);
406            annotate_expr(true_val, ctx, ann);
407            if let Some(fv) = false_val {
408                annotate_expr(fv, ctx, ann);
409            }
410        }
411        Expr::NullIf { expr: e, r#else } => {
412            annotate_expr(e, ctx, ann);
413            annotate_expr(r#else, ctx, ann);
414        }
415        Expr::Collate { expr: e, .. } => annotate_expr(e, ctx, ann),
416        Expr::Alias { expr: e, .. } => annotate_expr(e, ctx, ann),
417        Expr::ArrayIndex { expr: e, index } => {
418            annotate_expr(e, ctx, ann);
419            annotate_expr(index, ctx, ann);
420        }
421        Expr::JsonAccess { expr: e, path, .. } => {
422            annotate_expr(e, ctx, ann);
423            annotate_expr(path, ctx, ann);
424        }
425        Expr::Lambda { body, .. } => annotate_expr(body, ctx, ann),
426        Expr::AnyOp { expr: e, right, .. } | Expr::AllOp { expr: e, right, .. } => {
427            annotate_expr(e, ctx, ann);
428            annotate_expr(right, ctx, ann);
429        }
430        Expr::Subquery(sub) => {
431            let mut sub_ctx = AnnotationContext::new(ctx.schema);
432            annotate_statement(sub, &mut sub_ctx, ann);
433        }
434        Expr::Exists { subquery, .. } => {
435            let mut sub_ctx = AnnotationContext::new(ctx.schema);
436            annotate_statement(subquery, &mut sub_ctx, ann);
437        }
438        Expr::TypedFunction { func, filter, .. } => {
439            annotate_typed_function_children(func, ctx, ann);
440            if let Some(f) = filter {
441                annotate_expr(f, ctx, ann);
442            }
443        }
444        // Leaf nodes — no children to annotate
445        Expr::Column { .. }
446        | Expr::Number(_)
447        | Expr::StringLiteral(_)
448        | Expr::Boolean(_)
449        | Expr::Null
450        | Expr::Wildcard
451        | Expr::Star
452        | Expr::Parameter(_)
453        | Expr::TypeExpr(_)
454        | Expr::QualifiedWildcard { .. }
455        | Expr::Default => {}
456    }
457}
458
459/// Annotate children of a TypedFunction.
460fn annotate_typed_function_children<S: Schema>(
461    func: &TypedFunction,
462    ctx: &AnnotationContext<S>,
463    ann: &mut TypeAnnotations,
464) {
465    // Use walk_children to visit all child expressions and annotate each
466    func.walk_children(&mut |child| {
467        annotate_expr(child, ctx, ann);
468        true
469    });
470}
471
472// ═══════════════════════════════════════════════════════════════════════
473// Type inference for a single expression node
474// ═══════════════════════════════════════════════════════════════════════
475
476fn infer_type<S: Schema>(
477    expr: &Expr,
478    ctx: &AnnotationContext<S>,
479    ann: &TypeAnnotations,
480) -> Option<DataType> {
481    match expr {
482        // ── Literals ───────────────────────────────────────────────────
483        Expr::Number(s) => Some(infer_number_type(s)),
484        Expr::StringLiteral(_) => Some(DataType::Varchar(None)),
485        Expr::Boolean(_) => Some(DataType::Boolean),
486        Expr::Null => Some(DataType::Null),
487
488        // ── Column reference ──────────────────────────────────────────
489        Expr::Column { table, name, .. } => ctx.resolve_column_type(table.as_deref(), name),
490
491        // ── Binary operators ──────────────────────────────────────────
492        Expr::BinaryOp { left, op, right } => {
493            infer_binary_op_type(op, ann.get_type(left), ann.get_type(right))
494        }
495
496        // ── Unary operators ───────────────────────────────────────────
497        Expr::UnaryOp { op, expr: inner } => match op {
498            UnaryOperator::Not => Some(DataType::Boolean),
499            UnaryOperator::Minus | UnaryOperator::Plus => ann.get_type(inner).cloned(),
500            UnaryOperator::BitwiseNot => ann.get_type(inner).cloned(),
501        },
502
503        // ── CAST / TRY_CAST ──────────────────────────────────────────
504        Expr::Cast { data_type, .. } | Expr::TryCast { data_type, .. } => Some(data_type.clone()),
505
506        // ── CASE expression ──────────────────────────────────────────
507        Expr::Case {
508            when_clauses,
509            else_clause,
510            ..
511        } => {
512            let mut result_types: Vec<&DataType> = Vec::new();
513            for (_, result) in when_clauses {
514                if let Some(t) = ann.get_type(result) {
515                    result_types.push(t);
516                }
517            }
518            if let Some(el) = else_clause {
519                if let Some(t) = ann.get_type(el.as_ref()) {
520                    result_types.push(t);
521                }
522            }
523            common_type(&result_types)
524        }
525
526        // ── IF expression ────────────────────────────────────────────
527        Expr::If {
528            true_val,
529            false_val,
530            ..
531        } => {
532            let mut types = Vec::new();
533            if let Some(t) = ann.get_type(true_val) {
534                types.push(t);
535            }
536            if let Some(fv) = false_val {
537                if let Some(t) = ann.get_type(fv.as_ref()) {
538                    types.push(t);
539                }
540            }
541            common_type(&types)
542        }
543
544        // ── COALESCE ─────────────────────────────────────────────────
545        Expr::Coalesce(items) => {
546            let types: Vec<&DataType> = items.iter().filter_map(|e| ann.get_type(e)).collect();
547            common_type(&types)
548        }
549
550        // ── NULLIF ───────────────────────────────────────────────────
551        Expr::NullIf { expr: e, .. } => ann.get_type(e.as_ref()).cloned(),
552
553        // ── Generic function ─────────────────────────────────────────
554        Expr::Function { name, args, .. } => infer_generic_function_type(name, args, ctx, ann),
555
556        // ── Typed functions ──────────────────────────────────────────
557        Expr::TypedFunction { func, .. } => infer_typed_function_type(func, ann),
558
559        // ── Subquery (scalar) ────────────────────────────────────────
560        Expr::Subquery(sub) => infer_subquery_type(sub, ann),
561
562        // ── EXISTS → Boolean ─────────────────────────────────────────
563        Expr::Exists { .. } => Some(DataType::Boolean),
564
565        // ── Boolean predicates ───────────────────────────────────────
566        Expr::Between { .. }
567        | Expr::InList { .. }
568        | Expr::InSubquery { .. }
569        | Expr::IsNull { .. }
570        | Expr::IsBool { .. }
571        | Expr::Like { .. }
572        | Expr::ILike { .. }
573        | Expr::AnyOp { .. }
574        | Expr::AllOp { .. } => Some(DataType::Boolean),
575
576        // ── EXTRACT → numeric ────────────────────────────────────────
577        Expr::Extract { .. } => Some(DataType::Int),
578
579        // ── INTERVAL → Interval ──────────────────────────────────────
580        Expr::Interval { .. } => Some(DataType::Interval),
581
582        // ── Array literal ────────────────────────────────────────────
583        Expr::ArrayLiteral(items) => {
584            let elem_types: Vec<&DataType> = items.iter().filter_map(|e| ann.get_type(e)).collect();
585            let elem = common_type(&elem_types);
586            Some(DataType::Array(elem.map(Box::new)))
587        }
588
589        // ── Tuple ────────────────────────────────────────────────────
590        Expr::Tuple(items) => {
591            let types: Vec<DataType> = items
592                .iter()
593                .map(|e| ann.get_type(e).cloned().unwrap_or(DataType::Null))
594                .collect();
595            Some(DataType::Tuple(types))
596        }
597
598        // ── Array index → element type ───────────────────────────────
599        Expr::ArrayIndex { expr: e, .. } => match ann.get_type(e.as_ref()) {
600            Some(DataType::Array(Some(elem))) => Some(elem.as_ref().clone()),
601            _ => None,
602        },
603
604        // ── JSON access ──────────────────────────────────────────────
605        Expr::JsonAccess { as_text, .. } => {
606            if *as_text {
607                Some(DataType::Text)
608            } else {
609                Some(DataType::Json)
610            }
611        }
612
613        // ── Nested / Alias — pass through ────────────────────────────
614        Expr::Nested(inner) => ann.get_type(inner.as_ref()).cloned(),
615        Expr::Alias { expr: e, .. } => ann.get_type(e.as_ref()).cloned(),
616
617        // ── Collate → Varchar ────────────────────────────────────────
618        Expr::Collate { .. } => Some(DataType::Varchar(None)),
619
620        // ── TypeExpr ─────────────────────────────────────────────────
621        Expr::TypeExpr(dt) => Some(dt.clone()),
622
623        // ── Others — no type ─────────────────────────────────────────
624        Expr::Wildcard
625        | Expr::Star
626        | Expr::QualifiedWildcard { .. }
627        | Expr::Parameter(_)
628        | Expr::Lambda { .. }
629        | Expr::Default => None,
630    }
631}
632
633// ═══════════════════════════════════════════════════════════════════════
634// Number type inference
635// ═══════════════════════════════════════════════════════════════════════
636
637fn infer_number_type(s: &str) -> DataType {
638    if s.contains('.') || s.contains('e') || s.contains('E') {
639        DataType::Double
640    } else if let Ok(v) = s.parse::<i64>() {
641        if v >= i32::MIN as i64 && v <= i32::MAX as i64 {
642            DataType::Int
643        } else {
644            DataType::BigInt
645        }
646    } else {
647        // Very large numbers or special formats
648        DataType::BigInt
649    }
650}
651
652// ═══════════════════════════════════════════════════════════════════════
653// Binary operator type inference
654// ═══════════════════════════════════════════════════════════════════════
655
656fn infer_binary_op_type(
657    op: &BinaryOperator,
658    left: Option<&DataType>,
659    right: Option<&DataType>,
660) -> Option<DataType> {
661    use BinaryOperator::*;
662    match op {
663        // Comparison operators → Boolean
664        Eq | Neq | Lt | Gt | LtEq | GtEq => Some(DataType::Boolean),
665
666        // Logical operators → Boolean
667        And | Or | Xor => Some(DataType::Boolean),
668
669        // String concatenation → Varchar
670        Concat => Some(DataType::Varchar(None)),
671
672        // Arithmetic → coerce operand types
673        Plus | Minus | Multiply | Divide | Modulo => match (left, right) {
674            (Some(l), Some(r)) => Some(coerce_numeric(l, r)),
675            (Some(l), None) => Some(l.clone()),
676            (None, Some(r)) => Some(r.clone()),
677            (None, None) => None,
678        },
679
680        // Bitwise → integer type
681        BitwiseAnd | BitwiseOr | BitwiseXor | ShiftLeft | ShiftRight => match (left, right) {
682            (Some(l), Some(r)) => Some(coerce_numeric(l, r)),
683            (Some(l), None) => Some(l.clone()),
684            (None, Some(r)) => Some(r.clone()),
685            (None, None) => Some(DataType::Int),
686        },
687
688        // JSON operators
689        Arrow => Some(DataType::Json),
690        DoubleArrow => Some(DataType::Text),
691    }
692}
693
694// ═══════════════════════════════════════════════════════════════════════
695// Generic (untyped) function return type inference
696// ═══════════════════════════════════════════════════════════════════════
697
698fn infer_generic_function_type<S: Schema>(
699    name: &str,
700    args: &[Expr],
701    ctx: &AnnotationContext<S>,
702    ann: &TypeAnnotations,
703) -> Option<DataType> {
704    let upper = name.to_uppercase();
705    match upper.as_str() {
706        // Aggregate functions
707        "COUNT" | "COUNT_BIG" => Some(DataType::BigInt),
708        "SUM" => args
709            .first()
710            .and_then(|a| ann.get_type(a))
711            .map(|t| coerce_sum_type(t)),
712        "AVG" => Some(DataType::Double),
713        "MIN" | "MAX" => args.first().and_then(|a| ann.get_type(a)).cloned(),
714        "VARIANCE" | "VAR_SAMP" | "VAR_POP" | "STDDEV" | "STDDEV_SAMP" | "STDDEV_POP" => {
715            Some(DataType::Double)
716        }
717        "APPROX_COUNT_DISTINCT" | "APPROX_DISTINCT" => Some(DataType::BigInt),
718
719        // String functions
720        "CONCAT" | "UPPER" | "LOWER" | "TRIM" | "LTRIM" | "RTRIM" | "LPAD" | "RPAD" | "REPLACE"
721        | "REVERSE" | "SUBSTRING" | "SUBSTR" | "LEFT" | "RIGHT" | "INITCAP" | "REPEAT"
722        | "TRANSLATE" | "FORMAT" | "CONCAT_WS" | "SPACE" | "REPLICATE" => {
723            Some(DataType::Varchar(None))
724        }
725        "LENGTH" | "LEN" | "CHAR_LENGTH" | "CHARACTER_LENGTH" | "OCTET_LENGTH" | "BIT_LENGTH" => {
726            Some(DataType::Int)
727        }
728        "POSITION" | "STRPOS" | "LOCATE" | "INSTR" | "CHARINDEX" => Some(DataType::Int),
729        "ASCII" => Some(DataType::Int),
730        "CHR" | "CHAR" => Some(DataType::Varchar(Some(1))),
731
732        // Math functions
733        "ABS" | "CEIL" | "CEILING" | "FLOOR" => args.first().and_then(|a| ann.get_type(a)).cloned(),
734        "ROUND" | "TRUNCATE" | "TRUNC" => args.first().and_then(|a| ann.get_type(a)).cloned(),
735        "SQRT" | "LN" | "LOG" | "LOG2" | "LOG10" | "EXP" | "POWER" | "POW" | "ACOS" | "ASIN"
736        | "ATAN" | "ATAN2" | "COS" | "SIN" | "TAN" | "COT" | "DEGREES" | "RADIANS" | "PI"
737        | "SIGN" => Some(DataType::Double),
738        "MOD" => {
739            match (
740                args.first().and_then(|a| ann.get_type(a)),
741                args.get(1).and_then(|a| ann.get_type(a)),
742            ) {
743                (Some(l), Some(r)) => Some(coerce_numeric(l, r)),
744                (Some(l), _) => Some(l.clone()),
745                (_, Some(r)) => Some(r.clone()),
746                _ => Some(DataType::Int),
747            }
748        }
749        "GREATEST" | "LEAST" => {
750            let types: Vec<&DataType> = args.iter().filter_map(|a| ann.get_type(a)).collect();
751            common_type(&types)
752        }
753        "RANDOM" | "RAND" => Some(DataType::Double),
754
755        // Date/Time functions
756        "CURRENT_DATE" | "CURDATE" | "TODAY" => Some(DataType::Date),
757        "CURRENT_TIMESTAMP" | "NOW" | "GETDATE" | "SYSDATE" | "SYSTIMESTAMP" | "LOCALTIMESTAMP" => {
758            Some(DataType::Timestamp {
759                precision: None,
760                with_tz: false,
761            })
762        }
763        "CURRENT_TIME" | "CURTIME" => Some(DataType::Time { precision: None }),
764        "DATE" | "TO_DATE" | "DATE_TRUNC" | "DATE_ADD" | "DATE_SUB" | "DATEADD" | "DATESUB"
765        | "ADDDATE" | "SUBDATE" => Some(DataType::Date),
766        "TIMESTAMP" | "TO_TIMESTAMP" => Some(DataType::Timestamp {
767            precision: None,
768            with_tz: false,
769        }),
770        "YEAR" | "MONTH" | "DAY" | "DAYOFWEEK" | "DAYOFYEAR" | "HOUR" | "MINUTE" | "SECOND"
771        | "QUARTER" | "WEEK" | "EXTRACT" | "DATEDIFF" | "TIMESTAMPDIFF" | "MONTHS_BETWEEN" => {
772            Some(DataType::Int)
773        }
774
775        // Type conversion
776        "CAST" | "TRY_CAST" | "SAFE_CAST" | "CONVERT" => None, // handled by Expr::Cast
777
778        // Boolean functions
779        "COALESCE" => {
780            let types: Vec<&DataType> = args.iter().filter_map(|a| ann.get_type(a)).collect();
781            common_type(&types)
782        }
783        "NULLIF" => args.first().and_then(|a| ann.get_type(a)).cloned(),
784        "IF" | "IIF" => {
785            // IF(cond, true_val, false_val) — type from true_val
786            args.get(1).and_then(|a| ann.get_type(a)).cloned()
787        }
788        "IFNULL" | "NVL" | "ISNULL" => {
789            let types: Vec<&DataType> = args.iter().filter_map(|a| ann.get_type(a)).collect();
790            common_type(&types)
791        }
792
793        // JSON functions
794        "JSON_EXTRACT" | "JSON_QUERY" | "GET_JSON_OBJECT" => Some(DataType::Json),
795        "JSON_EXTRACT_SCALAR" | "JSON_VALUE" | "JSON_EXTRACT_PATH_TEXT" => {
796            Some(DataType::Varchar(None))
797        }
798        "TO_JSON" | "JSON_OBJECT" | "JSON_ARRAY" | "JSON_BUILD_OBJECT" | "JSON_BUILD_ARRAY" => {
799            Some(DataType::Json)
800        }
801        "PARSE_JSON" | "JSON_PARSE" | "JSON" => Some(DataType::Json),
802
803        // Array functions
804        "ARRAY_AGG" | "COLLECT_LIST" | "COLLECT_SET" => {
805            let elem = args.first().and_then(|a| ann.get_type(a)).cloned();
806            Some(DataType::Array(elem.map(Box::new)))
807        }
808        "ARRAY_LENGTH" | "ARRAY_SIZE" | "CARDINALITY" => Some(DataType::Int),
809        "ARRAY" | "ARRAY_CONSTRUCT" => {
810            let types: Vec<&DataType> = args.iter().filter_map(|a| ann.get_type(a)).collect();
811            let elem = common_type(&types);
812            Some(DataType::Array(elem.map(Box::new)))
813        }
814        "ARRAY_CONTAINS" | "ARRAY_POSITION" => Some(DataType::Boolean),
815
816        // Window ranking
817        "ROW_NUMBER" | "RANK" | "DENSE_RANK" | "NTILE" | "CUME_DIST" | "PERCENT_RANK" => {
818            Some(DataType::BigInt)
819        }
820
821        // Hash / crypto
822        "MD5" | "SHA1" | "SHA" | "SHA2" | "SHA256" | "SHA512" => Some(DataType::Varchar(None)),
823        "HEX" | "TO_HEX" => Some(DataType::Varchar(None)),
824        "UNHEX" | "FROM_HEX" => Some(DataType::Varbinary(None)),
825        "CRC32" | "HASH" => Some(DataType::BigInt),
826
827        // Type checking
828        "TYPEOF" | "TYPE_OF" => Some(DataType::Varchar(None)),
829
830        // UDFs — check schema
831        _ => ctx.schema.get_udf_type(&upper).cloned(),
832    }
833}
834
835// ═══════════════════════════════════════════════════════════════════════
836// TypedFunction return type inference
837// ═══════════════════════════════════════════════════════════════════════
838
839fn infer_typed_function_type(func: &TypedFunction, ann: &TypeAnnotations) -> Option<DataType> {
840    match func {
841        // ── Date/Time → Date or Timestamp ────────────────────────────
842        TypedFunction::DateAdd { .. }
843        | TypedFunction::DateSub { .. }
844        | TypedFunction::DateTrunc { .. }
845        | TypedFunction::TsOrDsToDate { .. } => Some(DataType::Date),
846        TypedFunction::DateDiff { .. } => Some(DataType::Int),
847        TypedFunction::CurrentDate => Some(DataType::Date),
848        TypedFunction::CurrentTimestamp => Some(DataType::Timestamp {
849            precision: None,
850            with_tz: false,
851        }),
852        TypedFunction::StrToTime { .. } => Some(DataType::Timestamp {
853            precision: None,
854            with_tz: false,
855        }),
856        TypedFunction::TimeToStr { .. } => Some(DataType::Varchar(None)),
857        TypedFunction::Year { .. } | TypedFunction::Month { .. } | TypedFunction::Day { .. } => {
858            Some(DataType::Int)
859        }
860
861        // ── String → Varchar ─────────────────────────────────────────
862        TypedFunction::Trim { .. }
863        | TypedFunction::Substring { .. }
864        | TypedFunction::Upper { .. }
865        | TypedFunction::Lower { .. }
866        | TypedFunction::Initcap { .. }
867        | TypedFunction::Replace { .. }
868        | TypedFunction::Reverse { .. }
869        | TypedFunction::Left { .. }
870        | TypedFunction::Right { .. }
871        | TypedFunction::Lpad { .. }
872        | TypedFunction::Rpad { .. }
873        | TypedFunction::ConcatWs { .. } => Some(DataType::Varchar(None)),
874        TypedFunction::Length { .. } => Some(DataType::Int),
875        TypedFunction::RegexpLike { .. } => Some(DataType::Boolean),
876        TypedFunction::RegexpExtract { .. } => Some(DataType::Varchar(None)),
877        TypedFunction::RegexpReplace { .. } => Some(DataType::Varchar(None)),
878        TypedFunction::Split { .. } => {
879            Some(DataType::Array(Some(Box::new(DataType::Varchar(None)))))
880        }
881
882        // ── Aggregates ───────────────────────────────────────────────
883        TypedFunction::Count { .. } => Some(DataType::BigInt),
884        TypedFunction::Sum { expr, .. } => ann.get_type(expr.as_ref()).map(|t| coerce_sum_type(t)),
885        TypedFunction::Avg { .. } => Some(DataType::Double),
886        TypedFunction::Min { expr } | TypedFunction::Max { expr } => {
887            ann.get_type(expr.as_ref()).cloned()
888        }
889        TypedFunction::ArrayAgg { expr, .. } => {
890            let elem = ann.get_type(expr.as_ref()).cloned();
891            Some(DataType::Array(elem.map(Box::new)))
892        }
893        TypedFunction::ApproxDistinct { .. } => Some(DataType::BigInt),
894        TypedFunction::Variance { .. } | TypedFunction::Stddev { .. } => Some(DataType::Double),
895
896        // ── Array ────────────────────────────────────────────────────
897        TypedFunction::ArrayConcat { arrays } => {
898            // Type is the same as input arrays
899            arrays.first().and_then(|a| ann.get_type(a)).cloned()
900        }
901        TypedFunction::ArrayContains { .. } => Some(DataType::Boolean),
902        TypedFunction::ArraySize { .. } => Some(DataType::Int),
903        TypedFunction::Explode { expr } => {
904            // Unwrap array element type
905            match ann.get_type(expr.as_ref()) {
906                Some(DataType::Array(Some(elem))) => Some(elem.as_ref().clone()),
907                _ => None,
908            }
909        }
910        TypedFunction::GenerateSeries { .. } => Some(DataType::Int),
911        TypedFunction::Flatten { expr } => ann.get_type(expr.as_ref()).cloned(),
912
913        // ── JSON ─────────────────────────────────────────────────────
914        TypedFunction::JSONExtract { .. } => Some(DataType::Json),
915        TypedFunction::JSONExtractScalar { .. } => Some(DataType::Varchar(None)),
916        TypedFunction::ParseJSON { .. } | TypedFunction::JSONFormat { .. } => Some(DataType::Json),
917
918        // ── Window ───────────────────────────────────────────────────
919        TypedFunction::RowNumber | TypedFunction::Rank | TypedFunction::DenseRank => {
920            Some(DataType::BigInt)
921        }
922        TypedFunction::NTile { .. } => Some(DataType::BigInt),
923        TypedFunction::Lead { expr, .. }
924        | TypedFunction::Lag { expr, .. }
925        | TypedFunction::FirstValue { expr }
926        | TypedFunction::LastValue { expr } => ann.get_type(expr.as_ref()).cloned(),
927
928        // ── Math ─────────────────────────────────────────────────────
929        TypedFunction::Abs { expr }
930        | TypedFunction::Ceil { expr }
931        | TypedFunction::Floor { expr } => ann.get_type(expr.as_ref()).cloned(),
932        TypedFunction::Round { expr, .. } => ann.get_type(expr.as_ref()).cloned(),
933        TypedFunction::Log { .. }
934        | TypedFunction::Ln { .. }
935        | TypedFunction::Pow { .. }
936        | TypedFunction::Sqrt { .. } => Some(DataType::Double),
937        TypedFunction::Greatest { exprs } | TypedFunction::Least { exprs } => {
938            let types: Vec<&DataType> = exprs.iter().filter_map(|e| ann.get_type(e)).collect();
939            common_type(&types)
940        }
941        TypedFunction::Mod { left, right } => {
942            match (ann.get_type(left.as_ref()), ann.get_type(right.as_ref())) {
943                (Some(l), Some(r)) => Some(coerce_numeric(l, r)),
944                (Some(l), _) => Some(l.clone()),
945                (_, Some(r)) => Some(r.clone()),
946                _ => Some(DataType::Int),
947            }
948        }
949
950        // ── Conversion ───────────────────────────────────────────────
951        TypedFunction::Hex { .. } | TypedFunction::Md5 { .. } | TypedFunction::Sha { .. } => {
952            Some(DataType::Varchar(None))
953        }
954        TypedFunction::Sha2 { .. } => Some(DataType::Varchar(None)),
955        TypedFunction::Unhex { .. } => Some(DataType::Varbinary(None)),
956    }
957}
958
959// ═══════════════════════════════════════════════════════════════════════
960// Subquery type inference
961// ═══════════════════════════════════════════════════════════════════════
962
963fn infer_subquery_type(sub: &Statement, ann: &TypeAnnotations) -> Option<DataType> {
964    // The type of a scalar subquery is the type of its single output column
965    if let Statement::Select(sel) = sub {
966        if let Some(SelectItem::Expr { expr, .. }) = sel.columns.first() {
967            return ann.get_type(expr).cloned();
968        }
969    }
970    None
971}
972
973// ═══════════════════════════════════════════════════════════════════════
974// Type coercion helpers
975// ═══════════════════════════════════════════════════════════════════════
976
977/// Numeric type widening precedence (higher = wider).
978fn numeric_precedence(dt: &DataType) -> u8 {
979    match dt {
980        DataType::Boolean => 1,
981        DataType::TinyInt => 2,
982        DataType::SmallInt => 3,
983        DataType::Int | DataType::Serial => 4,
984        DataType::BigInt | DataType::BigSerial => 5,
985        DataType::Real | DataType::Float => 6,
986        DataType::Double => 7,
987        DataType::Decimal { .. } | DataType::Numeric { .. } => 8,
988        _ => 0,
989    }
990}
991
992/// Coerce two numeric types to their common wider type.
993fn coerce_numeric(left: &DataType, right: &DataType) -> DataType {
994    let lp = numeric_precedence(left);
995    let rp = numeric_precedence(right);
996    if lp == 0 && rp == 0 {
997        // Neither is numeric — fall back to left
998        return left.clone();
999    }
1000    if lp >= rp {
1001        left.clone()
1002    } else {
1003        right.clone()
1004    }
1005}
1006
1007/// Determine the return type of SUM based on input type.
1008fn coerce_sum_type(input: &DataType) -> DataType {
1009    match input {
1010        DataType::TinyInt | DataType::SmallInt | DataType::Int | DataType::BigInt => {
1011            DataType::BigInt
1012        }
1013        DataType::Float | DataType::Real => DataType::Double,
1014        DataType::Double => DataType::Double,
1015        DataType::Decimal { precision, scale } => DataType::Decimal {
1016            precision: *precision,
1017            scale: *scale,
1018        },
1019        DataType::Numeric { precision, scale } => DataType::Numeric {
1020            precision: *precision,
1021            scale: *scale,
1022        },
1023        _ => DataType::BigInt,
1024    }
1025}
1026
1027/// Find the common (widest) type among a set of types.
1028fn common_type(types: &[&DataType]) -> Option<DataType> {
1029    if types.is_empty() {
1030        return None;
1031    }
1032    let mut result = types[0];
1033    for t in &types[1..] {
1034        // Skip NULL — it doesn't contribute to the common type
1035        if **t == DataType::Null {
1036            continue;
1037        }
1038        if *result == DataType::Null {
1039            result = t;
1040            continue;
1041        }
1042        // If both are numeric, pick the wider one
1043        let lp = numeric_precedence(result);
1044        let rp = numeric_precedence(t);
1045        if lp > 0 && rp > 0 {
1046            if rp > lp {
1047                result = t;
1048            }
1049            continue;
1050        }
1051        // If both are string-like, prefer VARCHAR
1052        if is_string_type(result) && is_string_type(t) {
1053            result = if matches!(result, DataType::Text) || matches!(t, DataType::Text) {
1054                if matches!(result, DataType::Text) {
1055                    result
1056                } else {
1057                    t
1058                }
1059            } else {
1060                result // keep first
1061            };
1062            continue;
1063        }
1064        // Otherwise keep the first non-null type
1065    }
1066    Some(result.clone())
1067}
1068
1069fn is_string_type(dt: &DataType) -> bool {
1070    matches!(
1071        dt,
1072        DataType::Varchar(_) | DataType::Char(_) | DataType::Text | DataType::String
1073    )
1074}
1075
1076// ═══════════════════════════════════════════════════════════════════════
1077// Tests
1078// ═══════════════════════════════════════════════════════════════════════
1079
1080#[cfg(test)]
1081mod tests {
1082    use super::*;
1083    use crate::dialects::Dialect;
1084    use crate::parser::Parser;
1085    use crate::schema::{MappingSchema, Schema};
1086
1087    fn setup_schema() -> MappingSchema {
1088        let mut schema = MappingSchema::new(Dialect::Ansi);
1089        schema
1090            .add_table(
1091                &["users"],
1092                vec![
1093                    ("id".to_string(), DataType::Int),
1094                    ("name".to_string(), DataType::Varchar(Some(255))),
1095                    ("age".to_string(), DataType::Int),
1096                    ("salary".to_string(), DataType::Double),
1097                    ("active".to_string(), DataType::Boolean),
1098                    (
1099                        "created_at".to_string(),
1100                        DataType::Timestamp {
1101                            precision: None,
1102                            with_tz: false,
1103                        },
1104                    ),
1105                ],
1106            )
1107            .unwrap();
1108        schema
1109            .add_table(
1110                &["orders"],
1111                vec![
1112                    ("id".to_string(), DataType::Int),
1113                    ("user_id".to_string(), DataType::Int),
1114                    (
1115                        "amount".to_string(),
1116                        DataType::Decimal {
1117                            precision: Some(10),
1118                            scale: Some(2),
1119                        },
1120                    ),
1121                    ("status".to_string(), DataType::Varchar(Some(50))),
1122                ],
1123            )
1124            .unwrap();
1125        schema
1126    }
1127
1128    fn parse_and_annotate(sql: &str, schema: &MappingSchema) -> (Statement, TypeAnnotations) {
1129        let stmt = Parser::new(sql).unwrap().parse_statement().unwrap();
1130        let ann = annotate_types(&stmt, schema);
1131        (stmt, ann)
1132    }
1133
1134    /// Helper: get the type of the first SELECT column
1135    fn first_col_type(stmt: &Statement, ann: &TypeAnnotations) -> Option<DataType> {
1136        if let Statement::Select(sel) = stmt {
1137            if let Some(SelectItem::Expr { expr, .. }) = sel.columns.first() {
1138                return ann.get_type(expr).cloned();
1139            }
1140        }
1141        None
1142    }
1143
1144    // ── Literal type inference ────────────────────────────────────────
1145
1146    #[test]
1147    fn test_number_literal_int() {
1148        let schema = setup_schema();
1149        let (stmt, ann) = parse_and_annotate("SELECT 42", &schema);
1150        assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Int));
1151    }
1152
1153    #[test]
1154    fn test_number_literal_big_int() {
1155        let schema = setup_schema();
1156        let (stmt, ann) = parse_and_annotate("SELECT 9999999999", &schema);
1157        assert_eq!(first_col_type(&stmt, &ann), Some(DataType::BigInt));
1158    }
1159
1160    #[test]
1161    fn test_number_literal_double() {
1162        let schema = setup_schema();
1163        let (stmt, ann) = parse_and_annotate("SELECT 3.14", &schema);
1164        assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Double));
1165    }
1166
1167    #[test]
1168    fn test_string_literal() {
1169        let schema = setup_schema();
1170        let (stmt, ann) = parse_and_annotate("SELECT 'hello'", &schema);
1171        assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Varchar(None)));
1172    }
1173
1174    #[test]
1175    fn test_boolean_literal() {
1176        let schema = setup_schema();
1177        let (stmt, ann) = parse_and_annotate("SELECT TRUE", &schema);
1178        assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Boolean));
1179    }
1180
1181    #[test]
1182    fn test_null_literal() {
1183        let schema = setup_schema();
1184        let (stmt, ann) = parse_and_annotate("SELECT NULL", &schema);
1185        assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Null));
1186    }
1187
1188    // ── Column reference type lookup ─────────────────────────────────
1189
1190    #[test]
1191    fn test_column_type_from_schema() {
1192        let schema = setup_schema();
1193        let (stmt, ann) = parse_and_annotate("SELECT id FROM users", &schema);
1194        assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Int));
1195    }
1196
1197    #[test]
1198    fn test_qualified_column_type() {
1199        let schema = setup_schema();
1200        let (stmt, ann) = parse_and_annotate("SELECT users.name FROM users", &schema);
1201        assert_eq!(
1202            first_col_type(&stmt, &ann),
1203            Some(DataType::Varchar(Some(255)))
1204        );
1205    }
1206
1207    #[test]
1208    fn test_aliased_table_column_type() {
1209        let schema = setup_schema();
1210        let (stmt, ann) = parse_and_annotate("SELECT u.salary FROM users AS u", &schema);
1211        assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Double));
1212    }
1213
1214    // ── Binary operator type inference ───────────────────────────────
1215
1216    #[test]
1217    fn test_int_plus_int() {
1218        let schema = setup_schema();
1219        let (stmt, ann) = parse_and_annotate("SELECT id + age FROM users", &schema);
1220        assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Int));
1221    }
1222
1223    #[test]
1224    fn test_int_plus_double() {
1225        let schema = setup_schema();
1226        let (stmt, ann) = parse_and_annotate("SELECT id + salary FROM users", &schema);
1227        assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Double));
1228    }
1229
1230    #[test]
1231    fn test_comparison_returns_boolean() {
1232        let schema = setup_schema();
1233        let (stmt, ann) = parse_and_annotate("SELECT id > 5 FROM users", &schema);
1234        assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Boolean));
1235    }
1236
1237    #[test]
1238    fn test_and_returns_boolean() {
1239        let schema = setup_schema();
1240        let (stmt, ann) = parse_and_annotate("SELECT id > 5 AND age < 30 FROM users", &schema);
1241        assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Boolean));
1242    }
1243
1244    // ── CAST type inference ──────────────────────────────────────────
1245
1246    #[test]
1247    fn test_cast_type() {
1248        let schema = setup_schema();
1249        let (stmt, ann) = parse_and_annotate("SELECT CAST(id AS BIGINT) FROM users", &schema);
1250        assert_eq!(first_col_type(&stmt, &ann), Some(DataType::BigInt));
1251    }
1252
1253    #[test]
1254    fn test_cast_to_varchar() {
1255        let schema = setup_schema();
1256        let (stmt, ann) = parse_and_annotate("SELECT CAST(id AS VARCHAR) FROM users", &schema);
1257        assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Varchar(None)));
1258    }
1259
1260    // ── CASE expression ──────────────────────────────────────────────
1261
1262    #[test]
1263    fn test_case_expression_type() {
1264        let schema = setup_schema();
1265        let (stmt, ann) = parse_and_annotate(
1266            "SELECT CASE WHEN id > 1 THEN salary ELSE 0.0 END FROM users",
1267            &schema,
1268        );
1269        let t = first_col_type(&stmt, &ann);
1270        assert!(
1271            matches!(t, Some(DataType::Double)),
1272            "Expected Double, got {t:?}"
1273        );
1274    }
1275
1276    // ── Function return types ────────────────────────────────────────
1277
1278    #[test]
1279    fn test_count_returns_bigint() {
1280        let schema = setup_schema();
1281        let (stmt, ann) = parse_and_annotate("SELECT COUNT(*) FROM users", &schema);
1282        assert_eq!(first_col_type(&stmt, &ann), Some(DataType::BigInt));
1283    }
1284
1285    #[test]
1286    fn test_sum_returns_bigint_for_int() {
1287        let schema = setup_schema();
1288        let (stmt, ann) = parse_and_annotate("SELECT SUM(id) FROM users", &schema);
1289        assert_eq!(first_col_type(&stmt, &ann), Some(DataType::BigInt));
1290    }
1291
1292    #[test]
1293    fn test_avg_returns_double() {
1294        let schema = setup_schema();
1295        let (stmt, ann) = parse_and_annotate("SELECT AVG(age) FROM users", &schema);
1296        assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Double));
1297    }
1298
1299    #[test]
1300    fn test_min_preserves_type() {
1301        let schema = setup_schema();
1302        let (stmt, ann) = parse_and_annotate("SELECT MIN(salary) FROM users", &schema);
1303        assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Double));
1304    }
1305
1306    #[test]
1307    fn test_upper_returns_varchar() {
1308        let schema = setup_schema();
1309        let (stmt, ann) = parse_and_annotate("SELECT UPPER(name) FROM users", &schema);
1310        assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Varchar(None)));
1311    }
1312
1313    #[test]
1314    fn test_length_returns_int() {
1315        let schema = setup_schema();
1316        let (stmt, ann) = parse_and_annotate("SELECT LENGTH(name) FROM users", &schema);
1317        assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Int));
1318    }
1319
1320    // ── Predicate types ──────────────────────────────────────────────
1321
1322    #[test]
1323    fn test_between_returns_boolean() {
1324        let schema = setup_schema();
1325        let (stmt, ann) = parse_and_annotate("SELECT age BETWEEN 18 AND 65 FROM users", &schema);
1326        assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Boolean));
1327    }
1328
1329    #[test]
1330    fn test_in_list_returns_boolean() {
1331        let schema = setup_schema();
1332        let (stmt, ann) = parse_and_annotate("SELECT id IN (1, 2, 3) FROM users", &schema);
1333        assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Boolean));
1334    }
1335
1336    #[test]
1337    fn test_is_null_returns_boolean() {
1338        let schema = setup_schema();
1339        let (stmt, ann) = parse_and_annotate("SELECT name IS NULL FROM users", &schema);
1340        assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Boolean));
1341    }
1342
1343    #[test]
1344    fn test_like_returns_boolean() {
1345        let schema = setup_schema();
1346        let (stmt, ann) = parse_and_annotate("SELECT name LIKE '%test%' FROM users", &schema);
1347        assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Boolean));
1348    }
1349
1350    // ── Exists ───────────────────────────────────────────────────────
1351
1352    #[test]
1353    fn test_exists_returns_boolean() {
1354        let schema = setup_schema();
1355        let (stmt, ann) =
1356            parse_and_annotate("SELECT EXISTS (SELECT 1 FROM orders) FROM users", &schema);
1357        assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Boolean));
1358    }
1359
1360    // ── Nested expressions ───────────────────────────────────────────
1361
1362    #[test]
1363    fn test_nested_expression_propagation() {
1364        let schema = setup_schema();
1365        let (stmt, ann) = parse_and_annotate("SELECT (id + age) * salary FROM users", &schema);
1366        let t = first_col_type(&stmt, &ann);
1367        // INT + INT = INT, INT * DOUBLE = DOUBLE
1368        assert!(
1369            matches!(t, Some(DataType::Double)),
1370            "Expected Double, got {t:?}"
1371        );
1372    }
1373
1374    // ── EXTRACT ──────────────────────────────────────────────────────
1375
1376    #[test]
1377    fn test_extract_returns_int() {
1378        let schema = setup_schema();
1379        let (stmt, ann) =
1380            parse_and_annotate("SELECT EXTRACT(YEAR FROM created_at) FROM users", &schema);
1381        assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Int));
1382    }
1383
1384    // ── Multiple columns ─────────────────────────────────────────────
1385
1386    #[test]
1387    fn test_multiple_columns_annotated() {
1388        let schema = setup_schema();
1389        let (stmt, ann) = parse_and_annotate("SELECT id, name, salary FROM users", &schema);
1390        if let Statement::Select(sel) = &stmt {
1391            assert_eq!(sel.columns.len(), 3);
1392            // id → Int
1393            if let SelectItem::Expr { expr, .. } = &sel.columns[0] {
1394                assert_eq!(ann.get_type(expr), Some(&DataType::Int));
1395            }
1396            // name → Varchar(255)
1397            if let SelectItem::Expr { expr, .. } = &sel.columns[1] {
1398                assert_eq!(ann.get_type(expr), Some(&DataType::Varchar(Some(255))));
1399            }
1400            // salary → Double
1401            if let SelectItem::Expr { expr, .. } = &sel.columns[2] {
1402                assert_eq!(ann.get_type(expr), Some(&DataType::Double));
1403            }
1404        }
1405    }
1406
1407    // ── WHERE clause annotation ──────────────────────────────────────
1408
1409    #[test]
1410    fn test_where_clause_annotated() {
1411        let schema = setup_schema();
1412        // Don't move stmt after annotation — raw pointers for inline fields
1413        // (like where_clause: Option<Expr>) are invalidated on move.
1414        let stmt = Parser::new("SELECT id FROM users WHERE age > 21")
1415            .unwrap()
1416            .parse_statement()
1417            .unwrap();
1418        let ann = annotate_types(&stmt, &schema);
1419        if let Statement::Select(sel) = &stmt {
1420            if let Some(wh) = &sel.where_clause {
1421                assert_eq!(ann.get_type(wh), Some(&DataType::Boolean));
1422            }
1423        }
1424    }
1425
1426    // ── Coercion rules ──────────────────────────────────────────────
1427
1428    #[test]
1429    fn test_int_and_bigint_coercion() {
1430        assert_eq!(
1431            coerce_numeric(&DataType::Int, &DataType::BigInt),
1432            DataType::BigInt
1433        );
1434    }
1435
1436    #[test]
1437    fn test_float_and_double_coercion() {
1438        assert_eq!(
1439            coerce_numeric(&DataType::Float, &DataType::Double),
1440            DataType::Double
1441        );
1442    }
1443
1444    #[test]
1445    fn test_int_and_double_coercion() {
1446        assert_eq!(
1447            coerce_numeric(&DataType::Int, &DataType::Double),
1448            DataType::Double
1449        );
1450    }
1451
1452    // ── Common type ─────────────────────────────────────────────────
1453
1454    #[test]
1455    fn test_common_type_nulls_skipped() {
1456        let types = vec![&DataType::Null, &DataType::Int, &DataType::Null];
1457        assert_eq!(common_type(&types), Some(DataType::Int));
1458    }
1459
1460    #[test]
1461    fn test_common_type_numeric_widening() {
1462        let types = vec![&DataType::Int, &DataType::Double, &DataType::Float];
1463        assert_eq!(common_type(&types), Some(DataType::Double));
1464    }
1465
1466    #[test]
1467    fn test_common_type_empty() {
1468        let types: Vec<&DataType> = vec![];
1469        assert_eq!(common_type(&types), None);
1470    }
1471
1472    // ── UDF type support ─────────────────────────────────────────────
1473
1474    #[test]
1475    fn test_udf_return_type() {
1476        let mut schema = setup_schema();
1477        schema.add_udf("my_func", DataType::Varchar(None));
1478        let (stmt, ann) = parse_and_annotate("SELECT my_func(id) FROM users", &schema);
1479        assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Varchar(None)));
1480    }
1481
1482    // ── Annotation count ─────────────────────────────────────────────
1483
1484    #[test]
1485    fn test_annotations_not_empty() {
1486        let schema = setup_schema();
1487        let (_, ann) = parse_and_annotate("SELECT id, name FROM users WHERE age > 21", &schema);
1488        assert!(!ann.is_empty());
1489        // Should have at least the SELECT columns and WHERE predicate
1490        assert!(ann.len() >= 3);
1491    }
1492
1493    // ── SUM of DECIMAL preserves precision ───────────────────────────
1494
1495    #[test]
1496    fn test_sum_decimal_preserves_type() {
1497        let schema = setup_schema();
1498        let (stmt, ann) = parse_and_annotate("SELECT SUM(amount) FROM orders", &schema);
1499        assert_eq!(
1500            first_col_type(&stmt, &ann),
1501            Some(DataType::Decimal {
1502                precision: Some(10),
1503                scale: Some(2)
1504            })
1505        );
1506    }
1507}