Skip to main content

vortex_array/
validity.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Array validity and nullability behavior, used by arrays and compute functions.
5
6use std::fmt::Debug;
7use std::ops::Range;
8
9use vortex_buffer::BitBuffer;
10use vortex_error::VortexExpect as _;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_error::vortex_err;
14use vortex_error::vortex_panic;
15use vortex_mask::AllOr;
16use vortex_mask::Mask;
17use vortex_mask::MaskValues;
18
19use crate::Array;
20use crate::ArrayRef;
21use crate::Canonical;
22use crate::ExecutionCtx;
23use crate::IntoArray;
24use crate::ToCanonical;
25use crate::arrays::BoolArray;
26use crate::arrays::ConstantArray;
27use crate::arrays::ScalarFnArrayExt;
28use crate::builtins::ArrayBuiltins;
29use crate::compute::sum;
30use crate::dtype::DType;
31use crate::dtype::Nullability;
32use crate::optimizer::ArrayOptimizer;
33use crate::patches::Patches;
34use crate::scalar::Scalar;
35use crate::scalar_fn::fns::binary::Binary;
36use crate::scalar_fn::fns::operators::Operator;
37
38/// Validity information for an array
39#[derive(Clone, Debug)]
40pub enum Validity {
41    /// Items *can't* be null
42    NonNullable,
43    /// All items are valid
44    AllValid,
45    /// All items are null
46    AllInvalid,
47    /// The validity of each position in the array is determined by a boolean array.
48    ///
49    /// True values are valid, false values are invalid ("null").
50    Array(ArrayRef),
51}
52
53impl Validity {
54    /// Make a step towards canonicalising validity if necessary
55    pub fn execute(self, ctx: &mut ExecutionCtx) -> VortexResult<Validity> {
56        match self {
57            v @ Validity::NonNullable | v @ Validity::AllValid | v @ Validity::AllInvalid => Ok(v),
58            Validity::Array(a) => Ok(Validity::Array(a.execute::<Canonical>(ctx)?.into_array())),
59        }
60    }
61}
62
63impl Validity {
64    /// The [`DType`] of the underlying validity array (if it exists).
65    pub const DTYPE: DType = DType::Bool(Nullability::NonNullable);
66
67    /// Convert the validity to an array representation.
68    pub fn to_array(&self, len: usize) -> ArrayRef {
69        match self {
70            Self::NonNullable | Self::AllValid => ConstantArray::new(true, len).into_array(),
71            Self::AllInvalid => ConstantArray::new(false, len).into_array(),
72            Self::Array(a) => a.clone(),
73        }
74    }
75
76    /// If Validity is [`Validity::Array`], returns the array, otherwise returns `None`.
77    #[inline]
78    pub fn into_array(self) -> Option<ArrayRef> {
79        if let Self::Array(a) = self {
80            Some(a)
81        } else {
82            None
83        }
84    }
85
86    /// If Validity is [`Validity::Array`], returns a reference to the array array, otherwise returns `None`.
87    #[inline]
88    pub fn as_array(&self) -> Option<&ArrayRef> {
89        if let Self::Array(a) = self {
90            Some(a)
91        } else {
92            None
93        }
94    }
95
96    #[inline]
97    pub fn nullability(&self) -> Nullability {
98        if matches!(self, Self::NonNullable) {
99            Nullability::NonNullable
100        } else {
101            Nullability::Nullable
102        }
103    }
104
105    /// The union nullability and validity.
106    #[inline]
107    pub fn union_nullability(self, nullability: Nullability) -> Self {
108        match nullability {
109            Nullability::NonNullable => self,
110            Nullability::Nullable => self.into_nullable(),
111        }
112    }
113
114    #[inline]
115    pub fn all_valid(&self, len: usize) -> VortexResult<bool> {
116        Ok(match self {
117            _ if len == 0 => true,
118            Validity::NonNullable | Validity::AllValid => true,
119            Validity::AllInvalid => false,
120            Validity::Array(array) => {
121                usize::try_from(&sum(array).vortex_expect("must have sum for bool array"))
122                    .vortex_expect("sum must be a usize")
123                    == array.len()
124            }
125        })
126    }
127
128    #[inline]
129    pub fn all_invalid(&self, len: usize) -> VortexResult<bool> {
130        Ok(match self {
131            _ if len == 0 => true,
132            Validity::NonNullable | Validity::AllValid => false,
133            Validity::AllInvalid => true,
134            Validity::Array(array) => {
135                usize::try_from(&sum(array).vortex_expect("must have sum for bool array"))
136                    .vortex_expect("sum must be a usize")
137                    == 0
138            }
139        })
140    }
141
142    /// Returns whether the `index` item is valid.
143    #[inline]
144    pub fn is_valid(&self, index: usize) -> VortexResult<bool> {
145        Ok(match self {
146            Self::NonNullable | Self::AllValid => true,
147            Self::AllInvalid => false,
148            Self::Array(a) => a
149                .scalar_at(index)
150                .vortex_expect("Validity array must support scalar_at")
151                .as_bool()
152                .value()
153                .vortex_expect("Validity must be non-nullable"),
154        })
155    }
156
157    #[inline]
158    pub fn is_null(&self, index: usize) -> VortexResult<bool> {
159        Ok(!self.is_valid(index)?)
160    }
161
162    #[inline]
163    pub fn slice(&self, range: Range<usize>) -> VortexResult<Self> {
164        match self {
165            Self::Array(a) => Ok(Self::Array(a.slice(range)?)),
166            Self::NonNullable | Self::AllValid | Self::AllInvalid => Ok(self.clone()),
167        }
168    }
169
170    pub fn take(&self, indices: &ArrayRef) -> VortexResult<Self> {
171        match self {
172            Self::NonNullable => match indices.validity_mask()?.bit_buffer() {
173                AllOr::All => {
174                    if indices.dtype().is_nullable() {
175                        Ok(Self::AllValid)
176                    } else {
177                        Ok(Self::NonNullable)
178                    }
179                }
180                AllOr::None => Ok(Self::AllInvalid),
181                AllOr::Some(buf) => Ok(Validity::from(buf.clone())),
182            },
183            Self::AllValid => match indices.validity_mask()?.bit_buffer() {
184                AllOr::All => Ok(Self::AllValid),
185                AllOr::None => Ok(Self::AllInvalid),
186                AllOr::Some(buf) => Ok(Validity::from(buf.clone())),
187            },
188            Self::AllInvalid => Ok(Self::AllInvalid),
189            Self::Array(is_valid) => {
190                let maybe_is_valid = is_valid.take(indices.to_array())?;
191                // Null indices invalidate that position.
192                let is_valid = maybe_is_valid.fill_null(Scalar::from(false))?;
193                Ok(Self::Array(is_valid))
194            }
195        }
196    }
197
198    // Invert the validity
199    pub fn not(&self) -> VortexResult<Self> {
200        match self {
201            Validity::NonNullable => Ok(Validity::NonNullable),
202            Validity::AllValid => Ok(Validity::AllInvalid),
203            Validity::AllInvalid => Ok(Validity::AllValid),
204            Validity::Array(arr) => Ok(Validity::Array(arr.not()?)),
205        }
206    }
207
208    /// Lazily filters a [`Validity`] with a selection mask, which keeps only the entries for which
209    /// the mask is true.
210    ///
211    /// The result has length equal to the number of true values in mask.
212    ///
213    /// If the validity is a [`Validity::Array`], then this lazily wraps it in a `FilterArray`
214    /// instead of eagerly filtering the values immediately.
215    pub fn filter(&self, mask: &Mask) -> VortexResult<Self> {
216        // NOTE(ngates): we take the mask as a reference to avoid the caller cloning unnecessarily
217        //  if we happen to be NonNullable, AllValid, or AllInvalid.
218        match self {
219            v @ (Validity::NonNullable | Validity::AllValid | Validity::AllInvalid) => {
220                Ok(v.clone())
221            }
222            Validity::Array(arr) => Ok(Validity::Array(arr.filter(mask.clone())?)),
223        }
224    }
225
226    #[inline]
227    pub fn to_mask(&self, length: usize) -> Mask {
228        match self {
229            Self::NonNullable | Self::AllValid => Mask::AllTrue(length),
230            Self::AllInvalid => Mask::AllFalse(length),
231            Self::Array(is_valid) => {
232                assert_eq!(
233                    is_valid.len(),
234                    length,
235                    "Validity::Array length must equal to_logical's argument: {}, {}.",
236                    is_valid.len(),
237                    length,
238                );
239                is_valid.to_bool().to_mask()
240            }
241        }
242    }
243
244    /// Logically & two Validity values of the same length
245    #[inline]
246    pub fn and(self, rhs: Validity) -> VortexResult<Validity> {
247        Ok(match (self, rhs) {
248            // Should be pretty clear
249            (Validity::NonNullable, Validity::NonNullable) => Validity::NonNullable,
250            // Any `AllInvalid` makes the output all invalid values
251            (Validity::AllInvalid, _) | (_, Validity::AllInvalid) => Validity::AllInvalid,
252            // All truthy values on one side, which makes no effect on an `Array` variant
253            (Validity::Array(a), Validity::AllValid)
254            | (Validity::Array(a), Validity::NonNullable)
255            | (Validity::NonNullable, Validity::Array(a))
256            | (Validity::AllValid, Validity::Array(a)) => Validity::Array(a),
257            // Both sides are all valid
258            (Validity::NonNullable, Validity::AllValid)
259            | (Validity::AllValid, Validity::NonNullable)
260            | (Validity::AllValid, Validity::AllValid) => Validity::AllValid,
261            // Here we actually have to do some work
262            (Validity::Array(lhs), Validity::Array(rhs)) => Validity::Array(
263                Binary
264                    .try_new_array(lhs.len(), Operator::And, [lhs, rhs])?
265                    .optimize()?,
266            ),
267        })
268    }
269
270    pub fn patch(
271        self,
272        len: usize,
273        indices_offset: usize,
274        indices: &ArrayRef,
275        patches: &Validity,
276        ctx: &mut ExecutionCtx,
277    ) -> VortexResult<Self> {
278        match (&self, patches) {
279            (Validity::NonNullable, Validity::NonNullable) => return Ok(Validity::NonNullable),
280            (Validity::NonNullable, _) => {
281                vortex_bail!("Can't patch a non-nullable validity with nullable validity")
282            }
283            (_, Validity::NonNullable) => {
284                vortex_bail!("Can't patch a nullable validity with non-nullable validity")
285            }
286            (Validity::AllValid, Validity::AllValid) => return Ok(Validity::AllValid),
287            (Validity::AllInvalid, Validity::AllInvalid) => return Ok(Validity::AllInvalid),
288            _ => {}
289        };
290
291        let own_nullability = if self == Validity::NonNullable {
292            Nullability::NonNullable
293        } else {
294            Nullability::Nullable
295        };
296
297        let source = match self {
298            Validity::NonNullable => BoolArray::from(BitBuffer::new_set(len)),
299            Validity::AllValid => BoolArray::from(BitBuffer::new_set(len)),
300            Validity::AllInvalid => BoolArray::from(BitBuffer::new_unset(len)),
301            Validity::Array(a) => a.clone().execute::<BoolArray>(ctx)?,
302        };
303
304        let patch_values = match patches {
305            Validity::NonNullable => BoolArray::from(BitBuffer::new_set(indices.len())),
306            Validity::AllValid => BoolArray::from(BitBuffer::new_set(indices.len())),
307            Validity::AllInvalid => BoolArray::from(BitBuffer::new_unset(indices.len())),
308            Validity::Array(a) => a.clone().execute::<BoolArray>(ctx)?,
309        };
310
311        let patches = Patches::new(
312            len,
313            indices_offset,
314            indices.to_array(),
315            patch_values.into_array(),
316            // TODO(0ax1): chunk offsets
317            None,
318        )?;
319
320        Ok(Self::from_array(
321            source.patch(&patches, ctx)?.into_array(),
322            own_nullability,
323        ))
324    }
325
326    /// Convert into a nullable variant
327    #[inline]
328    pub fn into_nullable(self) -> Validity {
329        match self {
330            Self::NonNullable => Self::AllValid,
331            Self::AllValid | Self::AllInvalid | Self::Array(_) => self,
332        }
333    }
334
335    /// Convert into a non-nullable variant
336    #[inline]
337    pub fn into_non_nullable(self, len: usize) -> Option<Validity> {
338        match self {
339            _ if len == 0 => Some(Validity::NonNullable),
340            Self::NonNullable => Some(Self::NonNullable),
341            Self::AllValid => Some(Self::NonNullable),
342            Self::AllInvalid => None,
343            Self::Array(is_valid) => {
344                is_valid
345                    .statistics()
346                    .compute_min::<bool>()
347                    .vortex_expect("validity array must support min")
348                    .then(|| {
349                        // min true => all true
350                        Self::NonNullable
351                    })
352            }
353        }
354    }
355
356    /// Convert into a variant compatible with the given nullability, if possible.
357    #[inline]
358    pub fn cast_nullability(self, nullability: Nullability, len: usize) -> VortexResult<Validity> {
359        match nullability {
360            Nullability::NonNullable => self.into_non_nullable(len).ok_or_else(|| {
361                vortex_err!(InvalidArgument: "Cannot cast array with invalid values to non-nullable type.")
362            }),
363            Nullability::Nullable => Ok(self.into_nullable()),
364        }
365    }
366
367    /// Create Validity by copying the given array's validity.
368    #[inline]
369    pub fn copy_from_array(array: &ArrayRef) -> VortexResult<Self> {
370        Ok(Validity::from_mask(
371            array.validity_mask()?,
372            array.dtype().nullability(),
373        ))
374    }
375
376    /// Create Validity from boolean array with given nullability of the array.
377    ///
378    /// Note: You want to pass the nullability of parent array and not the nullability of the validity array itself
379    ///     as that is always nonnullable
380    #[inline]
381    fn from_array(value: ArrayRef, nullability: Nullability) -> Self {
382        if !matches!(value.dtype(), DType::Bool(Nullability::NonNullable)) {
383            vortex_panic!("Expected a non-nullable boolean array")
384        }
385        match nullability {
386            Nullability::NonNullable => Self::NonNullable,
387            Nullability::Nullable => Self::Array(value),
388        }
389    }
390
391    /// Returns the length of the validity array, if it exists.
392    #[inline]
393    pub fn maybe_len(&self) -> Option<usize> {
394        match self {
395            Self::NonNullable | Self::AllValid | Self::AllInvalid => None,
396            Self::Array(a) => Some(a.len()),
397        }
398    }
399
400    #[inline]
401    pub fn uncompressed_size(&self) -> usize {
402        if let Validity::Array(a) = self {
403            a.len().div_ceil(8)
404        } else {
405            0
406        }
407    }
408}
409
410impl PartialEq for Validity {
411    #[inline]
412    fn eq(&self, other: &Self) -> bool {
413        match (self, other) {
414            (Self::NonNullable, Self::NonNullable) => true,
415            (Self::AllValid, Self::AllValid) => true,
416            (Self::AllInvalid, Self::AllInvalid) => true,
417            (Self::Array(a), Self::Array(b)) => {
418                let a = a.to_bool();
419                let b = b.to_bool();
420                a.to_bit_buffer() == b.to_bit_buffer()
421            }
422            _ => false,
423        }
424    }
425}
426
427impl From<BitBuffer> for Validity {
428    #[inline]
429    fn from(value: BitBuffer) -> Self {
430        let true_count = value.true_count();
431        if true_count == value.len() {
432            Self::AllValid
433        } else if true_count == 0 {
434            Self::AllInvalid
435        } else {
436            Self::Array(BoolArray::from(value).into_array())
437        }
438    }
439}
440
441impl FromIterator<Mask> for Validity {
442    #[inline]
443    fn from_iter<T: IntoIterator<Item = Mask>>(iter: T) -> Self {
444        Validity::from_mask(iter.into_iter().collect(), Nullability::Nullable)
445    }
446}
447
448impl FromIterator<bool> for Validity {
449    #[inline]
450    fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
451        Validity::from(BitBuffer::from_iter(iter))
452    }
453}
454
455impl From<Nullability> for Validity {
456    #[inline]
457    fn from(value: Nullability) -> Self {
458        Validity::from(&value)
459    }
460}
461
462impl From<&Nullability> for Validity {
463    #[inline]
464    fn from(value: &Nullability) -> Self {
465        match *value {
466            Nullability::NonNullable => Validity::NonNullable,
467            Nullability::Nullable => Validity::AllValid,
468        }
469    }
470}
471
472impl Validity {
473    pub fn from_bit_buffer(buffer: BitBuffer, nullability: Nullability) -> Self {
474        if buffer.true_count() == buffer.len() {
475            nullability.into()
476        } else if buffer.true_count() == 0 {
477            Validity::AllInvalid
478        } else {
479            Validity::Array(BoolArray::new(buffer, Validity::NonNullable).into_array())
480        }
481    }
482
483    pub fn from_mask(mask: Mask, nullability: Nullability) -> Self {
484        assert!(
485            nullability == Nullability::Nullable || matches!(mask, Mask::AllTrue(_)),
486            "NonNullable validity must be AllValid",
487        );
488        match mask {
489            Mask::AllTrue(_) => match nullability {
490                Nullability::NonNullable => Validity::NonNullable,
491                Nullability::Nullable => Validity::AllValid,
492            },
493            Mask::AllFalse(_) => Validity::AllInvalid,
494            Mask::Values(values) => Validity::Array(values.into_array()),
495        }
496    }
497}
498
499impl IntoArray for Mask {
500    #[inline]
501    fn into_array(self) -> ArrayRef {
502        match self {
503            Self::AllTrue(len) => ConstantArray::new(true, len).into_array(),
504            Self::AllFalse(len) => ConstantArray::new(false, len).into_array(),
505            Self::Values(a) => a.into_array(),
506        }
507    }
508}
509
510impl IntoArray for &MaskValues {
511    #[inline]
512    fn into_array(self) -> ArrayRef {
513        BoolArray::new(self.bit_buffer().clone(), Validity::NonNullable).into_array()
514    }
515}
516
517#[cfg(test)]
518mod tests {
519    use rstest::rstest;
520    use vortex_buffer::Buffer;
521    use vortex_buffer::buffer;
522    use vortex_mask::Mask;
523
524    use crate::ArrayRef;
525    use crate::IntoArray;
526    use crate::LEGACY_SESSION;
527    use crate::VortexSessionExecute;
528    use crate::arrays::BoolArray;
529    use crate::arrays::PrimitiveArray;
530    use crate::dtype::Nullability;
531    use crate::validity::Validity;
532
533    #[rstest]
534    #[case(Validity::AllValid, 5, &[2, 4], Validity::AllValid, Validity::AllValid)]
535    #[case(
536        Validity::AllValid,
537        5,
538        &[2, 4],
539        Validity::AllInvalid,
540        Validity::Array(BoolArray::from_iter([true, true, false, true, false]).into_array())
541    )]
542    #[case(
543        Validity::AllValid,
544        5,
545        &[2, 4],
546        Validity::Array(BoolArray::from_iter([true, false]).into_array()),
547        Validity::Array(BoolArray::from_iter([true, true, true, true, false]).into_array())
548    )]
549    #[case(
550        Validity::AllInvalid,
551        5,
552        &[2, 4],
553        Validity::AllValid,
554        Validity::Array(BoolArray::from_iter([false, false, true, false, true]).into_array())
555    )]
556    #[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllInvalid, Validity::AllInvalid)]
557    #[case(
558        Validity::AllInvalid,
559        5,
560        &[2, 4],
561        Validity::Array(BoolArray::from_iter([true, false]).into_array()),
562        Validity::Array(BoolArray::from_iter([false, false, true, false, false]).into_array())
563    )]
564    #[case(
565        Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()),
566        5,
567        &[2, 4],
568        Validity::AllValid,
569        Validity::Array(BoolArray::from_iter([false, true, true, true, true]).into_array())
570    )]
571    #[case(
572        Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()),
573        5,
574        &[2, 4],
575        Validity::AllInvalid,
576        Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array())
577    )]
578    #[case(
579        Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()),
580        5,
581        &[2, 4],
582        Validity::Array(BoolArray::from_iter([true, false]).into_array()),
583        Validity::Array(BoolArray::from_iter([false, true, true, true, false]).into_array())
584    )]
585
586    fn patch_validity(
587        #[case] validity: Validity,
588        #[case] len: usize,
589        #[case] positions: &[u64],
590        #[case] patches: Validity,
591        #[case] expected: Validity,
592    ) {
593        let indices =
594            PrimitiveArray::new(Buffer::copy_from(positions), Validity::NonNullable).into_array();
595        assert_eq!(
596            validity
597                .patch(
598                    len,
599                    0,
600                    &indices,
601                    &patches,
602                    &mut LEGACY_SESSION.create_execution_ctx()
603                )
604                .unwrap(),
605            expected
606        );
607    }
608
609    #[test]
610    #[should_panic]
611    fn out_of_bounds_patch() {
612        Validity::NonNullable
613            .patch(
614                2,
615                0,
616                &buffer![4].into_array(),
617                &Validity::AllInvalid,
618                &mut LEGACY_SESSION.create_execution_ctx(),
619            )
620            .unwrap();
621    }
622
623    #[test]
624    #[should_panic]
625    fn into_validity_nullable() {
626        Validity::from_mask(Mask::AllFalse(10), Nullability::NonNullable);
627    }
628
629    #[test]
630    #[should_panic]
631    fn into_validity_nullable_array() {
632        Validity::from_mask(Mask::from_iter(vec![true, false]), Nullability::NonNullable);
633    }
634
635    #[rstest]
636    #[case(
637        Validity::AllValid,
638        PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(),
639        Validity::from_iter(vec![true, false])
640    )]
641    #[case(Validity::AllValid, buffer![0, 1].into_array(), Validity::AllValid)]
642    #[case(
643        Validity::AllValid,
644        PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(),
645        Validity::AllInvalid
646    )]
647    #[case(
648        Validity::NonNullable,
649        PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(),
650        Validity::from_iter(vec![true, false])
651    )]
652    #[case(Validity::NonNullable, buffer![0, 1].into_array(), Validity::NonNullable)]
653    #[case(
654        Validity::NonNullable,
655        PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(),
656        Validity::AllInvalid
657    )]
658    fn validity_take(
659        #[case] validity: Validity,
660        #[case] indices: ArrayRef,
661        #[case] expected: Validity,
662    ) {
663        assert_eq!(validity.take(&indices).unwrap(), expected);
664    }
665}