vortex_array/compute/conformance/
binary_numeric.rs1use itertools::Itertools;
25use num_traits::Num;
26use vortex_error::VortexExpect;
27use vortex_error::vortex_err;
28use vortex_error::vortex_panic;
29
30use crate::Array;
31use crate::ArrayRef;
32use crate::IntoArray;
33use crate::LEGACY_SESSION;
34use crate::RecursiveCanonical;
35use crate::ToCanonical;
36use crate::VortexSessionExecute;
37use crate::arrays::ConstantArray;
38use crate::builtins::ArrayBuiltins;
39use crate::dtype::DType;
40use crate::dtype::NativePType;
41use crate::dtype::PType;
42use crate::scalar::NumericOperator;
43use crate::scalar::PrimitiveScalar;
44use crate::scalar::Scalar;
45
46fn to_vec_of_scalar(array: &ArrayRef) -> Vec<Scalar> {
47 (0..array.len())
49 .map(|index| {
50 array
51 .scalar_at(index)
52 .vortex_expect("scalar_at should succeed in conformance test")
53 })
54 .collect_vec()
55}
56
57fn test_binary_numeric_conformance<T: NativePType + Num + Copy>(array: ArrayRef)
82where
83 Scalar: From<T>,
84{
85 test_standard_binary_numeric::<T>(array.clone());
87
88 test_binary_numeric_edge_cases(array);
90}
91
92fn test_standard_binary_numeric<T: NativePType + Num + Copy>(array: ArrayRef)
93where
94 Scalar: From<T>,
95{
96 let canonicalized_array = array.to_primitive();
97 let original_values = to_vec_of_scalar(&canonicalized_array.into_array());
98
99 let one = T::from(1)
100 .ok_or_else(|| vortex_err!("could not convert 1 into array native type"))
101 .vortex_expect("operation should succeed in conformance test");
102 let scalar_one = Scalar::from(one)
103 .cast(array.dtype())
104 .vortex_expect("operation should succeed in conformance test");
105
106 let operators: [NumericOperator; 4] = [
107 NumericOperator::Add,
108 NumericOperator::Sub,
109 NumericOperator::Mul,
110 NumericOperator::Div,
111 ];
112
113 for operator in operators {
114 let op = operator;
115 let rhs_const = ConstantArray::new(scalar_one.clone(), array.len()).into_array();
116
117 let result = array
119 .binary(rhs_const.clone(), op.into())
120 .vortex_expect("apply shouldn't fail")
121 .execute::<RecursiveCanonical>(&mut LEGACY_SESSION.create_execution_ctx())
122 .map(|c| c.0.into_array());
123
124 let Ok(result) = result else {
127 continue;
128 };
129
130 println!("result {}", result.display_tree());
131 println!("result {}", result.display_values());
132
133 let actual_values = to_vec_of_scalar(&result);
134
135 let expected_results: Vec<Option<Scalar>> = original_values
137 .iter()
138 .map(|x| {
139 x.as_primitive()
140 .checked_binary_numeric(&scalar_one.as_primitive(), op)
141 .map(<Scalar as From<PrimitiveScalar<'_>>>::from)
142 })
143 .collect();
144
145 for (idx, (actual, expected)) in actual_values.iter().zip(&expected_results).enumerate() {
147 if let Some(expected_value) = expected {
148 assert_eq!(
149 actual,
150 expected_value,
151 "Binary numeric operation failed for encoding {} at index {}: \
152 ({array:?})[{idx}] {operator:?} {scalar_one} \
153 expected {expected_value:?}, got {actual:?}",
154 array.encoding_id(),
155 idx,
156 );
157 }
158 }
159
160 let result = rhs_const.binary(array.clone(), op.into()).and_then(|a| {
162 a.execute::<RecursiveCanonical>(&mut LEGACY_SESSION.create_execution_ctx())
163 .map(|c| c.0.into_array())
164 });
165
166 let Ok(result) = result else {
168 continue;
169 };
170
171 let actual_values = to_vec_of_scalar(&result);
172
173 let expected_results: Vec<Option<Scalar>> = original_values
175 .iter()
176 .map(|x| {
177 scalar_one
178 .as_primitive()
179 .checked_binary_numeric(&x.as_primitive(), op)
180 .map(<Scalar as From<PrimitiveScalar<'_>>>::from)
181 })
182 .collect();
183
184 for (idx, (actual, expected)) in actual_values.iter().zip(&expected_results).enumerate() {
186 if let Some(expected_value) = expected {
187 assert_eq!(
188 actual,
189 expected_value,
190 "Binary numeric operation failed for encoding {} at index {}: \
191 {scalar_one} {operator:?} ({array:?})[{idx}] \
192 expected {expected_value:?}, got {actual:?}",
193 array.encoding_id(),
194 idx,
195 );
196 }
197 }
198 }
199}
200
201pub fn test_binary_numeric_array(array: ArrayRef) {
218 match array.dtype() {
219 DType::Primitive(ptype, _) => match ptype {
220 PType::I8 => test_binary_numeric_conformance::<i8>(array),
221 PType::I16 => test_binary_numeric_conformance::<i16>(array),
222 PType::I32 => test_binary_numeric_conformance::<i32>(array),
223 PType::I64 => test_binary_numeric_conformance::<i64>(array),
224 PType::U8 => test_binary_numeric_conformance::<u8>(array),
225 PType::U16 => test_binary_numeric_conformance::<u16>(array),
226 PType::U32 => test_binary_numeric_conformance::<u32>(array),
227 PType::U64 => test_binary_numeric_conformance::<u64>(array),
228 PType::F16 => {
229 eprintln!("Skipping f16 binary numeric tests (not supported)");
231 }
232 PType::F32 => test_binary_numeric_conformance::<f32>(array),
233 PType::F64 => test_binary_numeric_conformance::<f64>(array),
234 },
235 dtype => vortex_panic!(
236 "Binary numeric tests are only supported for primitive numeric types, got {dtype}",
237 ),
238 }
239}
240
241fn test_binary_numeric_edge_cases(array: ArrayRef) {
249 match array.dtype() {
250 DType::Primitive(ptype, _) => match ptype {
251 PType::I8 => test_binary_numeric_edge_cases_signed::<i8>(array),
252 PType::I16 => test_binary_numeric_edge_cases_signed::<i16>(array),
253 PType::I32 => test_binary_numeric_edge_cases_signed::<i32>(array),
254 PType::I64 => test_binary_numeric_edge_cases_signed::<i64>(array),
255 PType::U8 => test_binary_numeric_edge_cases_unsigned::<u8>(array),
256 PType::U16 => test_binary_numeric_edge_cases_unsigned::<u16>(array),
257 PType::U32 => test_binary_numeric_edge_cases_unsigned::<u32>(array),
258 PType::U64 => test_binary_numeric_edge_cases_unsigned::<u64>(array),
259 PType::F16 => {
260 eprintln!("Skipping f16 edge case tests (not supported)");
261 }
262 PType::F32 => test_binary_numeric_edge_cases_float::<f32>(array),
263 PType::F64 => test_binary_numeric_edge_cases_float::<f64>(array),
264 },
265 dtype => vortex_panic!(
266 "Binary numeric edge case tests are only supported for primitive numeric types, got {dtype}"
267 ),
268 }
269}
270
271fn test_binary_numeric_edge_cases_signed<T>(array: ArrayRef)
272where
273 T: NativePType + Num + Copy + std::fmt::Debug + num_traits::Bounded + num_traits::Signed,
274 Scalar: From<T>,
275{
276 test_binary_numeric_with_scalar(array.clone(), T::zero());
278
279 test_binary_numeric_with_scalar(array.clone(), -T::one());
281
282 test_binary_numeric_with_scalar(array.clone(), T::max_value());
284
285 test_binary_numeric_with_scalar(array, T::min_value());
287}
288
289fn test_binary_numeric_edge_cases_unsigned<T>(array: ArrayRef)
290where
291 T: NativePType + Num + Copy + std::fmt::Debug + num_traits::Bounded,
292 Scalar: From<T>,
293{
294 test_binary_numeric_with_scalar(array.clone(), T::zero());
296
297 test_binary_numeric_with_scalar(array, T::max_value());
299}
300
301fn test_binary_numeric_edge_cases_float<T>(array: ArrayRef)
302where
303 T: NativePType + Num + Copy + std::fmt::Debug + num_traits::Float,
304 Scalar: From<T>,
305{
306 test_binary_numeric_with_scalar(array.clone(), T::zero());
308
309 test_binary_numeric_with_scalar(array.clone(), -T::one());
311
312 test_binary_numeric_with_scalar(array.clone(), T::max_value());
314
315 test_binary_numeric_with_scalar(array.clone(), T::min_value());
317
318 test_binary_numeric_with_scalar(array.clone(), T::epsilon());
320
321 test_binary_numeric_with_scalar(array.clone(), T::min_positive_value());
323
324 test_binary_numeric_with_scalar(array.clone(), T::nan());
326 test_binary_numeric_with_scalar(array.clone(), T::infinity());
327 test_binary_numeric_with_scalar(array, T::neg_infinity());
328}
329
330fn test_binary_numeric_with_scalar<T>(array: ArrayRef, scalar_value: T)
331where
332 T: NativePType + Num + Copy + std::fmt::Debug,
333 Scalar: From<T>,
334{
335 let canonicalized_array = array.to_primitive();
336 let original_values = to_vec_of_scalar(&canonicalized_array.into_array());
337
338 let scalar = Scalar::from(scalar_value)
339 .cast(array.dtype())
340 .vortex_expect("operation should succeed in conformance test");
341
342 let operators = if scalar_value == T::zero() {
344 vec![
346 NumericOperator::Add,
347 NumericOperator::Sub,
348 NumericOperator::Mul,
349 ]
350 } else {
351 vec![
352 NumericOperator::Add,
353 NumericOperator::Sub,
354 NumericOperator::Mul,
355 NumericOperator::Div,
356 ]
357 };
358
359 for operator in operators {
360 let op = operator;
361 let rhs_const = ConstantArray::new(scalar.clone(), array.len()).into_array();
362
363 let result = array
365 .binary(rhs_const, op.into())
366 .vortex_expect("apply failed")
367 .execute::<RecursiveCanonical>(&mut LEGACY_SESSION.create_execution_ctx())
368 .map(|x| x.0.into_array());
369
370 if result.is_err() {
373 continue;
374 }
375
376 let result = result.vortex_expect("operation should succeed in conformance test");
377 let actual_values = to_vec_of_scalar(&result);
378
379 let expected_results: Vec<Option<Scalar>> = original_values
381 .iter()
382 .map(|x| {
383 x.as_primitive()
384 .checked_binary_numeric(&scalar.as_primitive(), op)
385 .map(<Scalar as From<PrimitiveScalar<'_>>>::from)
386 })
387 .collect();
388
389 for (idx, (actual, expected)) in actual_values.iter().zip(&expected_results).enumerate() {
391 if let Some(expected_value) = expected {
392 assert_eq!(
393 actual,
394 expected_value,
395 "Binary numeric operation failed for encoding {} at index {} with scalar {:?}: \
396 ({array:?})[{idx}] {operator:?} {scalar} \
397 expected {expected_value:?}, got {actual:?}",
398 array.encoding_id(),
399 idx,
400 scalar_value,
401 );
402 }
403 }
404 }
405}