Skip to main content

vortex_array/
validity.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Array validity and nullability behavior, used by arrays and compute functions.
5
6use std::fmt::Debug;
7use std::ops::Range;
8
9use vortex_buffer::BitBuffer;
10use vortex_dtype::DType;
11use vortex_dtype::Nullability;
12use vortex_error::VortexExpect as _;
13use vortex_error::VortexResult;
14use vortex_error::vortex_bail;
15use vortex_error::vortex_err;
16use vortex_error::vortex_panic;
17use vortex_mask::AllOr;
18use vortex_mask::Mask;
19use vortex_mask::MaskValues;
20
21use crate::Array;
22use crate::ArrayRef;
23use crate::Canonical;
24use crate::ExecutionCtx;
25use crate::IntoArray;
26use crate::ToCanonical;
27use crate::arrays::BoolArray;
28use crate::arrays::ConstantArray;
29use crate::arrays::ScalarFnArrayExt;
30use crate::builtins::ArrayBuiltins;
31use crate::compute::sum;
32use crate::expr::Binary;
33use crate::expr::Operator;
34use crate::optimizer::ArrayOptimizer;
35use crate::patches::Patches;
36use crate::scalar::Scalar;
37
38/// Validity information for an array
39#[derive(Clone, Debug)]
40pub enum Validity {
41    /// Items *can't* be null
42    NonNullable,
43    /// All items are valid
44    AllValid,
45    /// All items are null
46    AllInvalid,
47    /// The validity of each position in the array is determined by a boolean array.
48    ///
49    /// True values are valid, false values are invalid ("null").
50    Array(ArrayRef),
51}
52
53impl Validity {
54    /// Make a step towards canonicalising validity if necessary
55    pub fn execute(self, ctx: &mut ExecutionCtx) -> VortexResult<Validity> {
56        match self {
57            v @ Validity::NonNullable | v @ Validity::AllValid | v @ Validity::AllInvalid => Ok(v),
58            Validity::Array(a) => Ok(Validity::Array(a.execute::<Canonical>(ctx)?.into_array())),
59        }
60    }
61}
62
63impl Validity {
64    /// The [`DType`] of the underlying validity array (if it exists).
65    pub const DTYPE: DType = DType::Bool(Nullability::NonNullable);
66
67    /// Convert the validity to an array representation.
68    pub fn to_array(&self, len: usize) -> ArrayRef {
69        match self {
70            Self::NonNullable | Self::AllValid => ConstantArray::new(true, len).into_array(),
71            Self::AllInvalid => ConstantArray::new(false, len).into_array(),
72            Self::Array(a) => a.clone(),
73        }
74    }
75
76    /// If Validity is [`Validity::Array`], returns the array, otherwise returns `None`.
77    #[inline]
78    pub fn into_array(self) -> Option<ArrayRef> {
79        if let Self::Array(a) = self {
80            Some(a)
81        } else {
82            None
83        }
84    }
85
86    /// If Validity is [`Validity::Array`], returns a reference to the array array, otherwise returns `None`.
87    #[inline]
88    pub fn as_array(&self) -> Option<&ArrayRef> {
89        if let Self::Array(a) = self {
90            Some(a)
91        } else {
92            None
93        }
94    }
95
96    #[inline]
97    pub fn nullability(&self) -> Nullability {
98        if matches!(self, Self::NonNullable) {
99            Nullability::NonNullable
100        } else {
101            Nullability::Nullable
102        }
103    }
104
105    /// The union nullability and validity.
106    #[inline]
107    pub fn union_nullability(self, nullability: Nullability) -> Self {
108        match nullability {
109            Nullability::NonNullable => self,
110            Nullability::Nullable => self.into_nullable(),
111        }
112    }
113
114    #[inline]
115    pub fn all_valid(&self, len: usize) -> VortexResult<bool> {
116        Ok(match self {
117            _ if len == 0 => true,
118            Validity::NonNullable | Validity::AllValid => true,
119            Validity::AllInvalid => false,
120            Validity::Array(array) => {
121                usize::try_from(&sum(array).vortex_expect("must have sum for bool array"))
122                    .vortex_expect("sum must be a usize")
123                    == array.len()
124            }
125        })
126    }
127
128    #[inline]
129    pub fn all_invalid(&self, len: usize) -> VortexResult<bool> {
130        Ok(match self {
131            _ if len == 0 => true,
132            Validity::NonNullable | Validity::AllValid => false,
133            Validity::AllInvalid => true,
134            Validity::Array(array) => {
135                usize::try_from(&sum(array).vortex_expect("must have sum for bool array"))
136                    .vortex_expect("sum must be a usize")
137                    == 0
138            }
139        })
140    }
141
142    /// Returns whether the `index` item is valid.
143    #[inline]
144    pub fn is_valid(&self, index: usize) -> VortexResult<bool> {
145        Ok(match self {
146            Self::NonNullable | Self::AllValid => true,
147            Self::AllInvalid => false,
148            Self::Array(a) => a
149                .scalar_at(index)
150                .vortex_expect("Validity array must support scalar_at")
151                .as_bool()
152                .value()
153                .vortex_expect("Validity must be non-nullable"),
154        })
155    }
156
157    #[inline]
158    pub fn is_null(&self, index: usize) -> VortexResult<bool> {
159        Ok(!self.is_valid(index)?)
160    }
161
162    #[inline]
163    pub fn slice(&self, range: Range<usize>) -> VortexResult<Self> {
164        match self {
165            Self::Array(a) => Ok(Self::Array(a.slice(range)?)),
166            Self::NonNullable | Self::AllValid | Self::AllInvalid => Ok(self.clone()),
167        }
168    }
169
170    pub fn take(&self, indices: &dyn Array) -> VortexResult<Self> {
171        match self {
172            Self::NonNullable => match indices.validity_mask()?.bit_buffer() {
173                AllOr::All => {
174                    if indices.dtype().is_nullable() {
175                        Ok(Self::AllValid)
176                    } else {
177                        Ok(Self::NonNullable)
178                    }
179                }
180                AllOr::None => Ok(Self::AllInvalid),
181                AllOr::Some(buf) => Ok(Validity::from(buf.clone())),
182            },
183            Self::AllValid => match indices.validity_mask()?.bit_buffer() {
184                AllOr::All => Ok(Self::AllValid),
185                AllOr::None => Ok(Self::AllInvalid),
186                AllOr::Some(buf) => Ok(Validity::from(buf.clone())),
187            },
188            Self::AllInvalid => Ok(Self::AllInvalid),
189            Self::Array(is_valid) => {
190                let maybe_is_valid = is_valid
191                    .take(indices.to_array())?
192                    .to_canonical()?
193                    .into_array();
194                // Null indices invalidate that position.
195                let is_valid = maybe_is_valid.fill_null(Scalar::from(false))?;
196                Ok(Self::Array(is_valid))
197            }
198        }
199    }
200
201    // Invert the validity
202    pub fn not(&self) -> VortexResult<Self> {
203        match self {
204            Validity::NonNullable => Ok(Validity::NonNullable),
205            Validity::AllValid => Ok(Validity::AllInvalid),
206            Validity::AllInvalid => Ok(Validity::AllValid),
207            Validity::Array(arr) => Ok(Validity::Array(arr.not()?)),
208        }
209    }
210
211    /// Lazily filters a [`Validity`] with a selection mask, which keeps only the entries for which
212    /// the mask is true.
213    ///
214    /// The result has length equal to the number of true values in mask.
215    ///
216    /// If the validity is a [`Validity::Array`], then this lazily wraps it in a `FilterArray`
217    /// instead of eagerly filtering the values immediately.
218    pub fn filter(&self, mask: &Mask) -> VortexResult<Self> {
219        // NOTE(ngates): we take the mask as a reference to avoid the caller cloning unnecessarily
220        //  if we happen to be NonNullable, AllValid, or AllInvalid.
221        match self {
222            v @ (Validity::NonNullable | Validity::AllValid | Validity::AllInvalid) => {
223                Ok(v.clone())
224            }
225            Validity::Array(arr) => Ok(Validity::Array(
226                arr.filter(mask.clone())?
227                    // TODO(connor): This is wrong!!! We should not be eagerly decompressing the
228                    // validity array.
229                    .to_canonical()?
230                    .into_array(),
231            )),
232        }
233    }
234
235    #[inline]
236    pub fn to_mask(&self, length: usize) -> Mask {
237        match self {
238            Self::NonNullable | Self::AllValid => Mask::AllTrue(length),
239            Self::AllInvalid => Mask::AllFalse(length),
240            Self::Array(is_valid) => {
241                assert_eq!(
242                    is_valid.len(),
243                    length,
244                    "Validity::Array length must equal to_logical's argument: {}, {}.",
245                    is_valid.len(),
246                    length,
247                );
248                is_valid.to_bool().to_mask()
249            }
250        }
251    }
252
253    /// Logically & two Validity values of the same length
254    #[inline]
255    pub fn and(self, rhs: Validity) -> VortexResult<Validity> {
256        Ok(match (self, rhs) {
257            // Should be pretty clear
258            (Validity::NonNullable, Validity::NonNullable) => Validity::NonNullable,
259            // Any `AllInvalid` makes the output all invalid values
260            (Validity::AllInvalid, _) | (_, Validity::AllInvalid) => Validity::AllInvalid,
261            // All truthy values on one side, which makes no effect on an `Array` variant
262            (Validity::Array(a), Validity::AllValid)
263            | (Validity::Array(a), Validity::NonNullable)
264            | (Validity::NonNullable, Validity::Array(a))
265            | (Validity::AllValid, Validity::Array(a)) => Validity::Array(a),
266            // Both sides are all valid
267            (Validity::NonNullable, Validity::AllValid)
268            | (Validity::AllValid, Validity::NonNullable)
269            | (Validity::AllValid, Validity::AllValid) => Validity::AllValid,
270            // Here we actually have to do some work
271            (Validity::Array(lhs), Validity::Array(rhs)) => Validity::Array(
272                Binary
273                    .try_new_array(lhs.len(), Operator::And, [lhs, rhs])?
274                    .optimize()?,
275            ),
276        })
277    }
278
279    pub fn patch(
280        self,
281        len: usize,
282        indices_offset: usize,
283        indices: &dyn Array,
284        patches: &Validity,
285    ) -> VortexResult<Self> {
286        match (&self, patches) {
287            (Validity::NonNullable, Validity::NonNullable) => return Ok(Validity::NonNullable),
288            (Validity::NonNullable, _) => {
289                vortex_bail!("Can't patch a non-nullable validity with nullable validity")
290            }
291            (_, Validity::NonNullable) => {
292                vortex_bail!("Can't patch a nullable validity with non-nullable validity")
293            }
294            (Validity::AllValid, Validity::AllValid) => return Ok(Validity::AllValid),
295            (Validity::AllInvalid, Validity::AllInvalid) => return Ok(Validity::AllInvalid),
296            _ => {}
297        };
298
299        let own_nullability = if self == Validity::NonNullable {
300            Nullability::NonNullable
301        } else {
302            Nullability::Nullable
303        };
304
305        let source = match self {
306            Validity::NonNullable => BoolArray::from(BitBuffer::new_set(len)),
307            Validity::AllValid => BoolArray::from(BitBuffer::new_set(len)),
308            Validity::AllInvalid => BoolArray::from(BitBuffer::new_unset(len)),
309            Validity::Array(a) => a.to_bool(),
310        };
311
312        let patch_values = match patches {
313            Validity::NonNullable => BoolArray::from(BitBuffer::new_set(indices.len())),
314            Validity::AllValid => BoolArray::from(BitBuffer::new_set(indices.len())),
315            Validity::AllInvalid => BoolArray::from(BitBuffer::new_unset(indices.len())),
316            Validity::Array(a) => a.to_bool(),
317        };
318
319        let patches = Patches::new(
320            len,
321            indices_offset,
322            indices.to_array(),
323            patch_values.into_array(),
324            // TODO(0ax1): chunk offsets
325            None,
326        )?;
327
328        Ok(Self::from_array(
329            source.patch(&patches)?.into_array(),
330            own_nullability,
331        ))
332    }
333
334    /// Convert into a nullable variant
335    #[inline]
336    pub fn into_nullable(self) -> Validity {
337        match self {
338            Self::NonNullable => Self::AllValid,
339            Self::AllValid | Self::AllInvalid | Self::Array(_) => self,
340        }
341    }
342
343    /// Convert into a non-nullable variant
344    #[inline]
345    pub fn into_non_nullable(self, len: usize) -> Option<Validity> {
346        match self {
347            _ if len == 0 => Some(Validity::NonNullable),
348            Self::NonNullable => Some(Self::NonNullable),
349            Self::AllValid => Some(Self::NonNullable),
350            Self::AllInvalid => None,
351            Self::Array(is_valid) => {
352                is_valid
353                    .statistics()
354                    .compute_min::<bool>()
355                    .vortex_expect("validity array must support min")
356                    .then(|| {
357                        // min true => all true
358                        Self::NonNullable
359                    })
360            }
361        }
362    }
363
364    /// Convert into a variant compatible with the given nullability, if possible.
365    #[inline]
366    pub fn cast_nullability(self, nullability: Nullability, len: usize) -> VortexResult<Validity> {
367        match nullability {
368            Nullability::NonNullable => self.into_non_nullable(len).ok_or_else(|| {
369                vortex_err!(InvalidArgument: "Cannot cast array with invalid values to non-nullable type.")
370            }),
371            Nullability::Nullable => Ok(self.into_nullable()),
372        }
373    }
374
375    /// Create Validity by copying the given array's validity.
376    #[inline]
377    pub fn copy_from_array(array: &dyn Array) -> VortexResult<Self> {
378        Ok(Validity::from_mask(
379            array.validity_mask()?,
380            array.dtype().nullability(),
381        ))
382    }
383
384    /// Create Validity from boolean array with given nullability of the array.
385    ///
386    /// Note: You want to pass the nullability of parent array and not the nullability of the validity array itself
387    ///     as that is always nonnullable
388    #[inline]
389    fn from_array(value: ArrayRef, nullability: Nullability) -> Self {
390        if !matches!(value.dtype(), DType::Bool(Nullability::NonNullable)) {
391            vortex_panic!("Expected a non-nullable boolean array")
392        }
393        match nullability {
394            Nullability::NonNullable => Self::NonNullable,
395            Nullability::Nullable => Self::Array(value),
396        }
397    }
398
399    /// Returns the length of the validity array, if it exists.
400    #[inline]
401    pub fn maybe_len(&self) -> Option<usize> {
402        match self {
403            Self::NonNullable | Self::AllValid | Self::AllInvalid => None,
404            Self::Array(a) => Some(a.len()),
405        }
406    }
407
408    #[inline]
409    pub fn uncompressed_size(&self) -> usize {
410        if let Validity::Array(a) = self {
411            a.len().div_ceil(8)
412        } else {
413            0
414        }
415    }
416}
417
418impl PartialEq for Validity {
419    #[inline]
420    fn eq(&self, other: &Self) -> bool {
421        match (self, other) {
422            (Self::NonNullable, Self::NonNullable) => true,
423            (Self::AllValid, Self::AllValid) => true,
424            (Self::AllInvalid, Self::AllInvalid) => true,
425            (Self::Array(a), Self::Array(b)) => {
426                let a = a.to_bool();
427                let b = b.to_bool();
428                a.to_bit_buffer() == b.to_bit_buffer()
429            }
430            _ => false,
431        }
432    }
433}
434
435impl From<BitBuffer> for Validity {
436    #[inline]
437    fn from(value: BitBuffer) -> Self {
438        let true_count = value.true_count();
439        if true_count == value.len() {
440            Self::AllValid
441        } else if true_count == 0 {
442            Self::AllInvalid
443        } else {
444            Self::Array(BoolArray::from(value).into_array())
445        }
446    }
447}
448
449impl FromIterator<Mask> for Validity {
450    #[inline]
451    fn from_iter<T: IntoIterator<Item = Mask>>(iter: T) -> Self {
452        Validity::from_mask(iter.into_iter().collect(), Nullability::Nullable)
453    }
454}
455
456impl FromIterator<bool> for Validity {
457    #[inline]
458    fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
459        Validity::from(BitBuffer::from_iter(iter))
460    }
461}
462
463impl From<Nullability> for Validity {
464    #[inline]
465    fn from(value: Nullability) -> Self {
466        Validity::from(&value)
467    }
468}
469
470impl From<&Nullability> for Validity {
471    #[inline]
472    fn from(value: &Nullability) -> Self {
473        match *value {
474            Nullability::NonNullable => Validity::NonNullable,
475            Nullability::Nullable => Validity::AllValid,
476        }
477    }
478}
479
480impl Validity {
481    pub fn from_bit_buffer(buffer: BitBuffer, nullability: Nullability) -> Self {
482        if buffer.true_count() == buffer.len() {
483            nullability.into()
484        } else if buffer.true_count() == 0 {
485            Validity::AllInvalid
486        } else {
487            Validity::Array(BoolArray::new(buffer, Validity::NonNullable).into_array())
488        }
489    }
490
491    pub fn from_mask(mask: Mask, nullability: Nullability) -> Self {
492        assert!(
493            nullability == Nullability::Nullable || matches!(mask, Mask::AllTrue(_)),
494            "NonNullable validity must be AllValid",
495        );
496        match mask {
497            Mask::AllTrue(_) => match nullability {
498                Nullability::NonNullable => Validity::NonNullable,
499                Nullability::Nullable => Validity::AllValid,
500            },
501            Mask::AllFalse(_) => Validity::AllInvalid,
502            Mask::Values(values) => Validity::Array(values.into_array()),
503        }
504    }
505}
506
507impl IntoArray for Mask {
508    #[inline]
509    fn into_array(self) -> ArrayRef {
510        match self {
511            Self::AllTrue(len) => ConstantArray::new(true, len).into_array(),
512            Self::AllFalse(len) => ConstantArray::new(false, len).into_array(),
513            Self::Values(a) => a.into_array(),
514        }
515    }
516}
517
518impl IntoArray for &MaskValues {
519    #[inline]
520    fn into_array(self) -> ArrayRef {
521        BoolArray::new(self.bit_buffer().clone(), Validity::NonNullable).into_array()
522    }
523}
524
525#[cfg(test)]
526mod tests {
527    use rstest::rstest;
528    use vortex_buffer::Buffer;
529    use vortex_buffer::buffer;
530    use vortex_dtype::Nullability;
531    use vortex_mask::Mask;
532
533    use crate::ArrayRef;
534    use crate::IntoArray;
535    use crate::arrays::BoolArray;
536    use crate::arrays::PrimitiveArray;
537    use crate::validity::Validity;
538
539    #[rstest]
540    #[case(Validity::AllValid, 5, &[2, 4], Validity::AllValid, Validity::AllValid)]
541    #[case(
542        Validity::AllValid,
543        5,
544        &[2, 4],
545        Validity::AllInvalid,
546        Validity::Array(BoolArray::from_iter([true, true, false, true, false]).into_array())
547    )]
548    #[case(
549        Validity::AllValid,
550        5,
551        &[2, 4],
552        Validity::Array(BoolArray::from_iter([true, false]).into_array()),
553        Validity::Array(BoolArray::from_iter([true, true, true, true, false]).into_array())
554    )]
555    #[case(
556        Validity::AllInvalid,
557        5,
558        &[2, 4],
559        Validity::AllValid,
560        Validity::Array(BoolArray::from_iter([false, false, true, false, true]).into_array())
561    )]
562    #[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllInvalid, Validity::AllInvalid)]
563    #[case(
564        Validity::AllInvalid,
565        5,
566        &[2, 4],
567        Validity::Array(BoolArray::from_iter([true, false]).into_array()),
568        Validity::Array(BoolArray::from_iter([false, false, true, false, false]).into_array())
569    )]
570    #[case(
571        Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()),
572        5,
573        &[2, 4],
574        Validity::AllValid,
575        Validity::Array(BoolArray::from_iter([false, true, true, true, true]).into_array())
576    )]
577    #[case(
578        Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()),
579        5,
580        &[2, 4],
581        Validity::AllInvalid,
582        Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array())
583    )]
584    #[case(
585        Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()),
586        5,
587        &[2, 4],
588        Validity::Array(BoolArray::from_iter([true, false]).into_array()),
589        Validity::Array(BoolArray::from_iter([false, true, true, true, false]).into_array())
590    )]
591
592    fn patch_validity(
593        #[case] validity: Validity,
594        #[case] len: usize,
595        #[case] positions: &[u64],
596        #[case] patches: Validity,
597        #[case] expected: Validity,
598    ) {
599        let indices =
600            PrimitiveArray::new(Buffer::copy_from(positions), Validity::NonNullable).into_array();
601        assert_eq!(
602            validity.patch(len, 0, &indices, &patches).unwrap(),
603            expected
604        );
605    }
606
607    #[test]
608    #[should_panic]
609    fn out_of_bounds_patch() {
610        Validity::NonNullable
611            .patch(2, 0, &buffer![4].into_array(), &Validity::AllInvalid)
612            .unwrap();
613    }
614
615    #[test]
616    #[should_panic]
617    fn into_validity_nullable() {
618        Validity::from_mask(Mask::AllFalse(10), Nullability::NonNullable);
619    }
620
621    #[test]
622    #[should_panic]
623    fn into_validity_nullable_array() {
624        Validity::from_mask(Mask::from_iter(vec![true, false]), Nullability::NonNullable);
625    }
626
627    #[rstest]
628    #[case(
629        Validity::AllValid,
630        PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(),
631        Validity::from_iter(vec![true, false])
632    )]
633    #[case(Validity::AllValid, buffer![0, 1].into_array(), Validity::AllValid)]
634    #[case(
635        Validity::AllValid,
636        PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(),
637        Validity::AllInvalid
638    )]
639    #[case(
640        Validity::NonNullable,
641        PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(),
642        Validity::from_iter(vec![true, false])
643    )]
644    #[case(Validity::NonNullable, buffer![0, 1].into_array(), Validity::NonNullable)]
645    #[case(
646        Validity::NonNullable,
647        PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(),
648        Validity::AllInvalid
649    )]
650    fn validity_take(
651        #[case] validity: Validity,
652        #[case] indices: ArrayRef,
653        #[case] expected: Validity,
654    ) {
655        assert_eq!(validity.take(&indices).unwrap(), expected);
656    }
657}