Skip to main content

sqlglot_rust/planner/
mod.rs

1//! Logical query planner.
2//!
3//! Generates a logical execution plan (a DAG of [`Step`]s) from an
4//! optimized SQL AST. Inspired by Python sqlglot's `planner.py`.
5//!
6//! The planner sits between the optimizer and the executor: the optimizer
7//! rewrites the AST, then the planner produces a plan that an execution
8//! engine can follow.
9//!
10//! # Example
11//!
12//! ```rust
13//! use sqlglot_rust::parser::parse;
14//! use sqlglot_rust::dialects::Dialect;
15//! use sqlglot_rust::planner::{plan, Plan};
16//!
17//! let ast = parse("SELECT a, b FROM t WHERE a > 1 ORDER BY b", Dialect::Ansi).unwrap();
18//! let p = plan(&ast).unwrap();
19//! println!("{}", p.to_mermaid());
20//! ```
21
22use std::fmt;
23
24use crate::ast::*;
25use crate::errors::{Result, SqlglotError};
26
27// ═══════════════════════════════════════════════════════════════════════
28// Step ID
29// ═══════════════════════════════════════════════════════════════════════
30
31/// Opaque identifier for a step within a plan.
32#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
33pub struct StepId(usize);
34
35impl fmt::Display for StepId {
36    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
37        write!(f, "step_{}", self.0)
38    }
39}
40
41// ═══════════════════════════════════════════════════════════════════════
42// Column projection
43// ═══════════════════════════════════════════════════════════════════════
44
45/// A projected column in a plan step.
46#[derive(Debug, Clone, PartialEq)]
47pub struct Projection {
48    /// The expression being projected.
49    pub expr: Expr,
50    /// Output alias (if any).
51    pub alias: Option<String>,
52}
53
54// ═══════════════════════════════════════════════════════════════════════
55// Plan step types
56// ═══════════════════════════════════════════════════════════════════════
57
58/// A single step in the logical execution plan.
59#[derive(Debug, Clone, PartialEq)]
60pub enum Step {
61    /// Full table scan with optional filter pushdown.
62    Scan {
63        /// Fully-qualified table name.
64        table: String,
65        /// Alias for the table (if any).
66        alias: Option<String>,
67        /// Projected columns.
68        projections: Vec<Projection>,
69        /// Predicate pushed down to the scan.
70        predicate: Option<Expr>,
71        /// IDs of steps this step depends on (always empty for a scan).
72        dependencies: Vec<StepId>,
73    },
74    /// Filter (WHERE / HAVING) applied to its input.
75    Filter {
76        /// The filter predicate.
77        predicate: Expr,
78        /// Projected columns.
79        projections: Vec<Projection>,
80        /// The single input step.
81        dependencies: Vec<StepId>,
82    },
83    /// Projection (SELECT list evaluation).
84    Project {
85        /// Output projections.
86        projections: Vec<Projection>,
87        /// The single input step.
88        dependencies: Vec<StepId>,
89    },
90    /// Aggregation (GROUP BY + aggregate functions).
91    Aggregate {
92        /// GROUP BY keys.
93        group_by: Vec<Expr>,
94        /// Aggregate expressions (COUNT, SUM, etc.).
95        aggregations: Vec<Projection>,
96        /// Projected output columns.
97        projections: Vec<Projection>,
98        /// The single input step.
99        dependencies: Vec<StepId>,
100    },
101    /// Sort (ORDER BY).
102    Sort {
103        /// Order-by items.
104        order_by: Vec<OrderByItem>,
105        /// Projected columns (pass-through).
106        projections: Vec<Projection>,
107        /// The single input step.
108        dependencies: Vec<StepId>,
109    },
110    /// Join two inputs.
111    Join {
112        /// Type of join.
113        join_type: JoinType,
114        /// Join condition (ON clause).
115        condition: Option<Expr>,
116        /// USING columns (if specified instead of ON).
117        using_columns: Vec<String>,
118        /// Projected columns.
119        projections: Vec<Projection>,
120        /// Two input steps: [left, right].
121        dependencies: Vec<StepId>,
122    },
123    /// LIMIT / OFFSET.
124    Limit {
125        /// Row limit.
126        limit: Option<Expr>,
127        /// Row offset.
128        offset: Option<Expr>,
129        /// Projected columns (pass-through).
130        projections: Vec<Projection>,
131        /// The single input step.
132        dependencies: Vec<StepId>,
133    },
134    /// UNION / INTERSECT / EXCEPT.
135    SetOperation {
136        /// The kind of set operation.
137        op: SetOperationType,
138        /// Whether ALL (no deduplication).
139        all: bool,
140        /// Projected columns from the combined result.
141        projections: Vec<Projection>,
142        /// Two input steps: [left, right].
143        dependencies: Vec<StepId>,
144    },
145    /// DISTINCT elimination.
146    Distinct {
147        /// Projected columns.
148        projections: Vec<Projection>,
149        /// The single input step.
150        dependencies: Vec<StepId>,
151    },
152}
153
154impl Step {
155    /// Returns the list of step IDs this step depends on.
156    #[must_use]
157    pub fn dependencies(&self) -> &[StepId] {
158        match self {
159            Step::Scan { dependencies, .. }
160            | Step::Filter { dependencies, .. }
161            | Step::Project { dependencies, .. }
162            | Step::Aggregate { dependencies, .. }
163            | Step::Sort { dependencies, .. }
164            | Step::Join { dependencies, .. }
165            | Step::Limit { dependencies, .. }
166            | Step::SetOperation { dependencies, .. }
167            | Step::Distinct { dependencies, .. } => dependencies,
168        }
169    }
170
171    /// Returns the projected columns of this step.
172    #[must_use]
173    pub fn projections(&self) -> &[Projection] {
174        match self {
175            Step::Scan { projections, .. }
176            | Step::Filter { projections, .. }
177            | Step::Project { projections, .. }
178            | Step::Aggregate { projections, .. }
179            | Step::Sort { projections, .. }
180            | Step::Join { projections, .. }
181            | Step::Limit { projections, .. }
182            | Step::SetOperation { projections, .. }
183            | Step::Distinct { projections, .. } => projections,
184        }
185    }
186
187    /// A short human-readable label for the step type.
188    #[must_use]
189    pub fn kind(&self) -> &'static str {
190        match self {
191            Step::Scan { .. } => "Scan",
192            Step::Filter { .. } => "Filter",
193            Step::Project { .. } => "Project",
194            Step::Aggregate { .. } => "Aggregate",
195            Step::Sort { .. } => "Sort",
196            Step::Join { .. } => "Join",
197            Step::Limit { .. } => "Limit",
198            Step::SetOperation { .. } => "SetOperation",
199            Step::Distinct { .. } => "Distinct",
200        }
201    }
202}
203
204// ═══════════════════════════════════════════════════════════════════════
205// Plan
206// ═══════════════════════════════════════════════════════════════════════
207
208/// A logical execution plan — a directed acyclic graph (DAG) of steps.
209///
210/// Steps are stored in topological order: a step's dependencies always
211/// have a smaller [`StepId`] than the step itself.
212#[derive(Debug, Clone)]
213pub struct Plan {
214    /// All steps in topological order.
215    steps: Vec<Step>,
216    /// The "root" step that produces the final result.
217    root: StepId,
218}
219
220impl Plan {
221    /// Returns the root step ID.
222    #[must_use]
223    pub fn root(&self) -> StepId {
224        self.root
225    }
226
227    /// Returns a reference to all steps.
228    #[must_use]
229    pub fn steps(&self) -> &[Step] {
230        &self.steps
231    }
232
233    /// Looks up a step by its ID.
234    #[must_use]
235    pub fn get(&self, id: StepId) -> Option<&Step> {
236        self.steps.get(id.0)
237    }
238
239    /// Number of steps in the plan.
240    #[must_use]
241    pub fn len(&self) -> usize {
242        self.steps.len()
243    }
244
245    /// Whether the plan has zero steps.
246    #[must_use]
247    pub fn is_empty(&self) -> bool {
248        self.steps.is_empty()
249    }
250
251    /// Render the plan as a Mermaid flowchart.
252    #[must_use]
253    pub fn to_mermaid(&self) -> String {
254        let mut out = String::from("graph TD\n");
255        for (i, step) in self.steps.iter().enumerate() {
256            let id = StepId(i);
257            let label = step_label(step);
258            out.push_str(&format!("    {id}[\"{label}\"]\n"));
259            for dep in step.dependencies() {
260                out.push_str(&format!("    {dep} --> {id}\n"));
261            }
262        }
263        out
264    }
265
266    /// Render the plan as a DOT (Graphviz) digraph.
267    #[must_use]
268    pub fn to_dot(&self) -> String {
269        let mut out = String::from("digraph plan {\n    rankdir=BT;\n");
270        for (i, step) in self.steps.iter().enumerate() {
271            let id = StepId(i);
272            let label = step_label(step);
273            out.push_str(&format!("    {id} [label=\"{label}\"];\n"));
274            for dep in step.dependencies() {
275                out.push_str(&format!("    {dep} -> {id};\n"));
276            }
277        }
278        out.push_str("}\n");
279        out
280    }
281}
282
283impl fmt::Display for Plan {
284    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
285        for (i, step) in self.steps.iter().enumerate() {
286            let id = StepId(i);
287            let root_marker = if id == self.root { " (root)" } else { "" };
288            writeln!(f, "{id}{root_marker}: {}", step_label(step))?;
289            for dep in step.dependencies() {
290                writeln!(f, "  <- {dep}")?;
291            }
292        }
293        Ok(())
294    }
295}
296
297/// Produce a concise label for visualization.
298fn step_label(step: &Step) -> String {
299    match step {
300        Step::Scan {
301            table,
302            alias,
303            predicate,
304            ..
305        } => {
306            let name = alias.as_deref().unwrap_or(table.as_str());
307            if predicate.is_some() {
308                format!("Scan({name} + filter)")
309            } else {
310                format!("Scan({name})")
311            }
312        }
313        Step::Filter { .. } => "Filter".to_string(),
314        Step::Project { projections, .. } => {
315            let cols: Vec<_> = projections
316                .iter()
317                .map(|p| {
318                    p.alias
319                        .as_deref()
320                        .unwrap_or_else(|| expr_short_name(&p.expr))
321                })
322                .collect();
323            if cols.len() <= 4 {
324                format!("Project({})", cols.join(", "))
325            } else {
326                format!("Project({} cols)", cols.len())
327            }
328        }
329        Step::Aggregate { group_by, .. } => {
330            if group_by.is_empty() {
331                "Aggregate(scalar)".to_string()
332            } else {
333                format!("Aggregate({} keys)", group_by.len())
334            }
335        }
336        Step::Sort { order_by, .. } => format!("Sort({} keys)", order_by.len()),
337        Step::Join { join_type, .. } => format!("Join({join_type:?})"),
338        Step::Limit { limit, offset, .. } => {
339            let mut parts = Vec::new();
340            if limit.is_some() {
341                parts.push("limit");
342            }
343            if offset.is_some() {
344                parts.push("offset");
345            }
346            format!("Limit({})", parts.join("+"))
347        }
348        Step::SetOperation { op, all, .. } => {
349            let all_str = if *all { " ALL" } else { "" };
350            format!("{op:?}{all_str}")
351        }
352        Step::Distinct { .. } => "Distinct".to_string(),
353    }
354}
355
356/// Short name for an expression (used in labels).
357fn expr_short_name(expr: &Expr) -> &str {
358    match expr {
359        Expr::Column { name, .. } => name.as_str(),
360        Expr::Wildcard => "*",
361        _ => "expr",
362    }
363}
364
365// ═══════════════════════════════════════════════════════════════════════
366// Plan builder
367// ═══════════════════════════════════════════════════════════════════════
368
369/// Build a logical execution plan from a parsed SQL statement.
370///
371/// The statement should ideally be optimized first (via
372/// [`crate::optimizer::optimize`]) for the best plan quality, but this
373/// is not required.
374///
375/// # Errors
376///
377/// Returns [`SqlglotError`] when the statement cannot be planned (e.g.,
378/// DDL statements, unsupported constructs).
379pub fn plan(statement: &Statement) -> Result<Plan> {
380    let mut builder = PlanBuilder::new();
381    let _root = builder.plan_statement(statement)?;
382    Ok(builder.build())
383}
384
385/// Internal builder that accumulates steps.
386struct PlanBuilder {
387    steps: Vec<Step>,
388}
389
390impl PlanBuilder {
391    fn new() -> Self {
392        Self { steps: Vec::new() }
393    }
394
395    fn add_step(&mut self, step: Step) -> StepId {
396        let id = StepId(self.steps.len());
397        self.steps.push(step);
398        id
399    }
400
401    fn build(self) -> Plan {
402        let root = if self.steps.is_empty() {
403            StepId(0)
404        } else {
405            StepId(self.steps.len() - 1)
406        };
407        Plan {
408            steps: self.steps,
409            root,
410        }
411    }
412
413    // ───────────────────────────────────────────────────────────────
414    // Statement dispatch
415    // ───────────────────────────────────────────────────────────────
416
417    fn plan_statement(&mut self, stmt: &Statement) -> Result<StepId> {
418        match stmt {
419            Statement::Select(sel) => self.plan_select(sel),
420            Statement::SetOperation(set_op) => self.plan_set_operation(set_op),
421            _ => Err(SqlglotError::Internal(format!(
422                "Planner does not support {:?} statements",
423                std::mem::discriminant(stmt)
424            ))),
425        }
426    }
427
428    // ───────────────────────────────────────────────────────────────
429    // SELECT
430    // ───────────────────────────────────────────────────────────────
431
432    fn plan_select(&mut self, sel: &SelectStatement) -> Result<StepId> {
433        // 1. Resolve FROM source(s)
434        let mut current = if let Some(from) = &sel.from {
435            self.plan_table_source(&from.source)?
436        } else {
437            // No FROM — single-row virtual scan (e.g., SELECT 1+2)
438            self.add_step(Step::Scan {
439                table: String::new(),
440                alias: None,
441                projections: vec![],
442                predicate: None,
443                dependencies: vec![],
444            })
445        };
446
447        // 2. JOINs
448        for join in &sel.joins {
449            let right = self.plan_table_source(&join.table)?;
450            let projections = vec![]; // pass-through
451            current = self.add_step(Step::Join {
452                join_type: join.join_type.clone(),
453                condition: join.on.clone(),
454                using_columns: join.using.clone(),
455                projections,
456                dependencies: vec![current, right],
457            });
458        }
459
460        // 3. WHERE
461        if let Some(pred) = &sel.where_clause {
462            current = self.add_step(Step::Filter {
463                predicate: pred.clone(),
464                projections: vec![],
465                dependencies: vec![current],
466            });
467        }
468
469        // 4. GROUP BY / Aggregation
470        if !sel.group_by.is_empty() || has_aggregates(&sel.columns) {
471            let aggregations = extract_aggregates(&sel.columns);
472            current = self.add_step(Step::Aggregate {
473                group_by: sel.group_by.clone(),
474                aggregations,
475                projections: vec![],
476                dependencies: vec![current],
477            });
478        }
479
480        // 5. HAVING
481        if let Some(having) = &sel.having {
482            current = self.add_step(Step::Filter {
483                predicate: having.clone(),
484                projections: vec![],
485                dependencies: vec![current],
486            });
487        }
488
489        // 6. DISTINCT
490        if sel.distinct {
491            current = self.add_step(Step::Distinct {
492                projections: vec![],
493                dependencies: vec![current],
494            });
495        }
496
497        // 7. ORDER BY
498        if !sel.order_by.is_empty() {
499            current = self.add_step(Step::Sort {
500                order_by: sel.order_by.clone(),
501                projections: vec![],
502                dependencies: vec![current],
503            });
504        }
505
506        // 8. LIMIT / OFFSET
507        if sel.limit.is_some() || sel.offset.is_some() || sel.fetch_first.is_some() {
508            let limit = sel.limit.clone().or_else(|| sel.fetch_first.clone());
509            current = self.add_step(Step::Limit {
510                limit,
511                offset: sel.offset.clone(),
512                projections: vec![],
513                dependencies: vec![current],
514            });
515        }
516
517        // 9. Project (SELECT columns)
518        let projections = select_items_to_projections(&sel.columns);
519        if !projections.is_empty() {
520            current = self.add_step(Step::Project {
521                projections,
522                dependencies: vec![current],
523            });
524        }
525
526        Ok(current)
527    }
528
529    // ───────────────────────────────────────────────────────────────
530    // Table sources
531    // ───────────────────────────────────────────────────────────────
532
533    fn plan_table_source(&mut self, source: &TableSource) -> Result<StepId> {
534        match source {
535            TableSource::Table(tref) => {
536                let table = fully_qualified_name(tref);
537                Ok(self.add_step(Step::Scan {
538                    table,
539                    alias: tref.alias.clone(),
540                    projections: vec![],
541                    predicate: None,
542                    dependencies: vec![],
543                }))
544            }
545            TableSource::Subquery { query, alias: _ } => self.plan_statement(query),
546            TableSource::Lateral { source } => self.plan_table_source(source),
547            TableSource::TableFunction { name, args, alias } => Ok(self.add_step(Step::Scan {
548                table: name.clone(),
549                alias: alias.clone(),
550                projections: args
551                    .iter()
552                    .map(|a| Projection {
553                        expr: a.clone(),
554                        alias: None,
555                    })
556                    .collect(),
557                predicate: None,
558                dependencies: vec![],
559            })),
560            TableSource::Unnest { expr, alias, .. } => Ok(self.add_step(Step::Scan {
561                table: "UNNEST".to_string(),
562                alias: alias.clone(),
563                projections: vec![Projection {
564                    expr: *expr.clone(),
565                    alias: None,
566                }],
567                predicate: None,
568                dependencies: vec![],
569            })),
570            TableSource::Pivot { source, alias, .. }
571            | TableSource::Unpivot { source, alias, .. } => {
572                // Plan the underlying source; the pivot/unpivot is treated
573                // as a virtual scan wrapping it.
574                let inner = self.plan_table_source(source)?;
575                // For simplicity, wrap pivot/unpivot into a project.
576                Ok(self.add_step(Step::Project {
577                    projections: vec![Projection {
578                        expr: Expr::Wildcard,
579                        alias: alias.clone(),
580                    }],
581                    dependencies: vec![inner],
582                }))
583            }
584        }
585    }
586
587    // ───────────────────────────────────────────────────────────────
588    // Set operations
589    // ───────────────────────────────────────────────────────────────
590
591    fn plan_set_operation(&mut self, set_op: &SetOperationStatement) -> Result<StepId> {
592        let left = self.plan_statement(&set_op.left)?;
593        let right = self.plan_statement(&set_op.right)?;
594
595        let mut current = self.add_step(Step::SetOperation {
596            op: set_op.op.clone(),
597            all: set_op.all,
598            projections: vec![],
599            dependencies: vec![left, right],
600        });
601
602        if !set_op.order_by.is_empty() {
603            current = self.add_step(Step::Sort {
604                order_by: set_op.order_by.clone(),
605                projections: vec![],
606                dependencies: vec![current],
607            });
608        }
609
610        if set_op.limit.is_some() || set_op.offset.is_some() {
611            current = self.add_step(Step::Limit {
612                limit: set_op.limit.clone(),
613                offset: set_op.offset.clone(),
614                projections: vec![],
615                dependencies: vec![current],
616            });
617        }
618
619        Ok(current)
620    }
621}
622
623// ═══════════════════════════════════════════════════════════════════════
624// Helpers
625// ═══════════════════════════════════════════════════════════════════════
626
627/// Build a fully qualified table name from a [`TableRef`].
628fn fully_qualified_name(tref: &TableRef) -> String {
629    let mut parts = Vec::new();
630    if let Some(catalog) = &tref.catalog {
631        parts.push(catalog.as_str());
632    }
633    if let Some(schema) = &tref.schema {
634        parts.push(schema.as_str());
635    }
636    parts.push(tref.name.as_str());
637    parts.join(".")
638}
639
640/// Convert SELECT items to projections.
641fn select_items_to_projections(items: &[SelectItem]) -> Vec<Projection> {
642    items
643        .iter()
644        .map(|item| match item {
645            SelectItem::Wildcard => Projection {
646                expr: Expr::Wildcard,
647                alias: None,
648            },
649            SelectItem::QualifiedWildcard { table } => Projection {
650                expr: Expr::QualifiedWildcard {
651                    table: table.clone(),
652                },
653                alias: None,
654            },
655            SelectItem::Expr { expr, alias } => Projection {
656                expr: expr.clone(),
657                alias: alias.clone(),
658            },
659        })
660        .collect()
661}
662
663/// Check whether any SELECT items contain aggregate functions.
664fn has_aggregates(items: &[SelectItem]) -> bool {
665    items.iter().any(|item| match item {
666        SelectItem::Expr { expr, .. } => expr_has_aggregate(expr),
667        _ => false,
668    })
669}
670
671/// Recursively check whether an expression contains an aggregate function.
672fn expr_has_aggregate(expr: &Expr) -> bool {
673    match expr {
674        Expr::Function { name, .. } => is_aggregate_name(name),
675        Expr::TypedFunction { func, .. } => typed_function_is_aggregate(func),
676        Expr::BinaryOp { left, right, .. } => expr_has_aggregate(left) || expr_has_aggregate(right),
677        Expr::UnaryOp { expr, .. } => expr_has_aggregate(expr),
678        Expr::Cast { expr, .. } | Expr::TryCast { expr, .. } => expr_has_aggregate(expr),
679        Expr::Case {
680            operand,
681            when_clauses,
682            else_clause,
683        } => {
684            operand.as_ref().is_some_and(|e| expr_has_aggregate(e))
685                || when_clauses
686                    .iter()
687                    .any(|(cond, result)| expr_has_aggregate(cond) || expr_has_aggregate(result))
688                || else_clause.as_ref().is_some_and(|e| expr_has_aggregate(e))
689        }
690        Expr::Alias { expr, .. } => expr_has_aggregate(expr),
691        _ => false,
692    }
693}
694
695/// Well-known aggregate function names.
696fn is_aggregate_name(name: &str) -> bool {
697    matches!(
698        name.to_uppercase().as_str(),
699        "COUNT"
700            | "SUM"
701            | "AVG"
702            | "MIN"
703            | "MAX"
704            | "GROUP_CONCAT"
705            | "STRING_AGG"
706            | "ARRAY_AGG"
707            | "LISTAGG"
708            | "COLLECT_LIST"
709            | "COLLECT_SET"
710            | "ANY_VALUE"
711            | "APPROX_COUNT_DISTINCT"
712            | "PERCENTILE_CONT"
713            | "PERCENTILE_DISC"
714            | "STDDEV"
715            | "STDDEV_POP"
716            | "STDDEV_SAMP"
717            | "VARIANCE"
718            | "VAR_POP"
719            | "VAR_SAMP"
720            | "CORR"
721            | "COVAR_POP"
722            | "COVAR_SAMP"
723            | "FIRST_VALUE"
724            | "LAST_VALUE"
725            | "NTH_VALUE"
726            | "BIT_AND"
727            | "BIT_OR"
728            | "BIT_XOR"
729            | "BOOL_AND"
730            | "BOOL_OR"
731            | "EVERY"
732    )
733}
734
735/// Check whether a TypedFunction variant is an aggregate.
736fn typed_function_is_aggregate(func: &TypedFunction) -> bool {
737    matches!(
738        func,
739        TypedFunction::Count { .. }
740            | TypedFunction::Sum { .. }
741            | TypedFunction::Avg { .. }
742            | TypedFunction::Min { .. }
743            | TypedFunction::Max { .. }
744            | TypedFunction::ArrayAgg { .. }
745            | TypedFunction::ApproxDistinct { .. }
746            | TypedFunction::Variance { .. }
747            | TypedFunction::Stddev { .. }
748    )
749}
750
751/// Extract aggregation projections from SELECT items.
752fn extract_aggregates(items: &[SelectItem]) -> Vec<Projection> {
753    let mut aggs = Vec::new();
754    for item in items {
755        if let SelectItem::Expr { expr, alias } = item {
756            collect_aggregates(expr, alias, &mut aggs);
757        }
758    }
759    aggs
760}
761
762fn collect_aggregates(expr: &Expr, alias: &Option<String>, out: &mut Vec<Projection>) {
763    match expr {
764        Expr::Function { name, .. } if is_aggregate_name(name) => {
765            out.push(Projection {
766                expr: expr.clone(),
767                alias: alias.clone(),
768            });
769        }
770        Expr::TypedFunction { func, .. } if typed_function_is_aggregate(func) => {
771            out.push(Projection {
772                expr: expr.clone(),
773                alias: alias.clone(),
774            });
775        }
776        Expr::BinaryOp { left, right, .. } => {
777            collect_aggregates(left, &None, out);
778            collect_aggregates(right, &None, out);
779        }
780        Expr::Alias { expr: inner, name } => {
781            collect_aggregates(inner, &Some(name.clone()), out);
782        }
783        _ => {}
784    }
785}
786
787// ═══════════════════════════════════════════════════════════════════════
788// Tests
789// ═══════════════════════════════════════════════════════════════════════
790
791#[cfg(test)]
792mod tests {
793    use super::*;
794    use crate::dialects::Dialect;
795    use crate::parser::parse;
796
797    #[test]
798    fn test_simple_select() {
799        let ast = parse("SELECT a, b FROM t", Dialect::Ansi).unwrap();
800        let p = plan(&ast).unwrap();
801        assert!(p.len() >= 2); // Scan + Project
802        assert_eq!(p.get(p.root()).unwrap().kind(), "Project");
803    }
804
805    #[test]
806    fn test_select_with_where() {
807        let ast = parse("SELECT a FROM t WHERE a > 1", Dialect::Ansi).unwrap();
808        let p = plan(&ast).unwrap();
809        // Scan -> Filter -> Project
810        let kinds: Vec<_> = p.steps().iter().map(|s| s.kind()).collect();
811        assert!(kinds.contains(&"Scan"));
812        assert!(kinds.contains(&"Filter"));
813        assert!(kinds.contains(&"Project"));
814    }
815
816    #[test]
817    fn test_select_with_order_by() {
818        let ast = parse("SELECT a FROM t ORDER BY a", Dialect::Ansi).unwrap();
819        let p = plan(&ast).unwrap();
820        let kinds: Vec<_> = p.steps().iter().map(|s| s.kind()).collect();
821        assert!(kinds.contains(&"Sort"));
822    }
823
824    #[test]
825    fn test_select_with_group_by() {
826        let ast = parse("SELECT a, COUNT(*) FROM t GROUP BY a", Dialect::Ansi).unwrap();
827        let p = plan(&ast).unwrap();
828        let kinds: Vec<_> = p.steps().iter().map(|s| s.kind()).collect();
829        assert!(kinds.contains(&"Aggregate"));
830    }
831
832    #[test]
833    fn test_select_with_having() {
834        let ast = parse(
835            "SELECT a, COUNT(*) FROM t GROUP BY a HAVING COUNT(*) > 1",
836            Dialect::Ansi,
837        )
838        .unwrap();
839        let p = plan(&ast).unwrap();
840        let kinds: Vec<_> = p.steps().iter().map(|s| s.kind()).collect();
841        // Should have Aggregate and a Filter for HAVING
842        assert!(kinds.contains(&"Aggregate"));
843        assert!(kinds.contains(&"Filter"));
844    }
845
846    #[test]
847    fn test_join() {
848        let ast = parse("SELECT a.x FROM a JOIN b ON a.id = b.id", Dialect::Ansi).unwrap();
849        let p = plan(&ast).unwrap();
850        let kinds: Vec<_> = p.steps().iter().map(|s| s.kind()).collect();
851        assert!(kinds.contains(&"Join"));
852    }
853
854    #[test]
855    fn test_multiple_joins() {
856        let ast = parse(
857            "SELECT a.x FROM a JOIN b ON a.id = b.id JOIN c ON b.id = c.id",
858            Dialect::Ansi,
859        )
860        .unwrap();
861        let p = plan(&ast).unwrap();
862        let join_count = p.steps().iter().filter(|s| s.kind() == "Join").count();
863        assert_eq!(join_count, 2);
864    }
865
866    #[test]
867    fn test_union() {
868        let ast = parse("SELECT a FROM t1 UNION ALL SELECT b FROM t2", Dialect::Ansi).unwrap();
869        let p = plan(&ast).unwrap();
870        let kinds: Vec<_> = p.steps().iter().map(|s| s.kind()).collect();
871        assert!(kinds.contains(&"SetOperation"));
872    }
873
874    #[test]
875    fn test_limit_offset() {
876        let ast = parse("SELECT a FROM t LIMIT 10 OFFSET 5", Dialect::Ansi).unwrap();
877        let p = plan(&ast).unwrap();
878        let kinds: Vec<_> = p.steps().iter().map(|s| s.kind()).collect();
879        assert!(kinds.contains(&"Limit"));
880    }
881
882    #[test]
883    fn test_distinct() {
884        let ast = parse("SELECT DISTINCT a FROM t", Dialect::Ansi).unwrap();
885        let p = plan(&ast).unwrap();
886        let kinds: Vec<_> = p.steps().iter().map(|s| s.kind()).collect();
887        assert!(kinds.contains(&"Distinct"));
888    }
889
890    #[test]
891    fn test_subquery_in_from() {
892        let ast = parse("SELECT x FROM (SELECT a AS x FROM t) sub", Dialect::Ansi).unwrap();
893        let p = plan(&ast).unwrap();
894        // Inner scan + inner project + outer project
895        assert!(p.len() >= 3);
896    }
897
898    #[test]
899    fn test_complex_query() {
900        let ast = parse(
901            "SELECT a, SUM(b) AS total FROM t WHERE c > 0 GROUP BY a HAVING SUM(b) > 10 ORDER BY total DESC LIMIT 5",
902            Dialect::Ansi,
903        ).unwrap();
904        let p = plan(&ast).unwrap();
905        let kinds: Vec<_> = p.steps().iter().map(|s| s.kind()).collect();
906        assert!(kinds.contains(&"Scan"));
907        assert!(kinds.contains(&"Filter")); // WHERE and HAVING
908        assert!(kinds.contains(&"Aggregate"));
909        assert!(kinds.contains(&"Sort"));
910        assert!(kinds.contains(&"Limit"));
911        assert!(kinds.contains(&"Project"));
912    }
913
914    #[test]
915    fn test_dag_dependencies() {
916        let ast = parse("SELECT a FROM t1 JOIN t2 ON t1.id = t2.id", Dialect::Ansi).unwrap();
917        let p = plan(&ast).unwrap();
918        // Every step's dependencies should reference valid earlier steps
919        for (i, step) in p.steps().iter().enumerate() {
920            for dep in step.dependencies() {
921                assert!(dep.0 < i, "step {i} depends on {dep} which is not earlier");
922            }
923        }
924    }
925
926    #[test]
927    fn test_mermaid_output() {
928        let ast = parse("SELECT a FROM t WHERE a > 1", Dialect::Ansi).unwrap();
929        let p = plan(&ast).unwrap();
930        let mermaid = p.to_mermaid();
931        assert!(mermaid.starts_with("graph TD"));
932        assert!(mermaid.contains("Scan"));
933    }
934
935    #[test]
936    fn test_dot_output() {
937        let ast = parse("SELECT a FROM t WHERE a > 1", Dialect::Ansi).unwrap();
938        let p = plan(&ast).unwrap();
939        let dot = p.to_dot();
940        assert!(dot.starts_with("digraph plan"));
941        assert!(dot.contains("Scan"));
942    }
943
944    #[test]
945    fn test_display() {
946        let ast = parse("SELECT a FROM t", Dialect::Ansi).unwrap();
947        let p = plan(&ast).unwrap();
948        let display = format!("{p}");
949        assert!(display.contains("(root)"));
950    }
951
952    #[test]
953    fn test_ddl_rejected() {
954        let ast = parse("CREATE TABLE t (a INT)", Dialect::Ansi).unwrap();
955        assert!(plan(&ast).is_err());
956    }
957
958    #[test]
959    fn test_no_from_select() {
960        let ast = parse("SELECT 1 + 2", Dialect::Ansi).unwrap();
961        let p = plan(&ast).unwrap();
962        assert!(!p.is_empty());
963    }
964
965    #[test]
966    fn test_left_join() {
967        let ast = parse(
968            "SELECT a.x FROM a LEFT JOIN b ON a.id = b.id",
969            Dialect::Ansi,
970        )
971        .unwrap();
972        let p = plan(&ast).unwrap();
973        let join_step = p.steps().iter().find(|s| s.kind() == "Join").unwrap();
974        if let Step::Join { join_type, .. } = join_step {
975            assert_eq!(*join_type, JoinType::Left);
976        } else {
977            panic!("expected Join step");
978        }
979    }
980
981    #[test]
982    fn test_cross_join() {
983        let ast = parse("SELECT a.x FROM a CROSS JOIN b", Dialect::Ansi).unwrap();
984        let p = plan(&ast).unwrap();
985        let join_step = p.steps().iter().find(|s| s.kind() == "Join").unwrap();
986        if let Step::Join { join_type, .. } = join_step {
987            assert_eq!(*join_type, JoinType::Cross);
988        } else {
989            panic!("expected Join step");
990        }
991    }
992
993    #[test]
994    fn test_union_with_order_limit() {
995        let ast = parse(
996            "SELECT a FROM t1 UNION SELECT b FROM t2 ORDER BY 1 LIMIT 10",
997            Dialect::Ansi,
998        )
999        .unwrap();
1000        let p = plan(&ast).unwrap();
1001        let kinds: Vec<_> = p.steps().iter().map(|s| s.kind()).collect();
1002        assert!(kinds.contains(&"SetOperation"));
1003        assert!(kinds.contains(&"Sort"));
1004        assert!(kinds.contains(&"Limit"));
1005    }
1006}