vortex_expr/transform/
replace.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_dtype::{Nullability, StructFields};
5use vortex_error::{VortexExpect, VortexResult};
6
7use crate::traversal::{NodeExt, Transformed};
8use crate::{ExprRef, col, pack, root};
9
10/// Replaces all occurrences of `needle` in the expression `expr` with `replacement`.
11pub fn replace(expr: ExprRef, needle: &ExprRef, replacement: ExprRef) -> ExprRef {
12    expr.transform_up(|node| replace_transformer(node, needle, &replacement))
13        .vortex_expect("ReplaceVisitor should not fail")
14        .into_inner()
15}
16
17/// Expand the `root` expression with a pack of the given struct fields.
18pub fn replace_root_fields(expr: ExprRef, fields: &StructFields) -> ExprRef {
19    replace(
20        expr,
21        &root(),
22        pack(
23            fields
24                .names()
25                .iter()
26                .map(|name| (name.clone(), col(name.clone()))),
27            Nullability::NonNullable,
28        ),
29    )
30}
31
32fn replace_transformer(
33    node: ExprRef,
34    needle: &ExprRef,
35    replacement: &ExprRef,
36) -> VortexResult<Transformed<ExprRef>> {
37    if &node == needle {
38        Ok(Transformed::yes(replacement.clone()))
39    } else {
40        Ok(Transformed::no(node))
41    }
42}
43
44#[cfg(test)]
45mod test {
46    use vortex_dtype::Nullability::NonNullable;
47
48    use super::replace;
49    use crate::{get_item, lit, pack};
50
51    #[test]
52    fn test_replace_full_tree() {
53        let e = get_item("b", pack([("a", lit(1)), ("b", lit(2))], NonNullable));
54        let needle = get_item("b", pack([("a", lit(1)), ("b", lit(2))], NonNullable));
55        let replacement = lit(42);
56        let replaced_expr = replace(e, &needle, replacement.clone());
57        assert_eq!(&replaced_expr, &replacement);
58    }
59
60    #[test]
61    fn test_replace_leaf() {
62        let e = pack([("a", lit(1)), ("b", lit(2))], NonNullable);
63        let needle = lit(2);
64        let replacement = lit(42);
65        let replaced_expr = replace(e, &needle, replacement.clone());
66        assert_eq!(replaced_expr.to_string(), "pack(a: 1i32, b: 42i32)");
67    }
68}