Skip to main content

vortex_array/arrays/masked/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Display;
5use std::fmt::Formatter;
6
7use smallvec::smallvec;
8use vortex_error::VortexResult;
9use vortex_error::vortex_bail;
10
11use crate::ArrayRef;
12use crate::LEGACY_SESSION;
13use crate::VortexSessionExecute;
14use crate::array::Array;
15use crate::array::ArrayParts;
16use crate::array::TypedArrayRef;
17use crate::array::child_to_validity;
18use crate::array::validity_to_child;
19use crate::array_slots;
20use crate::arrays::Masked;
21use crate::validity::Validity;
22
23#[array_slots(Masked)]
24pub struct MaskedSlots {
25    /// The underlying child array being masked.
26    pub child: ArrayRef,
27    /// The validity bitmap defining which elements are non-null.
28    pub validity: Option<ArrayRef>,
29}
30
31#[derive(Clone, Debug)]
32pub struct MaskedData;
33
34impl Display for MaskedData {
35    fn fmt(&self, _f: &mut Formatter<'_>) -> std::fmt::Result {
36        Ok(())
37    }
38}
39
40pub trait MaskedArrayExt: TypedArrayRef<Masked> + MaskedArraySlotsExt {
41    fn masked_validity(&self) -> Validity {
42        child_to_validity(
43            self.as_ref().slots()[MaskedSlots::VALIDITY].as_ref(),
44            self.as_ref().dtype().nullability(),
45        )
46    }
47}
48impl<T: TypedArrayRef<Masked>> MaskedArrayExt for T {}
49
50impl MaskedData {
51    pub(crate) fn try_new(
52        child_len: usize,
53        child_all_valid: bool,
54        validity: Validity,
55    ) -> VortexResult<Self> {
56        if matches!(validity, Validity::NonNullable) {
57            vortex_bail!("MaskedArray must have nullable validity, got {validity:?}")
58        }
59
60        if !child_all_valid {
61            vortex_bail!("MaskedArray children must not have nulls");
62        }
63
64        if let Some(validity_len) = validity.maybe_len()
65            && validity_len != child_len
66        {
67            vortex_bail!("Validity must be the same length as a MaskedArray's child");
68        }
69
70        // MaskedArray's nullability is determined solely by its validity, not the child's dtype.
71        // The child can have nullable dtype but must not have any actual null values.
72        Ok(Self)
73    }
74}
75
76impl Array<Masked> {
77    /// Constructs a new `MaskedArray`.
78    pub fn try_new(child: ArrayRef, validity: Validity) -> VortexResult<Self> {
79        let dtype = child.dtype().as_nullable();
80        let len = child.len();
81        let validity_slot = validity_to_child(&validity, len);
82        let data = MaskedData::try_new(
83            len,
84            child.all_valid(&mut LEGACY_SESSION.create_execution_ctx())?,
85            validity,
86        )?;
87        Ok(unsafe {
88            Array::from_parts_unchecked(
89                ArrayParts::new(Masked, dtype, len, data)
90                    .with_slots(smallvec![Some(child), validity_slot]),
91            )
92        })
93    }
94}