vortex_expr/
display.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4pub enum DisplayFormat {
5    Compact,
6    Tree,
7}
8
9/// Configurable display trait for expressions.
10pub trait DisplayAs {
11    fn fmt_as(&self, df: DisplayFormat, f: &mut std::fmt::Formatter) -> std::fmt::Result;
12
13    fn child_names(&self) -> Option<Vec<String>> {
14        None
15    }
16}
17
18pub struct DisplayTreeExpr<'a>(pub &'a dyn crate::VortexExpr);
19
20impl std::fmt::Display for DisplayTreeExpr<'_> {
21    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
22        pub use termtree::Tree;
23        fn make_tree(expr: &dyn crate::VortexExpr) -> Result<Tree<String>, std::fmt::Error> {
24            let node_name = TreeNodeDisplay(expr).to_string();
25
26            // Get child names for display purposes
27            let child_names = DisplayAs::child_names(expr);
28            let children = expr.children();
29
30            let child_trees: Result<Vec<Tree<String>>, _> = if let Some(names) = child_names
31                && names.len() == children.len()
32            {
33                children
34                    .iter()
35                    .zip(names.iter())
36                    .map(|(child, name)| {
37                        let child_tree = make_tree(child.as_ref())?;
38                        Ok(Tree::new(format!("{}: {}", name, child_tree.root))
39                            .with_leaves(child_tree.leaves))
40                    })
41                    .collect()
42            } else {
43                children
44                    .iter()
45                    .map(|child| make_tree(child.as_ref()))
46                    .collect()
47            };
48
49            Ok(Tree::new(node_name).with_leaves(child_trees?))
50        }
51
52        write!(f, "{}", make_tree(self.0)?)
53    }
54}
55
56struct TreeNodeDisplay<'a>(&'a dyn crate::VortexExpr);
57
58impl<'a> std::fmt::Display for TreeNodeDisplay<'a> {
59    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60        self.0.fmt_as(DisplayFormat::Tree, f)
61    }
62}
63
64#[cfg(test)]
65mod tests {
66    use vortex_array::compute::{BetweenOptions, StrictComparison};
67    use vortex_dtype::{DType, Nullability, PType};
68
69    use crate::{and, between, cast, eq, get_item, gt, lit, not, root, select};
70
71    #[test]
72    fn tree_display_getitem() {
73        let expr = get_item("x", root());
74        println!("{}", expr.display_tree());
75    }
76
77    #[test]
78    fn tree_display_binary() {
79        let expr = gt(get_item("x", root()), lit(5));
80        println!("{}", expr.display_tree());
81    }
82
83    #[test]
84    fn test_child_names_debug() {
85        // Simple test to debug child names display
86        let binary_expr = gt(get_item("x", root()), lit(10));
87        println!("Binary expr tree:\n{}", binary_expr.display_tree());
88
89        let between_expr = between(
90            get_item("score", root()),
91            lit(0),
92            lit(100),
93            BetweenOptions {
94                lower_strict: StrictComparison::NonStrict,
95                upper_strict: StrictComparison::NonStrict,
96            },
97        );
98        println!("Between expr tree:\n{}", between_expr.display_tree());
99    }
100
101    #[test]
102    fn test_display_tree() {
103        use insta::assert_snapshot;
104
105        use crate::{pack, select_exclude};
106        let root_expr = root();
107        assert_snapshot!(root_expr.display_tree().to_string(), @"Root");
108
109        let lit_expr = lit(42);
110        assert_snapshot!(lit_expr.display_tree().to_string(), @"Literal(value: 42i32, dtype: i32)");
111
112        let get_item_expr = get_item("my_field", root());
113        assert_snapshot!(get_item_expr.display_tree().to_string(), @r"
114        GetItem(my_field)
115        └── Root
116        ");
117
118        let binary_expr = gt(get_item("x", root()), lit(10));
119        assert_snapshot!(binary_expr.display_tree().to_string(), @r"
120        Binary(>)
121        ├── lhs: GetItem(x)
122        │   └── Root
123        └── rhs: Literal(value: 10i32, dtype: i32)
124        ");
125
126        let complex_binary = and(
127            eq(get_item("name", root()), lit("alice")),
128            gt(get_item("age", root()), lit(18)),
129        );
130        assert_snapshot!(complex_binary.display_tree().to_string(), @r#"
131        Binary(and)
132        ├── lhs: Binary(=)
133        │   ├── lhs: GetItem(name)
134        │   │   └── Root
135        │   └── rhs: Literal(value: "alice", dtype: utf8)
136        └── rhs: Binary(>)
137            ├── lhs: GetItem(age)
138            │   └── Root
139            └── rhs: Literal(value: 18i32, dtype: i32)
140        "#);
141
142        let select_expr = select(["name", "age"], root());
143        assert_snapshot!(select_expr.display_tree().to_string(), @r#"
144        Select(include): ["name", "age"]
145        └── Root
146        "#);
147
148        let select_exclude_expr = select_exclude(["internal_id", "metadata"], root());
149        assert_snapshot!(select_exclude_expr.display_tree().to_string(), @r#"
150        Select(exclude): ["internal_id", "metadata"]
151        └── Root
152        "#);
153
154        let cast_expr = cast(
155            get_item("value", root()),
156            DType::Primitive(PType::I64, Nullability::NonNullable),
157        );
158        assert_snapshot!(cast_expr.display_tree().to_string(), @r"
159        Cast(target: i64)
160        └── GetItem(value)
161            └── Root
162        ");
163
164        let not_expr = not(eq(get_item("active", root()), lit(true)));
165        assert_snapshot!(not_expr.display_tree().to_string(), @r"
166        Not
167        └── Binary(=)
168            ├── lhs: GetItem(active)
169            │   └── Root
170            └── rhs: Literal(value: true, dtype: bool)
171        ");
172
173        let between_expr = between(
174            get_item("score", root()),
175            lit(0),
176            lit(100),
177            BetweenOptions {
178                lower_strict: StrictComparison::NonStrict,
179                upper_strict: StrictComparison::NonStrict,
180            },
181        );
182        assert_snapshot!(between_expr.display_tree().to_string(), @r"
183        Between
184        ├── array: GetItem(score)
185        │   └── Root
186        ├── lower (NonStrict): Literal(value: 0i32, dtype: i32)
187        └── upper (NonStrict): Literal(value: 100i32, dtype: i32)
188        ");
189
190        // Test nested expression
191        let nested_expr = select(
192            ["result"],
193            cast(
194                between(
195                    get_item("score", root()),
196                    lit(50),
197                    lit(100),
198                    BetweenOptions {
199                        lower_strict: StrictComparison::Strict,
200                        upper_strict: StrictComparison::NonStrict,
201                    },
202                ),
203                DType::Bool(Nullability::NonNullable),
204            ),
205        );
206        assert_snapshot!(nested_expr.display_tree().to_string(), @r#"
207        Select(include): ["result"]
208        └── Cast(target: bool)
209            └── Between
210                ├── array: GetItem(score)
211                │   └── Root
212                ├── lower (Strict): Literal(value: 50i32, dtype: i32)
213                └── upper (NonStrict): Literal(value: 100i32, dtype: i32)
214        "#);
215
216        let select_from_pack_expr = select(
217            ["fizz", "buzz"],
218            pack(
219                [
220                    ("fizz", root()),
221                    ("bar", lit(5)),
222                    ("buzz", eq(lit(42), get_item("answer", root()))),
223                ],
224                Nullability::Nullable,
225            ),
226        );
227        assert_snapshot!(select_from_pack_expr.display_tree().to_string(), @r#"
228        Select(include): ["fizz", "buzz"]
229        └── Pack
230            ├── fizz: Root
231            ├── bar: Literal(value: 5i32, dtype: i32)
232            └── buzz: Binary(=)
233                ├── lhs: Literal(value: 42i32, dtype: i32)
234                └── rhs: GetItem(answer)
235                    └── Root
236        "#);
237    }
238}