Skip to main content

sochdb_query/executor/
planner.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2
3//! Query planner: SQL AST → Volcano operator tree.
4//!
5//! Converts a parsed SQL `SelectStmt` into a tree of physical operators
6//! that can be executed row-at-a-time via the Volcano model.
7//!
8//! ```text
9//! SelectStmt
10//!   → FROM → SeqScan / Join tree
11//!   → WHERE → Filter
12//!   → GROUP BY + aggregates → HashAggregate
13//!   → HAVING → Filter
14//!   → SELECT → Project
15//!   → ORDER BY → Sort
16//!   → LIMIT/OFFSET → Limit
17//! ```
18
19use crate::optimizer_integration::StorageBackend;
20use crate::sql::ast::*;
21use super::aggregate::{AggDef, AggFunc, HashAggregateNode};
22use super::filter::FilterNode;
23use super::join::{HashJoinNode, NestedLoopJoinNode};
24use super::limit::LimitNode;
25use super::node::PlanNode;
26use super::project::{ProjectExpr, ProjectNode};
27use super::scan::{EmptyNode, SeqScanNode};
28use super::sort::{SortKey, SortNode};
29use super::types::Schema;
30use sochdb_core::Result;
31use std::sync::Arc;
32
33/// Query planner that converts SQL AST to Volcano operator trees.
34pub struct QueryPlanner {
35    storage: Arc<dyn StorageBackend>,
36}
37
38impl QueryPlanner {
39    pub fn new(storage: Arc<dyn StorageBackend>) -> Self {
40        Self { storage }
41    }
42
43    /// Plan a SELECT statement into an operator tree.
44    pub fn plan_select(&self, select: &SelectStmt) -> Result<Box<dyn PlanNode>> {
45        // 1. FROM clause → base scan / join tree
46        let mut node = self.plan_from(&select.from)?;
47
48        // 2. WHERE clause → Filter
49        if let Some(where_expr) = &select.where_clause {
50            node = Box::new(FilterNode::new(node, where_expr.clone()));
51        }
52
53        // 3. Detect aggregates in SELECT list
54        let has_aggregates = self.has_aggregate_in_select(&select.columns);
55        let has_group_by = !select.group_by.is_empty();
56
57        if has_aggregates || has_group_by {
58            // GROUP BY + aggregates → HashAggregate
59            let (agg_defs, group_by_exprs) =
60                self.extract_aggregates(&select.columns, &select.group_by)?;
61            node = Box::new(HashAggregateNode::new(node, group_by_exprs, agg_defs));
62
63            // HAVING → Filter (operates on aggregate output)
64            if let Some(having) = &select.having {
65                node = Box::new(FilterNode::new(node, having.clone()));
66            }
67        } else {
68            // 4. SELECT → Project (non-aggregate case)
69            let needs_projection = !self.is_wildcard_only(&select.columns);
70            if needs_projection {
71                let exprs = self.plan_select_exprs(&select.columns, node.schema())?;
72                if !exprs.is_empty() {
73                    node = Box::new(ProjectNode::new(node, exprs));
74                }
75            }
76        }
77
78        // 5. DISTINCT — implement as a sort + dedup or hash-based
79        // (simplified: not yet implemented, would need a DistinctNode)
80
81        // 6. ORDER BY → Sort
82        if !select.order_by.is_empty() {
83            let sort_keys = self.plan_order_by(&select.order_by)?;
84            node = Box::new(SortNode::new(node, sort_keys));
85        }
86
87        // 7. LIMIT / OFFSET → Limit
88        let limit = self.extract_usize(&select.limit)?;
89        let offset = self.extract_usize(&select.offset)?.unwrap_or(0);
90        if limit.is_some() || offset > 0 {
91            node = Box::new(LimitNode::new(node, limit, offset));
92        }
93
94        Ok(node)
95    }
96
97    // ========================================================================
98    // FROM clause planning
99    // ========================================================================
100
101    fn plan_from(&self, from: &Option<FromClause>) -> Result<Box<dyn PlanNode>> {
102        let from = match from {
103            Some(f) => f,
104            None => {
105                // No FROM: return a single empty row (for SELECT 1+1, etc.)
106                return Ok(Box::new(super::scan::ValuesNode::new(
107                    Schema::empty(),
108                    vec![vec![]],
109                )));
110            }
111        };
112
113        if from.tables.is_empty() {
114            return Ok(Box::new(EmptyNode::new(Schema::empty())));
115        }
116
117        // Plan first table
118        let mut node = self.plan_table_ref(&from.tables[0])?;
119
120        // Implicit cross join for multiple tables in FROM
121        for table_ref in from.tables.iter().skip(1) {
122            let right = self.plan_table_ref(table_ref)?;
123            node = Box::new(NestedLoopJoinNode::new(
124                node,
125                right,
126                None, // CROSS JOIN
127                JoinType::Cross,
128            ));
129        }
130
131        Ok(node)
132    }
133
134    fn plan_table_ref(&self, table_ref: &TableRef) -> Result<Box<dyn PlanNode>> {
135        match table_ref {
136            TableRef::Table { name, alias } => {
137                let table_name = name.name().to_string();
138                // Start with wildcard scan; projection will be added later
139                Ok(Box::new(SeqScanNode::new(
140                    self.storage.clone(),
141                    table_name,
142                    vec!["*".to_string()],
143                    alias.as_deref(),
144                )))
145            }
146
147            TableRef::Join {
148                left,
149                join_type,
150                right,
151                condition,
152            } => self.plan_join(left, *join_type, right, condition),
153
154            TableRef::Subquery { query, alias: _ } => self.plan_select(query),
155
156            TableRef::Function { .. } => Err(sochdb_core::SochDBError::Internal(
157                "Table-valued functions not yet supported".into(),
158            )),
159        }
160    }
161
162    fn plan_join(
163        &self,
164        left_ref: &TableRef,
165        join_type: JoinType,
166        right_ref: &TableRef,
167        condition: &Option<JoinCondition>,
168    ) -> Result<Box<dyn PlanNode>> {
169        let left = self.plan_table_ref(left_ref)?;
170        let right = self.plan_table_ref(right_ref)?;
171
172        match condition {
173            Some(JoinCondition::On(expr)) => {
174                // Try to detect equi-join for HashJoin optimization
175                if let Some((left_key, right_key)) = self.extract_equi_keys(expr) {
176                    Ok(Box::new(HashJoinNode::new(
177                        left, right, left_key, right_key, join_type,
178                    )))
179                } else {
180                    // Theta join — use nested loop
181                    Ok(Box::new(NestedLoopJoinNode::new(
182                        left,
183                        right,
184                        Some(expr.clone()),
185                        join_type,
186                    )))
187                }
188            }
189            Some(JoinCondition::Using(columns)) => {
190                // USING(col) → equi-join on col = col
191                if let Some(col) = columns.first() {
192                    let left_key = Expr::Column(ColumnRef::new(col.clone()));
193                    let right_key = Expr::Column(ColumnRef::new(col.clone()));
194                    Ok(Box::new(HashJoinNode::new(
195                        left, right, left_key, right_key, join_type,
196                    )))
197                } else {
198                    Ok(Box::new(NestedLoopJoinNode::new(
199                        left, right, None, JoinType::Cross,
200                    )))
201                }
202            }
203            Some(JoinCondition::Natural) | None => {
204                if join_type == JoinType::Cross {
205                    Ok(Box::new(NestedLoopJoinNode::new(
206                        left, right, None, JoinType::Cross,
207                    )))
208                } else {
209                    // Natural join — would need schema introspection to find common columns
210                    // For now, fall back to cross join
211                    Ok(Box::new(NestedLoopJoinNode::new(
212                        left, right, None, JoinType::Cross,
213                    )))
214                }
215            }
216        }
217    }
218
219    /// Try to extract equi-join keys from an ON expression.
220    /// Returns (left_key_expr, right_key_expr) if the expression is `a.x = b.y`.
221    fn extract_equi_keys(&self, expr: &Expr) -> Option<(Expr, Expr)> {
222        match expr {
223            Expr::BinaryOp {
224                left,
225                op: BinaryOperator::Eq,
226                right,
227            } => Some((*left.clone(), *right.clone())),
228            _ => None,
229        }
230    }
231
232    // ========================================================================
233    // SELECT list / Projection
234    // ========================================================================
235
236    fn is_wildcard_only(&self, items: &[SelectItem]) -> bool {
237        items.len() == 1 && matches!(&items[0], SelectItem::Wildcard)
238    }
239
240    fn plan_select_exprs(
241        &self,
242        items: &[SelectItem],
243        _input_schema: &Schema,
244    ) -> Result<Vec<ProjectExpr>> {
245        let mut exprs = Vec::new();
246
247        for item in items {
248            match item {
249                SelectItem::Wildcard => {
250                    // Wildcard — pass-through handled separately
251                    return Ok(vec![]);
252                }
253                SelectItem::QualifiedWildcard(_table) => {
254                    // table.* — would need schema lookup
255                    return Ok(vec![]);
256                }
257                SelectItem::Expr { expr, alias } => {
258                    let name = alias.clone().unwrap_or_else(|| match expr {
259                        Expr::Column(col) => col.column.clone(),
260                        Expr::Function(func) => {
261                            let args_str = if func.args.is_empty() {
262                                "*".to_string()
263                            } else {
264                                "...".to_string()
265                            };
266                            format!("{}({})", func.name.name(), args_str)
267                        }
268                        _ => "?column?".to_string(),
269                    });
270                    exprs.push(ProjectExpr {
271                        expr: expr.clone(),
272                        alias: name,
273                    });
274                }
275            }
276        }
277
278        Ok(exprs)
279    }
280
281    // ========================================================================
282    // Aggregate detection and extraction
283    // ========================================================================
284
285    fn has_aggregate_in_select(&self, items: &[SelectItem]) -> bool {
286        for item in items {
287            if let SelectItem::Expr { expr, .. } = item {
288                if self.expr_has_aggregate(expr) {
289                    return true;
290                }
291            }
292        }
293        false
294    }
295
296    fn expr_has_aggregate(&self, expr: &Expr) -> bool {
297        match expr {
298            Expr::Function(func) => {
299                let name = func.name.name().to_uppercase();
300                matches!(
301                    name.as_str(),
302                    "COUNT" | "SUM" | "AVG" | "MIN" | "MAX"
303                )
304            }
305            Expr::BinaryOp { left, right, .. } => {
306                self.expr_has_aggregate(left) || self.expr_has_aggregate(right)
307            }
308            Expr::UnaryOp { expr, .. } => self.expr_has_aggregate(expr),
309            _ => false,
310        }
311    }
312
313    fn extract_aggregates(
314        &self,
315        items: &[SelectItem],
316        group_by: &[Expr],
317    ) -> Result<(Vec<AggDef>, Vec<Expr>)> {
318        let mut agg_defs = Vec::new();
319
320        for item in items {
321            if let SelectItem::Expr { expr, alias } = item {
322                if let Some(agg_def) = self.try_extract_agg(expr, alias)? {
323                    agg_defs.push(agg_def);
324                }
325                // Group-by columns are handled automatically by HashAggregateNode
326            }
327        }
328
329        Ok((agg_defs, group_by.to_vec()))
330    }
331
332    fn try_extract_agg(
333        &self,
334        expr: &Expr,
335        alias: &Option<String>,
336    ) -> Result<Option<AggDef>> {
337        match expr {
338            Expr::Function(func) => {
339                let name = func.name.name().to_uppercase();
340                let func_type = match name.as_str() {
341                    "COUNT" => {
342                        if func.distinct {
343                            Some(AggFunc::CountDistinct)
344                        } else {
345                            Some(AggFunc::Count)
346                        }
347                    }
348                    "SUM" => Some(AggFunc::Sum),
349                    "AVG" => Some(AggFunc::Avg),
350                    "MIN" => Some(AggFunc::Min),
351                    "MAX" => Some(AggFunc::Max),
352                    _ => None,
353                };
354
355                if let Some(func_type) = func_type {
356                    let agg_expr = if func.args.is_empty()
357                        || (func.args.len() == 1
358                            && matches!(&func.args[0], Expr::Column(c) if c.column == "*"))
359                    {
360                        None // COUNT(*)
361                    } else {
362                        Some(func.args[0].clone())
363                    };
364
365                    let output_name = alias.clone().unwrap_or_else(|| {
366                        let args_str = if func.args.is_empty() {
367                            "*".to_string()
368                        } else {
369                            match &func.args[0] {
370                                Expr::Column(c) => c.column.clone(),
371                                _ => "expr".to_string(),
372                            }
373                        };
374                        format!("{}({})", name.to_lowercase(), args_str)
375                    });
376
377                    Ok(Some(AggDef {
378                        func: func_type,
379                        expr: agg_expr,
380                        alias: output_name,
381                    }))
382                } else {
383                    Ok(None)
384                }
385            }
386            _ => Ok(None),
387        }
388    }
389
390    // ========================================================================
391    // ORDER BY
392    // ========================================================================
393
394    fn plan_order_by(&self, items: &[OrderByItem]) -> Result<Vec<SortKey>> {
395        Ok(items
396            .iter()
397            .map(|item| SortKey {
398                expr: item.expr.clone(),
399                ascending: item.asc,
400                nulls_first: item.nulls_first.unwrap_or(!item.asc),
401            })
402            .collect())
403    }
404
405    // ========================================================================
406    // Utilities
407    // ========================================================================
408
409    fn extract_usize(&self, expr: &Option<Expr>) -> Result<Option<usize>> {
410        match expr {
411            Some(Expr::Literal(Literal::Integer(n))) => Ok(Some(*n as usize)),
412            Some(_) => Err(sochdb_core::SochDBError::Internal(
413                "LIMIT/OFFSET must be an integer literal".into(),
414            )),
415            None => Ok(None),
416        }
417    }
418}
419
420/// Generate a textual EXPLAIN representation for a SELECT statement.
421pub fn explain_select(select: &SelectStmt, _storage: &Arc<dyn StorageBackend>) -> String {
422    let mut lines = Vec::new();
423
424    // Simplified EXPLAIN output
425    if let Some(from) = &select.from {
426        for table_ref in &from.tables {
427            explain_table_ref(table_ref, &mut lines, 0);
428        }
429    }
430
431    if select.where_clause.is_some() {
432        lines.push("  Filter (WHERE)".to_string());
433    }
434
435    if !select.group_by.is_empty() {
436        let cols: Vec<String> = select.group_by.iter().map(|e| format!("{:?}", e)).collect();
437        lines.push(format!("  HashAggregate [group_by={}]", cols.join(", ")));
438    }
439
440    if select.having.is_some() {
441        lines.push("  Filter (HAVING)".to_string());
442    }
443
444    // Check for aggregates in SELECT
445    let has_agg = select.columns.iter().any(|item| {
446        if let SelectItem::Expr { expr, .. } = item {
447            matches!(expr, Expr::Function(f) if {
448                let n = f.name.name().to_uppercase();
449                matches!(n.as_str(), "COUNT" | "SUM" | "AVG" | "MIN" | "MAX")
450            })
451        } else {
452            false
453        }
454    });
455    if has_agg && select.group_by.is_empty() {
456        lines.push("  HashAggregate [global]".to_string());
457    }
458
459    let col_names: Vec<String> = select.columns.iter().map(|item| {
460        match item {
461            SelectItem::Wildcard => "*".to_string(),
462            SelectItem::QualifiedWildcard(t) => format!("{}.*", t),
463            SelectItem::Expr { expr, alias } => {
464                alias.clone().unwrap_or_else(|| format!("{:?}", expr))
465            }
466        }
467    }).collect();
468    lines.push(format!("  Project [{}]", col_names.join(", ")));
469
470    if !select.order_by.is_empty() {
471        let orders: Vec<String> = select.order_by.iter().map(|o| {
472            let dir = if o.asc { "ASC" } else { "DESC" };
473            format!("{:?} {}", o.expr, dir)
474        }).collect();
475        lines.push(format!("  Sort [{}]", orders.join(", ")));
476    }
477
478    if select.limit.is_some() || select.offset.is_some() {
479        lines.push(format!(
480            "  Limit [limit={:?}, offset={:?}]",
481            select.limit, select.offset
482        ));
483    }
484
485    lines.join("\n")
486}
487
488fn explain_table_ref(table_ref: &TableRef, lines: &mut Vec<String>, depth: usize) {
489    let indent = "  ".repeat(depth);
490    match table_ref {
491        TableRef::Table { name, alias } => {
492            let alias_str = alias.as_ref().map_or(String::new(), |a| format!(" AS {}", a));
493            lines.push(format!("{}SeqScan [table={}{}]", indent, name, alias_str));
494        }
495        TableRef::Join {
496            left,
497            join_type,
498            right,
499            condition,
500        } => {
501            let jt = match join_type {
502                JoinType::Inner => "INNER",
503                JoinType::Left => "LEFT",
504                JoinType::Right => "RIGHT",
505                JoinType::Full => "FULL",
506                JoinType::Cross => "CROSS",
507            };
508            let cond_str = match condition {
509                Some(JoinCondition::On(expr)) => format!(" ON {:?}", expr),
510                Some(JoinCondition::Using(cols)) => format!(" USING({})", cols.join(", ")),
511                Some(JoinCondition::Natural) => " NATURAL".to_string(),
512                None => String::new(),
513            };
514            lines.push(format!("{}{} JOIN{}", indent, jt, cond_str));
515            explain_table_ref(left, lines, depth + 1);
516            explain_table_ref(right, lines, depth + 1);
517        }
518        TableRef::Subquery { alias, .. } => {
519            lines.push(format!("{}Subquery [alias={}]", indent, alias));
520        }
521        TableRef::Function { name, .. } => {
522            lines.push(format!("{}Function [{}]", indent, name));
523        }
524    }
525}