vortex_array/
validity.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Array validity and nullability behavior, used by arrays and compute functions.
5
6use std::fmt::Debug;
7use std::ops::{BitAnd, Not, Range};
8
9use arrow_buffer::{BooleanBuffer, NullBuffer};
10use vortex_dtype::{DType, Nullability};
11use vortex_error::{VortexExpect as _, VortexResult, vortex_err, vortex_panic};
12use vortex_mask::{AllOr, Mask, MaskValues};
13use vortex_scalar::Scalar;
14
15use crate::arrays::{BoolArray, ConstantArray};
16use crate::compute::{fill_null, filter, sum, take};
17use crate::patches::Patches;
18use crate::{Array, ArrayRef, IntoArray, ToCanonical};
19
20/// Validity information for an array
21#[derive(Clone, Debug)]
22pub enum Validity {
23    /// Items *can't* be null
24    NonNullable,
25    /// All items are valid
26    AllValid,
27    /// All items are null
28    AllInvalid,
29    /// Specified items are null
30    Array(ArrayRef),
31}
32
33impl Validity {
34    /// The [`DType`] of the underlying validity array (if it exists).
35    pub const DTYPE: DType = DType::Bool(Nullability::NonNullable);
36
37    /// If Validity is [`Validity::Array`], returns the array, otherwise returns `None`.
38    #[inline]
39    pub fn into_array(self) -> Option<ArrayRef> {
40        if let Self::Array(a) = self {
41            Some(a)
42        } else {
43            None
44        }
45    }
46
47    /// If Validity is [`Validity::Array`], returns a reference to the array array, otherwise returns `None`.
48    #[inline]
49    pub fn as_array(&self) -> Option<&ArrayRef> {
50        if let Self::Array(a) = self {
51            Some(a)
52        } else {
53            None
54        }
55    }
56
57    #[inline]
58    pub fn nullability(&self) -> Nullability {
59        if matches!(self, Self::NonNullable) {
60            Nullability::NonNullable
61        } else {
62            Nullability::Nullable
63        }
64    }
65
66    /// The union nullability and validity.
67    #[inline]
68    pub fn union_nullability(self, nullability: Nullability) -> Self {
69        match nullability {
70            Nullability::NonNullable => self,
71            Nullability::Nullable => self.into_nullable(),
72        }
73    }
74
75    #[inline]
76    pub fn all_valid(&self, len: usize) -> bool {
77        match self {
78            _ if len == 0 => true,
79            Validity::NonNullable | Validity::AllValid => true,
80            Validity::AllInvalid => false,
81            Validity::Array(array) => {
82                usize::try_from(&sum(array).vortex_expect("must have sum for bool array"))
83                    .vortex_expect("sum must be a usize")
84                    == array.len()
85            }
86        }
87    }
88
89    #[inline]
90    pub fn all_invalid(&self, len: usize) -> bool {
91        match self {
92            _ if len == 0 => true,
93            Validity::NonNullable | Validity::AllValid => false,
94            Validity::AllInvalid => true,
95            Validity::Array(array) => {
96                usize::try_from(&sum(array).vortex_expect("must have sum for bool array"))
97                    .vortex_expect("sum must be a usize")
98                    == 0
99            }
100        }
101    }
102
103    /// Returns whether the `index` item is valid.
104    #[inline]
105    pub fn is_valid(&self, index: usize) -> bool {
106        match self {
107            Self::NonNullable | Self::AllValid => true,
108            Self::AllInvalid => false,
109            Self::Array(a) => {
110                let scalar = a.scalar_at(index);
111                scalar
112                    .as_bool()
113                    .value()
114                    .vortex_expect("Validity must be non-nullable")
115            }
116        }
117    }
118
119    #[inline]
120    pub fn is_null(&self, index: usize) -> bool {
121        !self.is_valid(index)
122    }
123
124    #[inline]
125    pub fn slice(&self, range: Range<usize>) -> Self {
126        match self {
127            Self::Array(a) => Self::Array(a.slice(range)),
128            Self::NonNullable | Self::AllValid | Self::AllInvalid => self.clone(),
129        }
130    }
131
132    pub fn take(&self, indices: &dyn Array) -> VortexResult<Self> {
133        match self {
134            Self::NonNullable => match indices.validity_mask().boolean_buffer() {
135                AllOr::All => {
136                    if indices.dtype().is_nullable() {
137                        Ok(Self::AllValid)
138                    } else {
139                        Ok(Self::NonNullable)
140                    }
141                }
142                AllOr::None => Ok(Self::AllInvalid),
143                AllOr::Some(buf) => Ok(Validity::from(buf.clone())),
144            },
145            Self::AllValid => match indices.validity_mask().boolean_buffer() {
146                AllOr::All => Ok(Self::AllValid),
147                AllOr::None => Ok(Self::AllInvalid),
148                AllOr::Some(buf) => Ok(Validity::from(buf.clone())),
149            },
150            Self::AllInvalid => Ok(Self::AllInvalid),
151            Self::Array(is_valid) => {
152                let maybe_is_valid = take(is_valid, indices)?;
153                // Null indices invalidate that position.
154                let is_valid = fill_null(&maybe_is_valid, &Scalar::from(false))?;
155                Ok(Self::Array(is_valid))
156            }
157        }
158    }
159
160    /// Keep only the entries for which the mask is true.
161    ///
162    /// The result has length equal to the number of true values in mask.
163    pub fn filter(&self, mask: &Mask) -> VortexResult<Self> {
164        // NOTE(ngates): we take the mask as a reference to avoid the caller cloning unnecessarily
165        //  if we happen to be NonNullable, AllValid, or AllInvalid.
166        match self {
167            v @ (Validity::NonNullable | Validity::AllValid | Validity::AllInvalid) => {
168                Ok(v.clone())
169            }
170            Validity::Array(arr) => Ok(Validity::Array(filter(arr, mask)?)),
171        }
172    }
173
174    /// Set to false any entries for which the mask is true.
175    ///
176    /// The result is always nullable. The result has the same length as self.
177    #[inline]
178    pub fn mask(&self, mask: &Mask) -> Self {
179        match mask.boolean_buffer() {
180            AllOr::All => Validity::AllInvalid,
181            AllOr::None => self.clone(),
182            AllOr::Some(make_invalid) => match self {
183                Validity::NonNullable | Validity::AllValid => {
184                    Validity::Array(BoolArray::from(make_invalid.not()).into_array())
185                }
186                Validity::AllInvalid => Validity::AllInvalid,
187                Validity::Array(is_valid) => {
188                    let is_valid = is_valid.to_bool();
189                    let keep_valid = make_invalid.not();
190                    Validity::from(is_valid.boolean_buffer().bitand(&keep_valid))
191                }
192            },
193        }
194    }
195
196    #[inline]
197    pub fn to_mask(&self, length: usize) -> Mask {
198        match self {
199            Self::NonNullable | Self::AllValid => Mask::AllTrue(length),
200            Self::AllInvalid => Mask::AllFalse(length),
201            Self::Array(is_valid) => {
202                assert_eq!(
203                    is_valid.len(),
204                    length,
205                    "Validity::Array length must equal to_logical's argument: {}, {}.",
206                    is_valid.len(),
207                    length,
208                );
209                is_valid.to_bool().to_mask()
210            }
211        }
212    }
213
214    /// Logically & two Validity values of the same length
215    #[inline]
216    pub fn and(self, rhs: Validity) -> Validity {
217        match (self, rhs) {
218            // Should be pretty clear
219            (Validity::NonNullable, Validity::NonNullable) => Validity::NonNullable,
220            // Any `AllInvalid` makes the output all invalid values
221            (Validity::AllInvalid, _) | (_, Validity::AllInvalid) => Validity::AllInvalid,
222            // All truthy values on one side, which makes no effect on an `Array` variant
223            (Validity::Array(a), Validity::AllValid)
224            | (Validity::Array(a), Validity::NonNullable)
225            | (Validity::NonNullable, Validity::Array(a))
226            | (Validity::AllValid, Validity::Array(a)) => Validity::Array(a),
227            // Both sides are all valid
228            (Validity::NonNullable, Validity::AllValid)
229            | (Validity::AllValid, Validity::NonNullable)
230            | (Validity::AllValid, Validity::AllValid) => Validity::AllValid,
231            // Here we actually have to do some work
232            (Validity::Array(lhs), Validity::Array(rhs)) => {
233                let lhs = lhs.to_bool();
234                let rhs = rhs.to_bool();
235
236                let lhs = lhs.boolean_buffer();
237                let rhs = rhs.boolean_buffer();
238
239                Validity::from(lhs.bitand(rhs))
240            }
241        }
242    }
243
244    pub fn patch(
245        self,
246        len: usize,
247        indices_offset: usize,
248        indices: &dyn Array,
249        patches: &Validity,
250    ) -> Self {
251        match (&self, patches) {
252            (Validity::NonNullable, Validity::NonNullable) => return Validity::NonNullable,
253            (Validity::NonNullable, _) => {
254                vortex_panic!("Can't patch a non-nullable validity with nullable validity")
255            }
256            (_, Validity::NonNullable) => {
257                vortex_panic!("Can't patch a nullable validity with non-nullable validity")
258            }
259            (Validity::AllValid, Validity::AllValid) => return Validity::AllValid,
260            (Validity::AllInvalid, Validity::AllInvalid) => return Validity::AllInvalid,
261            _ => {}
262        };
263
264        let own_nullability = if self == Validity::NonNullable {
265            Nullability::NonNullable
266        } else {
267            Nullability::Nullable
268        };
269
270        let source = match self {
271            Validity::NonNullable => BoolArray::from(BooleanBuffer::new_set(len)),
272            Validity::AllValid => BoolArray::from(BooleanBuffer::new_set(len)),
273            Validity::AllInvalid => BoolArray::from(BooleanBuffer::new_unset(len)),
274            Validity::Array(a) => a.to_bool(),
275        };
276
277        let patch_values = match patches {
278            Validity::NonNullable => BoolArray::from(BooleanBuffer::new_set(indices.len())),
279            Validity::AllValid => BoolArray::from(BooleanBuffer::new_set(indices.len())),
280            Validity::AllInvalid => BoolArray::from(BooleanBuffer::new_unset(indices.len())),
281            Validity::Array(a) => a.to_bool(),
282        };
283
284        let patches = Patches::new(
285            len,
286            indices_offset,
287            indices.to_array(),
288            patch_values.into_array(),
289        );
290
291        Self::from_array(source.patch(&patches).into_array(), own_nullability)
292    }
293
294    /// Convert into a nullable variant
295    #[inline]
296    pub fn into_nullable(self) -> Validity {
297        match self {
298            Self::NonNullable => Self::AllValid,
299            Self::AllValid | Self::AllInvalid | Self::Array(_) => self,
300        }
301    }
302
303    /// Convert into a non-nullable variant
304    #[inline]
305    pub fn into_non_nullable(self, len: usize) -> Option<Validity> {
306        match self {
307            _ if len == 0 => Some(Validity::NonNullable),
308            Self::NonNullable => Some(Self::NonNullable),
309            Self::AllValid => Some(Self::NonNullable),
310            Self::AllInvalid => None,
311            Self::Array(is_valid) => {
312                is_valid
313                    .statistics()
314                    .compute_min::<bool>()
315                    .vortex_expect("validity array must support min")
316                    .then(|| {
317                        // min true => all true
318                        Self::NonNullable
319                    })
320            }
321        }
322    }
323
324    /// Convert into a variant compatible with the given nullability, if possible.
325    #[inline]
326    pub fn cast_nullability(self, nullability: Nullability, len: usize) -> VortexResult<Validity> {
327        match nullability {
328            Nullability::NonNullable => self.into_non_nullable(len).ok_or_else(|| {
329                vortex_err!("Cannot cast array with invalid values to non-nullable type.")
330            }),
331            Nullability::Nullable => Ok(self.into_nullable()),
332        }
333    }
334
335    /// Create Validity by copying the given array's validity.
336    #[inline]
337    pub fn copy_from_array(array: &dyn Array) -> Self {
338        Validity::from_mask(array.validity_mask(), array.dtype().nullability())
339    }
340
341    /// Create Validity from boolean array with given nullability of the array.
342    ///
343    /// Note: You want to pass the nullability of parent array and not the nullability of the validity array itself
344    ///     as that is always non-nullable
345    #[inline]
346    fn from_array(value: ArrayRef, nullability: Nullability) -> Self {
347        if !matches!(value.dtype(), DType::Bool(Nullability::NonNullable)) {
348            vortex_panic!("Expected a non-nullable boolean array")
349        }
350        match nullability {
351            Nullability::NonNullable => Self::NonNullable,
352            Nullability::Nullable => Self::Array(value),
353        }
354    }
355
356    /// Returns the length of the validity array, if it exists.
357    #[inline]
358    pub fn maybe_len(&self) -> Option<usize> {
359        match self {
360            Self::NonNullable | Self::AllValid | Self::AllInvalid => None,
361            Self::Array(a) => Some(a.len()),
362        }
363    }
364
365    #[inline]
366    pub fn uncompressed_size(&self) -> usize {
367        if let Validity::Array(a) = self {
368            a.len().div_ceil(8)
369        } else {
370            0
371        }
372    }
373}
374
375impl PartialEq for Validity {
376    #[inline]
377    fn eq(&self, other: &Self) -> bool {
378        match (self, other) {
379            (Self::NonNullable, Self::NonNullable) => true,
380            (Self::AllValid, Self::AllValid) => true,
381            (Self::AllInvalid, Self::AllInvalid) => true,
382            (Self::Array(a), Self::Array(b)) => {
383                let a = a.to_bool();
384                let b = b.to_bool();
385                a.boolean_buffer() == b.boolean_buffer()
386            }
387            _ => false,
388        }
389    }
390}
391
392impl From<BooleanBuffer> for Validity {
393    #[inline]
394    fn from(value: BooleanBuffer) -> Self {
395        if value.count_set_bits() == value.len() {
396            Self::AllValid
397        } else if value.count_set_bits() == 0 {
398            Self::AllInvalid
399        } else {
400            Self::Array(BoolArray::from(value).into_array())
401        }
402    }
403}
404
405impl From<NullBuffer> for Validity {
406    #[inline]
407    fn from(value: NullBuffer) -> Self {
408        value.into_inner().into()
409    }
410}
411
412impl FromIterator<Mask> for Validity {
413    #[inline]
414    fn from_iter<T: IntoIterator<Item = Mask>>(iter: T) -> Self {
415        Validity::from_mask(iter.into_iter().collect(), Nullability::Nullable)
416    }
417}
418
419impl FromIterator<bool> for Validity {
420    #[inline]
421    fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
422        Validity::from(BooleanBuffer::from_iter(iter))
423    }
424}
425
426impl From<Nullability> for Validity {
427    #[inline]
428    fn from(value: Nullability) -> Self {
429        match value {
430            Nullability::NonNullable => Validity::NonNullable,
431            Nullability::Nullable => Validity::AllValid,
432        }
433    }
434}
435
436impl Validity {
437    pub fn from_null_buffer(buffer: Option<NullBuffer>, nullability: Nullability) -> Self {
438        match buffer {
439            // If there are no nulls, then we infer from nullability
440            None => nullability.into(),
441            Some(nulls) => {
442                if nulls.null_count() == nulls.len() {
443                    Validity::AllInvalid
444                } else {
445                    Validity::Array(BoolArray::from(nulls.into_inner()).into_array())
446                }
447            }
448        }
449    }
450
451    pub fn from_mask(mask: Mask, nullability: Nullability) -> Self {
452        assert!(
453            nullability == Nullability::Nullable || matches!(mask, Mask::AllTrue(_)),
454            "NonNullable validity must be AllValid",
455        );
456        match mask {
457            Mask::AllTrue(_) => match nullability {
458                Nullability::NonNullable => Validity::NonNullable,
459                Nullability::Nullable => Validity::AllValid,
460            },
461            Mask::AllFalse(_) => Validity::AllInvalid,
462            Mask::Values(values) => Validity::Array(values.into_array()),
463        }
464    }
465}
466
467impl IntoArray for Mask {
468    #[inline]
469    fn into_array(self) -> ArrayRef {
470        match self {
471            Self::AllTrue(len) => ConstantArray::new(true, len).into_array(),
472            Self::AllFalse(len) => ConstantArray::new(false, len).into_array(),
473            Self::Values(a) => a.into_array(),
474        }
475    }
476}
477
478impl IntoArray for &MaskValues {
479    #[inline]
480    fn into_array(self) -> ArrayRef {
481        BoolArray::from_bool_buffer(self.boolean_buffer().clone(), Validity::NonNullable)
482            .into_array()
483    }
484}
485
486#[cfg(test)]
487mod tests {
488    use rstest::rstest;
489    use vortex_buffer::{Buffer, buffer};
490    use vortex_dtype::Nullability;
491    use vortex_mask::Mask;
492
493    use crate::arrays::{BoolArray, PrimitiveArray};
494    use crate::validity::Validity;
495    use crate::{ArrayRef, IntoArray};
496
497    #[rstest]
498    #[case(Validity::AllValid, 5, &[2, 4], Validity::AllValid, Validity::AllValid)]
499    #[case(Validity::AllValid, 5, &[2, 4], Validity::AllInvalid, Validity::Array(BoolArray::from_iter([true, true, false, true, false]).into_array())
500    )]
501    #[case(Validity::AllValid, 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([true, true, true, true, false]).into_array())
502    )]
503    #[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllValid, Validity::Array(BoolArray::from_iter([false, false, true, false, true]).into_array())
504    )]
505    #[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllInvalid, Validity::AllInvalid)]
506    #[case(Validity::AllInvalid, 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([false, false, true, false, false]).into_array())
507    )]
508    #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::AllValid, Validity::Array(BoolArray::from_iter([false, true, true, true, true]).into_array())
509    )]
510    #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::AllInvalid, Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array())
511    )]
512    #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([false, true, true, true, false]).into_array())
513    )]
514    fn patch_validity(
515        #[case] validity: Validity,
516        #[case] len: usize,
517        #[case] positions: &[u64],
518        #[case] patches: Validity,
519        #[case] expected: Validity,
520    ) {
521        let indices =
522            PrimitiveArray::new(Buffer::copy_from(positions), Validity::NonNullable).into_array();
523        assert_eq!(validity.patch(len, 0, &indices, &patches), expected);
524    }
525
526    #[test]
527    #[should_panic]
528    fn out_of_bounds_patch() {
529        Validity::NonNullable.patch(2, 0, &buffer![4].into_array(), &Validity::AllInvalid);
530    }
531
532    #[test]
533    #[should_panic]
534    fn into_validity_nullable() {
535        Validity::from_mask(Mask::AllFalse(10), Nullability::NonNullable);
536    }
537
538    #[test]
539    #[should_panic]
540    fn into_validity_nullable_array() {
541        Validity::from_mask(Mask::from_iter(vec![true, false]), Nullability::NonNullable);
542    }
543
544    #[rstest]
545    #[case(Validity::AllValid, PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(), Validity::from_iter(vec![true, false]))]
546    #[case(Validity::AllValid, buffer![0, 1].into_array(), Validity::AllValid)]
547    #[case(Validity::AllValid, PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(), Validity::AllInvalid)]
548    #[case(Validity::NonNullable, PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(), Validity::from_iter(vec![true, false]))]
549    #[case(Validity::NonNullable, buffer![0, 1].into_array(), Validity::NonNullable)]
550    #[case(Validity::NonNullable, PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(), Validity::AllInvalid)]
551    fn validity_take(
552        #[case] validity: Validity,
553        #[case] indices: ArrayRef,
554        #[case] expected: Validity,
555    ) {
556        assert_eq!(validity.take(&indices).unwrap(), expected);
557    }
558}