vortex_dict/compute/
compare.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::arrays::ConstantArray;
5use vortex_array::builders::builder_with_capacity;
6use vortex_array::compute::{CompareKernel, CompareKernelAdapter, Operator, cast, compare};
7use vortex_array::validity::Validity;
8use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
9use vortex_dtype::{DType, Nullability};
10use vortex_error::VortexResult;
11use vortex_mask::{AllOr, Mask};
12use vortex_scalar::Scalar;
13
14use crate::{DictArray, DictVTable};
15
16impl CompareKernel for DictVTable {
17    fn compare(
18        &self,
19        lhs: &DictArray,
20        rhs: &dyn Array,
21        operator: Operator,
22    ) -> VortexResult<Option<ArrayRef>> {
23        // if we have more values than codes, it is faster to canonicalise first.
24        if lhs.values().len() > lhs.codes().len() {
25            return Ok(None);
26        }
27        // If the RHS is constant, then we just need to compare against our encoded values.
28        if let Some(rhs) = rhs.as_constant() {
29            let compare_result = compare(
30                lhs.values(),
31                ConstantArray::new(rhs, lhs.values().len()).as_ref(),
32                operator,
33            )?;
34            return if operator == Operator::Eq {
35                let result_nullability =
36                    compare_result.dtype().nullability() | lhs.dtype().nullability();
37                dict_equal_to(compare_result, lhs.codes(), result_nullability).map(Some)
38            } else {
39                DictArray::try_new(lhs.codes().clone(), compare_result)
40                    .map(|a| a.into_array())
41                    .map(Some)
42            };
43        }
44
45        // It's a little more complex, but we could perform a comparison against the dictionary
46        // values in the future.
47        Ok(None)
48    }
49}
50
51register_kernel!(CompareKernelAdapter(DictVTable).lift());
52
53fn dict_equal_to(
54    values_compare: ArrayRef,
55    codes: &ArrayRef,
56    result_nullability: Nullability,
57) -> VortexResult<ArrayRef> {
58    let bool_result = values_compare.to_bool()?;
59    let result_validity = bool_result.validity_mask()?;
60    let bool_buffer = bool_result.boolean_buffer();
61    let (first_match, second_match) = match result_validity.boolean_buffer() {
62        AllOr::All => {
63            let mut indices_iter = bool_buffer.set_indices();
64            (indices_iter.next(), indices_iter.next())
65        }
66        AllOr::None => (None, None),
67        AllOr::Some(v) => {
68            let mut indices_iter = bool_buffer.set_indices().filter(|i| v.value(*i));
69            (indices_iter.next(), indices_iter.next())
70        }
71    };
72
73    Ok(match (first_match, second_match) {
74        // Couldn't find a value match, so the result is all false
75        (None, _) => match result_validity {
76            Mask::AllTrue(_) => {
77                let mut result_builder =
78                    builder_with_capacity(&DType::Bool(result_nullability), codes.len());
79                result_builder.extend_from_array(
80                    &ConstantArray::new(Scalar::bool(false, result_nullability), codes.len())
81                        .into_array(),
82                )?;
83                result_builder.set_validity(codes.validity_mask()?);
84                result_builder.finish()
85            }
86            Mask::AllFalse(_) => ConstantArray::new(
87                Scalar::null(DType::Bool(Nullability::Nullable)),
88                codes.len(),
89            )
90            .into_array(),
91            Mask::Values(_) => {
92                let mut result_builder =
93                    builder_with_capacity(&DType::Bool(result_nullability), codes.len());
94                result_builder.extend_from_array(
95                    &ConstantArray::new(Scalar::bool(false, result_nullability), codes.len())
96                        .into_array(),
97                )?;
98                result_builder.set_validity(
99                    Validity::from_mask(result_validity, bool_result.dtype().nullability())
100                        .take(codes)?
101                        .to_mask(codes.len())?,
102                );
103                result_builder.finish()
104            }
105        },
106        // We found a single matching value so we can compare the codes directly.
107        // Note: the codes include nullability so we can just compare the codes directly, to the found code.
108        (Some(code), None) => cast(
109            &compare(
110                codes,
111                &cast(
112                    ConstantArray::new(code, codes.len()).as_ref(),
113                    codes.dtype(),
114                )?,
115                Operator::Eq,
116            )?,
117            &DType::Bool(result_nullability),
118        )?,
119        // more than one value matches
120        _ => DictArray::try_new(codes.clone(), bool_result.into_array())?.into_array(),
121    })
122}
123
124#[cfg(test)]
125mod tests {
126    use vortex_array::arrays::{ConstantArray, PrimitiveArray};
127    use vortex_array::compute::{Operator, compare};
128    use vortex_array::validity::Validity;
129    use vortex_array::{IntoArray, ToCanonical};
130    use vortex_buffer::buffer;
131    use vortex_dtype::Nullability;
132    use vortex_mask::Mask;
133    use vortex_scalar::Scalar;
134
135    use crate::DictArray;
136
137    #[test]
138    fn test_compare_value() {
139        let dict = DictArray::try_new(
140            buffer![0u32, 1, 2].into_array(),
141            buffer![1i32, 2, 3].into_array(),
142        )
143        .unwrap();
144
145        let res = compare(
146            dict.as_ref(),
147            ConstantArray::new(Scalar::from(1i32), 3).as_ref(),
148            Operator::Eq,
149        )
150        .unwrap();
151        let res = res.to_bool().unwrap();
152        assert_eq!(
153            res.boolean_buffer().iter().collect::<Vec<_>>(),
154            vec![true, false, false]
155        );
156    }
157
158    #[test]
159    fn test_compare_non_eq() {
160        let dict = DictArray::try_new(
161            buffer![0u32, 1, 2].into_array(),
162            buffer![1i32, 2, 3].into_array(),
163        )
164        .unwrap();
165
166        let res = compare(
167            dict.as_ref(),
168            ConstantArray::new(Scalar::from(1i32), 3).as_ref(),
169            Operator::Gt,
170        )
171        .unwrap();
172        let res = res.to_bool().unwrap();
173        assert_eq!(
174            res.boolean_buffer().iter().collect::<Vec<_>>(),
175            vec![false, true, true]
176        );
177    }
178
179    #[test]
180    fn test_compare_nullable() {
181        let dict = DictArray::try_new(
182            PrimitiveArray::new(
183                buffer![0u32, 1, 2],
184                Validity::from_iter([false, true, false]),
185            )
186            .into_array(),
187            PrimitiveArray::new(buffer![1i32, 2, 3], Validity::AllValid).into_array(),
188        )
189        .unwrap();
190
191        let res = compare(
192            dict.as_ref(),
193            ConstantArray::new(Scalar::primitive(4i32, Nullability::Nullable), 3).as_ref(),
194            Operator::Eq,
195        )
196        .unwrap();
197        let res = res.to_bool().unwrap();
198        assert_eq!(
199            res.boolean_buffer().iter().collect::<Vec<_>>(),
200            vec![false, false, false]
201        );
202        assert_eq!(res.dtype().nullability(), Nullability::Nullable);
203        assert_eq!(
204            res.validity_mask().unwrap(),
205            Mask::from_iter([false, true, false])
206        );
207    }
208
209    #[test]
210    fn test_compare_null_values() {
211        let dict = DictArray::try_new(
212            buffer![0u32, 1, 2].into_array(),
213            PrimitiveArray::new(
214                buffer![1i32, 2, 0],
215                Validity::from_iter([true, true, false]),
216            )
217            .into_array(),
218        )
219        .unwrap();
220
221        let res = compare(
222            dict.as_ref(),
223            ConstantArray::new(Scalar::primitive(4i32, Nullability::NonNullable), 3).as_ref(),
224            Operator::Eq,
225        )
226        .unwrap();
227        let res = res.to_bool().unwrap();
228        assert_eq!(
229            res.boolean_buffer().iter().collect::<Vec<_>>(),
230            vec![false, false, false]
231        );
232        assert_eq!(res.dtype().nullability(), Nullability::Nullable);
233        assert_eq!(
234            res.validity_mask().unwrap(),
235            Mask::from_iter([true, true, false])
236        );
237    }
238}