Skip to main content

vortex_array/
validity.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Array validity and nullability behavior, used by arrays and compute functions.
5
6use std::fmt::Debug;
7use std::ops::Range;
8
9use vortex_buffer::BitBuffer;
10use vortex_error::VortexExpect as _;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_error::vortex_err;
14use vortex_error::vortex_panic;
15use vortex_mask::AllOr;
16use vortex_mask::Mask;
17use vortex_mask::MaskValues;
18
19use crate::ArrayRef;
20use crate::Canonical;
21use crate::DynArray;
22use crate::ExecutionCtx;
23use crate::IntoArray;
24use crate::arrays::BoolArray;
25use crate::arrays::ConstantArray;
26use crate::arrays::scalar_fn::ScalarFnArrayExt;
27use crate::builtins::ArrayBuiltins;
28use crate::dtype::DType;
29use crate::dtype::Nullability;
30use crate::optimizer::ArrayOptimizer;
31use crate::patches::Patches;
32use crate::scalar::Scalar;
33use crate::scalar_fn::fns::binary::Binary;
34use crate::scalar_fn::fns::operators::Operator;
35
36/// Validity information for an array
37#[derive(Clone)]
38pub enum Validity {
39    /// Items *can't* be null
40    NonNullable,
41    /// All items are valid
42    AllValid,
43    /// All items are null
44    AllInvalid,
45    /// The validity of each position in the array is determined by a boolean array.
46    ///
47    /// True values are valid, false values are invalid ("null").
48    Array(ArrayRef),
49}
50
51impl Debug for Validity {
52    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53        match self {
54            Self::NonNullable => write!(f, "NonNullable"),
55            Self::AllValid => write!(f, "AllValid"),
56            Self::AllInvalid => write!(f, "AllInvalid"),
57            Self::Array(arr) => write!(f, "SomeValid({})", arr.as_ref().display_values()),
58        }
59    }
60}
61
62impl Validity {
63    /// Make a step towards canonicalising validity if necessary
64    pub fn execute(self, ctx: &mut ExecutionCtx) -> VortexResult<Validity> {
65        match self {
66            v @ Validity::NonNullable | v @ Validity::AllValid | v @ Validity::AllInvalid => Ok(v),
67            Validity::Array(a) => Ok(Validity::Array(a.execute::<Canonical>(ctx)?.into_array())),
68        }
69    }
70}
71
72impl Validity {
73    /// The [`DType`] of the underlying validity array (if it exists).
74    pub const DTYPE: DType = DType::Bool(Nullability::NonNullable);
75
76    /// Convert the validity to an array representation.
77    pub fn to_array(&self, len: usize) -> ArrayRef {
78        match self {
79            Self::NonNullable | Self::AllValid => ConstantArray::new(true, len).into_array(),
80            Self::AllInvalid => ConstantArray::new(false, len).into_array(),
81            Self::Array(a) => a.clone(),
82        }
83    }
84
85    /// If Validity is [`Validity::Array`], returns the array, otherwise returns `None`.
86    #[inline]
87    pub fn into_array(self) -> Option<ArrayRef> {
88        if let Self::Array(a) = self {
89            Some(a)
90        } else {
91            None
92        }
93    }
94
95    /// If Validity is [`Validity::Array`], returns a reference to the array array, otherwise returns `None`.
96    #[inline]
97    pub fn as_array(&self) -> Option<&ArrayRef> {
98        if let Self::Array(a) = self {
99            Some(a)
100        } else {
101            None
102        }
103    }
104
105    #[inline]
106    pub fn nullability(&self) -> Nullability {
107        if matches!(self, Self::NonNullable) {
108            Nullability::NonNullable
109        } else {
110            Nullability::Nullable
111        }
112    }
113
114    /// The union nullability and validity.
115    #[inline]
116    pub fn union_nullability(self, nullability: Nullability) -> Self {
117        match nullability {
118            Nullability::NonNullable => self,
119            Nullability::Nullable => self.into_nullable(),
120        }
121    }
122
123    /// Returns whether the `index` item is valid.
124    #[inline]
125    pub fn is_valid(&self, index: usize) -> VortexResult<bool> {
126        Ok(match self {
127            Self::NonNullable | Self::AllValid => true,
128            Self::AllInvalid => false,
129            Self::Array(a) => a
130                .scalar_at(index)
131                .vortex_expect("Validity array must support scalar_at")
132                .as_bool()
133                .value()
134                .vortex_expect("Validity must be non-nullable"),
135        })
136    }
137
138    #[inline]
139    pub fn is_null(&self, index: usize) -> VortexResult<bool> {
140        Ok(!self.is_valid(index)?)
141    }
142
143    #[inline]
144    pub fn slice(&self, range: Range<usize>) -> VortexResult<Self> {
145        match self {
146            Self::Array(a) => Ok(Self::Array(a.slice(range)?)),
147            Self::NonNullable | Self::AllValid | Self::AllInvalid => Ok(self.clone()),
148        }
149    }
150
151    pub fn take(&self, indices: &ArrayRef) -> VortexResult<Self> {
152        match self {
153            Self::NonNullable => match indices.validity_mask()?.bit_buffer() {
154                AllOr::All => {
155                    if indices.dtype().is_nullable() {
156                        Ok(Self::AllValid)
157                    } else {
158                        Ok(Self::NonNullable)
159                    }
160                }
161                AllOr::None => Ok(Self::AllInvalid),
162                AllOr::Some(buf) => Ok(Validity::from(buf.clone())),
163            },
164            Self::AllValid => match indices.validity_mask()?.bit_buffer() {
165                AllOr::All => Ok(Self::AllValid),
166                AllOr::None => Ok(Self::AllInvalid),
167                AllOr::Some(buf) => Ok(Validity::from(buf.clone())),
168            },
169            Self::AllInvalid => Ok(Self::AllInvalid),
170            Self::Array(is_valid) => {
171                let maybe_is_valid = is_valid.take(indices.to_array())?;
172                // Null indices invalidate that position.
173                let is_valid = maybe_is_valid.fill_null(Scalar::from(false))?;
174                Ok(Self::Array(is_valid))
175            }
176        }
177    }
178
179    // Invert the validity
180    pub fn not(&self) -> VortexResult<Self> {
181        match self {
182            Validity::NonNullable => Ok(Validity::NonNullable),
183            Validity::AllValid => Ok(Validity::AllInvalid),
184            Validity::AllInvalid => Ok(Validity::AllValid),
185            Validity::Array(arr) => Ok(Validity::Array(arr.not()?)),
186        }
187    }
188
189    /// Lazily filters a [`Validity`] with a selection mask, which keeps only the entries for which
190    /// the mask is true.
191    ///
192    /// The result has length equal to the number of true values in mask.
193    ///
194    /// If the validity is a [`Validity::Array`], then this lazily wraps it in a `FilterArray`
195    /// instead of eagerly filtering the values immediately.
196    pub fn filter(&self, mask: &Mask) -> VortexResult<Self> {
197        // NOTE(ngates): we take the mask as a reference to avoid the caller cloning unnecessarily
198        //  if we happen to be NonNullable, AllValid, or AllInvalid.
199        match self {
200            v @ (Validity::NonNullable | Validity::AllValid | Validity::AllInvalid) => {
201                Ok(v.clone())
202            }
203            Validity::Array(arr) => Ok(Validity::Array(arr.filter(mask.clone())?)),
204        }
205    }
206
207    pub fn execute_mask(&self, length: usize, ctx: &mut ExecutionCtx) -> VortexResult<Mask> {
208        match self {
209            Self::NonNullable | Self::AllValid => Ok(Mask::AllTrue(length)),
210            Self::AllInvalid => Ok(Mask::AllFalse(length)),
211            Self::Array(arr) => {
212                assert_eq!(
213                    arr.len(),
214                    length,
215                    "Validity::Array length must equal to_logical's argument: {}, {}.",
216                    arr.len(),
217                    length,
218                );
219                // TODO(ngates): I'm not sure execution should take arrays by ownership.
220                //  If so we should fix call sites to clone and this function takes self.
221                arr.clone().execute::<Mask>(ctx)
222            }
223        }
224    }
225
226    /// Compare two Validity values of the same length by executing them into masks if necessary.
227    pub fn mask_eq(&self, other: &Validity, ctx: &mut ExecutionCtx) -> VortexResult<bool> {
228        match (self, other) {
229            (Validity::NonNullable, Validity::NonNullable) => Ok(true),
230            (Validity::AllValid, Validity::AllValid) => Ok(true),
231            (Validity::AllInvalid, Validity::AllInvalid) => Ok(true),
232            (Validity::Array(a), Validity::Array(b)) => {
233                let a = a.clone().execute::<Mask>(ctx)?;
234                let b = b.clone().execute::<Mask>(ctx)?;
235                Ok(a == b)
236            }
237            _ => Ok(false),
238        }
239    }
240
241    /// Logically & two Validity values of the same length
242    #[inline]
243    pub fn and(self, rhs: Validity) -> VortexResult<Validity> {
244        Ok(match (self, rhs) {
245            // Should be pretty clear
246            (Validity::NonNullable, Validity::NonNullable) => Validity::NonNullable,
247            // Any `AllInvalid` makes the output all invalid values
248            (Validity::AllInvalid, _) | (_, Validity::AllInvalid) => Validity::AllInvalid,
249            // All truthy values on one side, which makes no effect on an `Array` variant
250            (Validity::Array(a), Validity::AllValid)
251            | (Validity::Array(a), Validity::NonNullable)
252            | (Validity::NonNullable, Validity::Array(a))
253            | (Validity::AllValid, Validity::Array(a)) => Validity::Array(a),
254            // Both sides are all valid
255            (Validity::NonNullable, Validity::AllValid)
256            | (Validity::AllValid, Validity::NonNullable)
257            | (Validity::AllValid, Validity::AllValid) => Validity::AllValid,
258            // Here we actually have to do some work
259            (Validity::Array(lhs), Validity::Array(rhs)) => Validity::Array(
260                Binary
261                    .try_new_array(lhs.len(), Operator::And, [lhs, rhs])?
262                    .optimize()?,
263            ),
264        })
265    }
266
267    pub fn patch(
268        self,
269        len: usize,
270        indices_offset: usize,
271        indices: &ArrayRef,
272        patches: &Validity,
273        ctx: &mut ExecutionCtx,
274    ) -> VortexResult<Self> {
275        match (&self, patches) {
276            (Validity::NonNullable, Validity::NonNullable) => return Ok(Validity::NonNullable),
277            (Validity::NonNullable, _) => {
278                vortex_bail!("Can't patch a non-nullable validity with nullable validity")
279            }
280            (_, Validity::NonNullable) => {
281                vortex_bail!("Can't patch a nullable validity with non-nullable validity")
282            }
283            (Validity::AllValid, Validity::AllValid) => return Ok(Validity::AllValid),
284            (Validity::AllInvalid, Validity::AllInvalid) => return Ok(Validity::AllInvalid),
285            _ => {}
286        };
287
288        let own_nullability = if matches!(self, Validity::NonNullable) {
289            Nullability::NonNullable
290        } else {
291            Nullability::Nullable
292        };
293
294        let source = match self {
295            Validity::NonNullable => BoolArray::from(BitBuffer::new_set(len)),
296            Validity::AllValid => BoolArray::from(BitBuffer::new_set(len)),
297            Validity::AllInvalid => BoolArray::from(BitBuffer::new_unset(len)),
298            Validity::Array(a) => a.clone().execute::<BoolArray>(ctx)?,
299        };
300
301        let patch_values = match patches {
302            Validity::NonNullable => BoolArray::from(BitBuffer::new_set(indices.len())),
303            Validity::AllValid => BoolArray::from(BitBuffer::new_set(indices.len())),
304            Validity::AllInvalid => BoolArray::from(BitBuffer::new_unset(indices.len())),
305            Validity::Array(a) => a.clone().execute::<BoolArray>(ctx)?,
306        };
307
308        let patches = Patches::new(
309            len,
310            indices_offset,
311            indices.to_array(),
312            patch_values.into_array(),
313            // TODO(0ax1): chunk offsets
314            None,
315        )?;
316
317        Ok(Self::from_array(
318            source.patch(&patches, ctx)?.into_array(),
319            own_nullability,
320        ))
321    }
322
323    /// Convert into a nullable variant
324    #[inline]
325    pub fn into_nullable(self) -> Validity {
326        match self {
327            Self::NonNullable => Self::AllValid,
328            Self::AllValid | Self::AllInvalid | Self::Array(_) => self,
329        }
330    }
331
332    /// Convert into a non-nullable variant
333    #[inline]
334    pub fn into_non_nullable(self, len: usize) -> Option<Validity> {
335        match self {
336            _ if len == 0 => Some(Validity::NonNullable),
337            Self::NonNullable => Some(Self::NonNullable),
338            Self::AllValid => Some(Self::NonNullable),
339            Self::AllInvalid => None,
340            Self::Array(is_valid) => {
341                is_valid
342                    .statistics()
343                    .compute_min::<bool>()
344                    .vortex_expect("validity array must support min")
345                    .then(|| {
346                        // min true => all true
347                        Self::NonNullable
348                    })
349            }
350        }
351    }
352
353    /// Convert into a variant compatible with the given nullability, if possible.
354    #[inline]
355    pub fn cast_nullability(self, nullability: Nullability, len: usize) -> VortexResult<Validity> {
356        match nullability {
357            Nullability::NonNullable => self.into_non_nullable(len).ok_or_else(|| {
358                vortex_err!(InvalidArgument: "Cannot cast array with invalid values to non-nullable type.")
359            }),
360            Nullability::Nullable => Ok(self.into_nullable()),
361        }
362    }
363
364    /// Create Validity by copying the given array's validity.
365    #[inline]
366    pub fn copy_from_array(array: &ArrayRef) -> VortexResult<Self> {
367        Ok(Validity::from_mask(
368            array.validity_mask()?,
369            array.dtype().nullability(),
370        ))
371    }
372
373    /// Create Validity from boolean array with given nullability of the array.
374    ///
375    /// Note: You want to pass the nullability of parent array and not the nullability of the validity array itself
376    ///     as that is always nonnullable
377    fn from_array(value: ArrayRef, nullability: Nullability) -> Self {
378        if !matches!(value.dtype(), DType::Bool(Nullability::NonNullable)) {
379            vortex_panic!("Expected a non-nullable boolean array")
380        }
381        match nullability {
382            Nullability::NonNullable => Self::NonNullable,
383            Nullability::Nullable => Self::Array(value),
384        }
385    }
386
387    /// Returns the length of the validity array, if it exists.
388    #[inline]
389    pub fn maybe_len(&self) -> Option<usize> {
390        match self {
391            Self::NonNullable | Self::AllValid | Self::AllInvalid => None,
392            Self::Array(a) => Some(a.len()),
393        }
394    }
395
396    #[inline]
397    pub fn uncompressed_size(&self) -> usize {
398        if let Validity::Array(a) = self {
399            a.len().div_ceil(8)
400        } else {
401            0
402        }
403    }
404}
405
406impl From<BitBuffer> for Validity {
407    #[inline]
408    fn from(value: BitBuffer) -> Self {
409        let true_count = value.true_count();
410        if true_count == value.len() {
411            Self::AllValid
412        } else if true_count == 0 {
413            Self::AllInvalid
414        } else {
415            Self::Array(BoolArray::from(value).into_array())
416        }
417    }
418}
419
420impl FromIterator<Mask> for Validity {
421    #[inline]
422    fn from_iter<T: IntoIterator<Item = Mask>>(iter: T) -> Self {
423        Validity::from_mask(iter.into_iter().collect(), Nullability::Nullable)
424    }
425}
426
427impl FromIterator<bool> for Validity {
428    #[inline]
429    fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
430        Validity::from(BitBuffer::from_iter(iter))
431    }
432}
433
434impl From<Nullability> for Validity {
435    #[inline]
436    fn from(value: Nullability) -> Self {
437        Validity::from(&value)
438    }
439}
440
441impl From<&Nullability> for Validity {
442    #[inline]
443    fn from(value: &Nullability) -> Self {
444        match *value {
445            Nullability::NonNullable => Validity::NonNullable,
446            Nullability::Nullable => Validity::AllValid,
447        }
448    }
449}
450
451impl Validity {
452    pub fn from_bit_buffer(buffer: BitBuffer, nullability: Nullability) -> Self {
453        if buffer.true_count() == buffer.len() {
454            nullability.into()
455        } else if buffer.true_count() == 0 {
456            Validity::AllInvalid
457        } else {
458            Validity::Array(BoolArray::new(buffer, Validity::NonNullable).into_array())
459        }
460    }
461
462    pub fn from_mask(mask: Mask, nullability: Nullability) -> Self {
463        assert!(
464            nullability == Nullability::Nullable || matches!(mask, Mask::AllTrue(_)),
465            "NonNullable validity must be AllValid",
466        );
467        match mask {
468            Mask::AllTrue(_) => match nullability {
469                Nullability::NonNullable => Validity::NonNullable,
470                Nullability::Nullable => Validity::AllValid,
471            },
472            Mask::AllFalse(_) => Validity::AllInvalid,
473            Mask::Values(values) => Validity::Array(values.into_array()),
474        }
475    }
476}
477
478impl IntoArray for Mask {
479    #[inline]
480    fn into_array(self) -> ArrayRef {
481        match self {
482            Self::AllTrue(len) => ConstantArray::new(true, len).into_array(),
483            Self::AllFalse(len) => ConstantArray::new(false, len).into_array(),
484            Self::Values(a) => a.into_array(),
485        }
486    }
487}
488
489impl IntoArray for &MaskValues {
490    #[inline]
491    fn into_array(self) -> ArrayRef {
492        BoolArray::new(self.bit_buffer().clone(), Validity::NonNullable).into_array()
493    }
494}
495
496#[cfg(test)]
497mod tests {
498    use rstest::rstest;
499    use vortex_buffer::Buffer;
500    use vortex_buffer::buffer;
501    use vortex_mask::Mask;
502
503    use crate::ArrayRef;
504    use crate::IntoArray;
505    use crate::LEGACY_SESSION;
506    use crate::VortexSessionExecute;
507    use crate::arrays::PrimitiveArray;
508    use crate::dtype::Nullability;
509    use crate::validity::BoolArray;
510    use crate::validity::Validity;
511
512    #[rstest]
513    #[case(Validity::AllValid, 5, &[2, 4], Validity::AllValid, Validity::AllValid)]
514    #[case(
515        Validity::AllValid,
516        5,
517        &[2, 4],
518        Validity::AllInvalid,
519        Validity::Array(BoolArray::from_iter([true, true, false, true, false]).into_array())
520    )]
521    #[case(
522        Validity::AllValid,
523        5,
524        &[2, 4],
525        Validity::Array(BoolArray::from_iter([true, false]).into_array()),
526        Validity::Array(BoolArray::from_iter([true, true, true, true, false]).into_array())
527    )]
528    #[case(
529        Validity::AllInvalid,
530        5,
531        &[2, 4],
532        Validity::AllValid,
533        Validity::Array(BoolArray::from_iter([false, false, true, false, true]).into_array())
534    )]
535    #[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllInvalid, Validity::AllInvalid)]
536    #[case(
537        Validity::AllInvalid,
538        5,
539        &[2, 4],
540        Validity::Array(BoolArray::from_iter([true, false]).into_array()),
541        Validity::Array(BoolArray::from_iter([false, false, true, false, false]).into_array())
542    )]
543    #[case(
544        Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()),
545        5,
546        &[2, 4],
547        Validity::AllValid,
548        Validity::Array(BoolArray::from_iter([false, true, true, true, true]).into_array())
549    )]
550    #[case(
551        Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()),
552        5,
553        &[2, 4],
554        Validity::AllInvalid,
555        Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array())
556    )]
557    #[case(
558        Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()),
559        5,
560        &[2, 4],
561        Validity::Array(BoolArray::from_iter([true, false]).into_array()),
562        Validity::Array(BoolArray::from_iter([false, true, true, true, false]).into_array())
563    )]
564
565    fn patch_validity(
566        #[case] validity: Validity,
567        #[case] len: usize,
568        #[case] positions: &[u64],
569        #[case] patches: Validity,
570        #[case] expected: Validity,
571    ) {
572        let indices =
573            PrimitiveArray::new(Buffer::copy_from(positions), Validity::NonNullable).into_array();
574
575        let mut ctx = LEGACY_SESSION.create_execution_ctx();
576
577        assert!(
578            validity
579                .patch(
580                    len,
581                    0,
582                    &indices,
583                    &patches,
584                    &mut LEGACY_SESSION.create_execution_ctx(),
585                )
586                .unwrap()
587                .mask_eq(&expected, &mut ctx)
588                .unwrap()
589        );
590    }
591
592    #[test]
593    #[should_panic]
594    fn out_of_bounds_patch() {
595        Validity::NonNullable
596            .patch(
597                2,
598                0,
599                &buffer![4].into_array(),
600                &Validity::AllInvalid,
601                &mut LEGACY_SESSION.create_execution_ctx(),
602            )
603            .unwrap();
604    }
605
606    #[test]
607    #[should_panic]
608    fn into_validity_nullable() {
609        Validity::from_mask(Mask::AllFalse(10), Nullability::NonNullable);
610    }
611
612    #[test]
613    #[should_panic]
614    fn into_validity_nullable_array() {
615        Validity::from_mask(Mask::from_iter(vec![true, false]), Nullability::NonNullable);
616    }
617
618    #[rstest]
619    #[case(
620        Validity::AllValid,
621        PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(),
622        Validity::from_iter(vec![true, false])
623    )]
624    #[case(Validity::AllValid, buffer![0, 1].into_array(), Validity::AllValid)]
625    #[case(
626        Validity::AllValid,
627        PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(),
628        Validity::AllInvalid
629    )]
630    #[case(
631        Validity::NonNullable,
632        PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(),
633        Validity::from_iter(vec![true, false])
634    )]
635    #[case(Validity::NonNullable, buffer![0, 1].into_array(), Validity::NonNullable)]
636    #[case(
637        Validity::NonNullable,
638        PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(),
639        Validity::AllInvalid
640    )]
641    fn validity_take(
642        #[case] validity: Validity,
643        #[case] indices: ArrayRef,
644        #[case] expected: Validity,
645    ) {
646        let mut ctx = LEGACY_SESSION.create_execution_ctx();
647        assert!(
648            validity
649                .take(&indices)
650                .unwrap()
651                .mask_eq(&expected, &mut ctx)
652                .unwrap()
653        );
654    }
655}