vortex_fastlanes/for/compute/
compare.rs1use std::ops::Shr;
5
6use num_traits::WrappingSub;
7use vortex_array::Array;
8use vortex_array::ArrayRef;
9use vortex_array::ExecutionCtx;
10use vortex_array::IntoArray;
11use vortex_array::arrays::ConstantArray;
12use vortex_array::builtins::ArrayBuiltins;
13use vortex_array::dtype::NativePType;
14use vortex_array::dtype::Nullability;
15use vortex_array::match_each_integer_ptype;
16use vortex_array::scalar::PValue;
17use vortex_array::scalar::Scalar;
18use vortex_array::scalar_fn::fns::binary::CompareKernel;
19use vortex_array::scalar_fn::fns::operators::CompareOperator;
20use vortex_array::scalar_fn::fns::operators::Operator;
21use vortex_error::VortexError;
22use vortex_error::VortexExpect as _;
23use vortex_error::VortexResult;
24
25use crate::FoRArray;
26use crate::FoRVTable;
27
28impl CompareKernel for FoRVTable {
29 fn compare(
30 lhs: &FoRArray,
31 rhs: &dyn Array,
32 operator: CompareOperator,
33 _ctx: &mut ExecutionCtx,
34 ) -> VortexResult<Option<ArrayRef>> {
35 if let Some(constant) = rhs.as_constant()
36 && let Some(constant) = constant.as_primitive_opt()
37 {
38 match_each_integer_ptype!(constant.ptype(), |T| {
39 return compare_constant(
40 lhs,
41 constant
42 .typed_value::<T>()
43 .vortex_expect("null scalar handled in adaptor"),
44 rhs.dtype().nullability(),
45 operator,
46 );
47 })
48 }
49
50 Ok(None)
51 }
52}
53
54fn compare_constant<T>(
55 lhs: &FoRArray,
56 mut rhs: T,
57 nullability: Nullability,
58 operator: CompareOperator,
59) -> VortexResult<Option<ArrayRef>>
60where
61 T: NativePType + WrappingSub + Shr<usize, Output = T>,
62 T: TryFrom<PValue, Error = VortexError>,
63 PValue: From<T>,
64{
65 if !matches!(operator, CompareOperator::Eq | CompareOperator::NotEq) {
68 return Ok(None);
69 }
70
71 let reference = lhs.reference_scalar();
72 let reference = reference.as_primitive().typed_value::<T>();
73
74 if let Some(reference) = reference {
76 rhs = rhs.wrapping_sub(&reference);
77 }
78
79 let rhs = Scalar::primitive(rhs, nullability);
82
83 lhs.encoded()
84 .binary(
85 ConstantArray::new(rhs, lhs.len()).into_array(),
86 Operator::from(operator),
87 )
88 .map(Some)
89}
90
91#[cfg(test)]
92mod tests {
93 use vortex_array::IntoArray;
94 use vortex_array::arrays::BoolArray;
95 use vortex_array::arrays::PrimitiveArray;
96 use vortex_array::assert_arrays_eq;
97 use vortex_array::dtype::DType;
98 use vortex_array::validity::Validity;
99 use vortex_buffer::buffer;
100
101 use super::*;
102
103 #[test]
104 fn test_compare_constant() {
105 let reference = Scalar::from(10);
106 let lhs = FoRArray::try_new(
108 PrimitiveArray::new(buffer!(0i32, 20, 2), Validity::AllValid).into_array(),
109 reference,
110 )
111 .unwrap();
112
113 let result = compare_constant(&lhs, 30i32, Nullability::NonNullable, CompareOperator::Eq)
114 .unwrap()
115 .unwrap();
116 assert_arrays_eq!(result, BoolArray::from_iter([false, true, false].map(Some)));
117
118 let result = compare_constant(
119 &lhs,
120 12i32,
121 Nullability::NonNullable,
122 CompareOperator::NotEq,
123 )
124 .unwrap()
125 .unwrap();
126 assert_arrays_eq!(result, BoolArray::from_iter([true, true, false].map(Some)));
127
128 for op in [
129 CompareOperator::Lt,
130 CompareOperator::Lte,
131 CompareOperator::Gt,
132 CompareOperator::Gte,
133 ] {
134 assert!(
135 compare_constant(&lhs, 30i32, Nullability::NonNullable, op)
136 .unwrap()
137 .is_none()
138 );
139 }
140 }
141
142 #[test]
143 fn test_compare_nullable_constant() {
144 let reference = Scalar::from(0);
145 let lhs = FoRArray::try_new(
147 PrimitiveArray::new(buffer!(0i32, 20, 2), Validity::NonNullable).into_array(),
148 reference,
149 )
150 .unwrap();
151
152 assert_eq!(
153 compare_constant(&lhs, 30i32, Nullability::Nullable, CompareOperator::Eq)
154 .unwrap()
155 .unwrap()
156 .dtype(),
157 &DType::Bool(Nullability::Nullable)
158 );
159 assert_eq!(
160 compare_constant(&lhs, 30i32, Nullability::NonNullable, CompareOperator::Eq)
161 .unwrap()
162 .unwrap()
163 .dtype(),
164 &DType::Bool(Nullability::NonNullable)
165 );
166 }
167
168 #[test]
169 fn compare_non_encodable_constant() {
170 let reference = Scalar::from(10);
171 let lhs = FoRArray::try_new(
173 PrimitiveArray::new(buffer!(0i32, 10, 1), Validity::AllValid).into_array(),
174 reference,
175 )
176 .unwrap();
177
178 let result = compare_constant(&lhs, -1i32, Nullability::NonNullable, CompareOperator::Eq)
179 .unwrap()
180 .unwrap();
181 assert_arrays_eq!(
182 result,
183 BoolArray::from_iter([false, false, false].map(Some))
184 );
185
186 let result = compare_constant(
187 &lhs,
188 -1i32,
189 Nullability::NonNullable,
190 CompareOperator::NotEq,
191 )
192 .unwrap()
193 .unwrap();
194 assert_arrays_eq!(result, BoolArray::from_iter([true, true, true].map(Some)));
195 }
196
197 #[test]
198 fn compare_large_constant() {
199 let reference = Scalar::from(-9219218377546224477i64);
200 #[allow(clippy::cast_possible_truncation)]
201 let lhs = FoRArray::try_new(
202 PrimitiveArray::new(
203 buffer![0i64, 9654309310445864926u64 as i64],
204 Validity::AllValid,
205 )
206 .into_array(),
207 reference,
208 )
209 .unwrap();
210
211 let result = compare_constant(
212 &lhs,
213 435090932899640449i64,
214 Nullability::Nullable,
215 CompareOperator::Eq,
216 )
217 .unwrap()
218 .unwrap();
219 assert_arrays_eq!(result, BoolArray::from_iter([Some(false), Some(true)]));
220
221 let result = compare_constant(
222 &lhs,
223 435090932899640449i64,
224 Nullability::Nullable,
225 CompareOperator::NotEq,
226 )
227 .unwrap()
228 .unwrap();
229 assert_arrays_eq!(result, BoolArray::from_iter([Some(true), Some(false)]));
230 }
231}