1use std::fmt;
2
3use substrait::proto::plan_rel::RelType as PlanRelType;
4use substrait::proto::read_rel::ReadType;
5use substrait::proto::rel::RelType;
6use substrait::proto::rel_common::EmitKind;
7use substrait::proto::{
8 AggregateFunction, AggregateRel, Expression, FilterRel, NamedStruct, PlanRel, ProjectRel,
9 ReadRel, Rel, RelCommon, RelRoot, Type,
10};
11
12use super::expressions::Reference;
13use super::types::Name;
14use super::{PlanError, Scope, Textify};
15
16pub trait NamedRelation {
17 fn name(&self) -> &'static str;
18}
19
20impl NamedRelation for Rel {
21 fn name(&self) -> &'static str {
22 match self.rel_type.as_ref() {
23 None => "UnknownRel",
24 Some(RelType::Read(_)) => "Read",
25 Some(RelType::Filter(_)) => "Filter",
26 Some(RelType::Project(_)) => "Project",
27 Some(RelType::Fetch(_)) => "Fetch",
28 Some(RelType::Aggregate(_)) => "Aggregate",
29 Some(RelType::Sort(_)) => "Sort",
30 Some(RelType::HashJoin(_)) => "HashJoin",
31 Some(RelType::Exchange(_)) => "Exchange",
32 Some(RelType::Join(_)) => "Join",
33 Some(RelType::Set(_)) => "Set",
34 Some(RelType::ExtensionLeaf(_)) => "ExtensionLeaf",
35 Some(RelType::Cross(_)) => "Cross",
36 Some(RelType::Reference(_)) => "Reference",
37 Some(RelType::ExtensionSingle(_)) => "ExtensionSingle",
38 Some(RelType::ExtensionMulti(_)) => "ExtensionMulti",
39 Some(RelType::Write(_)) => "Write",
40 Some(RelType::Ddl(_)) => "Ddl",
41 Some(RelType::Update(_)) => "Update",
42 Some(RelType::MergeJoin(_)) => "MergeJoin",
43 Some(RelType::NestedLoopJoin(_)) => "NestedLoopJoin",
44 Some(RelType::Window(_)) => "Window",
45 Some(RelType::Expand(_)) => "Expand",
46 }
47 }
48}
49
50#[derive(Debug, Clone)]
51pub enum Value<'a> {
52 Name(Name<'a>),
53 TableName(Vec<Name<'a>>),
54 Field(Option<Name<'a>>, Option<&'a Type>),
55 Tuple(Vec<Value<'a>>),
56 List(Vec<Value<'a>>),
57 Reference(i32),
58 Expression(&'a Expression),
59 AggregateFunction(&'a AggregateFunction),
60 Missing(PlanError),
61}
62
63impl<'a> Value<'a> {
64 pub fn expect(maybe_value: Option<Self>, f: impl FnOnce() -> PlanError) -> Self {
65 match maybe_value {
66 Some(s) => s,
67 None => Value::Missing(f()),
68 }
69 }
70}
71
72impl<'a> From<Result<Vec<Name<'a>>, PlanError>> for Value<'a> {
73 fn from(token: Result<Vec<Name<'a>>, PlanError>) -> Self {
74 match token {
75 Ok(value) => Value::TableName(value),
76 Err(err) => Value::Missing(err),
77 }
78 }
79}
80
81impl<'a> Textify for Value<'a> {
82 fn name() -> &'static str {
83 "Value"
84 }
85
86 fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
87 match self {
88 Value::Name(name) => write!(w, "{}", ctx.display(name)),
89 Value::TableName(names) => write!(w, "{}", ctx.separated(names, ".")),
90 Value::Field(name, typ) => {
91 write!(w, "{}:{}", ctx.expect(name.as_ref()), ctx.expect(*typ))
92 }
93 Value::Tuple(values) => write!(w, "({})", ctx.separated(values, ", ")),
94 Value::List(values) => write!(w, "[{}]", ctx.separated(values, ", ")),
95 Value::Reference(i) => write!(w, "{}", Reference(*i)),
96 Value::Expression(e) => write!(w, "{}", ctx.display(*e)),
97 Value::AggregateFunction(agg_fn) => agg_fn.textify(ctx, w),
98 Value::Missing(err) => write!(w, "{}", ctx.failure(err.clone())),
99 }
100 }
101}
102
103fn schema_to_values<'a>(schema: &'a NamedStruct) -> Vec<Value<'a>> {
104 let mut fields = schema
105 .r#struct
106 .as_ref()
107 .map(|s| s.types.iter())
108 .into_iter()
109 .flatten();
110 let mut names = schema.names.iter();
111
112 let mut values = Vec::new();
116 loop {
117 let field = fields.next();
118 let name = names.next().map(|n| Name(n));
119 if field.is_none() && name.is_none() {
120 break;
121 }
122
123 values.push(Value::Field(name, field));
124 }
125
126 values
127}
128
129struct Emitted<'a> {
130 pub values: &'a [Value<'a>],
131 pub emit: Option<&'a EmitKind>,
132}
133
134impl<'a> Emitted<'a> {
135 pub fn new(values: &'a [Value<'a>], emit: Option<&'a EmitKind>) -> Self {
136 Self { values, emit }
137 }
138
139 pub fn write_direct<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
140 write!(w, "{}", ctx.separated(self.values.iter(), ", "))
141 }
142}
143
144impl<'a> Textify for Emitted<'a> {
145 fn name() -> &'static str {
146 "Emitted"
147 }
148
149 fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
150 if ctx.options().show_emit {
151 return self.write_direct(ctx, w);
152 }
153
154 let indices = match &self.emit {
155 Some(EmitKind::Emit(e)) => &e.output_mapping,
156 Some(EmitKind::Direct(_)) => return self.write_direct(ctx, w),
157 None => return self.write_direct(ctx, w),
158 };
159
160 for (i, &index) in indices.iter().enumerate() {
161 if i > 0 {
162 write!(w, ", ")?;
163 }
164
165 write!(w, "{}", ctx.expect(self.values.get(index as usize)))?;
166 }
167
168 Ok(())
169 }
170}
171
172pub struct Relation<'a> {
173 pub name: &'a str,
174 pub arguments: Vec<Value<'a>>,
175 pub columns: Vec<Value<'a>>,
176 pub emit: Option<&'a EmitKind>,
177 pub children: Vec<Option<Relation<'a>>>,
179}
180
181impl Textify for Relation<'_> {
182 fn name() -> &'static str {
183 "Relation"
184 }
185
186 fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
187 let args = ctx.separated(self.arguments.iter(), ", ");
188 let cols = Emitted::new(&self.columns, self.emit);
189
190 let indent = ctx.indent();
191 let name = self.name;
192 let cols = ctx.display(&cols);
193 if self.arguments.is_empty() {
194 write!(w, "{indent}{name}[{cols}]")?;
195 } else {
196 write!(w, "{indent}{name}[{args} => {cols}]")?;
197 }
198 let child_scope = ctx.push_indent();
199 for child in self.children.iter().flatten() {
200 writeln!(w)?;
201 child.textify(&child_scope, w)?;
202 }
203 Ok(())
204 }
205}
206
207impl<'a> Relation<'a> {
208 pub fn emitted(&self) -> usize {
209 match self.emit {
210 Some(EmitKind::Emit(e)) => e.output_mapping.len(),
211 Some(EmitKind::Direct(_)) => self.columns.len(),
212 None => self.columns.len(),
213 }
214 }
215}
216
217#[derive(Debug, Copy, Clone)]
218pub struct TableName<'a>(&'a [String]);
219
220impl<'a> Textify for TableName<'a> {
221 fn name() -> &'static str {
222 "TableName"
223 }
224
225 fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
226 let names = self.0.iter().map(|n| Name(n)).collect::<Vec<_>>();
227 write!(w, "{}", ctx.separated(names.iter(), "."))
228 }
229}
230
231pub fn get_table_name(rel: Option<&ReadType>) -> Result<&[String], PlanError> {
232 match rel {
233 Some(ReadType::NamedTable(r)) => Ok(r.names.as_slice()),
234 _ => Err(PlanError::unimplemented(
235 "ReadRel",
236 Some("table_name"),
237 format!("Unexpected read type {rel:?}") as String,
238 )),
239 }
240}
241
242impl<'a> From<&'a ReadRel> for Relation<'a> {
243 fn from(rel: &'a ReadRel) -> Self {
244 let name = get_table_name(rel.read_type.as_ref());
245 let named: Value = match name {
246 Ok(n) => Value::TableName(n.iter().map(|n| Name(n)).collect()),
247 Err(e) => Value::Missing(e),
248 };
249
250 let columns = match rel.base_schema {
251 Some(ref schema) => schema_to_values(schema),
252 None => {
253 let err = PlanError::unimplemented(
254 "ReadRel",
255 Some("base_schema"),
256 "Base schema is required",
257 );
258 vec![Value::Missing(err)]
259 }
260 };
261 let emit = rel.common.as_ref().and_then(|c| c.emit_kind.as_ref());
262
263 Relation {
264 name: "Read",
265 arguments: vec![named],
266 columns,
267 emit,
268 children: vec![],
269 }
270 }
271}
272
273pub fn get_emit(rel: Option<&RelCommon>) -> Option<&EmitKind> {
274 rel.as_ref().and_then(|c| c.emit_kind.as_ref())
275}
276
277impl<'a> Relation<'a> {
278 pub fn input_refs(&self) -> Vec<Value<'a>> {
286 let len = self.emitted();
287 (0..len).map(|i| Value::Reference(i as i32)).collect()
288 }
289
290 pub fn convert_children(refs: Vec<Option<&'a Rel>>) -> (Vec<Option<Relation<'a>>>, usize) {
294 let mut children = vec![];
295 let mut inputs = 0;
296
297 for maybe_rel in refs {
298 match maybe_rel {
299 Some(rel) => {
300 let child = Relation::from(rel);
301 inputs += child.emitted();
302 children.push(Some(child));
303 }
304 None => children.push(None),
305 }
306 }
307
308 (children, inputs)
309 }
310}
311
312impl<'a> From<&'a FilterRel> for Relation<'a> {
313 fn from(rel: &'a FilterRel) -> Self {
314 let condition = rel
315 .condition
316 .as_ref()
317 .map(|c| Value::Expression(c.as_ref()));
318 let condition = Value::expect(condition, || {
319 PlanError::unimplemented("FilterRel", Some("condition"), "Condition is None")
320 });
321 let emit = get_emit(rel.common.as_ref());
322 let (children, columns) = Relation::convert_children(vec![rel.input.as_deref()]);
323 let columns = (0..columns).map(|i| Value::Reference(i as i32)).collect();
324
325 Relation {
326 name: "Filter",
327 arguments: vec![condition],
328 columns,
329 emit,
330 children,
331 }
332 }
333}
334
335impl<'a> From<&'a ProjectRel> for Relation<'a> {
336 fn from(rel: &'a ProjectRel) -> Self {
337 let (children, columns) = Relation::convert_children(vec![rel.input.as_deref()]);
338 let expressions = rel.expressions.iter().map(Value::Expression);
339 let mut columns: Vec<Value> = (0..columns).map(|i| Value::Reference(i as i32)).collect();
340 columns.extend(expressions);
341
342 Relation {
343 name: "Project",
344 arguments: vec![],
345 columns,
346 emit: get_emit(rel.common.as_ref()),
347 children,
348 }
349 }
350}
351
352impl<'a> From<&'a Rel> for Relation<'a> {
353 fn from(rel: &'a Rel) -> Self {
354 match rel.rel_type.as_ref() {
355 Some(RelType::Read(r)) => Relation::from(r.as_ref()),
356 Some(RelType::Filter(r)) => Relation::from(r.as_ref()),
357 Some(RelType::Project(r)) => Relation::from(r.as_ref()),
358 Some(RelType::Aggregate(r)) => Relation::from(r.as_ref()),
359 _ => todo!(),
360 }
361 }
362}
363
364impl<'a> From<&'a AggregateRel> for Relation<'a> {
365 fn from(rel: &'a AggregateRel) -> Self {
375 let arguments = rel
377 .grouping_expressions
378 .iter()
379 .map(Value::Expression)
380 .collect();
381
382 let mut all_outputs: Vec<Value> = vec![];
384
385 let input_field_count = rel.grouping_expressions.len();
388 for i in 0..input_field_count {
389 all_outputs.push(Value::Reference(i as i32));
390 }
391
392 for m in &rel.measures {
395 if let Some(agg_fn) = m.measure.as_ref() {
396 all_outputs.push(Value::AggregateFunction(agg_fn));
397 }
398 }
399
400 let emit = get_emit(rel.common.as_ref());
402
403 Relation {
404 name: "Aggregate",
405 arguments,
406 columns: all_outputs,
407 emit,
408 children: rel
409 .input
410 .as_ref()
411 .map(|c| Some(Relation::from(c.as_ref())))
412 .into_iter()
413 .collect(),
414 }
415 }
416}
417
418impl Textify for RelRoot {
419 fn name() -> &'static str {
420 "RelRoot"
421 }
422
423 fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
424 let names = self.names.iter().map(|n| Name(n)).collect::<Vec<_>>();
425
426 write!(
427 w,
428 "{}Root[{}]",
429 ctx.indent(),
430 ctx.separated(names.iter(), ", ")
431 )?;
432 let child_scope = ctx.push_indent();
433 for child in self.input.iter() {
434 let child = Relation::from(child);
435 writeln!(w)?;
436 child.textify(&child_scope, w)?;
437 }
438
439 Ok(())
440 }
441}
442
443impl Textify for PlanRelType {
444 fn name() -> &'static str {
445 "PlanRelType"
446 }
447
448 fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
449 match self {
450 PlanRelType::Rel(rel) => Relation::from(rel).textify(ctx, w),
451 PlanRelType::Root(root) => root.textify(ctx, w),
452 }
453 }
454}
455
456impl Textify for PlanRel {
457 fn name() -> &'static str {
458 "PlanRel"
459 }
460
461 fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
464 write!(w, "{}", ctx.expect(self.rel_type.as_ref()))
465 }
466}
467
468#[cfg(test)]
469mod tests {
470 use substrait::proto::expression::literal::LiteralType;
471 use substrait::proto::expression::{Literal, RexType, ScalarFunction};
472 use substrait::proto::function_argument::ArgType;
473 use substrait::proto::read_rel::{NamedTable, ReadType};
474 use substrait::proto::rel_common::Emit;
475 use substrait::proto::r#type::{self as ptype, Kind, Nullability, Struct};
476 use substrait::proto::{
477 Expression, FunctionArgument, NamedStruct, ReadRel, Type, aggregate_rel,
478 };
479
480 use super::*;
481 use crate::fixtures::TestContext;
482
483 #[test]
484 fn test_read_rel() {
485 let ctx = TestContext::new();
486
487 let read_rel = ReadRel {
489 common: None,
490 base_schema: Some(NamedStruct {
491 names: vec!["col1".into(), "column 2".into()],
492 r#struct: Some(Struct {
493 type_variation_reference: 0,
494 types: vec![
495 Type {
496 kind: Some(Kind::I32(ptype::I32 {
497 type_variation_reference: 0,
498 nullability: Nullability::Nullable as i32,
499 })),
500 },
501 Type {
502 kind: Some(Kind::String(ptype::String {
503 type_variation_reference: 0,
504 nullability: Nullability::Nullable as i32,
505 })),
506 },
507 ],
508 nullability: Nullability::Nullable as i32,
509 }),
510 }),
511 filter: None,
512 best_effort_filter: None,
513 projection: None,
514 advanced_extension: None,
515 read_type: Some(ReadType::NamedTable(NamedTable {
516 names: vec!["some_db".into(), "test_table".into()],
517 advanced_extension: None,
518 })),
519 };
520
521 let rel = Relation::from(&read_rel);
522
523 let (result, errors) = ctx.textify(&rel);
524 assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
525 assert_eq!(
526 result,
527 "Read[some_db.test_table => col1:i32?, \"column 2\":string?]"
528 );
529 }
530
531 #[test]
532 fn test_filter_rel() {
533 let ctx = TestContext::new()
534 .with_uri(1, "test_uri")
535 .with_function(1, 10, "gt");
536
537 let read_rel = ReadRel {
539 common: None,
540 base_schema: Some(NamedStruct {
541 names: vec!["col1".into(), "col2".into()],
542 r#struct: Some(Struct {
543 type_variation_reference: 0,
544 types: vec![
545 Type {
546 kind: Some(Kind::I32(ptype::I32 {
547 type_variation_reference: 0,
548 nullability: Nullability::Nullable as i32,
549 })),
550 },
551 Type {
552 kind: Some(Kind::I32(ptype::I32 {
553 type_variation_reference: 0,
554 nullability: Nullability::Nullable as i32,
555 })),
556 },
557 ],
558 nullability: Nullability::Nullable as i32,
559 }),
560 }),
561 filter: None,
562 best_effort_filter: None,
563 projection: None,
564 advanced_extension: None,
565 read_type: Some(ReadType::NamedTable(NamedTable {
566 names: vec!["test_table".into()],
567 advanced_extension: None,
568 })),
569 };
570
571 let filter_expr = Expression {
573 rex_type: Some(RexType::ScalarFunction(ScalarFunction {
574 function_reference: 10, arguments: vec![
576 FunctionArgument {
577 arg_type: Some(ArgType::Value(Reference(0).into())),
578 },
579 FunctionArgument {
580 arg_type: Some(ArgType::Value(Expression {
581 rex_type: Some(RexType::Literal(Literal {
582 literal_type: Some(LiteralType::I32(10)),
583 nullable: false,
584 type_variation_reference: 0,
585 })),
586 })),
587 },
588 ],
589 options: vec![],
590 output_type: None,
591 #[allow(deprecated)]
592 args: vec![],
593 })),
594 };
595
596 let filter_rel = FilterRel {
597 common: None,
598 input: Some(Box::new(Rel {
599 rel_type: Some(RelType::Read(Box::new(read_rel))),
600 })),
601 condition: Some(Box::new(filter_expr)),
602 advanced_extension: None,
603 };
604
605 let rel = Rel {
606 rel_type: Some(RelType::Filter(Box::new(filter_rel))),
607 };
608
609 let rel = Relation::from(&rel);
610
611 let (result, errors) = ctx.textify(&rel);
612 assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
613 let expected = r#"
614Filter[gt($0, 10:i32) => $0, $1]
615 Read[test_table => col1:i32?, col2:i32?]"#
616 .trim_start();
617 assert_eq!(result, expected);
618 }
619
620 #[test]
621 fn test_aggregate_function_textify() {
622 let ctx = TestContext::new()
623 .with_uri(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
624 .with_function(1, 10, "sum")
625 .with_function(1, 11, "count");
626
627 let agg_fn = AggregateFunction {
629 function_reference: 10, arguments: vec![FunctionArgument {
631 arg_type: Some(ArgType::Value(Expression {
632 rex_type: Some(RexType::Selection(Box::new(
633 crate::parser::expressions::reference(1),
634 ))),
635 })),
636 }],
637 options: vec![],
638 output_type: None,
639 invocation: 0,
640 phase: 0,
641 sorts: vec![],
642 #[allow(deprecated)]
643 args: vec![],
644 };
645
646 let value = Value::AggregateFunction(&agg_fn);
647 let (result, errors) = ctx.textify(&value);
648
649 println!("Textification result: {result}");
650 if !errors.is_empty() {
651 println!("Errors: {errors:?}");
652 }
653
654 assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
655 assert_eq!(result, "sum($1)");
656 }
657
658 #[test]
659 fn test_aggregate_relation_textify() {
660 let ctx = TestContext::new()
661 .with_uri(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
662 .with_function(1, 10, "sum")
663 .with_function(1, 11, "count");
664
665 let agg_fn1 = AggregateFunction {
667 function_reference: 10, arguments: vec![FunctionArgument {
669 arg_type: Some(ArgType::Value(Expression {
670 rex_type: Some(RexType::Selection(Box::new(
671 crate::parser::expressions::reference(1),
672 ))),
673 })),
674 }],
675 options: vec![],
676 output_type: None,
677 invocation: 0,
678 phase: 0,
679 sorts: vec![],
680 #[allow(deprecated)]
681 args: vec![],
682 };
683
684 let agg_fn2 = AggregateFunction {
685 function_reference: 11, arguments: vec![FunctionArgument {
687 arg_type: Some(ArgType::Value(Expression {
688 rex_type: Some(RexType::Selection(Box::new(
689 crate::parser::expressions::reference(1),
690 ))),
691 })),
692 }],
693 options: vec![],
694 output_type: None,
695 invocation: 0,
696 phase: 0,
697 sorts: vec![],
698 #[allow(deprecated)]
699 args: vec![],
700 };
701
702 let aggregate_rel = AggregateRel {
703 input: Some(Box::new(Rel {
704 rel_type: Some(RelType::Read(Box::new(ReadRel {
705 common: None,
706 base_schema: Some(NamedStruct {
707 names: vec!["category".into(), "amount".into()],
708 r#struct: Some(Struct {
709 type_variation_reference: 0,
710 types: vec![
711 Type {
712 kind: Some(Kind::String(ptype::String {
713 type_variation_reference: 0,
714 nullability: Nullability::Nullable as i32,
715 })),
716 },
717 Type {
718 kind: Some(Kind::Fp64(ptype::Fp64 {
719 type_variation_reference: 0,
720 nullability: Nullability::Nullable as i32,
721 })),
722 },
723 ],
724 nullability: Nullability::Nullable as i32,
725 }),
726 }),
727 filter: None,
728 best_effort_filter: None,
729 projection: None,
730 advanced_extension: None,
731 read_type: Some(ReadType::NamedTable(NamedTable {
732 names: vec!["orders".into()],
733 advanced_extension: None,
734 })),
735 }))),
736 })),
737 grouping_expressions: vec![Expression {
738 rex_type: Some(RexType::Selection(Box::new(
739 crate::parser::expressions::reference(0),
740 ))),
741 }],
742 groupings: vec![],
743 measures: vec![
744 aggregate_rel::Measure {
745 measure: Some(agg_fn1),
746 filter: None,
747 },
748 aggregate_rel::Measure {
749 measure: Some(agg_fn2),
750 filter: None,
751 },
752 ],
753 common: Some(RelCommon {
754 emit_kind: Some(EmitKind::Emit(Emit {
755 output_mapping: vec![1, 2], })),
757 ..Default::default()
758 }),
759 advanced_extension: None,
760 };
761
762 let relation = Relation::from(&aggregate_rel);
763 let (result, errors) = ctx.textify(&relation);
764
765 println!("Aggregate relation textification result:");
766 println!("{result}");
767 if !errors.is_empty() {
768 println!("Errors: {errors:?}");
769 }
770
771 assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
772 assert!(result.contains("Aggregate[$0 => sum($1), count($1)]"));
774 }
775}