Skip to main content

vortex_array/expr/transform/
coerce.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Expression-level type coercion pass.
5
6use vortex_error::VortexResult;
7
8use crate::dtype::DType;
9use crate::expr::Expression;
10use crate::expr::cast;
11use crate::expr::traversal::NodeExt;
12use crate::expr::traversal::Transformed;
13use crate::scalar_fn::fns::literal::Literal;
14use crate::scalar_fn::fns::root::Root;
15
16/// Rewrite an expression tree to insert casts where a scalar function's `coerce_args` demands
17/// a different type than what the child currently produces.
18///
19/// The rewrite is bottom-up: children are coerced first, then each parent node checks whether
20/// its children match the coerced argument types.
21pub fn coerce_expression(expr: Expression, scope: &DType) -> VortexResult<Expression> {
22    // We capture scope by reference for the closure.
23    let scope = scope.clone();
24    expr.transform_up(|node| {
25        // Leaf nodes (Root, Literal) have no children to coerce.
26        if node.is::<Root>() || node.is::<Literal>() || node.children().is_empty() {
27            return Ok(Transformed::no(node));
28        }
29
30        // Compute the current child return types.
31        let child_dtypes: Vec<DType> = node
32            .children()
33            .iter()
34            .map(|c| c.return_dtype(&scope))
35            .collect::<VortexResult<_>>()?;
36
37        // Ask the scalar function what types it wants.
38        let coerced_dtypes = node.scalar_fn().coerce_args(&child_dtypes)?;
39
40        // If nothing changed, skip.
41        if child_dtypes == coerced_dtypes {
42            return Ok(Transformed::no(node));
43        }
44
45        // Build new children, inserting casts where needed.
46        let new_children: Vec<Expression> = node
47            .children()
48            .iter()
49            .zip(coerced_dtypes.iter())
50            .map(|(child, target)| {
51                let child_dtype = child.return_dtype(&scope)?;
52                if child_dtype.eq_ignore_nullability(target)
53                    && child_dtype.nullability() == target.nullability()
54                {
55                    Ok(child.clone())
56                } else {
57                    Ok(cast(child.clone(), target.clone()))
58                }
59            })
60            .collect::<VortexResult<_>>()?;
61
62        let new_expr = node.with_children(new_children)?;
63        Ok(Transformed::yes(new_expr))
64    })
65    .map(|t| t.into_inner())
66}
67
68#[cfg(test)]
69mod tests {
70    use vortex_error::VortexResult;
71
72    use crate::dtype::DType;
73    use crate::dtype::Nullability::NonNullable;
74    use crate::dtype::PType;
75    use crate::dtype::StructFields;
76    use crate::expr::col;
77    use crate::expr::lit;
78    use crate::expr::transform::coerce::coerce_expression;
79    use crate::scalar::Scalar;
80    use crate::scalar_fn::ScalarFnVTableExt;
81    use crate::scalar_fn::fns::binary::Binary;
82    use crate::scalar_fn::fns::cast::Cast;
83    use crate::scalar_fn::fns::operators::Operator;
84
85    fn test_scope() -> DType {
86        DType::Struct(
87            StructFields::new(
88                ["x", "y"].into(),
89                vec![
90                    DType::Primitive(PType::I32, NonNullable),
91                    DType::Primitive(PType::I64, NonNullable),
92                ],
93            ),
94            NonNullable,
95        )
96    }
97
98    #[test]
99    fn mixed_type_comparison_inserts_cast() -> VortexResult<()> {
100        let scope = test_scope();
101        // x (I32) < y (I64) => should cast x to I64
102        let expr = Binary.new_expr(Operator::Lt, [col("x"), col("y")]);
103        let coerced = coerce_expression(expr, &scope)?;
104
105        // The LHS child should now be a cast expression
106        assert!(coerced.child(0).is::<Cast>());
107        // The coerced LHS should return I64
108        assert_eq!(
109            coerced.child(0).return_dtype(&scope)?,
110            DType::Primitive(PType::I64, NonNullable)
111        );
112        // The RHS should be unchanged
113        assert!(!coerced.child(1).is::<Cast>());
114        Ok(())
115    }
116
117    #[test]
118    fn same_type_comparison_no_cast() -> VortexResult<()> {
119        let scope = test_scope();
120        // x (I32) < x (I32) => no cast needed
121        let expr = Binary.new_expr(Operator::Lt, [col("x"), col("x")]);
122        let coerced = coerce_expression(expr, &scope)?;
123
124        // Neither child should be a cast
125        assert!(!coerced.child(0).is::<Cast>());
126        assert!(!coerced.child(1).is::<Cast>());
127        Ok(())
128    }
129
130    #[test]
131    fn mixed_type_arithmetic_coerces_both() -> VortexResult<()> {
132        let scope = DType::Struct(
133            StructFields::new(
134                ["a", "b"].into(),
135                vec![
136                    DType::Primitive(PType::U8, NonNullable),
137                    DType::Primitive(PType::I32, NonNullable),
138                ],
139            ),
140            NonNullable,
141        );
142        // a (U8) + b (I32) => both should be coerced to I32
143        // U8 + I32: unsigned_signed_supertype(U8, I32) => max(1,4)=4 => I64
144        let expr = Binary.new_expr(Operator::Add, [col("a"), col("b")]);
145        let coerced = coerce_expression(expr, &scope)?;
146
147        // LHS (U8) should be cast
148        assert!(coerced.child(0).is::<Cast>());
149        // Both should return the same supertype
150        let lhs_dt = coerced.child(0).return_dtype(&scope)?;
151        let rhs_dt = coerced.child(1).return_dtype(&scope)?;
152        assert_eq!(lhs_dt, rhs_dt);
153        Ok(())
154    }
155
156    #[test]
157    fn boolean_operators_no_coercion() -> VortexResult<()> {
158        let scope = DType::Struct(
159            StructFields::new(
160                ["p", "q"].into(),
161                vec![DType::Bool(NonNullable), DType::Bool(NonNullable)],
162            ),
163            NonNullable,
164        );
165        let expr = Binary.new_expr(Operator::And, [col("p"), col("q")]);
166        let coerced = coerce_expression(expr, &scope)?;
167
168        assert!(!coerced.child(0).is::<Cast>());
169        assert!(!coerced.child(1).is::<Cast>());
170        Ok(())
171    }
172
173    #[test]
174    fn literal_coercion() -> VortexResult<()> {
175        let scope = DType::Struct(
176            StructFields::new(
177                ["x"].into(),
178                vec![DType::Primitive(PType::I64, NonNullable)],
179            ),
180            NonNullable,
181        );
182        // x (I64) + 1i32 => literal should be cast to I64
183        let expr = Binary.new_expr(Operator::Add, [col("x"), lit(Scalar::from(1i32))]);
184        let coerced = coerce_expression(expr, &scope)?;
185
186        // The RHS (literal) should be cast to I64
187        assert!(coerced.child(1).is::<Cast>());
188        assert_eq!(
189            coerced.child(1).return_dtype(&scope)?,
190            DType::Primitive(PType::I64, NonNullable)
191        );
192        Ok(())
193    }
194}