vortex_array/expr/transform/
replace.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_dtype::{Nullability, StructFields};
5use vortex_error::{VortexExpect, VortexResult};
6
7use crate::expr::Expression;
8use crate::expr::exprs::get_item::col;
9use crate::expr::exprs::pack::pack;
10use crate::expr::exprs::root::root;
11use crate::expr::traversal::{NodeExt, Transformed};
12
13/// Replaces all occurrences of `needle` in the expression `expr` with `replacement`.
14pub fn replace(expr: Expression, needle: &Expression, replacement: Expression) -> Expression {
15    expr.transform_up(|node| replace_transformer(node, needle, &replacement))
16        .vortex_expect("ReplaceVisitor should not fail")
17        .into_inner()
18}
19
20/// Expand the `root` expression with a pack of the given struct fields.
21pub fn replace_root_fields(expr: Expression, fields: &StructFields) -> Expression {
22    replace(
23        expr,
24        &root(),
25        pack(
26            fields
27                .names()
28                .iter()
29                .map(|name| (name.clone(), col(name.clone()))),
30            Nullability::NonNullable,
31        ),
32    )
33}
34
35fn replace_transformer(
36    node: Expression,
37    needle: &Expression,
38    replacement: &Expression,
39) -> VortexResult<Transformed<Expression>> {
40    if &node == needle {
41        Ok(Transformed::yes(replacement.clone()))
42    } else {
43        Ok(Transformed::no(node))
44    }
45}
46
47#[cfg(test)]
48mod test {
49    use vortex_dtype::Nullability::NonNullable;
50
51    use super::replace;
52    use crate::expr::exprs::get_item::get_item;
53    use crate::expr::exprs::literal::lit;
54    use crate::expr::exprs::pack::pack;
55
56    #[test]
57    fn test_replace_full_tree() {
58        let e = get_item("b", pack([("a", lit(1)), ("b", lit(2))], NonNullable));
59        let needle = get_item("b", pack([("a", lit(1)), ("b", lit(2))], NonNullable));
60        let replacement = lit(42);
61        let replaced_expr = replace(e, &needle, replacement.clone());
62        assert_eq!(&replaced_expr, &replacement);
63    }
64
65    #[test]
66    fn test_replace_leaf() {
67        let e = pack([("a", lit(1)), ("b", lit(2))], NonNullable);
68        let needle = lit(2);
69        let replacement = lit(42);
70        let replaced_expr = replace(e, &needle, replacement);
71        assert_eq!(replaced_expr.to_string(), "pack(a: 1i32, b: 42i32)");
72    }
73}