vortex_dict/compute/
compare.rs

1use vortex_array::arrays::ConstantArray;
2use vortex_array::builders::builder_with_capacity;
3use vortex_array::compute::{CompareKernel, CompareKernelAdapter, Operator, cast, compare};
4use vortex_array::validity::Validity;
5use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
6use vortex_dtype::{DType, Nullability};
7use vortex_error::VortexResult;
8use vortex_mask::{AllOr, Mask};
9use vortex_scalar::Scalar;
10
11use crate::{DictArray, DictVTable};
12
13impl CompareKernel for DictVTable {
14    fn compare(
15        &self,
16        lhs: &DictArray,
17        rhs: &dyn Array,
18        operator: Operator,
19    ) -> VortexResult<Option<ArrayRef>> {
20        // if we have more values than codes, it is faster to canonicalise first.
21        if lhs.values().len() > lhs.codes().len() {
22            return Ok(None);
23        }
24        // If the RHS is constant, then we just need to compare against our encoded values.
25        if let Some(rhs) = rhs.as_constant() {
26            let compare_result = compare(
27                lhs.values(),
28                ConstantArray::new(rhs, lhs.values().len()).as_ref(),
29                operator,
30            )?;
31            return if operator == Operator::Eq {
32                let result_nullability =
33                    compare_result.dtype().nullability() | lhs.dtype().nullability();
34                dict_equal_to(compare_result, lhs.codes(), result_nullability).map(Some)
35            } else {
36                DictArray::try_new(lhs.codes().clone(), compare_result)
37                    .map(|a| a.into_array())
38                    .map(Some)
39            };
40        }
41
42        // It's a little more complex, but we could perform a comparison against the dictionary
43        // values in the future.
44        Ok(None)
45    }
46}
47
48register_kernel!(CompareKernelAdapter(DictVTable).lift());
49
50fn dict_equal_to(
51    values_compare: ArrayRef,
52    codes: &ArrayRef,
53    result_nullability: Nullability,
54) -> VortexResult<ArrayRef> {
55    let bool_result = values_compare.to_bool()?;
56    let result_validity = bool_result.validity_mask()?;
57    let bool_buffer = bool_result.boolean_buffer();
58    let (first_match, second_match) = match result_validity.boolean_buffer() {
59        AllOr::All => {
60            let mut indices_iter = bool_buffer.set_indices();
61            (indices_iter.next(), indices_iter.next())
62        }
63        AllOr::None => (None, None),
64        AllOr::Some(v) => {
65            let mut indices_iter = bool_buffer.set_indices().filter(|i| v.value(*i));
66            (indices_iter.next(), indices_iter.next())
67        }
68    };
69
70    Ok(match (first_match, second_match) {
71        // Couldn't find a value match, so the result is all false
72        (None, _) => match result_validity {
73            Mask::AllTrue(_) => {
74                let mut result_builder =
75                    builder_with_capacity(&DType::Bool(result_nullability), codes.len());
76                result_builder.extend_from_array(
77                    &ConstantArray::new(Scalar::bool(false, result_nullability), codes.len())
78                        .into_array(),
79                )?;
80                result_builder.set_validity(codes.validity_mask()?);
81                result_builder.finish()
82            }
83            Mask::AllFalse(_) => ConstantArray::new(
84                Scalar::null(DType::Bool(Nullability::Nullable)),
85                codes.len(),
86            )
87            .into_array(),
88            Mask::Values(_) => {
89                let mut result_builder =
90                    builder_with_capacity(&DType::Bool(result_nullability), codes.len());
91                result_builder.extend_from_array(
92                    &ConstantArray::new(Scalar::bool(false, result_nullability), codes.len())
93                        .into_array(),
94                )?;
95                result_builder.set_validity(
96                    Validity::from_mask(result_validity, bool_result.dtype().nullability())
97                        .take(codes)?
98                        .to_mask(codes.len())?,
99                );
100                result_builder.finish()
101            }
102        },
103        // We found a single matching value so we can compare the codes directly.
104        // Note: the codes include nullability so we can just compare the codes directly, to the found code.
105        (Some(code), None) => cast(
106            &compare(
107                codes,
108                &cast(
109                    ConstantArray::new(code, codes.len()).as_ref(),
110                    codes.dtype(),
111                )?,
112                Operator::Eq,
113            )?,
114            &DType::Bool(result_nullability),
115        )?,
116        // more than one value matches
117        _ => DictArray::try_new(codes.clone(), bool_result.into_array())?.into_array(),
118    })
119}
120
121#[cfg(test)]
122mod tests {
123    use vortex_array::arrays::{ConstantArray, PrimitiveArray};
124    use vortex_array::compute::{Operator, compare};
125    use vortex_array::validity::Validity;
126    use vortex_array::{IntoArray, ToCanonical};
127    use vortex_buffer::buffer;
128    use vortex_dtype::Nullability;
129    use vortex_mask::Mask;
130    use vortex_scalar::Scalar;
131
132    use crate::DictArray;
133
134    #[test]
135    fn test_compare_value() {
136        let dict = DictArray::try_new(
137            buffer![0u32, 1, 2].into_array(),
138            buffer![1i32, 2, 3].into_array(),
139        )
140        .unwrap();
141
142        let res = compare(
143            dict.as_ref(),
144            ConstantArray::new(Scalar::from(1i32), 3).as_ref(),
145            Operator::Eq,
146        )
147        .unwrap();
148        let res = res.to_bool().unwrap();
149        assert_eq!(
150            res.boolean_buffer().iter().collect::<Vec<_>>(),
151            vec![true, false, false]
152        );
153    }
154
155    #[test]
156    fn test_compare_non_eq() {
157        let dict = DictArray::try_new(
158            buffer![0u32, 1, 2].into_array(),
159            buffer![1i32, 2, 3].into_array(),
160        )
161        .unwrap();
162
163        let res = compare(
164            dict.as_ref(),
165            ConstantArray::new(Scalar::from(1i32), 3).as_ref(),
166            Operator::Gt,
167        )
168        .unwrap();
169        let res = res.to_bool().unwrap();
170        assert_eq!(
171            res.boolean_buffer().iter().collect::<Vec<_>>(),
172            vec![false, true, true]
173        );
174    }
175
176    #[test]
177    fn test_compare_nullable() {
178        let dict = DictArray::try_new(
179            PrimitiveArray::new(
180                buffer![0u32, 1, 2],
181                Validity::from_iter([false, true, false]),
182            )
183            .into_array(),
184            PrimitiveArray::new(buffer![1i32, 2, 3], Validity::AllValid).into_array(),
185        )
186        .unwrap();
187
188        let res = compare(
189            dict.as_ref(),
190            ConstantArray::new(Scalar::primitive(4i32, Nullability::Nullable), 3).as_ref(),
191            Operator::Eq,
192        )
193        .unwrap();
194        let res = res.to_bool().unwrap();
195        assert_eq!(
196            res.boolean_buffer().iter().collect::<Vec<_>>(),
197            vec![false, false, false]
198        );
199        assert_eq!(res.dtype().nullability(), Nullability::Nullable);
200        assert_eq!(
201            res.validity_mask().unwrap(),
202            Mask::from_iter([false, true, false])
203        );
204    }
205
206    #[test]
207    fn test_compare_null_values() {
208        let dict = DictArray::try_new(
209            buffer![0u32, 1, 2].into_array(),
210            PrimitiveArray::new(
211                buffer![1i32, 2, 0],
212                Validity::from_iter([true, true, false]),
213            )
214            .into_array(),
215        )
216        .unwrap();
217
218        let res = compare(
219            dict.as_ref(),
220            ConstantArray::new(Scalar::primitive(4i32, Nullability::NonNullable), 3).as_ref(),
221            Operator::Eq,
222        )
223        .unwrap();
224        let res = res.to_bool().unwrap();
225        assert_eq!(
226            res.boolean_buffer().iter().collect::<Vec<_>>(),
227            vec![false, false, false]
228        );
229        assert_eq!(res.dtype().nullability(), Nullability::Nullable);
230        assert_eq!(
231            res.validity_mask().unwrap(),
232            Mask::from_iter([true, true, false])
233        );
234    }
235}