vortex_array/expr/transform/
replace.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_dtype::Nullability;
5use vortex_dtype::StructFields;
6use vortex_error::VortexExpect;
7
8use crate::expr::Expression;
9use crate::expr::exprs::get_item::col;
10use crate::expr::exprs::pack::pack;
11use crate::expr::exprs::root::root;
12use crate::expr::traversal::NodeExt;
13use crate::expr::traversal::Transformed;
14use crate::expr::traversal::TraversalOrder;
15
16/// Replaces all occurrences of `needle` in the expression `expr` with `replacement`.
17pub fn replace(expr: Expression, needle: &Expression, replacement: Expression) -> Expression {
18    expr.transform_down(|node| {
19        if &node == needle {
20            Ok(Transformed {
21                value: replacement.clone(),
22                // If there is a match with a needle there can be no more matches in that subtree.
23                order: TraversalOrder::Skip,
24                changed: true,
25            })
26        } else {
27            Ok(Transformed::no(node))
28        }
29    })
30    .vortex_expect("ReplaceVisitor should not fail")
31    .into_inner()
32}
33
34/// Expand the `root` expression with a pack of the given struct fields.
35pub fn replace_root_fields(expr: Expression, fields: &StructFields) -> Expression {
36    replace(
37        expr,
38        &root(),
39        pack(
40            fields
41                .names()
42                .iter()
43                .map(|name| (name.clone(), col(name.clone()))),
44            Nullability::NonNullable,
45        ),
46    )
47}
48
49#[cfg(test)]
50mod test {
51    use vortex_dtype::Nullability::NonNullable;
52
53    use super::replace;
54    use crate::expr::exprs::get_item::get_item;
55    use crate::expr::exprs::literal::lit;
56    use crate::expr::exprs::pack::pack;
57
58    #[test]
59    fn test_replace_full_tree() {
60        let e = get_item("b", pack([("a", lit(1)), ("b", lit(2))], NonNullable));
61        let needle = get_item("b", pack([("a", lit(1)), ("b", lit(2))], NonNullable));
62        let replacement = lit(42);
63        let replaced_expr = replace(e, &needle, replacement.clone());
64        assert_eq!(&replaced_expr, &replacement);
65    }
66
67    #[test]
68    fn test_replace_leaf() {
69        let e = pack([("a", lit(1)), ("b", lit(2))], NonNullable);
70        let needle = lit(2);
71        let replacement = lit(42);
72        let replaced_expr = replace(e, &needle, replacement);
73        assert_eq!(replaced_expr.to_string(), "pack(a: 1i32, b: 42i32)");
74    }
75}