vortex_fastlanes/for/compute/
compare.rs1use std::ops::Shr;
5
6use num_traits::WrappingSub;
7use vortex_array::Array;
8use vortex_array::ArrayRef;
9use vortex_array::arrays::ConstantArray;
10use vortex_array::compute::CompareKernel;
11use vortex_array::compute::CompareKernelAdapter;
12use vortex_array::compute::Operator;
13use vortex_array::compute::compare;
14use vortex_array::register_kernel;
15use vortex_dtype::NativePType;
16use vortex_dtype::Nullability;
17use vortex_dtype::match_each_integer_ptype;
18use vortex_error::VortexError;
19use vortex_error::VortexExpect as _;
20use vortex_error::VortexResult;
21use vortex_scalar::PValue;
22use vortex_scalar::PrimitiveScalar;
23use vortex_scalar::Scalar;
24
25use crate::FoRArray;
26use crate::FoRVTable;
27
28impl CompareKernel for FoRVTable {
29 fn compare(
30 &self,
31 lhs: &FoRArray,
32 rhs: &dyn Array,
33 operator: Operator,
34 ) -> VortexResult<Option<ArrayRef>> {
35 if let Some(constant) = rhs.as_constant()
36 && let Ok(constant) = PrimitiveScalar::try_from(&constant)
37 {
38 match_each_integer_ptype!(constant.ptype(), |T| {
39 return compare_constant(
40 lhs,
41 constant
42 .typed_value::<T>()
43 .vortex_expect("null scalar handled in top-level"),
44 rhs.dtype().nullability(),
45 operator,
46 );
47 })
48 }
49
50 Ok(None)
51 }
52}
53
54register_kernel!(CompareKernelAdapter(FoRVTable).lift());
55
56fn compare_constant<T>(
57 lhs: &FoRArray,
58 mut rhs: T,
59 nullability: Nullability,
60 operator: Operator,
61) -> VortexResult<Option<ArrayRef>>
62where
63 T: NativePType + WrappingSub + Shr<usize, Output = T>,
64 T: TryFrom<PValue, Error = VortexError>,
65 PValue: From<T>,
66{
67 if !matches!(operator, Operator::Eq | Operator::NotEq) {
70 return Ok(None);
71 }
72
73 let reference = lhs.reference_scalar();
74 let reference = reference.as_primitive().typed_value::<T>();
75
76 if let Some(reference) = reference {
78 rhs = rhs.wrapping_sub(&reference);
79 }
80
81 let rhs = Scalar::primitive(rhs, nullability);
84
85 compare(
86 lhs.encoded(),
87 ConstantArray::new(rhs, lhs.len()).as_ref(),
88 operator,
89 )
90 .map(Some)
91}
92
93#[cfg(test)]
94mod tests {
95 use vortex_array::IntoArray;
96 use vortex_array::ToCanonical;
97 use vortex_array::arrays::PrimitiveArray;
98 use vortex_array::validity::Validity;
99 use vortex_buffer::BitBuffer;
100 use vortex_buffer::buffer;
101 use vortex_dtype::DType;
102
103 use super::*;
104
105 #[test]
106 fn test_compare_constant() {
107 let reference = Scalar::from(10);
108 let lhs = FoRArray::try_new(
110 PrimitiveArray::new(buffer!(0i32, 20, 2), Validity::AllValid).into_array(),
111 reference,
112 )
113 .unwrap();
114
115 assert_result(
116 compare_constant(&lhs, 30i32, Nullability::NonNullable, Operator::Eq),
117 [false, true, false],
118 );
119 assert_result(
120 compare_constant(&lhs, 12i32, Nullability::NonNullable, Operator::NotEq),
121 [true, true, false],
122 );
123 for op in [Operator::Lt, Operator::Lte, Operator::Gt, Operator::Gte] {
124 assert!(
125 compare_constant(&lhs, 30i32, Nullability::NonNullable, op)
126 .unwrap()
127 .is_none()
128 );
129 }
130 }
131
132 #[test]
133 fn test_compare_nullable_constant() {
134 let reference = Scalar::from(0);
135 let lhs = FoRArray::try_new(
137 PrimitiveArray::new(buffer!(0i32, 20, 2), Validity::NonNullable).into_array(),
138 reference,
139 )
140 .unwrap();
141
142 assert_eq!(
143 compare_constant(&lhs, 30i32, Nullability::Nullable, Operator::Eq)
144 .unwrap()
145 .unwrap()
146 .dtype(),
147 &DType::Bool(Nullability::Nullable)
148 );
149 assert_eq!(
150 compare_constant(&lhs, 30i32, Nullability::NonNullable, Operator::Eq)
151 .unwrap()
152 .unwrap()
153 .dtype(),
154 &DType::Bool(Nullability::NonNullable)
155 );
156 }
157
158 #[test]
159 fn compare_non_encodable_constant() {
160 let reference = Scalar::from(10);
161 let lhs = FoRArray::try_new(
163 PrimitiveArray::new(buffer!(0i32, 10, 1), Validity::AllValid).into_array(),
164 reference,
165 )
166 .unwrap();
167
168 assert_result(
169 compare_constant(&lhs, -1i32, Nullability::NonNullable, Operator::Eq),
170 [false, false, false],
171 );
172 assert_result(
173 compare_constant(&lhs, -1i32, Nullability::NonNullable, Operator::NotEq),
174 [true, true, true],
175 );
176 }
177
178 #[test]
179 fn compare_large_constant() {
180 let reference = Scalar::from(-9219218377546224477i64);
181 #[allow(clippy::cast_possible_truncation)]
182 let lhs = FoRArray::try_new(
183 PrimitiveArray::new(
184 buffer![0i64, 9654309310445864926u64 as i64],
185 Validity::AllValid,
186 )
187 .into_array(),
188 reference,
189 )
190 .unwrap();
191
192 assert_result(
193 compare_constant(
194 &lhs,
195 435090932899640449i64,
196 Nullability::Nullable,
197 Operator::Eq,
198 ),
199 [false, true],
200 );
201 assert_result(
202 compare_constant(
203 &lhs,
204 435090932899640449i64,
205 Nullability::Nullable,
206 Operator::NotEq,
207 ),
208 [true, false],
209 );
210 }
211
212 fn assert_result<T: IntoIterator<Item = bool>>(
213 result: VortexResult<Option<ArrayRef>>,
214 expected: T,
215 ) {
216 let result = result.unwrap().unwrap().to_bool();
217 assert_eq!(result.bit_buffer(), &BitBuffer::from_iter(expected));
218 }
219}