vortex_array/expr/transform/
coerce.rs1use vortex_error::VortexResult;
7
8use crate::dtype::DType;
9use crate::expr::Expression;
10use crate::expr::cast;
11use crate::expr::traversal::NodeExt;
12use crate::expr::traversal::Transformed;
13use crate::scalar_fn::fns::literal::Literal;
14use crate::scalar_fn::fns::root::Root;
15
16pub fn coerce_expression(expr: Expression, scope: &DType) -> VortexResult<Expression> {
22 let scope = scope.clone();
24 expr.transform_up(|node| {
25 if node.is::<Root>() || node.is::<Literal>() || node.children().is_empty() {
27 return Ok(Transformed::no(node));
28 }
29
30 let child_dtypes: Vec<DType> = node
32 .children()
33 .iter()
34 .map(|c| c.return_dtype(&scope))
35 .collect::<VortexResult<_>>()?;
36
37 let coerced_dtypes = node.scalar_fn().coerce_args(&child_dtypes)?;
39
40 if child_dtypes == coerced_dtypes {
42 return Ok(Transformed::no(node));
43 }
44
45 let new_children: Vec<Expression> = node
47 .children()
48 .iter()
49 .zip(coerced_dtypes.iter())
50 .map(|(child, target)| {
51 let child_dtype = child.return_dtype(&scope)?;
52 if child_dtype.eq_ignore_nullability(target)
53 && child_dtype.nullability() == target.nullability()
54 {
55 Ok(child.clone())
56 } else {
57 Ok(cast(child.clone(), target.clone()))
58 }
59 })
60 .collect::<VortexResult<_>>()?;
61
62 let new_expr = node.with_children(new_children)?;
63 Ok(Transformed::yes(new_expr))
64 })
65 .map(|t| t.into_inner())
66}
67
68#[cfg(test)]
69mod tests {
70 use vortex_error::VortexResult;
71
72 use crate::dtype::DType;
73 use crate::dtype::Nullability::NonNullable;
74 use crate::dtype::PType;
75 use crate::dtype::StructFields;
76 use crate::expr::col;
77 use crate::expr::lit;
78 use crate::expr::transform::coerce::coerce_expression;
79 use crate::scalar::Scalar;
80 use crate::scalar_fn::ScalarFnVTableExt;
81 use crate::scalar_fn::fns::binary::Binary;
82 use crate::scalar_fn::fns::cast::Cast;
83 use crate::scalar_fn::fns::operators::Operator;
84
85 fn test_scope() -> DType {
86 DType::Struct(
87 StructFields::new(
88 ["x", "y"].into(),
89 vec![
90 DType::Primitive(PType::I32, NonNullable),
91 DType::Primitive(PType::I64, NonNullable),
92 ],
93 ),
94 NonNullable,
95 )
96 }
97
98 #[test]
99 fn mixed_type_comparison_inserts_cast() -> VortexResult<()> {
100 let scope = test_scope();
101 let expr = Binary.new_expr(Operator::Lt, [col("x"), col("y")]);
103 let coerced = coerce_expression(expr, &scope)?;
104
105 assert!(coerced.child(0).is::<Cast>());
107 assert_eq!(
109 coerced.child(0).return_dtype(&scope)?,
110 DType::Primitive(PType::I64, NonNullable)
111 );
112 assert!(!coerced.child(1).is::<Cast>());
114 Ok(())
115 }
116
117 #[test]
118 fn same_type_comparison_no_cast() -> VortexResult<()> {
119 let scope = test_scope();
120 let expr = Binary.new_expr(Operator::Lt, [col("x"), col("x")]);
122 let coerced = coerce_expression(expr, &scope)?;
123
124 assert!(!coerced.child(0).is::<Cast>());
126 assert!(!coerced.child(1).is::<Cast>());
127 Ok(())
128 }
129
130 #[test]
131 fn mixed_type_arithmetic_coerces_both() -> VortexResult<()> {
132 let scope = DType::Struct(
133 StructFields::new(
134 ["a", "b"].into(),
135 vec![
136 DType::Primitive(PType::U8, NonNullable),
137 DType::Primitive(PType::I32, NonNullable),
138 ],
139 ),
140 NonNullable,
141 );
142 let expr = Binary.new_expr(Operator::Add, [col("a"), col("b")]);
145 let coerced = coerce_expression(expr, &scope)?;
146
147 assert!(coerced.child(0).is::<Cast>());
149 let lhs_dt = coerced.child(0).return_dtype(&scope)?;
151 let rhs_dt = coerced.child(1).return_dtype(&scope)?;
152 assert_eq!(lhs_dt, rhs_dt);
153 Ok(())
154 }
155
156 #[test]
157 fn boolean_operators_no_coercion() -> VortexResult<()> {
158 let scope = DType::Struct(
159 StructFields::new(
160 ["p", "q"].into(),
161 vec![DType::Bool(NonNullable), DType::Bool(NonNullable)],
162 ),
163 NonNullable,
164 );
165 let expr = Binary.new_expr(Operator::And, [col("p"), col("q")]);
166 let coerced = coerce_expression(expr, &scope)?;
167
168 assert!(!coerced.child(0).is::<Cast>());
169 assert!(!coerced.child(1).is::<Cast>());
170 Ok(())
171 }
172
173 #[test]
174 fn literal_coercion() -> VortexResult<()> {
175 let scope = DType::Struct(
176 StructFields::new(
177 ["x"].into(),
178 vec![DType::Primitive(PType::I64, NonNullable)],
179 ),
180 NonNullable,
181 );
182 let expr = Binary.new_expr(Operator::Add, [col("x"), lit(Scalar::from(1i32))]);
184 let coerced = coerce_expression(expr, &scope)?;
185
186 assert!(coerced.child(1).is::<Cast>());
188 assert_eq!(
189 coerced.child(1).return_dtype(&scope)?,
190 DType::Primitive(PType::I64, NonNullable)
191 );
192 Ok(())
193 }
194}