vortex_array/arrays/dict/
canonical.rs1use std::ops::Not;
5
6use vortex_buffer::BitBuffer;
7use vortex_dtype::{DType, Nullability};
8use vortex_error::{VortexExpect, VortexResult};
9use vortex_mask::{AllOr, Mask};
10use vortex_scalar::Scalar;
11
12use super::{DictArray, DictVTable};
13use crate::arrays::{BoolArray, ConstantArray};
14use crate::compute::{Operator, cast, compare, mask, take};
15use crate::validity::Validity;
16use crate::vtable::CanonicalVTable;
17use crate::{Array, ArrayRef, Canonical, IntoArray, ToCanonical};
18
19impl CanonicalVTable<DictVTable> for DictVTable {
20 fn canonicalize(array: &DictArray) -> Canonical {
21 match array.dtype() {
22 DType::Utf8(_) | DType::Binary(_) => {
27 let canonical_values: ArrayRef = array.values().to_canonical().into_array();
28 take(&canonical_values, array.codes())
29 .vortex_expect("taking codes from dictionary values shouldn't fail")
30 .to_canonical()
31 }
32 DType::Bool(_) => {
33 dict_bool_take(array).vortex_expect("Canonicalizing dict bool array shouldn't fail")
34 }
35 _ => take(array.values(), array.codes())
36 .vortex_expect("taking codes from dictionary values shouldn't fail")
37 .to_canonical(),
38 }
39 }
40}
41
42fn dict_bool_take(dict_array: &DictArray) -> VortexResult<Canonical> {
43 let values = dict_array.values();
44 let codes = dict_array.codes();
45 let result_nullability = dict_array.dtype().nullability();
46
47 let bool_values = values.to_bool();
48 let result_validity = bool_values.validity_mask();
49 let bool_buffer = bool_values.bit_buffer();
50 let (first_match, second_match) = match result_validity.bit_buffer() {
51 AllOr::All => {
52 let mut indices_iter = bool_buffer.set_indices();
53 (indices_iter.next(), indices_iter.next())
54 }
55 AllOr::None => (None, None),
56 AllOr::Some(v) => {
57 let mut indices_iter = bool_buffer.set_indices().filter(|i| v.value(*i));
58 (indices_iter.next(), indices_iter.next())
59 }
60 };
61
62 Ok(match (first_match, second_match) {
63 (None, _) => match result_validity {
65 Mask::AllTrue(_) => BoolArray::from_bit_buffer(
66 BitBuffer::new_unset(codes.len()),
67 Validity::copy_from_array(codes).union_nullability(result_nullability),
68 )
69 .to_canonical(),
70 Mask::AllFalse(_) => ConstantArray::new(
71 Scalar::null(DType::Bool(Nullability::Nullable)),
72 codes.len(),
73 )
74 .to_canonical(),
75 Mask::Values(_) => BoolArray::from_bit_buffer(
76 BitBuffer::new_unset(codes.len()),
77 Validity::from_mask(result_validity, result_nullability).take(codes)?,
78 )
79 .to_canonical(),
80 },
81 (Some(code), None) => match result_validity {
83 Mask::AllTrue(_) => cast(
84 &compare(
85 codes,
86 &cast(
87 ConstantArray::new(code, codes.len()).as_ref(),
88 codes.dtype(),
89 )?,
90 Operator::Eq,
91 )?,
92 &DType::Bool(result_nullability),
93 )?
94 .to_canonical(),
95 Mask::AllFalse(_) => ConstantArray::new(
96 Scalar::null(DType::Bool(Nullability::Nullable)),
97 codes.len(),
98 )
99 .to_canonical(),
100 Mask::Values(rv) => mask(
101 &compare(
102 codes,
103 &cast(
104 ConstantArray::new(code, codes.len()).as_ref(),
105 codes.dtype(),
106 )?,
107 Operator::Eq,
108 )?,
109 &Mask::from_buffer(
110 take(BoolArray::from(rv.bit_buffer().clone()).as_ref(), codes)?
111 .to_bool()
112 .bit_buffer()
113 .not(),
114 ),
115 )?
116 .to_canonical(),
117 },
118 _ => take(bool_values.as_ref(), codes)
120 .vortex_expect("taking codes from dictionary values shouldn't fail")
121 .to_canonical(),
122 })
123}