vortex_array/
validity.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Array validity and nullability behavior, used by arrays and compute functions.
5
6use std::fmt::Debug;
7use std::ops::{BitAnd, Not, Range};
8
9use arrow_buffer::{BooleanBuffer, NullBuffer};
10use vortex_dtype::{DType, Nullability};
11use vortex_error::{VortexExpect as _, VortexResult, vortex_bail, vortex_err, vortex_panic};
12use vortex_mask::{AllOr, Mask, MaskValues};
13use vortex_scalar::Scalar;
14
15use crate::arrays::{BoolArray, ConstantArray};
16use crate::compute::{fill_null, filter, sum, take};
17use crate::patches::Patches;
18use crate::{Array, ArrayRef, IntoArray, ToCanonical};
19
20/// Validity information for an array
21#[derive(Clone, Debug)]
22pub enum Validity {
23    /// Items *can't* be null
24    NonNullable,
25    /// All items are valid
26    AllValid,
27    /// All items are null
28    AllInvalid,
29    /// Specified items are null
30    Array(ArrayRef),
31}
32
33impl Validity {
34    /// The [`DType`] of the underlying validity array (if it exists).
35    pub const DTYPE: DType = DType::Bool(Nullability::NonNullable);
36
37    pub fn null_count(&self, length: usize) -> VortexResult<usize> {
38        match self {
39            Self::NonNullable | Self::AllValid => Ok(0),
40            Self::AllInvalid => Ok(length),
41            Self::Array(a) => {
42                let validity_len = a.len();
43                if validity_len != length {
44                    vortex_bail!(
45                        "Validity array length {} doesn't match array length {}",
46                        validity_len,
47                        length
48                    )
49                }
50                let true_count = sum(a)?
51                    .as_primitive()
52                    .as_::<usize>()
53                    .ok_or_else(|| vortex_err!("Failed to compute true count"))?;
54                Ok(length - true_count)
55            }
56        }
57    }
58
59    /// If Validity is [`Validity::Array`], returns the array, otherwise returns `None`.
60    pub fn into_array(self) -> Option<ArrayRef> {
61        if let Self::Array(a) = self {
62            Some(a)
63        } else {
64            None
65        }
66    }
67
68    /// If Validity is [`Validity::Array`], returns a reference to the array array, otherwise returns `None`.
69    pub fn as_array(&self) -> Option<&ArrayRef> {
70        if let Self::Array(a) = self {
71            Some(a)
72        } else {
73            None
74        }
75    }
76
77    pub fn nullability(&self) -> Nullability {
78        if matches!(self, Self::NonNullable) {
79            Nullability::NonNullable
80        } else {
81            Nullability::Nullable
82        }
83    }
84
85    /// The union nullability and validity.
86    pub fn union_nullability(self, nullability: Nullability) -> Self {
87        match nullability {
88            Nullability::NonNullable => self,
89            Nullability::Nullable => self.into_nullable(),
90        }
91    }
92
93    pub fn all_valid(&self) -> bool {
94        match self {
95            Validity::NonNullable | Validity::AllValid => true,
96            Validity::AllInvalid => false,
97            Validity::Array(array) => {
98                usize::try_from(&sum(array).vortex_expect("must have sum for bool array"))
99                    .vortex_expect("sum must be a usize")
100                    == array.len()
101            }
102        }
103    }
104
105    pub fn all_invalid(&self) -> bool {
106        match self {
107            Validity::NonNullable | Validity::AllValid => false,
108            Validity::AllInvalid => true,
109            Validity::Array(array) => {
110                usize::try_from(&sum(array).vortex_expect("must have sum for bool array"))
111                    .vortex_expect("sum must be a usize")
112                    == 0
113            }
114        }
115    }
116
117    /// Returns whether the `index` item is valid.
118    #[inline]
119    pub fn is_valid(&self, index: usize) -> bool {
120        match self {
121            Self::NonNullable | Self::AllValid => true,
122            Self::AllInvalid => false,
123            Self::Array(a) => {
124                let scalar = a.scalar_at(index);
125                scalar
126                    .as_bool()
127                    .value()
128                    .vortex_expect("Validity must be non-nullable")
129            }
130        }
131    }
132
133    #[inline]
134    pub fn is_null(&self, index: usize) -> bool {
135        !self.is_valid(index)
136    }
137
138    pub fn slice(&self, range: Range<usize>) -> Self {
139        match self {
140            Self::Array(a) => Self::Array(a.slice(range)),
141            Self::NonNullable | Self::AllValid | Self::AllInvalid => self.clone(),
142        }
143    }
144
145    pub fn take(&self, indices: &dyn Array) -> VortexResult<Self> {
146        match self {
147            Self::NonNullable => match indices.validity_mask().boolean_buffer() {
148                AllOr::All => {
149                    if indices.dtype().is_nullable() {
150                        Ok(Self::AllValid)
151                    } else {
152                        Ok(Self::NonNullable)
153                    }
154                }
155                AllOr::None => Ok(Self::AllInvalid),
156                AllOr::Some(buf) => Ok(Validity::from(buf.clone())),
157            },
158            Self::AllValid => match indices.validity_mask().boolean_buffer() {
159                AllOr::All => Ok(Self::AllValid),
160                AllOr::None => Ok(Self::AllInvalid),
161                AllOr::Some(buf) => Ok(Validity::from(buf.clone())),
162            },
163            Self::AllInvalid => Ok(Self::AllInvalid),
164            Self::Array(is_valid) => {
165                let maybe_is_valid = take(is_valid, indices)?;
166                // Null indices invalidate that position.
167                let is_valid = fill_null(&maybe_is_valid, &Scalar::from(false))?;
168                Ok(Self::Array(is_valid))
169            }
170        }
171    }
172
173    /// Keep only the entries for which the mask is true.
174    ///
175    /// The result has length equal to the number of true values in mask.
176    pub fn filter(&self, mask: &Mask) -> VortexResult<Self> {
177        // NOTE(ngates): we take the mask as a reference to avoid the caller cloning unnecessarily
178        //  if we happen to be NonNullable, AllValid, or AllInvalid.
179        match self {
180            v @ (Validity::NonNullable | Validity::AllValid | Validity::AllInvalid) => {
181                Ok(v.clone())
182            }
183            Validity::Array(arr) => Ok(Validity::Array(filter(arr, mask)?)),
184        }
185    }
186
187    /// Set to false any entries for which the mask is true.
188    ///
189    /// The result is always nullable. The result has the same length as self.
190    pub fn mask(&self, mask: &Mask) -> VortexResult<Self> {
191        match mask.boolean_buffer() {
192            AllOr::All => Ok(Validity::AllInvalid),
193            AllOr::None => Ok(self.clone()),
194            AllOr::Some(make_invalid) => Ok(match self {
195                Validity::NonNullable | Validity::AllValid => {
196                    Validity::Array(BoolArray::from(make_invalid.not()).into_array())
197                }
198                Validity::AllInvalid => Validity::AllInvalid,
199                Validity::Array(is_valid) => {
200                    let is_valid = is_valid.to_bool()?;
201                    let keep_valid = make_invalid.not();
202                    Validity::from(is_valid.boolean_buffer().bitand(&keep_valid))
203                }
204            }),
205        }
206    }
207
208    pub fn to_mask(&self, length: usize) -> Mask {
209        match self {
210            Self::NonNullable | Self::AllValid => Mask::AllTrue(length),
211            Self::AllInvalid => Mask::AllFalse(length),
212            Self::Array(is_valid) => {
213                assert_eq!(
214                    is_valid.len(),
215                    length,
216                    "Validity::Array length must equal to_logical's argument: {}, {}.",
217                    is_valid.len(),
218                    length,
219                );
220                Mask::from(
221                    &is_valid
222                        .to_bool()
223                        .vortex_expect("validity array must be bool"),
224                )
225            }
226        }
227    }
228
229    /// Logically & two Validity values of the same length
230    pub fn and(self, rhs: Validity) -> VortexResult<Validity> {
231        let validity = match (self, rhs) {
232            // Should be pretty clear
233            (Validity::NonNullable, Validity::NonNullable) => Validity::NonNullable,
234            // Any `AllInvalid` makes the output all invalid values
235            (Validity::AllInvalid, _) | (_, Validity::AllInvalid) => Validity::AllInvalid,
236            // All truthy values on one side, which makes no effect on an `Array` variant
237            (Validity::Array(a), Validity::AllValid)
238            | (Validity::Array(a), Validity::NonNullable)
239            | (Validity::NonNullable, Validity::Array(a))
240            | (Validity::AllValid, Validity::Array(a)) => Validity::Array(a),
241            // Both sides are all valid
242            (Validity::NonNullable, Validity::AllValid)
243            | (Validity::AllValid, Validity::NonNullable)
244            | (Validity::AllValid, Validity::AllValid) => Validity::AllValid,
245            // Here we actually have to do some work
246            (Validity::Array(lhs), Validity::Array(rhs)) => {
247                let lhs = lhs.to_bool()?;
248                let rhs = rhs.to_bool()?;
249
250                let lhs = lhs.boolean_buffer();
251                let rhs = rhs.boolean_buffer();
252
253                Validity::from(lhs.bitand(rhs))
254            }
255        };
256
257        Ok(validity)
258    }
259
260    pub fn patch(
261        self,
262        len: usize,
263        indices_offset: usize,
264        indices: &dyn Array,
265        patches: &Validity,
266    ) -> VortexResult<Self> {
267        match (&self, patches) {
268            (Validity::NonNullable, Validity::NonNullable) => return Ok(Validity::NonNullable),
269            (Validity::NonNullable, _) => {
270                vortex_bail!("Can't patch a non-nullable validity with nullable validity")
271            }
272            (_, Validity::NonNullable) => {
273                vortex_bail!("Can't patch a nullable validity with non-nullable validity")
274            }
275            (Validity::AllValid, Validity::AllValid) => return Ok(Validity::AllValid),
276            (Validity::AllInvalid, Validity::AllInvalid) => return Ok(Validity::AllInvalid),
277            _ => {}
278        };
279
280        let own_nullability = if self == Validity::NonNullable {
281            Nullability::NonNullable
282        } else {
283            Nullability::Nullable
284        };
285
286        let source = match self {
287            Validity::NonNullable => BoolArray::from(BooleanBuffer::new_set(len)),
288            Validity::AllValid => BoolArray::from(BooleanBuffer::new_set(len)),
289            Validity::AllInvalid => BoolArray::from(BooleanBuffer::new_unset(len)),
290            Validity::Array(a) => a.to_bool()?,
291        };
292
293        let patch_values = match patches {
294            Validity::NonNullable => BoolArray::from(BooleanBuffer::new_set(indices.len())),
295            Validity::AllValid => BoolArray::from(BooleanBuffer::new_set(indices.len())),
296            Validity::AllInvalid => BoolArray::from(BooleanBuffer::new_unset(indices.len())),
297            Validity::Array(a) => a.to_bool()?,
298        };
299
300        let patches = Patches::new(
301            len,
302            indices_offset,
303            indices.to_array(),
304            patch_values.into_array(),
305        );
306
307        Ok(Self::from_array(
308            source.patch(&patches)?.into_array(),
309            own_nullability,
310        ))
311    }
312
313    /// Convert into a nullable variant
314    pub fn into_nullable(self) -> Validity {
315        match self {
316            Self::NonNullable => Self::AllValid,
317            Self::AllValid | Self::AllInvalid | Self::Array(_) => self,
318        }
319    }
320
321    /// Convert into a non-nullable variant
322    pub fn into_non_nullable(self) -> Option<Validity> {
323        match self {
324            Self::NonNullable => Some(Self::NonNullable),
325            Self::AllValid => Some(Self::NonNullable),
326            Self::AllInvalid => None,
327            Self::Array(is_valid) => {
328                is_valid
329                    .statistics()
330                    .compute_min::<bool>()
331                    .vortex_expect("validity array must support min")
332                    .then(|| {
333                        // min true => all true
334                        Self::NonNullable
335                    })
336            }
337        }
338    }
339
340    /// Convert into a variant compatible with the given nullability, if possible.
341    pub fn cast_nullability(self, nullability: Nullability) -> VortexResult<Validity> {
342        match nullability {
343            Nullability::NonNullable => self.into_non_nullable().ok_or_else(|| {
344                vortex_err!("Cannot cast array with invalid values to non-nullable type.")
345            }),
346            Nullability::Nullable => Ok(self.into_nullable()),
347        }
348    }
349
350    /// Create Validity by copying the given array's validity.
351    pub fn copy_from_array(array: &dyn Array) -> VortexResult<Self> {
352        Ok(Validity::from_mask(
353            array.validity_mask(),
354            array.dtype().nullability(),
355        ))
356    }
357
358    /// Create Validity from boolean array with given nullability of the array.
359    ///
360    /// Note: You want to pass the nullability of parent array and not the nullability of the validity array itself
361    ///     as that is always nonnullable
362    fn from_array(value: ArrayRef, nullability: Nullability) -> Self {
363        if !matches!(value.dtype(), DType::Bool(Nullability::NonNullable)) {
364            vortex_panic!("Expected a non-nullable boolean array")
365        }
366        match nullability {
367            Nullability::NonNullable => Self::NonNullable,
368            Nullability::Nullable => Self::Array(value),
369        }
370    }
371
372    /// Returns the length of the validity array, if it exists.
373    pub fn maybe_len(&self) -> Option<usize> {
374        match self {
375            Self::NonNullable | Self::AllValid | Self::AllInvalid => None,
376            Self::Array(a) => Some(a.len()),
377        }
378    }
379
380    pub fn uncompressed_size(&self) -> usize {
381        if let Validity::Array(a) = self {
382            a.len().div_ceil(8)
383        } else {
384            0
385        }
386    }
387
388    pub fn is_array(&self) -> bool {
389        matches!(self, Validity::Array(_))
390    }
391}
392
393impl PartialEq for Validity {
394    fn eq(&self, other: &Self) -> bool {
395        match (self, other) {
396            (Self::NonNullable, Self::NonNullable) => true,
397            (Self::AllValid, Self::AllValid) => true,
398            (Self::AllInvalid, Self::AllInvalid) => true,
399            (Self::Array(a), Self::Array(b)) => {
400                let a = a
401                    .to_bool()
402                    .vortex_expect("Failed to get Validity Array as BoolArray");
403                let b = b
404                    .to_bool()
405                    .vortex_expect("Failed to get Validity Array as BoolArray");
406                a.boolean_buffer() == b.boolean_buffer()
407            }
408            _ => false,
409        }
410    }
411}
412
413impl From<BooleanBuffer> for Validity {
414    fn from(value: BooleanBuffer) -> Self {
415        if value.count_set_bits() == value.len() {
416            Self::AllValid
417        } else if value.count_set_bits() == 0 {
418            Self::AllInvalid
419        } else {
420            Self::Array(BoolArray::from(value).into_array())
421        }
422    }
423}
424
425impl From<NullBuffer> for Validity {
426    fn from(value: NullBuffer) -> Self {
427        value.into_inner().into()
428    }
429}
430
431impl FromIterator<Mask> for Validity {
432    fn from_iter<T: IntoIterator<Item = Mask>>(iter: T) -> Self {
433        Validity::from_mask(iter.into_iter().collect(), Nullability::Nullable)
434    }
435}
436
437impl FromIterator<bool> for Validity {
438    fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
439        Validity::from(BooleanBuffer::from_iter(iter))
440    }
441}
442
443impl From<Nullability> for Validity {
444    fn from(value: Nullability) -> Self {
445        match value {
446            Nullability::NonNullable => Validity::NonNullable,
447            Nullability::Nullable => Validity::AllValid,
448        }
449    }
450}
451
452impl Validity {
453    pub fn from_mask(mask: Mask, nullability: Nullability) -> Self {
454        assert!(
455            nullability == Nullability::Nullable || matches!(mask, Mask::AllTrue(_)),
456            "NonNullable validity must be AllValid",
457        );
458        match mask {
459            Mask::AllTrue(_) => match nullability {
460                Nullability::NonNullable => Validity::NonNullable,
461                Nullability::Nullable => Validity::AllValid,
462            },
463            Mask::AllFalse(_) => Validity::AllInvalid,
464            Mask::Values(values) => Validity::Array(values.into_array()),
465        }
466    }
467}
468
469impl IntoArray for Mask {
470    fn into_array(self) -> ArrayRef {
471        match self {
472            Self::AllTrue(len) => ConstantArray::new(true, len).into_array(),
473            Self::AllFalse(len) => ConstantArray::new(false, len).into_array(),
474            Self::Values(a) => a.into_array(),
475        }
476    }
477}
478
479impl IntoArray for &MaskValues {
480    fn into_array(self) -> ArrayRef {
481        BoolArray::new(self.boolean_buffer().clone(), Validity::NonNullable).into_array()
482    }
483}
484
485#[cfg(test)]
486mod tests {
487    use rstest::rstest;
488    use vortex_buffer::{Buffer, buffer};
489    use vortex_dtype::Nullability;
490    use vortex_mask::Mask;
491
492    use crate::arrays::{BoolArray, PrimitiveArray};
493    use crate::validity::Validity;
494    use crate::{ArrayRef, IntoArray};
495
496    #[rstest]
497    #[case(Validity::AllValid, 5, &[2, 4], Validity::AllValid, Validity::AllValid)]
498    #[case(Validity::AllValid, 5, &[2, 4], Validity::AllInvalid, Validity::Array(BoolArray::from_iter([true, true, false, true, false]).into_array())
499    )]
500    #[case(Validity::AllValid, 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([true, true, true, true, false]).into_array())
501    )]
502    #[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllValid, Validity::Array(BoolArray::from_iter([false, false, true, false, true]).into_array())
503    )]
504    #[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllInvalid, Validity::AllInvalid)]
505    #[case(Validity::AllInvalid, 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([false, false, true, false, false]).into_array())
506    )]
507    #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::AllValid, Validity::Array(BoolArray::from_iter([false, true, true, true, true]).into_array())
508    )]
509    #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::AllInvalid, Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array())
510    )]
511    #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([false, true, true, true, false]).into_array())
512    )]
513    fn patch_validity(
514        #[case] validity: Validity,
515        #[case] len: usize,
516        #[case] positions: &[u64],
517        #[case] patches: Validity,
518        #[case] expected: Validity,
519    ) {
520        let indices =
521            PrimitiveArray::new(Buffer::copy_from(positions), Validity::NonNullable).into_array();
522        assert_eq!(
523            validity.patch(len, 0, &indices, &patches).unwrap(),
524            expected
525        );
526    }
527
528    #[test]
529    #[should_panic]
530    fn out_of_bounds_patch() {
531        Validity::NonNullable
532            .patch(2, 0, &buffer![4].into_array(), &Validity::AllInvalid)
533            .unwrap();
534    }
535
536    #[test]
537    #[should_panic]
538    fn into_validity_nullable() {
539        Validity::from_mask(Mask::AllFalse(10), Nullability::NonNullable);
540    }
541
542    #[test]
543    #[should_panic]
544    fn into_validity_nullable_array() {
545        Validity::from_mask(Mask::from_iter(vec![true, false]), Nullability::NonNullable);
546    }
547
548    #[rstest]
549    #[case(Validity::AllValid, PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(), Validity::from_iter(vec![true, false]))]
550    #[case(Validity::AllValid, buffer![0, 1].into_array(), Validity::AllValid)]
551    #[case(Validity::AllValid, PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(), Validity::AllInvalid)]
552    #[case(Validity::NonNullable, PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(), Validity::from_iter(vec![true, false]))]
553    #[case(Validity::NonNullable, buffer![0, 1].into_array(), Validity::NonNullable)]
554    #[case(Validity::NonNullable, PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(), Validity::AllInvalid)]
555    fn validity_take(
556        #[case] validity: Validity,
557        #[case] indices: ArrayRef,
558        #[case] expected: Validity,
559    ) {
560        assert_eq!(validity.take(&indices).unwrap(), expected);
561    }
562}