vortex_array/arrays/dict/vtable/
validity.rs1use vortex_buffer::BitBuffer;
5use vortex_dtype::match_each_integer_ptype;
6use vortex_error::VortexExpect;
7use vortex_error::VortexResult;
8use vortex_mask::AllOr;
9use vortex_mask::Mask;
10
11use super::DictVTable;
12use crate::Array;
13use crate::IntoArray;
14use crate::ToCanonical;
15use crate::arrays::dict::DictArray;
16use crate::validity::Validity;
17use crate::vtable::ValidityVTable;
18
19impl ValidityVTable<DictVTable> for DictVTable {
20 fn is_valid(array: &DictArray, index: usize) -> bool {
21 let scalar = array.codes().scalar_at(index);
22
23 if scalar.is_null() {
24 return false;
25 };
26 let values_index: usize = scalar
27 .as_ref()
28 .try_into()
29 .vortex_expect("Failed to convert dictionary code to usize");
30 array.values().is_valid(values_index)
31 }
32
33 fn all_valid(array: &DictArray) -> bool {
34 array.codes().all_valid() && array.values().all_valid()
35 }
36
37 fn all_invalid(array: &DictArray) -> bool {
38 array.codes().all_invalid() || array.values().all_invalid()
39 }
40
41 fn validity(array: &DictArray) -> VortexResult<Validity> {
42 Ok(
43 match (array.codes().validity()?, array.values().validity()?) {
44 (
45 Validity::NonNullable | Validity::AllValid,
46 Validity::NonNullable | Validity::AllValid,
47 ) => {
48 Validity::AllValid
50 }
51 (Validity::AllInvalid, _) | (_, Validity::AllInvalid) => Validity::AllInvalid,
52 (Validity::Array(codes_validity), Validity::NonNullable | Validity::AllValid) => {
53 Validity::Array(codes_validity)
54 }
55 (Validity::AllValid | Validity::NonNullable, Validity::Array(values_validity)) => {
56 Validity::Array(
57 unsafe { DictArray::new_unchecked(array.codes().clone(), values_validity) }
58 .into_array(),
59 )
60 }
61 (Validity::Array(_), Validity::Array(values_validity)) => {
62 unsafe { DictArray::new_unchecked(array.codes().clone(), values_validity) }
64 .into_array()
65 .validity()?
66 }
67 },
68 )
69 }
70
71 fn validity_mask(array: &DictArray) -> Mask {
72 let codes_validity = array.codes().validity_mask();
73 match codes_validity.bit_buffer() {
74 AllOr::All => {
75 let primitive_codes = array.codes().to_primitive();
76 let values_mask = array.values().validity_mask();
77 let is_valid_buffer = match_each_integer_ptype!(primitive_codes.ptype(), |P| {
78 let codes_slice = primitive_codes.as_slice::<P>();
79 BitBuffer::collect_bool(array.len(), |idx| {
80 #[allow(clippy::cast_possible_truncation)]
81 values_mask.value(codes_slice[idx] as usize)
82 })
83 });
84 Mask::from_buffer(is_valid_buffer)
85 }
86 AllOr::None => Mask::AllFalse(array.len()),
87 AllOr::Some(validity_buff) => {
88 let primitive_codes = array.codes().to_primitive();
89 let values_mask = array.values().validity_mask();
90 let is_valid_buffer = match_each_integer_ptype!(primitive_codes.ptype(), |P| {
91 let codes_slice = primitive_codes.as_slice::<P>();
92 #[allow(clippy::cast_possible_truncation)]
93 BitBuffer::collect_bool(array.len(), |idx| {
94 validity_buff.value(idx) && values_mask.value(codes_slice[idx] as usize)
95 })
96 });
97 Mask::from_buffer(is_valid_buffer)
98 }
99 }
100 }
101}