Skip to main content

vortex_fastlanes/for/compute/
compare.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::ops::Shr;
5
6use num_traits::WrappingSub;
7use vortex_array::Array;
8use vortex_array::ArrayRef;
9use vortex_array::ExecutionCtx;
10use vortex_array::IntoArray;
11use vortex_array::arrays::ConstantArray;
12use vortex_array::builtins::ArrayBuiltins;
13use vortex_array::dtype::NativePType;
14use vortex_array::dtype::Nullability;
15use vortex_array::match_each_integer_ptype;
16use vortex_array::scalar::PValue;
17use vortex_array::scalar::Scalar;
18use vortex_array::scalar_fn::fns::binary::CompareKernel;
19use vortex_array::scalar_fn::fns::operators::CompareOperator;
20use vortex_array::scalar_fn::fns::operators::Operator;
21use vortex_error::VortexError;
22use vortex_error::VortexExpect as _;
23use vortex_error::VortexResult;
24
25use crate::FoRArray;
26use crate::FoRVTable;
27
28impl CompareKernel for FoRVTable {
29    fn compare(
30        lhs: &FoRArray,
31        rhs: &dyn Array,
32        operator: CompareOperator,
33        _ctx: &mut ExecutionCtx,
34    ) -> VortexResult<Option<ArrayRef>> {
35        if let Some(constant) = rhs.as_constant()
36            && let Some(constant) = constant.as_primitive_opt()
37        {
38            match_each_integer_ptype!(constant.ptype(), |T| {
39                return compare_constant(
40                    lhs,
41                    constant
42                        .typed_value::<T>()
43                        .vortex_expect("null scalar handled in adaptor"),
44                    rhs.dtype().nullability(),
45                    operator,
46                );
47            })
48        }
49
50        Ok(None)
51    }
52}
53
54fn compare_constant<T>(
55    lhs: &FoRArray,
56    mut rhs: T,
57    nullability: Nullability,
58    operator: CompareOperator,
59) -> VortexResult<Option<ArrayRef>>
60where
61    T: NativePType + WrappingSub + Shr<usize, Output = T>,
62    T: TryFrom<PValue, Error = VortexError>,
63    PValue: From<T>,
64{
65    // For now, we only support equals and not equals. Comparisons are a little more fiddly to
66    // get right regarding how to handle overflow and the wrapping subtraction.
67    if !matches!(operator, CompareOperator::Eq | CompareOperator::NotEq) {
68        return Ok(None);
69    }
70
71    let reference = lhs.reference_scalar();
72    let reference = reference.as_primitive().typed_value::<T>();
73
74    // We encode the RHS into the FoR domain.
75    if let Some(reference) = reference {
76        rhs = rhs.wrapping_sub(&reference);
77    }
78
79    // Wrap up the RHS into a scalar and cast to the encoded DType (this will be the equivalent
80    // unsigned integer type).
81    let rhs = Scalar::primitive(rhs, nullability);
82
83    lhs.encoded()
84        .binary(
85            ConstantArray::new(rhs, lhs.len()).into_array(),
86            Operator::from(operator),
87        )
88        .map(Some)
89}
90
91#[cfg(test)]
92mod tests {
93    use vortex_array::IntoArray;
94    use vortex_array::arrays::BoolArray;
95    use vortex_array::arrays::PrimitiveArray;
96    use vortex_array::assert_arrays_eq;
97    use vortex_array::dtype::DType;
98    use vortex_array::validity::Validity;
99    use vortex_buffer::buffer;
100
101    use super::*;
102
103    #[test]
104    fn test_compare_constant() {
105        let reference = Scalar::from(10);
106        // 10, 30, 12
107        let lhs = FoRArray::try_new(
108            PrimitiveArray::new(buffer!(0i32, 20, 2), Validity::AllValid).into_array(),
109            reference,
110        )
111        .unwrap();
112
113        let result = compare_constant(&lhs, 30i32, Nullability::NonNullable, CompareOperator::Eq)
114            .unwrap()
115            .unwrap();
116        assert_arrays_eq!(result, BoolArray::from_iter([false, true, false].map(Some)));
117
118        let result = compare_constant(
119            &lhs,
120            12i32,
121            Nullability::NonNullable,
122            CompareOperator::NotEq,
123        )
124        .unwrap()
125        .unwrap();
126        assert_arrays_eq!(result, BoolArray::from_iter([true, true, false].map(Some)));
127
128        for op in [
129            CompareOperator::Lt,
130            CompareOperator::Lte,
131            CompareOperator::Gt,
132            CompareOperator::Gte,
133        ] {
134            assert!(
135                compare_constant(&lhs, 30i32, Nullability::NonNullable, op)
136                    .unwrap()
137                    .is_none()
138            );
139        }
140    }
141
142    #[test]
143    fn test_compare_nullable_constant() {
144        let reference = Scalar::from(0);
145        // 10, 30, 12
146        let lhs = FoRArray::try_new(
147            PrimitiveArray::new(buffer!(0i32, 20, 2), Validity::NonNullable).into_array(),
148            reference,
149        )
150        .unwrap();
151
152        assert_eq!(
153            compare_constant(&lhs, 30i32, Nullability::Nullable, CompareOperator::Eq)
154                .unwrap()
155                .unwrap()
156                .dtype(),
157            &DType::Bool(Nullability::Nullable)
158        );
159        assert_eq!(
160            compare_constant(&lhs, 30i32, Nullability::NonNullable, CompareOperator::Eq)
161                .unwrap()
162                .unwrap()
163                .dtype(),
164            &DType::Bool(Nullability::NonNullable)
165        );
166    }
167
168    #[test]
169    fn compare_non_encodable_constant() {
170        let reference = Scalar::from(10);
171        // 10, 30, 12
172        let lhs = FoRArray::try_new(
173            PrimitiveArray::new(buffer!(0i32, 10, 1), Validity::AllValid).into_array(),
174            reference,
175        )
176        .unwrap();
177
178        let result = compare_constant(&lhs, -1i32, Nullability::NonNullable, CompareOperator::Eq)
179            .unwrap()
180            .unwrap();
181        assert_arrays_eq!(
182            result,
183            BoolArray::from_iter([false, false, false].map(Some))
184        );
185
186        let result = compare_constant(
187            &lhs,
188            -1i32,
189            Nullability::NonNullable,
190            CompareOperator::NotEq,
191        )
192        .unwrap()
193        .unwrap();
194        assert_arrays_eq!(result, BoolArray::from_iter([true, true, true].map(Some)));
195    }
196
197    #[test]
198    fn compare_large_constant() {
199        let reference = Scalar::from(-9219218377546224477i64);
200        #[allow(clippy::cast_possible_truncation)]
201        let lhs = FoRArray::try_new(
202            PrimitiveArray::new(
203                buffer![0i64, 9654309310445864926u64 as i64],
204                Validity::AllValid,
205            )
206            .into_array(),
207            reference,
208        )
209        .unwrap();
210
211        let result = compare_constant(
212            &lhs,
213            435090932899640449i64,
214            Nullability::Nullable,
215            CompareOperator::Eq,
216        )
217        .unwrap()
218        .unwrap();
219        assert_arrays_eq!(result, BoolArray::from_iter([Some(false), Some(true)]));
220
221        let result = compare_constant(
222            &lhs,
223            435090932899640449i64,
224            Nullability::Nullable,
225            CompareOperator::NotEq,
226        )
227        .unwrap()
228        .unwrap();
229        assert_arrays_eq!(result, BoolArray::from_iter([Some(true), Some(false)]));
230    }
231}