vortex_array/arrays/dict/compute/
compare.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5
6use super::DictArray;
7use super::DictVTable;
8use crate::Array;
9use crate::ArrayRef;
10use crate::IntoArray;
11use crate::arrays::ConstantArray;
12use crate::compute::CompareKernel;
13use crate::compute::CompareKernelAdapter;
14use crate::compute::Operator;
15use crate::compute::compare;
16use crate::register_kernel;
17
18impl CompareKernel for DictVTable {
19    fn compare(
20        &self,
21        lhs: &DictArray,
22        rhs: &dyn Array,
23        operator: Operator,
24    ) -> VortexResult<Option<ArrayRef>> {
25        // if we have more values than codes, it is faster to canonicalise first.
26        if lhs.values().len() > lhs.codes().len() {
27            return Ok(None);
28        }
29
30        // If the RHS is constant, then we just need to compare against our encoded values.
31        if let Some(rhs) = rhs.as_constant() {
32            let compare_result = compare(
33                lhs.values(),
34                ConstantArray::new(rhs, lhs.values().len()).as_ref(),
35                operator,
36            )?;
37
38            // SAFETY: values len preserved, codes all still point to valid values
39            let result = unsafe {
40                DictArray::new_unchecked(lhs.codes().clone(), compare_result)
41                    .set_all_values_referenced(lhs.has_all_values_referenced())
42                    .into_array()
43            };
44
45            // We canonicalize the result because dictionary-encoded bools is dumb.
46            return Ok(Some(result.to_canonical().into_array()));
47        }
48
49        // It's a little more complex, but we could perform a comparison against the dictionary
50        // values in the future.
51        Ok(None)
52    }
53}
54
55register_kernel!(CompareKernelAdapter(DictVTable).lift());
56#[cfg(test)]
57mod tests {
58    use vortex_buffer::buffer;
59    use vortex_dtype::Nullability;
60    use vortex_mask::Mask;
61    use vortex_scalar::Scalar;
62
63    use crate::IntoArray;
64    use crate::ToCanonical;
65    use crate::arrays::ConstantArray;
66    use crate::arrays::PrimitiveArray;
67    use crate::arrays::dict::DictArray;
68    use crate::compute::Operator;
69    use crate::compute::compare;
70    use crate::validity::Validity;
71
72    #[test]
73    fn test_compare_value() {
74        let dict = DictArray::try_new(
75            buffer![0u32, 1, 2].into_array(),
76            buffer![1i32, 2, 3].into_array(),
77        )
78        .unwrap();
79
80        let res = compare(
81            dict.as_ref(),
82            ConstantArray::new(Scalar::from(1i32), 3).as_ref(),
83            Operator::Eq,
84        )
85        .unwrap();
86        let res = res.to_bool();
87        assert_eq!(
88            res.bit_buffer().iter().collect::<Vec<_>>(),
89            vec![true, false, false]
90        );
91    }
92
93    #[test]
94    fn test_compare_non_eq() {
95        let dict = DictArray::try_new(
96            buffer![0u32, 1, 2].into_array(),
97            buffer![1i32, 2, 3].into_array(),
98        )
99        .unwrap();
100
101        let res = compare(
102            dict.as_ref(),
103            ConstantArray::new(Scalar::from(1i32), 3).as_ref(),
104            Operator::Gt,
105        )
106        .unwrap();
107        let res = res.to_bool();
108        assert_eq!(
109            res.bit_buffer().iter().collect::<Vec<_>>(),
110            vec![false, true, true]
111        );
112    }
113
114    #[test]
115    fn test_compare_nullable() {
116        let dict = DictArray::try_new(
117            PrimitiveArray::new(
118                buffer![0u32, 1, 2],
119                Validity::from_iter([false, true, false]),
120            )
121            .into_array(),
122            PrimitiveArray::new(buffer![1i32, 2, 3], Validity::AllValid).into_array(),
123        )
124        .unwrap();
125
126        let res = compare(
127            dict.as_ref(),
128            ConstantArray::new(Scalar::primitive(4i32, Nullability::Nullable), 3).as_ref(),
129            Operator::Eq,
130        )
131        .unwrap();
132        let res = res.to_bool();
133        assert_eq!(
134            res.bit_buffer().iter().collect::<Vec<_>>(),
135            vec![false, false, false]
136        );
137        assert_eq!(res.dtype().nullability(), Nullability::Nullable);
138        assert_eq!(res.validity_mask(), Mask::from_iter([false, true, false]));
139    }
140
141    #[test]
142    fn test_compare_null_values() {
143        let dict = DictArray::try_new(
144            buffer![0u32, 1, 2].into_array(),
145            PrimitiveArray::new(
146                buffer![1i32, 2, 0],
147                Validity::from_iter([true, true, false]),
148            )
149            .into_array(),
150        )
151        .unwrap();
152
153        let res = compare(
154            dict.as_ref(),
155            ConstantArray::new(Scalar::primitive(4i32, Nullability::NonNullable), 3).as_ref(),
156            Operator::Eq,
157        )
158        .unwrap();
159        let res = res.to_bool();
160        assert_eq!(
161            res.bit_buffer().iter().collect::<Vec<_>>(),
162            vec![false, false, false]
163        );
164        assert_eq!(res.dtype().nullability(), Nullability::Nullable);
165        assert_eq!(res.validity_mask(), Mask::from_iter([true, true, false]));
166    }
167}