1use std::collections::HashMap;
38
39use crate::ast::*;
40use crate::schema::Schema;
41
42pub struct TypeAnnotations {
52 types: HashMap<*const Expr, DataType>,
53}
54
55unsafe impl Send for TypeAnnotations {}
58unsafe impl Sync for TypeAnnotations {}
59
60impl TypeAnnotations {
61 fn new() -> Self {
62 Self {
63 types: HashMap::new(),
64 }
65 }
66
67 fn set(&mut self, expr: &Expr, dt: DataType) {
68 self.types.insert(expr as *const Expr, dt);
69 }
70
71 #[must_use]
73 pub fn get_type(&self, expr: &Expr) -> Option<&DataType> {
74 self.types.get(&(expr as *const Expr))
75 }
76
77 #[must_use]
79 pub fn len(&self) -> usize {
80 self.types.len()
81 }
82
83 #[must_use]
85 pub fn is_empty(&self) -> bool {
86 self.types.is_empty()
87 }
88}
89
90#[must_use]
102pub fn annotate_types<S: Schema>(stmt: &Statement, schema: &S) -> TypeAnnotations {
103 let mut ann = TypeAnnotations::new();
104 let mut ctx = AnnotationContext::new(schema);
105 annotate_statement(stmt, &mut ctx, &mut ann);
106 ann
107}
108
109struct AnnotationContext<'s, S: Schema> {
115 schema: &'s S,
116 table_aliases: HashMap<String, Vec<String>>,
118}
119
120impl<'s, S: Schema> AnnotationContext<'s, S> {
121 fn new(schema: &'s S) -> Self {
122 Self {
123 schema,
124 table_aliases: HashMap::new(),
125 }
126 }
127
128 fn register_table(&mut self, table_ref: &TableRef) {
130 let path = vec![table_ref.name.clone()];
131 let alias = table_ref
132 .alias
133 .as_deref()
134 .unwrap_or(&table_ref.name)
135 .to_string();
136 self.table_aliases.insert(alias, path);
137 }
138
139 fn resolve_column_type(&self, table: Option<&str>, column: &str) -> Option<DataType> {
141 if let Some(tbl) = table {
142 if let Some(path) = self.table_aliases.get(tbl) {
144 let path_refs: Vec<&str> = path.iter().map(String::as_str).collect();
145 return self.schema.get_column_type(&path_refs, column).ok();
146 }
147 return self.schema.get_column_type(&[tbl], column).ok();
149 }
150 for path in self.table_aliases.values() {
152 let path_refs: Vec<&str> = path.iter().map(String::as_str).collect();
153 if let Ok(dt) = self.schema.get_column_type(&path_refs, column) {
154 return Some(dt);
155 }
156 }
157 None
158 }
159}
160
161fn annotate_statement<S: Schema>(
166 stmt: &Statement,
167 ctx: &mut AnnotationContext<S>,
168 ann: &mut TypeAnnotations,
169) {
170 match stmt {
171 Statement::Select(sel) => annotate_select(sel, ctx, ann),
172 Statement::SetOperation(set_op) => {
173 annotate_statement(&set_op.left, ctx, ann);
174 annotate_statement(&set_op.right, ctx, ann);
175 }
176 Statement::Insert(ins) => {
177 if let InsertSource::Query(q) = &ins.source {
178 annotate_statement(q, ctx, ann);
179 }
180 for row in match &ins.source {
181 InsertSource::Values(rows) => rows.as_slice(),
182 _ => &[],
183 } {
184 for expr in row {
185 annotate_expr(expr, ctx, ann);
186 }
187 }
188 }
189 Statement::Update(upd) => {
190 for (_, expr) in &upd.assignments {
191 annotate_expr(expr, ctx, ann);
192 }
193 if let Some(wh) = &upd.where_clause {
194 annotate_expr(wh, ctx, ann);
195 }
196 }
197 Statement::Delete(del) => {
198 if let Some(wh) = &del.where_clause {
199 annotate_expr(wh, ctx, ann);
200 }
201 }
202 Statement::Expression(expr) => {
203 annotate_expr(expr, ctx, ann);
204 }
205 Statement::Explain(expl) => {
206 annotate_statement(&expl.statement, ctx, ann);
207 }
208 _ => {}
210 }
211}
212
213fn annotate_select<S: Schema>(
214 sel: &SelectStatement,
215 ctx: &mut AnnotationContext<S>,
216 ann: &mut TypeAnnotations,
217) {
218 for cte in &sel.ctes {
220 annotate_statement(&cte.query, ctx, ann);
221 }
222
223 if let Some(from) = &sel.from {
225 register_table_source(&from.source, ctx);
226 }
227 for join in &sel.joins {
228 register_table_source(&join.table, ctx);
229 }
230
231 if let Some(wh) = &sel.where_clause {
233 annotate_expr(wh, ctx, ann);
234 }
235
236 for item in &sel.columns {
238 if let SelectItem::Expr { expr, .. } = item {
239 annotate_expr(expr, ctx, ann);
240 }
241 }
242
243 for expr in &sel.group_by {
245 annotate_expr(expr, ctx, ann);
246 }
247
248 if let Some(having) = &sel.having {
250 annotate_expr(having, ctx, ann);
251 }
252
253 for ob in &sel.order_by {
255 annotate_expr(&ob.expr, ctx, ann);
256 }
257
258 if let Some(limit) = &sel.limit {
260 annotate_expr(limit, ctx, ann);
261 }
262 if let Some(offset) = &sel.offset {
263 annotate_expr(offset, ctx, ann);
264 }
265 if let Some(fetch) = &sel.fetch_first {
266 annotate_expr(fetch, ctx, ann);
267 }
268
269 if let Some(qualify) = &sel.qualify {
271 annotate_expr(qualify, ctx, ann);
272 }
273
274 for join in &sel.joins {
276 if let Some(on) = &join.on {
277 annotate_expr(on, ctx, ann);
278 }
279 }
280}
281
282fn register_table_source<S: Schema>(source: &TableSource, ctx: &mut AnnotationContext<S>) {
283 match source {
284 TableSource::Table(tref) => ctx.register_table(tref),
285 TableSource::Subquery { alias, .. } => {
286 let _ = alias;
289 }
290 TableSource::TableFunction { alias, .. } => {
291 let _ = alias;
292 }
293 TableSource::Lateral { source } => register_table_source(source, ctx),
294 TableSource::Pivot { source, .. } | TableSource::Unpivot { source, .. } => {
295 register_table_source(source, ctx);
296 }
297 TableSource::Unnest { .. } => {}
298 }
299}
300
301fn annotate_expr<S: Schema>(expr: &Expr, ctx: &AnnotationContext<S>, ann: &mut TypeAnnotations) {
306 annotate_children(expr, ctx, ann);
308
309 let dt = infer_type(expr, ctx, ann);
310 if let Some(t) = dt {
311 ann.set(expr, t);
312 }
313}
314
315fn annotate_children<S: Schema>(
317 expr: &Expr,
318 ctx: &AnnotationContext<S>,
319 ann: &mut TypeAnnotations,
320) {
321 match expr {
322 Expr::BinaryOp { left, right, .. } => {
323 annotate_expr(left, ctx, ann);
324 annotate_expr(right, ctx, ann);
325 }
326 Expr::UnaryOp { expr: inner, .. } => annotate_expr(inner, ctx, ann),
327 Expr::Function { args, filter, .. } => {
328 for arg in args {
329 annotate_expr(arg, ctx, ann);
330 }
331 if let Some(f) = filter {
332 annotate_expr(f, ctx, ann);
333 }
334 }
335 Expr::Between {
336 expr: e, low, high, ..
337 } => {
338 annotate_expr(e, ctx, ann);
339 annotate_expr(low, ctx, ann);
340 annotate_expr(high, ctx, ann);
341 }
342 Expr::InList { expr: e, list, .. } => {
343 annotate_expr(e, ctx, ann);
344 for item in list {
345 annotate_expr(item, ctx, ann);
346 }
347 }
348 Expr::InSubquery {
349 expr: e, subquery, ..
350 } => {
351 annotate_expr(e, ctx, ann);
352 let mut sub_ctx = AnnotationContext::new(ctx.schema);
353 annotate_statement(subquery, &mut sub_ctx, ann);
354 }
355 Expr::IsNull { expr: e, .. } | Expr::IsBool { expr: e, .. } => {
356 annotate_expr(e, ctx, ann);
357 }
358 Expr::Like {
359 expr: e,
360 pattern,
361 escape,
362 ..
363 }
364 | Expr::ILike {
365 expr: e,
366 pattern,
367 escape,
368 ..
369 } => {
370 annotate_expr(e, ctx, ann);
371 annotate_expr(pattern, ctx, ann);
372 if let Some(esc) = escape {
373 annotate_expr(esc, ctx, ann);
374 }
375 }
376 Expr::Case {
377 operand,
378 when_clauses,
379 else_clause,
380 } => {
381 if let Some(op) = operand {
382 annotate_expr(op, ctx, ann);
383 }
384 for (cond, result) in when_clauses {
385 annotate_expr(cond, ctx, ann);
386 annotate_expr(result, ctx, ann);
387 }
388 if let Some(el) = else_clause {
389 annotate_expr(el, ctx, ann);
390 }
391 }
392 Expr::Nested(inner) => annotate_expr(inner, ctx, ann),
393 Expr::Cast { expr: e, .. } | Expr::TryCast { expr: e, .. } => {
394 annotate_expr(e, ctx, ann);
395 }
396 Expr::Extract { expr: e, .. } => annotate_expr(e, ctx, ann),
397 Expr::Interval { value, .. } => annotate_expr(value, ctx, ann),
398 Expr::ArrayLiteral(items) | Expr::Tuple(items) | Expr::Coalesce(items) => {
399 for item in items {
400 annotate_expr(item, ctx, ann);
401 }
402 }
403 Expr::If {
404 condition,
405 true_val,
406 false_val,
407 } => {
408 annotate_expr(condition, ctx, ann);
409 annotate_expr(true_val, ctx, ann);
410 if let Some(fv) = false_val {
411 annotate_expr(fv, ctx, ann);
412 }
413 }
414 Expr::NullIf { expr: e, r#else } => {
415 annotate_expr(e, ctx, ann);
416 annotate_expr(r#else, ctx, ann);
417 }
418 Expr::Collate { expr: e, .. } => annotate_expr(e, ctx, ann),
419 Expr::Alias { expr: e, .. } => annotate_expr(e, ctx, ann),
420 Expr::ArrayIndex { expr: e, index } => {
421 annotate_expr(e, ctx, ann);
422 annotate_expr(index, ctx, ann);
423 }
424 Expr::JsonAccess { expr: e, path, .. } => {
425 annotate_expr(e, ctx, ann);
426 annotate_expr(path, ctx, ann);
427 }
428 Expr::Lambda { body, .. } => annotate_expr(body, ctx, ann),
429 Expr::AnyOp { expr: e, right, .. } | Expr::AllOp { expr: e, right, .. } => {
430 annotate_expr(e, ctx, ann);
431 annotate_expr(right, ctx, ann);
432 }
433 Expr::Subquery(sub) => {
434 let mut sub_ctx = AnnotationContext::new(ctx.schema);
435 annotate_statement(sub, &mut sub_ctx, ann);
436 }
437 Expr::Exists { subquery, .. } => {
438 let mut sub_ctx = AnnotationContext::new(ctx.schema);
439 annotate_statement(subquery, &mut sub_ctx, ann);
440 }
441 Expr::TypedFunction { func, filter, .. } => {
442 annotate_typed_function_children(func, ctx, ann);
443 if let Some(f) = filter {
444 annotate_expr(f, ctx, ann);
445 }
446 }
447 Expr::Cube { exprs } | Expr::Rollup { exprs } | Expr::GroupingSets { sets: exprs } => {
448 for item in exprs {
449 annotate_expr(item, ctx, ann);
450 }
451 }
452 Expr::Column { .. }
454 | Expr::Number(_)
455 | Expr::StringLiteral(_)
456 | Expr::NationalStringLiteral(_)
457 | Expr::Boolean(_)
458 | Expr::Null
459 | Expr::Wildcard
460 | Expr::Star
461 | Expr::Parameter(_)
462 | Expr::TypeExpr(_)
463 | Expr::QualifiedWildcard { .. }
464 | Expr::Default
465 | Expr::Commented { .. }
466 | Expr::SimilarTo { .. } => {}
467 }
468}
469
470fn annotate_typed_function_children<S: Schema>(
472 func: &TypedFunction,
473 ctx: &AnnotationContext<S>,
474 ann: &mut TypeAnnotations,
475) {
476 func.walk_children(&mut |child| {
478 annotate_expr(child, ctx, ann);
479 true
480 });
481}
482
483fn infer_type<S: Schema>(
488 expr: &Expr,
489 ctx: &AnnotationContext<S>,
490 ann: &TypeAnnotations,
491) -> Option<DataType> {
492 match expr {
493 Expr::Number(s) => Some(infer_number_type(s)),
495 Expr::StringLiteral(_) | Expr::NationalStringLiteral(_) => Some(DataType::Varchar(None)),
496 Expr::Boolean(_) => Some(DataType::Boolean),
497 Expr::Null => Some(DataType::Null),
498
499 Expr::Column { table, name, .. } => ctx.resolve_column_type(table.as_deref(), name),
501
502 Expr::BinaryOp { left, op, right } => {
504 infer_binary_op_type(op, ann.get_type(left), ann.get_type(right))
505 }
506
507 Expr::UnaryOp { op, expr: inner } => match op {
509 UnaryOperator::Not => Some(DataType::Boolean),
510 UnaryOperator::Minus | UnaryOperator::Plus => ann.get_type(inner).cloned(),
511 UnaryOperator::BitwiseNot => ann.get_type(inner).cloned(),
512 },
513
514 Expr::Cast { data_type, .. } | Expr::TryCast { data_type, .. } => Some(data_type.clone()),
516
517 Expr::Case {
519 when_clauses,
520 else_clause,
521 ..
522 } => {
523 let mut result_types: Vec<&DataType> = Vec::new();
524 for (_, result) in when_clauses {
525 if let Some(t) = ann.get_type(result) {
526 result_types.push(t);
527 }
528 }
529 if let Some(el) = else_clause {
530 if let Some(t) = ann.get_type(el.as_ref()) {
531 result_types.push(t);
532 }
533 }
534 common_type(&result_types)
535 }
536
537 Expr::If {
539 true_val,
540 false_val,
541 ..
542 } => {
543 let mut types = Vec::new();
544 if let Some(t) = ann.get_type(true_val) {
545 types.push(t);
546 }
547 if let Some(fv) = false_val {
548 if let Some(t) = ann.get_type(fv.as_ref()) {
549 types.push(t);
550 }
551 }
552 common_type(&types)
553 }
554
555 Expr::Coalesce(items) => {
557 let types: Vec<&DataType> = items.iter().filter_map(|e| ann.get_type(e)).collect();
558 common_type(&types)
559 }
560
561 Expr::NullIf { expr: e, .. } => ann.get_type(e.as_ref()).cloned(),
563
564 Expr::Function { name, args, .. } => infer_generic_function_type(name, args, ctx, ann),
566
567 Expr::TypedFunction { func, .. } => infer_typed_function_type(func, ann),
569
570 Expr::Subquery(sub) => infer_subquery_type(sub, ann),
572
573 Expr::Exists { .. } => Some(DataType::Boolean),
575
576 Expr::Between { .. }
578 | Expr::InList { .. }
579 | Expr::InSubquery { .. }
580 | Expr::IsNull { .. }
581 | Expr::IsBool { .. }
582 | Expr::Like { .. }
583 | Expr::ILike { .. }
584 | Expr::AnyOp { .. }
585 | Expr::AllOp { .. } => Some(DataType::Boolean),
586
587 Expr::Extract { .. } => Some(DataType::Int),
589
590 Expr::Interval { .. } => Some(DataType::Interval),
592
593 Expr::ArrayLiteral(items) => {
595 let elem_types: Vec<&DataType> = items.iter().filter_map(|e| ann.get_type(e)).collect();
596 let elem = common_type(&elem_types);
597 Some(DataType::Array(elem.map(Box::new)))
598 }
599
600 Expr::Tuple(items) => {
602 let types: Vec<DataType> = items
603 .iter()
604 .map(|e| ann.get_type(e).cloned().unwrap_or(DataType::Null))
605 .collect();
606 Some(DataType::Tuple(types))
607 }
608
609 Expr::ArrayIndex { expr: e, .. } => match ann.get_type(e.as_ref()) {
611 Some(DataType::Array(Some(elem))) => Some(elem.as_ref().clone()),
612 _ => None,
613 },
614
615 Expr::JsonAccess { as_text, .. } => {
617 if *as_text {
618 Some(DataType::Text)
619 } else {
620 Some(DataType::Json)
621 }
622 }
623
624 Expr::Nested(inner) => ann.get_type(inner.as_ref()).cloned(),
626 Expr::Alias { expr: e, .. } => ann.get_type(e.as_ref()).cloned(),
627
628 Expr::Collate { .. } => Some(DataType::Varchar(None)),
630
631 Expr::TypeExpr(dt) => Some(dt.clone()),
633
634 Expr::Wildcard
636 | Expr::Star
637 | Expr::QualifiedWildcard { .. }
638 | Expr::Parameter(_)
639 | Expr::Lambda { .. }
640 | Expr::Default
641 | Expr::Cube { .. }
642 | Expr::Rollup { .. }
643 | Expr::GroupingSets { .. }
644 | Expr::SimilarTo { .. }
645 | Expr::Commented { .. } => None,
646 }
647}
648
649fn infer_number_type(s: &str) -> DataType {
654 if s.contains('.') || s.contains('e') || s.contains('E') {
655 DataType::Double
656 } else if let Ok(v) = s.parse::<i64>() {
657 if v >= i32::MIN as i64 && v <= i32::MAX as i64 {
658 DataType::Int
659 } else {
660 DataType::BigInt
661 }
662 } else {
663 DataType::BigInt
665 }
666}
667
668fn infer_binary_op_type(
673 op: &BinaryOperator,
674 left: Option<&DataType>,
675 right: Option<&DataType>,
676) -> Option<DataType> {
677 use BinaryOperator::*;
678 match op {
679 Eq | Neq | Lt | Gt | LtEq | GtEq | AtArrow | ArrowAt => Some(DataType::Boolean),
681
682 And | Or | Xor => Some(DataType::Boolean),
684
685 Concat => Some(DataType::Varchar(None)),
687
688 Plus | Minus | Multiply | Divide | Modulo => match (left, right) {
690 (Some(l), Some(r)) => Some(coerce_numeric(l, r)),
691 (Some(l), None) => Some(l.clone()),
692 (None, Some(r)) => Some(r.clone()),
693 (None, None) => None,
694 },
695
696 BitwiseAnd | BitwiseOr | BitwiseXor | ShiftLeft | ShiftRight => match (left, right) {
698 (Some(l), Some(r)) => Some(coerce_numeric(l, r)),
699 (Some(l), None) => Some(l.clone()),
700 (None, Some(r)) => Some(r.clone()),
701 (None, None) => Some(DataType::Int),
702 },
703
704 Arrow => Some(DataType::Json),
706 DoubleArrow => Some(DataType::Text),
707 }
708}
709
710fn infer_generic_function_type<S: Schema>(
715 name: &str,
716 args: &[Expr],
717 ctx: &AnnotationContext<S>,
718 ann: &TypeAnnotations,
719) -> Option<DataType> {
720 let upper = name.to_uppercase();
721 match upper.as_str() {
722 "COUNT" | "COUNT_BIG" => Some(DataType::BigInt),
724 "SUM" => args
725 .first()
726 .and_then(|a| ann.get_type(a))
727 .map(|t| coerce_sum_type(t)),
728 "AVG" => Some(DataType::Double),
729 "MIN" | "MAX" => args.first().and_then(|a| ann.get_type(a)).cloned(),
730 "VARIANCE" | "VAR_SAMP" | "VAR_POP" | "STDDEV" | "STDDEV_SAMP" | "STDDEV_POP" => {
731 Some(DataType::Double)
732 }
733 "APPROX_COUNT_DISTINCT" | "APPROX_DISTINCT" => Some(DataType::BigInt),
734
735 "CONCAT" | "UPPER" | "LOWER" | "TRIM" | "LTRIM" | "RTRIM" | "LPAD" | "RPAD" | "REPLACE"
737 | "REVERSE" | "SUBSTRING" | "SUBSTR" | "LEFT" | "RIGHT" | "INITCAP" | "REPEAT"
738 | "TRANSLATE" | "FORMAT" | "CONCAT_WS" | "SPACE" | "REPLICATE" => {
739 Some(DataType::Varchar(None))
740 }
741 "LENGTH" | "LEN" | "CHAR_LENGTH" | "CHARACTER_LENGTH" | "OCTET_LENGTH" | "BIT_LENGTH" => {
742 Some(DataType::Int)
743 }
744 "POSITION" | "STRPOS" | "LOCATE" | "INSTR" | "CHARINDEX" => Some(DataType::Int),
745 "ASCII" => Some(DataType::Int),
746 "CHR" | "CHAR" => Some(DataType::Varchar(Some(1))),
747
748 "ABS" | "CEIL" | "CEILING" | "FLOOR" => args.first().and_then(|a| ann.get_type(a)).cloned(),
750 "ROUND" | "TRUNCATE" | "TRUNC" => args.first().and_then(|a| ann.get_type(a)).cloned(),
751 "SQRT" | "LN" | "LOG" | "LOG2" | "LOG10" | "EXP" | "POWER" | "POW" | "ACOS" | "ASIN"
752 | "ATAN" | "ATAN2" | "COS" | "SIN" | "TAN" | "COT" | "DEGREES" | "RADIANS" | "PI"
753 | "SIGN" => Some(DataType::Double),
754 "MOD" => {
755 match (
756 args.first().and_then(|a| ann.get_type(a)),
757 args.get(1).and_then(|a| ann.get_type(a)),
758 ) {
759 (Some(l), Some(r)) => Some(coerce_numeric(l, r)),
760 (Some(l), _) => Some(l.clone()),
761 (_, Some(r)) => Some(r.clone()),
762 _ => Some(DataType::Int),
763 }
764 }
765 "GREATEST" | "LEAST" => {
766 let types: Vec<&DataType> = args.iter().filter_map(|a| ann.get_type(a)).collect();
767 common_type(&types)
768 }
769 "RANDOM" | "RAND" => Some(DataType::Double),
770
771 "CURRENT_DATE" | "CURDATE" | "TODAY" => Some(DataType::Date),
773 "CURRENT_TIMESTAMP" | "NOW" | "GETDATE" | "SYSDATE" | "SYSTIMESTAMP" | "LOCALTIMESTAMP" => {
774 Some(DataType::Timestamp {
775 precision: None,
776 with_tz: false,
777 })
778 }
779 "CURRENT_TIME" | "CURTIME" => Some(DataType::Time { precision: None }),
780 "DATE" | "TO_DATE" | "DATE_TRUNC" | "DATE_ADD" | "DATE_SUB" | "DATEADD" | "DATESUB"
781 | "ADDDATE" | "SUBDATE" => Some(DataType::Date),
782 "TIMESTAMP" | "TO_TIMESTAMP" => Some(DataType::Timestamp {
783 precision: None,
784 with_tz: false,
785 }),
786 "YEAR" | "MONTH" | "DAY" | "DAYOFWEEK" | "DAYOFYEAR" | "HOUR" | "MINUTE" | "SECOND"
787 | "QUARTER" | "WEEK" | "EXTRACT" | "DATEDIFF" | "TIMESTAMPDIFF" | "MONTHS_BETWEEN" => {
788 Some(DataType::Int)
789 }
790
791 "CAST" | "TRY_CAST" | "SAFE_CAST" | "CONVERT" => None, "COALESCE" => {
796 let types: Vec<&DataType> = args.iter().filter_map(|a| ann.get_type(a)).collect();
797 common_type(&types)
798 }
799 "NULLIF" => args.first().and_then(|a| ann.get_type(a)).cloned(),
800 "IF" | "IIF" => {
801 args.get(1).and_then(|a| ann.get_type(a)).cloned()
803 }
804 "IFNULL" | "NVL" | "ISNULL" => {
805 let types: Vec<&DataType> = args.iter().filter_map(|a| ann.get_type(a)).collect();
806 common_type(&types)
807 }
808
809 "JSON_EXTRACT" | "JSON_QUERY" | "GET_JSON_OBJECT" => Some(DataType::Json),
811 "JSON_EXTRACT_SCALAR" | "JSON_VALUE" | "JSON_EXTRACT_PATH_TEXT" => {
812 Some(DataType::Varchar(None))
813 }
814 "TO_JSON" | "JSON_OBJECT" | "JSON_ARRAY" | "JSON_BUILD_OBJECT" | "JSON_BUILD_ARRAY" => {
815 Some(DataType::Json)
816 }
817 "PARSE_JSON" | "JSON_PARSE" | "JSON" => Some(DataType::Json),
818
819 "ARRAY_AGG" | "COLLECT_LIST" | "COLLECT_SET" => {
821 let elem = args.first().and_then(|a| ann.get_type(a)).cloned();
822 Some(DataType::Array(elem.map(Box::new)))
823 }
824 "ARRAY_LENGTH" | "ARRAY_SIZE" | "CARDINALITY" => Some(DataType::Int),
825 "ARRAY" | "ARRAY_CONSTRUCT" => {
826 let types: Vec<&DataType> = args.iter().filter_map(|a| ann.get_type(a)).collect();
827 let elem = common_type(&types);
828 Some(DataType::Array(elem.map(Box::new)))
829 }
830 "ARRAY_CONTAINS" | "ARRAY_POSITION" => Some(DataType::Boolean),
831
832 "ROW_NUMBER" | "RANK" | "DENSE_RANK" | "NTILE" | "CUME_DIST" | "PERCENT_RANK" => {
834 Some(DataType::BigInt)
835 }
836
837 "MD5" | "SHA1" | "SHA" | "SHA2" | "SHA256" | "SHA512" => Some(DataType::Varchar(None)),
839 "HEX" | "TO_HEX" => Some(DataType::Varchar(None)),
840 "UNHEX" | "FROM_HEX" => Some(DataType::Varbinary(None)),
841 "CRC32" | "HASH" => Some(DataType::BigInt),
842
843 "TYPEOF" | "TYPE_OF" => Some(DataType::Varchar(None)),
845
846 _ => ctx.schema.get_udf_type(&upper).cloned(),
848 }
849}
850
851fn infer_typed_function_type(func: &TypedFunction, ann: &TypeAnnotations) -> Option<DataType> {
856 match func {
857 TypedFunction::DateAdd { .. }
859 | TypedFunction::DateSub { .. }
860 | TypedFunction::DateTrunc { .. }
861 | TypedFunction::TsOrDsToDate { .. } => Some(DataType::Date),
862 TypedFunction::DateDiff { .. } => Some(DataType::Int),
863 TypedFunction::CurrentDate => Some(DataType::Date),
864 TypedFunction::CurrentTime => Some(DataType::Time { precision: None }),
865 TypedFunction::CurrentTimestamp => Some(DataType::Timestamp {
866 precision: None,
867 with_tz: false,
868 }),
869 TypedFunction::StrToTime { .. } => Some(DataType::Timestamp {
870 precision: None,
871 with_tz: false,
872 }),
873 TypedFunction::TimeToStr { .. } => Some(DataType::Varchar(None)),
874 TypedFunction::Year { .. } | TypedFunction::Month { .. } | TypedFunction::Day { .. } => {
875 Some(DataType::Int)
876 }
877
878 TypedFunction::Trim { .. }
880 | TypedFunction::Substring { .. }
881 | TypedFunction::Upper { .. }
882 | TypedFunction::Lower { .. }
883 | TypedFunction::Initcap { .. }
884 | TypedFunction::Replace { .. }
885 | TypedFunction::Reverse { .. }
886 | TypedFunction::Left { .. }
887 | TypedFunction::Right { .. }
888 | TypedFunction::Lpad { .. }
889 | TypedFunction::Rpad { .. }
890 | TypedFunction::ConcatWs { .. } => Some(DataType::Varchar(None)),
891 TypedFunction::Length { .. } => Some(DataType::Int),
892 TypedFunction::RegexpLike { .. } => Some(DataType::Boolean),
893 TypedFunction::RegexpExtract { .. } => Some(DataType::Varchar(None)),
894 TypedFunction::RegexpReplace { .. } => Some(DataType::Varchar(None)),
895 TypedFunction::Split { .. } => {
896 Some(DataType::Array(Some(Box::new(DataType::Varchar(None)))))
897 }
898
899 TypedFunction::Count { .. } => Some(DataType::BigInt),
901 TypedFunction::Sum { expr, .. } => ann.get_type(expr.as_ref()).map(|t| coerce_sum_type(t)),
902 TypedFunction::Avg { .. } => Some(DataType::Double),
903 TypedFunction::Min { expr } | TypedFunction::Max { expr } => {
904 ann.get_type(expr.as_ref()).cloned()
905 }
906 TypedFunction::ArrayAgg { expr, .. } => {
907 let elem = ann.get_type(expr.as_ref()).cloned();
908 Some(DataType::Array(elem.map(Box::new)))
909 }
910 TypedFunction::ApproxDistinct { .. } => Some(DataType::BigInt),
911 TypedFunction::Variance { .. } | TypedFunction::Stddev { .. } => Some(DataType::Double),
912 TypedFunction::GroupConcat { .. } => Some(DataType::Varchar(None)),
913
914 TypedFunction::ArrayConcat { arrays } => {
916 arrays.first().and_then(|a| ann.get_type(a)).cloned()
918 }
919 TypedFunction::ArrayContains { .. } => Some(DataType::Boolean),
920 TypedFunction::ArraySize { .. } => Some(DataType::Int),
921 TypedFunction::Explode { expr } => {
922 match ann.get_type(expr.as_ref()) {
924 Some(DataType::Array(Some(elem))) => Some(elem.as_ref().clone()),
925 _ => None,
926 }
927 }
928 TypedFunction::GenerateSeries { .. } => Some(DataType::Int),
929 TypedFunction::Flatten { expr } => ann.get_type(expr.as_ref()).cloned(),
930
931 TypedFunction::JSONExtract { .. } => Some(DataType::Json),
933 TypedFunction::JSONExtractScalar { .. } => Some(DataType::Varchar(None)),
934 TypedFunction::ParseJSON { .. } | TypedFunction::JSONFormat { .. } => Some(DataType::Json),
935
936 TypedFunction::RowNumber | TypedFunction::Rank | TypedFunction::DenseRank => {
938 Some(DataType::BigInt)
939 }
940 TypedFunction::NTile { .. } => Some(DataType::BigInt),
941 TypedFunction::Lead { expr, .. }
942 | TypedFunction::Lag { expr, .. }
943 | TypedFunction::FirstValue { expr }
944 | TypedFunction::LastValue { expr } => ann.get_type(expr.as_ref()).cloned(),
945
946 TypedFunction::Abs { expr }
948 | TypedFunction::Ceil { expr }
949 | TypedFunction::Floor { expr } => ann.get_type(expr.as_ref()).cloned(),
950 TypedFunction::Round { expr, .. } => ann.get_type(expr.as_ref()).cloned(),
951 TypedFunction::Log { .. }
952 | TypedFunction::Ln { .. }
953 | TypedFunction::Pow { .. }
954 | TypedFunction::Sqrt { .. } => Some(DataType::Double),
955 TypedFunction::Greatest { exprs } | TypedFunction::Least { exprs } => {
956 let types: Vec<&DataType> = exprs.iter().filter_map(|e| ann.get_type(e)).collect();
957 common_type(&types)
958 }
959 TypedFunction::Mod { left, right } => {
960 match (ann.get_type(left.as_ref()), ann.get_type(right.as_ref())) {
961 (Some(l), Some(r)) => Some(coerce_numeric(l, r)),
962 (Some(l), _) => Some(l.clone()),
963 (_, Some(r)) => Some(r.clone()),
964 _ => Some(DataType::Int),
965 }
966 }
967
968 TypedFunction::Hex { .. } | TypedFunction::Md5 { .. } | TypedFunction::Sha { .. } => {
970 Some(DataType::Varchar(None))
971 }
972 TypedFunction::Sha2 { .. } => Some(DataType::Varchar(None)),
973 TypedFunction::Unhex { .. } => Some(DataType::Varbinary(None)),
974 }
975}
976
977fn infer_subquery_type(sub: &Statement, ann: &TypeAnnotations) -> Option<DataType> {
982 if let Statement::Select(sel) = sub {
984 if let Some(SelectItem::Expr { expr, .. }) = sel.columns.first() {
985 return ann.get_type(expr).cloned();
986 }
987 }
988 None
989}
990
991fn numeric_precedence(dt: &DataType) -> u8 {
997 match dt {
998 DataType::Boolean => 1,
999 DataType::TinyInt => 2,
1000 DataType::SmallInt => 3,
1001 DataType::Int | DataType::Serial => 4,
1002 DataType::BigInt | DataType::BigSerial => 5,
1003 DataType::Real | DataType::Float => 6,
1004 DataType::Double => 7,
1005 DataType::Decimal { .. } | DataType::Numeric { .. } => 8,
1006 _ => 0,
1007 }
1008}
1009
1010fn coerce_numeric(left: &DataType, right: &DataType) -> DataType {
1012 let lp = numeric_precedence(left);
1013 let rp = numeric_precedence(right);
1014 if lp == 0 && rp == 0 {
1015 return left.clone();
1017 }
1018 if lp >= rp {
1019 left.clone()
1020 } else {
1021 right.clone()
1022 }
1023}
1024
1025fn coerce_sum_type(input: &DataType) -> DataType {
1027 match input {
1028 DataType::TinyInt | DataType::SmallInt | DataType::Int | DataType::BigInt => {
1029 DataType::BigInt
1030 }
1031 DataType::Float | DataType::Real => DataType::Double,
1032 DataType::Double => DataType::Double,
1033 DataType::Decimal { precision, scale } => DataType::Decimal {
1034 precision: *precision,
1035 scale: *scale,
1036 },
1037 DataType::Numeric { precision, scale } => DataType::Numeric {
1038 precision: *precision,
1039 scale: *scale,
1040 },
1041 _ => DataType::BigInt,
1042 }
1043}
1044
1045fn common_type(types: &[&DataType]) -> Option<DataType> {
1047 if types.is_empty() {
1048 return None;
1049 }
1050 let mut result = types[0];
1051 for t in &types[1..] {
1052 if **t == DataType::Null {
1054 continue;
1055 }
1056 if *result == DataType::Null {
1057 result = t;
1058 continue;
1059 }
1060 let lp = numeric_precedence(result);
1062 let rp = numeric_precedence(t);
1063 if lp > 0 && rp > 0 {
1064 if rp > lp {
1065 result = t;
1066 }
1067 continue;
1068 }
1069 if is_string_type(result) && is_string_type(t) {
1071 result = if matches!(result, DataType::Text) || matches!(t, DataType::Text) {
1072 if matches!(result, DataType::Text) {
1073 result
1074 } else {
1075 t
1076 }
1077 } else {
1078 result };
1080 continue;
1081 }
1082 }
1084 Some(result.clone())
1085}
1086
1087fn is_string_type(dt: &DataType) -> bool {
1088 matches!(
1089 dt,
1090 DataType::Varchar(_) | DataType::Char(_) | DataType::Text | DataType::String
1091 )
1092}
1093
1094#[cfg(test)]
1099mod tests {
1100 use super::*;
1101 use crate::dialects::Dialect;
1102 use crate::parser::Parser;
1103 use crate::schema::{MappingSchema, Schema};
1104
1105 fn setup_schema() -> MappingSchema {
1106 let mut schema = MappingSchema::new(Dialect::Ansi);
1107 schema
1108 .add_table(
1109 &["users"],
1110 vec![
1111 ("id".to_string(), DataType::Int),
1112 ("name".to_string(), DataType::Varchar(Some(255))),
1113 ("age".to_string(), DataType::Int),
1114 ("salary".to_string(), DataType::Double),
1115 ("active".to_string(), DataType::Boolean),
1116 (
1117 "created_at".to_string(),
1118 DataType::Timestamp {
1119 precision: None,
1120 with_tz: false,
1121 },
1122 ),
1123 ],
1124 )
1125 .unwrap();
1126 schema
1127 .add_table(
1128 &["orders"],
1129 vec![
1130 ("id".to_string(), DataType::Int),
1131 ("user_id".to_string(), DataType::Int),
1132 (
1133 "amount".to_string(),
1134 DataType::Decimal {
1135 precision: Some(10),
1136 scale: Some(2),
1137 },
1138 ),
1139 ("status".to_string(), DataType::Varchar(Some(50))),
1140 ],
1141 )
1142 .unwrap();
1143 schema
1144 }
1145
1146 fn parse_and_annotate(sql: &str, schema: &MappingSchema) -> (Statement, TypeAnnotations) {
1147 let stmt = Parser::new(sql).unwrap().parse_statement().unwrap();
1148 let ann = annotate_types(&stmt, schema);
1149 (stmt, ann)
1150 }
1151
1152 fn first_col_type(stmt: &Statement, ann: &TypeAnnotations) -> Option<DataType> {
1154 if let Statement::Select(sel) = stmt {
1155 if let Some(SelectItem::Expr { expr, .. }) = sel.columns.first() {
1156 return ann.get_type(expr).cloned();
1157 }
1158 }
1159 None
1160 }
1161
1162 #[test]
1165 fn test_number_literal_int() {
1166 let schema = setup_schema();
1167 let (stmt, ann) = parse_and_annotate("SELECT 42", &schema);
1168 assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Int));
1169 }
1170
1171 #[test]
1172 fn test_number_literal_big_int() {
1173 let schema = setup_schema();
1174 let (stmt, ann) = parse_and_annotate("SELECT 9999999999", &schema);
1175 assert_eq!(first_col_type(&stmt, &ann), Some(DataType::BigInt));
1176 }
1177
1178 #[test]
1179 fn test_number_literal_double() {
1180 let schema = setup_schema();
1181 let (stmt, ann) = parse_and_annotate("SELECT 3.14", &schema);
1182 assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Double));
1183 }
1184
1185 #[test]
1186 fn test_string_literal() {
1187 let schema = setup_schema();
1188 let (stmt, ann) = parse_and_annotate("SELECT 'hello'", &schema);
1189 assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Varchar(None)));
1190 }
1191
1192 #[test]
1193 fn test_boolean_literal() {
1194 let schema = setup_schema();
1195 let (stmt, ann) = parse_and_annotate("SELECT TRUE", &schema);
1196 assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Boolean));
1197 }
1198
1199 #[test]
1200 fn test_null_literal() {
1201 let schema = setup_schema();
1202 let (stmt, ann) = parse_and_annotate("SELECT NULL", &schema);
1203 assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Null));
1204 }
1205
1206 #[test]
1209 fn test_column_type_from_schema() {
1210 let schema = setup_schema();
1211 let (stmt, ann) = parse_and_annotate("SELECT id FROM users", &schema);
1212 assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Int));
1213 }
1214
1215 #[test]
1216 fn test_qualified_column_type() {
1217 let schema = setup_schema();
1218 let (stmt, ann) = parse_and_annotate("SELECT users.name FROM users", &schema);
1219 assert_eq!(
1220 first_col_type(&stmt, &ann),
1221 Some(DataType::Varchar(Some(255)))
1222 );
1223 }
1224
1225 #[test]
1226 fn test_aliased_table_column_type() {
1227 let schema = setup_schema();
1228 let (stmt, ann) = parse_and_annotate("SELECT u.salary FROM users AS u", &schema);
1229 assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Double));
1230 }
1231
1232 #[test]
1235 fn test_int_plus_int() {
1236 let schema = setup_schema();
1237 let (stmt, ann) = parse_and_annotate("SELECT id + age FROM users", &schema);
1238 assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Int));
1239 }
1240
1241 #[test]
1242 fn test_int_plus_double() {
1243 let schema = setup_schema();
1244 let (stmt, ann) = parse_and_annotate("SELECT id + salary FROM users", &schema);
1245 assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Double));
1246 }
1247
1248 #[test]
1249 fn test_comparison_returns_boolean() {
1250 let schema = setup_schema();
1251 let (stmt, ann) = parse_and_annotate("SELECT id > 5 FROM users", &schema);
1252 assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Boolean));
1253 }
1254
1255 #[test]
1256 fn test_and_returns_boolean() {
1257 let schema = setup_schema();
1258 let (stmt, ann) = parse_and_annotate("SELECT id > 5 AND age < 30 FROM users", &schema);
1259 assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Boolean));
1260 }
1261
1262 #[test]
1265 fn test_cast_type() {
1266 let schema = setup_schema();
1267 let (stmt, ann) = parse_and_annotate("SELECT CAST(id AS BIGINT) FROM users", &schema);
1268 assert_eq!(first_col_type(&stmt, &ann), Some(DataType::BigInt));
1269 }
1270
1271 #[test]
1272 fn test_cast_to_varchar() {
1273 let schema = setup_schema();
1274 let (stmt, ann) = parse_and_annotate("SELECT CAST(id AS VARCHAR) FROM users", &schema);
1275 assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Varchar(None)));
1276 }
1277
1278 #[test]
1281 fn test_case_expression_type() {
1282 let schema = setup_schema();
1283 let (stmt, ann) = parse_and_annotate(
1284 "SELECT CASE WHEN id > 1 THEN salary ELSE 0.0 END FROM users",
1285 &schema,
1286 );
1287 let t = first_col_type(&stmt, &ann);
1288 assert!(
1289 matches!(t, Some(DataType::Double)),
1290 "Expected Double, got {t:?}"
1291 );
1292 }
1293
1294 #[test]
1297 fn test_count_returns_bigint() {
1298 let schema = setup_schema();
1299 let (stmt, ann) = parse_and_annotate("SELECT COUNT(*) FROM users", &schema);
1300 assert_eq!(first_col_type(&stmt, &ann), Some(DataType::BigInt));
1301 }
1302
1303 #[test]
1304 fn test_sum_returns_bigint_for_int() {
1305 let schema = setup_schema();
1306 let (stmt, ann) = parse_and_annotate("SELECT SUM(id) FROM users", &schema);
1307 assert_eq!(first_col_type(&stmt, &ann), Some(DataType::BigInt));
1308 }
1309
1310 #[test]
1311 fn test_avg_returns_double() {
1312 let schema = setup_schema();
1313 let (stmt, ann) = parse_and_annotate("SELECT AVG(age) FROM users", &schema);
1314 assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Double));
1315 }
1316
1317 #[test]
1318 fn test_min_preserves_type() {
1319 let schema = setup_schema();
1320 let (stmt, ann) = parse_and_annotate("SELECT MIN(salary) FROM users", &schema);
1321 assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Double));
1322 }
1323
1324 #[test]
1325 fn test_upper_returns_varchar() {
1326 let schema = setup_schema();
1327 let (stmt, ann) = parse_and_annotate("SELECT UPPER(name) FROM users", &schema);
1328 assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Varchar(None)));
1329 }
1330
1331 #[test]
1332 fn test_length_returns_int() {
1333 let schema = setup_schema();
1334 let (stmt, ann) = parse_and_annotate("SELECT LENGTH(name) FROM users", &schema);
1335 assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Int));
1336 }
1337
1338 #[test]
1341 fn test_between_returns_boolean() {
1342 let schema = setup_schema();
1343 let (stmt, ann) = parse_and_annotate("SELECT age BETWEEN 18 AND 65 FROM users", &schema);
1344 assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Boolean));
1345 }
1346
1347 #[test]
1348 fn test_in_list_returns_boolean() {
1349 let schema = setup_schema();
1350 let (stmt, ann) = parse_and_annotate("SELECT id IN (1, 2, 3) FROM users", &schema);
1351 assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Boolean));
1352 }
1353
1354 #[test]
1355 fn test_is_null_returns_boolean() {
1356 let schema = setup_schema();
1357 let (stmt, ann) = parse_and_annotate("SELECT name IS NULL FROM users", &schema);
1358 assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Boolean));
1359 }
1360
1361 #[test]
1362 fn test_like_returns_boolean() {
1363 let schema = setup_schema();
1364 let (stmt, ann) = parse_and_annotate("SELECT name LIKE '%test%' FROM users", &schema);
1365 assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Boolean));
1366 }
1367
1368 #[test]
1371 fn test_exists_returns_boolean() {
1372 let schema = setup_schema();
1373 let (stmt, ann) =
1374 parse_and_annotate("SELECT EXISTS (SELECT 1 FROM orders) FROM users", &schema);
1375 assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Boolean));
1376 }
1377
1378 #[test]
1381 fn test_nested_expression_propagation() {
1382 let schema = setup_schema();
1383 let (stmt, ann) = parse_and_annotate("SELECT (id + age) * salary FROM users", &schema);
1384 let t = first_col_type(&stmt, &ann);
1385 assert!(
1387 matches!(t, Some(DataType::Double)),
1388 "Expected Double, got {t:?}"
1389 );
1390 }
1391
1392 #[test]
1395 fn test_extract_returns_int() {
1396 let schema = setup_schema();
1397 let (stmt, ann) =
1398 parse_and_annotate("SELECT EXTRACT(YEAR FROM created_at) FROM users", &schema);
1399 assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Int));
1400 }
1401
1402 #[test]
1405 fn test_multiple_columns_annotated() {
1406 let schema = setup_schema();
1407 let (stmt, ann) = parse_and_annotate("SELECT id, name, salary FROM users", &schema);
1408 if let Statement::Select(sel) = &stmt {
1409 assert_eq!(sel.columns.len(), 3);
1410 if let SelectItem::Expr { expr, .. } = &sel.columns[0] {
1412 assert_eq!(ann.get_type(expr), Some(&DataType::Int));
1413 }
1414 if let SelectItem::Expr { expr, .. } = &sel.columns[1] {
1416 assert_eq!(ann.get_type(expr), Some(&DataType::Varchar(Some(255))));
1417 }
1418 if let SelectItem::Expr { expr, .. } = &sel.columns[2] {
1420 assert_eq!(ann.get_type(expr), Some(&DataType::Double));
1421 }
1422 }
1423 }
1424
1425 #[test]
1428 fn test_where_clause_annotated() {
1429 let schema = setup_schema();
1430 let stmt = Parser::new("SELECT id FROM users WHERE age > 21")
1433 .unwrap()
1434 .parse_statement()
1435 .unwrap();
1436 let ann = annotate_types(&stmt, &schema);
1437 if let Statement::Select(sel) = &stmt {
1438 if let Some(wh) = &sel.where_clause {
1439 assert_eq!(ann.get_type(wh), Some(&DataType::Boolean));
1440 }
1441 }
1442 }
1443
1444 #[test]
1447 fn test_int_and_bigint_coercion() {
1448 assert_eq!(
1449 coerce_numeric(&DataType::Int, &DataType::BigInt),
1450 DataType::BigInt
1451 );
1452 }
1453
1454 #[test]
1455 fn test_float_and_double_coercion() {
1456 assert_eq!(
1457 coerce_numeric(&DataType::Float, &DataType::Double),
1458 DataType::Double
1459 );
1460 }
1461
1462 #[test]
1463 fn test_int_and_double_coercion() {
1464 assert_eq!(
1465 coerce_numeric(&DataType::Int, &DataType::Double),
1466 DataType::Double
1467 );
1468 }
1469
1470 #[test]
1473 fn test_common_type_nulls_skipped() {
1474 let types = vec![&DataType::Null, &DataType::Int, &DataType::Null];
1475 assert_eq!(common_type(&types), Some(DataType::Int));
1476 }
1477
1478 #[test]
1479 fn test_common_type_numeric_widening() {
1480 let types = vec![&DataType::Int, &DataType::Double, &DataType::Float];
1481 assert_eq!(common_type(&types), Some(DataType::Double));
1482 }
1483
1484 #[test]
1485 fn test_common_type_empty() {
1486 let types: Vec<&DataType> = vec![];
1487 assert_eq!(common_type(&types), None);
1488 }
1489
1490 #[test]
1493 fn test_udf_return_type() {
1494 let mut schema = setup_schema();
1495 schema.add_udf("my_func", DataType::Varchar(None));
1496 let (stmt, ann) = parse_and_annotate("SELECT my_func(id) FROM users", &schema);
1497 assert_eq!(first_col_type(&stmt, &ann), Some(DataType::Varchar(None)));
1498 }
1499
1500 #[test]
1503 fn test_annotations_not_empty() {
1504 let schema = setup_schema();
1505 let (_, ann) = parse_and_annotate("SELECT id, name FROM users WHERE age > 21", &schema);
1506 assert!(!ann.is_empty());
1507 assert!(ann.len() >= 3);
1509 }
1510
1511 #[test]
1514 fn test_sum_decimal_preserves_type() {
1515 let schema = setup_schema();
1516 let (stmt, ann) = parse_and_annotate("SELECT SUM(amount) FROM orders", &schema);
1517 assert_eq!(
1518 first_col_type(&stmt, &ann),
1519 Some(DataType::Decimal {
1520 precision: Some(10),
1521 scale: Some(2)
1522 })
1523 );
1524 }
1525}