vortex_array/
validity.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Array validity and nullability behavior, used by arrays and compute functions.
5
6use std::fmt::Debug;
7use std::ops::{BitAnd, Not, Range};
8
9use arrow_buffer::{BooleanBuffer, NullBuffer};
10use vortex_dtype::{DType, Nullability};
11use vortex_error::{VortexExpect as _, VortexResult, vortex_err, vortex_panic};
12use vortex_mask::{AllOr, Mask, MaskValues};
13use vortex_scalar::Scalar;
14
15use crate::arrays::{BoolArray, ConstantArray};
16use crate::compute::{fill_null, filter, sum, take};
17use crate::patches::Patches;
18use crate::{Array, ArrayRef, IntoArray, ToCanonical};
19
20/// Validity information for an array
21#[derive(Clone, Debug)]
22pub enum Validity {
23    /// Items *can't* be null
24    NonNullable,
25    /// All items are valid
26    AllValid,
27    /// All items are null
28    AllInvalid,
29    /// The validity of each position in the array is determined by a boolean array.
30    ///
31    /// True values are valid, false values are invalid ("null").
32    Array(ArrayRef),
33}
34
35impl Validity {
36    /// The [`DType`] of the underlying validity array (if it exists).
37    pub const DTYPE: DType = DType::Bool(Nullability::NonNullable);
38
39    /// If Validity is [`Validity::Array`], returns the array, otherwise returns `None`.
40    #[inline]
41    pub fn into_array(self) -> Option<ArrayRef> {
42        if let Self::Array(a) = self {
43            Some(a)
44        } else {
45            None
46        }
47    }
48
49    /// If Validity is [`Validity::Array`], returns a reference to the array array, otherwise returns `None`.
50    #[inline]
51    pub fn as_array(&self) -> Option<&ArrayRef> {
52        if let Self::Array(a) = self {
53            Some(a)
54        } else {
55            None
56        }
57    }
58
59    #[inline]
60    pub fn nullability(&self) -> Nullability {
61        if matches!(self, Self::NonNullable) {
62            Nullability::NonNullable
63        } else {
64            Nullability::Nullable
65        }
66    }
67
68    /// The union nullability and validity.
69    #[inline]
70    pub fn union_nullability(self, nullability: Nullability) -> Self {
71        match nullability {
72            Nullability::NonNullable => self,
73            Nullability::Nullable => self.into_nullable(),
74        }
75    }
76
77    #[inline]
78    pub fn all_valid(&self, len: usize) -> bool {
79        match self {
80            _ if len == 0 => true,
81            Validity::NonNullable | Validity::AllValid => true,
82            Validity::AllInvalid => false,
83            Validity::Array(array) => {
84                usize::try_from(&sum(array).vortex_expect("must have sum for bool array"))
85                    .vortex_expect("sum must be a usize")
86                    == array.len()
87            }
88        }
89    }
90
91    #[inline]
92    pub fn all_invalid(&self, len: usize) -> bool {
93        match self {
94            _ if len == 0 => true,
95            Validity::NonNullable | Validity::AllValid => false,
96            Validity::AllInvalid => true,
97            Validity::Array(array) => {
98                usize::try_from(&sum(array).vortex_expect("must have sum for bool array"))
99                    .vortex_expect("sum must be a usize")
100                    == 0
101            }
102        }
103    }
104
105    /// Returns whether the `index` item is valid.
106    #[inline]
107    pub fn is_valid(&self, index: usize) -> bool {
108        match self {
109            Self::NonNullable | Self::AllValid => true,
110            Self::AllInvalid => false,
111            Self::Array(a) => {
112                let scalar = a.scalar_at(index);
113                scalar
114                    .as_bool()
115                    .value()
116                    .vortex_expect("Validity must be non-nullable")
117            }
118        }
119    }
120
121    #[inline]
122    pub fn is_null(&self, index: usize) -> bool {
123        !self.is_valid(index)
124    }
125
126    #[inline]
127    pub fn slice(&self, range: Range<usize>) -> Self {
128        match self {
129            Self::Array(a) => Self::Array(a.slice(range)),
130            Self::NonNullable | Self::AllValid | Self::AllInvalid => self.clone(),
131        }
132    }
133
134    pub fn take(&self, indices: &dyn Array) -> VortexResult<Self> {
135        match self {
136            Self::NonNullable => match indices.validity_mask().boolean_buffer() {
137                AllOr::All => {
138                    if indices.dtype().is_nullable() {
139                        Ok(Self::AllValid)
140                    } else {
141                        Ok(Self::NonNullable)
142                    }
143                }
144                AllOr::None => Ok(Self::AllInvalid),
145                AllOr::Some(buf) => Ok(Validity::from(buf.clone())),
146            },
147            Self::AllValid => match indices.validity_mask().boolean_buffer() {
148                AllOr::All => Ok(Self::AllValid),
149                AllOr::None => Ok(Self::AllInvalid),
150                AllOr::Some(buf) => Ok(Validity::from(buf.clone())),
151            },
152            Self::AllInvalid => Ok(Self::AllInvalid),
153            Self::Array(is_valid) => {
154                let maybe_is_valid = take(is_valid, indices)?;
155                // Null indices invalidate that position.
156                let is_valid = fill_null(&maybe_is_valid, &Scalar::from(false))?;
157                Ok(Self::Array(is_valid))
158            }
159        }
160    }
161
162    /// Keep only the entries for which the mask is true.
163    ///
164    /// The result has length equal to the number of true values in mask.
165    pub fn filter(&self, mask: &Mask) -> VortexResult<Self> {
166        // NOTE(ngates): we take the mask as a reference to avoid the caller cloning unnecessarily
167        //  if we happen to be NonNullable, AllValid, or AllInvalid.
168        match self {
169            v @ (Validity::NonNullable | Validity::AllValid | Validity::AllInvalid) => {
170                Ok(v.clone())
171            }
172            Validity::Array(arr) => Ok(Validity::Array(filter(arr, mask)?)),
173        }
174    }
175
176    /// Set to false any entries for which the mask is true.
177    ///
178    /// The result is always nullable. The result has the same length as self.
179    #[inline]
180    pub fn mask(&self, mask: &Mask) -> Self {
181        match mask.boolean_buffer() {
182            AllOr::All => Validity::AllInvalid,
183            AllOr::None => self.clone(),
184            AllOr::Some(make_invalid) => match self {
185                Validity::NonNullable | Validity::AllValid => {
186                    Validity::Array(BoolArray::from(make_invalid.not()).into_array())
187                }
188                Validity::AllInvalid => Validity::AllInvalid,
189                Validity::Array(is_valid) => {
190                    let is_valid = is_valid.to_bool();
191                    let keep_valid = make_invalid.not();
192                    Validity::from(is_valid.boolean_buffer().bitand(&keep_valid))
193                }
194            },
195        }
196    }
197
198    #[inline]
199    pub fn to_mask(&self, length: usize) -> Mask {
200        match self {
201            Self::NonNullable | Self::AllValid => Mask::AllTrue(length),
202            Self::AllInvalid => Mask::AllFalse(length),
203            Self::Array(is_valid) => {
204                assert_eq!(
205                    is_valid.len(),
206                    length,
207                    "Validity::Array length must equal to_logical's argument: {}, {}.",
208                    is_valid.len(),
209                    length,
210                );
211                is_valid.to_bool().to_mask()
212            }
213        }
214    }
215
216    /// Logically & two Validity values of the same length
217    #[inline]
218    pub fn and(self, rhs: Validity) -> Validity {
219        match (self, rhs) {
220            // Should be pretty clear
221            (Validity::NonNullable, Validity::NonNullable) => Validity::NonNullable,
222            // Any `AllInvalid` makes the output all invalid values
223            (Validity::AllInvalid, _) | (_, Validity::AllInvalid) => Validity::AllInvalid,
224            // All truthy values on one side, which makes no effect on an `Array` variant
225            (Validity::Array(a), Validity::AllValid)
226            | (Validity::Array(a), Validity::NonNullable)
227            | (Validity::NonNullable, Validity::Array(a))
228            | (Validity::AllValid, Validity::Array(a)) => Validity::Array(a),
229            // Both sides are all valid
230            (Validity::NonNullable, Validity::AllValid)
231            | (Validity::AllValid, Validity::NonNullable)
232            | (Validity::AllValid, Validity::AllValid) => Validity::AllValid,
233            // Here we actually have to do some work
234            (Validity::Array(lhs), Validity::Array(rhs)) => {
235                let lhs = lhs.to_bool();
236                let rhs = rhs.to_bool();
237
238                let lhs = lhs.boolean_buffer();
239                let rhs = rhs.boolean_buffer();
240
241                Validity::from(lhs.bitand(rhs))
242            }
243        }
244    }
245
246    pub fn patch(
247        self,
248        len: usize,
249        indices_offset: usize,
250        indices: &dyn Array,
251        patches: &Validity,
252    ) -> Self {
253        match (&self, patches) {
254            (Validity::NonNullable, Validity::NonNullable) => return Validity::NonNullable,
255            (Validity::NonNullable, _) => {
256                vortex_panic!("Can't patch a non-nullable validity with nullable validity")
257            }
258            (_, Validity::NonNullable) => {
259                vortex_panic!("Can't patch a nullable validity with non-nullable validity")
260            }
261            (Validity::AllValid, Validity::AllValid) => return Validity::AllValid,
262            (Validity::AllInvalid, Validity::AllInvalid) => return Validity::AllInvalid,
263            _ => {}
264        };
265
266        let own_nullability = if self == Validity::NonNullable {
267            Nullability::NonNullable
268        } else {
269            Nullability::Nullable
270        };
271
272        let source = match self {
273            Validity::NonNullable => BoolArray::from(BooleanBuffer::new_set(len)),
274            Validity::AllValid => BoolArray::from(BooleanBuffer::new_set(len)),
275            Validity::AllInvalid => BoolArray::from(BooleanBuffer::new_unset(len)),
276            Validity::Array(a) => a.to_bool(),
277        };
278
279        let patch_values = match patches {
280            Validity::NonNullable => BoolArray::from(BooleanBuffer::new_set(indices.len())),
281            Validity::AllValid => BoolArray::from(BooleanBuffer::new_set(indices.len())),
282            Validity::AllInvalid => BoolArray::from(BooleanBuffer::new_unset(indices.len())),
283            Validity::Array(a) => a.to_bool(),
284        };
285
286        let patches = Patches::new(
287            len,
288            indices_offset,
289            indices.to_array(),
290            patch_values.into_array(),
291            // TODO(0ax1): chunk offsets
292            None,
293        );
294
295        Self::from_array(source.patch(&patches).into_array(), own_nullability)
296    }
297
298    /// Convert into a nullable variant
299    #[inline]
300    pub fn into_nullable(self) -> Validity {
301        match self {
302            Self::NonNullable => Self::AllValid,
303            Self::AllValid | Self::AllInvalid | Self::Array(_) => self,
304        }
305    }
306
307    /// Convert into a non-nullable variant
308    #[inline]
309    pub fn into_non_nullable(self, len: usize) -> Option<Validity> {
310        match self {
311            _ if len == 0 => Some(Validity::NonNullable),
312            Self::NonNullable => Some(Self::NonNullable),
313            Self::AllValid => Some(Self::NonNullable),
314            Self::AllInvalid => None,
315            Self::Array(is_valid) => {
316                is_valid
317                    .statistics()
318                    .compute_min::<bool>()
319                    .vortex_expect("validity array must support min")
320                    .then(|| {
321                        // min true => all true
322                        Self::NonNullable
323                    })
324            }
325        }
326    }
327
328    /// Convert into a variant compatible with the given nullability, if possible.
329    #[inline]
330    pub fn cast_nullability(self, nullability: Nullability, len: usize) -> VortexResult<Validity> {
331        match nullability {
332            Nullability::NonNullable => self.into_non_nullable(len).ok_or_else(|| {
333                vortex_err!("Cannot cast array with invalid values to non-nullable type.")
334            }),
335            Nullability::Nullable => Ok(self.into_nullable()),
336        }
337    }
338
339    /// Create Validity by copying the given array's validity.
340    #[inline]
341    pub fn copy_from_array(array: &dyn Array) -> Self {
342        Validity::from_mask(array.validity_mask(), array.dtype().nullability())
343    }
344
345    /// Create Validity from boolean array with given nullability of the array.
346    ///
347    /// Note: You want to pass the nullability of parent array and not the nullability of the validity array itself
348    ///     as that is always non-nullable
349    #[inline]
350    fn from_array(value: ArrayRef, nullability: Nullability) -> Self {
351        if !matches!(value.dtype(), DType::Bool(Nullability::NonNullable)) {
352            vortex_panic!("Expected a non-nullable boolean array")
353        }
354        match nullability {
355            Nullability::NonNullable => Self::NonNullable,
356            Nullability::Nullable => Self::Array(value),
357        }
358    }
359
360    /// Returns the length of the validity array, if it exists.
361    #[inline]
362    pub fn maybe_len(&self) -> Option<usize> {
363        match self {
364            Self::NonNullable | Self::AllValid | Self::AllInvalid => None,
365            Self::Array(a) => Some(a.len()),
366        }
367    }
368
369    #[inline]
370    pub fn uncompressed_size(&self) -> usize {
371        if let Validity::Array(a) = self {
372            a.len().div_ceil(8)
373        } else {
374            0
375        }
376    }
377}
378
379impl PartialEq for Validity {
380    #[inline]
381    fn eq(&self, other: &Self) -> bool {
382        match (self, other) {
383            (Self::NonNullable, Self::NonNullable) => true,
384            (Self::AllValid, Self::AllValid) => true,
385            (Self::AllInvalid, Self::AllInvalid) => true,
386            (Self::Array(a), Self::Array(b)) => {
387                let a = a.to_bool();
388                let b = b.to_bool();
389                a.boolean_buffer() == b.boolean_buffer()
390            }
391            _ => false,
392        }
393    }
394}
395
396impl From<BooleanBuffer> for Validity {
397    #[inline]
398    fn from(value: BooleanBuffer) -> Self {
399        if value.count_set_bits() == value.len() {
400            Self::AllValid
401        } else if value.count_set_bits() == 0 {
402            Self::AllInvalid
403        } else {
404            Self::Array(BoolArray::from(value).into_array())
405        }
406    }
407}
408
409impl From<NullBuffer> for Validity {
410    #[inline]
411    fn from(value: NullBuffer) -> Self {
412        value.into_inner().into()
413    }
414}
415
416impl FromIterator<Mask> for Validity {
417    #[inline]
418    fn from_iter<T: IntoIterator<Item = Mask>>(iter: T) -> Self {
419        Validity::from_mask(iter.into_iter().collect(), Nullability::Nullable)
420    }
421}
422
423impl FromIterator<bool> for Validity {
424    #[inline]
425    fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
426        Validity::from(BooleanBuffer::from_iter(iter))
427    }
428}
429
430impl From<Nullability> for Validity {
431    #[inline]
432    fn from(value: Nullability) -> Self {
433        match value {
434            Nullability::NonNullable => Validity::NonNullable,
435            Nullability::Nullable => Validity::AllValid,
436        }
437    }
438}
439
440impl Validity {
441    pub fn from_null_buffer(buffer: Option<NullBuffer>, nullability: Nullability) -> Self {
442        match buffer {
443            // If there are no nulls, then we infer from nullability
444            None => nullability.into(),
445            Some(nulls) => {
446                if nulls.null_count() == nulls.len() {
447                    Validity::AllInvalid
448                } else {
449                    Validity::Array(BoolArray::from(nulls.into_inner()).into_array())
450                }
451            }
452        }
453    }
454
455    pub fn from_mask(mask: Mask, nullability: Nullability) -> Self {
456        assert!(
457            nullability == Nullability::Nullable || matches!(mask, Mask::AllTrue(_)),
458            "NonNullable validity must be AllValid",
459        );
460        match mask {
461            Mask::AllTrue(_) => match nullability {
462                Nullability::NonNullable => Validity::NonNullable,
463                Nullability::Nullable => Validity::AllValid,
464            },
465            Mask::AllFalse(_) => Validity::AllInvalid,
466            Mask::Values(values) => Validity::Array(values.into_array()),
467        }
468    }
469}
470
471impl IntoArray for Mask {
472    #[inline]
473    fn into_array(self) -> ArrayRef {
474        match self {
475            Self::AllTrue(len) => ConstantArray::new(true, len).into_array(),
476            Self::AllFalse(len) => ConstantArray::new(false, len).into_array(),
477            Self::Values(a) => a.into_array(),
478        }
479    }
480}
481
482impl IntoArray for &MaskValues {
483    #[inline]
484    fn into_array(self) -> ArrayRef {
485        BoolArray::from_bool_buffer(self.boolean_buffer().clone(), Validity::NonNullable)
486            .into_array()
487    }
488}
489
490#[cfg(test)]
491mod tests {
492    use rstest::rstest;
493    use vortex_buffer::{Buffer, buffer};
494    use vortex_dtype::Nullability;
495    use vortex_mask::Mask;
496
497    use crate::arrays::{BoolArray, PrimitiveArray};
498    use crate::validity::Validity;
499    use crate::{ArrayRef, IntoArray};
500
501    #[rstest]
502    #[case(Validity::AllValid, 5, &[2, 4], Validity::AllValid, Validity::AllValid)]
503    #[case(Validity::AllValid, 5, &[2, 4], Validity::AllInvalid, Validity::Array(BoolArray::from_iter([true, true, false, true, false]).into_array())
504    )]
505    #[case(Validity::AllValid, 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([true, true, true, true, false]).into_array())
506    )]
507    #[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllValid, Validity::Array(BoolArray::from_iter([false, false, true, false, true]).into_array())
508    )]
509    #[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllInvalid, Validity::AllInvalid)]
510    #[case(Validity::AllInvalid, 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([false, false, true, false, false]).into_array())
511    )]
512    #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::AllValid, Validity::Array(BoolArray::from_iter([false, true, true, true, true]).into_array())
513    )]
514    #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::AllInvalid, Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array())
515    )]
516    #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([false, true, true, true, false]).into_array())
517    )]
518    fn patch_validity(
519        #[case] validity: Validity,
520        #[case] len: usize,
521        #[case] positions: &[u64],
522        #[case] patches: Validity,
523        #[case] expected: Validity,
524    ) {
525        let indices =
526            PrimitiveArray::new(Buffer::copy_from(positions), Validity::NonNullable).into_array();
527        assert_eq!(validity.patch(len, 0, &indices, &patches), expected);
528    }
529
530    #[test]
531    #[should_panic]
532    fn out_of_bounds_patch() {
533        Validity::NonNullable.patch(2, 0, &buffer![4].into_array(), &Validity::AllInvalid);
534    }
535
536    #[test]
537    #[should_panic]
538    fn into_validity_nullable() {
539        Validity::from_mask(Mask::AllFalse(10), Nullability::NonNullable);
540    }
541
542    #[test]
543    #[should_panic]
544    fn into_validity_nullable_array() {
545        Validity::from_mask(Mask::from_iter(vec![true, false]), Nullability::NonNullable);
546    }
547
548    #[rstest]
549    #[case(Validity::AllValid, PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(), Validity::from_iter(vec![true, false]))]
550    #[case(Validity::AllValid, buffer![0, 1].into_array(), Validity::AllValid)]
551    #[case(Validity::AllValid, PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(), Validity::AllInvalid)]
552    #[case(Validity::NonNullable, PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(), Validity::from_iter(vec![true, false]))]
553    #[case(Validity::NonNullable, buffer![0, 1].into_array(), Validity::NonNullable)]
554    #[case(Validity::NonNullable, PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(), Validity::AllInvalid)]
555    fn validity_take(
556        #[case] validity: Validity,
557        #[case] indices: ArrayRef,
558        #[case] expected: Validity,
559    ) {
560        assert_eq!(validity.take(&indices).unwrap(), expected);
561    }
562}