vortex_array/arrays/dict/vtable/
canonical.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::ops::Not;
5
6use vortex_buffer::BitBuffer;
7use vortex_dtype::DType;
8use vortex_dtype::Nullability;
9use vortex_error::VortexExpect;
10use vortex_error::VortexResult;
11use vortex_mask::AllOr;
12use vortex_mask::Mask;
13use vortex_scalar::Scalar;
14
15use super::DictVTable;
16use crate::Array;
17use crate::ArrayRef;
18use crate::Canonical;
19use crate::IntoArray;
20use crate::ToCanonical;
21use crate::arrays::BoolArray;
22use crate::arrays::ConstantArray;
23use crate::arrays::dict::DictArray;
24use crate::compute::Operator;
25use crate::compute::cast;
26use crate::compute::compare;
27use crate::compute::mask;
28use crate::compute::take;
29use crate::validity::Validity;
30use crate::vtable::CanonicalVTable;
31
32impl CanonicalVTable<DictVTable> for DictVTable {
33    fn canonicalize(array: &DictArray) -> Canonical {
34        match array.dtype() {
35            // NOTE: Utf8 and Binary will decompress into VarBinViewArray, which requires a full
36            // decompression to construct the views child array.
37            // For this case, it is *always* faster to decompress the values first and then create
38            // copies of the view pointers.
39            DType::Utf8(_) | DType::Binary(_) => {
40                let canonical_values: ArrayRef = array.values().to_canonical().into_array();
41                take(&canonical_values, array.codes())
42                    .vortex_expect("taking codes from dictionary values shouldn't fail")
43                    .to_canonical()
44            }
45            DType::Bool(_) => {
46                dict_bool_take(array).vortex_expect("Canonicalizing dict bool array shouldn't fail")
47            }
48            _ => take(array.values(), array.codes())
49                .vortex_expect("taking codes from dictionary values shouldn't fail")
50                .to_canonical(),
51        }
52    }
53}
54
55fn dict_bool_take(dict_array: &DictArray) -> VortexResult<Canonical> {
56    let values = dict_array.values();
57    let codes = dict_array.codes();
58    let result_nullability = dict_array.dtype().nullability();
59
60    let bool_values = values.to_bool();
61    let result_validity = bool_values.validity_mask();
62    let bool_buffer = bool_values.bit_buffer();
63    let (first_match, second_match) = match result_validity.bit_buffer() {
64        AllOr::All => {
65            let mut indices_iter = bool_buffer.set_indices();
66            (indices_iter.next(), indices_iter.next())
67        }
68        AllOr::None => (None, None),
69        AllOr::Some(v) => {
70            let mut indices_iter = bool_buffer.set_indices().filter(|i| v.value(*i));
71            (indices_iter.next(), indices_iter.next())
72        }
73    };
74
75    Ok(match (first_match, second_match) {
76        // Couldn't find a value match, so the result is all false.
77        (None, _) => match result_validity {
78            Mask::AllTrue(_) => BoolArray::from_bit_buffer(
79                BitBuffer::new_unset(codes.len()),
80                Validity::copy_from_array(codes).union_nullability(result_nullability),
81            )
82            .to_canonical(),
83            Mask::AllFalse(_) => ConstantArray::new(
84                Scalar::null(DType::Bool(Nullability::Nullable)),
85                codes.len(),
86            )
87            .to_canonical(),
88            Mask::Values(_) => BoolArray::from_bit_buffer(
89                BitBuffer::new_unset(codes.len()),
90                Validity::from_mask(result_validity, result_nullability).take(codes)?,
91            )
92            .to_canonical(),
93        },
94        // We found a single matching value so we can compare the codes directly.
95        (Some(code), None) => match result_validity {
96            Mask::AllTrue(_) => cast(
97                &compare(
98                    codes,
99                    &cast(
100                        ConstantArray::new(code, codes.len()).as_ref(),
101                        codes.dtype(),
102                    )?,
103                    Operator::Eq,
104                )?,
105                &DType::Bool(result_nullability),
106            )?
107            .to_canonical(),
108            Mask::AllFalse(_) => ConstantArray::new(
109                Scalar::null(DType::Bool(Nullability::Nullable)),
110                codes.len(),
111            )
112            .to_canonical(),
113            Mask::Values(rv) => mask(
114                &compare(
115                    codes,
116                    &cast(
117                        ConstantArray::new(code, codes.len()).as_ref(),
118                        codes.dtype(),
119                    )?,
120                    Operator::Eq,
121                )?,
122                &Mask::from_buffer(
123                    take(BoolArray::from(rv.bit_buffer().clone()).as_ref(), codes)?
124                        .to_bool()
125                        .bit_buffer()
126                        .not(),
127                ),
128            )?
129            .to_canonical(),
130        },
131        // More than one value matches.
132        _ => take(bool_values.as_ref(), codes)
133            .vortex_expect("taking codes from dictionary values shouldn't fail")
134            .to_canonical(),
135    })
136}