vortex_dict/compute/
compare.rs

1use vortex_array::arrays::ConstantArray;
2use vortex_array::builders::builder_with_capacity;
3use vortex_array::compute::{CompareKernel, CompareKernelAdapter, Operator, cast, compare};
4use vortex_array::validity::Validity;
5use vortex_array::{Array, ArrayRef, ToCanonical, register_kernel};
6use vortex_dtype::{DType, Nullability};
7use vortex_error::VortexResult;
8use vortex_mask::{AllOr, Mask};
9use vortex_scalar::Scalar;
10
11use crate::{DictArray, DictEncoding};
12
13impl CompareKernel for DictEncoding {
14    fn compare(
15        &self,
16        lhs: &DictArray,
17        rhs: &dyn Array,
18        operator: Operator,
19    ) -> VortexResult<Option<ArrayRef>> {
20        // if we have more values than codes, it is faster to canonicalise first.
21        if lhs.values().len() > lhs.codes().len() {
22            return Ok(None);
23        }
24        // If the RHS is constant, then we just need to compare against our encoded values.
25        if let Some(rhs) = rhs.as_constant() {
26            let compare_result = compare(
27                lhs.values(),
28                &ConstantArray::new(rhs, lhs.values().len()),
29                operator,
30            )?;
31            return if operator == Operator::Eq {
32                let result_nullability =
33                    compare_result.dtype().nullability() | lhs.dtype().nullability();
34                dict_equal_to(compare_result, lhs.codes(), result_nullability).map(Some)
35            } else {
36                DictArray::try_new(lhs.codes().clone(), compare_result)
37                    .map(|a| a.into_array())
38                    .map(Some)
39            };
40        }
41
42        // It's a little more complex, but we could perform a comparison against the dictionary
43        // values in the future.
44        Ok(None)
45    }
46}
47
48register_kernel!(CompareKernelAdapter(DictEncoding).lift());
49
50fn dict_equal_to(
51    values_compare: ArrayRef,
52    codes: &ArrayRef,
53    result_nullability: Nullability,
54) -> VortexResult<ArrayRef> {
55    let bool_result = values_compare.to_bool()?;
56    let result_validity = bool_result.validity_mask()?;
57    let bool_buffer = bool_result.boolean_buffer();
58    let (first_match, second_match) = match result_validity.boolean_buffer() {
59        AllOr::All => {
60            let mut indices_iter = bool_buffer.set_indices();
61            (indices_iter.next(), indices_iter.next())
62        }
63        AllOr::None => (None, None),
64        AllOr::Some(v) => {
65            let mut indices_iter = bool_buffer.set_indices().filter(|i| v.value(*i));
66            (indices_iter.next(), indices_iter.next())
67        }
68    };
69
70    Ok(match (first_match, second_match) {
71        // Couldn't find a value match, so the result is all false
72        (None, _) => match result_validity {
73            Mask::AllTrue(_) => {
74                let mut result_builder =
75                    builder_with_capacity(&DType::Bool(result_nullability), codes.len());
76                result_builder.extend_from_array(
77                    &ConstantArray::new(Scalar::bool(false, result_nullability), codes.len())
78                        .into_array(),
79                )?;
80                result_builder.set_validity(codes.validity_mask()?);
81                result_builder.finish()
82            }
83            Mask::AllFalse(_) => ConstantArray::new(
84                Scalar::null(DType::Bool(Nullability::Nullable)),
85                codes.len(),
86            )
87            .into_array(),
88            Mask::Values(_) => {
89                let mut result_builder =
90                    builder_with_capacity(&DType::Bool(result_nullability), codes.len());
91                result_builder.extend_from_array(
92                    &ConstantArray::new(Scalar::bool(false, result_nullability), codes.len())
93                        .into_array(),
94                )?;
95                result_builder.set_validity(
96                    Validity::from_mask(result_validity, bool_result.dtype().nullability())
97                        .take(codes)?
98                        .to_mask(codes.len())?,
99                );
100                result_builder.finish()
101            }
102        },
103        // We found a single matching value so we can compare the codes directly.
104        // Note: the codes include nullability so we can just compare the codes directly, to the found code.
105        (Some(code), None) => cast(
106            &compare(
107                codes,
108                &cast(&ConstantArray::new(code, codes.len()), codes.dtype())?,
109                Operator::Eq,
110            )?,
111            &DType::Bool(result_nullability),
112        )?,
113        // more than one value matches
114        _ => DictArray::try_new(codes.clone(), bool_result.into_array())?.into_array(),
115    })
116}
117
118#[cfg(test)]
119mod tests {
120    use vortex_array::arrays::{ConstantArray, PrimitiveArray};
121    use vortex_array::compute::{Operator, compare};
122    use vortex_array::validity::Validity;
123    use vortex_array::{Array, IntoArray, ToCanonical};
124    use vortex_buffer::buffer;
125    use vortex_dtype::Nullability;
126    use vortex_mask::Mask;
127    use vortex_scalar::Scalar;
128
129    use crate::DictArray;
130
131    #[test]
132    fn test_compare_value() {
133        let dict = DictArray::try_new(
134            buffer![0u32, 1, 2].into_array(),
135            buffer![1i32, 2, 3].into_array(),
136        )
137        .unwrap();
138
139        let res = compare(
140            &dict,
141            &ConstantArray::new(Scalar::from(1i32), 3),
142            Operator::Eq,
143        )
144        .unwrap();
145        let res = res.to_bool().unwrap();
146        assert_eq!(
147            res.boolean_buffer().iter().collect::<Vec<_>>(),
148            vec![true, false, false]
149        );
150    }
151
152    #[test]
153    fn test_compare_non_eq() {
154        let dict = DictArray::try_new(
155            buffer![0u32, 1, 2].into_array(),
156            buffer![1i32, 2, 3].into_array(),
157        )
158        .unwrap();
159
160        let res = compare(
161            &dict,
162            &ConstantArray::new(Scalar::from(1i32), 3),
163            Operator::Gt,
164        )
165        .unwrap();
166        let res = res.to_bool().unwrap();
167        assert_eq!(
168            res.boolean_buffer().iter().collect::<Vec<_>>(),
169            vec![false, true, true]
170        );
171    }
172
173    #[test]
174    fn test_compare_nullable() {
175        let dict = DictArray::try_new(
176            PrimitiveArray::new(
177                buffer![0u32, 1, 2],
178                Validity::from_iter([false, true, false]),
179            )
180            .into_array(),
181            PrimitiveArray::new(buffer![1i32, 2, 3], Validity::AllValid).into_array(),
182        )
183        .unwrap();
184
185        let res = compare(
186            &dict,
187            &ConstantArray::new(Scalar::primitive(4i32, Nullability::Nullable), 3),
188            Operator::Eq,
189        )
190        .unwrap();
191        let res = res.to_bool().unwrap();
192        assert_eq!(
193            res.boolean_buffer().iter().collect::<Vec<_>>(),
194            vec![false, false, false]
195        );
196        assert_eq!(res.dtype().nullability(), Nullability::Nullable);
197        assert_eq!(
198            res.validity_mask().unwrap(),
199            Mask::from_iter([false, true, false])
200        );
201    }
202
203    #[test]
204    fn test_compare_null_values() {
205        let dict = DictArray::try_new(
206            buffer![0u32, 1, 2].into_array(),
207            PrimitiveArray::new(
208                buffer![1i32, 2, 0],
209                Validity::from_iter([true, true, false]),
210            )
211            .into_array(),
212        )
213        .unwrap();
214
215        let res = compare(
216            &dict,
217            &ConstantArray::new(Scalar::primitive(4i32, Nullability::NonNullable), 3),
218            Operator::Eq,
219        )
220        .unwrap();
221        let res = res.to_bool().unwrap();
222        assert_eq!(
223            res.boolean_buffer().iter().collect::<Vec<_>>(),
224            vec![false, false, false]
225        );
226        assert_eq!(res.dtype().nullability(), Nullability::Nullable);
227        assert_eq!(
228            res.validity_mask().unwrap(),
229            Mask::from_iter([true, true, false])
230        );
231    }
232}