1use crate::optimizer_integration::StorageBackend;
20use crate::sql::ast::*;
21use super::aggregate::{AggDef, AggFunc, HashAggregateNode};
22use super::filter::FilterNode;
23use super::join::{HashJoinNode, NestedLoopJoinNode};
24use super::limit::LimitNode;
25use super::node::PlanNode;
26use super::project::{ProjectExpr, ProjectNode};
27use super::scan::{EmptyNode, SeqScanNode};
28use super::sort::{SortKey, SortNode};
29use super::types::Schema;
30use sochdb_core::Result;
31use std::sync::Arc;
32
33pub struct QueryPlanner {
35 storage: Arc<dyn StorageBackend>,
36}
37
38impl QueryPlanner {
39 pub fn new(storage: Arc<dyn StorageBackend>) -> Self {
40 Self { storage }
41 }
42
43 pub fn plan_select(&self, select: &SelectStmt) -> Result<Box<dyn PlanNode>> {
45 let mut node = self.plan_from(&select.from)?;
47
48 if let Some(where_expr) = &select.where_clause {
50 node = Box::new(FilterNode::new(node, where_expr.clone()));
51 }
52
53 let has_aggregates = self.has_aggregate_in_select(&select.columns);
55 let has_group_by = !select.group_by.is_empty();
56
57 if has_aggregates || has_group_by {
58 let (agg_defs, group_by_exprs) =
60 self.extract_aggregates(&select.columns, &select.group_by)?;
61 node = Box::new(HashAggregateNode::new(node, group_by_exprs, agg_defs));
62
63 if let Some(having) = &select.having {
65 node = Box::new(FilterNode::new(node, having.clone()));
66 }
67 } else {
68 let needs_projection = !self.is_wildcard_only(&select.columns);
70 if needs_projection {
71 let exprs = self.plan_select_exprs(&select.columns, node.schema())?;
72 if !exprs.is_empty() {
73 node = Box::new(ProjectNode::new(node, exprs));
74 }
75 }
76 }
77
78 if !select.order_by.is_empty() {
83 let sort_keys = self.plan_order_by(&select.order_by)?;
84 node = Box::new(SortNode::new(node, sort_keys));
85 }
86
87 let limit = self.extract_usize(&select.limit)?;
89 let offset = self.extract_usize(&select.offset)?.unwrap_or(0);
90 if limit.is_some() || offset > 0 {
91 node = Box::new(LimitNode::new(node, limit, offset));
92 }
93
94 Ok(node)
95 }
96
97 fn plan_from(&self, from: &Option<FromClause>) -> Result<Box<dyn PlanNode>> {
102 let from = match from {
103 Some(f) => f,
104 None => {
105 return Ok(Box::new(super::scan::ValuesNode::new(
107 Schema::empty(),
108 vec![vec![]],
109 )));
110 }
111 };
112
113 if from.tables.is_empty() {
114 return Ok(Box::new(EmptyNode::new(Schema::empty())));
115 }
116
117 let mut node = self.plan_table_ref(&from.tables[0])?;
119
120 for table_ref in from.tables.iter().skip(1) {
122 let right = self.plan_table_ref(table_ref)?;
123 node = Box::new(NestedLoopJoinNode::new(
124 node,
125 right,
126 None, JoinType::Cross,
128 ));
129 }
130
131 Ok(node)
132 }
133
134 fn plan_table_ref(&self, table_ref: &TableRef) -> Result<Box<dyn PlanNode>> {
135 match table_ref {
136 TableRef::Table { name, alias } => {
137 let table_name = name.name().to_string();
138 Ok(Box::new(SeqScanNode::new(
140 self.storage.clone(),
141 table_name,
142 vec!["*".to_string()],
143 alias.as_deref(),
144 )))
145 }
146
147 TableRef::Join {
148 left,
149 join_type,
150 right,
151 condition,
152 } => self.plan_join(left, *join_type, right, condition),
153
154 TableRef::Subquery { query, alias: _ } => self.plan_select(query),
155
156 TableRef::Function { .. } => Err(sochdb_core::SochDBError::Internal(
157 "Table-valued functions not yet supported".into(),
158 )),
159 }
160 }
161
162 fn plan_join(
163 &self,
164 left_ref: &TableRef,
165 join_type: JoinType,
166 right_ref: &TableRef,
167 condition: &Option<JoinCondition>,
168 ) -> Result<Box<dyn PlanNode>> {
169 let left = self.plan_table_ref(left_ref)?;
170 let right = self.plan_table_ref(right_ref)?;
171
172 match condition {
173 Some(JoinCondition::On(expr)) => {
174 if let Some((left_key, right_key)) = self.extract_equi_keys(expr) {
176 Ok(Box::new(HashJoinNode::new(
177 left, right, left_key, right_key, join_type,
178 )))
179 } else {
180 Ok(Box::new(NestedLoopJoinNode::new(
182 left,
183 right,
184 Some(expr.clone()),
185 join_type,
186 )))
187 }
188 }
189 Some(JoinCondition::Using(columns)) => {
190 if let Some(col) = columns.first() {
192 let left_key = Expr::Column(ColumnRef::new(col.clone()));
193 let right_key = Expr::Column(ColumnRef::new(col.clone()));
194 Ok(Box::new(HashJoinNode::new(
195 left, right, left_key, right_key, join_type,
196 )))
197 } else {
198 Ok(Box::new(NestedLoopJoinNode::new(
199 left, right, None, JoinType::Cross,
200 )))
201 }
202 }
203 Some(JoinCondition::Natural) | None => {
204 if join_type == JoinType::Cross {
205 Ok(Box::new(NestedLoopJoinNode::new(
206 left, right, None, JoinType::Cross,
207 )))
208 } else {
209 Ok(Box::new(NestedLoopJoinNode::new(
212 left, right, None, JoinType::Cross,
213 )))
214 }
215 }
216 }
217 }
218
219 fn extract_equi_keys(&self, expr: &Expr) -> Option<(Expr, Expr)> {
222 match expr {
223 Expr::BinaryOp {
224 left,
225 op: BinaryOperator::Eq,
226 right,
227 } => Some((*left.clone(), *right.clone())),
228 _ => None,
229 }
230 }
231
232 fn is_wildcard_only(&self, items: &[SelectItem]) -> bool {
237 items.len() == 1 && matches!(&items[0], SelectItem::Wildcard)
238 }
239
240 fn plan_select_exprs(
241 &self,
242 items: &[SelectItem],
243 _input_schema: &Schema,
244 ) -> Result<Vec<ProjectExpr>> {
245 let mut exprs = Vec::new();
246
247 for item in items {
248 match item {
249 SelectItem::Wildcard => {
250 return Ok(vec![]);
252 }
253 SelectItem::QualifiedWildcard(_table) => {
254 return Ok(vec![]);
256 }
257 SelectItem::Expr { expr, alias } => {
258 let name = alias.clone().unwrap_or_else(|| match expr {
259 Expr::Column(col) => col.column.clone(),
260 Expr::Function(func) => {
261 let args_str = if func.args.is_empty() {
262 "*".to_string()
263 } else {
264 "...".to_string()
265 };
266 format!("{}({})", func.name.name(), args_str)
267 }
268 _ => "?column?".to_string(),
269 });
270 exprs.push(ProjectExpr {
271 expr: expr.clone(),
272 alias: name,
273 });
274 }
275 }
276 }
277
278 Ok(exprs)
279 }
280
281 fn has_aggregate_in_select(&self, items: &[SelectItem]) -> bool {
286 for item in items {
287 if let SelectItem::Expr { expr, .. } = item {
288 if self.expr_has_aggregate(expr) {
289 return true;
290 }
291 }
292 }
293 false
294 }
295
296 fn expr_has_aggregate(&self, expr: &Expr) -> bool {
297 match expr {
298 Expr::Function(func) => {
299 let name = func.name.name().to_uppercase();
300 matches!(
301 name.as_str(),
302 "COUNT" | "SUM" | "AVG" | "MIN" | "MAX"
303 )
304 }
305 Expr::BinaryOp { left, right, .. } => {
306 self.expr_has_aggregate(left) || self.expr_has_aggregate(right)
307 }
308 Expr::UnaryOp { expr, .. } => self.expr_has_aggregate(expr),
309 _ => false,
310 }
311 }
312
313 fn extract_aggregates(
314 &self,
315 items: &[SelectItem],
316 group_by: &[Expr],
317 ) -> Result<(Vec<AggDef>, Vec<Expr>)> {
318 let mut agg_defs = Vec::new();
319
320 for item in items {
321 if let SelectItem::Expr { expr, alias } = item {
322 if let Some(agg_def) = self.try_extract_agg(expr, alias)? {
323 agg_defs.push(agg_def);
324 }
325 }
327 }
328
329 Ok((agg_defs, group_by.to_vec()))
330 }
331
332 fn try_extract_agg(
333 &self,
334 expr: &Expr,
335 alias: &Option<String>,
336 ) -> Result<Option<AggDef>> {
337 match expr {
338 Expr::Function(func) => {
339 let name = func.name.name().to_uppercase();
340 let func_type = match name.as_str() {
341 "COUNT" => {
342 if func.distinct {
343 Some(AggFunc::CountDistinct)
344 } else {
345 Some(AggFunc::Count)
346 }
347 }
348 "SUM" => Some(AggFunc::Sum),
349 "AVG" => Some(AggFunc::Avg),
350 "MIN" => Some(AggFunc::Min),
351 "MAX" => Some(AggFunc::Max),
352 _ => None,
353 };
354
355 if let Some(func_type) = func_type {
356 let agg_expr = if func.args.is_empty()
357 || (func.args.len() == 1
358 && matches!(&func.args[0], Expr::Column(c) if c.column == "*"))
359 {
360 None } else {
362 Some(func.args[0].clone())
363 };
364
365 let output_name = alias.clone().unwrap_or_else(|| {
366 let args_str = if func.args.is_empty() {
367 "*".to_string()
368 } else {
369 match &func.args[0] {
370 Expr::Column(c) => c.column.clone(),
371 _ => "expr".to_string(),
372 }
373 };
374 format!("{}({})", name.to_lowercase(), args_str)
375 });
376
377 Ok(Some(AggDef {
378 func: func_type,
379 expr: agg_expr,
380 alias: output_name,
381 }))
382 } else {
383 Ok(None)
384 }
385 }
386 _ => Ok(None),
387 }
388 }
389
390 fn plan_order_by(&self, items: &[OrderByItem]) -> Result<Vec<SortKey>> {
395 Ok(items
396 .iter()
397 .map(|item| SortKey {
398 expr: item.expr.clone(),
399 ascending: item.asc,
400 nulls_first: item.nulls_first.unwrap_or(!item.asc),
401 })
402 .collect())
403 }
404
405 fn extract_usize(&self, expr: &Option<Expr>) -> Result<Option<usize>> {
410 match expr {
411 Some(Expr::Literal(Literal::Integer(n))) => Ok(Some(*n as usize)),
412 Some(_) => Err(sochdb_core::SochDBError::Internal(
413 "LIMIT/OFFSET must be an integer literal".into(),
414 )),
415 None => Ok(None),
416 }
417 }
418}
419
420pub fn explain_select(select: &SelectStmt, _storage: &Arc<dyn StorageBackend>) -> String {
422 let mut lines = Vec::new();
423
424 if let Some(from) = &select.from {
426 for table_ref in &from.tables {
427 explain_table_ref(table_ref, &mut lines, 0);
428 }
429 }
430
431 if select.where_clause.is_some() {
432 lines.push(" Filter (WHERE)".to_string());
433 }
434
435 if !select.group_by.is_empty() {
436 let cols: Vec<String> = select.group_by.iter().map(|e| format!("{:?}", e)).collect();
437 lines.push(format!(" HashAggregate [group_by={}]", cols.join(", ")));
438 }
439
440 if select.having.is_some() {
441 lines.push(" Filter (HAVING)".to_string());
442 }
443
444 let has_agg = select.columns.iter().any(|item| {
446 if let SelectItem::Expr { expr, .. } = item {
447 matches!(expr, Expr::Function(f) if {
448 let n = f.name.name().to_uppercase();
449 matches!(n.as_str(), "COUNT" | "SUM" | "AVG" | "MIN" | "MAX")
450 })
451 } else {
452 false
453 }
454 });
455 if has_agg && select.group_by.is_empty() {
456 lines.push(" HashAggregate [global]".to_string());
457 }
458
459 let col_names: Vec<String> = select.columns.iter().map(|item| {
460 match item {
461 SelectItem::Wildcard => "*".to_string(),
462 SelectItem::QualifiedWildcard(t) => format!("{}.*", t),
463 SelectItem::Expr { expr, alias } => {
464 alias.clone().unwrap_or_else(|| format!("{:?}", expr))
465 }
466 }
467 }).collect();
468 lines.push(format!(" Project [{}]", col_names.join(", ")));
469
470 if !select.order_by.is_empty() {
471 let orders: Vec<String> = select.order_by.iter().map(|o| {
472 let dir = if o.asc { "ASC" } else { "DESC" };
473 format!("{:?} {}", o.expr, dir)
474 }).collect();
475 lines.push(format!(" Sort [{}]", orders.join(", ")));
476 }
477
478 if select.limit.is_some() || select.offset.is_some() {
479 lines.push(format!(
480 " Limit [limit={:?}, offset={:?}]",
481 select.limit, select.offset
482 ));
483 }
484
485 lines.join("\n")
486}
487
488fn explain_table_ref(table_ref: &TableRef, lines: &mut Vec<String>, depth: usize) {
489 let indent = " ".repeat(depth);
490 match table_ref {
491 TableRef::Table { name, alias } => {
492 let alias_str = alias.as_ref().map_or(String::new(), |a| format!(" AS {}", a));
493 lines.push(format!("{}SeqScan [table={}{}]", indent, name, alias_str));
494 }
495 TableRef::Join {
496 left,
497 join_type,
498 right,
499 condition,
500 } => {
501 let jt = match join_type {
502 JoinType::Inner => "INNER",
503 JoinType::Left => "LEFT",
504 JoinType::Right => "RIGHT",
505 JoinType::Full => "FULL",
506 JoinType::Cross => "CROSS",
507 };
508 let cond_str = match condition {
509 Some(JoinCondition::On(expr)) => format!(" ON {:?}", expr),
510 Some(JoinCondition::Using(cols)) => format!(" USING({})", cols.join(", ")),
511 Some(JoinCondition::Natural) => " NATURAL".to_string(),
512 None => String::new(),
513 };
514 lines.push(format!("{}{} JOIN{}", indent, jt, cond_str));
515 explain_table_ref(left, lines, depth + 1);
516 explain_table_ref(right, lines, depth + 1);
517 }
518 TableRef::Subquery { alias, .. } => {
519 lines.push(format!("{}Subquery [alias={}]", indent, alias));
520 }
521 TableRef::Function { name, .. } => {
522 lines.push(format!("{}Function [{}]", indent, name));
523 }
524 }
525}