1use std::fmt;
23
24use crate::ast::*;
25use crate::errors::{Result, SqlglotError};
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
33pub struct StepId(usize);
34
35impl fmt::Display for StepId {
36 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
37 write!(f, "step_{}", self.0)
38 }
39}
40
41#[derive(Debug, Clone, PartialEq)]
47pub struct Projection {
48 pub expr: Expr,
50 pub alias: Option<String>,
52}
53
54#[derive(Debug, Clone, PartialEq)]
60pub enum Step {
61 Scan {
63 table: String,
65 alias: Option<String>,
67 projections: Vec<Projection>,
69 predicate: Option<Expr>,
71 dependencies: Vec<StepId>,
73 },
74 Filter {
76 predicate: Expr,
78 projections: Vec<Projection>,
80 dependencies: Vec<StepId>,
82 },
83 Project {
85 projections: Vec<Projection>,
87 dependencies: Vec<StepId>,
89 },
90 Aggregate {
92 group_by: Vec<Expr>,
94 aggregations: Vec<Projection>,
96 projections: Vec<Projection>,
98 dependencies: Vec<StepId>,
100 },
101 Sort {
103 order_by: Vec<OrderByItem>,
105 projections: Vec<Projection>,
107 dependencies: Vec<StepId>,
109 },
110 Join {
112 join_type: JoinType,
114 condition: Option<Expr>,
116 using_columns: Vec<String>,
118 projections: Vec<Projection>,
120 dependencies: Vec<StepId>,
122 },
123 Limit {
125 limit: Option<Expr>,
127 offset: Option<Expr>,
129 projections: Vec<Projection>,
131 dependencies: Vec<StepId>,
133 },
134 SetOperation {
136 op: SetOperationType,
138 all: bool,
140 projections: Vec<Projection>,
142 dependencies: Vec<StepId>,
144 },
145 Distinct {
147 projections: Vec<Projection>,
149 dependencies: Vec<StepId>,
151 },
152}
153
154impl Step {
155 #[must_use]
157 pub fn dependencies(&self) -> &[StepId] {
158 match self {
159 Step::Scan { dependencies, .. }
160 | Step::Filter { dependencies, .. }
161 | Step::Project { dependencies, .. }
162 | Step::Aggregate { dependencies, .. }
163 | Step::Sort { dependencies, .. }
164 | Step::Join { dependencies, .. }
165 | Step::Limit { dependencies, .. }
166 | Step::SetOperation { dependencies, .. }
167 | Step::Distinct { dependencies, .. } => dependencies,
168 }
169 }
170
171 #[must_use]
173 pub fn projections(&self) -> &[Projection] {
174 match self {
175 Step::Scan { projections, .. }
176 | Step::Filter { projections, .. }
177 | Step::Project { projections, .. }
178 | Step::Aggregate { projections, .. }
179 | Step::Sort { projections, .. }
180 | Step::Join { projections, .. }
181 | Step::Limit { projections, .. }
182 | Step::SetOperation { projections, .. }
183 | Step::Distinct { projections, .. } => projections,
184 }
185 }
186
187 #[must_use]
189 pub fn kind(&self) -> &'static str {
190 match self {
191 Step::Scan { .. } => "Scan",
192 Step::Filter { .. } => "Filter",
193 Step::Project { .. } => "Project",
194 Step::Aggregate { .. } => "Aggregate",
195 Step::Sort { .. } => "Sort",
196 Step::Join { .. } => "Join",
197 Step::Limit { .. } => "Limit",
198 Step::SetOperation { .. } => "SetOperation",
199 Step::Distinct { .. } => "Distinct",
200 }
201 }
202}
203
204#[derive(Debug, Clone)]
213pub struct Plan {
214 steps: Vec<Step>,
216 root: StepId,
218}
219
220impl Plan {
221 #[must_use]
223 pub fn root(&self) -> StepId {
224 self.root
225 }
226
227 #[must_use]
229 pub fn steps(&self) -> &[Step] {
230 &self.steps
231 }
232
233 #[must_use]
235 pub fn get(&self, id: StepId) -> Option<&Step> {
236 self.steps.get(id.0)
237 }
238
239 #[must_use]
241 pub fn len(&self) -> usize {
242 self.steps.len()
243 }
244
245 #[must_use]
247 pub fn is_empty(&self) -> bool {
248 self.steps.is_empty()
249 }
250
251 #[must_use]
253 pub fn to_mermaid(&self) -> String {
254 let mut out = String::from("graph TD\n");
255 for (i, step) in self.steps.iter().enumerate() {
256 let id = StepId(i);
257 let label = step_label(step);
258 out.push_str(&format!(" {id}[\"{label}\"]\n"));
259 for dep in step.dependencies() {
260 out.push_str(&format!(" {dep} --> {id}\n"));
261 }
262 }
263 out
264 }
265
266 #[must_use]
268 pub fn to_dot(&self) -> String {
269 let mut out = String::from("digraph plan {\n rankdir=BT;\n");
270 for (i, step) in self.steps.iter().enumerate() {
271 let id = StepId(i);
272 let label = step_label(step);
273 out.push_str(&format!(" {id} [label=\"{label}\"];\n"));
274 for dep in step.dependencies() {
275 out.push_str(&format!(" {dep} -> {id};\n"));
276 }
277 }
278 out.push_str("}\n");
279 out
280 }
281}
282
283impl fmt::Display for Plan {
284 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
285 for (i, step) in self.steps.iter().enumerate() {
286 let id = StepId(i);
287 let root_marker = if id == self.root { " (root)" } else { "" };
288 writeln!(f, "{id}{root_marker}: {}", step_label(step))?;
289 for dep in step.dependencies() {
290 writeln!(f, " <- {dep}")?;
291 }
292 }
293 Ok(())
294 }
295}
296
297fn step_label(step: &Step) -> String {
299 match step {
300 Step::Scan {
301 table,
302 alias,
303 predicate,
304 ..
305 } => {
306 let name = alias.as_deref().unwrap_or(table.as_str());
307 if predicate.is_some() {
308 format!("Scan({name} + filter)")
309 } else {
310 format!("Scan({name})")
311 }
312 }
313 Step::Filter { .. } => "Filter".to_string(),
314 Step::Project { projections, .. } => {
315 let cols: Vec<_> = projections
316 .iter()
317 .map(|p| {
318 p.alias
319 .as_deref()
320 .unwrap_or_else(|| expr_short_name(&p.expr))
321 })
322 .collect();
323 if cols.len() <= 4 {
324 format!("Project({})", cols.join(", "))
325 } else {
326 format!("Project({} cols)", cols.len())
327 }
328 }
329 Step::Aggregate { group_by, .. } => {
330 if group_by.is_empty() {
331 "Aggregate(scalar)".to_string()
332 } else {
333 format!("Aggregate({} keys)", group_by.len())
334 }
335 }
336 Step::Sort { order_by, .. } => format!("Sort({} keys)", order_by.len()),
337 Step::Join { join_type, .. } => format!("Join({join_type:?})"),
338 Step::Limit { limit, offset, .. } => {
339 let mut parts = Vec::new();
340 if limit.is_some() {
341 parts.push("limit");
342 }
343 if offset.is_some() {
344 parts.push("offset");
345 }
346 format!("Limit({})", parts.join("+"))
347 }
348 Step::SetOperation { op, all, .. } => {
349 let all_str = if *all { " ALL" } else { "" };
350 format!("{op:?}{all_str}")
351 }
352 Step::Distinct { .. } => "Distinct".to_string(),
353 }
354}
355
356fn expr_short_name(expr: &Expr) -> &str {
358 match expr {
359 Expr::Column { name, .. } => name.as_str(),
360 Expr::Wildcard => "*",
361 _ => "expr",
362 }
363}
364
365pub fn plan(statement: &Statement) -> Result<Plan> {
380 let mut builder = PlanBuilder::new();
381 let _root = builder.plan_statement(statement)?;
382 Ok(builder.build())
383}
384
385struct PlanBuilder {
387 steps: Vec<Step>,
388}
389
390impl PlanBuilder {
391 fn new() -> Self {
392 Self { steps: Vec::new() }
393 }
394
395 fn add_step(&mut self, step: Step) -> StepId {
396 let id = StepId(self.steps.len());
397 self.steps.push(step);
398 id
399 }
400
401 fn build(self) -> Plan {
402 let root = if self.steps.is_empty() {
403 StepId(0)
404 } else {
405 StepId(self.steps.len() - 1)
406 };
407 Plan {
408 steps: self.steps,
409 root,
410 }
411 }
412
413 fn plan_statement(&mut self, stmt: &Statement) -> Result<StepId> {
418 match stmt {
419 Statement::Select(sel) => self.plan_select(sel),
420 Statement::SetOperation(set_op) => self.plan_set_operation(set_op),
421 _ => Err(SqlglotError::Internal(format!(
422 "Planner does not support {:?} statements",
423 std::mem::discriminant(stmt)
424 ))),
425 }
426 }
427
428 fn plan_select(&mut self, sel: &SelectStatement) -> Result<StepId> {
433 let mut current = if let Some(from) = &sel.from {
435 self.plan_table_source(&from.source)?
436 } else {
437 self.add_step(Step::Scan {
439 table: String::new(),
440 alias: None,
441 projections: vec![],
442 predicate: None,
443 dependencies: vec![],
444 })
445 };
446
447 for join in &sel.joins {
449 let right = self.plan_table_source(&join.table)?;
450 let projections = vec![]; current = self.add_step(Step::Join {
452 join_type: join.join_type.clone(),
453 condition: join.on.clone(),
454 using_columns: join.using.clone(),
455 projections,
456 dependencies: vec![current, right],
457 });
458 }
459
460 if let Some(pred) = &sel.where_clause {
462 current = self.add_step(Step::Filter {
463 predicate: pred.clone(),
464 projections: vec![],
465 dependencies: vec![current],
466 });
467 }
468
469 if !sel.group_by.is_empty() || has_aggregates(&sel.columns) {
471 let aggregations = extract_aggregates(&sel.columns);
472 current = self.add_step(Step::Aggregate {
473 group_by: sel.group_by.clone(),
474 aggregations,
475 projections: vec![],
476 dependencies: vec![current],
477 });
478 }
479
480 if let Some(having) = &sel.having {
482 current = self.add_step(Step::Filter {
483 predicate: having.clone(),
484 projections: vec![],
485 dependencies: vec![current],
486 });
487 }
488
489 if sel.distinct {
491 current = self.add_step(Step::Distinct {
492 projections: vec![],
493 dependencies: vec![current],
494 });
495 }
496
497 if !sel.order_by.is_empty() {
499 current = self.add_step(Step::Sort {
500 order_by: sel.order_by.clone(),
501 projections: vec![],
502 dependencies: vec![current],
503 });
504 }
505
506 if sel.limit.is_some() || sel.offset.is_some() || sel.fetch_first.is_some() {
508 let limit = sel.limit.clone().or_else(|| sel.fetch_first.clone());
509 current = self.add_step(Step::Limit {
510 limit,
511 offset: sel.offset.clone(),
512 projections: vec![],
513 dependencies: vec![current],
514 });
515 }
516
517 let projections = select_items_to_projections(&sel.columns);
519 if !projections.is_empty() {
520 current = self.add_step(Step::Project {
521 projections,
522 dependencies: vec![current],
523 });
524 }
525
526 Ok(current)
527 }
528
529 fn plan_table_source(&mut self, source: &TableSource) -> Result<StepId> {
534 match source {
535 TableSource::Table(tref) => {
536 let table = fully_qualified_name(tref);
537 Ok(self.add_step(Step::Scan {
538 table,
539 alias: tref.alias.clone(),
540 projections: vec![],
541 predicate: None,
542 dependencies: vec![],
543 }))
544 }
545 TableSource::Subquery {
546 query, alias: _, ..
547 } => self.plan_statement(query),
548 TableSource::Lateral { source } => self.plan_table_source(source),
549 TableSource::TableFunction {
550 name, args, alias, ..
551 } => Ok(self.add_step(Step::Scan {
552 table: name.clone(),
553 alias: alias.clone(),
554 projections: args
555 .iter()
556 .map(|a| Projection {
557 expr: a.clone(),
558 alias: None,
559 })
560 .collect(),
561 predicate: None,
562 dependencies: vec![],
563 })),
564 TableSource::Unnest { expr, alias, .. } => Ok(self.add_step(Step::Scan {
565 table: "UNNEST".to_string(),
566 alias: alias.clone(),
567 projections: vec![Projection {
568 expr: *expr.clone(),
569 alias: None,
570 }],
571 predicate: None,
572 dependencies: vec![],
573 })),
574 TableSource::Pivot { source, alias, .. }
575 | TableSource::Unpivot { source, alias, .. } => {
576 let inner = self.plan_table_source(source)?;
579 Ok(self.add_step(Step::Project {
581 projections: vec![Projection {
582 expr: Expr::Wildcard,
583 alias: alias.clone(),
584 }],
585 dependencies: vec![inner],
586 }))
587 }
588 }
589 }
590
591 fn plan_set_operation(&mut self, set_op: &SetOperationStatement) -> Result<StepId> {
596 let left = self.plan_statement(&set_op.left)?;
597 let right = self.plan_statement(&set_op.right)?;
598
599 let mut current = self.add_step(Step::SetOperation {
600 op: set_op.op.clone(),
601 all: set_op.all,
602 projections: vec![],
603 dependencies: vec![left, right],
604 });
605
606 if !set_op.order_by.is_empty() {
607 current = self.add_step(Step::Sort {
608 order_by: set_op.order_by.clone(),
609 projections: vec![],
610 dependencies: vec![current],
611 });
612 }
613
614 if set_op.limit.is_some() || set_op.offset.is_some() {
615 current = self.add_step(Step::Limit {
616 limit: set_op.limit.clone(),
617 offset: set_op.offset.clone(),
618 projections: vec![],
619 dependencies: vec![current],
620 });
621 }
622
623 Ok(current)
624 }
625}
626
627fn fully_qualified_name(tref: &TableRef) -> String {
633 let mut parts = Vec::new();
634 if let Some(catalog) = &tref.catalog {
635 parts.push(catalog.as_str());
636 }
637 if let Some(schema) = &tref.schema {
638 parts.push(schema.as_str());
639 }
640 parts.push(tref.name.as_str());
641 parts.join(".")
642}
643
644fn select_items_to_projections(items: &[SelectItem]) -> Vec<Projection> {
646 items
647 .iter()
648 .map(|item| match item {
649 SelectItem::Wildcard => Projection {
650 expr: Expr::Wildcard,
651 alias: None,
652 },
653 SelectItem::QualifiedWildcard { table } => Projection {
654 expr: Expr::QualifiedWildcard {
655 table: table.clone(),
656 },
657 alias: None,
658 },
659 SelectItem::Expr { expr, alias, .. } => Projection {
660 expr: expr.clone(),
661 alias: alias.clone(),
662 },
663 })
664 .collect()
665}
666
667fn has_aggregates(items: &[SelectItem]) -> bool {
669 items.iter().any(|item| match item {
670 SelectItem::Expr { expr, .. } => expr_has_aggregate(expr),
671 _ => false,
672 })
673}
674
675fn expr_has_aggregate(expr: &Expr) -> bool {
677 match expr {
678 Expr::Function { name, .. } => is_aggregate_name(name),
679 Expr::TypedFunction { func, .. } => typed_function_is_aggregate(func),
680 Expr::BinaryOp { left, right, .. } => expr_has_aggregate(left) || expr_has_aggregate(right),
681 Expr::UnaryOp { expr, .. } => expr_has_aggregate(expr),
682 Expr::Cast { expr, .. } | Expr::TryCast { expr, .. } => expr_has_aggregate(expr),
683 Expr::Case {
684 operand,
685 when_clauses,
686 else_clause,
687 } => {
688 operand.as_ref().is_some_and(|e| expr_has_aggregate(e))
689 || when_clauses
690 .iter()
691 .any(|(cond, result)| expr_has_aggregate(cond) || expr_has_aggregate(result))
692 || else_clause.as_ref().is_some_and(|e| expr_has_aggregate(e))
693 }
694 Expr::Alias { expr, .. } => expr_has_aggregate(expr),
695 _ => false,
696 }
697}
698
699fn is_aggregate_name(name: &str) -> bool {
701 matches!(
702 name.to_uppercase().as_str(),
703 "COUNT"
704 | "SUM"
705 | "AVG"
706 | "MIN"
707 | "MAX"
708 | "GROUP_CONCAT"
709 | "STRING_AGG"
710 | "ARRAY_AGG"
711 | "LISTAGG"
712 | "COLLECT_LIST"
713 | "COLLECT_SET"
714 | "ANY_VALUE"
715 | "APPROX_COUNT_DISTINCT"
716 | "PERCENTILE_CONT"
717 | "PERCENTILE_DISC"
718 | "STDDEV"
719 | "STDDEV_POP"
720 | "STDDEV_SAMP"
721 | "VARIANCE"
722 | "VAR_POP"
723 | "VAR_SAMP"
724 | "CORR"
725 | "COVAR_POP"
726 | "COVAR_SAMP"
727 | "FIRST_VALUE"
728 | "LAST_VALUE"
729 | "NTH_VALUE"
730 | "BIT_AND"
731 | "BIT_OR"
732 | "BIT_XOR"
733 | "BOOL_AND"
734 | "BOOL_OR"
735 | "EVERY"
736 )
737}
738
739fn typed_function_is_aggregate(func: &TypedFunction) -> bool {
741 matches!(
742 func,
743 TypedFunction::Count { .. }
744 | TypedFunction::Sum { .. }
745 | TypedFunction::Avg { .. }
746 | TypedFunction::Min { .. }
747 | TypedFunction::Max { .. }
748 | TypedFunction::ArrayAgg { .. }
749 | TypedFunction::ApproxDistinct { .. }
750 | TypedFunction::Variance { .. }
751 | TypedFunction::Stddev { .. }
752 | TypedFunction::GroupConcat { .. }
753 )
754}
755
756fn extract_aggregates(items: &[SelectItem]) -> Vec<Projection> {
758 let mut aggs = Vec::new();
759 for item in items {
760 if let SelectItem::Expr { expr, alias, .. } = item {
761 collect_aggregates(expr, alias, &mut aggs);
762 }
763 }
764 aggs
765}
766
767fn collect_aggregates(expr: &Expr, alias: &Option<String>, out: &mut Vec<Projection>) {
768 match expr {
769 Expr::Function { name, .. } if is_aggregate_name(name) => {
770 out.push(Projection {
771 expr: expr.clone(),
772 alias: alias.clone(),
773 });
774 }
775 Expr::TypedFunction { func, .. } if typed_function_is_aggregate(func) => {
776 out.push(Projection {
777 expr: expr.clone(),
778 alias: alias.clone(),
779 });
780 }
781 Expr::BinaryOp { left, right, .. } => {
782 collect_aggregates(left, &None, out);
783 collect_aggregates(right, &None, out);
784 }
785 Expr::Alias { expr: inner, name } => {
786 collect_aggregates(inner, &Some(name.clone()), out);
787 }
788 _ => {}
789 }
790}
791
792#[cfg(test)]
797mod tests {
798 use super::*;
799 use crate::dialects::Dialect;
800 use crate::parser::parse;
801
802 #[test]
803 fn test_simple_select() {
804 let ast = parse("SELECT a, b FROM t", Dialect::Ansi).unwrap();
805 let p = plan(&ast).unwrap();
806 assert!(p.len() >= 2); assert_eq!(p.get(p.root()).unwrap().kind(), "Project");
808 }
809
810 #[test]
811 fn test_select_with_where() {
812 let ast = parse("SELECT a FROM t WHERE a > 1", Dialect::Ansi).unwrap();
813 let p = plan(&ast).unwrap();
814 let kinds: Vec<_> = p.steps().iter().map(|s| s.kind()).collect();
816 assert!(kinds.contains(&"Scan"));
817 assert!(kinds.contains(&"Filter"));
818 assert!(kinds.contains(&"Project"));
819 }
820
821 #[test]
822 fn test_select_with_order_by() {
823 let ast = parse("SELECT a FROM t ORDER BY a", Dialect::Ansi).unwrap();
824 let p = plan(&ast).unwrap();
825 let kinds: Vec<_> = p.steps().iter().map(|s| s.kind()).collect();
826 assert!(kinds.contains(&"Sort"));
827 }
828
829 #[test]
830 fn test_select_with_group_by() {
831 let ast = parse("SELECT a, COUNT(*) FROM t GROUP BY a", Dialect::Ansi).unwrap();
832 let p = plan(&ast).unwrap();
833 let kinds: Vec<_> = p.steps().iter().map(|s| s.kind()).collect();
834 assert!(kinds.contains(&"Aggregate"));
835 }
836
837 #[test]
838 fn test_select_with_having() {
839 let ast = parse(
840 "SELECT a, COUNT(*) FROM t GROUP BY a HAVING COUNT(*) > 1",
841 Dialect::Ansi,
842 )
843 .unwrap();
844 let p = plan(&ast).unwrap();
845 let kinds: Vec<_> = p.steps().iter().map(|s| s.kind()).collect();
846 assert!(kinds.contains(&"Aggregate"));
848 assert!(kinds.contains(&"Filter"));
849 }
850
851 #[test]
852 fn test_join() {
853 let ast = parse("SELECT a.x FROM a JOIN b ON a.id = b.id", Dialect::Ansi).unwrap();
854 let p = plan(&ast).unwrap();
855 let kinds: Vec<_> = p.steps().iter().map(|s| s.kind()).collect();
856 assert!(kinds.contains(&"Join"));
857 }
858
859 #[test]
860 fn test_multiple_joins() {
861 let ast = parse(
862 "SELECT a.x FROM a JOIN b ON a.id = b.id JOIN c ON b.id = c.id",
863 Dialect::Ansi,
864 )
865 .unwrap();
866 let p = plan(&ast).unwrap();
867 let join_count = p.steps().iter().filter(|s| s.kind() == "Join").count();
868 assert_eq!(join_count, 2);
869 }
870
871 #[test]
872 fn test_union() {
873 let ast = parse("SELECT a FROM t1 UNION ALL SELECT b FROM t2", Dialect::Ansi).unwrap();
874 let p = plan(&ast).unwrap();
875 let kinds: Vec<_> = p.steps().iter().map(|s| s.kind()).collect();
876 assert!(kinds.contains(&"SetOperation"));
877 }
878
879 #[test]
880 fn test_limit_offset() {
881 let ast = parse("SELECT a FROM t LIMIT 10 OFFSET 5", Dialect::Ansi).unwrap();
882 let p = plan(&ast).unwrap();
883 let kinds: Vec<_> = p.steps().iter().map(|s| s.kind()).collect();
884 assert!(kinds.contains(&"Limit"));
885 }
886
887 #[test]
888 fn test_distinct() {
889 let ast = parse("SELECT DISTINCT a FROM t", Dialect::Ansi).unwrap();
890 let p = plan(&ast).unwrap();
891 let kinds: Vec<_> = p.steps().iter().map(|s| s.kind()).collect();
892 assert!(kinds.contains(&"Distinct"));
893 }
894
895 #[test]
896 fn test_subquery_in_from() {
897 let ast = parse("SELECT x FROM (SELECT a AS x FROM t) sub", Dialect::Ansi).unwrap();
898 let p = plan(&ast).unwrap();
899 assert!(p.len() >= 3);
901 }
902
903 #[test]
904 fn test_complex_query() {
905 let ast = parse(
906 "SELECT a, SUM(b) AS total FROM t WHERE c > 0 GROUP BY a HAVING SUM(b) > 10 ORDER BY total DESC LIMIT 5",
907 Dialect::Ansi,
908 ).unwrap();
909 let p = plan(&ast).unwrap();
910 let kinds: Vec<_> = p.steps().iter().map(|s| s.kind()).collect();
911 assert!(kinds.contains(&"Scan"));
912 assert!(kinds.contains(&"Filter")); assert!(kinds.contains(&"Aggregate"));
914 assert!(kinds.contains(&"Sort"));
915 assert!(kinds.contains(&"Limit"));
916 assert!(kinds.contains(&"Project"));
917 }
918
919 #[test]
920 fn test_dag_dependencies() {
921 let ast = parse("SELECT a FROM t1 JOIN t2 ON t1.id = t2.id", Dialect::Ansi).unwrap();
922 let p = plan(&ast).unwrap();
923 for (i, step) in p.steps().iter().enumerate() {
925 for dep in step.dependencies() {
926 assert!(dep.0 < i, "step {i} depends on {dep} which is not earlier");
927 }
928 }
929 }
930
931 #[test]
932 fn test_mermaid_output() {
933 let ast = parse("SELECT a FROM t WHERE a > 1", Dialect::Ansi).unwrap();
934 let p = plan(&ast).unwrap();
935 let mermaid = p.to_mermaid();
936 assert!(mermaid.starts_with("graph TD"));
937 assert!(mermaid.contains("Scan"));
938 }
939
940 #[test]
941 fn test_dot_output() {
942 let ast = parse("SELECT a FROM t WHERE a > 1", Dialect::Ansi).unwrap();
943 let p = plan(&ast).unwrap();
944 let dot = p.to_dot();
945 assert!(dot.starts_with("digraph plan"));
946 assert!(dot.contains("Scan"));
947 }
948
949 #[test]
950 fn test_display() {
951 let ast = parse("SELECT a FROM t", Dialect::Ansi).unwrap();
952 let p = plan(&ast).unwrap();
953 let display = format!("{p}");
954 assert!(display.contains("(root)"));
955 }
956
957 #[test]
958 fn test_ddl_rejected() {
959 let ast = parse("CREATE TABLE t (a INT)", Dialect::Ansi).unwrap();
960 assert!(plan(&ast).is_err());
961 }
962
963 #[test]
964 fn test_no_from_select() {
965 let ast = parse("SELECT 1 + 2", Dialect::Ansi).unwrap();
966 let p = plan(&ast).unwrap();
967 assert!(!p.is_empty());
968 }
969
970 #[test]
971 fn test_left_join() {
972 let ast = parse(
973 "SELECT a.x FROM a LEFT JOIN b ON a.id = b.id",
974 Dialect::Ansi,
975 )
976 .unwrap();
977 let p = plan(&ast).unwrap();
978 let join_step = p.steps().iter().find(|s| s.kind() == "Join").unwrap();
979 if let Step::Join { join_type, .. } = join_step {
980 assert_eq!(*join_type, JoinType::Left);
981 } else {
982 panic!("expected Join step");
983 }
984 }
985
986 #[test]
987 fn test_cross_join() {
988 let ast = parse("SELECT a.x FROM a CROSS JOIN b", Dialect::Ansi).unwrap();
989 let p = plan(&ast).unwrap();
990 let join_step = p.steps().iter().find(|s| s.kind() == "Join").unwrap();
991 if let Step::Join { join_type, .. } = join_step {
992 assert_eq!(*join_type, JoinType::Cross);
993 } else {
994 panic!("expected Join step");
995 }
996 }
997
998 #[test]
999 fn test_union_with_order_limit() {
1000 let ast = parse(
1001 "SELECT a FROM t1 UNION SELECT b FROM t2 ORDER BY 1 LIMIT 10",
1002 Dialect::Ansi,
1003 )
1004 .unwrap();
1005 let p = plan(&ast).unwrap();
1006 let kinds: Vec<_> = p.steps().iter().map(|s| s.kind()).collect();
1007 assert!(kinds.contains(&"SetOperation"));
1008 assert!(kinds.contains(&"Sort"));
1009 assert!(kinds.contains(&"Limit"));
1010 }
1011}