1mod error;
7mod scope;
8mod type_checker;
9
10pub use error::{AnalyzerError, AnalyzerErrorKind};
11pub use scope::{ColumnLookupResult, CteRef, Scope, ScopeColumn, ScopeTable};
12pub use type_checker::{TypeChecker, TypedExpr};
13
14use crate::ast::*;
15use crate::catalog::{Catalog, ColumnSchema, MemoryCatalog, TableSchema};
16use crate::error::{Error, Result};
17use crate::types::SqlType;
18
19pub struct Analyzer<C: Catalog = MemoryCatalog> {
21 catalog: C,
23 scopes: Vec<Scope>,
25 errors: Vec<AnalyzerError>,
27}
28
29#[derive(Debug, Clone)]
31pub struct AnalyzedQuery {
32 pub columns: Vec<OutputColumn>,
34 pub has_aggregation: bool,
36 pub has_window_functions: bool,
38}
39
40#[derive(Debug, Clone)]
42pub struct OutputColumn {
43 pub name: String,
45 pub data_type: SqlType,
47 pub nullable: bool,
49}
50
51impl<C: Catalog> Analyzer<C> {
52 pub fn with_catalog(catalog: C) -> Self {
54 Self {
55 catalog,
56 scopes: vec![Scope::new()],
57 errors: Vec::new(),
58 }
59 }
60
61 pub fn catalog(&self) -> &C {
63 &self.catalog
64 }
65
66 pub fn analyze(&mut self, stmt: &Statement) -> Result<()> {
68 self.errors.clear();
69 self.analyze_statement(stmt)
70 .map_err(|e| Error::analyzer(e.to_string()))
71 }
72
73 pub fn analyze_query_result(&mut self, query: &Query) -> Result<AnalyzedQuery> {
75 self.errors.clear();
76 self.analyze_query_internal(query)
77 .map_err(|e| Error::analyzer(e.to_string()))
78 }
79
80 pub fn errors(&self) -> &[AnalyzerError] {
82 &self.errors
83 }
84
85 fn analyze_statement(&mut self, stmt: &Statement) -> std::result::Result<(), AnalyzerError> {
87 match &stmt.kind {
88 StatementKind::Query(query) => {
89 self.analyze_query_internal(query)?;
90 Ok(())
91 }
92 StatementKind::Insert(insert) => self.analyze_insert(insert),
93 StatementKind::Update(update) => self.analyze_update(update),
94 StatementKind::Delete(delete) => self.analyze_delete(delete),
95 StatementKind::Merge(merge) => self.analyze_merge(merge),
96 StatementKind::CreateTable(create) => self.analyze_create_table(create),
97 StatementKind::CreateView(create) => self.analyze_create_view(create),
98 _ => Ok(()), }
100 }
101
102 fn analyze_query_internal(
104 &mut self,
105 query: &Query,
106 ) -> std::result::Result<AnalyzedQuery, AnalyzerError> {
107 if let Some(with) = &query.with {
109 self.analyze_with_clause(with)?;
110 }
111
112 let result = self.analyze_query_body(&query.body)?;
114
115 for order_item in &query.order_by {
117 self.analyze_expr(&order_item.expr)?;
118 }
119
120 if let Some(limit) = &query.limit {
122 if let Some(count) = &limit.count {
123 self.analyze_expr_expect_int(count)?;
124 }
125 if let Some(offset) = &limit.offset {
126 self.analyze_expr_expect_int(offset)?;
127 }
128 }
129
130 Ok(result)
131 }
132
133 fn analyze_with_clause(&mut self, with: &WithClause) -> std::result::Result<(), AnalyzerError> {
135 for cte in &with.ctes {
136 if self.current_scope().has_cte(&cte.name.value) {
138 return Err(AnalyzerError::new(AnalyzerErrorKind::DuplicateCte {
139 name: cte.name.value.clone(),
140 }));
141 }
142
143 let cte_result = self.analyze_query_internal(&cte.query)?;
145
146 let columns: Vec<ScopeColumn> = cte_result
148 .columns
149 .iter()
150 .enumerate()
151 .map(|(i, col)| {
152 ScopeColumn::new(
153 col.name.clone(),
154 col.data_type.clone(),
155 col.nullable,
156 cte.name.value.clone(),
157 i,
158 )
159 })
160 .collect();
161
162 self.current_scope_mut().add_cte(CteRef {
163 name: cte.name.value.clone(),
164 columns,
165 is_recursive: with.recursive,
166 });
167 }
168 Ok(())
169 }
170
171 fn analyze_query_body(
173 &mut self,
174 body: &QueryBody,
175 ) -> std::result::Result<AnalyzedQuery, AnalyzerError> {
176 match body {
177 QueryBody::Select(select) => self.analyze_select(select),
178 QueryBody::SetOperation { left, right, .. } => {
179 let left_result = self.analyze_query_body(left)?;
180 let right_result = self.analyze_query_body(right)?;
181
182 if left_result.columns.len() != right_result.columns.len() {
184 return Err(AnalyzerError::set_operation_column_mismatch(
185 left_result.columns.len(),
186 right_result.columns.len(),
187 ));
188 }
189
190 Ok(left_result)
192 }
193 QueryBody::Parenthesized(query) => self.analyze_query_internal(query),
194 }
195 }
196
197 fn analyze_select(
199 &mut self,
200 select: &Select,
201 ) -> std::result::Result<AnalyzedQuery, AnalyzerError> {
202 self.push_scope();
203
204 if let Some(from) = &select.from {
206 for table_ref in &from.tables {
207 self.analyze_table_ref(table_ref)?;
208 }
209 }
210
211 let has_group_by = select.group_by.is_some();
213 self.current_scope_mut().has_group_by = has_group_by;
214
215 if let Some(group_by) = &select.group_by {
216 for item in &group_by.items {
217 if let GroupByItem::Expr(expr) = item {
218 if let ExprKind::Identifier(ident) = &expr.kind {
219 self.current_scope_mut()
220 .group_by_columns
221 .push(ident.value.clone());
222 }
223 }
224 }
225 }
226
227 if let Some(where_clause) = &select.where_clause {
229 self.analyze_expr_expect_bool(where_clause)?;
230 }
231
232 let mut columns = Vec::new();
234 let mut has_aggregation = false;
235 let mut has_window_functions = false;
236
237 for item in &select.projection {
238 match item {
239 SelectItem::Expr { expr, alias } => {
240 let typed = self.analyze_expr(expr)?;
241 has_aggregation = has_aggregation || typed.contains_aggregate;
242 has_window_functions = has_window_functions || typed.contains_window;
243
244 let name = alias
245 .as_ref()
246 .map(|a| a.value.clone())
247 .or_else(|| self.expr_to_name(expr))
248 .unwrap_or_else(|| format!("_col{}", columns.len()));
249
250 columns.push(OutputColumn {
251 name,
252 data_type: typed.data_type,
253 nullable: typed.nullable,
254 });
255 }
256 SelectItem::Wildcard => {
257 for table in self.current_scope().all_tables() {
259 for col in &table.columns {
260 columns.push(OutputColumn {
261 name: col.name.clone(),
262 data_type: col.data_type.clone(),
263 nullable: col.nullable,
264 });
265 }
266 }
267 }
268 SelectItem::QualifiedWildcard { qualifier } => {
269 let table_name = qualifier
271 .parts
272 .last()
273 .map(|i| i.value.clone())
274 .unwrap_or_default();
275
276 if let Some(table) = self.current_scope().lookup_table(&table_name) {
277 for col in &table.columns {
278 columns.push(OutputColumn {
279 name: col.name.clone(),
280 data_type: col.data_type.clone(),
281 nullable: col.nullable,
282 });
283 }
284 } else {
285 return Err(AnalyzerError::table_not_found(&table_name));
286 }
287 }
288 SelectItem::WildcardExcept { qualifier, except } => {
289 let table_iter: Vec<_> = if let Some(q) = qualifier {
290 let table_name =
291 q.parts.last().map(|i| i.value.clone()).unwrap_or_default();
292 if let Some(table) = self.current_scope().lookup_table(&table_name) {
293 vec![table.clone()]
294 } else {
295 return Err(AnalyzerError::table_not_found(&table_name));
296 }
297 } else {
298 self.current_scope().all_tables().cloned().collect()
299 };
300
301 let except_names: Vec<String> =
302 except.iter().map(|i| i.value.to_lowercase()).collect();
303 for table in table_iter {
304 for col in &table.columns {
305 if !except_names.contains(&col.name.to_lowercase()) {
306 columns.push(OutputColumn {
307 name: col.name.clone(),
308 data_type: col.data_type.clone(),
309 nullable: col.nullable,
310 });
311 }
312 }
313 }
314 }
315 SelectItem::WildcardReplace { qualifier, replace } => {
316 let table_iter: Vec<_> = if let Some(q) = qualifier {
317 let table_name =
318 q.parts.last().map(|i| i.value.clone()).unwrap_or_default();
319 if let Some(table) = self.current_scope().lookup_table(&table_name) {
320 vec![table.clone()]
321 } else {
322 return Err(AnalyzerError::table_not_found(&table_name));
323 }
324 } else {
325 self.current_scope().all_tables().cloned().collect()
326 };
327
328 let replace_map: std::collections::HashMap<String, &Expr> = replace
329 .iter()
330 .map(|(expr, ident)| (ident.value.to_lowercase(), expr.as_ref()))
331 .collect();
332
333 for table in table_iter {
334 for col in &table.columns {
335 let col_lower = col.name.to_lowercase();
336 if let Some(replace_expr) = replace_map.get(&col_lower) {
337 let typed = self.analyze_expr(replace_expr)?;
338 columns.push(OutputColumn {
339 name: col.name.clone(),
340 data_type: typed.data_type,
341 nullable: typed.nullable,
342 });
343 } else {
344 columns.push(OutputColumn {
345 name: col.name.clone(),
346 data_type: col.data_type.clone(),
347 nullable: col.nullable,
348 });
349 }
350 }
351 }
352 }
353 }
354 }
355
356 if let Some(having) = &select.having {
358 if !has_group_by && !has_aggregation {
359 return Err(AnalyzerError::new(AnalyzerErrorKind::HavingWithoutGroupBy));
360 }
361 self.analyze_expr_expect_bool(having)?;
362 }
363
364 self.pop_scope();
365
366 Ok(AnalyzedQuery {
367 columns,
368 has_aggregation,
369 has_window_functions,
370 })
371 }
372
373 fn analyze_table_ref(
375 &mut self,
376 table_ref: &TableRef,
377 ) -> std::result::Result<(), AnalyzerError> {
378 match table_ref {
379 TableRef::Table { name, alias, .. } => {
380 let name_parts: Vec<String> = name.parts.iter().map(|i| i.value.clone()).collect();
381
382 let cte_name = name_parts.last().cloned().unwrap_or_default();
384 if let Some(cte) = self.lookup_cte(&cte_name) {
385 let table_alias = alias
386 .as_ref()
387 .map(|a| a.name.value.clone())
388 .unwrap_or_else(|| cte_name.clone());
389
390 let columns: Vec<ScopeColumn> = cte
391 .columns
392 .iter()
393 .map(|c| {
394 ScopeColumn::new(
395 c.name.clone(),
396 c.data_type.clone(),
397 c.nullable,
398 table_alias.clone(),
399 c.column_index,
400 )
401 })
402 .collect();
403
404 self.current_scope_mut().add_table(ScopeTable::new(
405 table_alias,
406 name_parts,
407 columns,
408 ));
409 return Ok(());
410 }
411
412 let table_schema = self
414 .catalog
415 .resolve_table(&name_parts)
416 .map_err(|_| AnalyzerError::table_not_found(&cte_name))?
417 .ok_or_else(|| AnalyzerError::table_not_found(&cte_name))?;
418
419 let table_alias = alias
420 .as_ref()
421 .map(|a| a.name.value.clone())
422 .unwrap_or_else(|| table_schema.name.clone());
423
424 let columns = self.table_schema_to_columns(&table_schema, &table_alias);
425 self.current_scope_mut().add_table(ScopeTable::new(
426 table_alias,
427 name_parts,
428 columns,
429 ));
430 }
431 TableRef::Subquery { query, alias } => {
432 let result = self.analyze_query_internal(query)?;
433
434 let alias_name = alias
435 .as_ref()
436 .map(|a| a.name.value.clone())
437 .unwrap_or_else(|| "_subquery".to_string());
438
439 let columns: Vec<ScopeColumn> = result
440 .columns
441 .iter()
442 .enumerate()
443 .map(|(i, col)| {
444 ScopeColumn::new(
445 col.name.clone(),
446 col.data_type.clone(),
447 col.nullable,
448 alias_name.clone(),
449 i,
450 )
451 })
452 .collect();
453
454 self.current_scope_mut().add_table(ScopeTable::new(
455 alias_name,
456 vec!["_subquery".to_string()],
457 columns,
458 ));
459 }
460 TableRef::Join {
461 left,
462 right,
463 condition,
464 ..
465 } => {
466 self.analyze_table_ref(left)?;
467 self.analyze_table_ref(right)?;
468
469 if let Some(JoinCondition::On(expr)) = condition {
470 self.analyze_expr_expect_bool(expr)?;
471 }
472 }
473 TableRef::Unnest { expr, alias, .. } => {
474 let typed = self.analyze_expr(expr)?;
475
476 let elem_type = match &typed.data_type {
477 SqlType::Array(elem) => (**elem).clone(),
478 _ => SqlType::Unknown,
479 };
480
481 let alias_name = alias
482 .as_ref()
483 .map(|a| a.name.value.clone())
484 .unwrap_or_else(|| "_unnest".to_string());
485
486 let columns = vec![ScopeColumn::new(
487 "value".to_string(),
488 elem_type,
489 true,
490 alias_name.clone(),
491 0,
492 )];
493
494 self.current_scope_mut().add_table(ScopeTable::new(
495 alias_name,
496 vec!["_unnest".to_string()],
497 columns,
498 ));
499 }
500 TableRef::Parenthesized(inner) => {
501 self.analyze_table_ref(inner)?;
502 }
503 TableRef::TableFunction { .. } => {
504 }
506 }
507 Ok(())
508 }
509
510 fn analyze_insert(
512 &mut self,
513 insert: &InsertStatement,
514 ) -> std::result::Result<(), AnalyzerError> {
515 let name_parts: Vec<String> = insert.table.parts.iter().map(|i| i.value.clone()).collect();
516 let table_name = name_parts.last().cloned().unwrap_or_default();
517
518 let table_schema = self
520 .catalog
521 .resolve_table(&name_parts)
522 .map_err(|_| AnalyzerError::table_not_found(&table_name))?
523 .ok_or_else(|| AnalyzerError::table_not_found(&table_name))?;
524
525 for col in &insert.columns {
527 if table_schema.get_column(&col.value).is_none() {
528 return Err(AnalyzerError::column_not_found(
529 &col.value,
530 Some(table_name.clone()),
531 ));
532 }
533 }
534
535 match &insert.source {
537 InsertSource::Values(rows) => {
538 for row in rows {
539 for expr in row {
540 self.analyze_expr(expr)?;
541 }
542 }
543 }
544 InsertSource::Query(query) => {
545 self.analyze_query_internal(query)?;
546 }
547 InsertSource::DefaultValues => {}
548 }
549
550 Ok(())
551 }
552
553 fn analyze_update(
555 &mut self,
556 update: &UpdateStatement,
557 ) -> std::result::Result<(), AnalyzerError> {
558 self.push_scope();
559
560 let (name_parts, table_name, alias_opt) = self.extract_table_info(&update.table)?;
562
563 let table_schema = self
564 .catalog
565 .resolve_table(&name_parts)
566 .map_err(|_| AnalyzerError::table_not_found(&table_name))?
567 .ok_or_else(|| AnalyzerError::table_not_found(&table_name))?;
568
569 let alias = alias_opt.unwrap_or_else(|| table_name.clone());
570
571 let columns = self.table_schema_to_columns(&table_schema, &alias);
572 self.current_scope_mut()
573 .add_table(ScopeTable::new(alias.clone(), name_parts, columns));
574
575 for assignment in &update.assignments {
577 match &assignment.target {
578 AssignmentTarget::Column(col) => {
579 if table_schema.get_column(&col.value).is_none() {
580 return Err(AnalyzerError::column_not_found(
581 &col.value,
582 Some(table_name.clone()),
583 ));
584 }
585 }
586 AssignmentTarget::Path(_) => {}
587 }
588 self.analyze_expr(&assignment.value)?;
589 }
590
591 if let Some(where_clause) = &update.where_clause {
593 self.analyze_expr_expect_bool(where_clause)?;
594 }
595
596 self.pop_scope();
597 Ok(())
598 }
599
600 fn extract_table_info(
602 &self,
603 table_ref: &TableRef,
604 ) -> std::result::Result<(Vec<String>, String, Option<String>), AnalyzerError> {
605 match table_ref {
606 TableRef::Table { name, alias, .. } => {
607 let name_parts: Vec<String> = name.parts.iter().map(|i| i.value.clone()).collect();
608 let table_name = name_parts.last().cloned().unwrap_or_default();
609 let alias_name = alias.as_ref().map(|a| a.name.value.clone());
610 Ok((name_parts, table_name, alias_name))
611 }
612 _ => Err(AnalyzerError::new(AnalyzerErrorKind::Other {
613 message: "Expected table reference".to_string(),
614 })),
615 }
616 }
617
618 fn analyze_delete(
620 &mut self,
621 delete: &DeleteStatement,
622 ) -> std::result::Result<(), AnalyzerError> {
623 self.push_scope();
624
625 let name_parts: Vec<String> = delete.table.parts.iter().map(|i| i.value.clone()).collect();
626 let table_name = name_parts.last().cloned().unwrap_or_default();
627
628 let table_schema = self
629 .catalog
630 .resolve_table(&name_parts)
631 .map_err(|_| AnalyzerError::table_not_found(&table_name))?
632 .ok_or_else(|| AnalyzerError::table_not_found(&table_name))?;
633
634 let alias = delete
635 .alias
636 .as_ref()
637 .map(|a| a.name.value.clone())
638 .unwrap_or_else(|| table_name.clone());
639
640 let columns = self.table_schema_to_columns(&table_schema, &alias);
641 self.current_scope_mut()
642 .add_table(ScopeTable::new(alias, name_parts, columns));
643
644 if let Some(where_clause) = &delete.where_clause {
646 self.analyze_expr_expect_bool(where_clause)?;
647 }
648
649 self.pop_scope();
650 Ok(())
651 }
652
653 fn analyze_merge(&mut self, merge: &MergeStatement) -> std::result::Result<(), AnalyzerError> {
655 self.push_scope();
656
657 self.analyze_table_ref(&merge.target)?;
659
660 self.analyze_table_ref(&merge.source)?;
662
663 self.analyze_expr_expect_bool(&merge.on)?;
665
666 for clause in &merge.clauses {
668 match clause {
669 MergeClause::Matched { condition, action } => {
670 if let Some(cond) = condition {
671 self.analyze_expr_expect_bool(cond)?;
672 }
673 match action {
674 MergeMatchedAction::Update { assignments } => {
675 for assignment in assignments {
676 self.analyze_expr(&assignment.value)?;
677 }
678 }
679 MergeMatchedAction::Delete => {}
680 }
681 }
682 MergeClause::NotMatched { condition, action } => {
683 if let Some(cond) = condition {
684 self.analyze_expr_expect_bool(cond)?;
685 }
686 for expr in &action.values {
687 self.analyze_expr(expr)?;
688 }
689 }
690 MergeClause::NotMatchedBySource { condition, action } => {
691 if let Some(cond) = condition {
692 self.analyze_expr_expect_bool(cond)?;
693 }
694 match action {
695 MergeMatchedAction::Update { assignments } => {
696 for assignment in assignments {
697 self.analyze_expr(&assignment.value)?;
698 }
699 }
700 MergeMatchedAction::Delete => {}
701 }
702 }
703 }
704 }
705
706 self.pop_scope();
707 Ok(())
708 }
709
710 fn analyze_create_table(
712 &mut self,
713 create: &CreateTableStatement,
714 ) -> std::result::Result<(), AnalyzerError> {
715 if !create.if_not_exists {
717 let name_parts: Vec<String> =
718 create.name.parts.iter().map(|i| i.value.clone()).collect();
719 if let Ok(Some(_)) = self.catalog.resolve_table(&name_parts) {
720 return Err(AnalyzerError::new(AnalyzerErrorKind::Other {
721 message: format!("table '{}' already exists", create.name),
722 }));
723 }
724 }
725
726 for col in &create.columns {
728 let count = create
730 .columns
731 .iter()
732 .filter(|c| c.name.value.eq_ignore_ascii_case(&col.name.value))
733 .count();
734 if count > 1 {
735 return Err(AnalyzerError::new(AnalyzerErrorKind::DuplicateAlias {
736 name: col.name.value.clone(),
737 }));
738 }
739 }
740
741 Ok(())
742 }
743
744 fn analyze_create_view(
746 &mut self,
747 create: &CreateViewStatement,
748 ) -> std::result::Result<(), AnalyzerError> {
749 self.analyze_query_internal(&create.query)?;
751 Ok(())
752 }
753
754 fn analyze_expr(&self, expr: &Expr) -> std::result::Result<TypedExpr, AnalyzerError> {
758 let checker = TypeChecker::new(&self.catalog);
759 checker.check_expr(expr, self.current_scope())
760 }
761
762 fn analyze_expr_expect_bool(&self, expr: &Expr) -> std::result::Result<(), AnalyzerError> {
764 let typed = self.analyze_expr(expr)?;
765 if typed.data_type != SqlType::Bool
766 && typed.data_type != SqlType::Unknown
767 && typed.data_type != SqlType::Any
768 {
769 Err(AnalyzerError::type_mismatch(
770 SqlType::Bool,
771 typed.data_type,
772 "condition",
773 ))
774 } else {
775 Ok(())
776 }
777 }
778
779 fn analyze_expr_expect_int(&self, expr: &Expr) -> std::result::Result<(), AnalyzerError> {
781 let typed = self.analyze_expr(expr)?;
782 if !typed.data_type.is_integer()
783 && typed.data_type != SqlType::Unknown
784 && typed.data_type != SqlType::Any
785 {
786 Err(AnalyzerError::type_mismatch(
787 SqlType::Int64,
788 typed.data_type,
789 "LIMIT/OFFSET",
790 ))
791 } else {
792 Ok(())
793 }
794 }
795
796 fn table_schema_to_columns(&self, schema: &TableSchema, alias: &str) -> Vec<ScopeColumn> {
798 schema
799 .columns
800 .iter()
801 .enumerate()
802 .map(|(i, col)| {
803 ScopeColumn::new(
804 col.name.clone(),
805 self.column_schema_to_sql_type(col),
806 col.nullable,
807 alias.to_string(),
808 i,
809 )
810 })
811 .collect()
812 }
813
814 fn column_schema_to_sql_type(&self, col: &ColumnSchema) -> SqlType {
816 col.data_type.clone()
817 }
818
819 fn expr_to_name(&self, expr: &Expr) -> Option<String> {
821 match &expr.kind {
822 ExprKind::Identifier(ident) => Some(ident.value.clone()),
823 ExprKind::CompoundIdentifier(parts) => parts.last().map(|i| i.value.clone()),
824 ExprKind::Function(func) => func.name.parts.last().map(|i| i.value.clone()),
825 ExprKind::Aggregate(agg) => agg.function.name.parts.last().map(|i| i.value.clone()),
826 ExprKind::WindowFunction(wf) => wf.function.name.parts.last().map(|i| i.value.clone()),
827 _ => None,
828 }
829 }
830
831 fn push_scope(&mut self) {
833 self.scopes.push(Scope::new());
834 }
835
836 fn pop_scope(&mut self) {
838 self.scopes.pop();
839 }
840
841 fn current_scope(&self) -> &Scope {
843 self.scopes.last().expect("No scope available")
844 }
845
846 fn current_scope_mut(&mut self) -> &mut Scope {
848 self.scopes.last_mut().expect("No scope available")
849 }
850
851 fn lookup_cte(&self, name: &str) -> Option<CteRef> {
853 for scope in self.scopes.iter().rev() {
854 if let Some(cte) = scope.lookup_cte(name) {
855 return Some(cte.clone());
856 }
857 }
858 None
859 }
860}
861
862impl Default for Analyzer<MemoryCatalog> {
863 fn default() -> Self {
864 Self::new()
865 }
866}
867
868impl Analyzer<MemoryCatalog> {
869 pub fn new() -> Self {
871 let mut catalog = MemoryCatalog::new();
872 catalog.register_builtins();
873 Self::with_catalog(catalog)
874 }
875}
876
877#[cfg(test)]
878mod tests {
879 use super::*;
880 use crate::catalog::TableSchemaBuilder;
881 use crate::parser::Parser;
882
883 fn setup_test_catalog() -> MemoryCatalog {
884 let mut catalog = MemoryCatalog::new();
885 catalog.register_builtins();
886
887 catalog.add_table(
889 TableSchemaBuilder::new("users")
890 .column(ColumnSchema::new("id", SqlType::Int64).not_null())
891 .column(ColumnSchema::new("name", SqlType::Varchar))
892 .column(ColumnSchema::new("age", SqlType::Int64))
893 .column(ColumnSchema::new("email", SqlType::Varchar))
894 .build(),
895 );
896
897 catalog.add_table(
898 TableSchemaBuilder::new("orders")
899 .column(ColumnSchema::new("id", SqlType::Int64).not_null())
900 .column(ColumnSchema::new("user_id", SqlType::Int64))
901 .column(ColumnSchema::new("amount", SqlType::Float64))
902 .column(ColumnSchema::new("created_at", SqlType::Timestamp))
903 .build(),
904 );
905
906 catalog
907 }
908
909 fn parse_and_analyze(sql: &str, catalog: MemoryCatalog) -> Result<AnalyzedQuery> {
910 let mut parser = Parser::new(sql);
911 let stmts = parser.parse()?;
912 let stmt = stmts
913 .into_iter()
914 .next()
915 .expect("Expected at least one statement");
916
917 if let StatementKind::Query(query) = stmt.kind {
918 let mut analyzer = Analyzer::with_catalog(catalog);
919 analyzer.analyze_query_result(&query)
920 } else {
921 panic!("Expected a query statement");
922 }
923 }
924
925 #[test]
926 fn test_simple_select() {
927 let catalog = setup_test_catalog();
928 let result = parse_and_analyze("SELECT id, name FROM users", catalog).unwrap();
929
930 assert_eq!(result.columns.len(), 2);
931 assert_eq!(result.columns[0].name, "id");
932 assert_eq!(result.columns[0].data_type, SqlType::Int64);
933 assert_eq!(result.columns[1].name, "name");
934 assert_eq!(result.columns[1].data_type, SqlType::Varchar);
935 }
936
937 #[test]
938 fn test_select_star() {
939 let catalog = setup_test_catalog();
940 let result = parse_and_analyze("SELECT * FROM users", catalog).unwrap();
941
942 assert_eq!(result.columns.len(), 4);
943 }
944
945 #[test]
946 fn test_select_with_alias() {
947 let catalog = setup_test_catalog();
948 let result =
949 parse_and_analyze("SELECT id AS user_id, name AS username FROM users", catalog)
950 .unwrap();
951
952 assert_eq!(result.columns[0].name, "user_id");
953 assert_eq!(result.columns[1].name, "username");
954 }
955
956 #[test]
957 fn test_join() {
958 let catalog = setup_test_catalog();
959 let result = parse_and_analyze(
960 "SELECT u.id, u.name, o.amount FROM users u JOIN orders o ON u.id = o.user_id",
961 catalog,
962 )
963 .unwrap();
964
965 assert_eq!(result.columns.len(), 3);
966 assert_eq!(result.columns[2].data_type, SqlType::Float64);
967 }
968
969 #[test]
970 fn test_aggregate() {
971 let catalog = setup_test_catalog();
972 let result = parse_and_analyze("SELECT COUNT(*), AVG(age) FROM users", catalog).unwrap();
973
974 assert!(result.has_aggregation);
975 assert_eq!(result.columns.len(), 2);
976 }
977
978 #[test]
979 fn test_table_not_found() {
980 let catalog = setup_test_catalog();
981 let err = parse_and_analyze("SELECT * FROM nonexistent", catalog).unwrap_err();
982 assert!(err.to_string().contains("not found"));
983 }
984
985 #[test]
986 fn test_column_not_found() {
987 let catalog = setup_test_catalog();
988 let err = parse_and_analyze("SELECT nonexistent FROM users", catalog).unwrap_err();
989 assert!(err.to_string().contains("not found"));
990 }
991
992 #[test]
993 fn test_ambiguous_column() {
994 let catalog = setup_test_catalog();
995 let err = parse_and_analyze("SELECT id FROM users, orders", catalog).unwrap_err();
996 assert!(err.to_string().contains("ambiguous"));
997 }
998
999 #[test]
1000 fn test_where_clause_type_check() {
1001 let catalog = setup_test_catalog();
1002 let result = parse_and_analyze("SELECT * FROM users WHERE age > 21", catalog);
1004 assert!(result.is_ok());
1005 }
1006
1007 #[test]
1008 fn test_union() {
1009 let catalog = setup_test_catalog();
1010 let result = parse_and_analyze(
1011 "SELECT id, name FROM users UNION SELECT id, name FROM users",
1012 catalog,
1013 )
1014 .unwrap();
1015
1016 assert_eq!(result.columns.len(), 2);
1017 }
1018
1019 #[test]
1020 fn test_cte() {
1021 let catalog = setup_test_catalog();
1022 let result = parse_and_analyze(
1023 "WITH active_users AS (SELECT id, name FROM users WHERE age > 18) SELECT * FROM active_users",
1024 catalog
1025 ).unwrap();
1026
1027 assert_eq!(result.columns.len(), 2);
1028 }
1029}