vortex_array/arrays/dict/vtable/
canonical.rs1use std::ops::Not;
5
6use vortex_buffer::BitBuffer;
7use vortex_dtype::DType;
8use vortex_dtype::Nullability;
9use vortex_error::VortexExpect;
10use vortex_error::VortexResult;
11use vortex_mask::AllOr;
12use vortex_mask::Mask;
13use vortex_scalar::Scalar;
14
15use super::DictVTable;
16use crate::Array;
17use crate::ArrayRef;
18use crate::Canonical;
19use crate::IntoArray;
20use crate::ToCanonical;
21use crate::arrays::BoolArray;
22use crate::arrays::ConstantArray;
23use crate::arrays::dict::DictArray;
24use crate::compute::Operator;
25use crate::compute::cast;
26use crate::compute::compare;
27use crate::compute::mask;
28use crate::compute::take;
29use crate::validity::Validity;
30use crate::vtable::CanonicalVTable;
31
32impl CanonicalVTable<DictVTable> for DictVTable {
33 fn canonicalize(array: &DictArray) -> Canonical {
34 match array.dtype() {
35 DType::Utf8(_) | DType::Binary(_) => {
40 let canonical_values: ArrayRef = array.values().to_canonical().into_array();
41 take(&canonical_values, array.codes())
42 .vortex_expect("taking codes from dictionary values shouldn't fail")
43 .to_canonical()
44 }
45 DType::Bool(_) => {
46 dict_bool_take(array).vortex_expect("Canonicalizing dict bool array shouldn't fail")
47 }
48 _ => take(array.values(), array.codes())
49 .vortex_expect("taking codes from dictionary values shouldn't fail")
50 .to_canonical(),
51 }
52 }
53}
54
55fn dict_bool_take(dict_array: &DictArray) -> VortexResult<Canonical> {
56 let values = dict_array.values();
57 let codes = dict_array.codes();
58 let result_nullability = dict_array.dtype().nullability();
59
60 let bool_values = values.to_bool();
61 let result_validity = bool_values.validity_mask();
62 let bool_buffer = bool_values.bit_buffer();
63 let (first_match, second_match) = match result_validity.bit_buffer() {
64 AllOr::All => {
65 let mut indices_iter = bool_buffer.set_indices();
66 (indices_iter.next(), indices_iter.next())
67 }
68 AllOr::None => (None, None),
69 AllOr::Some(v) => {
70 let mut indices_iter = bool_buffer.set_indices().filter(|i| v.value(*i));
71 (indices_iter.next(), indices_iter.next())
72 }
73 };
74
75 Ok(match (first_match, second_match) {
76 (None, _) => match result_validity {
78 Mask::AllTrue(_) => BoolArray::from_bit_buffer(
79 BitBuffer::new_unset(codes.len()),
80 Validity::copy_from_array(codes).union_nullability(result_nullability),
81 )
82 .to_canonical(),
83 Mask::AllFalse(_) => ConstantArray::new(
84 Scalar::null(DType::Bool(Nullability::Nullable)),
85 codes.len(),
86 )
87 .to_canonical(),
88 Mask::Values(_) => BoolArray::from_bit_buffer(
89 BitBuffer::new_unset(codes.len()),
90 Validity::from_mask(result_validity, result_nullability).take(codes)?,
91 )
92 .to_canonical(),
93 },
94 (Some(code), None) => match result_validity {
96 Mask::AllTrue(_) => cast(
97 &compare(
98 codes,
99 &cast(
100 ConstantArray::new(code, codes.len()).as_ref(),
101 codes.dtype(),
102 )?,
103 Operator::Eq,
104 )?,
105 &DType::Bool(result_nullability),
106 )?
107 .to_canonical(),
108 Mask::AllFalse(_) => ConstantArray::new(
109 Scalar::null(DType::Bool(Nullability::Nullable)),
110 codes.len(),
111 )
112 .to_canonical(),
113 Mask::Values(rv) => mask(
114 &compare(
115 codes,
116 &cast(
117 ConstantArray::new(code, codes.len()).as_ref(),
118 codes.dtype(),
119 )?,
120 Operator::Eq,
121 )?,
122 &Mask::from_buffer(
123 take(BoolArray::from(rv.bit_buffer().clone()).as_ref(), codes)?
124 .to_bool()
125 .bit_buffer()
126 .not(),
127 ),
128 )?
129 .to_canonical(),
130 },
131 _ => take(bool_values.as_ref(), codes)
133 .vortex_expect("taking codes from dictionary values shouldn't fail")
134 .to_canonical(),
135 })
136}