vortex_fastlanes/for/compute/
compare.rs1use std::ops::Shr;
5
6use num_traits::WrappingSub;
7use vortex_array::Array;
8use vortex_array::ArrayRef;
9use vortex_array::ExecutionCtx;
10use vortex_array::arrays::ConstantArray;
11use vortex_array::compute::Operator;
12use vortex_array::compute::compare;
13use vortex_array::expr::CompareKernel;
14use vortex_array::scalar::PValue;
15use vortex_array::scalar::Scalar;
16use vortex_dtype::NativePType;
17use vortex_dtype::Nullability;
18use vortex_dtype::match_each_integer_ptype;
19use vortex_error::VortexError;
20use vortex_error::VortexExpect as _;
21use vortex_error::VortexResult;
22
23use crate::FoRArray;
24use crate::FoRVTable;
25
26impl CompareKernel for FoRVTable {
27 fn compare(
28 lhs: &FoRArray,
29 rhs: &dyn Array,
30 operator: Operator,
31 _ctx: &mut ExecutionCtx,
32 ) -> VortexResult<Option<ArrayRef>> {
33 if let Some(constant) = rhs.as_constant()
34 && let Some(constant) = constant.as_primitive_opt()
35 {
36 match_each_integer_ptype!(constant.ptype(), |T| {
37 return compare_constant(
38 lhs,
39 constant
40 .typed_value::<T>()
41 .vortex_expect("null scalar handled in adaptor"),
42 rhs.dtype().nullability(),
43 operator,
44 );
45 })
46 }
47
48 Ok(None)
49 }
50}
51
52fn compare_constant<T>(
53 lhs: &FoRArray,
54 mut rhs: T,
55 nullability: Nullability,
56 operator: Operator,
57) -> VortexResult<Option<ArrayRef>>
58where
59 T: NativePType + WrappingSub + Shr<usize, Output = T>,
60 T: TryFrom<PValue, Error = VortexError>,
61 PValue: From<T>,
62{
63 if !matches!(operator, Operator::Eq | Operator::NotEq) {
66 return Ok(None);
67 }
68
69 let reference = lhs.reference_scalar();
70 let reference = reference.as_primitive().typed_value::<T>();
71
72 if let Some(reference) = reference {
74 rhs = rhs.wrapping_sub(&reference);
75 }
76
77 let rhs = Scalar::primitive(rhs, nullability);
80
81 compare(
82 lhs.encoded(),
83 ConstantArray::new(rhs, lhs.len()).as_ref(),
84 operator,
85 )
86 .map(Some)
87}
88
89#[cfg(test)]
90mod tests {
91 use vortex_array::IntoArray;
92 use vortex_array::arrays::BoolArray;
93 use vortex_array::arrays::PrimitiveArray;
94 use vortex_array::assert_arrays_eq;
95 use vortex_array::validity::Validity;
96 use vortex_buffer::buffer;
97 use vortex_dtype::DType;
98
99 use super::*;
100
101 #[test]
102 fn test_compare_constant() {
103 let reference = Scalar::from(10);
104 let lhs = FoRArray::try_new(
106 PrimitiveArray::new(buffer!(0i32, 20, 2), Validity::AllValid).into_array(),
107 reference,
108 )
109 .unwrap();
110
111 let result = compare_constant(&lhs, 30i32, Nullability::NonNullable, Operator::Eq)
112 .unwrap()
113 .unwrap();
114 assert_arrays_eq!(result, BoolArray::from_iter([false, true, false].map(Some)));
115
116 let result = compare_constant(&lhs, 12i32, Nullability::NonNullable, Operator::NotEq)
117 .unwrap()
118 .unwrap();
119 assert_arrays_eq!(result, BoolArray::from_iter([true, true, false].map(Some)));
120
121 for op in [Operator::Lt, Operator::Lte, Operator::Gt, Operator::Gte] {
122 assert!(
123 compare_constant(&lhs, 30i32, Nullability::NonNullable, op)
124 .unwrap()
125 .is_none()
126 );
127 }
128 }
129
130 #[test]
131 fn test_compare_nullable_constant() {
132 let reference = Scalar::from(0);
133 let lhs = FoRArray::try_new(
135 PrimitiveArray::new(buffer!(0i32, 20, 2), Validity::NonNullable).into_array(),
136 reference,
137 )
138 .unwrap();
139
140 assert_eq!(
141 compare_constant(&lhs, 30i32, Nullability::Nullable, Operator::Eq)
142 .unwrap()
143 .unwrap()
144 .dtype(),
145 &DType::Bool(Nullability::Nullable)
146 );
147 assert_eq!(
148 compare_constant(&lhs, 30i32, Nullability::NonNullable, Operator::Eq)
149 .unwrap()
150 .unwrap()
151 .dtype(),
152 &DType::Bool(Nullability::NonNullable)
153 );
154 }
155
156 #[test]
157 fn compare_non_encodable_constant() {
158 let reference = Scalar::from(10);
159 let lhs = FoRArray::try_new(
161 PrimitiveArray::new(buffer!(0i32, 10, 1), Validity::AllValid).into_array(),
162 reference,
163 )
164 .unwrap();
165
166 let result = compare_constant(&lhs, -1i32, Nullability::NonNullable, Operator::Eq)
167 .unwrap()
168 .unwrap();
169 assert_arrays_eq!(
170 result,
171 BoolArray::from_iter([false, false, false].map(Some))
172 );
173
174 let result = compare_constant(&lhs, -1i32, Nullability::NonNullable, Operator::NotEq)
175 .unwrap()
176 .unwrap();
177 assert_arrays_eq!(result, BoolArray::from_iter([true, true, true].map(Some)));
178 }
179
180 #[test]
181 fn compare_large_constant() {
182 let reference = Scalar::from(-9219218377546224477i64);
183 #[allow(clippy::cast_possible_truncation)]
184 let lhs = FoRArray::try_new(
185 PrimitiveArray::new(
186 buffer![0i64, 9654309310445864926u64 as i64],
187 Validity::AllValid,
188 )
189 .into_array(),
190 reference,
191 )
192 .unwrap();
193
194 let result = compare_constant(
195 &lhs,
196 435090932899640449i64,
197 Nullability::Nullable,
198 Operator::Eq,
199 )
200 .unwrap()
201 .unwrap();
202 assert_arrays_eq!(result, BoolArray::from_iter([Some(false), Some(true)]));
203
204 let result = compare_constant(
205 &lhs,
206 435090932899640449i64,
207 Nullability::Nullable,
208 Operator::NotEq,
209 )
210 .unwrap()
211 .unwrap();
212 assert_arrays_eq!(result, BoolArray::from_iter([Some(true), Some(false)]));
213 }
214}