Skip to main content

vortex_array/
validity.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Array validity and nullability behavior, used by arrays and compute functions.
5
6use std::fmt::Debug;
7use std::ops::Range;
8
9use vortex_buffer::BitBuffer;
10use vortex_error::VortexExpect as _;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_error::vortex_err;
14use vortex_error::vortex_panic;
15use vortex_mask::AllOr;
16use vortex_mask::Mask;
17use vortex_mask::MaskValues;
18
19use crate::ArrayRef;
20use crate::Canonical;
21use crate::DynArray;
22use crate::ExecutionCtx;
23use crate::IntoArray;
24use crate::ToCanonical;
25use crate::arrays::BoolArray;
26use crate::arrays::ConstantArray;
27use crate::arrays::scalar_fn::ScalarFnArrayExt;
28use crate::builtins::ArrayBuiltins;
29use crate::compute::sum;
30use crate::dtype::DType;
31use crate::dtype::Nullability;
32use crate::optimizer::ArrayOptimizer;
33use crate::patches::Patches;
34use crate::scalar::Scalar;
35use crate::scalar_fn::fns::binary::Binary;
36use crate::scalar_fn::fns::operators::Operator;
37
38/// Validity information for an array
39#[derive(Clone)]
40pub enum Validity {
41    /// Items *can't* be null
42    NonNullable,
43    /// All items are valid
44    AllValid,
45    /// All items are null
46    AllInvalid,
47    /// The validity of each position in the array is determined by a boolean array.
48    ///
49    /// True values are valid, false values are invalid ("null").
50    Array(ArrayRef),
51}
52
53impl Debug for Validity {
54    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55        match self {
56            Self::NonNullable => write!(f, "NonNullable"),
57            Self::AllValid => write!(f, "AllValid"),
58            Self::AllInvalid => write!(f, "AllInvalid"),
59            Self::Array(arr) => write!(f, "SomeValid({})", arr.as_ref().display_values()),
60        }
61    }
62}
63
64impl Validity {
65    /// Make a step towards canonicalising validity if necessary
66    pub fn execute(self, ctx: &mut ExecutionCtx) -> VortexResult<Validity> {
67        match self {
68            v @ Validity::NonNullable | v @ Validity::AllValid | v @ Validity::AllInvalid => Ok(v),
69            Validity::Array(a) => Ok(Validity::Array(a.execute::<Canonical>(ctx)?.into_array())),
70        }
71    }
72}
73
74impl Validity {
75    /// The [`DType`] of the underlying validity array (if it exists).
76    pub const DTYPE: DType = DType::Bool(Nullability::NonNullable);
77
78    /// Convert the validity to an array representation.
79    pub fn to_array(&self, len: usize) -> ArrayRef {
80        match self {
81            Self::NonNullable | Self::AllValid => ConstantArray::new(true, len).into_array(),
82            Self::AllInvalid => ConstantArray::new(false, len).into_array(),
83            Self::Array(a) => a.clone(),
84        }
85    }
86
87    /// If Validity is [`Validity::Array`], returns the array, otherwise returns `None`.
88    #[inline]
89    pub fn into_array(self) -> Option<ArrayRef> {
90        if let Self::Array(a) = self {
91            Some(a)
92        } else {
93            None
94        }
95    }
96
97    /// If Validity is [`Validity::Array`], returns a reference to the array array, otherwise returns `None`.
98    #[inline]
99    pub fn as_array(&self) -> Option<&ArrayRef> {
100        if let Self::Array(a) = self {
101            Some(a)
102        } else {
103            None
104        }
105    }
106
107    #[inline]
108    pub fn nullability(&self) -> Nullability {
109        if matches!(self, Self::NonNullable) {
110            Nullability::NonNullable
111        } else {
112            Nullability::Nullable
113        }
114    }
115
116    /// The union nullability and validity.
117    #[inline]
118    pub fn union_nullability(self, nullability: Nullability) -> Self {
119        match nullability {
120            Nullability::NonNullable => self,
121            Nullability::Nullable => self.into_nullable(),
122        }
123    }
124
125    #[inline]
126    pub fn all_valid(&self, len: usize) -> VortexResult<bool> {
127        Ok(match self {
128            _ if len == 0 => true,
129            Validity::NonNullable | Validity::AllValid => true,
130            Validity::AllInvalid => false,
131            Validity::Array(array) => {
132                usize::try_from(&sum(array).vortex_expect("must have sum for bool array"))
133                    .vortex_expect("sum must be a usize")
134                    == array.len()
135            }
136        })
137    }
138
139    #[inline]
140    pub fn all_invalid(&self, len: usize) -> VortexResult<bool> {
141        Ok(match self {
142            _ if len == 0 => true,
143            Validity::NonNullable | Validity::AllValid => false,
144            Validity::AllInvalid => true,
145            Validity::Array(array) => {
146                usize::try_from(&sum(array).vortex_expect("must have sum for bool array"))
147                    .vortex_expect("sum must be a usize")
148                    == 0
149            }
150        })
151    }
152
153    /// Returns whether the `index` item is valid.
154    #[inline]
155    pub fn is_valid(&self, index: usize) -> VortexResult<bool> {
156        Ok(match self {
157            Self::NonNullable | Self::AllValid => true,
158            Self::AllInvalid => false,
159            Self::Array(a) => a
160                .scalar_at(index)
161                .vortex_expect("Validity array must support scalar_at")
162                .as_bool()
163                .value()
164                .vortex_expect("Validity must be non-nullable"),
165        })
166    }
167
168    #[inline]
169    pub fn is_null(&self, index: usize) -> VortexResult<bool> {
170        Ok(!self.is_valid(index)?)
171    }
172
173    #[inline]
174    pub fn slice(&self, range: Range<usize>) -> VortexResult<Self> {
175        match self {
176            Self::Array(a) => Ok(Self::Array(a.slice(range)?)),
177            Self::NonNullable | Self::AllValid | Self::AllInvalid => Ok(self.clone()),
178        }
179    }
180
181    pub fn take(&self, indices: &ArrayRef) -> VortexResult<Self> {
182        match self {
183            Self::NonNullable => match indices.validity_mask()?.bit_buffer() {
184                AllOr::All => {
185                    if indices.dtype().is_nullable() {
186                        Ok(Self::AllValid)
187                    } else {
188                        Ok(Self::NonNullable)
189                    }
190                }
191                AllOr::None => Ok(Self::AllInvalid),
192                AllOr::Some(buf) => Ok(Validity::from(buf.clone())),
193            },
194            Self::AllValid => match indices.validity_mask()?.bit_buffer() {
195                AllOr::All => Ok(Self::AllValid),
196                AllOr::None => Ok(Self::AllInvalid),
197                AllOr::Some(buf) => Ok(Validity::from(buf.clone())),
198            },
199            Self::AllInvalid => Ok(Self::AllInvalid),
200            Self::Array(is_valid) => {
201                let maybe_is_valid = is_valid.take(indices.to_array())?;
202                // Null indices invalidate that position.
203                let is_valid = maybe_is_valid.fill_null(Scalar::from(false))?;
204                Ok(Self::Array(is_valid))
205            }
206        }
207    }
208
209    // Invert the validity
210    pub fn not(&self) -> VortexResult<Self> {
211        match self {
212            Validity::NonNullable => Ok(Validity::NonNullable),
213            Validity::AllValid => Ok(Validity::AllInvalid),
214            Validity::AllInvalid => Ok(Validity::AllValid),
215            Validity::Array(arr) => Ok(Validity::Array(arr.not()?)),
216        }
217    }
218
219    /// Lazily filters a [`Validity`] with a selection mask, which keeps only the entries for which
220    /// the mask is true.
221    ///
222    /// The result has length equal to the number of true values in mask.
223    ///
224    /// If the validity is a [`Validity::Array`], then this lazily wraps it in a `FilterArray`
225    /// instead of eagerly filtering the values immediately.
226    pub fn filter(&self, mask: &Mask) -> VortexResult<Self> {
227        // NOTE(ngates): we take the mask as a reference to avoid the caller cloning unnecessarily
228        //  if we happen to be NonNullable, AllValid, or AllInvalid.
229        match self {
230            v @ (Validity::NonNullable | Validity::AllValid | Validity::AllInvalid) => {
231                Ok(v.clone())
232            }
233            Validity::Array(arr) => Ok(Validity::Array(arr.filter(mask.clone())?)),
234        }
235    }
236
237    #[inline]
238    pub fn to_mask(&self, length: usize) -> Mask {
239        match self {
240            Self::NonNullable | Self::AllValid => Mask::AllTrue(length),
241            Self::AllInvalid => Mask::AllFalse(length),
242            Self::Array(is_valid) => {
243                assert_eq!(
244                    is_valid.len(),
245                    length,
246                    "Validity::Array length must equal to_logical's argument: {}, {}.",
247                    is_valid.len(),
248                    length,
249                );
250                is_valid.to_bool().to_mask()
251            }
252        }
253    }
254
255    /// Logically & two Validity values of the same length
256    #[inline]
257    pub fn and(self, rhs: Validity) -> VortexResult<Validity> {
258        Ok(match (self, rhs) {
259            // Should be pretty clear
260            (Validity::NonNullable, Validity::NonNullable) => Validity::NonNullable,
261            // Any `AllInvalid` makes the output all invalid values
262            (Validity::AllInvalid, _) | (_, Validity::AllInvalid) => Validity::AllInvalid,
263            // All truthy values on one side, which makes no effect on an `Array` variant
264            (Validity::Array(a), Validity::AllValid)
265            | (Validity::Array(a), Validity::NonNullable)
266            | (Validity::NonNullable, Validity::Array(a))
267            | (Validity::AllValid, Validity::Array(a)) => Validity::Array(a),
268            // Both sides are all valid
269            (Validity::NonNullable, Validity::AllValid)
270            | (Validity::AllValid, Validity::NonNullable)
271            | (Validity::AllValid, Validity::AllValid) => Validity::AllValid,
272            // Here we actually have to do some work
273            (Validity::Array(lhs), Validity::Array(rhs)) => Validity::Array(
274                Binary
275                    .try_new_array(lhs.len(), Operator::And, [lhs, rhs])?
276                    .optimize()?,
277            ),
278        })
279    }
280
281    pub fn patch(
282        self,
283        len: usize,
284        indices_offset: usize,
285        indices: &ArrayRef,
286        patches: &Validity,
287        ctx: &mut ExecutionCtx,
288    ) -> VortexResult<Self> {
289        match (&self, patches) {
290            (Validity::NonNullable, Validity::NonNullable) => return Ok(Validity::NonNullable),
291            (Validity::NonNullable, _) => {
292                vortex_bail!("Can't patch a non-nullable validity with nullable validity")
293            }
294            (_, Validity::NonNullable) => {
295                vortex_bail!("Can't patch a nullable validity with non-nullable validity")
296            }
297            (Validity::AllValid, Validity::AllValid) => return Ok(Validity::AllValid),
298            (Validity::AllInvalid, Validity::AllInvalid) => return Ok(Validity::AllInvalid),
299            _ => {}
300        };
301
302        let own_nullability = if self == Validity::NonNullable {
303            Nullability::NonNullable
304        } else {
305            Nullability::Nullable
306        };
307
308        let source = match self {
309            Validity::NonNullable => BoolArray::from(BitBuffer::new_set(len)),
310            Validity::AllValid => BoolArray::from(BitBuffer::new_set(len)),
311            Validity::AllInvalid => BoolArray::from(BitBuffer::new_unset(len)),
312            Validity::Array(a) => a.clone().execute::<BoolArray>(ctx)?,
313        };
314
315        let patch_values = match patches {
316            Validity::NonNullable => BoolArray::from(BitBuffer::new_set(indices.len())),
317            Validity::AllValid => BoolArray::from(BitBuffer::new_set(indices.len())),
318            Validity::AllInvalid => BoolArray::from(BitBuffer::new_unset(indices.len())),
319            Validity::Array(a) => a.clone().execute::<BoolArray>(ctx)?,
320        };
321
322        let patches = Patches::new(
323            len,
324            indices_offset,
325            indices.to_array(),
326            patch_values.into_array(),
327            // TODO(0ax1): chunk offsets
328            None,
329        )?;
330
331        Ok(Self::from_array(
332            source.patch(&patches, ctx)?.into_array(),
333            own_nullability,
334        ))
335    }
336
337    /// Convert into a nullable variant
338    #[inline]
339    pub fn into_nullable(self) -> Validity {
340        match self {
341            Self::NonNullable => Self::AllValid,
342            Self::AllValid | Self::AllInvalid | Self::Array(_) => self,
343        }
344    }
345
346    /// Convert into a non-nullable variant
347    #[inline]
348    pub fn into_non_nullable(self, len: usize) -> Option<Validity> {
349        match self {
350            _ if len == 0 => Some(Validity::NonNullable),
351            Self::NonNullable => Some(Self::NonNullable),
352            Self::AllValid => Some(Self::NonNullable),
353            Self::AllInvalid => None,
354            Self::Array(is_valid) => {
355                is_valid
356                    .statistics()
357                    .compute_min::<bool>()
358                    .vortex_expect("validity array must support min")
359                    .then(|| {
360                        // min true => all true
361                        Self::NonNullable
362                    })
363            }
364        }
365    }
366
367    /// Convert into a variant compatible with the given nullability, if possible.
368    #[inline]
369    pub fn cast_nullability(self, nullability: Nullability, len: usize) -> VortexResult<Validity> {
370        match nullability {
371            Nullability::NonNullable => self.into_non_nullable(len).ok_or_else(|| {
372                vortex_err!(InvalidArgument: "Cannot cast array with invalid values to non-nullable type.")
373            }),
374            Nullability::Nullable => Ok(self.into_nullable()),
375        }
376    }
377
378    /// Create Validity by copying the given array's validity.
379    #[inline]
380    pub fn copy_from_array(array: &ArrayRef) -> VortexResult<Self> {
381        Ok(Validity::from_mask(
382            array.validity_mask()?,
383            array.dtype().nullability(),
384        ))
385    }
386
387    /// Create Validity from boolean array with given nullability of the array.
388    ///
389    /// Note: You want to pass the nullability of parent array and not the nullability of the validity array itself
390    ///     as that is always nonnullable
391    #[inline]
392    fn from_array(value: ArrayRef, nullability: Nullability) -> Self {
393        if !matches!(value.dtype(), DType::Bool(Nullability::NonNullable)) {
394            vortex_panic!("Expected a non-nullable boolean array")
395        }
396        match nullability {
397            Nullability::NonNullable => Self::NonNullable,
398            Nullability::Nullable => Self::Array(value),
399        }
400    }
401
402    /// Returns the length of the validity array, if it exists.
403    #[inline]
404    pub fn maybe_len(&self) -> Option<usize> {
405        match self {
406            Self::NonNullable | Self::AllValid | Self::AllInvalid => None,
407            Self::Array(a) => Some(a.len()),
408        }
409    }
410
411    #[inline]
412    pub fn uncompressed_size(&self) -> usize {
413        if let Validity::Array(a) = self {
414            a.len().div_ceil(8)
415        } else {
416            0
417        }
418    }
419}
420
421impl PartialEq for Validity {
422    #[inline]
423    fn eq(&self, other: &Self) -> bool {
424        match (self, other) {
425            (Self::NonNullable, Self::NonNullable) => true,
426            (Self::AllValid, Self::AllValid) => true,
427            (Self::AllInvalid, Self::AllInvalid) => true,
428            (Self::Array(a), Self::Array(b)) => {
429                let a = a.to_bool();
430                let b = b.to_bool();
431                a.to_bit_buffer() == b.to_bit_buffer()
432            }
433            _ => false,
434        }
435    }
436}
437
438impl From<BitBuffer> for Validity {
439    #[inline]
440    fn from(value: BitBuffer) -> Self {
441        let true_count = value.true_count();
442        if true_count == value.len() {
443            Self::AllValid
444        } else if true_count == 0 {
445            Self::AllInvalid
446        } else {
447            Self::Array(BoolArray::from(value).into_array())
448        }
449    }
450}
451
452impl FromIterator<Mask> for Validity {
453    #[inline]
454    fn from_iter<T: IntoIterator<Item = Mask>>(iter: T) -> Self {
455        Validity::from_mask(iter.into_iter().collect(), Nullability::Nullable)
456    }
457}
458
459impl FromIterator<bool> for Validity {
460    #[inline]
461    fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
462        Validity::from(BitBuffer::from_iter(iter))
463    }
464}
465
466impl From<Nullability> for Validity {
467    #[inline]
468    fn from(value: Nullability) -> Self {
469        Validity::from(&value)
470    }
471}
472
473impl From<&Nullability> for Validity {
474    #[inline]
475    fn from(value: &Nullability) -> Self {
476        match *value {
477            Nullability::NonNullable => Validity::NonNullable,
478            Nullability::Nullable => Validity::AllValid,
479        }
480    }
481}
482
483impl Validity {
484    pub fn from_bit_buffer(buffer: BitBuffer, nullability: Nullability) -> Self {
485        if buffer.true_count() == buffer.len() {
486            nullability.into()
487        } else if buffer.true_count() == 0 {
488            Validity::AllInvalid
489        } else {
490            Validity::Array(BoolArray::new(buffer, Validity::NonNullable).into_array())
491        }
492    }
493
494    pub fn from_mask(mask: Mask, nullability: Nullability) -> Self {
495        assert!(
496            nullability == Nullability::Nullable || matches!(mask, Mask::AllTrue(_)),
497            "NonNullable validity must be AllValid",
498        );
499        match mask {
500            Mask::AllTrue(_) => match nullability {
501                Nullability::NonNullable => Validity::NonNullable,
502                Nullability::Nullable => Validity::AllValid,
503            },
504            Mask::AllFalse(_) => Validity::AllInvalid,
505            Mask::Values(values) => Validity::Array(values.into_array()),
506        }
507    }
508}
509
510impl IntoArray for Mask {
511    #[inline]
512    fn into_array(self) -> ArrayRef {
513        match self {
514            Self::AllTrue(len) => ConstantArray::new(true, len).into_array(),
515            Self::AllFalse(len) => ConstantArray::new(false, len).into_array(),
516            Self::Values(a) => a.into_array(),
517        }
518    }
519}
520
521impl IntoArray for &MaskValues {
522    #[inline]
523    fn into_array(self) -> ArrayRef {
524        BoolArray::new(self.bit_buffer().clone(), Validity::NonNullable).into_array()
525    }
526}
527
528#[cfg(test)]
529mod tests {
530    use rstest::rstest;
531    use vortex_buffer::Buffer;
532    use vortex_buffer::buffer;
533    use vortex_mask::Mask;
534
535    use crate::ArrayRef;
536    use crate::IntoArray;
537    use crate::LEGACY_SESSION;
538    use crate::VortexSessionExecute;
539    use crate::arrays::PrimitiveArray;
540    use crate::dtype::Nullability;
541    use crate::validity::BoolArray;
542    use crate::validity::Validity;
543
544    #[rstest]
545    #[case(Validity::AllValid, 5, &[2, 4], Validity::AllValid, Validity::AllValid)]
546    #[case(
547        Validity::AllValid,
548        5,
549        &[2, 4],
550        Validity::AllInvalid,
551        Validity::Array(BoolArray::from_iter([true, true, false, true, false]).into_array())
552    )]
553    #[case(
554        Validity::AllValid,
555        5,
556        &[2, 4],
557        Validity::Array(BoolArray::from_iter([true, false]).into_array()),
558        Validity::Array(BoolArray::from_iter([true, true, true, true, false]).into_array())
559    )]
560    #[case(
561        Validity::AllInvalid,
562        5,
563        &[2, 4],
564        Validity::AllValid,
565        Validity::Array(BoolArray::from_iter([false, false, true, false, true]).into_array())
566    )]
567    #[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllInvalid, Validity::AllInvalid)]
568    #[case(
569        Validity::AllInvalid,
570        5,
571        &[2, 4],
572        Validity::Array(BoolArray::from_iter([true, false]).into_array()),
573        Validity::Array(BoolArray::from_iter([false, false, true, false, false]).into_array())
574    )]
575    #[case(
576        Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()),
577        5,
578        &[2, 4],
579        Validity::AllValid,
580        Validity::Array(BoolArray::from_iter([false, true, true, true, true]).into_array())
581    )]
582    #[case(
583        Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()),
584        5,
585        &[2, 4],
586        Validity::AllInvalid,
587        Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array())
588    )]
589    #[case(
590        Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()),
591        5,
592        &[2, 4],
593        Validity::Array(BoolArray::from_iter([true, false]).into_array()),
594        Validity::Array(BoolArray::from_iter([false, true, true, true, false]).into_array())
595    )]
596
597    fn patch_validity(
598        #[case] validity: Validity,
599        #[case] len: usize,
600        #[case] positions: &[u64],
601        #[case] patches: Validity,
602        #[case] expected: Validity,
603    ) {
604        let indices =
605            PrimitiveArray::new(Buffer::copy_from(positions), Validity::NonNullable).into_array();
606        assert_eq!(
607            validity
608                .patch(
609                    len,
610                    0,
611                    &indices,
612                    &patches,
613                    &mut LEGACY_SESSION.create_execution_ctx()
614                )
615                .unwrap(),
616            expected
617        );
618    }
619
620    #[test]
621    #[should_panic]
622    fn out_of_bounds_patch() {
623        Validity::NonNullable
624            .patch(
625                2,
626                0,
627                &buffer![4].into_array(),
628                &Validity::AllInvalid,
629                &mut LEGACY_SESSION.create_execution_ctx(),
630            )
631            .unwrap();
632    }
633
634    #[test]
635    #[should_panic]
636    fn into_validity_nullable() {
637        Validity::from_mask(Mask::AllFalse(10), Nullability::NonNullable);
638    }
639
640    #[test]
641    #[should_panic]
642    fn into_validity_nullable_array() {
643        Validity::from_mask(Mask::from_iter(vec![true, false]), Nullability::NonNullable);
644    }
645
646    #[rstest]
647    #[case(
648        Validity::AllValid,
649        PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(),
650        Validity::from_iter(vec![true, false])
651    )]
652    #[case(Validity::AllValid, buffer![0, 1].into_array(), Validity::AllValid)]
653    #[case(
654        Validity::AllValid,
655        PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(),
656        Validity::AllInvalid
657    )]
658    #[case(
659        Validity::NonNullable,
660        PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(),
661        Validity::from_iter(vec![true, false])
662    )]
663    #[case(Validity::NonNullable, buffer![0, 1].into_array(), Validity::NonNullable)]
664    #[case(
665        Validity::NonNullable,
666        PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(),
667        Validity::AllInvalid
668    )]
669    fn validity_take(
670        #[case] validity: Validity,
671        #[case] indices: ArrayRef,
672        #[case] expected: Validity,
673    ) {
674        assert_eq!(validity.take(&indices).unwrap(), expected);
675    }
676}