vortex_array/arrays/dict/vtable/
validity.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_buffer::BitBuffer;
5use vortex_dtype::match_each_integer_ptype;
6use vortex_error::VortexExpect;
7use vortex_error::VortexResult;
8use vortex_mask::AllOr;
9use vortex_mask::Mask;
10
11use super::DictVTable;
12use crate::Array;
13use crate::IntoArray;
14use crate::ToCanonical;
15use crate::arrays::dict::DictArray;
16use crate::validity::Validity;
17use crate::vtable::ValidityVTable;
18
19impl ValidityVTable<DictVTable> for DictVTable {
20    fn is_valid(array: &DictArray, index: usize) -> bool {
21        let scalar = array.codes().scalar_at(index);
22
23        if scalar.is_null() {
24            return false;
25        };
26        let values_index: usize = scalar
27            .as_ref()
28            .try_into()
29            .vortex_expect("Failed to convert dictionary code to usize");
30        array.values().is_valid(values_index)
31    }
32
33    fn all_valid(array: &DictArray) -> bool {
34        array.codes().all_valid() && array.values().all_valid()
35    }
36
37    fn all_invalid(array: &DictArray) -> bool {
38        array.codes().all_invalid() || array.values().all_invalid()
39    }
40
41    fn validity(array: &DictArray) -> VortexResult<Validity> {
42        Ok(
43            match (array.codes().validity()?, array.values().validity()?) {
44                (
45                    Validity::NonNullable | Validity::AllValid,
46                    Validity::NonNullable | Validity::AllValid,
47                ) => {
48                    // Recall that we know the dictionary is nullable if we're in this function.
49                    Validity::AllValid
50                }
51                (Validity::AllInvalid, _) | (_, Validity::AllInvalid) => Validity::AllInvalid,
52                (Validity::Array(codes_validity), Validity::NonNullable | Validity::AllValid) => {
53                    Validity::Array(codes_validity)
54                }
55                (Validity::AllValid | Validity::NonNullable, Validity::Array(values_validity)) => {
56                    Validity::Array(
57                        unsafe { DictArray::new_unchecked(array.codes().clone(), values_validity) }
58                            .into_array(),
59                    )
60                }
61                (Validity::Array(_), Validity::Array(values_validity)) => {
62                    // We essentially create is_not_null(Dict(codes, is_not_null(values)))
63                    unsafe { DictArray::new_unchecked(array.codes().clone(), values_validity) }
64                        .into_array()
65                        .validity()?
66                }
67            },
68        )
69    }
70
71    fn validity_mask(array: &DictArray) -> Mask {
72        let codes_validity = array.codes().validity_mask();
73        match codes_validity.bit_buffer() {
74            AllOr::All => {
75                let primitive_codes = array.codes().to_primitive();
76                let values_mask = array.values().validity_mask();
77                let is_valid_buffer = match_each_integer_ptype!(primitive_codes.ptype(), |P| {
78                    let codes_slice = primitive_codes.as_slice::<P>();
79                    BitBuffer::collect_bool(array.len(), |idx| {
80                        #[allow(clippy::cast_possible_truncation)]
81                        values_mask.value(codes_slice[idx] as usize)
82                    })
83                });
84                Mask::from_buffer(is_valid_buffer)
85            }
86            AllOr::None => Mask::AllFalse(array.len()),
87            AllOr::Some(validity_buff) => {
88                let primitive_codes = array.codes().to_primitive();
89                let values_mask = array.values().validity_mask();
90                let is_valid_buffer = match_each_integer_ptype!(primitive_codes.ptype(), |P| {
91                    let codes_slice = primitive_codes.as_slice::<P>();
92                    #[allow(clippy::cast_possible_truncation)]
93                    BitBuffer::collect_bool(array.len(), |idx| {
94                        validity_buff.value(idx) && values_mask.value(codes_slice[idx] as usize)
95                    })
96                });
97                Mask::from_buffer(is_valid_buffer)
98            }
99        }
100    }
101}