1use std::collections::HashMap;
38
39use crate::ast::*;
40use crate::schema::Schema;
41
42pub struct TypeAnnotations {
52 types: HashMap<*const Expr, DataType>,
53}
54
55unsafe impl Send for TypeAnnotations {}
58unsafe impl Sync for TypeAnnotations {}
59
60impl TypeAnnotations {
61 fn new() -> Self {
62 Self {
63 types: HashMap::new(),
64 }
65 }
66
67 fn set(&mut self, expr: &Expr, dt: DataType) {
68 self.types.insert(expr as *const Expr, dt);
69 }
70
71 #[must_use]
73 pub fn get_type(&self, expr: &Expr) -> Option<&DataType> {
74 self.types.get(&(expr as *const Expr))
75 }
76
77 #[must_use]
79 pub fn len(&self) -> usize {
80 self.types.len()
81 }
82
83 #[must_use]
85 pub fn is_empty(&self) -> bool {
86 self.types.is_empty()
87 }
88}
89
90#[must_use]
102pub fn annotate_types<S: Schema>(stmt: &Statement, schema: &S) -> TypeAnnotations {
103 let mut ann = TypeAnnotations::new();
104 let mut ctx = AnnotationContext::new(schema);
105 annotate_statement(stmt, &mut ctx, &mut ann);
106 ann
107}
108
109struct AnnotationContext<'s, S: Schema> {
115 schema: &'s S,
116 table_aliases: HashMap<String, Vec<String>>,
118}
119
120impl<'s, S: Schema> AnnotationContext<'s, S> {
121 fn new(schema: &'s S) -> Self {
122 Self {
123 schema,
124 table_aliases: HashMap::new(),
125 }
126 }
127
128 fn register_table(&mut self, table_ref: &TableRef) {
130 let path = vec![table_ref.name.clone()];
131 let alias = table_ref
132 .alias
133 .as_deref()
134 .unwrap_or(&table_ref.name)
135 .to_string();
136 self.table_aliases.insert(alias, path);
137 }
138
139 fn resolve_column_type(&self, table: Option<&str>, column: &str) -> Option<DataType> {
141 if let Some(tbl) = table {
142 if let Some(path) = self.table_aliases.get(tbl) {
144 let path_refs: Vec<&str> = path.iter().map(String::as_str).collect();
145 return self.schema.get_column_type(&path_refs, column).ok();
146 }
147 return self.schema.get_column_type(&[tbl], column).ok();
149 }
150 for path in self.table_aliases.values() {
152 let path_refs: Vec<&str> = path.iter().map(String::as_str).collect();
153 if let Ok(dt) = self.schema.get_column_type(&path_refs, column) {
154 return Some(dt);
155 }
156 }
157 None
158 }
159}
160
161fn annotate_statement<S: Schema>(
166 stmt: &Statement,
167 ctx: &mut AnnotationContext<S>,
168 ann: &mut TypeAnnotations,
169) {
170 match stmt {
171 Statement::Select(sel) => annotate_select(sel, ctx, ann),
172 Statement::SetOperation(set_op) => {
173 annotate_statement(&set_op.left, ctx, ann);
174 annotate_statement(&set_op.right, ctx, ann);
175 }
176 Statement::Insert(ins) => {
177 if let InsertSource::Query(q) = &ins.source {
178 annotate_statement(q, ctx, ann);
179 }
180 for row in match &ins.source {
181 InsertSource::Values(rows) => rows.as_slice(),
182 _ => &[],
183 } {
184 for expr in row {
185 annotate_expr(expr, ctx, ann);
186 }
187 }
188 }
189 Statement::Update(upd) => {
190 for (_, expr) in &upd.assignments {
191 annotate_expr(expr, ctx, ann);
192 }
193 if let Some(wh) = &upd.where_clause {
194 annotate_expr(wh, ctx, ann);
195 }
196 }
197 Statement::Delete(del) => {
198 if let Some(wh) = &del.where_clause {
199 annotate_expr(wh, ctx, ann);
200 }
201 }
202 Statement::Expression(expr) => {
203 annotate_expr(expr, ctx, ann);
204 }
205 Statement::Explain(expl) => {
206 annotate_statement(&expl.statement, ctx, ann);
207 }
208 _ => {}
210 }
211}
212
213fn annotate_select<S: Schema>(
214 sel: &SelectStatement,
215 ctx: &mut AnnotationContext<S>,
216 ann: &mut TypeAnnotations,
217) {
218 for cte in &sel.ctes {
220 annotate_statement(&cte.query, ctx, ann);
221 }
222
223 if let Some(from) = &sel.from {
225 register_table_source(&from.source, ctx);
226 }
227 for join in &sel.joins {
228 register_table_source(&join.table, ctx);
229 }
230
231 if let Some(wh) = &sel.where_clause {
233 annotate_expr(wh, ctx, ann);
234 }
235
236 for item in &sel.columns {
238 if let SelectItem::Expr { expr, .. } = item {
239 annotate_expr(expr, ctx, ann);
240 }
241 }
242
243 for expr in &sel.group_by {
245 annotate_expr(expr, ctx, ann);
246 }
247
248 if let Some(having) = &sel.having {
250 annotate_expr(having, ctx, ann);
251 }
252
253 for ob in &sel.order_by {
255 annotate_expr(&ob.expr, ctx, ann);
256 }
257
258 if let Some(limit) = &sel.limit {
260 annotate_expr(limit, ctx, ann);
261 }
262 if let Some(offset) = &sel.offset {
263 annotate_expr(offset, ctx, ann);
264 }
265 if let Some(fetch) = &sel.fetch_first {
266 annotate_expr(fetch, ctx, ann);
267 }
268
269 if let Some(qualify) = &sel.qualify {
271 annotate_expr(qualify, ctx, ann);
272 }
273
274 for join in &sel.joins {
276 if let Some(on) = &join.on {
277 annotate_expr(on, ctx, ann);
278 }
279 }
280}
281
282fn register_table_source<S: Schema>(source: &TableSource, ctx: &mut AnnotationContext<S>) {
283 match source {
284 TableSource::Table(tref) => ctx.register_table(tref),
285 TableSource::Subquery { alias, .. } => {
286 let _ = alias;
289 }
290 TableSource::TableFunction { alias, .. } => {
291 let _ = alias;
292 }
293 TableSource::Lateral { source } => register_table_source(source, ctx),
294 TableSource::Unnest { .. } => {}
295 }
296}
297
298fn annotate_expr<S: Schema>(expr: &Expr, ctx: &AnnotationContext<S>, ann: &mut TypeAnnotations) {
303 annotate_children(expr, ctx, ann);
305
306 let dt = infer_type(expr, ctx, ann);
307 if let Some(t) = dt {
308 ann.set(expr, t);
309 }
310}
311
312fn annotate_children<S: Schema>(
314 expr: &Expr,
315 ctx: &AnnotationContext<S>,
316 ann: &mut TypeAnnotations,
317) {
318 match expr {
319 Expr::BinaryOp { left, right, .. } => {
320 annotate_expr(left, ctx, ann);
321 annotate_expr(right, ctx, ann);
322 }
323 Expr::UnaryOp { expr: inner, .. } => annotate_expr(inner, ctx, ann),
324 Expr::Function { args, filter, .. } => {
325 for arg in args {
326 annotate_expr(arg, ctx, ann);
327 }
328 if let Some(f) = filter {
329 annotate_expr(f, ctx, ann);
330 }
331 }
332 Expr::Between {
333 expr: e, low, high, ..
334 } => {
335 annotate_expr(e, ctx, ann);
336 annotate_expr(low, ctx, ann);
337 annotate_expr(high, ctx, ann);
338 }
339 Expr::InList { expr: e, list, .. } => {
340 annotate_expr(e, ctx, ann);
341 for item in list {
342 annotate_expr(item, ctx, ann);
343 }
344 }
345 Expr::InSubquery {
346 expr: e, subquery, ..
347 } => {
348 annotate_expr(e, ctx, ann);
349 let mut sub_ctx = AnnotationContext::new(ctx.schema);
350 annotate_statement(subquery, &mut sub_ctx, ann);
351 }
352 Expr::IsNull { expr: e, .. } | Expr::IsBool { expr: e, .. } => {
353 annotate_expr(e, ctx, ann);
354 }
355 Expr::Like {
356 expr: e,
357 pattern,
358 escape,
359 ..
360 }
361 | Expr::ILike {
362 expr: e,
363 pattern,
364 escape,
365 ..
366 } => {
367 annotate_expr(e, ctx, ann);
368 annotate_expr(pattern, ctx, ann);
369 if let Some(esc) = escape {
370 annotate_expr(esc, ctx, ann);
371 }
372 }
373 Expr::Case {
374 operand,
375 when_clauses,
376 else_clause,
377 } => {
378 if let Some(op) = operand {
379 annotate_expr(op, ctx, ann);
380 }
381 for (cond, result) in when_clauses {
382 annotate_expr(cond, ctx, ann);
383 annotate_expr(result, ctx, ann);
384 }
385 if let Some(el) = else_clause {
386 annotate_expr(el, ctx, ann);
387 }
388 }
389 Expr::Nested(inner) => annotate_expr(inner, ctx, ann),
390 Expr::Cast { expr: e, .. } | Expr::TryCast { expr: e, .. } => {
391 annotate_expr(e, ctx, ann);
392 }
393 Expr::Extract { expr: e, .. } => annotate_expr(e, ctx, ann),
394 Expr::Interval { value, .. } => annotate_expr(value, ctx, ann),
395 Expr::ArrayLiteral(items) | Expr::Tuple(items) | Expr::Coalesce(items) => {
396 for item in items {
397 annotate_expr(item, ctx, ann);
398 }
399 }
400 Expr::If {
401 condition,
402 true_val,
403 false_val,
404 } => {
405 annotate_expr(condition, ctx, ann);
406 annotate_expr(true_val, ctx, ann);
407 if let Some(fv) = false_val {
408 annotate_expr(fv, ctx, ann);
409 }
410 }
411 Expr::NullIf { expr: e, r#else } => {
412 annotate_expr(e, ctx, ann);
413 annotate_expr(r#else, ctx, ann);
414 }
415 Expr::Collate { expr: e, .. } => annotate_expr(e, ctx, ann),
416 Expr::Alias { expr: e, .. } => annotate_expr(e, ctx, ann),
417 Expr::ArrayIndex { expr: e, index } => {
418 annotate_expr(e, ctx, ann);
419 annotate_expr(index, ctx, ann);
420 }
421 Expr::JsonAccess { expr: e, path, .. } => {
422 annotate_expr(e, ctx, ann);
423 annotate_expr(path, ctx, ann);
424 }
425 Expr::Lambda { body, .. } => annotate_expr(body, ctx, ann),
426 Expr::AnyOp { expr: e, right, .. } | Expr::AllOp { expr: e, right, .. } => {
427 annotate_expr(e, ctx, ann);
428 annotate_expr(right, ctx, ann);
429 }
430 Expr::Subquery(sub) => {
431 let mut sub_ctx = AnnotationContext::new(ctx.schema);
432 annotate_statement(sub, &mut sub_ctx, ann);
433 }
434 Expr::Exists { subquery, .. } => {
435 let mut sub_ctx = AnnotationContext::new(ctx.schema);
436 annotate_statement(subquery, &mut sub_ctx, ann);
437 }
438 Expr::TypedFunction { func, filter, .. } => {
439 annotate_typed_function_children(func, ctx, ann);
440 if let Some(f) = filter {
441 annotate_expr(f, ctx, ann);
442 }
443 }
444 Expr::Column { .. }
446 | Expr::Number(_)
447 | Expr::StringLiteral(_)
448 | Expr::Boolean(_)
449 | Expr::Null
450 | Expr::Wildcard
451 | Expr::Star
452 | Expr::Parameter(_)
453 | Expr::TypeExpr(_)
454 | Expr::QualifiedWildcard { .. }
455 | Expr::Default => {}
456 }
457}
458
459fn annotate_typed_function_children<S: Schema>(
461 func: &TypedFunction,
462 ctx: &AnnotationContext<S>,
463 ann: &mut TypeAnnotations,
464) {
465 func.walk_children(&mut |child| {
467 annotate_expr(child, ctx, ann);
468 true
469 });
470}
471
472fn infer_type<S: Schema>(
477 expr: &Expr,
478 ctx: &AnnotationContext<S>,
479 ann: &TypeAnnotations,
480) -> Option<DataType> {
481 match expr {
482 Expr::Number(s) => Some(infer_number_type(s)),
484 Expr::StringLiteral(_) => Some(DataType::Varchar(None)),
485 Expr::Boolean(_) => Some(DataType::Boolean),
486 Expr::Null => Some(DataType::Null),
487
488 Expr::Column { table, name, .. } => ctx.resolve_column_type(table.as_deref(), name),
490
491 Expr::BinaryOp { left, op, right } => {
493 infer_binary_op_type(op, ann.get_type(left), ann.get_type(right))
494 }
495
496 Expr::UnaryOp { op, expr: inner } => match op {
498 UnaryOperator::Not => Some(DataType::Boolean),
499 UnaryOperator::Minus | UnaryOperator::Plus => ann.get_type(inner).cloned(),
500 UnaryOperator::BitwiseNot => ann.get_type(inner).cloned(),
501 },
502
503 Expr::Cast { data_type, .. } | Expr::TryCast { data_type, .. } => Some(data_type.clone()),
505
506 Expr::Case {
508 when_clauses,
509 else_clause,
510 ..
511 } => {
512 let mut result_types: Vec<&DataType> = Vec::new();
513 for (_, result) in when_clauses {
514 if let Some(t) = ann.get_type(result) {
515 result_types.push(t);
516 }
517 }
518 if let Some(el) = else_clause {
519 if let Some(t) = ann.get_type(el.as_ref()) {
520 result_types.push(t);
521 }
522 }
523 common_type(&result_types)
524 }
525
526 Expr::If {
528 true_val,
529 false_val,
530 ..
531 } => {
532 let mut types = Vec::new();
533 if let Some(t) = ann.get_type(true_val) {
534 types.push(t);
535 }
536 if let Some(fv) = false_val {
537 if let Some(t) = ann.get_type(fv.as_ref()) {
538 types.push(t);
539 }
540 }
541 common_type(&types)
542 }
543
544 Expr::Coalesce(items) => {
546 let types: Vec<&DataType> = items.iter().filter_map(|e| ann.get_type(e)).collect();
547 common_type(&types)
548 }
549
550 Expr::NullIf { expr: e, .. } => ann.get_type(e.as_ref()).cloned(),
552
553 Expr::Function { name, args, .. } => infer_generic_function_type(name, args, ctx, ann),
555
556 Expr::TypedFunction { func, .. } => infer_typed_function_type(func, ann),
558
559 Expr::Subquery(sub) => infer_subquery_type(sub, ann),
561
562 Expr::Exists { .. } => Some(DataType::Boolean),
564
565 Expr::Between { .. }
567 | Expr::InList { .. }
568 | Expr::InSubquery { .. }
569 | Expr::IsNull { .. }
570 | Expr::IsBool { .. }
571 | Expr::Like { .. }
572 | Expr::ILike { .. }
573 | Expr::AnyOp { .. }
574 | Expr::AllOp { .. } => Some(DataType::Boolean),
575
576 Expr::Extract { .. } => Some(DataType::Int),
578
579 Expr::Interval { .. } => Some(DataType::Interval),
581
582 Expr::ArrayLiteral(items) => {
584 let elem_types: Vec<&DataType> = items.iter().filter_map(|e| ann.get_type(e)).collect();
585 let elem = common_type(&elem_types);
586 Some(DataType::Array(elem.map(Box::new)))
587 }
588
589 Expr::Tuple(items) => {
591 let types: Vec<DataType> = items
592 .iter()
593 .map(|e| ann.get_type(e).cloned().unwrap_or(DataType::Null))
594 .collect();
595 Some(DataType::Tuple(types))
596 }
597
598 Expr::ArrayIndex { expr: e, .. } => match ann.get_type(e.as_ref()) {
600 Some(DataType::Array(Some(elem))) => Some(elem.as_ref().clone()),
601 _ => None,
602 },
603
604 Expr::JsonAccess { as_text, .. } => {
606 if *as_text {
607 Some(DataType::Text)
608 } else {
609 Some(DataType::Json)
610 }
611 }
612
613 Expr::Nested(inner) => ann.get_type(inner.as_ref()).cloned(),
615 Expr::Alias { expr: e, .. } => ann.get_type(e.as_ref()).cloned(),
616
617 Expr::Collate { .. } => Some(DataType::Varchar(None)),
619
620 Expr::TypeExpr(dt) => Some(dt.clone()),
622
623 Expr::Wildcard
625 | Expr::Star
626 | Expr::QualifiedWildcard { .. }
627 | Expr::Parameter(_)
628 | Expr::Lambda { .. }
629 | Expr::Default => None,
630 }
631}
632
633fn infer_number_type(s: &str) -> DataType {
638 if s.contains('.') || s.contains('e') || s.contains('E') {
639 DataType::Double
640 } else if let Ok(v) = s.parse::<i64>() {
641 if v >= i32::MIN as i64 && v <= i32::MAX as i64 {
642 DataType::Int
643 } else {
644 DataType::BigInt
645 }
646 } else {
647 DataType::BigInt
649 }
650}
651
652fn infer_binary_op_type(
657 op: &BinaryOperator,
658 left: Option<&DataType>,
659 right: Option<&DataType>,
660) -> Option<DataType> {
661 use BinaryOperator::*;
662 match op {
663 Eq | Neq | Lt | Gt | LtEq | GtEq => Some(DataType::Boolean),
665
666 And | Or | Xor => Some(DataType::Boolean),
668
669 Concat => Some(DataType::Varchar(None)),
671
672 Plus | Minus | Multiply | Divide | Modulo => match (left, right) {
674 (Some(l), Some(r)) => Some(coerce_numeric(l, r)),
675 (Some(l), None) => Some(l.clone()),
676 (None, Some(r)) => Some(r.clone()),
677 (None, None) => None,
678 },
679
680 BitwiseAnd | BitwiseOr | BitwiseXor | ShiftLeft | ShiftRight => match (left, right) {
682 (Some(l), Some(r)) => Some(coerce_numeric(l, r)),
683 (Some(l), None) => Some(l.clone()),
684 (None, Some(r)) => Some(r.clone()),
685 (None, None) => Some(DataType::Int),
686 },
687
688 Arrow => Some(DataType::Json),
690 DoubleArrow => Some(DataType::Text),
691 }
692}
693
694fn infer_generic_function_type<S: Schema>(
699 name: &str,
700 args: &[Expr],
701 ctx: &AnnotationContext<S>,
702 ann: &TypeAnnotations,
703) -> Option<DataType> {
704 let upper = name.to_uppercase();
705 match upper.as_str() {
706 "COUNT" | "COUNT_BIG" => Some(DataType::BigInt),
708 "SUM" => args
709 .first()
710 .and_then(|a| ann.get_type(a))
711 .map(|t| coerce_sum_type(t)),
712 "AVG" => Some(DataType::Double),
713 "MIN" | "MAX" => args.first().and_then(|a| ann.get_type(a)).cloned(),
714 "VARIANCE" | "VAR_SAMP" | "VAR_POP" | "STDDEV" | "STDDEV_SAMP" | "STDDEV_POP" => {
715 Some(DataType::Double)
716 }
717 "APPROX_COUNT_DISTINCT" | "APPROX_DISTINCT" => Some(DataType::BigInt),
718
719 "CONCAT" | "UPPER" | "LOWER" | "TRIM" | "LTRIM" | "RTRIM" | "LPAD" | "RPAD" | "REPLACE"
721 | "REVERSE" | "SUBSTRING" | "SUBSTR" | "LEFT" | "RIGHT" | "INITCAP" | "REPEAT"
722 | "TRANSLATE" | "FORMAT" | "CONCAT_WS" | "SPACE" | "REPLICATE" => {
723 Some(DataType::Varchar(None))
724 }
725 "LENGTH" | "LEN" | "CHAR_LENGTH" | "CHARACTER_LENGTH" | "OCTET_LENGTH" | "BIT_LENGTH" => {
726 Some(DataType::Int)
727 }
728 "POSITION" | "STRPOS" | "LOCATE" | "INSTR" | "CHARINDEX" => Some(DataType::Int),
729 "ASCII" => Some(DataType::Int),
730 "CHR" | "CHAR" => Some(DataType::Varchar(Some(1))),
731
732 "ABS" | "CEIL" | "CEILING" | "FLOOR" => args.first().and_then(|a| ann.get_type(a)).cloned(),
734 "ROUND" | "TRUNCATE" | "TRUNC" => args.first().and_then(|a| ann.get_type(a)).cloned(),
735 "SQRT" | "LN" | "LOG" | "LOG2" | "LOG10" | "EXP" | "POWER" | "POW" | "ACOS" | "ASIN"
736 | "ATAN" | "ATAN2" | "COS" | "SIN" | "TAN" | "COT" | "DEGREES" | "RADIANS" | "PI"
737 | "SIGN" => Some(DataType::Double),
738 "MOD" => {
739 match (
740 args.first().and_then(|a| ann.get_type(a)),
741 args.get(1).and_then(|a| ann.get_type(a)),
742 ) {
743 (Some(l), Some(r)) => Some(coerce_numeric(l, r)),
744 (Some(l), _) => Some(l.clone()),
745 (_, Some(r)) => Some(r.clone()),
746 _ => Some(DataType::Int),
747 }
748 }
749 "GREATEST" | "LEAST" => {
750 let types: Vec<&DataType> = args.iter().filter_map(|a| ann.get_type(a)).collect();
751 common_type(&types)
752 }
753 "RANDOM" | "RAND" => Some(DataType::Double),
754
755 "CURRENT_DATE" | "CURDATE" | "TODAY" => Some(DataType::Date),
757 "CURRENT_TIMESTAMP" | "NOW" | "GETDATE" | "SYSDATE" | "SYSTIMESTAMP" | "LOCALTIMESTAMP" => {
758 Some(DataType::Timestamp {
759 precision: None,
760 with_tz: false,
761 })
762 }
763 "CURRENT_TIME" | "CURTIME" => Some(DataType::Time { precision: None }),
764 "DATE" | "TO_DATE" | "DATE_TRUNC" | "DATE_ADD" | "DATE_SUB" | "DATEADD" | "DATESUB"
765 | "ADDDATE" | "SUBDATE" => Some(DataType::Date),
766 "TIMESTAMP" | "TO_TIMESTAMP" => Some(DataType::Timestamp {
767 precision: None,
768 with_tz: false,
769 }),
770 "YEAR" | "MONTH" | "DAY" | "DAYOFWEEK" | "DAYOFYEAR" | "HOUR" | "MINUTE" | "SECOND"
771 | "QUARTER" | "WEEK" | "EXTRACT" | "DATEDIFF" | "TIMESTAMPDIFF" | "MONTHS_BETWEEN" => {
772 Some(DataType::Int)
773 }
774
775 "CAST" | "TRY_CAST" | "SAFE_CAST" | "CONVERT" => None, "COALESCE" => {
780 let types: Vec<&DataType> = args.iter().filter_map(|a| ann.get_type(a)).collect();
781 common_type(&types)
782 }
783 "NULLIF" => args.first().and_then(|a| ann.get_type(a)).cloned(),
784 "IF" | "IIF" => {
785 args.get(1).and_then(|a| ann.get_type(a)).cloned()
787 }
788 "IFNULL" | "NVL" | "ISNULL" => {
789 let types: Vec<&DataType> = args.iter().filter_map(|a| ann.get_type(a)).collect();
790 common_type(&types)
791 }
792
793 "JSON_EXTRACT" | "JSON_QUERY" | "GET_JSON_OBJECT" => Some(DataType::Json),
795 "JSON_EXTRACT_SCALAR" | "JSON_VALUE" | "JSON_EXTRACT_PATH_TEXT" => {
796 Some(DataType::Varchar(None))
797 }
798 "TO_JSON" | "JSON_OBJECT" | "JSON_ARRAY" | "JSON_BUILD_OBJECT" | "JSON_BUILD_ARRAY" => {
799 Some(DataType::Json)
800 }
801 "PARSE_JSON" | "JSON_PARSE" | "JSON" => Some(DataType::Json),
802
803 "ARRAY_AGG" | "COLLECT_LIST" | "COLLECT_SET" => {
805 let elem = args.first().and_then(|a| ann.get_type(a)).cloned();
806 Some(DataType::Array(elem.map(Box::new)))
807 }
808 "ARRAY_LENGTH" | "ARRAY_SIZE" | "CARDINALITY" => Some(DataType::Int),
809 "ARRAY" | "ARRAY_CONSTRUCT" => {
810 let types: Vec<&DataType> = args.iter().filter_map(|a| ann.get_type(a)).collect();
811 let elem = common_type(&types);
812 Some(DataType::Array(elem.map(Box::new)))
813 }
814 "ARRAY_CONTAINS" | "ARRAY_POSITION" => Some(DataType::Boolean),
815
816 "ROW_NUMBER" | "RANK" | "DENSE_RANK" | "NTILE" | "CUME_DIST" | "PERCENT_RANK" => {
818 Some(DataType::BigInt)
819 }
820
821 "MD5" | "SHA1" | "SHA" | "SHA2" | "SHA256" | "SHA512" => Some(DataType::Varchar(None)),
823 "HEX" | "TO_HEX" => Some(DataType::Varchar(None)),
824 "UNHEX" | "FROM_HEX" => Some(DataType::Varbinary(None)),
825 "CRC32" | "HASH" => Some(DataType::BigInt),
826
827 "TYPEOF" | "TYPE_OF" => Some(DataType::Varchar(None)),
829
830 _ => ctx.schema.get_udf_type(&upper).cloned(),
832 }
833}
834
835fn infer_typed_function_type(func: &TypedFunction, ann: &TypeAnnotations) -> Option<DataType> {
840 match func {
841 TypedFunction::DateAdd { .. }
843 | TypedFunction::DateSub { .. }
844 | TypedFunction::DateTrunc { .. }
845 | TypedFunction::TsOrDsToDate { .. } => Some(DataType::Date),
846 TypedFunction::DateDiff { .. } => Some(DataType::Int),
847 TypedFunction::CurrentDate => Some(DataType::Date),
848 TypedFunction::CurrentTimestamp => Some(DataType::Timestamp {
849 precision: None,
850 with_tz: false,
851 }),
852 TypedFunction::StrToTime { .. } => Some(DataType::Timestamp {
853 precision: None,
854 with_tz: false,
855 }),
856 TypedFunction::TimeToStr { .. } => Some(DataType::Varchar(None)),
857 TypedFunction::Year { .. } | TypedFunction::Month { .. } | TypedFunction::Day { .. } => {
858 Some(DataType::Int)
859 }
860
861 TypedFunction::Trim { .. }
863 | TypedFunction::Substring { .. }
864 | TypedFunction::Upper { .. }
865 | TypedFunction::Lower { .. }
866 | TypedFunction::Initcap { .. }
867 | TypedFunction::Replace { .. }
868 | TypedFunction::Reverse { .. }
869 | TypedFunction::Left { .. }
870 | TypedFunction::Right { .. }
871 | TypedFunction::Lpad { .. }
872 | TypedFunction::Rpad { .. }
873 | TypedFunction::ConcatWs { .. } => Some(DataType::Varchar(None)),
874 TypedFunction::Length { .. } => Some(DataType::Int),
875 TypedFunction::RegexpLike { .. } => Some(DataType::Boolean),
876 TypedFunction::RegexpExtract { .. } => Some(DataType::Varchar(None)),
877 TypedFunction::RegexpReplace { .. } => Some(DataType::Varchar(None)),
878 TypedFunction::Split { .. } => {
879 Some(DataType::Array(Some(Box::new(DataType::Varchar(None)))))
880 }
881
882 TypedFunction::Count { .. } => Some(DataType::BigInt),
884 TypedFunction::Sum { expr, .. } => ann.get_type(expr.as_ref()).map(|t| coerce_sum_type(t)),
885 TypedFunction::Avg { .. } => Some(DataType::Double),
886 TypedFunction::Min { expr } | TypedFunction::Max { expr } => {
887 ann.get_type(expr.as_ref()).cloned()
888 }
889 TypedFunction::ArrayAgg { expr, .. } => {
890 let elem = ann.get_type(expr.as_ref()).cloned();
891 Some(DataType::Array(elem.map(Box::new)))
892 }
893 TypedFunction::ApproxDistinct { .. } => Some(DataType::BigInt),
894 TypedFunction::Variance { .. } | TypedFunction::Stddev { .. } => Some(DataType::Double),
895
896 TypedFunction::ArrayConcat { arrays } => {
898 arrays.first().and_then(|a| ann.get_type(a)).cloned()
900 }
901 TypedFunction::ArrayContains { .. } => Some(DataType::Boolean),
902 TypedFunction::ArraySize { .. } => Some(DataType::Int),
903 TypedFunction::Explode { expr } => {
904 match ann.get_type(expr.as_ref()) {
906 Some(DataType::Array(Some(elem))) => Some(elem.as_ref().clone()),
907 _ => None,
908 }
909 }
910 TypedFunction::GenerateSeries { .. } => Some(DataType::Int),
911 TypedFunction::Flatten { expr } => ann.get_type(expr.as_ref()).cloned(),
912
913 TypedFunction::JSONExtract { .. } => Some(DataType::Json),
915 TypedFunction::JSONExtractScalar { .. } => Some(DataType::Varchar(None)),
916 TypedFunction::ParseJSON { .. } | TypedFunction::JSONFormat { .. } => Some(DataType::Json),
917
918 TypedFunction::RowNumber | TypedFunction::Rank | TypedFunction::DenseRank => {
920 Some(DataType::BigInt)
921 }
922 TypedFunction::NTile { .. } => Some(DataType::BigInt),
923 TypedFunction::Lead { expr, .. }
924 | TypedFunction::Lag { expr, .. }
925 | TypedFunction::FirstValue { expr }
926 | TypedFunction::LastValue { expr } => ann.get_type(expr.as_ref()).cloned(),
927
928 TypedFunction::Abs { expr }
930 | TypedFunction::Ceil { expr }
931 | TypedFunction::Floor { expr } => ann.get_type(expr.as_ref()).cloned(),
932 TypedFunction::Round { expr, .. } => ann.get_type(expr.as_ref()).cloned(),
933 TypedFunction::Log { .. }
934 | TypedFunction::Ln { .. }
935 | TypedFunction::Pow { .. }
936 | TypedFunction::Sqrt { .. } => Some(DataType::Double),
937 TypedFunction::Greatest { exprs } | TypedFunction::Least { exprs } => {
938 let types: Vec<&DataType> = exprs.iter().filter_map(|e| ann.get_type(e)).collect();
939 common_type(&types)
940 }
941 TypedFunction::Mod { left, right } => {
942 match (ann.get_type(left.as_ref()), ann.get_type(right.as_ref())) {
943 (Some(l), Some(r)) => Some(coerce_numeric(l, r)),
944 (Some(l), _) => Some(l.clone()),
945 (_, Some(r)) => Some(r.clone()),
946 _ => Some(DataType::Int),
947 }
948 }
949
950 TypedFunction::Hex { .. } | TypedFunction::Md5 { .. } | TypedFunction::Sha { .. } => {
952 Some(DataType::Varchar(None))
953 }
954 TypedFunction::Sha2 { .. } => Some(DataType::Varchar(None)),
955 TypedFunction::Unhex { .. } => Some(DataType::Varbinary(None)),
956 }
957}
958
959fn infer_subquery_type(sub: &Statement, ann: &TypeAnnotations) -> Option<DataType> {
964 if let Statement::Select(sel) = sub {
966 if let Some(SelectItem::Expr { expr, .. }) = sel.columns.first() {
967 return ann.get_type(expr).cloned();
968 }
969 }
970 None
971}
972
973fn numeric_precedence(dt: &DataType) -> u8 {
979 match dt {
980 DataType::Boolean => 1,
981 DataType::TinyInt => 2,
982 DataType::SmallInt => 3,
983 DataType::Int | DataType::Serial => 4,
984 DataType::BigInt | DataType::BigSerial => 5,
985 DataType::Real | DataType::Float => 6,
986 DataType::Double => 7,
987 DataType::Decimal { .. } | DataType::Numeric { .. } => 8,
988 _ => 0,
989 }
990}
991
992fn coerce_numeric(left: &DataType, right: &DataType) -> DataType {
994 let lp = numeric_precedence(left);
995 let rp = numeric_precedence(right);
996 if lp == 0 && rp == 0 {
997 return left.clone();
999 }
1000 if lp >= rp {
1001 left.clone()
1002 } else {
1003 right.clone()
1004 }
1005}
1006
1007fn coerce_sum_type(input: &DataType) -> DataType {
1009 match input {
1010 DataType::TinyInt | DataType::SmallInt | DataType::Int | DataType::BigInt => {
1011 DataType::BigInt
1012 }
1013 DataType::Float | DataType::Real => DataType::Double,
1014 DataType::Double => DataType::Double,
1015 DataType::Decimal { precision, scale } => DataType::Decimal {
1016 precision: *precision,
1017 scale: *scale,
1018 },
1019 DataType::Numeric { precision, scale } => DataType::Numeric {
1020 precision: *precision,
1021 scale: *scale,
1022 },
1023 _ => DataType::BigInt,
1024 }
1025}
1026
1027fn common_type(types: &[&DataType]) -> Option<DataType> {
1029 if types.is_empty() {
1030 return None;
1031 }
1032 let mut result = types[0];
1033 for t in &types[1..] {
1034 if **t == DataType::Null {
1036 continue;
1037 }
1038 if *result == DataType::Null {
1039 result = t;
1040 continue;
1041 }
1042 let lp = numeric_precedence(result);
1044 let rp = numeric_precedence(t);
1045 if lp > 0 && rp > 0 {
1046 if rp > lp {
1047 result = t;
1048 }
1049 continue;
1050 }
1051 if is_string_type(result) && is_string_type(t) {
1053 result = if matches!(result, DataType::Text) || matches!(t, DataType::Text) {
1054 if matches!(result, DataType::Text) {
1055 result
1056 } else {
1057 t
1058 }
1059 } else {
1060 result };
1062 continue;
1063 }
1064 }
1066 Some(result.clone())
1067}
1068
1069fn is_string_type(dt: &DataType) -> bool {
1070 matches!(
1071 dt,
1072 DataType::Varchar(_) | DataType::Char(_) | DataType::Text | DataType::String
1073 )
1074}
1075
1076#[cfg(test)]
1081mod tests {
1082 use super::*;
1083 use crate::dialects::Dialect;
1084 use crate::parser::Parser;
1085 use crate::schema::{MappingSchema, Schema};
1086
1087 fn setup_schema() -> MappingSchema {
1088 let mut schema = MappingSchema::new(Dialect::Ansi);
1089 schema
1090 .add_table(
1091 &["users"],
1092 vec![
1093 ("id".to_string(), DataType::Int),
1094 ("name".to_string(), DataType::Varchar(Some(255))),
1095 ("age".to_string(), DataType::Int),
1096 ("salary".to_string(), DataType::Double),
1097 ("active".to_string(), DataType::Boolean),
1098 (
1099 "created_at".to_string(),
1100 DataType::Timestamp {
1101 precision: None,
1102 with_tz: false,
1103 },
1104 ),
1105 ],
1106 )
1107 .unwrap();
1108 schema
1109 .add_table(
1110 &["orders"],
1111 vec![
1112 ("id".to_string(), DataType::Int),
1113 ("user_id".to_string(), DataType::Int),
1114 (
1115 "amount".to_string(),
1116 DataType::Decimal {
1117 precision: Some(10),
1118 scale: Some(2),
1119 },
1120 ),
1121 ("status".to_string(), DataType::Varchar(Some(50))),
1122 ],
1123 )
1124 .unwrap();
1125 schema
1126 }
1127
1128 fn parse_and_annotate(sql: &str, schema: &MappingSchema) -> (Statement, TypeAnnotations) {
1129 let stmt = Parser::new(sql).unwrap().parse_statement().unwrap();
1130 let ann = annotate_types(&stmt, schema);
1131 (stmt, ann)
1132 }
1133
1134 fn first_col_type(stmt: &Statement, ann: &TypeAnnotations) -> Option<DataType> {
1136 if let Statement::Select(sel) = stmt {
1137 if let Some(SelectItem::Expr { expr, .. }) = sel.columns.first() {
1138 return ann.get_type(expr).cloned();
1139 }
1140 }
1141 None
1142 }
1143
1144 #[test]
1147 fn test_number_literal_int() {
1148 let schema = setup_schema();
1149 let (stmt, ann) = parse_and_annotate("SELECT 42", &schema);
1150 assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Int));
1151 }
1152
1153 #[test]
1154 fn test_number_literal_big_int() {
1155 let schema = setup_schema();
1156 let (stmt, ann) = parse_and_annotate("SELECT 9999999999", &schema);
1157 assert_eq!(first_col_type(&stmt, &ann), Some(DataType::BigInt));
1158 }
1159
1160 #[test]
1161 fn test_number_literal_double() {
1162 let schema = setup_schema();
1163 let (stmt, ann) = parse_and_annotate("SELECT 3.14", &schema);
1164 assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Double));
1165 }
1166
1167 #[test]
1168 fn test_string_literal() {
1169 let schema = setup_schema();
1170 let (stmt, ann) = parse_and_annotate("SELECT 'hello'", &schema);
1171 assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Varchar(None)));
1172 }
1173
1174 #[test]
1175 fn test_boolean_literal() {
1176 let schema = setup_schema();
1177 let (stmt, ann) = parse_and_annotate("SELECT TRUE", &schema);
1178 assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Boolean));
1179 }
1180
1181 #[test]
1182 fn test_null_literal() {
1183 let schema = setup_schema();
1184 let (stmt, ann) = parse_and_annotate("SELECT NULL", &schema);
1185 assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Null));
1186 }
1187
1188 #[test]
1191 fn test_column_type_from_schema() {
1192 let schema = setup_schema();
1193 let (stmt, ann) = parse_and_annotate("SELECT id FROM users", &schema);
1194 assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Int));
1195 }
1196
1197 #[test]
1198 fn test_qualified_column_type() {
1199 let schema = setup_schema();
1200 let (stmt, ann) = parse_and_annotate("SELECT users.name FROM users", &schema);
1201 assert_eq!(
1202 first_col_type(&stmt, &ann),
1203 Some(DataType::Varchar(Some(255)))
1204 );
1205 }
1206
1207 #[test]
1208 fn test_aliased_table_column_type() {
1209 let schema = setup_schema();
1210 let (stmt, ann) = parse_and_annotate("SELECT u.salary FROM users AS u", &schema);
1211 assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Double));
1212 }
1213
1214 #[test]
1217 fn test_int_plus_int() {
1218 let schema = setup_schema();
1219 let (stmt, ann) = parse_and_annotate("SELECT id + age FROM users", &schema);
1220 assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Int));
1221 }
1222
1223 #[test]
1224 fn test_int_plus_double() {
1225 let schema = setup_schema();
1226 let (stmt, ann) = parse_and_annotate("SELECT id + salary FROM users", &schema);
1227 assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Double));
1228 }
1229
1230 #[test]
1231 fn test_comparison_returns_boolean() {
1232 let schema = setup_schema();
1233 let (stmt, ann) = parse_and_annotate("SELECT id > 5 FROM users", &schema);
1234 assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Boolean));
1235 }
1236
1237 #[test]
1238 fn test_and_returns_boolean() {
1239 let schema = setup_schema();
1240 let (stmt, ann) = parse_and_annotate("SELECT id > 5 AND age < 30 FROM users", &schema);
1241 assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Boolean));
1242 }
1243
1244 #[test]
1247 fn test_cast_type() {
1248 let schema = setup_schema();
1249 let (stmt, ann) = parse_and_annotate("SELECT CAST(id AS BIGINT) FROM users", &schema);
1250 assert_eq!(first_col_type(&stmt, &ann), Some(DataType::BigInt));
1251 }
1252
1253 #[test]
1254 fn test_cast_to_varchar() {
1255 let schema = setup_schema();
1256 let (stmt, ann) = parse_and_annotate("SELECT CAST(id AS VARCHAR) FROM users", &schema);
1257 assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Varchar(None)));
1258 }
1259
1260 #[test]
1263 fn test_case_expression_type() {
1264 let schema = setup_schema();
1265 let (stmt, ann) = parse_and_annotate(
1266 "SELECT CASE WHEN id > 1 THEN salary ELSE 0.0 END FROM users",
1267 &schema,
1268 );
1269 let t = first_col_type(&stmt, &ann);
1270 assert!(
1271 matches!(t, Some(DataType::Double)),
1272 "Expected Double, got {t:?}"
1273 );
1274 }
1275
1276 #[test]
1279 fn test_count_returns_bigint() {
1280 let schema = setup_schema();
1281 let (stmt, ann) = parse_and_annotate("SELECT COUNT(*) FROM users", &schema);
1282 assert_eq!(first_col_type(&stmt, &ann), Some(DataType::BigInt));
1283 }
1284
1285 #[test]
1286 fn test_sum_returns_bigint_for_int() {
1287 let schema = setup_schema();
1288 let (stmt, ann) = parse_and_annotate("SELECT SUM(id) FROM users", &schema);
1289 assert_eq!(first_col_type(&stmt, &ann), Some(DataType::BigInt));
1290 }
1291
1292 #[test]
1293 fn test_avg_returns_double() {
1294 let schema = setup_schema();
1295 let (stmt, ann) = parse_and_annotate("SELECT AVG(age) FROM users", &schema);
1296 assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Double));
1297 }
1298
1299 #[test]
1300 fn test_min_preserves_type() {
1301 let schema = setup_schema();
1302 let (stmt, ann) = parse_and_annotate("SELECT MIN(salary) FROM users", &schema);
1303 assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Double));
1304 }
1305
1306 #[test]
1307 fn test_upper_returns_varchar() {
1308 let schema = setup_schema();
1309 let (stmt, ann) = parse_and_annotate("SELECT UPPER(name) FROM users", &schema);
1310 assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Varchar(None)));
1311 }
1312
1313 #[test]
1314 fn test_length_returns_int() {
1315 let schema = setup_schema();
1316 let (stmt, ann) = parse_and_annotate("SELECT LENGTH(name) FROM users", &schema);
1317 assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Int));
1318 }
1319
1320 #[test]
1323 fn test_between_returns_boolean() {
1324 let schema = setup_schema();
1325 let (stmt, ann) = parse_and_annotate("SELECT age BETWEEN 18 AND 65 FROM users", &schema);
1326 assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Boolean));
1327 }
1328
1329 #[test]
1330 fn test_in_list_returns_boolean() {
1331 let schema = setup_schema();
1332 let (stmt, ann) = parse_and_annotate("SELECT id IN (1, 2, 3) FROM users", &schema);
1333 assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Boolean));
1334 }
1335
1336 #[test]
1337 fn test_is_null_returns_boolean() {
1338 let schema = setup_schema();
1339 let (stmt, ann) = parse_and_annotate("SELECT name IS NULL FROM users", &schema);
1340 assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Boolean));
1341 }
1342
1343 #[test]
1344 fn test_like_returns_boolean() {
1345 let schema = setup_schema();
1346 let (stmt, ann) = parse_and_annotate("SELECT name LIKE '%test%' FROM users", &schema);
1347 assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Boolean));
1348 }
1349
1350 #[test]
1353 fn test_exists_returns_boolean() {
1354 let schema = setup_schema();
1355 let (stmt, ann) =
1356 parse_and_annotate("SELECT EXISTS (SELECT 1 FROM orders) FROM users", &schema);
1357 assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Boolean));
1358 }
1359
1360 #[test]
1363 fn test_nested_expression_propagation() {
1364 let schema = setup_schema();
1365 let (stmt, ann) = parse_and_annotate("SELECT (id + age) * salary FROM users", &schema);
1366 let t = first_col_type(&stmt, &ann);
1367 assert!(
1369 matches!(t, Some(DataType::Double)),
1370 "Expected Double, got {t:?}"
1371 );
1372 }
1373
1374 #[test]
1377 fn test_extract_returns_int() {
1378 let schema = setup_schema();
1379 let (stmt, ann) =
1380 parse_and_annotate("SELECT EXTRACT(YEAR FROM created_at) FROM users", &schema);
1381 assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Int));
1382 }
1383
1384 #[test]
1387 fn test_multiple_columns_annotated() {
1388 let schema = setup_schema();
1389 let (stmt, ann) = parse_and_annotate("SELECT id, name, salary FROM users", &schema);
1390 if let Statement::Select(sel) = &stmt {
1391 assert_eq!(sel.columns.len(), 3);
1392 if let SelectItem::Expr { expr, .. } = &sel.columns[0] {
1394 assert_eq!(ann.get_type(expr), Some(&DataType::Int));
1395 }
1396 if let SelectItem::Expr { expr, .. } = &sel.columns[1] {
1398 assert_eq!(ann.get_type(expr), Some(&DataType::Varchar(Some(255))));
1399 }
1400 if let SelectItem::Expr { expr, .. } = &sel.columns[2] {
1402 assert_eq!(ann.get_type(expr), Some(&DataType::Double));
1403 }
1404 }
1405 }
1406
1407 #[test]
1410 fn test_where_clause_annotated() {
1411 let schema = setup_schema();
1412 let stmt = Parser::new("SELECT id FROM users WHERE age > 21")
1415 .unwrap()
1416 .parse_statement()
1417 .unwrap();
1418 let ann = annotate_types(&stmt, &schema);
1419 if let Statement::Select(sel) = &stmt {
1420 if let Some(wh) = &sel.where_clause {
1421 assert_eq!(ann.get_type(wh), Some(&DataType::Boolean));
1422 }
1423 }
1424 }
1425
1426 #[test]
1429 fn test_int_and_bigint_coercion() {
1430 assert_eq!(
1431 coerce_numeric(&DataType::Int, &DataType::BigInt),
1432 DataType::BigInt
1433 );
1434 }
1435
1436 #[test]
1437 fn test_float_and_double_coercion() {
1438 assert_eq!(
1439 coerce_numeric(&DataType::Float, &DataType::Double),
1440 DataType::Double
1441 );
1442 }
1443
1444 #[test]
1445 fn test_int_and_double_coercion() {
1446 assert_eq!(
1447 coerce_numeric(&DataType::Int, &DataType::Double),
1448 DataType::Double
1449 );
1450 }
1451
1452 #[test]
1455 fn test_common_type_nulls_skipped() {
1456 let types = vec![&DataType::Null, &DataType::Int, &DataType::Null];
1457 assert_eq!(common_type(&types), Some(DataType::Int));
1458 }
1459
1460 #[test]
1461 fn test_common_type_numeric_widening() {
1462 let types = vec![&DataType::Int, &DataType::Double, &DataType::Float];
1463 assert_eq!(common_type(&types), Some(DataType::Double));
1464 }
1465
1466 #[test]
1467 fn test_common_type_empty() {
1468 let types: Vec<&DataType> = vec![];
1469 assert_eq!(common_type(&types), None);
1470 }
1471
1472 #[test]
1475 fn test_udf_return_type() {
1476 let mut schema = setup_schema();
1477 schema.add_udf("my_func", DataType::Varchar(None));
1478 let (stmt, ann) = parse_and_annotate("SELECT my_func(id) FROM users", &schema);
1479 assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Varchar(None)));
1480 }
1481
1482 #[test]
1485 fn test_annotations_not_empty() {
1486 let schema = setup_schema();
1487 let (_, ann) = parse_and_annotate("SELECT id, name FROM users WHERE age > 21", &schema);
1488 assert!(!ann.is_empty());
1489 assert!(ann.len() >= 3);
1491 }
1492
1493 #[test]
1496 fn test_sum_decimal_preserves_type() {
1497 let schema = setup_schema();
1498 let (stmt, ann) = parse_and_annotate("SELECT SUM(amount) FROM orders", &schema);
1499 assert_eq!(
1500 first_col_type(&stmt, &ann),
1501 Some(DataType::Decimal {
1502 precision: Some(10),
1503 scale: Some(2)
1504 })
1505 );
1506 }
1507}