vibesql/analyzer/
mod.rs

1//! Semantic analyzer for SQL statements.
2//!
3//! This module provides semantic analysis for parsed SQL AST,
4//! including type checking, name resolution, and validation.
5
6mod error;
7mod scope;
8mod type_checker;
9
10pub use error::{AnalyzerError, AnalyzerErrorKind};
11pub use scope::{ColumnLookupResult, CteRef, Scope, ScopeColumn, ScopeTable};
12pub use type_checker::{TypeChecker, TypedExpr};
13
14use crate::ast::*;
15use crate::catalog::{Catalog, ColumnSchema, MemoryCatalog, TableSchema};
16use crate::error::{Error, Result};
17use crate::types::SqlType;
18
19/// Semantic analyzer for SQL statements.
20pub struct Analyzer<C: Catalog = MemoryCatalog> {
21    /// The catalog for resolving tables and functions.
22    catalog: C,
23    /// Scope stack.
24    scopes: Vec<Scope>,
25    /// Accumulated errors (for error recovery).
26    errors: Vec<AnalyzerError>,
27}
28
29/// Analysis result for a query.
30#[derive(Debug, Clone)]
31pub struct AnalyzedQuery {
32    /// The output columns.
33    pub columns: Vec<OutputColumn>,
34    /// Whether the query has aggregation.
35    pub has_aggregation: bool,
36    /// Whether the query uses window functions.
37    pub has_window_functions: bool,
38}
39
40/// An output column from a query.
41#[derive(Debug, Clone)]
42pub struct OutputColumn {
43    /// Column name (or alias).
44    pub name: String,
45    /// Data type.
46    pub data_type: SqlType,
47    /// Whether the column is nullable.
48    pub nullable: bool,
49}
50
51impl<C: Catalog> Analyzer<C> {
52    /// Create a new analyzer with the given catalog.
53    pub fn with_catalog(catalog: C) -> Self {
54        Self {
55            catalog,
56            scopes: vec![Scope::new()],
57            errors: Vec::new(),
58        }
59    }
60
61    /// Get the catalog.
62    pub fn catalog(&self) -> &C {
63        &self.catalog
64    }
65
66    /// Analyze a statement.
67    pub fn analyze(&mut self, stmt: &Statement) -> Result<()> {
68        self.errors.clear();
69        self.analyze_statement(stmt)
70            .map_err(|e| Error::analyzer(e.to_string()))
71    }
72
73    /// Analyze a query and return column information.
74    pub fn analyze_query_result(&mut self, query: &Query) -> Result<AnalyzedQuery> {
75        self.errors.clear();
76        self.analyze_query_internal(query)
77            .map_err(|e| Error::analyzer(e.to_string()))
78    }
79
80    /// Get any accumulated errors.
81    pub fn errors(&self) -> &[AnalyzerError] {
82        &self.errors
83    }
84
85    /// Analyze a statement.
86    fn analyze_statement(&mut self, stmt: &Statement) -> std::result::Result<(), AnalyzerError> {
87        match &stmt.kind {
88            StatementKind::Query(query) => {
89                self.analyze_query_internal(query)?;
90                Ok(())
91            }
92            StatementKind::Insert(insert) => self.analyze_insert(insert),
93            StatementKind::Update(update) => self.analyze_update(update),
94            StatementKind::Delete(delete) => self.analyze_delete(delete),
95            StatementKind::Merge(merge) => self.analyze_merge(merge),
96            StatementKind::CreateTable(create) => self.analyze_create_table(create),
97            StatementKind::CreateView(create) => self.analyze_create_view(create),
98            _ => Ok(()), // Other statements don't need deep analysis
99        }
100    }
101
102    /// Analyze a query.
103    fn analyze_query_internal(
104        &mut self,
105        query: &Query,
106    ) -> std::result::Result<AnalyzedQuery, AnalyzerError> {
107        // Process WITH clause first (CTEs)
108        if let Some(with) = &query.with {
109            self.analyze_with_clause(with)?;
110        }
111
112        // Analyze the main query body
113        let result = self.analyze_query_body(&query.body)?;
114
115        // Analyze ORDER BY
116        for order_item in &query.order_by {
117            self.analyze_expr(&order_item.expr)?;
118        }
119
120        // Analyze LIMIT/OFFSET
121        if let Some(limit) = &query.limit {
122            if let Some(count) = &limit.count {
123                self.analyze_expr_expect_int(count)?;
124            }
125            if let Some(offset) = &limit.offset {
126                self.analyze_expr_expect_int(offset)?;
127            }
128        }
129
130        Ok(result)
131    }
132
133    /// Analyze a WITH clause.
134    fn analyze_with_clause(&mut self, with: &WithClause) -> std::result::Result<(), AnalyzerError> {
135        for cte in &with.ctes {
136            // Check for duplicate CTE names
137            if self.current_scope().has_cte(&cte.name.value) {
138                return Err(AnalyzerError::new(AnalyzerErrorKind::DuplicateCte {
139                    name: cte.name.value.clone(),
140                }));
141            }
142
143            // Analyze the CTE query
144            let cte_result = self.analyze_query_internal(&cte.query)?;
145
146            // Add CTE to scope
147            let columns: Vec<ScopeColumn> = cte_result
148                .columns
149                .iter()
150                .enumerate()
151                .map(|(i, col)| {
152                    ScopeColumn::new(
153                        col.name.clone(),
154                        col.data_type.clone(),
155                        col.nullable,
156                        cte.name.value.clone(),
157                        i,
158                    )
159                })
160                .collect();
161
162            self.current_scope_mut().add_cte(CteRef {
163                name: cte.name.value.clone(),
164                columns,
165                is_recursive: with.recursive,
166            });
167        }
168        Ok(())
169    }
170
171    /// Analyze a query body (SELECT, UNION, etc.).
172    fn analyze_query_body(
173        &mut self,
174        body: &QueryBody,
175    ) -> std::result::Result<AnalyzedQuery, AnalyzerError> {
176        match body {
177            QueryBody::Select(select) => self.analyze_select(select),
178            QueryBody::SetOperation { left, right, .. } => {
179                let left_result = self.analyze_query_body(left)?;
180                let right_result = self.analyze_query_body(right)?;
181
182                // Check column count matches
183                if left_result.columns.len() != right_result.columns.len() {
184                    return Err(AnalyzerError::set_operation_column_mismatch(
185                        left_result.columns.len(),
186                        right_result.columns.len(),
187                    ));
188                }
189
190                // Result uses left side column names
191                Ok(left_result)
192            }
193            QueryBody::Parenthesized(query) => self.analyze_query_internal(query),
194        }
195    }
196
197    /// Analyze a SELECT statement.
198    fn analyze_select(
199        &mut self,
200        select: &Select,
201    ) -> std::result::Result<AnalyzedQuery, AnalyzerError> {
202        self.push_scope();
203
204        // First, analyze FROM clause to populate scope with tables
205        if let Some(from) = &select.from {
206            for table_ref in &from.tables {
207                self.analyze_table_ref(table_ref)?;
208            }
209        }
210
211        // Check for GROUP BY
212        let has_group_by = select.group_by.is_some();
213        self.current_scope_mut().has_group_by = has_group_by;
214
215        if let Some(group_by) = &select.group_by {
216            for item in &group_by.items {
217                if let GroupByItem::Expr(expr) = item {
218                    if let ExprKind::Identifier(ident) = &expr.kind {
219                        self.current_scope_mut()
220                            .group_by_columns
221                            .push(ident.value.clone());
222                    }
223                }
224            }
225        }
226
227        // Analyze WHERE clause
228        if let Some(where_clause) = &select.where_clause {
229            self.analyze_expr_expect_bool(where_clause)?;
230        }
231
232        // Analyze SELECT items
233        let mut columns = Vec::new();
234        let mut has_aggregation = false;
235        let mut has_window_functions = false;
236
237        for item in &select.projection {
238            match item {
239                SelectItem::Expr { expr, alias } => {
240                    let typed = self.analyze_expr(expr)?;
241                    has_aggregation = has_aggregation || typed.contains_aggregate;
242                    has_window_functions = has_window_functions || typed.contains_window;
243
244                    let name = alias
245                        .as_ref()
246                        .map(|a| a.value.clone())
247                        .or_else(|| self.expr_to_name(expr))
248                        .unwrap_or_else(|| format!("_col{}", columns.len()));
249
250                    columns.push(OutputColumn {
251                        name,
252                        data_type: typed.data_type,
253                        nullable: typed.nullable,
254                    });
255                }
256                SelectItem::Wildcard => {
257                    // Expand * to all columns from all tables in scope
258                    for table in self.current_scope().all_tables() {
259                        for col in &table.columns {
260                            columns.push(OutputColumn {
261                                name: col.name.clone(),
262                                data_type: col.data_type.clone(),
263                                nullable: col.nullable,
264                            });
265                        }
266                    }
267                }
268                SelectItem::QualifiedWildcard { qualifier } => {
269                    // Expand table.* to all columns from that table
270                    let table_name = qualifier
271                        .parts
272                        .last()
273                        .map(|i| i.value.clone())
274                        .unwrap_or_default();
275
276                    if let Some(table) = self.current_scope().lookup_table(&table_name) {
277                        for col in &table.columns {
278                            columns.push(OutputColumn {
279                                name: col.name.clone(),
280                                data_type: col.data_type.clone(),
281                                nullable: col.nullable,
282                            });
283                        }
284                    } else {
285                        return Err(AnalyzerError::table_not_found(&table_name));
286                    }
287                }
288                SelectItem::WildcardExcept { qualifier, except } => {
289                    let table_iter: Vec<_> = if let Some(q) = qualifier {
290                        let table_name =
291                            q.parts.last().map(|i| i.value.clone()).unwrap_or_default();
292                        if let Some(table) = self.current_scope().lookup_table(&table_name) {
293                            vec![table.clone()]
294                        } else {
295                            return Err(AnalyzerError::table_not_found(&table_name));
296                        }
297                    } else {
298                        self.current_scope().all_tables().cloned().collect()
299                    };
300
301                    let except_names: Vec<String> =
302                        except.iter().map(|i| i.value.to_lowercase()).collect();
303                    for table in table_iter {
304                        for col in &table.columns {
305                            if !except_names.contains(&col.name.to_lowercase()) {
306                                columns.push(OutputColumn {
307                                    name: col.name.clone(),
308                                    data_type: col.data_type.clone(),
309                                    nullable: col.nullable,
310                                });
311                            }
312                        }
313                    }
314                }
315                SelectItem::WildcardReplace { qualifier, replace } => {
316                    let table_iter: Vec<_> = if let Some(q) = qualifier {
317                        let table_name =
318                            q.parts.last().map(|i| i.value.clone()).unwrap_or_default();
319                        if let Some(table) = self.current_scope().lookup_table(&table_name) {
320                            vec![table.clone()]
321                        } else {
322                            return Err(AnalyzerError::table_not_found(&table_name));
323                        }
324                    } else {
325                        self.current_scope().all_tables().cloned().collect()
326                    };
327
328                    let replace_map: std::collections::HashMap<String, &Expr> = replace
329                        .iter()
330                        .map(|(expr, ident)| (ident.value.to_lowercase(), expr.as_ref()))
331                        .collect();
332
333                    for table in table_iter {
334                        for col in &table.columns {
335                            let col_lower = col.name.to_lowercase();
336                            if let Some(replace_expr) = replace_map.get(&col_lower) {
337                                let typed = self.analyze_expr(replace_expr)?;
338                                columns.push(OutputColumn {
339                                    name: col.name.clone(),
340                                    data_type: typed.data_type,
341                                    nullable: typed.nullable,
342                                });
343                            } else {
344                                columns.push(OutputColumn {
345                                    name: col.name.clone(),
346                                    data_type: col.data_type.clone(),
347                                    nullable: col.nullable,
348                                });
349                            }
350                        }
351                    }
352                }
353            }
354        }
355
356        // Analyze HAVING clause
357        if let Some(having) = &select.having {
358            if !has_group_by && !has_aggregation {
359                return Err(AnalyzerError::new(AnalyzerErrorKind::HavingWithoutGroupBy));
360            }
361            self.analyze_expr_expect_bool(having)?;
362        }
363
364        self.pop_scope();
365
366        Ok(AnalyzedQuery {
367            columns,
368            has_aggregation,
369            has_window_functions,
370        })
371    }
372
373    /// Analyze a table reference in FROM clause.
374    fn analyze_table_ref(
375        &mut self,
376        table_ref: &TableRef,
377    ) -> std::result::Result<(), AnalyzerError> {
378        match table_ref {
379            TableRef::Table { name, alias, .. } => {
380                let name_parts: Vec<String> = name.parts.iter().map(|i| i.value.clone()).collect();
381
382                // First check if it's a CTE (search all parent scopes)
383                let cte_name = name_parts.last().cloned().unwrap_or_default();
384                if let Some(cte) = self.lookup_cte(&cte_name) {
385                    let table_alias = alias
386                        .as_ref()
387                        .map(|a| a.name.value.clone())
388                        .unwrap_or_else(|| cte_name.clone());
389
390                    let columns: Vec<ScopeColumn> = cte
391                        .columns
392                        .iter()
393                        .map(|c| {
394                            ScopeColumn::new(
395                                c.name.clone(),
396                                c.data_type.clone(),
397                                c.nullable,
398                                table_alias.clone(),
399                                c.column_index,
400                            )
401                        })
402                        .collect();
403
404                    self.current_scope_mut().add_table(ScopeTable::new(
405                        table_alias,
406                        name_parts,
407                        columns,
408                    ));
409                    return Ok(());
410                }
411
412                // Look up table in catalog
413                let table_schema = self
414                    .catalog
415                    .resolve_table(&name_parts)
416                    .map_err(|_| AnalyzerError::table_not_found(&cte_name))?
417                    .ok_or_else(|| AnalyzerError::table_not_found(&cte_name))?;
418
419                let table_alias = alias
420                    .as_ref()
421                    .map(|a| a.name.value.clone())
422                    .unwrap_or_else(|| table_schema.name.clone());
423
424                let columns = self.table_schema_to_columns(&table_schema, &table_alias);
425                self.current_scope_mut().add_table(ScopeTable::new(
426                    table_alias,
427                    name_parts,
428                    columns,
429                ));
430            }
431            TableRef::Subquery { query, alias } => {
432                let result = self.analyze_query_internal(query)?;
433
434                let alias_name = alias
435                    .as_ref()
436                    .map(|a| a.name.value.clone())
437                    .unwrap_or_else(|| "_subquery".to_string());
438
439                let columns: Vec<ScopeColumn> = result
440                    .columns
441                    .iter()
442                    .enumerate()
443                    .map(|(i, col)| {
444                        ScopeColumn::new(
445                            col.name.clone(),
446                            col.data_type.clone(),
447                            col.nullable,
448                            alias_name.clone(),
449                            i,
450                        )
451                    })
452                    .collect();
453
454                self.current_scope_mut().add_table(ScopeTable::new(
455                    alias_name,
456                    vec!["_subquery".to_string()],
457                    columns,
458                ));
459            }
460            TableRef::Join {
461                left,
462                right,
463                condition,
464                ..
465            } => {
466                self.analyze_table_ref(left)?;
467                self.analyze_table_ref(right)?;
468
469                if let Some(JoinCondition::On(expr)) = condition {
470                    self.analyze_expr_expect_bool(expr)?;
471                }
472            }
473            TableRef::Unnest { expr, alias, .. } => {
474                let typed = self.analyze_expr(expr)?;
475
476                let elem_type = match &typed.data_type {
477                    SqlType::Array(elem) => (**elem).clone(),
478                    _ => SqlType::Unknown,
479                };
480
481                let alias_name = alias
482                    .as_ref()
483                    .map(|a| a.name.value.clone())
484                    .unwrap_or_else(|| "_unnest".to_string());
485
486                let columns = vec![ScopeColumn::new(
487                    "value".to_string(),
488                    elem_type,
489                    true,
490                    alias_name.clone(),
491                    0,
492                )];
493
494                self.current_scope_mut().add_table(ScopeTable::new(
495                    alias_name,
496                    vec!["_unnest".to_string()],
497                    columns,
498                ));
499            }
500            TableRef::Parenthesized(inner) => {
501                self.analyze_table_ref(inner)?;
502            }
503            TableRef::TableFunction { .. } => {
504                // Table functions would need special handling
505            }
506        }
507        Ok(())
508    }
509
510    /// Analyze an INSERT statement.
511    fn analyze_insert(
512        &mut self,
513        insert: &InsertStatement,
514    ) -> std::result::Result<(), AnalyzerError> {
515        let name_parts: Vec<String> = insert.table.parts.iter().map(|i| i.value.clone()).collect();
516        let table_name = name_parts.last().cloned().unwrap_or_default();
517
518        // Verify table exists
519        let table_schema = self
520            .catalog
521            .resolve_table(&name_parts)
522            .map_err(|_| AnalyzerError::table_not_found(&table_name))?
523            .ok_or_else(|| AnalyzerError::table_not_found(&table_name))?;
524
525        // Verify columns if specified
526        for col in &insert.columns {
527            if table_schema.get_column(&col.value).is_none() {
528                return Err(AnalyzerError::column_not_found(
529                    &col.value,
530                    Some(table_name.clone()),
531                ));
532            }
533        }
534
535        // Analyze the source
536        match &insert.source {
537            InsertSource::Values(rows) => {
538                for row in rows {
539                    for expr in row {
540                        self.analyze_expr(expr)?;
541                    }
542                }
543            }
544            InsertSource::Query(query) => {
545                self.analyze_query_internal(query)?;
546            }
547            InsertSource::DefaultValues => {}
548        }
549
550        Ok(())
551    }
552
553    /// Analyze an UPDATE statement.
554    fn analyze_update(
555        &mut self,
556        update: &UpdateStatement,
557    ) -> std::result::Result<(), AnalyzerError> {
558        self.push_scope();
559
560        // Add target table to scope - need to extract name from TableRef
561        let (name_parts, table_name, alias_opt) = self.extract_table_info(&update.table)?;
562
563        let table_schema = self
564            .catalog
565            .resolve_table(&name_parts)
566            .map_err(|_| AnalyzerError::table_not_found(&table_name))?
567            .ok_or_else(|| AnalyzerError::table_not_found(&table_name))?;
568
569        let alias = alias_opt.unwrap_or_else(|| table_name.clone());
570
571        let columns = self.table_schema_to_columns(&table_schema, &alias);
572        self.current_scope_mut()
573            .add_table(ScopeTable::new(alias.clone(), name_parts, columns));
574
575        // Analyze assignments
576        for assignment in &update.assignments {
577            match &assignment.target {
578                AssignmentTarget::Column(col) => {
579                    if table_schema.get_column(&col.value).is_none() {
580                        return Err(AnalyzerError::column_not_found(
581                            &col.value,
582                            Some(table_name.clone()),
583                        ));
584                    }
585                }
586                AssignmentTarget::Path(_) => {}
587            }
588            self.analyze_expr(&assignment.value)?;
589        }
590
591        // Analyze WHERE clause
592        if let Some(where_clause) = &update.where_clause {
593            self.analyze_expr_expect_bool(where_clause)?;
594        }
595
596        self.pop_scope();
597        Ok(())
598    }
599
600    /// Extract table name information from a TableRef.
601    fn extract_table_info(
602        &self,
603        table_ref: &TableRef,
604    ) -> std::result::Result<(Vec<String>, String, Option<String>), AnalyzerError> {
605        match table_ref {
606            TableRef::Table { name, alias, .. } => {
607                let name_parts: Vec<String> = name.parts.iter().map(|i| i.value.clone()).collect();
608                let table_name = name_parts.last().cloned().unwrap_or_default();
609                let alias_name = alias.as_ref().map(|a| a.name.value.clone());
610                Ok((name_parts, table_name, alias_name))
611            }
612            _ => Err(AnalyzerError::new(AnalyzerErrorKind::Other {
613                message: "Expected table reference".to_string(),
614            })),
615        }
616    }
617
618    /// Analyze a DELETE statement.
619    fn analyze_delete(
620        &mut self,
621        delete: &DeleteStatement,
622    ) -> std::result::Result<(), AnalyzerError> {
623        self.push_scope();
624
625        let name_parts: Vec<String> = delete.table.parts.iter().map(|i| i.value.clone()).collect();
626        let table_name = name_parts.last().cloned().unwrap_or_default();
627
628        let table_schema = self
629            .catalog
630            .resolve_table(&name_parts)
631            .map_err(|_| AnalyzerError::table_not_found(&table_name))?
632            .ok_or_else(|| AnalyzerError::table_not_found(&table_name))?;
633
634        let alias = delete
635            .alias
636            .as_ref()
637            .map(|a| a.name.value.clone())
638            .unwrap_or_else(|| table_name.clone());
639
640        let columns = self.table_schema_to_columns(&table_schema, &alias);
641        self.current_scope_mut()
642            .add_table(ScopeTable::new(alias, name_parts, columns));
643
644        // Analyze WHERE clause
645        if let Some(where_clause) = &delete.where_clause {
646            self.analyze_expr_expect_bool(where_clause)?;
647        }
648
649        self.pop_scope();
650        Ok(())
651    }
652
653    /// Analyze a MERGE statement.
654    fn analyze_merge(&mut self, merge: &MergeStatement) -> std::result::Result<(), AnalyzerError> {
655        self.push_scope();
656
657        // Analyze target table
658        self.analyze_table_ref(&merge.target)?;
659
660        // Analyze source table
661        self.analyze_table_ref(&merge.source)?;
662
663        // Analyze ON condition
664        self.analyze_expr_expect_bool(&merge.on)?;
665
666        // Analyze WHEN clauses
667        for clause in &merge.clauses {
668            match clause {
669                MergeClause::Matched { condition, action } => {
670                    if let Some(cond) = condition {
671                        self.analyze_expr_expect_bool(cond)?;
672                    }
673                    match action {
674                        MergeMatchedAction::Update { assignments } => {
675                            for assignment in assignments {
676                                self.analyze_expr(&assignment.value)?;
677                            }
678                        }
679                        MergeMatchedAction::Delete => {}
680                    }
681                }
682                MergeClause::NotMatched { condition, action } => {
683                    if let Some(cond) = condition {
684                        self.analyze_expr_expect_bool(cond)?;
685                    }
686                    for expr in &action.values {
687                        self.analyze_expr(expr)?;
688                    }
689                }
690                MergeClause::NotMatchedBySource { condition, action } => {
691                    if let Some(cond) = condition {
692                        self.analyze_expr_expect_bool(cond)?;
693                    }
694                    match action {
695                        MergeMatchedAction::Update { assignments } => {
696                            for assignment in assignments {
697                                self.analyze_expr(&assignment.value)?;
698                            }
699                        }
700                        MergeMatchedAction::Delete => {}
701                    }
702                }
703            }
704        }
705
706        self.pop_scope();
707        Ok(())
708    }
709
710    /// Analyze a CREATE TABLE statement.
711    fn analyze_create_table(
712        &mut self,
713        create: &CreateTableStatement,
714    ) -> std::result::Result<(), AnalyzerError> {
715        // Check that the table doesn't already exist (unless IF NOT EXISTS)
716        if !create.if_not_exists {
717            let name_parts: Vec<String> =
718                create.name.parts.iter().map(|i| i.value.clone()).collect();
719            if let Ok(Some(_)) = self.catalog.resolve_table(&name_parts) {
720                return Err(AnalyzerError::new(AnalyzerErrorKind::Other {
721                    message: format!("table '{}' already exists", create.name),
722                }));
723            }
724        }
725
726        // Validate column definitions
727        for col in &create.columns {
728            // Check for duplicate column names
729            let count = create
730                .columns
731                .iter()
732                .filter(|c| c.name.value.eq_ignore_ascii_case(&col.name.value))
733                .count();
734            if count > 1 {
735                return Err(AnalyzerError::new(AnalyzerErrorKind::DuplicateAlias {
736                    name: col.name.value.clone(),
737                }));
738            }
739        }
740
741        Ok(())
742    }
743
744    /// Analyze a CREATE VIEW statement.
745    fn analyze_create_view(
746        &mut self,
747        create: &CreateViewStatement,
748    ) -> std::result::Result<(), AnalyzerError> {
749        // Analyze the view query
750        self.analyze_query_internal(&create.query)?;
751        Ok(())
752    }
753
754    // === Helper methods ===
755
756    /// Analyze an expression and return its typed result.
757    fn analyze_expr(&self, expr: &Expr) -> std::result::Result<TypedExpr, AnalyzerError> {
758        let checker = TypeChecker::new(&self.catalog);
759        checker.check_expr(expr, self.current_scope())
760    }
761
762    /// Analyze an expression and expect a boolean result.
763    fn analyze_expr_expect_bool(&self, expr: &Expr) -> std::result::Result<(), AnalyzerError> {
764        let typed = self.analyze_expr(expr)?;
765        if typed.data_type != SqlType::Bool
766            && typed.data_type != SqlType::Unknown
767            && typed.data_type != SqlType::Any
768        {
769            Err(AnalyzerError::type_mismatch(
770                SqlType::Bool,
771                typed.data_type,
772                "condition",
773            ))
774        } else {
775            Ok(())
776        }
777    }
778
779    /// Analyze an expression and expect an integer result.
780    fn analyze_expr_expect_int(&self, expr: &Expr) -> std::result::Result<(), AnalyzerError> {
781        let typed = self.analyze_expr(expr)?;
782        if !typed.data_type.is_integer()
783            && typed.data_type != SqlType::Unknown
784            && typed.data_type != SqlType::Any
785        {
786            Err(AnalyzerError::type_mismatch(
787                SqlType::Int64,
788                typed.data_type,
789                "LIMIT/OFFSET",
790            ))
791        } else {
792            Ok(())
793        }
794    }
795
796    /// Convert a table schema to column references.
797    fn table_schema_to_columns(&self, schema: &TableSchema, alias: &str) -> Vec<ScopeColumn> {
798        schema
799            .columns
800            .iter()
801            .enumerate()
802            .map(|(i, col)| {
803                ScopeColumn::new(
804                    col.name.clone(),
805                    self.column_schema_to_sql_type(col),
806                    col.nullable,
807                    alias.to_string(),
808                    i,
809                )
810            })
811            .collect()
812    }
813
814    /// Convert a column schema to SqlType.
815    fn column_schema_to_sql_type(&self, col: &ColumnSchema) -> SqlType {
816        col.data_type.clone()
817    }
818
819    /// Try to derive a name from an expression.
820    fn expr_to_name(&self, expr: &Expr) -> Option<String> {
821        match &expr.kind {
822            ExprKind::Identifier(ident) => Some(ident.value.clone()),
823            ExprKind::CompoundIdentifier(parts) => parts.last().map(|i| i.value.clone()),
824            ExprKind::Function(func) => func.name.parts.last().map(|i| i.value.clone()),
825            ExprKind::Aggregate(agg) => agg.function.name.parts.last().map(|i| i.value.clone()),
826            ExprKind::WindowFunction(wf) => wf.function.name.parts.last().map(|i| i.value.clone()),
827            _ => None,
828        }
829    }
830
831    /// Push a new scope.
832    fn push_scope(&mut self) {
833        self.scopes.push(Scope::new());
834    }
835
836    /// Pop the current scope.
837    fn pop_scope(&mut self) {
838        self.scopes.pop();
839    }
840
841    /// Get the current scope.
842    fn current_scope(&self) -> &Scope {
843        self.scopes.last().expect("No scope available")
844    }
845
846    /// Get the current scope mutably.
847    fn current_scope_mut(&mut self) -> &mut Scope {
848        self.scopes.last_mut().expect("No scope available")
849    }
850
851    /// Look up a CTE in all scopes (current and parents).
852    fn lookup_cte(&self, name: &str) -> Option<CteRef> {
853        for scope in self.scopes.iter().rev() {
854            if let Some(cte) = scope.lookup_cte(name) {
855                return Some(cte.clone());
856            }
857        }
858        None
859    }
860}
861
862impl Default for Analyzer<MemoryCatalog> {
863    fn default() -> Self {
864        Self::new()
865    }
866}
867
868impl Analyzer<MemoryCatalog> {
869    /// Create a new analyzer with an empty memory catalog.
870    pub fn new() -> Self {
871        let mut catalog = MemoryCatalog::new();
872        catalog.register_builtins();
873        Self::with_catalog(catalog)
874    }
875}
876
877#[cfg(test)]
878mod tests {
879    use super::*;
880    use crate::catalog::TableSchemaBuilder;
881    use crate::parser::Parser;
882
883    fn setup_test_catalog() -> MemoryCatalog {
884        let mut catalog = MemoryCatalog::new();
885        catalog.register_builtins();
886
887        // Add test tables
888        catalog.add_table(
889            TableSchemaBuilder::new("users")
890                .column(ColumnSchema::new("id", SqlType::Int64).not_null())
891                .column(ColumnSchema::new("name", SqlType::Varchar))
892                .column(ColumnSchema::new("age", SqlType::Int64))
893                .column(ColumnSchema::new("email", SqlType::Varchar))
894                .build(),
895        );
896
897        catalog.add_table(
898            TableSchemaBuilder::new("orders")
899                .column(ColumnSchema::new("id", SqlType::Int64).not_null())
900                .column(ColumnSchema::new("user_id", SqlType::Int64))
901                .column(ColumnSchema::new("amount", SqlType::Float64))
902                .column(ColumnSchema::new("created_at", SqlType::Timestamp))
903                .build(),
904        );
905
906        catalog
907    }
908
909    fn parse_and_analyze(sql: &str, catalog: MemoryCatalog) -> Result<AnalyzedQuery> {
910        let mut parser = Parser::new(sql);
911        let stmts = parser.parse()?;
912        let stmt = stmts
913            .into_iter()
914            .next()
915            .expect("Expected at least one statement");
916
917        if let StatementKind::Query(query) = stmt.kind {
918            let mut analyzer = Analyzer::with_catalog(catalog);
919            analyzer.analyze_query_result(&query)
920        } else {
921            panic!("Expected a query statement");
922        }
923    }
924
925    #[test]
926    fn test_simple_select() {
927        let catalog = setup_test_catalog();
928        let result = parse_and_analyze("SELECT id, name FROM users", catalog).unwrap();
929
930        assert_eq!(result.columns.len(), 2);
931        assert_eq!(result.columns[0].name, "id");
932        assert_eq!(result.columns[0].data_type, SqlType::Int64);
933        assert_eq!(result.columns[1].name, "name");
934        assert_eq!(result.columns[1].data_type, SqlType::Varchar);
935    }
936
937    #[test]
938    fn test_select_star() {
939        let catalog = setup_test_catalog();
940        let result = parse_and_analyze("SELECT * FROM users", catalog).unwrap();
941
942        assert_eq!(result.columns.len(), 4);
943    }
944
945    #[test]
946    fn test_select_with_alias() {
947        let catalog = setup_test_catalog();
948        let result =
949            parse_and_analyze("SELECT id AS user_id, name AS username FROM users", catalog)
950                .unwrap();
951
952        assert_eq!(result.columns[0].name, "user_id");
953        assert_eq!(result.columns[1].name, "username");
954    }
955
956    #[test]
957    fn test_join() {
958        let catalog = setup_test_catalog();
959        let result = parse_and_analyze(
960            "SELECT u.id, u.name, o.amount FROM users u JOIN orders o ON u.id = o.user_id",
961            catalog,
962        )
963        .unwrap();
964
965        assert_eq!(result.columns.len(), 3);
966        assert_eq!(result.columns[2].data_type, SqlType::Float64);
967    }
968
969    #[test]
970    fn test_aggregate() {
971        let catalog = setup_test_catalog();
972        let result = parse_and_analyze("SELECT COUNT(*), AVG(age) FROM users", catalog).unwrap();
973
974        assert!(result.has_aggregation);
975        assert_eq!(result.columns.len(), 2);
976    }
977
978    #[test]
979    fn test_table_not_found() {
980        let catalog = setup_test_catalog();
981        let err = parse_and_analyze("SELECT * FROM nonexistent", catalog).unwrap_err();
982        assert!(err.to_string().contains("not found"));
983    }
984
985    #[test]
986    fn test_column_not_found() {
987        let catalog = setup_test_catalog();
988        let err = parse_and_analyze("SELECT nonexistent FROM users", catalog).unwrap_err();
989        assert!(err.to_string().contains("not found"));
990    }
991
992    #[test]
993    fn test_ambiguous_column() {
994        let catalog = setup_test_catalog();
995        let err = parse_and_analyze("SELECT id FROM users, orders", catalog).unwrap_err();
996        assert!(err.to_string().contains("ambiguous"));
997    }
998
999    #[test]
1000    fn test_where_clause_type_check() {
1001        let catalog = setup_test_catalog();
1002        // Valid: boolean condition
1003        let result = parse_and_analyze("SELECT * FROM users WHERE age > 21", catalog);
1004        assert!(result.is_ok());
1005    }
1006
1007    #[test]
1008    fn test_union() {
1009        let catalog = setup_test_catalog();
1010        let result = parse_and_analyze(
1011            "SELECT id, name FROM users UNION SELECT id, name FROM users",
1012            catalog,
1013        )
1014        .unwrap();
1015
1016        assert_eq!(result.columns.len(), 2);
1017    }
1018
1019    #[test]
1020    fn test_cte() {
1021        let catalog = setup_test_catalog();
1022        let result = parse_and_analyze(
1023            "WITH active_users AS (SELECT id, name FROM users WHERE age > 18) SELECT * FROM active_users",
1024            catalog
1025        ).unwrap();
1026
1027        assert_eq!(result.columns.len(), 2);
1028    }
1029}