Skip to main content

vortex_fastlanes/for/compute/
compare.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::ops::Shr;
5
6use num_traits::WrappingSub;
7use vortex_array::Array;
8use vortex_array::ArrayRef;
9use vortex_array::ExecutionCtx;
10use vortex_array::arrays::ConstantArray;
11use vortex_array::compute::Operator;
12use vortex_array::compute::compare;
13use vortex_array::expr::CompareKernel;
14use vortex_array::scalar::PValue;
15use vortex_array::scalar::Scalar;
16use vortex_dtype::NativePType;
17use vortex_dtype::Nullability;
18use vortex_dtype::match_each_integer_ptype;
19use vortex_error::VortexError;
20use vortex_error::VortexExpect as _;
21use vortex_error::VortexResult;
22
23use crate::FoRArray;
24use crate::FoRVTable;
25
26impl CompareKernel for FoRVTable {
27    fn compare(
28        lhs: &FoRArray,
29        rhs: &dyn Array,
30        operator: Operator,
31        _ctx: &mut ExecutionCtx,
32    ) -> VortexResult<Option<ArrayRef>> {
33        if let Some(constant) = rhs.as_constant()
34            && let Some(constant) = constant.as_primitive_opt()
35        {
36            match_each_integer_ptype!(constant.ptype(), |T| {
37                return compare_constant(
38                    lhs,
39                    constant
40                        .typed_value::<T>()
41                        .vortex_expect("null scalar handled in adaptor"),
42                    rhs.dtype().nullability(),
43                    operator,
44                );
45            })
46        }
47
48        Ok(None)
49    }
50}
51
52fn compare_constant<T>(
53    lhs: &FoRArray,
54    mut rhs: T,
55    nullability: Nullability,
56    operator: Operator,
57) -> VortexResult<Option<ArrayRef>>
58where
59    T: NativePType + WrappingSub + Shr<usize, Output = T>,
60    T: TryFrom<PValue, Error = VortexError>,
61    PValue: From<T>,
62{
63    // For now, we only support equals and not equals. Comparisons are a little more fiddly to
64    // get right regarding how to handle overflow and the wrapping subtraction.
65    if !matches!(operator, Operator::Eq | Operator::NotEq) {
66        return Ok(None);
67    }
68
69    let reference = lhs.reference_scalar();
70    let reference = reference.as_primitive().typed_value::<T>();
71
72    // We encode the RHS into the FoR domain.
73    if let Some(reference) = reference {
74        rhs = rhs.wrapping_sub(&reference);
75    }
76
77    // Wrap up the RHS into a scalar and cast to the encoded DType (this will be the equivalent
78    // unsigned integer type).
79    let rhs = Scalar::primitive(rhs, nullability);
80
81    compare(
82        lhs.encoded(),
83        ConstantArray::new(rhs, lhs.len()).as_ref(),
84        operator,
85    )
86    .map(Some)
87}
88
89#[cfg(test)]
90mod tests {
91    use vortex_array::IntoArray;
92    use vortex_array::arrays::BoolArray;
93    use vortex_array::arrays::PrimitiveArray;
94    use vortex_array::assert_arrays_eq;
95    use vortex_array::validity::Validity;
96    use vortex_buffer::buffer;
97    use vortex_dtype::DType;
98
99    use super::*;
100
101    #[test]
102    fn test_compare_constant() {
103        let reference = Scalar::from(10);
104        // 10, 30, 12
105        let lhs = FoRArray::try_new(
106            PrimitiveArray::new(buffer!(0i32, 20, 2), Validity::AllValid).into_array(),
107            reference,
108        )
109        .unwrap();
110
111        let result = compare_constant(&lhs, 30i32, Nullability::NonNullable, Operator::Eq)
112            .unwrap()
113            .unwrap();
114        assert_arrays_eq!(result, BoolArray::from_iter([false, true, false].map(Some)));
115
116        let result = compare_constant(&lhs, 12i32, Nullability::NonNullable, Operator::NotEq)
117            .unwrap()
118            .unwrap();
119        assert_arrays_eq!(result, BoolArray::from_iter([true, true, false].map(Some)));
120
121        for op in [Operator::Lt, Operator::Lte, Operator::Gt, Operator::Gte] {
122            assert!(
123                compare_constant(&lhs, 30i32, Nullability::NonNullable, op)
124                    .unwrap()
125                    .is_none()
126            );
127        }
128    }
129
130    #[test]
131    fn test_compare_nullable_constant() {
132        let reference = Scalar::from(0);
133        // 10, 30, 12
134        let lhs = FoRArray::try_new(
135            PrimitiveArray::new(buffer!(0i32, 20, 2), Validity::NonNullable).into_array(),
136            reference,
137        )
138        .unwrap();
139
140        assert_eq!(
141            compare_constant(&lhs, 30i32, Nullability::Nullable, Operator::Eq)
142                .unwrap()
143                .unwrap()
144                .dtype(),
145            &DType::Bool(Nullability::Nullable)
146        );
147        assert_eq!(
148            compare_constant(&lhs, 30i32, Nullability::NonNullable, Operator::Eq)
149                .unwrap()
150                .unwrap()
151                .dtype(),
152            &DType::Bool(Nullability::NonNullable)
153        );
154    }
155
156    #[test]
157    fn compare_non_encodable_constant() {
158        let reference = Scalar::from(10);
159        // 10, 30, 12
160        let lhs = FoRArray::try_new(
161            PrimitiveArray::new(buffer!(0i32, 10, 1), Validity::AllValid).into_array(),
162            reference,
163        )
164        .unwrap();
165
166        let result = compare_constant(&lhs, -1i32, Nullability::NonNullable, Operator::Eq)
167            .unwrap()
168            .unwrap();
169        assert_arrays_eq!(
170            result,
171            BoolArray::from_iter([false, false, false].map(Some))
172        );
173
174        let result = compare_constant(&lhs, -1i32, Nullability::NonNullable, Operator::NotEq)
175            .unwrap()
176            .unwrap();
177        assert_arrays_eq!(result, BoolArray::from_iter([true, true, true].map(Some)));
178    }
179
180    #[test]
181    fn compare_large_constant() {
182        let reference = Scalar::from(-9219218377546224477i64);
183        #[allow(clippy::cast_possible_truncation)]
184        let lhs = FoRArray::try_new(
185            PrimitiveArray::new(
186                buffer![0i64, 9654309310445864926u64 as i64],
187                Validity::AllValid,
188            )
189            .into_array(),
190            reference,
191        )
192        .unwrap();
193
194        let result = compare_constant(
195            &lhs,
196            435090932899640449i64,
197            Nullability::Nullable,
198            Operator::Eq,
199        )
200        .unwrap()
201        .unwrap();
202        assert_arrays_eq!(result, BoolArray::from_iter([Some(false), Some(true)]));
203
204        let result = compare_constant(
205            &lhs,
206            435090932899640449i64,
207            Nullability::Nullable,
208            Operator::NotEq,
209        )
210        .unwrap()
211        .unwrap();
212        assert_arrays_eq!(result, BoolArray::from_iter([Some(true), Some(false)]));
213    }
214}