Skip to main content

vortex_array/arrays/dict/vtable/
validity.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_dtype::Nullability;
5use vortex_error::VortexResult;
6
7use super::DictVTable;
8use crate::Array;
9use crate::IntoArray;
10use crate::arrays::dict::DictArray;
11use crate::builtins::ArrayBuiltins;
12use crate::scalar::Scalar;
13use crate::validity::Validity;
14use crate::vtable::ValidityVTable;
15
16impl ValidityVTable<DictVTable> for DictVTable {
17    fn validity(array: &DictArray) -> VortexResult<Validity> {
18        Ok(
19            match (array.codes().validity()?, array.values().validity()?) {
20                (
21                    Validity::NonNullable | Validity::AllValid,
22                    Validity::NonNullable | Validity::AllValid,
23                ) => {
24                    // Recall that we know the dictionary is nullable if we're in this function.
25                    Validity::AllValid
26                }
27                (Validity::AllInvalid, _) | (_, Validity::AllInvalid) => Validity::AllInvalid,
28                (Validity::Array(codes_validity), Validity::NonNullable | Validity::AllValid) => {
29                    Validity::Array(codes_validity)
30                }
31                (Validity::AllValid | Validity::NonNullable, Validity::Array(values_validity)) => {
32                    // We know codes are all valid, so the cast is free.
33                    let codes = array.codes().cast(array.codes().dtype().as_nonnullable())?;
34                    Validity::Array(
35                        unsafe { DictArray::new_unchecked(codes, values_validity) }.into_array(),
36                    )
37                }
38                (Validity::Array(_codes_validity), Validity::Array(values_validity)) => {
39                    // Create a mask representing "is the value at codes[i] valid?"
40                    let values_valid_mask =
41                        unsafe { DictArray::new_unchecked(array.codes().clone(), values_validity) }
42                            .into_array();
43                    let values_valid_mask = values_valid_mask
44                        .fill_null(Scalar::bool(false, Nullability::NonNullable))?;
45
46                    Validity::Array(values_valid_mask)
47                }
48            },
49        )
50    }
51}