vortex_array/
validity.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Array validity and nullability behavior, used by arrays and compute functions.
5
6use std::fmt::Debug;
7use std::ops::{BitAnd, Not};
8
9use arrow_buffer::{BooleanBuffer, NullBuffer};
10use vortex_dtype::{DType, Nullability};
11use vortex_error::{VortexExpect as _, VortexResult, vortex_bail, vortex_err, vortex_panic};
12use vortex_mask::{AllOr, Mask, MaskValues};
13use vortex_scalar::Scalar;
14
15use crate::arrays::{BoolArray, ConstantArray};
16use crate::compute::{fill_null, filter, sum, take};
17use crate::patches::Patches;
18use crate::{Array, ArrayRef, IntoArray, ToCanonical};
19
20/// Validity information for an array
21#[derive(Clone, Debug)]
22pub enum Validity {
23    /// Items *can't* be null
24    NonNullable,
25    /// All items are valid
26    AllValid,
27    /// All items are null
28    AllInvalid,
29    /// Specified items are null
30    Array(ArrayRef),
31}
32
33impl Validity {
34    /// The [`DType`] of the underlying validity array (if it exists).
35    pub const DTYPE: DType = DType::Bool(Nullability::NonNullable);
36
37    pub fn null_count(&self, length: usize) -> VortexResult<usize> {
38        match self {
39            Self::NonNullable | Self::AllValid => Ok(0),
40            Self::AllInvalid => Ok(length),
41            Self::Array(a) => {
42                let validity_len = a.len();
43                if validity_len != length {
44                    vortex_bail!(
45                        "Validity array length {} doesn't match array length {}",
46                        validity_len,
47                        length
48                    )
49                }
50                let true_count = sum(a)?
51                    .as_primitive()
52                    .as_::<usize>()
53                    .ok_or_else(|| vortex_err!("Failed to compute true count"))?;
54                Ok(length - true_count)
55            }
56        }
57    }
58
59    /// If Validity is [`Validity::Array`], returns the array, otherwise returns `None`.
60    pub fn into_array(self) -> Option<ArrayRef> {
61        if let Self::Array(a) = self {
62            Some(a)
63        } else {
64            None
65        }
66    }
67
68    /// If Validity is [`Validity::Array`], returns a reference to the array array, otherwise returns `None`.
69    pub fn as_array(&self) -> Option<&ArrayRef> {
70        if let Self::Array(a) = self {
71            Some(a)
72        } else {
73            None
74        }
75    }
76
77    pub fn nullability(&self) -> Nullability {
78        if matches!(self, Self::NonNullable) {
79            Nullability::NonNullable
80        } else {
81            Nullability::Nullable
82        }
83    }
84
85    /// The union nullability and validity.
86    pub fn union_nullability(self, nullability: Nullability) -> Self {
87        match nullability {
88            Nullability::NonNullable => self,
89            Nullability::Nullable => self.into_nullable(),
90        }
91    }
92
93    pub fn all_valid(&self) -> VortexResult<bool> {
94        Ok(match self {
95            Validity::NonNullable | Validity::AllValid => true,
96            Validity::AllInvalid => false,
97            Validity::Array(array) => sum(array)
98                .map(|v| {
99                    v.as_primitive()
100                        .typed_value::<u64>()
101                        .map(|count| count == array.len() as u64)
102                })?
103                .ok_or_else(|| vortex_err!("Failed to compute sum for validity array"))?,
104        })
105    }
106
107    pub fn all_invalid(&self) -> VortexResult<bool> {
108        Ok(match self {
109            Validity::NonNullable | Validity::AllValid => false,
110            Validity::AllInvalid => true,
111            Validity::Array(array) => sum(array)
112                .map(|v| {
113                    v.as_primitive()
114                        .typed_value::<u64>()
115                        .map(|count| count == 0u64)
116                })?
117                .ok_or_else(|| vortex_err!("Failed to compute sum for validity array"))?,
118        })
119    }
120
121    /// Returns whether the `index` item is valid.
122    #[inline]
123    pub fn is_valid(&self, index: usize) -> VortexResult<bool> {
124        Ok(match self {
125            Self::NonNullable | Self::AllValid => true,
126            Self::AllInvalid => false,
127            Self::Array(a) => {
128                let scalar = a.scalar_at(index);
129                scalar
130                    .as_bool()
131                    .value()
132                    .vortex_expect("Validity must be non-nullable")
133            }
134        })
135    }
136
137    #[inline]
138    pub fn is_null(&self, index: usize) -> VortexResult<bool> {
139        Ok(!self.is_valid(index)?)
140    }
141
142    pub fn slice(&self, start: usize, stop: usize) -> Self {
143        match self {
144            Self::Array(a) => Self::Array(a.slice(start, stop)),
145            Self::NonNullable | Self::AllValid | Self::AllInvalid => self.clone(),
146        }
147    }
148
149    pub fn take(&self, indices: &dyn Array) -> VortexResult<Self> {
150        match self {
151            Self::NonNullable => match indices.validity_mask()?.boolean_buffer() {
152                AllOr::All => {
153                    if indices.dtype().is_nullable() {
154                        Ok(Self::AllValid)
155                    } else {
156                        Ok(Self::NonNullable)
157                    }
158                }
159                AllOr::None => Ok(Self::AllInvalid),
160                AllOr::Some(buf) => Ok(Validity::from(buf.clone())),
161            },
162            Self::AllValid => match indices.validity_mask()?.boolean_buffer() {
163                AllOr::All => Ok(Self::AllValid),
164                AllOr::None => Ok(Self::AllInvalid),
165                AllOr::Some(buf) => Ok(Validity::from(buf.clone())),
166            },
167            Self::AllInvalid => Ok(Self::AllInvalid),
168            Self::Array(is_valid) => {
169                let maybe_is_valid = take(is_valid, indices)?;
170                // Null indices invalidate that position.
171                let is_valid = fill_null(&maybe_is_valid, &Scalar::from(false))?;
172                Ok(Self::Array(is_valid))
173            }
174        }
175    }
176
177    /// Keep only the entries for which the mask is true.
178    ///
179    /// The result has length equal to the number of true values in mask.
180    pub fn filter(&self, mask: &Mask) -> VortexResult<Self> {
181        // NOTE(ngates): we take the mask as a reference to avoid the caller cloning unnecessarily
182        //  if we happen to be NonNullable, AllValid, or AllInvalid.
183        match self {
184            v @ (Validity::NonNullable | Validity::AllValid | Validity::AllInvalid) => {
185                Ok(v.clone())
186            }
187            Validity::Array(arr) => Ok(Validity::Array(filter(arr, mask)?)),
188        }
189    }
190
191    /// Set to false any entries for which the mask is true.
192    ///
193    /// The result is always nullable. The result has the same length as self.
194    pub fn mask(&self, mask: &Mask) -> VortexResult<Self> {
195        match mask.boolean_buffer() {
196            AllOr::All => Ok(Validity::AllInvalid),
197            AllOr::None => Ok(self.clone()),
198            AllOr::Some(make_invalid) => Ok(match self {
199                Validity::NonNullable | Validity::AllValid => {
200                    Validity::Array(BoolArray::from(make_invalid.not()).into_array())
201                }
202                Validity::AllInvalid => Validity::AllInvalid,
203                Validity::Array(is_valid) => {
204                    let is_valid = is_valid.to_bool()?;
205                    let keep_valid = make_invalid.not();
206                    Validity::from(is_valid.boolean_buffer().bitand(&keep_valid))
207                }
208            }),
209        }
210    }
211
212    pub fn to_mask(&self, length: usize) -> VortexResult<Mask> {
213        Ok(match self {
214            Self::NonNullable | Self::AllValid => Mask::AllTrue(length),
215            Self::AllInvalid => Mask::AllFalse(length),
216            Self::Array(is_valid) => {
217                assert_eq!(
218                    is_valid.len(),
219                    length,
220                    "Validity::Array length must equal to_logical's argument: {}, {}.",
221                    is_valid.len(),
222                    length,
223                );
224                Mask::try_from(&is_valid.to_bool()?)?
225            }
226        })
227    }
228
229    /// Logically & two Validity values of the same length
230    pub fn and(self, rhs: Validity) -> VortexResult<Validity> {
231        let validity = match (self, rhs) {
232            // Should be pretty clear
233            (Validity::NonNullable, Validity::NonNullable) => Validity::NonNullable,
234            // Any `AllInvalid` makes the output all invalid values
235            (Validity::AllInvalid, _) | (_, Validity::AllInvalid) => Validity::AllInvalid,
236            // All truthy values on one side, which makes no effect on an `Array` variant
237            (Validity::Array(a), Validity::AllValid)
238            | (Validity::Array(a), Validity::NonNullable)
239            | (Validity::NonNullable, Validity::Array(a))
240            | (Validity::AllValid, Validity::Array(a)) => Validity::Array(a),
241            // Both sides are all valid
242            (Validity::NonNullable, Validity::AllValid)
243            | (Validity::AllValid, Validity::NonNullable)
244            | (Validity::AllValid, Validity::AllValid) => Validity::AllValid,
245            // Here we actually have to do some work
246            (Validity::Array(lhs), Validity::Array(rhs)) => {
247                let lhs = lhs.to_bool()?;
248                let rhs = rhs.to_bool()?;
249
250                let lhs = lhs.boolean_buffer();
251                let rhs = rhs.boolean_buffer();
252
253                Validity::from(lhs.bitand(rhs))
254            }
255        };
256
257        Ok(validity)
258    }
259
260    pub fn patch(
261        self,
262        len: usize,
263        indices_offset: usize,
264        indices: &dyn Array,
265        patches: &Validity,
266    ) -> VortexResult<Self> {
267        match (&self, patches) {
268            (Validity::NonNullable, Validity::NonNullable) => return Ok(Validity::NonNullable),
269            (Validity::NonNullable, _) => {
270                vortex_bail!("Can't patch a non-nullable validity with nullable validity")
271            }
272            (_, Validity::NonNullable) => {
273                vortex_bail!("Can't patch a nullable validity with non-nullable validity")
274            }
275            (Validity::AllValid, Validity::AllValid) => return Ok(Validity::AllValid),
276            (Validity::AllInvalid, Validity::AllInvalid) => return Ok(Validity::AllInvalid),
277            _ => {}
278        };
279
280        let own_nullability = if self == Validity::NonNullable {
281            Nullability::NonNullable
282        } else {
283            Nullability::Nullable
284        };
285
286        let source = match self {
287            Validity::NonNullable => BoolArray::from(BooleanBuffer::new_set(len)),
288            Validity::AllValid => BoolArray::from(BooleanBuffer::new_set(len)),
289            Validity::AllInvalid => BoolArray::from(BooleanBuffer::new_unset(len)),
290            Validity::Array(a) => a.to_bool()?,
291        };
292
293        let patch_values = match patches {
294            Validity::NonNullable => BoolArray::from(BooleanBuffer::new_set(indices.len())),
295            Validity::AllValid => BoolArray::from(BooleanBuffer::new_set(indices.len())),
296            Validity::AllInvalid => BoolArray::from(BooleanBuffer::new_unset(indices.len())),
297            Validity::Array(a) => a.to_bool()?,
298        };
299
300        let patches = Patches::new(
301            len,
302            indices_offset,
303            indices.to_array(),
304            patch_values.into_array(),
305        );
306
307        Ok(Self::from_array(
308            source.patch(&patches)?.into_array(),
309            own_nullability,
310        ))
311    }
312
313    /// Convert into a nullable variant
314    pub fn into_nullable(self) -> Validity {
315        match self {
316            Self::NonNullable => Self::AllValid,
317            Self::AllValid | Self::AllInvalid | Self::Array(_) => self,
318        }
319    }
320
321    /// Convert into a non-nullable variant
322    pub fn into_non_nullable(self) -> Option<Validity> {
323        match self {
324            Self::NonNullable => Some(Self::NonNullable),
325            Self::AllValid => Some(Self::NonNullable),
326            Self::AllInvalid => None,
327            Self::Array(is_valid) => {
328                is_valid
329                    .statistics()
330                    .compute_min::<bool>()
331                    .vortex_expect("validity array must support min")
332                    .then(|| {
333                        // min true => all true
334                        Self::NonNullable
335                    })
336            }
337        }
338    }
339
340    /// Convert into a variant compatible with the given nullability, if possible.
341    pub fn cast_nullability(self, nullability: Nullability) -> VortexResult<Validity> {
342        match nullability {
343            Nullability::NonNullable => self.into_non_nullable().ok_or_else(|| {
344                vortex_err!("Cannot cast array with invalid values to non-nullable type.")
345            }),
346            Nullability::Nullable => Ok(self.into_nullable()),
347        }
348    }
349
350    /// Create Validity by copying the given array's validity.
351    pub fn copy_from_array(array: &dyn Array) -> VortexResult<Self> {
352        Ok(Validity::from_mask(
353            array.validity_mask()?,
354            array.dtype().nullability(),
355        ))
356    }
357
358    /// Create Validity from boolean array with given nullability of the array.
359    ///
360    /// Note: You want to pass the nullability of parent array and not the nullability of the validity array itself
361    ///     as that is always nonnullable
362    fn from_array(value: ArrayRef, nullability: Nullability) -> Self {
363        if !matches!(value.dtype(), DType::Bool(Nullability::NonNullable)) {
364            vortex_panic!("Expected a non-nullable boolean array")
365        }
366        match nullability {
367            Nullability::NonNullable => Self::NonNullable,
368            Nullability::Nullable => Self::Array(value),
369        }
370    }
371
372    /// Returns the length of the validity array, if it exists.
373    pub fn maybe_len(&self) -> Option<usize> {
374        match self {
375            Self::NonNullable | Self::AllValid | Self::AllInvalid => None,
376            Self::Array(a) => Some(a.len()),
377        }
378    }
379
380    pub fn uncompressed_size(&self) -> usize {
381        if let Validity::Array(a) = self {
382            a.len().div_ceil(8)
383        } else {
384            0
385        }
386    }
387
388    pub fn is_array(&self) -> bool {
389        matches!(self, Validity::Array(_))
390    }
391}
392
393impl PartialEq for Validity {
394    fn eq(&self, other: &Self) -> bool {
395        match (self, other) {
396            (Self::NonNullable, Self::NonNullable) => true,
397            (Self::AllValid, Self::AllValid) => true,
398            (Self::AllInvalid, Self::AllInvalid) => true,
399            (Self::Array(a), Self::Array(b)) => {
400                let a = a
401                    .to_bool()
402                    .vortex_expect("Failed to get Validity Array as BoolArray");
403                let b = b
404                    .to_bool()
405                    .vortex_expect("Failed to get Validity Array as BoolArray");
406                a.boolean_buffer() == b.boolean_buffer()
407            }
408            _ => false,
409        }
410    }
411}
412
413impl From<BooleanBuffer> for Validity {
414    fn from(value: BooleanBuffer) -> Self {
415        if value.count_set_bits() == value.len() {
416            Self::AllValid
417        } else if value.count_set_bits() == 0 {
418            Self::AllInvalid
419        } else {
420            Self::Array(BoolArray::from(value).into_array())
421        }
422    }
423}
424
425impl From<NullBuffer> for Validity {
426    fn from(value: NullBuffer) -> Self {
427        value.into_inner().into()
428    }
429}
430
431impl FromIterator<Mask> for Validity {
432    fn from_iter<T: IntoIterator<Item = Mask>>(iter: T) -> Self {
433        Validity::from_mask(iter.into_iter().collect(), Nullability::Nullable)
434    }
435}
436
437impl FromIterator<bool> for Validity {
438    fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
439        Validity::from(BooleanBuffer::from_iter(iter))
440    }
441}
442
443impl From<Nullability> for Validity {
444    fn from(value: Nullability) -> Self {
445        match value {
446            Nullability::NonNullable => Validity::NonNullable,
447            Nullability::Nullable => Validity::AllValid,
448        }
449    }
450}
451
452impl Validity {
453    pub fn from_mask(mask: Mask, nullability: Nullability) -> Self {
454        assert!(
455            nullability == Nullability::Nullable || matches!(mask, Mask::AllTrue(_)),
456            "NonNullable validity must be AllValid",
457        );
458        match mask {
459            Mask::AllTrue(_) => match nullability {
460                Nullability::NonNullable => Validity::NonNullable,
461                Nullability::Nullable => Validity::AllValid,
462            },
463            Mask::AllFalse(_) => Validity::AllInvalid,
464            Mask::Values(values) => Validity::Array(values.into_array()),
465        }
466    }
467}
468
469impl IntoArray for Mask {
470    fn into_array(self) -> ArrayRef {
471        match self {
472            Self::AllTrue(len) => ConstantArray::new(true, len).into_array(),
473            Self::AllFalse(len) => ConstantArray::new(false, len).into_array(),
474            Self::Values(a) => a.into_array(),
475        }
476    }
477}
478
479impl IntoArray for &MaskValues {
480    fn into_array(self) -> ArrayRef {
481        BoolArray::new(self.boolean_buffer().clone(), Validity::NonNullable).into_array()
482    }
483}
484
485#[cfg(test)]
486mod tests {
487    use rstest::rstest;
488    use vortex_buffer::{Buffer, buffer};
489    use vortex_dtype::Nullability;
490    use vortex_mask::Mask;
491
492    use crate::arrays::{BoolArray, PrimitiveArray};
493    use crate::validity::Validity;
494    use crate::{ArrayRef, IntoArray};
495
496    #[rstest]
497    #[case(Validity::AllValid, 5, &[2, 4], Validity::AllValid, Validity::AllValid)]
498    #[case(Validity::AllValid, 5, &[2, 4], Validity::AllInvalid, Validity::Array(BoolArray::from_iter([true, true, false, true, false]).into_array())
499    )]
500    #[case(Validity::AllValid, 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([true, true, true, true, false]).into_array())
501    )]
502    #[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllValid, Validity::Array(BoolArray::from_iter([false, false, true, false, true]).into_array())
503    )]
504    #[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllInvalid, Validity::AllInvalid)]
505    #[case(Validity::AllInvalid, 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([false, false, true, false, false]).into_array())
506    )]
507    #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::AllValid, Validity::Array(BoolArray::from_iter([false, true, true, true, true]).into_array())
508    )]
509    #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::AllInvalid, Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array())
510    )]
511    #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([false, true, true, true, false]).into_array())
512    )]
513    fn patch_validity(
514        #[case] validity: Validity,
515        #[case] len: usize,
516        #[case] positions: &[u64],
517        #[case] patches: Validity,
518        #[case] expected: Validity,
519    ) {
520        let indices =
521            PrimitiveArray::new(Buffer::copy_from(positions), Validity::NonNullable).into_array();
522        assert_eq!(
523            validity.patch(len, 0, &indices, &patches).unwrap(),
524            expected
525        );
526    }
527
528    #[test]
529    #[should_panic]
530    fn out_of_bounds_patch() {
531        Validity::NonNullable
532            .patch(2, 0, &buffer![4].into_array(), &Validity::AllInvalid)
533            .unwrap();
534    }
535
536    #[test]
537    #[should_panic]
538    fn into_validity_nullable() {
539        Validity::from_mask(Mask::AllFalse(10), Nullability::NonNullable);
540    }
541
542    #[test]
543    #[should_panic]
544    fn into_validity_nullable_array() {
545        Validity::from_mask(Mask::from_iter(vec![true, false]), Nullability::NonNullable);
546    }
547
548    #[rstest]
549    #[case(Validity::AllValid, PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(), Validity::from_iter(vec![true, false]))]
550    #[case(Validity::AllValid, buffer![0, 1].into_array(), Validity::AllValid)]
551    #[case(Validity::AllValid, PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(), Validity::AllInvalid)]
552    #[case(Validity::NonNullable, PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(), Validity::from_iter(vec![true, false]))]
553    #[case(Validity::NonNullable, buffer![0, 1].into_array(), Validity::NonNullable)]
554    #[case(Validity::NonNullable, PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(), Validity::AllInvalid)]
555    fn validity_take(
556        #[case] validity: Validity,
557        #[case] indices: ArrayRef,
558        #[case] expected: Validity,
559    ) {
560        assert_eq!(validity.take(&indices).unwrap(), expected);
561    }
562}