Skip to main content

vortex_array/arrays/dict/vtable/
validity.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5
6use super::Dict;
7use crate::IntoArray;
8use crate::array::ArrayView;
9use crate::array::ValidityVTable;
10use crate::arrays::DictArray;
11use crate::arrays::dict::DictArraySlotsExt;
12use crate::builtins::ArrayBuiltins;
13use crate::dtype::Nullability;
14use crate::scalar::Scalar;
15use crate::validity::Validity;
16
17impl ValidityVTable<Dict> for Dict {
18    fn validity(array: ArrayView<'_, Dict>) -> VortexResult<Validity> {
19        Ok(
20            match (array.codes().validity()?, array.values().validity()?) {
21                (
22                    Validity::NonNullable | Validity::AllValid,
23                    Validity::NonNullable | Validity::AllValid,
24                ) => {
25                    // Recall that we know the dictionary is nullable if we're in this function.
26                    Validity::AllValid
27                }
28                (Validity::AllInvalid, _) | (_, Validity::AllInvalid) => Validity::AllInvalid,
29                (Validity::Array(codes_validity), Validity::NonNullable | Validity::AllValid) => {
30                    Validity::Array(codes_validity)
31                }
32                (Validity::AllValid | Validity::NonNullable, Validity::Array(values_validity)) => {
33                    // We know codes are all valid, so the cast is free.
34                    let codes = array.codes().cast(array.codes().dtype().as_nonnullable())?;
35                    Validity::Array(
36                        unsafe { DictArray::new_unchecked(codes, values_validity) }.into_array(),
37                    )
38                }
39                (Validity::Array(_codes_validity), Validity::Array(values_validity)) => {
40                    // Create a mask representing "is the value at codes[i] valid?"
41                    let values_valid_mask =
42                        unsafe { DictArray::new_unchecked(array.codes().clone(), values_validity) }
43                            .into_array();
44                    let values_valid_mask = values_valid_mask
45                        .fill_null(Scalar::bool(false, Nullability::NonNullable))?;
46
47                    Validity::Array(values_valid_mask)
48                }
49            },
50        )
51    }
52}