vortex_array/
validity.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Array validity and nullability behavior, used by arrays and compute functions.
5
6use std::fmt::Debug;
7use std::ops::{BitAnd, Not, Range};
8
9use arrow_buffer::{BooleanBuffer, NullBuffer};
10use vortex_dtype::{DType, Nullability};
11use vortex_error::{VortexExpect as _, VortexResult, vortex_err, vortex_panic};
12use vortex_mask::{AllOr, Mask, MaskValues};
13use vortex_scalar::Scalar;
14
15use crate::arrays::{BoolArray, ConstantArray};
16use crate::compute::{fill_null, filter, sum, take};
17use crate::patches::Patches;
18use crate::{Array, ArrayRef, IntoArray, ToCanonical};
19
20/// Validity information for an array
21#[derive(Clone, Debug)]
22pub enum Validity {
23    /// Items *can't* be null
24    NonNullable,
25    /// All items are valid
26    AllValid,
27    /// All items are null
28    AllInvalid,
29    /// Specified items are null
30    Array(ArrayRef),
31}
32
33impl Validity {
34    /// The [`DType`] of the underlying validity array (if it exists).
35    pub const DTYPE: DType = DType::Bool(Nullability::NonNullable);
36
37    /// If Validity is [`Validity::Array`], returns the array, otherwise returns `None`.
38    #[inline]
39    pub fn into_array(self) -> Option<ArrayRef> {
40        if let Self::Array(a) = self {
41            Some(a)
42        } else {
43            None
44        }
45    }
46
47    /// If Validity is [`Validity::Array`], returns a reference to the array array, otherwise returns `None`.
48    #[inline]
49    pub fn as_array(&self) -> Option<&ArrayRef> {
50        if let Self::Array(a) = self {
51            Some(a)
52        } else {
53            None
54        }
55    }
56
57    #[inline]
58    pub fn nullability(&self) -> Nullability {
59        if matches!(self, Self::NonNullable) {
60            Nullability::NonNullable
61        } else {
62            Nullability::Nullable
63        }
64    }
65
66    /// The union nullability and validity.
67    #[inline]
68    pub fn union_nullability(self, nullability: Nullability) -> Self {
69        match nullability {
70            Nullability::NonNullable => self,
71            Nullability::Nullable => self.into_nullable(),
72        }
73    }
74
75    #[inline]
76    pub fn all_valid(&self) -> bool {
77        match self {
78            Validity::NonNullable | Validity::AllValid => true,
79            Validity::AllInvalid => false,
80            Validity::Array(array) => {
81                usize::try_from(&sum(array).vortex_expect("must have sum for bool array"))
82                    .vortex_expect("sum must be a usize")
83                    == array.len()
84            }
85        }
86    }
87
88    #[inline]
89    pub fn all_invalid(&self) -> bool {
90        match self {
91            Validity::NonNullable | Validity::AllValid => false,
92            Validity::AllInvalid => true,
93            Validity::Array(array) => {
94                usize::try_from(&sum(array).vortex_expect("must have sum for bool array"))
95                    .vortex_expect("sum must be a usize")
96                    == 0
97            }
98        }
99    }
100
101    /// Returns whether the `index` item is valid.
102    #[inline]
103    pub fn is_valid(&self, index: usize) -> bool {
104        match self {
105            Self::NonNullable | Self::AllValid => true,
106            Self::AllInvalid => false,
107            Self::Array(a) => {
108                let scalar = a.scalar_at(index);
109                scalar
110                    .as_bool()
111                    .value()
112                    .vortex_expect("Validity must be non-nullable")
113            }
114        }
115    }
116
117    #[inline]
118    pub fn is_null(&self, index: usize) -> bool {
119        !self.is_valid(index)
120    }
121
122    #[inline]
123    pub fn slice(&self, range: Range<usize>) -> Self {
124        match self {
125            Self::Array(a) => Self::Array(a.slice(range)),
126            Self::NonNullable | Self::AllValid | Self::AllInvalid => self.clone(),
127        }
128    }
129
130    pub fn take(&self, indices: &dyn Array) -> VortexResult<Self> {
131        match self {
132            Self::NonNullable => match indices.validity_mask().boolean_buffer() {
133                AllOr::All => {
134                    if indices.dtype().is_nullable() {
135                        Ok(Self::AllValid)
136                    } else {
137                        Ok(Self::NonNullable)
138                    }
139                }
140                AllOr::None => Ok(Self::AllInvalid),
141                AllOr::Some(buf) => Ok(Validity::from(buf.clone())),
142            },
143            Self::AllValid => match indices.validity_mask().boolean_buffer() {
144                AllOr::All => Ok(Self::AllValid),
145                AllOr::None => Ok(Self::AllInvalid),
146                AllOr::Some(buf) => Ok(Validity::from(buf.clone())),
147            },
148            Self::AllInvalid => Ok(Self::AllInvalid),
149            Self::Array(is_valid) => {
150                let maybe_is_valid = take(is_valid, indices)?;
151                // Null indices invalidate that position.
152                let is_valid = fill_null(&maybe_is_valid, &Scalar::from(false))?;
153                Ok(Self::Array(is_valid))
154            }
155        }
156    }
157
158    /// Keep only the entries for which the mask is true.
159    ///
160    /// The result has length equal to the number of true values in mask.
161    pub fn filter(&self, mask: &Mask) -> VortexResult<Self> {
162        // NOTE(ngates): we take the mask as a reference to avoid the caller cloning unnecessarily
163        //  if we happen to be NonNullable, AllValid, or AllInvalid.
164        match self {
165            v @ (Validity::NonNullable | Validity::AllValid | Validity::AllInvalid) => {
166                Ok(v.clone())
167            }
168            Validity::Array(arr) => Ok(Validity::Array(filter(arr, mask)?)),
169        }
170    }
171
172    /// Set to false any entries for which the mask is true.
173    ///
174    /// The result is always nullable. The result has the same length as self.
175    #[inline]
176    pub fn mask(&self, mask: &Mask) -> Self {
177        match mask.boolean_buffer() {
178            AllOr::All => Validity::AllInvalid,
179            AllOr::None => self.clone(),
180            AllOr::Some(make_invalid) => match self {
181                Validity::NonNullable | Validity::AllValid => {
182                    Validity::Array(BoolArray::from(make_invalid.not()).into_array())
183                }
184                Validity::AllInvalid => Validity::AllInvalid,
185                Validity::Array(is_valid) => {
186                    let is_valid = is_valid.to_bool();
187                    let keep_valid = make_invalid.not();
188                    Validity::from(is_valid.boolean_buffer().bitand(&keep_valid))
189                }
190            },
191        }
192    }
193
194    #[inline]
195    pub fn to_mask(&self, length: usize) -> Mask {
196        match self {
197            Self::NonNullable | Self::AllValid => Mask::AllTrue(length),
198            Self::AllInvalid => Mask::AllFalse(length),
199            Self::Array(is_valid) => {
200                assert_eq!(
201                    is_valid.len(),
202                    length,
203                    "Validity::Array length must equal to_logical's argument: {}, {}.",
204                    is_valid.len(),
205                    length,
206                );
207                is_valid.to_bool().to_mask()
208            }
209        }
210    }
211
212    /// Logically & two Validity values of the same length
213    #[inline]
214    pub fn and(self, rhs: Validity) -> Validity {
215        match (self, rhs) {
216            // Should be pretty clear
217            (Validity::NonNullable, Validity::NonNullable) => Validity::NonNullable,
218            // Any `AllInvalid` makes the output all invalid values
219            (Validity::AllInvalid, _) | (_, Validity::AllInvalid) => Validity::AllInvalid,
220            // All truthy values on one side, which makes no effect on an `Array` variant
221            (Validity::Array(a), Validity::AllValid)
222            | (Validity::Array(a), Validity::NonNullable)
223            | (Validity::NonNullable, Validity::Array(a))
224            | (Validity::AllValid, Validity::Array(a)) => Validity::Array(a),
225            // Both sides are all valid
226            (Validity::NonNullable, Validity::AllValid)
227            | (Validity::AllValid, Validity::NonNullable)
228            | (Validity::AllValid, Validity::AllValid) => Validity::AllValid,
229            // Here we actually have to do some work
230            (Validity::Array(lhs), Validity::Array(rhs)) => {
231                let lhs = lhs.to_bool();
232                let rhs = rhs.to_bool();
233
234                let lhs = lhs.boolean_buffer();
235                let rhs = rhs.boolean_buffer();
236
237                Validity::from(lhs.bitand(rhs))
238            }
239        }
240    }
241
242    pub fn patch(
243        self,
244        len: usize,
245        indices_offset: usize,
246        indices: &dyn Array,
247        patches: &Validity,
248    ) -> Self {
249        match (&self, patches) {
250            (Validity::NonNullable, Validity::NonNullable) => return Validity::NonNullable,
251            (Validity::NonNullable, _) => {
252                vortex_panic!("Can't patch a non-nullable validity with nullable validity")
253            }
254            (_, Validity::NonNullable) => {
255                vortex_panic!("Can't patch a nullable validity with non-nullable validity")
256            }
257            (Validity::AllValid, Validity::AllValid) => return Validity::AllValid,
258            (Validity::AllInvalid, Validity::AllInvalid) => return Validity::AllInvalid,
259            _ => {}
260        };
261
262        let own_nullability = if self == Validity::NonNullable {
263            Nullability::NonNullable
264        } else {
265            Nullability::Nullable
266        };
267
268        let source = match self {
269            Validity::NonNullable => BoolArray::from(BooleanBuffer::new_set(len)),
270            Validity::AllValid => BoolArray::from(BooleanBuffer::new_set(len)),
271            Validity::AllInvalid => BoolArray::from(BooleanBuffer::new_unset(len)),
272            Validity::Array(a) => a.to_bool(),
273        };
274
275        let patch_values = match patches {
276            Validity::NonNullable => BoolArray::from(BooleanBuffer::new_set(indices.len())),
277            Validity::AllValid => BoolArray::from(BooleanBuffer::new_set(indices.len())),
278            Validity::AllInvalid => BoolArray::from(BooleanBuffer::new_unset(indices.len())),
279            Validity::Array(a) => a.to_bool(),
280        };
281
282        let patches = Patches::new(
283            len,
284            indices_offset,
285            indices.to_array(),
286            patch_values.into_array(),
287        );
288
289        Self::from_array(source.patch(&patches).into_array(), own_nullability)
290    }
291
292    /// Convert into a nullable variant
293    #[inline]
294    pub fn into_nullable(self) -> Validity {
295        match self {
296            Self::NonNullable => Self::AllValid,
297            Self::AllValid | Self::AllInvalid | Self::Array(_) => self,
298        }
299    }
300
301    /// Convert into a non-nullable variant
302    #[inline]
303    pub fn into_non_nullable(self) -> Option<Validity> {
304        match self {
305            Self::NonNullable => Some(Self::NonNullable),
306            Self::AllValid => Some(Self::NonNullable),
307            Self::AllInvalid => None,
308            Self::Array(is_valid) => {
309                is_valid
310                    .statistics()
311                    .compute_min::<bool>()
312                    .vortex_expect("validity array must support min")
313                    .then(|| {
314                        // min true => all true
315                        Self::NonNullable
316                    })
317            }
318        }
319    }
320
321    /// Convert into a variant compatible with the given nullability, if possible.
322    #[inline]
323    pub fn cast_nullability(self, nullability: Nullability) -> VortexResult<Validity> {
324        match nullability {
325            Nullability::NonNullable => self.into_non_nullable().ok_or_else(|| {
326                vortex_err!("Cannot cast array with invalid values to non-nullable type.")
327            }),
328            Nullability::Nullable => Ok(self.into_nullable()),
329        }
330    }
331
332    /// Create Validity by copying the given array's validity.
333    #[inline]
334    pub fn copy_from_array(array: &dyn Array) -> Self {
335        Validity::from_mask(array.validity_mask(), array.dtype().nullability())
336    }
337
338    /// Create Validity from boolean array with given nullability of the array.
339    ///
340    /// Note: You want to pass the nullability of parent array and not the nullability of the validity array itself
341    ///     as that is always nonnullable
342    #[inline]
343    fn from_array(value: ArrayRef, nullability: Nullability) -> Self {
344        if !matches!(value.dtype(), DType::Bool(Nullability::NonNullable)) {
345            vortex_panic!("Expected a non-nullable boolean array")
346        }
347        match nullability {
348            Nullability::NonNullable => Self::NonNullable,
349            Nullability::Nullable => Self::Array(value),
350        }
351    }
352
353    /// Returns the length of the validity array, if it exists.
354    #[inline]
355    pub fn maybe_len(&self) -> Option<usize> {
356        match self {
357            Self::NonNullable | Self::AllValid | Self::AllInvalid => None,
358            Self::Array(a) => Some(a.len()),
359        }
360    }
361
362    #[inline]
363    pub fn uncompressed_size(&self) -> usize {
364        if let Validity::Array(a) = self {
365            a.len().div_ceil(8)
366        } else {
367            0
368        }
369    }
370}
371
372impl PartialEq for Validity {
373    #[inline]
374    fn eq(&self, other: &Self) -> bool {
375        match (self, other) {
376            (Self::NonNullable, Self::NonNullable) => true,
377            (Self::AllValid, Self::AllValid) => true,
378            (Self::AllInvalid, Self::AllInvalid) => true,
379            (Self::Array(a), Self::Array(b)) => {
380                let a = a.to_bool();
381                let b = b.to_bool();
382                a.boolean_buffer() == b.boolean_buffer()
383            }
384            _ => false,
385        }
386    }
387}
388
389impl From<BooleanBuffer> for Validity {
390    #[inline]
391    fn from(value: BooleanBuffer) -> Self {
392        if value.count_set_bits() == value.len() {
393            Self::AllValid
394        } else if value.count_set_bits() == 0 {
395            Self::AllInvalid
396        } else {
397            Self::Array(BoolArray::from(value).into_array())
398        }
399    }
400}
401
402impl From<NullBuffer> for Validity {
403    #[inline]
404    fn from(value: NullBuffer) -> Self {
405        value.into_inner().into()
406    }
407}
408
409impl FromIterator<Mask> for Validity {
410    #[inline]
411    fn from_iter<T: IntoIterator<Item = Mask>>(iter: T) -> Self {
412        Validity::from_mask(iter.into_iter().collect(), Nullability::Nullable)
413    }
414}
415
416impl FromIterator<bool> for Validity {
417    #[inline]
418    fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
419        Validity::from(BooleanBuffer::from_iter(iter))
420    }
421}
422
423impl From<Nullability> for Validity {
424    #[inline]
425    fn from(value: Nullability) -> Self {
426        match value {
427            Nullability::NonNullable => Validity::NonNullable,
428            Nullability::Nullable => Validity::AllValid,
429        }
430    }
431}
432
433impl Validity {
434    pub fn from_mask(mask: Mask, nullability: Nullability) -> Self {
435        assert!(
436            nullability == Nullability::Nullable || matches!(mask, Mask::AllTrue(_)),
437            "NonNullable validity must be AllValid",
438        );
439        match mask {
440            Mask::AllTrue(_) => match nullability {
441                Nullability::NonNullable => Validity::NonNullable,
442                Nullability::Nullable => Validity::AllValid,
443            },
444            Mask::AllFalse(_) => Validity::AllInvalid,
445            Mask::Values(values) => Validity::Array(values.into_array()),
446        }
447    }
448}
449
450impl IntoArray for Mask {
451    #[inline]
452    fn into_array(self) -> ArrayRef {
453        match self {
454            Self::AllTrue(len) => ConstantArray::new(true, len).into_array(),
455            Self::AllFalse(len) => ConstantArray::new(false, len).into_array(),
456            Self::Values(a) => a.into_array(),
457        }
458    }
459}
460
461impl IntoArray for &MaskValues {
462    #[inline]
463    fn into_array(self) -> ArrayRef {
464        BoolArray::new(self.boolean_buffer().clone(), Validity::NonNullable).into_array()
465    }
466}
467
468#[cfg(test)]
469mod tests {
470    use rstest::rstest;
471    use vortex_buffer::{Buffer, buffer};
472    use vortex_dtype::Nullability;
473    use vortex_mask::Mask;
474
475    use crate::arrays::{BoolArray, PrimitiveArray};
476    use crate::validity::Validity;
477    use crate::{ArrayRef, IntoArray};
478
479    #[rstest]
480    #[case(Validity::AllValid, 5, &[2, 4], Validity::AllValid, Validity::AllValid)]
481    #[case(Validity::AllValid, 5, &[2, 4], Validity::AllInvalid, Validity::Array(BoolArray::from_iter([true, true, false, true, false]).into_array())
482    )]
483    #[case(Validity::AllValid, 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([true, true, true, true, false]).into_array())
484    )]
485    #[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllValid, Validity::Array(BoolArray::from_iter([false, false, true, false, true]).into_array())
486    )]
487    #[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllInvalid, Validity::AllInvalid)]
488    #[case(Validity::AllInvalid, 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([false, false, true, false, false]).into_array())
489    )]
490    #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::AllValid, Validity::Array(BoolArray::from_iter([false, true, true, true, true]).into_array())
491    )]
492    #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::AllInvalid, Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array())
493    )]
494    #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([false, true, true, true, false]).into_array())
495    )]
496    fn patch_validity(
497        #[case] validity: Validity,
498        #[case] len: usize,
499        #[case] positions: &[u64],
500        #[case] patches: Validity,
501        #[case] expected: Validity,
502    ) {
503        let indices =
504            PrimitiveArray::new(Buffer::copy_from(positions), Validity::NonNullable).into_array();
505        assert_eq!(validity.patch(len, 0, &indices, &patches), expected);
506    }
507
508    #[test]
509    #[should_panic]
510    fn out_of_bounds_patch() {
511        Validity::NonNullable.patch(2, 0, &buffer![4].into_array(), &Validity::AllInvalid);
512    }
513
514    #[test]
515    #[should_panic]
516    fn into_validity_nullable() {
517        Validity::from_mask(Mask::AllFalse(10), Nullability::NonNullable);
518    }
519
520    #[test]
521    #[should_panic]
522    fn into_validity_nullable_array() {
523        Validity::from_mask(Mask::from_iter(vec![true, false]), Nullability::NonNullable);
524    }
525
526    #[rstest]
527    #[case(Validity::AllValid, PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(), Validity::from_iter(vec![true, false]))]
528    #[case(Validity::AllValid, buffer![0, 1].into_array(), Validity::AllValid)]
529    #[case(Validity::AllValid, PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(), Validity::AllInvalid)]
530    #[case(Validity::NonNullable, PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(), Validity::from_iter(vec![true, false]))]
531    #[case(Validity::NonNullable, buffer![0, 1].into_array(), Validity::NonNullable)]
532    #[case(Validity::NonNullable, PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(), Validity::AllInvalid)]
533    fn validity_take(
534        #[case] validity: Validity,
535        #[case] indices: ArrayRef,
536        #[case] expected: Validity,
537    ) {
538        assert_eq!(validity.take(&indices).unwrap(), expected);
539    }
540}