vortex_array/
validity.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Array validity and nullability behavior, used by arrays and compute functions.
5
6use std::fmt::Debug;
7use std::ops::BitAnd;
8use std::ops::Range;
9
10use vortex_buffer::BitBuffer;
11use vortex_dtype::DType;
12use vortex_dtype::Nullability;
13use vortex_error::VortexExpect as _;
14use vortex_error::VortexResult;
15use vortex_error::vortex_err;
16use vortex_error::vortex_panic;
17use vortex_mask::AllOr;
18use vortex_mask::Mask;
19use vortex_mask::MaskValues;
20use vortex_scalar::Scalar;
21
22use crate::Array;
23use crate::ArrayRef;
24use crate::IntoArray;
25use crate::ToCanonical;
26use crate::arrays::BoolArray;
27use crate::arrays::ConstantArray;
28use crate::compute::fill_null;
29use crate::compute::filter;
30use crate::compute::sum;
31use crate::compute::take;
32use crate::patches::Patches;
33
34/// Validity information for an array
35#[derive(Clone, Debug)]
36pub enum Validity {
37    /// Items *can't* be null
38    NonNullable,
39    /// All items are valid
40    AllValid,
41    /// All items are null
42    AllInvalid,
43    /// The validity of each position in the array is determined by a boolean array.
44    ///
45    /// True values are valid, false values are invalid ("null").
46    Array(ArrayRef),
47}
48
49impl Validity {
50    /// The [`DType`] of the underlying validity array (if it exists).
51    pub const DTYPE: DType = DType::Bool(Nullability::NonNullable);
52
53    /// Convert the validity to an array representation.
54    pub fn to_array(&self, len: usize) -> ArrayRef {
55        match self {
56            Self::NonNullable | Self::AllValid => ConstantArray::new(true, len).into_array(),
57            Self::AllInvalid => ConstantArray::new(false, len).into_array(),
58            Self::Array(a) => a.clone(),
59        }
60    }
61
62    /// If Validity is [`Validity::Array`], returns the array, otherwise returns `None`.
63    #[inline]
64    pub fn into_array(self) -> Option<ArrayRef> {
65        if let Self::Array(a) = self {
66            Some(a)
67        } else {
68            None
69        }
70    }
71
72    /// If Validity is [`Validity::Array`], returns a reference to the array array, otherwise returns `None`.
73    #[inline]
74    pub fn as_array(&self) -> Option<&ArrayRef> {
75        if let Self::Array(a) = self {
76            Some(a)
77        } else {
78            None
79        }
80    }
81
82    #[inline]
83    pub fn nullability(&self) -> Nullability {
84        if matches!(self, Self::NonNullable) {
85            Nullability::NonNullable
86        } else {
87            Nullability::Nullable
88        }
89    }
90
91    /// The union nullability and validity.
92    #[inline]
93    pub fn union_nullability(self, nullability: Nullability) -> Self {
94        match nullability {
95            Nullability::NonNullable => self,
96            Nullability::Nullable => self.into_nullable(),
97        }
98    }
99
100    #[inline]
101    pub fn all_valid(&self, len: usize) -> bool {
102        match self {
103            _ if len == 0 => true,
104            Validity::NonNullable | Validity::AllValid => true,
105            Validity::AllInvalid => false,
106            Validity::Array(array) => {
107                usize::try_from(&sum(array).vortex_expect("must have sum for bool array"))
108                    .vortex_expect("sum must be a usize")
109                    == array.len()
110            }
111        }
112    }
113
114    #[inline]
115    pub fn all_invalid(&self, len: usize) -> bool {
116        match self {
117            _ if len == 0 => true,
118            Validity::NonNullable | Validity::AllValid => false,
119            Validity::AllInvalid => true,
120            Validity::Array(array) => {
121                usize::try_from(&sum(array).vortex_expect("must have sum for bool array"))
122                    .vortex_expect("sum must be a usize")
123                    == 0
124            }
125        }
126    }
127
128    /// Returns whether the `index` item is valid.
129    #[inline]
130    pub fn is_valid(&self, index: usize) -> bool {
131        match self {
132            Self::NonNullable | Self::AllValid => true,
133            Self::AllInvalid => false,
134            Self::Array(a) => {
135                let scalar = a.scalar_at(index);
136                scalar
137                    .as_bool()
138                    .value()
139                    .vortex_expect("Validity must be non-nullable")
140            }
141        }
142    }
143
144    #[inline]
145    pub fn is_null(&self, index: usize) -> bool {
146        !self.is_valid(index)
147    }
148
149    #[inline]
150    pub fn slice(&self, range: Range<usize>) -> Self {
151        match self {
152            Self::Array(a) => Self::Array(a.slice(range)),
153            Self::NonNullable | Self::AllValid | Self::AllInvalid => self.clone(),
154        }
155    }
156
157    pub fn take(&self, indices: &dyn Array) -> VortexResult<Self> {
158        match self {
159            Self::NonNullable => match indices.validity_mask().bit_buffer() {
160                AllOr::All => {
161                    if indices.dtype().is_nullable() {
162                        Ok(Self::AllValid)
163                    } else {
164                        Ok(Self::NonNullable)
165                    }
166                }
167                AllOr::None => Ok(Self::AllInvalid),
168                AllOr::Some(buf) => Ok(Validity::from(buf.clone())),
169            },
170            Self::AllValid => match indices.validity_mask().bit_buffer() {
171                AllOr::All => Ok(Self::AllValid),
172                AllOr::None => Ok(Self::AllInvalid),
173                AllOr::Some(buf) => Ok(Validity::from(buf.clone())),
174            },
175            Self::AllInvalid => Ok(Self::AllInvalid),
176            Self::Array(is_valid) => {
177                let maybe_is_valid = take(is_valid, indices)?;
178                // Null indices invalidate that position.
179                let is_valid = fill_null(&maybe_is_valid, &Scalar::from(false))?;
180                Ok(Self::Array(is_valid))
181            }
182        }
183    }
184
185    /// Keep only the entries for which the mask is true.
186    ///
187    /// The result has length equal to the number of true values in mask.
188    pub fn filter(&self, mask: &Mask) -> VortexResult<Self> {
189        // NOTE(ngates): we take the mask as a reference to avoid the caller cloning unnecessarily
190        //  if we happen to be NonNullable, AllValid, or AllInvalid.
191        match self {
192            v @ (Validity::NonNullable | Validity::AllValid | Validity::AllInvalid) => {
193                Ok(v.clone())
194            }
195            Validity::Array(arr) => Ok(Validity::Array(filter(arr, mask)?)),
196        }
197    }
198
199    /// Set to false any entries for which the mask is true.
200    ///
201    /// The result is always nullable. The result has the same length as self.
202    #[inline]
203    pub fn mask(&self, mask: &Mask) -> Self {
204        match mask.bit_buffer() {
205            AllOr::All => Validity::AllInvalid,
206            AllOr::None => self.clone().into_nullable(),
207            AllOr::Some(make_invalid) => match self {
208                Validity::NonNullable | Validity::AllValid => {
209                    Validity::Array(BoolArray::from(!make_invalid).into_array())
210                }
211                Validity::AllInvalid => Validity::AllInvalid,
212                Validity::Array(is_valid) => {
213                    let is_valid = is_valid.to_bool();
214                    Validity::from(is_valid.bit_buffer() & !make_invalid)
215                }
216            },
217        }
218    }
219
220    #[inline]
221    pub fn to_mask(&self, length: usize) -> Mask {
222        match self {
223            Self::NonNullable | Self::AllValid => Mask::AllTrue(length),
224            Self::AllInvalid => Mask::AllFalse(length),
225            Self::Array(is_valid) => {
226                assert_eq!(
227                    is_valid.len(),
228                    length,
229                    "Validity::Array length must equal to_logical's argument: {}, {}.",
230                    is_valid.len(),
231                    length,
232                );
233                is_valid.to_bool().to_mask()
234            }
235        }
236    }
237
238    /// Logically & two Validity values of the same length
239    #[inline]
240    pub fn and(self, rhs: Validity) -> Validity {
241        match (self, rhs) {
242            // Should be pretty clear
243            (Validity::NonNullable, Validity::NonNullable) => Validity::NonNullable,
244            // Any `AllInvalid` makes the output all invalid values
245            (Validity::AllInvalid, _) | (_, Validity::AllInvalid) => Validity::AllInvalid,
246            // All truthy values on one side, which makes no effect on an `Array` variant
247            (Validity::Array(a), Validity::AllValid)
248            | (Validity::Array(a), Validity::NonNullable)
249            | (Validity::NonNullable, Validity::Array(a))
250            | (Validity::AllValid, Validity::Array(a)) => Validity::Array(a),
251            // Both sides are all valid
252            (Validity::NonNullable, Validity::AllValid)
253            | (Validity::AllValid, Validity::NonNullable)
254            | (Validity::AllValid, Validity::AllValid) => Validity::AllValid,
255            // Here we actually have to do some work
256            (Validity::Array(lhs), Validity::Array(rhs)) => {
257                let lhs = lhs.to_bool();
258                let rhs = rhs.to_bool();
259
260                let lhs = lhs.bit_buffer();
261                let rhs = rhs.bit_buffer();
262
263                Validity::from(lhs.bitand(rhs))
264            }
265        }
266    }
267
268    pub fn patch(
269        self,
270        len: usize,
271        indices_offset: usize,
272        indices: &dyn Array,
273        patches: &Validity,
274    ) -> Self {
275        match (&self, patches) {
276            (Validity::NonNullable, Validity::NonNullable) => return Validity::NonNullable,
277            (Validity::NonNullable, _) => {
278                vortex_panic!("Can't patch a non-nullable validity with nullable validity")
279            }
280            (_, Validity::NonNullable) => {
281                vortex_panic!("Can't patch a nullable validity with non-nullable validity")
282            }
283            (Validity::AllValid, Validity::AllValid) => return Validity::AllValid,
284            (Validity::AllInvalid, Validity::AllInvalid) => return Validity::AllInvalid,
285            _ => {}
286        };
287
288        let own_nullability = if self == Validity::NonNullable {
289            Nullability::NonNullable
290        } else {
291            Nullability::Nullable
292        };
293
294        let source = match self {
295            Validity::NonNullable => BoolArray::from(BitBuffer::new_set(len)),
296            Validity::AllValid => BoolArray::from(BitBuffer::new_set(len)),
297            Validity::AllInvalid => BoolArray::from(BitBuffer::new_unset(len)),
298            Validity::Array(a) => a.to_bool(),
299        };
300
301        let patch_values = match patches {
302            Validity::NonNullable => BoolArray::from(BitBuffer::new_set(indices.len())),
303            Validity::AllValid => BoolArray::from(BitBuffer::new_set(indices.len())),
304            Validity::AllInvalid => BoolArray::from(BitBuffer::new_unset(indices.len())),
305            Validity::Array(a) => a.to_bool(),
306        };
307
308        let patches = Patches::new(
309            len,
310            indices_offset,
311            indices.to_array(),
312            patch_values.into_array(),
313            // TODO(0ax1): chunk offsets
314            None,
315        );
316
317        Self::from_array(source.patch(&patches).into_array(), own_nullability)
318    }
319
320    /// Convert into a nullable variant
321    #[inline]
322    pub fn into_nullable(self) -> Validity {
323        match self {
324            Self::NonNullable => Self::AllValid,
325            Self::AllValid | Self::AllInvalid | Self::Array(_) => self,
326        }
327    }
328
329    /// Convert into a non-nullable variant
330    #[inline]
331    pub fn into_non_nullable(self, len: usize) -> Option<Validity> {
332        match self {
333            _ if len == 0 => Some(Validity::NonNullable),
334            Self::NonNullable => Some(Self::NonNullable),
335            Self::AllValid => Some(Self::NonNullable),
336            Self::AllInvalid => None,
337            Self::Array(is_valid) => {
338                is_valid
339                    .statistics()
340                    .compute_min::<bool>()
341                    .vortex_expect("validity array must support min")
342                    .then(|| {
343                        // min true => all true
344                        Self::NonNullable
345                    })
346            }
347        }
348    }
349
350    /// Convert into a variant compatible with the given nullability, if possible.
351    #[inline]
352    pub fn cast_nullability(self, nullability: Nullability, len: usize) -> VortexResult<Validity> {
353        match nullability {
354            Nullability::NonNullable => self.into_non_nullable(len).ok_or_else(|| {
355                vortex_err!("Cannot cast array with invalid values to non-nullable type.")
356            }),
357            Nullability::Nullable => Ok(self.into_nullable()),
358        }
359    }
360
361    /// Create Validity by copying the given array's validity.
362    #[inline]
363    pub fn copy_from_array(array: &dyn Array) -> Self {
364        Validity::from_mask(array.validity_mask(), array.dtype().nullability())
365    }
366
367    /// Create Validity from boolean array with given nullability of the array.
368    ///
369    /// Note: You want to pass the nullability of parent array and not the nullability of the validity array itself
370    ///     as that is always nonnullable
371    #[inline]
372    fn from_array(value: ArrayRef, nullability: Nullability) -> Self {
373        if !matches!(value.dtype(), DType::Bool(Nullability::NonNullable)) {
374            vortex_panic!("Expected a non-nullable boolean array")
375        }
376        match nullability {
377            Nullability::NonNullable => Self::NonNullable,
378            Nullability::Nullable => Self::Array(value),
379        }
380    }
381
382    /// Returns the length of the validity array, if it exists.
383    #[inline]
384    pub fn maybe_len(&self) -> Option<usize> {
385        match self {
386            Self::NonNullable | Self::AllValid | Self::AllInvalid => None,
387            Self::Array(a) => Some(a.len()),
388        }
389    }
390
391    #[inline]
392    pub fn uncompressed_size(&self) -> usize {
393        if let Validity::Array(a) = self {
394            a.len().div_ceil(8)
395        } else {
396            0
397        }
398    }
399}
400
401impl PartialEq for Validity {
402    #[inline]
403    fn eq(&self, other: &Self) -> bool {
404        match (self, other) {
405            (Self::NonNullable, Self::NonNullable) => true,
406            (Self::AllValid, Self::AllValid) => true,
407            (Self::AllInvalid, Self::AllInvalid) => true,
408            (Self::Array(a), Self::Array(b)) => {
409                let a = a.to_bool();
410                let b = b.to_bool();
411                a.bit_buffer() == b.bit_buffer()
412            }
413            _ => false,
414        }
415    }
416}
417
418impl From<BitBuffer> for Validity {
419    #[inline]
420    fn from(value: BitBuffer) -> Self {
421        let true_count = value.true_count();
422        if true_count == value.len() {
423            Self::AllValid
424        } else if true_count == 0 {
425            Self::AllInvalid
426        } else {
427            Self::Array(BoolArray::from(value).into_array())
428        }
429    }
430}
431
432impl FromIterator<Mask> for Validity {
433    #[inline]
434    fn from_iter<T: IntoIterator<Item = Mask>>(iter: T) -> Self {
435        Validity::from_mask(iter.into_iter().collect(), Nullability::Nullable)
436    }
437}
438
439impl FromIterator<bool> for Validity {
440    #[inline]
441    fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
442        Validity::from(BitBuffer::from_iter(iter))
443    }
444}
445
446impl From<Nullability> for Validity {
447    #[inline]
448    fn from(value: Nullability) -> Self {
449        match value {
450            Nullability::NonNullable => Validity::NonNullable,
451            Nullability::Nullable => Validity::AllValid,
452        }
453    }
454}
455
456impl Validity {
457    pub fn from_bit_buffer(buffer: BitBuffer, nullability: Nullability) -> Self {
458        if buffer.true_count() == buffer.len() {
459            nullability.into()
460        } else if buffer.true_count() == 0 {
461            Validity::AllInvalid
462        } else {
463            Validity::Array(BoolArray::from_bit_buffer(buffer, Validity::NonNullable).into_array())
464        }
465    }
466
467    pub fn from_mask(mask: Mask, nullability: Nullability) -> Self {
468        assert!(
469            nullability == Nullability::Nullable || matches!(mask, Mask::AllTrue(_)),
470            "NonNullable validity must be AllValid",
471        );
472        match mask {
473            Mask::AllTrue(_) => match nullability {
474                Nullability::NonNullable => Validity::NonNullable,
475                Nullability::Nullable => Validity::AllValid,
476            },
477            Mask::AllFalse(_) => Validity::AllInvalid,
478            Mask::Values(values) => Validity::Array(values.into_array()),
479        }
480    }
481}
482
483impl IntoArray for Mask {
484    #[inline]
485    fn into_array(self) -> ArrayRef {
486        match self {
487            Self::AllTrue(len) => ConstantArray::new(true, len).into_array(),
488            Self::AllFalse(len) => ConstantArray::new(false, len).into_array(),
489            Self::Values(a) => a.into_array(),
490        }
491    }
492}
493
494impl IntoArray for &MaskValues {
495    #[inline]
496    fn into_array(self) -> ArrayRef {
497        BoolArray::from_bit_buffer(self.bit_buffer().clone(), Validity::NonNullable).into_array()
498    }
499}
500
501#[cfg(test)]
502mod tests {
503    use rstest::rstest;
504    use vortex_buffer::Buffer;
505    use vortex_buffer::buffer;
506    use vortex_dtype::Nullability;
507    use vortex_mask::Mask;
508
509    use crate::ArrayRef;
510    use crate::IntoArray;
511    use crate::arrays::BoolArray;
512    use crate::arrays::PrimitiveArray;
513    use crate::validity::Validity;
514
515    #[rstest]
516    #[case(Validity::AllValid, 5, &[2, 4], Validity::AllValid, Validity::AllValid)]
517    #[case(Validity::AllValid, 5, &[2, 4], Validity::AllInvalid, Validity::Array(BoolArray::from_iter([true, true, false, true, false]).into_array())
518    )]
519    #[case(Validity::AllValid, 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([true, true, true, true, false]).into_array())
520    )]
521    #[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllValid, Validity::Array(BoolArray::from_iter([false, false, true, false, true]).into_array())
522    )]
523    #[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllInvalid, Validity::AllInvalid)]
524    #[case(Validity::AllInvalid, 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([false, false, true, false, false]).into_array())
525    )]
526    #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::AllValid, Validity::Array(BoolArray::from_iter([false, true, true, true, true]).into_array())
527    )]
528    #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::AllInvalid, Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array())
529    )]
530    #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([false, true, true, true, false]).into_array())
531    )]
532    fn patch_validity(
533        #[case] validity: Validity,
534        #[case] len: usize,
535        #[case] positions: &[u64],
536        #[case] patches: Validity,
537        #[case] expected: Validity,
538    ) {
539        let indices =
540            PrimitiveArray::new(Buffer::copy_from(positions), Validity::NonNullable).into_array();
541        assert_eq!(validity.patch(len, 0, &indices, &patches), expected);
542    }
543
544    #[test]
545    #[should_panic]
546    fn out_of_bounds_patch() {
547        Validity::NonNullable.patch(2, 0, &buffer![4].into_array(), &Validity::AllInvalid);
548    }
549
550    #[test]
551    #[should_panic]
552    fn into_validity_nullable() {
553        Validity::from_mask(Mask::AllFalse(10), Nullability::NonNullable);
554    }
555
556    #[test]
557    #[should_panic]
558    fn into_validity_nullable_array() {
559        Validity::from_mask(Mask::from_iter(vec![true, false]), Nullability::NonNullable);
560    }
561
562    #[rstest]
563    #[case(Validity::AllValid, PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(), Validity::from_iter(vec![true, false])
564    )]
565    #[case(Validity::AllValid, buffer![0, 1].into_array(), Validity::AllValid)]
566    #[case(Validity::AllValid, PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(), Validity::AllInvalid
567    )]
568    #[case(Validity::NonNullable, PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(), Validity::from_iter(vec![true, false])
569    )]
570    #[case(Validity::NonNullable, buffer![0, 1].into_array(), Validity::NonNullable)]
571    #[case(Validity::NonNullable, PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(), Validity::AllInvalid
572    )]
573    fn validity_take(
574        #[case] validity: Validity,
575        #[case] indices: ArrayRef,
576        #[case] expected: Validity,
577    ) {
578        assert_eq!(validity.take(&indices).unwrap(), expected);
579    }
580
581    #[test]
582    fn mask_non_nullable() {
583        assert_eq!(
584            Validity::AllValid,
585            Validity::NonNullable.mask(&Mask::AllFalse(2))
586        )
587    }
588}