vortex_array/arrays/dict/vtable/
validity.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_buffer::BitBuffer;
5use vortex_dtype::match_each_integer_ptype;
6use vortex_error::VortexExpect;
7use vortex_mask::AllOr;
8use vortex_mask::Mask;
9
10use super::DictVTable;
11use crate::Array;
12use crate::ToCanonical;
13use crate::arrays::dict::DictArray;
14use crate::vtable::ValidityVTable;
15
16impl ValidityVTable<DictVTable> for DictVTable {
17    fn is_valid(array: &DictArray, index: usize) -> bool {
18        let scalar = array.codes().scalar_at(index);
19
20        if scalar.is_null() {
21            return false;
22        };
23        let values_index: usize = scalar
24            .as_ref()
25            .try_into()
26            .vortex_expect("Failed to convert dictionary code to usize");
27        array.values().is_valid(values_index)
28    }
29
30    fn all_valid(array: &DictArray) -> bool {
31        array.codes().all_valid() && array.values().all_valid()
32    }
33
34    fn all_invalid(array: &DictArray) -> bool {
35        array.codes().all_invalid() || array.values().all_invalid()
36    }
37
38    fn validity_mask(array: &DictArray) -> Mask {
39        let codes_validity = array.codes().validity_mask();
40        match codes_validity.bit_buffer() {
41            AllOr::All => {
42                let primitive_codes = array.codes().to_primitive();
43                let values_mask = array.values().validity_mask();
44                let is_valid_buffer = match_each_integer_ptype!(primitive_codes.ptype(), |P| {
45                    let codes_slice = primitive_codes.as_slice::<P>();
46                    BitBuffer::collect_bool(array.len(), |idx| {
47                        #[allow(clippy::cast_possible_truncation)]
48                        values_mask.value(codes_slice[idx] as usize)
49                    })
50                });
51                Mask::from_buffer(is_valid_buffer)
52            }
53            AllOr::None => Mask::AllFalse(array.len()),
54            AllOr::Some(validity_buff) => {
55                let primitive_codes = array.codes().to_primitive();
56                let values_mask = array.values().validity_mask();
57                let is_valid_buffer = match_each_integer_ptype!(primitive_codes.ptype(), |P| {
58                    let codes_slice = primitive_codes.as_slice::<P>();
59                    #[allow(clippy::cast_possible_truncation)]
60                    BitBuffer::collect_bool(array.len(), |idx| {
61                        validity_buff.value(idx) && values_mask.value(codes_slice[idx] as usize)
62                    })
63                });
64                Mask::from_buffer(is_valid_buffer)
65            }
66        }
67    }
68}