vortex_dict/compute/
compare.rs

1use vortex_array::arrays::ConstantArray;
2use vortex_array::builders::builder_with_capacity;
3use vortex_array::compute::{CompareFn, Operator, compare, try_cast};
4use vortex_array::validity::Validity;
5use vortex_array::{Array, ArrayRef, ToCanonical};
6use vortex_dtype::{DType, Nullability};
7use vortex_error::VortexResult;
8use vortex_mask::{AllOr, Mask};
9use vortex_scalar::Scalar;
10
11use crate::{DictArray, DictEncoding};
12
13impl CompareFn<&DictArray> for DictEncoding {
14    fn compare(
15        &self,
16        lhs: &DictArray,
17        rhs: &dyn Array,
18        operator: Operator,
19    ) -> VortexResult<Option<ArrayRef>> {
20        // if we have more values than codes, it is faster to canonicalise first.
21        if lhs.values().len() > lhs.codes().len() {
22            return Ok(None);
23        }
24        // If the RHS is constant, then we just need to compare against our encoded values.
25        if let Some(rhs) = rhs.as_constant() {
26            let compare_result = compare(
27                lhs.values(),
28                &ConstantArray::new(rhs, lhs.values().len()),
29                operator,
30            )?;
31            return if operator == Operator::Eq {
32                let result_nullability =
33                    compare_result.dtype().nullability() | lhs.dtype().nullability();
34                dict_equal_to(compare_result, lhs.codes(), result_nullability).map(Some)
35            } else {
36                DictArray::try_new(lhs.codes().clone(), compare_result)
37                    .map(|a| a.into_array())
38                    .map(Some)
39            };
40        }
41
42        // It's a little more complex, but we could perform a comparison against the dictionary
43        // values in the future.
44        Ok(None)
45    }
46}
47
48fn dict_equal_to(
49    values_compare: ArrayRef,
50    codes: &ArrayRef,
51    result_nullability: Nullability,
52) -> VortexResult<ArrayRef> {
53    let bool_result = values_compare.to_bool()?;
54    let result_validity = bool_result.validity_mask()?;
55    let bool_buffer = bool_result.boolean_buffer();
56    let (first_match, second_match) = match result_validity.boolean_buffer() {
57        AllOr::All => {
58            let mut indices_iter = bool_buffer.set_indices();
59            (indices_iter.next(), indices_iter.next())
60        }
61        AllOr::None => (None, None),
62        AllOr::Some(v) => {
63            let mut indices_iter = bool_buffer.set_indices().filter(|i| v.value(*i));
64            (indices_iter.next(), indices_iter.next())
65        }
66    };
67
68    Ok(match (first_match, second_match) {
69        // Couldn't find a value match, so the result is all false
70        (None, _) => match result_validity {
71            Mask::AllTrue(_) => {
72                let mut result_builder =
73                    builder_with_capacity(&DType::Bool(result_nullability), codes.len());
74                result_builder.extend_from_array(
75                    &ConstantArray::new(Scalar::bool(false, result_nullability), codes.len())
76                        .into_array(),
77                )?;
78                result_builder.set_validity(codes.validity_mask()?);
79                result_builder.finish()
80            }
81            Mask::AllFalse(_) => ConstantArray::new(
82                Scalar::null(DType::Bool(Nullability::Nullable)),
83                codes.len(),
84            )
85            .into_array(),
86            Mask::Values(_) => {
87                let mut result_builder =
88                    builder_with_capacity(&DType::Bool(result_nullability), codes.len());
89                result_builder.extend_from_array(
90                    &ConstantArray::new(Scalar::bool(false, result_nullability), codes.len())
91                        .into_array(),
92                )?;
93                result_builder.set_validity(
94                    Validity::from_mask(result_validity, bool_result.dtype().nullability())
95                        .take(codes)?
96                        .to_mask(codes.len())?,
97                );
98                result_builder.finish()
99            }
100        },
101        // We found a single matching value so we can compare the codes directly.
102        // Note: the codes include nullability so we can just compare the codes directly, to the found code.
103        (Some(code), None) => try_cast(
104            &compare(
105                codes,
106                &try_cast(&ConstantArray::new(code, codes.len()), codes.dtype())?,
107                Operator::Eq,
108            )?,
109            &DType::Bool(result_nullability),
110        )?,
111        // more than one value matches
112        _ => DictArray::try_new(codes.clone(), bool_result.into_array())?.into_array(),
113    })
114}
115
116#[cfg(test)]
117mod tests {
118    use vortex_array::arrays::{ConstantArray, PrimitiveArray};
119    use vortex_array::compute::{Operator, compare};
120    use vortex_array::validity::Validity;
121    use vortex_array::{Array, IntoArray, ToCanonical};
122    use vortex_buffer::buffer;
123    use vortex_dtype::Nullability;
124    use vortex_mask::Mask;
125    use vortex_scalar::Scalar;
126
127    use crate::DictArray;
128
129    #[test]
130    fn test_compare_value() {
131        let dict = DictArray::try_new(
132            buffer![0u32, 1, 2].into_array(),
133            buffer![1i32, 2, 3].into_array(),
134        )
135        .unwrap();
136
137        let res = compare(
138            &dict,
139            &ConstantArray::new(Scalar::from(1i32), 3),
140            Operator::Eq,
141        )
142        .unwrap();
143        let res = res.to_bool().unwrap();
144        assert_eq!(
145            res.boolean_buffer().iter().collect::<Vec<_>>(),
146            vec![true, false, false]
147        );
148    }
149
150    #[test]
151    fn test_compare_non_eq() {
152        let dict = DictArray::try_new(
153            buffer![0u32, 1, 2].into_array(),
154            buffer![1i32, 2, 3].into_array(),
155        )
156        .unwrap();
157
158        let res = compare(
159            &dict,
160            &ConstantArray::new(Scalar::from(1i32), 3),
161            Operator::Gt,
162        )
163        .unwrap();
164        let res = res.to_bool().unwrap();
165        assert_eq!(
166            res.boolean_buffer().iter().collect::<Vec<_>>(),
167            vec![false, true, true]
168        );
169    }
170
171    #[test]
172    fn test_compare_nullable() {
173        let dict = DictArray::try_new(
174            PrimitiveArray::new(
175                buffer![0u32, 1, 2],
176                Validity::from_iter([false, true, false]),
177            )
178            .into_array(),
179            PrimitiveArray::new(buffer![1i32, 2, 3], Validity::AllValid).into_array(),
180        )
181        .unwrap();
182
183        let res = compare(
184            &dict,
185            &ConstantArray::new(Scalar::primitive(4i32, Nullability::Nullable), 3),
186            Operator::Eq,
187        )
188        .unwrap();
189        let res = res.to_bool().unwrap();
190        assert_eq!(
191            res.boolean_buffer().iter().collect::<Vec<_>>(),
192            vec![false, false, false]
193        );
194        assert_eq!(res.dtype().nullability(), Nullability::Nullable);
195        assert_eq!(
196            res.validity_mask().unwrap(),
197            Mask::from_iter([false, true, false])
198        );
199    }
200
201    #[test]
202    fn test_compare_null_values() {
203        let dict = DictArray::try_new(
204            buffer![0u32, 1, 2].into_array(),
205            PrimitiveArray::new(
206                buffer![1i32, 2, 0],
207                Validity::from_iter([true, true, false]),
208            )
209            .into_array(),
210        )
211        .unwrap();
212
213        let res = compare(
214            &dict,
215            &ConstantArray::new(Scalar::primitive(4i32, Nullability::NonNullable), 3),
216            Operator::Eq,
217        )
218        .unwrap();
219        let res = res.to_bool().unwrap();
220        assert_eq!(
221            res.boolean_buffer().iter().collect::<Vec<_>>(),
222            vec![false, false, false]
223        );
224        assert_eq!(res.dtype().nullability(), Nullability::Nullable);
225        assert_eq!(
226            res.validity_mask().unwrap(),
227            Mask::from_iter([true, true, false])
228        );
229    }
230}