vortex_dict/compute/
compare.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::arrays::ConstantArray;
5use vortex_array::compute::{CompareKernel, CompareKernelAdapter, Operator, compare};
6use vortex_array::{Array, ArrayRef, IntoArray, register_kernel};
7use vortex_error::VortexResult;
8
9use crate::{DictArray, DictVTable};
10
11impl CompareKernel for DictVTable {
12    fn compare(
13        &self,
14        lhs: &DictArray,
15        rhs: &dyn Array,
16        operator: Operator,
17    ) -> VortexResult<Option<ArrayRef>> {
18        // if we have more values than codes, it is faster to canonicalise first.
19        if lhs.values().len() > lhs.codes().len() {
20            return Ok(None);
21        }
22
23        // If the RHS is constant, then we just need to compare against our encoded values.
24        if let Some(rhs) = rhs.as_constant() {
25            let compare_result = compare(
26                lhs.values(),
27                ConstantArray::new(rhs, lhs.values().len()).as_ref(),
28                operator,
29            )?;
30
31            // SAFETY: values len preserved, codes all still point to valid values
32            let result = unsafe {
33                DictArray::new_unchecked(lhs.codes().clone(), compare_result).into_array()
34            };
35
36            // We canonicalize the result because dictionary-encoded bools is dumb.
37            return Ok(Some(result.to_canonical().into_array()));
38        }
39
40        // It's a little more complex, but we could perform a comparison against the dictionary
41        // values in the future.
42        Ok(None)
43    }
44}
45
46register_kernel!(CompareKernelAdapter(DictVTable).lift());
47
48#[cfg(test)]
49mod tests {
50    use vortex_array::arrays::{ConstantArray, PrimitiveArray};
51    use vortex_array::compute::{Operator, compare};
52    use vortex_array::validity::Validity;
53    use vortex_array::{IntoArray, ToCanonical};
54    use vortex_buffer::buffer;
55    use vortex_dtype::Nullability;
56    use vortex_mask::Mask;
57    use vortex_scalar::Scalar;
58
59    use crate::DictArray;
60
61    #[test]
62    fn test_compare_value() {
63        let dict = DictArray::try_new(
64            buffer![0u32, 1, 2].into_array(),
65            buffer![1i32, 2, 3].into_array(),
66        )
67        .unwrap();
68
69        let res = compare(
70            dict.as_ref(),
71            ConstantArray::new(Scalar::from(1i32), 3).as_ref(),
72            Operator::Eq,
73        )
74        .unwrap();
75        let res = res.to_bool();
76        assert_eq!(
77            res.boolean_buffer().iter().collect::<Vec<_>>(),
78            vec![true, false, false]
79        );
80    }
81
82    #[test]
83    fn test_compare_non_eq() {
84        let dict = DictArray::try_new(
85            buffer![0u32, 1, 2].into_array(),
86            buffer![1i32, 2, 3].into_array(),
87        )
88        .unwrap();
89
90        let res = compare(
91            dict.as_ref(),
92            ConstantArray::new(Scalar::from(1i32), 3).as_ref(),
93            Operator::Gt,
94        )
95        .unwrap();
96        let res = res.to_bool();
97        assert_eq!(
98            res.boolean_buffer().iter().collect::<Vec<_>>(),
99            vec![false, true, true]
100        );
101    }
102
103    #[test]
104    fn test_compare_nullable() {
105        let dict = DictArray::try_new(
106            PrimitiveArray::new(
107                buffer![0u32, 1, 2],
108                Validity::from_iter([false, true, false]),
109            )
110            .into_array(),
111            PrimitiveArray::new(buffer![1i32, 2, 3], Validity::AllValid).into_array(),
112        )
113        .unwrap();
114
115        let res = compare(
116            dict.as_ref(),
117            ConstantArray::new(Scalar::primitive(4i32, Nullability::Nullable), 3).as_ref(),
118            Operator::Eq,
119        )
120        .unwrap();
121        let res = res.to_bool();
122        assert_eq!(
123            res.boolean_buffer().iter().collect::<Vec<_>>(),
124            vec![false, false, false]
125        );
126        assert_eq!(res.dtype().nullability(), Nullability::Nullable);
127        assert_eq!(res.validity_mask(), Mask::from_iter([false, true, false]));
128    }
129
130    #[test]
131    fn test_compare_null_values() {
132        let dict = DictArray::try_new(
133            buffer![0u32, 1, 2].into_array(),
134            PrimitiveArray::new(
135                buffer![1i32, 2, 0],
136                Validity::from_iter([true, true, false]),
137            )
138            .into_array(),
139        )
140        .unwrap();
141
142        let res = compare(
143            dict.as_ref(),
144            ConstantArray::new(Scalar::primitive(4i32, Nullability::NonNullable), 3).as_ref(),
145            Operator::Eq,
146        )
147        .unwrap();
148        let res = res.to_bool();
149        assert_eq!(
150            res.boolean_buffer().iter().collect::<Vec<_>>(),
151            vec![false, false, false]
152        );
153        assert_eq!(res.dtype().nullability(), Nullability::Nullable);
154        assert_eq!(res.validity_mask(), Mask::from_iter([true, true, false]));
155    }
156}