vortex_array/arrays/dict/compute/
compare.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5
6use super::{DictArray, DictVTable};
7use crate::arrays::ConstantArray;
8use crate::compute::{CompareKernel, CompareKernelAdapter, Operator, compare};
9use crate::{Array, ArrayRef, IntoArray, register_kernel};
10
11impl CompareKernel for DictVTable {
12    fn compare(
13        &self,
14        lhs: &DictArray,
15        rhs: &dyn Array,
16        operator: Operator,
17    ) -> VortexResult<Option<ArrayRef>> {
18        // if we have more values than codes, it is faster to canonicalise first.
19        if lhs.values().len() > lhs.codes().len() {
20            return Ok(None);
21        }
22
23        // If the RHS is constant, then we just need to compare against our encoded values.
24        if let Some(rhs) = rhs.as_constant() {
25            let compare_result = compare(
26                lhs.values(),
27                ConstantArray::new(rhs, lhs.values().len()).as_ref(),
28                operator,
29            )?;
30
31            // SAFETY: values len preserved, codes all still point to valid values
32            let result = unsafe {
33                DictArray::new_unchecked(lhs.codes().clone(), compare_result).into_array()
34            };
35
36            // We canonicalize the result because dictionary-encoded bools is dumb.
37            return Ok(Some(result.to_canonical().into_array()));
38        }
39
40        // It's a little more complex, but we could perform a comparison against the dictionary
41        // values in the future.
42        Ok(None)
43    }
44}
45
46register_kernel!(CompareKernelAdapter(DictVTable).lift());
47#[cfg(test)]
48mod tests {
49    use vortex_buffer::buffer;
50    use vortex_dtype::Nullability;
51    use vortex_mask::Mask;
52    use vortex_scalar::Scalar;
53
54    use crate::arrays::dict::DictArray;
55    use crate::arrays::{ConstantArray, PrimitiveArray};
56    use crate::compute::{Operator, compare};
57    use crate::validity::Validity;
58    use crate::{IntoArray, ToCanonical};
59
60    #[test]
61    fn test_compare_value() {
62        let dict = DictArray::try_new(
63            buffer![0u32, 1, 2].into_array(),
64            buffer![1i32, 2, 3].into_array(),
65        )
66        .unwrap();
67
68        let res = compare(
69            dict.as_ref(),
70            ConstantArray::new(Scalar::from(1i32), 3).as_ref(),
71            Operator::Eq,
72        )
73        .unwrap();
74        let res = res.to_bool();
75        assert_eq!(
76            res.bit_buffer().iter().collect::<Vec<_>>(),
77            vec![true, false, false]
78        );
79    }
80
81    #[test]
82    fn test_compare_non_eq() {
83        let dict = DictArray::try_new(
84            buffer![0u32, 1, 2].into_array(),
85            buffer![1i32, 2, 3].into_array(),
86        )
87        .unwrap();
88
89        let res = compare(
90            dict.as_ref(),
91            ConstantArray::new(Scalar::from(1i32), 3).as_ref(),
92            Operator::Gt,
93        )
94        .unwrap();
95        let res = res.to_bool();
96        assert_eq!(
97            res.bit_buffer().iter().collect::<Vec<_>>(),
98            vec![false, true, true]
99        );
100    }
101
102    #[test]
103    fn test_compare_nullable() {
104        let dict = DictArray::try_new(
105            PrimitiveArray::new(
106                buffer![0u32, 1, 2],
107                Validity::from_iter([false, true, false]),
108            )
109            .into_array(),
110            PrimitiveArray::new(buffer![1i32, 2, 3], Validity::AllValid).into_array(),
111        )
112        .unwrap();
113
114        let res = compare(
115            dict.as_ref(),
116            ConstantArray::new(Scalar::primitive(4i32, Nullability::Nullable), 3).as_ref(),
117            Operator::Eq,
118        )
119        .unwrap();
120        let res = res.to_bool();
121        assert_eq!(
122            res.bit_buffer().iter().collect::<Vec<_>>(),
123            vec![false, false, false]
124        );
125        assert_eq!(res.dtype().nullability(), Nullability::Nullable);
126        assert_eq!(res.validity_mask(), Mask::from_iter([false, true, false]));
127    }
128
129    #[test]
130    fn test_compare_null_values() {
131        let dict = DictArray::try_new(
132            buffer![0u32, 1, 2].into_array(),
133            PrimitiveArray::new(
134                buffer![1i32, 2, 0],
135                Validity::from_iter([true, true, false]),
136            )
137            .into_array(),
138        )
139        .unwrap();
140
141        let res = compare(
142            dict.as_ref(),
143            ConstantArray::new(Scalar::primitive(4i32, Nullability::NonNullable), 3).as_ref(),
144            Operator::Eq,
145        )
146        .unwrap();
147        let res = res.to_bool();
148        assert_eq!(
149            res.bit_buffer().iter().collect::<Vec<_>>(),
150            vec![false, false, false]
151        );
152        assert_eq!(res.dtype().nullability(), Nullability::Nullable);
153        assert_eq!(res.validity_mask(), Mask::from_iter([true, true, false]));
154    }
155}