1use std::fmt;
23
24use crate::ast::*;
25use crate::errors::{Result, SqlglotError};
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
33pub struct StepId(usize);
34
35impl fmt::Display for StepId {
36 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
37 write!(f, "step_{}", self.0)
38 }
39}
40
41#[derive(Debug, Clone, PartialEq)]
47pub struct Projection {
48 pub expr: Expr,
50 pub alias: Option<String>,
52}
53
54#[derive(Debug, Clone, PartialEq)]
60pub enum Step {
61 Scan {
63 table: String,
65 alias: Option<String>,
67 projections: Vec<Projection>,
69 predicate: Option<Expr>,
71 dependencies: Vec<StepId>,
73 },
74 Filter {
76 predicate: Expr,
78 projections: Vec<Projection>,
80 dependencies: Vec<StepId>,
82 },
83 Project {
85 projections: Vec<Projection>,
87 dependencies: Vec<StepId>,
89 },
90 Aggregate {
92 group_by: Vec<Expr>,
94 aggregations: Vec<Projection>,
96 projections: Vec<Projection>,
98 dependencies: Vec<StepId>,
100 },
101 Sort {
103 order_by: Vec<OrderByItem>,
105 projections: Vec<Projection>,
107 dependencies: Vec<StepId>,
109 },
110 Join {
112 join_type: JoinType,
114 condition: Option<Expr>,
116 using_columns: Vec<String>,
118 projections: Vec<Projection>,
120 dependencies: Vec<StepId>,
122 },
123 Limit {
125 limit: Option<Expr>,
127 offset: Option<Expr>,
129 projections: Vec<Projection>,
131 dependencies: Vec<StepId>,
133 },
134 SetOperation {
136 op: SetOperationType,
138 all: bool,
140 projections: Vec<Projection>,
142 dependencies: Vec<StepId>,
144 },
145 Distinct {
147 projections: Vec<Projection>,
149 dependencies: Vec<StepId>,
151 },
152}
153
154impl Step {
155 #[must_use]
157 pub fn dependencies(&self) -> &[StepId] {
158 match self {
159 Step::Scan { dependencies, .. }
160 | Step::Filter { dependencies, .. }
161 | Step::Project { dependencies, .. }
162 | Step::Aggregate { dependencies, .. }
163 | Step::Sort { dependencies, .. }
164 | Step::Join { dependencies, .. }
165 | Step::Limit { dependencies, .. }
166 | Step::SetOperation { dependencies, .. }
167 | Step::Distinct { dependencies, .. } => dependencies,
168 }
169 }
170
171 #[must_use]
173 pub fn projections(&self) -> &[Projection] {
174 match self {
175 Step::Scan { projections, .. }
176 | Step::Filter { projections, .. }
177 | Step::Project { projections, .. }
178 | Step::Aggregate { projections, .. }
179 | Step::Sort { projections, .. }
180 | Step::Join { projections, .. }
181 | Step::Limit { projections, .. }
182 | Step::SetOperation { projections, .. }
183 | Step::Distinct { projections, .. } => projections,
184 }
185 }
186
187 #[must_use]
189 pub fn kind(&self) -> &'static str {
190 match self {
191 Step::Scan { .. } => "Scan",
192 Step::Filter { .. } => "Filter",
193 Step::Project { .. } => "Project",
194 Step::Aggregate { .. } => "Aggregate",
195 Step::Sort { .. } => "Sort",
196 Step::Join { .. } => "Join",
197 Step::Limit { .. } => "Limit",
198 Step::SetOperation { .. } => "SetOperation",
199 Step::Distinct { .. } => "Distinct",
200 }
201 }
202}
203
204#[derive(Debug, Clone)]
213pub struct Plan {
214 steps: Vec<Step>,
216 root: StepId,
218}
219
220impl Plan {
221 #[must_use]
223 pub fn root(&self) -> StepId {
224 self.root
225 }
226
227 #[must_use]
229 pub fn steps(&self) -> &[Step] {
230 &self.steps
231 }
232
233 #[must_use]
235 pub fn get(&self, id: StepId) -> Option<&Step> {
236 self.steps.get(id.0)
237 }
238
239 #[must_use]
241 pub fn len(&self) -> usize {
242 self.steps.len()
243 }
244
245 #[must_use]
247 pub fn is_empty(&self) -> bool {
248 self.steps.is_empty()
249 }
250
251 #[must_use]
253 pub fn to_mermaid(&self) -> String {
254 let mut out = String::from("graph TD\n");
255 for (i, step) in self.steps.iter().enumerate() {
256 let id = StepId(i);
257 let label = step_label(step);
258 out.push_str(&format!(" {id}[\"{label}\"]\n"));
259 for dep in step.dependencies() {
260 out.push_str(&format!(" {dep} --> {id}\n"));
261 }
262 }
263 out
264 }
265
266 #[must_use]
268 pub fn to_dot(&self) -> String {
269 let mut out = String::from("digraph plan {\n rankdir=BT;\n");
270 for (i, step) in self.steps.iter().enumerate() {
271 let id = StepId(i);
272 let label = step_label(step);
273 out.push_str(&format!(" {id} [label=\"{label}\"];\n"));
274 for dep in step.dependencies() {
275 out.push_str(&format!(" {dep} -> {id};\n"));
276 }
277 }
278 out.push_str("}\n");
279 out
280 }
281}
282
283impl fmt::Display for Plan {
284 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
285 for (i, step) in self.steps.iter().enumerate() {
286 let id = StepId(i);
287 let root_marker = if id == self.root { " (root)" } else { "" };
288 writeln!(f, "{id}{root_marker}: {}", step_label(step))?;
289 for dep in step.dependencies() {
290 writeln!(f, " <- {dep}")?;
291 }
292 }
293 Ok(())
294 }
295}
296
297fn step_label(step: &Step) -> String {
299 match step {
300 Step::Scan {
301 table,
302 alias,
303 predicate,
304 ..
305 } => {
306 let name = alias.as_deref().unwrap_or(table.as_str());
307 if predicate.is_some() {
308 format!("Scan({name} + filter)")
309 } else {
310 format!("Scan({name})")
311 }
312 }
313 Step::Filter { .. } => "Filter".to_string(),
314 Step::Project { projections, .. } => {
315 let cols: Vec<_> = projections
316 .iter()
317 .map(|p| {
318 p.alias
319 .as_deref()
320 .unwrap_or_else(|| expr_short_name(&p.expr))
321 })
322 .collect();
323 if cols.len() <= 4 {
324 format!("Project({})", cols.join(", "))
325 } else {
326 format!("Project({} cols)", cols.len())
327 }
328 }
329 Step::Aggregate { group_by, .. } => {
330 if group_by.is_empty() {
331 "Aggregate(scalar)".to_string()
332 } else {
333 format!("Aggregate({} keys)", group_by.len())
334 }
335 }
336 Step::Sort { order_by, .. } => format!("Sort({} keys)", order_by.len()),
337 Step::Join { join_type, .. } => format!("Join({join_type:?})"),
338 Step::Limit { limit, offset, .. } => {
339 let mut parts = Vec::new();
340 if limit.is_some() {
341 parts.push("limit");
342 }
343 if offset.is_some() {
344 parts.push("offset");
345 }
346 format!("Limit({})", parts.join("+"))
347 }
348 Step::SetOperation { op, all, .. } => {
349 let all_str = if *all { " ALL" } else { "" };
350 format!("{op:?}{all_str}")
351 }
352 Step::Distinct { .. } => "Distinct".to_string(),
353 }
354}
355
356fn expr_short_name(expr: &Expr) -> &str {
358 match expr {
359 Expr::Column { name, .. } => name.as_str(),
360 Expr::Wildcard => "*",
361 _ => "expr",
362 }
363}
364
365pub fn plan(statement: &Statement) -> Result<Plan> {
380 let mut builder = PlanBuilder::new();
381 let _root = builder.plan_statement(statement)?;
382 Ok(builder.build())
383}
384
385struct PlanBuilder {
387 steps: Vec<Step>,
388}
389
390impl PlanBuilder {
391 fn new() -> Self {
392 Self { steps: Vec::new() }
393 }
394
395 fn add_step(&mut self, step: Step) -> StepId {
396 let id = StepId(self.steps.len());
397 self.steps.push(step);
398 id
399 }
400
401 fn build(self) -> Plan {
402 let root = if self.steps.is_empty() {
403 StepId(0)
404 } else {
405 StepId(self.steps.len() - 1)
406 };
407 Plan {
408 steps: self.steps,
409 root,
410 }
411 }
412
413 fn plan_statement(&mut self, stmt: &Statement) -> Result<StepId> {
418 match stmt {
419 Statement::Select(sel) => self.plan_select(sel),
420 Statement::SetOperation(set_op) => self.plan_set_operation(set_op),
421 _ => Err(SqlglotError::Internal(format!(
422 "Planner does not support {:?} statements",
423 std::mem::discriminant(stmt)
424 ))),
425 }
426 }
427
428 fn plan_select(&mut self, sel: &SelectStatement) -> Result<StepId> {
433 let mut current = if let Some(from) = &sel.from {
435 self.plan_table_source(&from.source)?
436 } else {
437 self.add_step(Step::Scan {
439 table: String::new(),
440 alias: None,
441 projections: vec![],
442 predicate: None,
443 dependencies: vec![],
444 })
445 };
446
447 for join in &sel.joins {
449 let right = self.plan_table_source(&join.table)?;
450 let projections = vec![]; current = self.add_step(Step::Join {
452 join_type: join.join_type.clone(),
453 condition: join.on.clone(),
454 using_columns: join.using.clone(),
455 projections,
456 dependencies: vec![current, right],
457 });
458 }
459
460 if let Some(pred) = &sel.where_clause {
462 current = self.add_step(Step::Filter {
463 predicate: pred.clone(),
464 projections: vec![],
465 dependencies: vec![current],
466 });
467 }
468
469 if !sel.group_by.is_empty() || has_aggregates(&sel.columns) {
471 let aggregations = extract_aggregates(&sel.columns);
472 current = self.add_step(Step::Aggregate {
473 group_by: sel.group_by.clone(),
474 aggregations,
475 projections: vec![],
476 dependencies: vec![current],
477 });
478 }
479
480 if let Some(having) = &sel.having {
482 current = self.add_step(Step::Filter {
483 predicate: having.clone(),
484 projections: vec![],
485 dependencies: vec![current],
486 });
487 }
488
489 if sel.distinct {
491 current = self.add_step(Step::Distinct {
492 projections: vec![],
493 dependencies: vec![current],
494 });
495 }
496
497 if !sel.order_by.is_empty() {
499 current = self.add_step(Step::Sort {
500 order_by: sel.order_by.clone(),
501 projections: vec![],
502 dependencies: vec![current],
503 });
504 }
505
506 if sel.limit.is_some() || sel.offset.is_some() || sel.fetch_first.is_some() {
508 let limit = sel.limit.clone().or_else(|| sel.fetch_first.clone());
509 current = self.add_step(Step::Limit {
510 limit,
511 offset: sel.offset.clone(),
512 projections: vec![],
513 dependencies: vec![current],
514 });
515 }
516
517 let projections = select_items_to_projections(&sel.columns);
519 if !projections.is_empty() {
520 current = self.add_step(Step::Project {
521 projections,
522 dependencies: vec![current],
523 });
524 }
525
526 Ok(current)
527 }
528
529 fn plan_table_source(&mut self, source: &TableSource) -> Result<StepId> {
534 match source {
535 TableSource::Table(tref) => {
536 let table = fully_qualified_name(tref);
537 Ok(self.add_step(Step::Scan {
538 table,
539 alias: tref.alias.clone(),
540 projections: vec![],
541 predicate: None,
542 dependencies: vec![],
543 }))
544 }
545 TableSource::Subquery { query, alias: _ } => self.plan_statement(query),
546 TableSource::Lateral { source } => self.plan_table_source(source),
547 TableSource::TableFunction { name, args, alias } => Ok(self.add_step(Step::Scan {
548 table: name.clone(),
549 alias: alias.clone(),
550 projections: args
551 .iter()
552 .map(|a| Projection {
553 expr: a.clone(),
554 alias: None,
555 })
556 .collect(),
557 predicate: None,
558 dependencies: vec![],
559 })),
560 TableSource::Unnest { expr, alias, .. } => Ok(self.add_step(Step::Scan {
561 table: "UNNEST".to_string(),
562 alias: alias.clone(),
563 projections: vec![Projection {
564 expr: *expr.clone(),
565 alias: None,
566 }],
567 predicate: None,
568 dependencies: vec![],
569 })),
570 TableSource::Pivot { source, alias, .. }
571 | TableSource::Unpivot { source, alias, .. } => {
572 let inner = self.plan_table_source(source)?;
575 Ok(self.add_step(Step::Project {
577 projections: vec![Projection {
578 expr: Expr::Wildcard,
579 alias: alias.clone(),
580 }],
581 dependencies: vec![inner],
582 }))
583 }
584 }
585 }
586
587 fn plan_set_operation(&mut self, set_op: &SetOperationStatement) -> Result<StepId> {
592 let left = self.plan_statement(&set_op.left)?;
593 let right = self.plan_statement(&set_op.right)?;
594
595 let mut current = self.add_step(Step::SetOperation {
596 op: set_op.op.clone(),
597 all: set_op.all,
598 projections: vec![],
599 dependencies: vec![left, right],
600 });
601
602 if !set_op.order_by.is_empty() {
603 current = self.add_step(Step::Sort {
604 order_by: set_op.order_by.clone(),
605 projections: vec![],
606 dependencies: vec![current],
607 });
608 }
609
610 if set_op.limit.is_some() || set_op.offset.is_some() {
611 current = self.add_step(Step::Limit {
612 limit: set_op.limit.clone(),
613 offset: set_op.offset.clone(),
614 projections: vec![],
615 dependencies: vec![current],
616 });
617 }
618
619 Ok(current)
620 }
621}
622
623fn fully_qualified_name(tref: &TableRef) -> String {
629 let mut parts = Vec::new();
630 if let Some(catalog) = &tref.catalog {
631 parts.push(catalog.as_str());
632 }
633 if let Some(schema) = &tref.schema {
634 parts.push(schema.as_str());
635 }
636 parts.push(tref.name.as_str());
637 parts.join(".")
638}
639
640fn select_items_to_projections(items: &[SelectItem]) -> Vec<Projection> {
642 items
643 .iter()
644 .map(|item| match item {
645 SelectItem::Wildcard => Projection {
646 expr: Expr::Wildcard,
647 alias: None,
648 },
649 SelectItem::QualifiedWildcard { table } => Projection {
650 expr: Expr::QualifiedWildcard {
651 table: table.clone(),
652 },
653 alias: None,
654 },
655 SelectItem::Expr { expr, alias } => Projection {
656 expr: expr.clone(),
657 alias: alias.clone(),
658 },
659 })
660 .collect()
661}
662
663fn has_aggregates(items: &[SelectItem]) -> bool {
665 items.iter().any(|item| match item {
666 SelectItem::Expr { expr, .. } => expr_has_aggregate(expr),
667 _ => false,
668 })
669}
670
671fn expr_has_aggregate(expr: &Expr) -> bool {
673 match expr {
674 Expr::Function { name, .. } => is_aggregate_name(name),
675 Expr::TypedFunction { func, .. } => typed_function_is_aggregate(func),
676 Expr::BinaryOp { left, right, .. } => expr_has_aggregate(left) || expr_has_aggregate(right),
677 Expr::UnaryOp { expr, .. } => expr_has_aggregate(expr),
678 Expr::Cast { expr, .. } | Expr::TryCast { expr, .. } => expr_has_aggregate(expr),
679 Expr::Case {
680 operand,
681 when_clauses,
682 else_clause,
683 } => {
684 operand.as_ref().is_some_and(|e| expr_has_aggregate(e))
685 || when_clauses
686 .iter()
687 .any(|(cond, result)| expr_has_aggregate(cond) || expr_has_aggregate(result))
688 || else_clause.as_ref().is_some_and(|e| expr_has_aggregate(e))
689 }
690 Expr::Alias { expr, .. } => expr_has_aggregate(expr),
691 _ => false,
692 }
693}
694
695fn is_aggregate_name(name: &str) -> bool {
697 matches!(
698 name.to_uppercase().as_str(),
699 "COUNT"
700 | "SUM"
701 | "AVG"
702 | "MIN"
703 | "MAX"
704 | "GROUP_CONCAT"
705 | "STRING_AGG"
706 | "ARRAY_AGG"
707 | "LISTAGG"
708 | "COLLECT_LIST"
709 | "COLLECT_SET"
710 | "ANY_VALUE"
711 | "APPROX_COUNT_DISTINCT"
712 | "PERCENTILE_CONT"
713 | "PERCENTILE_DISC"
714 | "STDDEV"
715 | "STDDEV_POP"
716 | "STDDEV_SAMP"
717 | "VARIANCE"
718 | "VAR_POP"
719 | "VAR_SAMP"
720 | "CORR"
721 | "COVAR_POP"
722 | "COVAR_SAMP"
723 | "FIRST_VALUE"
724 | "LAST_VALUE"
725 | "NTH_VALUE"
726 | "BIT_AND"
727 | "BIT_OR"
728 | "BIT_XOR"
729 | "BOOL_AND"
730 | "BOOL_OR"
731 | "EVERY"
732 )
733}
734
735fn typed_function_is_aggregate(func: &TypedFunction) -> bool {
737 matches!(
738 func,
739 TypedFunction::Count { .. }
740 | TypedFunction::Sum { .. }
741 | TypedFunction::Avg { .. }
742 | TypedFunction::Min { .. }
743 | TypedFunction::Max { .. }
744 | TypedFunction::ArrayAgg { .. }
745 | TypedFunction::ApproxDistinct { .. }
746 | TypedFunction::Variance { .. }
747 | TypedFunction::Stddev { .. }
748 )
749}
750
751fn extract_aggregates(items: &[SelectItem]) -> Vec<Projection> {
753 let mut aggs = Vec::new();
754 for item in items {
755 if let SelectItem::Expr { expr, alias } = item {
756 collect_aggregates(expr, alias, &mut aggs);
757 }
758 }
759 aggs
760}
761
762fn collect_aggregates(expr: &Expr, alias: &Option<String>, out: &mut Vec<Projection>) {
763 match expr {
764 Expr::Function { name, .. } if is_aggregate_name(name) => {
765 out.push(Projection {
766 expr: expr.clone(),
767 alias: alias.clone(),
768 });
769 }
770 Expr::TypedFunction { func, .. } if typed_function_is_aggregate(func) => {
771 out.push(Projection {
772 expr: expr.clone(),
773 alias: alias.clone(),
774 });
775 }
776 Expr::BinaryOp { left, right, .. } => {
777 collect_aggregates(left, &None, out);
778 collect_aggregates(right, &None, out);
779 }
780 Expr::Alias { expr: inner, name } => {
781 collect_aggregates(inner, &Some(name.clone()), out);
782 }
783 _ => {}
784 }
785}
786
787#[cfg(test)]
792mod tests {
793 use super::*;
794 use crate::dialects::Dialect;
795 use crate::parser::parse;
796
797 #[test]
798 fn test_simple_select() {
799 let ast = parse("SELECT a, b FROM t", Dialect::Ansi).unwrap();
800 let p = plan(&ast).unwrap();
801 assert!(p.len() >= 2); assert_eq!(p.get(p.root()).unwrap().kind(), "Project");
803 }
804
805 #[test]
806 fn test_select_with_where() {
807 let ast = parse("SELECT a FROM t WHERE a > 1", Dialect::Ansi).unwrap();
808 let p = plan(&ast).unwrap();
809 let kinds: Vec<_> = p.steps().iter().map(|s| s.kind()).collect();
811 assert!(kinds.contains(&"Scan"));
812 assert!(kinds.contains(&"Filter"));
813 assert!(kinds.contains(&"Project"));
814 }
815
816 #[test]
817 fn test_select_with_order_by() {
818 let ast = parse("SELECT a FROM t ORDER BY a", Dialect::Ansi).unwrap();
819 let p = plan(&ast).unwrap();
820 let kinds: Vec<_> = p.steps().iter().map(|s| s.kind()).collect();
821 assert!(kinds.contains(&"Sort"));
822 }
823
824 #[test]
825 fn test_select_with_group_by() {
826 let ast = parse("SELECT a, COUNT(*) FROM t GROUP BY a", Dialect::Ansi).unwrap();
827 let p = plan(&ast).unwrap();
828 let kinds: Vec<_> = p.steps().iter().map(|s| s.kind()).collect();
829 assert!(kinds.contains(&"Aggregate"));
830 }
831
832 #[test]
833 fn test_select_with_having() {
834 let ast = parse(
835 "SELECT a, COUNT(*) FROM t GROUP BY a HAVING COUNT(*) > 1",
836 Dialect::Ansi,
837 )
838 .unwrap();
839 let p = plan(&ast).unwrap();
840 let kinds: Vec<_> = p.steps().iter().map(|s| s.kind()).collect();
841 assert!(kinds.contains(&"Aggregate"));
843 assert!(kinds.contains(&"Filter"));
844 }
845
846 #[test]
847 fn test_join() {
848 let ast = parse("SELECT a.x FROM a JOIN b ON a.id = b.id", Dialect::Ansi).unwrap();
849 let p = plan(&ast).unwrap();
850 let kinds: Vec<_> = p.steps().iter().map(|s| s.kind()).collect();
851 assert!(kinds.contains(&"Join"));
852 }
853
854 #[test]
855 fn test_multiple_joins() {
856 let ast = parse(
857 "SELECT a.x FROM a JOIN b ON a.id = b.id JOIN c ON b.id = c.id",
858 Dialect::Ansi,
859 )
860 .unwrap();
861 let p = plan(&ast).unwrap();
862 let join_count = p.steps().iter().filter(|s| s.kind() == "Join").count();
863 assert_eq!(join_count, 2);
864 }
865
866 #[test]
867 fn test_union() {
868 let ast = parse("SELECT a FROM t1 UNION ALL SELECT b FROM t2", Dialect::Ansi).unwrap();
869 let p = plan(&ast).unwrap();
870 let kinds: Vec<_> = p.steps().iter().map(|s| s.kind()).collect();
871 assert!(kinds.contains(&"SetOperation"));
872 }
873
874 #[test]
875 fn test_limit_offset() {
876 let ast = parse("SELECT a FROM t LIMIT 10 OFFSET 5", Dialect::Ansi).unwrap();
877 let p = plan(&ast).unwrap();
878 let kinds: Vec<_> = p.steps().iter().map(|s| s.kind()).collect();
879 assert!(kinds.contains(&"Limit"));
880 }
881
882 #[test]
883 fn test_distinct() {
884 let ast = parse("SELECT DISTINCT a FROM t", Dialect::Ansi).unwrap();
885 let p = plan(&ast).unwrap();
886 let kinds: Vec<_> = p.steps().iter().map(|s| s.kind()).collect();
887 assert!(kinds.contains(&"Distinct"));
888 }
889
890 #[test]
891 fn test_subquery_in_from() {
892 let ast = parse("SELECT x FROM (SELECT a AS x FROM t) sub", Dialect::Ansi).unwrap();
893 let p = plan(&ast).unwrap();
894 assert!(p.len() >= 3);
896 }
897
898 #[test]
899 fn test_complex_query() {
900 let ast = parse(
901 "SELECT a, SUM(b) AS total FROM t WHERE c > 0 GROUP BY a HAVING SUM(b) > 10 ORDER BY total DESC LIMIT 5",
902 Dialect::Ansi,
903 ).unwrap();
904 let p = plan(&ast).unwrap();
905 let kinds: Vec<_> = p.steps().iter().map(|s| s.kind()).collect();
906 assert!(kinds.contains(&"Scan"));
907 assert!(kinds.contains(&"Filter")); assert!(kinds.contains(&"Aggregate"));
909 assert!(kinds.contains(&"Sort"));
910 assert!(kinds.contains(&"Limit"));
911 assert!(kinds.contains(&"Project"));
912 }
913
914 #[test]
915 fn test_dag_dependencies() {
916 let ast = parse("SELECT a FROM t1 JOIN t2 ON t1.id = t2.id", Dialect::Ansi).unwrap();
917 let p = plan(&ast).unwrap();
918 for (i, step) in p.steps().iter().enumerate() {
920 for dep in step.dependencies() {
921 assert!(dep.0 < i, "step {i} depends on {dep} which is not earlier");
922 }
923 }
924 }
925
926 #[test]
927 fn test_mermaid_output() {
928 let ast = parse("SELECT a FROM t WHERE a > 1", Dialect::Ansi).unwrap();
929 let p = plan(&ast).unwrap();
930 let mermaid = p.to_mermaid();
931 assert!(mermaid.starts_with("graph TD"));
932 assert!(mermaid.contains("Scan"));
933 }
934
935 #[test]
936 fn test_dot_output() {
937 let ast = parse("SELECT a FROM t WHERE a > 1", Dialect::Ansi).unwrap();
938 let p = plan(&ast).unwrap();
939 let dot = p.to_dot();
940 assert!(dot.starts_with("digraph plan"));
941 assert!(dot.contains("Scan"));
942 }
943
944 #[test]
945 fn test_display() {
946 let ast = parse("SELECT a FROM t", Dialect::Ansi).unwrap();
947 let p = plan(&ast).unwrap();
948 let display = format!("{p}");
949 assert!(display.contains("(root)"));
950 }
951
952 #[test]
953 fn test_ddl_rejected() {
954 let ast = parse("CREATE TABLE t (a INT)", Dialect::Ansi).unwrap();
955 assert!(plan(&ast).is_err());
956 }
957
958 #[test]
959 fn test_no_from_select() {
960 let ast = parse("SELECT 1 + 2", Dialect::Ansi).unwrap();
961 let p = plan(&ast).unwrap();
962 assert!(!p.is_empty());
963 }
964
965 #[test]
966 fn test_left_join() {
967 let ast = parse(
968 "SELECT a.x FROM a LEFT JOIN b ON a.id = b.id",
969 Dialect::Ansi,
970 )
971 .unwrap();
972 let p = plan(&ast).unwrap();
973 let join_step = p.steps().iter().find(|s| s.kind() == "Join").unwrap();
974 if let Step::Join { join_type, .. } = join_step {
975 assert_eq!(*join_type, JoinType::Left);
976 } else {
977 panic!("expected Join step");
978 }
979 }
980
981 #[test]
982 fn test_cross_join() {
983 let ast = parse("SELECT a.x FROM a CROSS JOIN b", Dialect::Ansi).unwrap();
984 let p = plan(&ast).unwrap();
985 let join_step = p.steps().iter().find(|s| s.kind() == "Join").unwrap();
986 if let Step::Join { join_type, .. } = join_step {
987 assert_eq!(*join_type, JoinType::Cross);
988 } else {
989 panic!("expected Join step");
990 }
991 }
992
993 #[test]
994 fn test_union_with_order_limit() {
995 let ast = parse(
996 "SELECT a FROM t1 UNION SELECT b FROM t2 ORDER BY 1 LIMIT 10",
997 Dialect::Ansi,
998 )
999 .unwrap();
1000 let p = plan(&ast).unwrap();
1001 let kinds: Vec<_> = p.steps().iter().map(|s| s.kind()).collect();
1002 assert!(kinds.contains(&"SetOperation"));
1003 assert!(kinds.contains(&"Sort"));
1004 assert!(kinds.contains(&"Limit"));
1005 }
1006}