vortex_dict/compute/
compare.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::arrays::ConstantArray;
5use vortex_array::builders::builder_with_capacity;
6use vortex_array::compute::{CompareKernel, CompareKernelAdapter, Operator, cast, compare};
7use vortex_array::validity::Validity;
8use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
9use vortex_dtype::{DType, Nullability};
10use vortex_error::VortexResult;
11use vortex_mask::{AllOr, Mask};
12use vortex_scalar::Scalar;
13
14use crate::{DictArray, DictVTable};
15
16impl CompareKernel for DictVTable {
17    fn compare(
18        &self,
19        lhs: &DictArray,
20        rhs: &dyn Array,
21        operator: Operator,
22    ) -> VortexResult<Option<ArrayRef>> {
23        // if we have more values than codes, it is faster to canonicalise first.
24        if lhs.values().len() > lhs.codes().len() {
25            return Ok(None);
26        }
27        // If the RHS is constant, then we just need to compare against our encoded values.
28        if let Some(rhs) = rhs.as_constant() {
29            let compare_result = compare(
30                lhs.values(),
31                ConstantArray::new(rhs, lhs.values().len()).as_ref(),
32                operator,
33            )?;
34            return if operator == Operator::Eq {
35                let result_nullability =
36                    compare_result.dtype().nullability() | lhs.dtype().nullability();
37                dict_equal_to(compare_result, lhs.codes(), result_nullability).map(Some)
38            } else {
39                // SAFETY: values len preserved, codes all still point to valid values
40                unsafe {
41                    Ok(Some(
42                        DictArray::new_unchecked(lhs.codes().clone(), compare_result).into_array(),
43                    ))
44                }
45            };
46        }
47
48        // It's a little more complex, but we could perform a comparison against the dictionary
49        // values in the future.
50        Ok(None)
51    }
52}
53
54register_kernel!(CompareKernelAdapter(DictVTable).lift());
55
56fn dict_equal_to(
57    values_compare: ArrayRef,
58    codes: &ArrayRef,
59    result_nullability: Nullability,
60) -> VortexResult<ArrayRef> {
61    let bool_result = values_compare.to_bool()?;
62    let result_validity = bool_result.validity_mask();
63    let bool_buffer = bool_result.boolean_buffer();
64    let (first_match, second_match) = match result_validity.boolean_buffer() {
65        AllOr::All => {
66            let mut indices_iter = bool_buffer.set_indices();
67            (indices_iter.next(), indices_iter.next())
68        }
69        AllOr::None => (None, None),
70        AllOr::Some(v) => {
71            let mut indices_iter = bool_buffer.set_indices().filter(|i| v.value(*i));
72            (indices_iter.next(), indices_iter.next())
73        }
74    };
75
76    Ok(match (first_match, second_match) {
77        // Couldn't find a value match, so the result is all false
78        (None, _) => match result_validity {
79            Mask::AllTrue(_) => {
80                let mut result_builder =
81                    builder_with_capacity(&DType::Bool(result_nullability), codes.len());
82                result_builder.extend_from_array(
83                    &ConstantArray::new(Scalar::bool(false, result_nullability), codes.len())
84                        .into_array(),
85                )?;
86                result_builder.set_validity(codes.validity_mask());
87                result_builder.finish()
88            }
89            Mask::AllFalse(_) => ConstantArray::new(
90                Scalar::null(DType::Bool(Nullability::Nullable)),
91                codes.len(),
92            )
93            .into_array(),
94            Mask::Values(_) => {
95                let mut result_builder =
96                    builder_with_capacity(&DType::Bool(result_nullability), codes.len());
97                result_builder.extend_from_array(
98                    &ConstantArray::new(Scalar::bool(false, result_nullability), codes.len())
99                        .into_array(),
100                )?;
101                result_builder.set_validity(
102                    Validity::from_mask(result_validity, bool_result.dtype().nullability())
103                        .take(codes)?
104                        .to_mask(codes.len()),
105                );
106                result_builder.finish()
107            }
108        },
109        // We found a single matching value so we can compare the codes directly.
110        // Note: the codes include nullability so we can just compare the codes directly, to the found code.
111        (Some(code), None) => cast(
112            &compare(
113                codes,
114                &cast(
115                    ConstantArray::new(code, codes.len()).as_ref(),
116                    codes.dtype(),
117                )?,
118                Operator::Eq,
119            )?,
120            &DType::Bool(result_nullability),
121        )?,
122        // more than one value matches
123        _ => unsafe {
124            DictArray::new_unchecked(codes.clone(), bool_result.into_array()).into_array()
125        },
126    })
127}
128
129#[cfg(test)]
130mod tests {
131    use vortex_array::arrays::{ConstantArray, PrimitiveArray};
132    use vortex_array::compute::{Operator, compare};
133    use vortex_array::validity::Validity;
134    use vortex_array::{IntoArray, ToCanonical};
135    use vortex_buffer::buffer;
136    use vortex_dtype::Nullability;
137    use vortex_mask::Mask;
138    use vortex_scalar::Scalar;
139
140    use crate::DictArray;
141
142    #[test]
143    fn test_compare_value() {
144        let dict = DictArray::try_new(
145            buffer![0u32, 1, 2].into_array(),
146            buffer![1i32, 2, 3].into_array(),
147        )
148        .unwrap();
149
150        let res = compare(
151            dict.as_ref(),
152            ConstantArray::new(Scalar::from(1i32), 3).as_ref(),
153            Operator::Eq,
154        )
155        .unwrap();
156        let res = res.to_bool().unwrap();
157        assert_eq!(
158            res.boolean_buffer().iter().collect::<Vec<_>>(),
159            vec![true, false, false]
160        );
161    }
162
163    #[test]
164    fn test_compare_non_eq() {
165        let dict = DictArray::try_new(
166            buffer![0u32, 1, 2].into_array(),
167            buffer![1i32, 2, 3].into_array(),
168        )
169        .unwrap();
170
171        let res = compare(
172            dict.as_ref(),
173            ConstantArray::new(Scalar::from(1i32), 3).as_ref(),
174            Operator::Gt,
175        )
176        .unwrap();
177        let res = res.to_bool().unwrap();
178        assert_eq!(
179            res.boolean_buffer().iter().collect::<Vec<_>>(),
180            vec![false, true, true]
181        );
182    }
183
184    #[test]
185    fn test_compare_nullable() {
186        let dict = DictArray::try_new(
187            PrimitiveArray::new(
188                buffer![0u32, 1, 2],
189                Validity::from_iter([false, true, false]),
190            )
191            .into_array(),
192            PrimitiveArray::new(buffer![1i32, 2, 3], Validity::AllValid).into_array(),
193        )
194        .unwrap();
195
196        let res = compare(
197            dict.as_ref(),
198            ConstantArray::new(Scalar::primitive(4i32, Nullability::Nullable), 3).as_ref(),
199            Operator::Eq,
200        )
201        .unwrap();
202        let res = res.to_bool().unwrap();
203        assert_eq!(
204            res.boolean_buffer().iter().collect::<Vec<_>>(),
205            vec![false, false, false]
206        );
207        assert_eq!(res.dtype().nullability(), Nullability::Nullable);
208        assert_eq!(res.validity_mask(), Mask::from_iter([false, true, false]));
209    }
210
211    #[test]
212    fn test_compare_null_values() {
213        let dict = DictArray::try_new(
214            buffer![0u32, 1, 2].into_array(),
215            PrimitiveArray::new(
216                buffer![1i32, 2, 0],
217                Validity::from_iter([true, true, false]),
218            )
219            .into_array(),
220        )
221        .unwrap();
222
223        let res = compare(
224            dict.as_ref(),
225            ConstantArray::new(Scalar::primitive(4i32, Nullability::NonNullable), 3).as_ref(),
226            Operator::Eq,
227        )
228        .unwrap();
229        let res = res.to_bool().unwrap();
230        assert_eq!(
231            res.boolean_buffer().iter().collect::<Vec<_>>(),
232            vec![false, false, false]
233        );
234        assert_eq!(res.dtype().nullability(), Nullability::Nullable);
235        assert_eq!(res.validity_mask(), Mask::from_iter([true, true, false]));
236    }
237}