vortex_fastlanes/for/compute/
compare.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::ops::Shr;
5
6use num_traits::WrappingSub;
7use vortex_array::Array;
8use vortex_array::ArrayRef;
9use vortex_array::arrays::ConstantArray;
10use vortex_array::compute::CompareKernel;
11use vortex_array::compute::CompareKernelAdapter;
12use vortex_array::compute::Operator;
13use vortex_array::compute::compare;
14use vortex_array::register_kernel;
15use vortex_dtype::NativePType;
16use vortex_dtype::Nullability;
17use vortex_dtype::match_each_integer_ptype;
18use vortex_error::VortexError;
19use vortex_error::VortexExpect as _;
20use vortex_error::VortexResult;
21use vortex_scalar::PValue;
22use vortex_scalar::PrimitiveScalar;
23use vortex_scalar::Scalar;
24
25use crate::FoRArray;
26use crate::FoRVTable;
27
28impl CompareKernel for FoRVTable {
29    fn compare(
30        &self,
31        lhs: &FoRArray,
32        rhs: &dyn Array,
33        operator: Operator,
34    ) -> VortexResult<Option<ArrayRef>> {
35        if let Some(constant) = rhs.as_constant()
36            && let Ok(constant) = PrimitiveScalar::try_from(&constant)
37        {
38            match_each_integer_ptype!(constant.ptype(), |T| {
39                return compare_constant(
40                    lhs,
41                    constant
42                        .typed_value::<T>()
43                        .vortex_expect("null scalar handled in top-level"),
44                    rhs.dtype().nullability(),
45                    operator,
46                );
47            })
48        }
49
50        Ok(None)
51    }
52}
53
54register_kernel!(CompareKernelAdapter(FoRVTable).lift());
55
56fn compare_constant<T>(
57    lhs: &FoRArray,
58    mut rhs: T,
59    nullability: Nullability,
60    operator: Operator,
61) -> VortexResult<Option<ArrayRef>>
62where
63    T: NativePType + WrappingSub + Shr<usize, Output = T>,
64    T: TryFrom<PValue, Error = VortexError>,
65    PValue: From<T>,
66{
67    // For now, we only support equals and not equals. Comparisons are a little more fiddly to
68    // get right regarding how to handle overflow and the wrapping subtraction.
69    if !matches!(operator, Operator::Eq | Operator::NotEq) {
70        return Ok(None);
71    }
72
73    let reference = lhs.reference_scalar();
74    let reference = reference.as_primitive().typed_value::<T>();
75
76    // We encode the RHS into the FoR domain.
77    if let Some(reference) = reference {
78        rhs = rhs.wrapping_sub(&reference);
79    }
80
81    // Wrap up the RHS into a scalar and cast to the encoded DType (this will be the equivalent
82    // unsigned integer type).
83    let rhs = Scalar::primitive(rhs, nullability);
84
85    compare(
86        lhs.encoded(),
87        ConstantArray::new(rhs, lhs.len()).as_ref(),
88        operator,
89    )
90    .map(Some)
91}
92
93#[cfg(test)]
94mod tests {
95    use vortex_array::IntoArray;
96    use vortex_array::ToCanonical;
97    use vortex_array::arrays::PrimitiveArray;
98    use vortex_array::validity::Validity;
99    use vortex_buffer::BitBuffer;
100    use vortex_buffer::buffer;
101    use vortex_dtype::DType;
102
103    use super::*;
104
105    #[test]
106    fn test_compare_constant() {
107        let reference = Scalar::from(10);
108        // 10, 30, 12
109        let lhs = FoRArray::try_new(
110            PrimitiveArray::new(buffer!(0i32, 20, 2), Validity::AllValid).into_array(),
111            reference,
112        )
113        .unwrap();
114
115        assert_result(
116            compare_constant(&lhs, 30i32, Nullability::NonNullable, Operator::Eq),
117            [false, true, false],
118        );
119        assert_result(
120            compare_constant(&lhs, 12i32, Nullability::NonNullable, Operator::NotEq),
121            [true, true, false],
122        );
123        for op in [Operator::Lt, Operator::Lte, Operator::Gt, Operator::Gte] {
124            assert!(
125                compare_constant(&lhs, 30i32, Nullability::NonNullable, op)
126                    .unwrap()
127                    .is_none()
128            );
129        }
130    }
131
132    #[test]
133    fn test_compare_nullable_constant() {
134        let reference = Scalar::from(0);
135        // 10, 30, 12
136        let lhs = FoRArray::try_new(
137            PrimitiveArray::new(buffer!(0i32, 20, 2), Validity::NonNullable).into_array(),
138            reference,
139        )
140        .unwrap();
141
142        assert_eq!(
143            compare_constant(&lhs, 30i32, Nullability::Nullable, Operator::Eq)
144                .unwrap()
145                .unwrap()
146                .dtype(),
147            &DType::Bool(Nullability::Nullable)
148        );
149        assert_eq!(
150            compare_constant(&lhs, 30i32, Nullability::NonNullable, Operator::Eq)
151                .unwrap()
152                .unwrap()
153                .dtype(),
154            &DType::Bool(Nullability::NonNullable)
155        );
156    }
157
158    #[test]
159    fn compare_non_encodable_constant() {
160        let reference = Scalar::from(10);
161        // 10, 30, 12
162        let lhs = FoRArray::try_new(
163            PrimitiveArray::new(buffer!(0i32, 10, 1), Validity::AllValid).into_array(),
164            reference,
165        )
166        .unwrap();
167
168        assert_result(
169            compare_constant(&lhs, -1i32, Nullability::NonNullable, Operator::Eq),
170            [false, false, false],
171        );
172        assert_result(
173            compare_constant(&lhs, -1i32, Nullability::NonNullable, Operator::NotEq),
174            [true, true, true],
175        );
176    }
177
178    #[test]
179    fn compare_large_constant() {
180        let reference = Scalar::from(-9219218377546224477i64);
181        #[allow(clippy::cast_possible_truncation)]
182        let lhs = FoRArray::try_new(
183            PrimitiveArray::new(
184                buffer![0i64, 9654309310445864926u64 as i64],
185                Validity::AllValid,
186            )
187            .into_array(),
188            reference,
189        )
190        .unwrap();
191
192        assert_result(
193            compare_constant(
194                &lhs,
195                435090932899640449i64,
196                Nullability::Nullable,
197                Operator::Eq,
198            ),
199            [false, true],
200        );
201        assert_result(
202            compare_constant(
203                &lhs,
204                435090932899640449i64,
205                Nullability::Nullable,
206                Operator::NotEq,
207            ),
208            [true, false],
209        );
210    }
211
212    fn assert_result<T: IntoIterator<Item = bool>>(
213        result: VortexResult<Option<ArrayRef>>,
214        expected: T,
215    ) {
216        let result = result.unwrap().unwrap().to_bool();
217        assert_eq!(result.bit_buffer(), &BitBuffer::from_iter(expected));
218    }
219}