vortex_array/arrays/dict/
canonical.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::ops::Not;
5
6use vortex_buffer::BitBuffer;
7use vortex_dtype::{DType, Nullability};
8use vortex_error::{VortexExpect, VortexResult};
9use vortex_mask::{AllOr, Mask};
10use vortex_scalar::Scalar;
11
12use super::{DictArray, DictVTable};
13use crate::arrays::{BoolArray, ConstantArray};
14use crate::compute::{Operator, cast, compare, mask, take};
15use crate::validity::Validity;
16use crate::vtable::CanonicalVTable;
17use crate::{Array, ArrayRef, Canonical, IntoArray, ToCanonical};
18
19impl CanonicalVTable<DictVTable> for DictVTable {
20    fn canonicalize(array: &DictArray) -> Canonical {
21        match array.dtype() {
22            // NOTE: Utf8 and Binary will decompress into VarBinViewArray, which requires a full
23            // decompression to construct the views child array.
24            // For this case, it is *always* faster to decompress the values first and then create
25            // copies of the view pointers.
26            DType::Utf8(_) | DType::Binary(_) => {
27                let canonical_values: ArrayRef = array.values().to_canonical().into_array();
28                take(&canonical_values, array.codes())
29                    .vortex_expect("taking codes from dictionary values shouldn't fail")
30                    .to_canonical()
31            }
32            DType::Bool(_) => {
33                dict_bool_take(array).vortex_expect("Canonicalizing dict bool array shouldn't fail")
34            }
35            _ => take(array.values(), array.codes())
36                .vortex_expect("taking codes from dictionary values shouldn't fail")
37                .to_canonical(),
38        }
39    }
40}
41
42fn dict_bool_take(dict_array: &DictArray) -> VortexResult<Canonical> {
43    let values = dict_array.values();
44    let codes = dict_array.codes();
45    let result_nullability = dict_array.dtype().nullability();
46
47    let bool_values = values.to_bool();
48    let result_validity = bool_values.validity_mask();
49    let bool_buffer = bool_values.bit_buffer();
50    let (first_match, second_match) = match result_validity.bit_buffer() {
51        AllOr::All => {
52            let mut indices_iter = bool_buffer.set_indices();
53            (indices_iter.next(), indices_iter.next())
54        }
55        AllOr::None => (None, None),
56        AllOr::Some(v) => {
57            let mut indices_iter = bool_buffer.set_indices().filter(|i| v.value(*i));
58            (indices_iter.next(), indices_iter.next())
59        }
60    };
61
62    Ok(match (first_match, second_match) {
63        // Couldn't find a value match, so the result is all false
64        (None, _) => match result_validity {
65            Mask::AllTrue(_) => BoolArray::from_bit_buffer(
66                BitBuffer::new_unset(codes.len()),
67                Validity::copy_from_array(codes).union_nullability(result_nullability),
68            )
69            .to_canonical(),
70            Mask::AllFalse(_) => ConstantArray::new(
71                Scalar::null(DType::Bool(Nullability::Nullable)),
72                codes.len(),
73            )
74            .to_canonical(),
75            Mask::Values(_) => BoolArray::from_bit_buffer(
76                BitBuffer::new_unset(codes.len()),
77                Validity::from_mask(result_validity, result_nullability).take(codes)?,
78            )
79            .to_canonical(),
80        },
81        // We found a single matching value so we can compare the codes directly.
82        (Some(code), None) => match result_validity {
83            Mask::AllTrue(_) => cast(
84                &compare(
85                    codes,
86                    &cast(
87                        ConstantArray::new(code, codes.len()).as_ref(),
88                        codes.dtype(),
89                    )?,
90                    Operator::Eq,
91                )?,
92                &DType::Bool(result_nullability),
93            )?
94            .to_canonical(),
95            Mask::AllFalse(_) => ConstantArray::new(
96                Scalar::null(DType::Bool(Nullability::Nullable)),
97                codes.len(),
98            )
99            .to_canonical(),
100            Mask::Values(rv) => mask(
101                &compare(
102                    codes,
103                    &cast(
104                        ConstantArray::new(code, codes.len()).as_ref(),
105                        codes.dtype(),
106                    )?,
107                    Operator::Eq,
108                )?,
109                &Mask::from_buffer(
110                    take(BoolArray::from(rv.bit_buffer().clone()).as_ref(), codes)?
111                        .to_bool()
112                        .bit_buffer()
113                        .not(),
114                ),
115            )?
116            .to_canonical(),
117        },
118        // more than one value matches
119        _ => take(bool_values.as_ref(), codes)
120            .vortex_expect("taking codes from dictionary values shouldn't fail")
121            .to_canonical(),
122    })
123}