Skip to main content

vortex_array/
validity.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Array validity and nullability behavior, used by arrays and compute functions.
5
6use std::fmt::Debug;
7use std::ops::Range;
8
9use vortex_buffer::BitBuffer;
10use vortex_error::VortexExpect as _;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_error::vortex_err;
14use vortex_error::vortex_panic;
15use vortex_mask::AllOr;
16use vortex_mask::Mask;
17use vortex_mask::MaskValues;
18
19use crate::ArrayRef;
20use crate::Canonical;
21use crate::ExecutionCtx;
22use crate::IntoArray;
23use crate::ToCanonical;
24use crate::arrays::BoolArray;
25use crate::arrays::ConstantArray;
26use crate::arrays::bool::BoolArrayExt;
27use crate::arrays::scalar_fn::ScalarFnFactoryExt;
28use crate::builtins::ArrayBuiltins;
29use crate::dtype::DType;
30use crate::dtype::Nullability;
31use crate::optimizer::ArrayOptimizer;
32use crate::patches::Patches;
33use crate::scalar::Scalar;
34use crate::scalar_fn::fns::binary::Binary;
35use crate::scalar_fn::fns::operators::Operator;
36
37/// Validity information for an array
38#[derive(Clone)]
39pub enum Validity {
40    /// Items *can't* be null
41    NonNullable,
42    /// All items are valid
43    AllValid,
44    /// All items are null
45    AllInvalid,
46    /// The validity of each position in the array is determined by a boolean array.
47    ///
48    /// True values are valid, false values are invalid ("null").
49    Array(ArrayRef),
50}
51
52impl Debug for Validity {
53    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54        match self {
55            Self::NonNullable => write!(f, "NonNullable"),
56            Self::AllValid => write!(f, "AllValid"),
57            Self::AllInvalid => write!(f, "AllInvalid"),
58            Self::Array(arr) => write!(f, "SomeValid({})", arr.display_values()),
59        }
60    }
61}
62
63impl Validity {
64    /// Make a step towards canonicalising validity if necessary
65    pub fn execute(self, ctx: &mut ExecutionCtx) -> VortexResult<Validity> {
66        match self {
67            v @ Validity::NonNullable | v @ Validity::AllValid | v @ Validity::AllInvalid => Ok(v),
68            Validity::Array(a) => Ok(Validity::Array(a.execute::<Canonical>(ctx)?.into_array())),
69        }
70    }
71}
72
73impl Validity {
74    /// The [`DType`] of the underlying validity array (if it exists).
75    pub const DTYPE: DType = DType::Bool(Nullability::NonNullable);
76
77    /// Convert the validity to an array representation.
78    pub fn to_array(&self, len: usize) -> ArrayRef {
79        match self {
80            Self::NonNullable | Self::AllValid => ConstantArray::new(true, len).into_array(),
81            Self::AllInvalid => ConstantArray::new(false, len).into_array(),
82            Self::Array(a) => a.clone(),
83        }
84    }
85
86    /// If Validity is [`Validity::Array`], returns the array, otherwise returns `None`.
87    #[inline]
88    pub fn into_array(self) -> Option<ArrayRef> {
89        if let Self::Array(a) = self {
90            Some(a)
91        } else {
92            None
93        }
94    }
95
96    /// If Validity is [`Validity::Array`], returns a reference to the array array, otherwise returns `None`.
97    #[inline]
98    pub fn as_array(&self) -> Option<&ArrayRef> {
99        if let Self::Array(a) = self {
100            Some(a)
101        } else {
102            None
103        }
104    }
105
106    #[inline]
107    pub fn nullability(&self) -> Nullability {
108        if matches!(self, Self::NonNullable) {
109            Nullability::NonNullable
110        } else {
111            Nullability::Nullable
112        }
113    }
114
115    /// The union nullability and validity.
116    #[inline]
117    pub fn union_nullability(self, nullability: Nullability) -> Self {
118        match nullability {
119            Nullability::NonNullable => self,
120            Nullability::Nullable => self.into_nullable(),
121        }
122    }
123
124    /// Returns whether the `index` item is valid.
125    #[inline]
126    pub fn is_valid(&self, index: usize) -> VortexResult<bool> {
127        Ok(match self {
128            Self::NonNullable | Self::AllValid => true,
129            Self::AllInvalid => false,
130            Self::Array(a) => a
131                .scalar_at(index)
132                .vortex_expect("Validity array must support scalar_at")
133                .as_bool()
134                .value()
135                .vortex_expect("Validity must be non-nullable"),
136        })
137    }
138
139    #[inline]
140    pub fn is_null(&self, index: usize) -> VortexResult<bool> {
141        Ok(!self.is_valid(index)?)
142    }
143
144    #[inline]
145    pub fn slice(&self, range: Range<usize>) -> VortexResult<Self> {
146        match self {
147            Self::Array(a) => Ok(Self::Array(a.slice(range)?)),
148            Self::NonNullable | Self::AllValid | Self::AllInvalid => Ok(self.clone()),
149        }
150    }
151
152    pub fn take(&self, indices: &ArrayRef) -> VortexResult<Self> {
153        match self {
154            Self::NonNullable => match indices.validity_mask()?.bit_buffer() {
155                AllOr::All => {
156                    if indices.dtype().is_nullable() {
157                        Ok(Self::AllValid)
158                    } else {
159                        Ok(Self::NonNullable)
160                    }
161                }
162                AllOr::None => Ok(Self::AllInvalid),
163                AllOr::Some(buf) => Ok(Validity::from(buf.clone())),
164            },
165            Self::AllValid => match indices.validity_mask()?.bit_buffer() {
166                AllOr::All => Ok(Self::AllValid),
167                AllOr::None => Ok(Self::AllInvalid),
168                AllOr::Some(buf) => Ok(Validity::from(buf.clone())),
169            },
170            Self::AllInvalid => Ok(Self::AllInvalid),
171            Self::Array(is_valid) => {
172                let maybe_is_valid = is_valid.take(indices.clone())?;
173                // Null indices invalidate that position.
174                let is_valid = maybe_is_valid.fill_null(Scalar::from(false))?;
175                Ok(Self::Array(is_valid))
176            }
177        }
178    }
179
180    // Invert the validity
181    pub fn not(&self) -> VortexResult<Self> {
182        match self {
183            Validity::NonNullable => Ok(Validity::NonNullable),
184            Validity::AllValid => Ok(Validity::AllInvalid),
185            Validity::AllInvalid => Ok(Validity::AllValid),
186            Validity::Array(arr) => Ok(Validity::Array(arr.not()?)),
187        }
188    }
189
190    /// Lazily filters a [`Validity`] with a selection mask, which keeps only the entries for which
191    /// the mask is true.
192    ///
193    /// The result has length equal to the number of true values in mask.
194    ///
195    /// If the validity is a [`Validity::Array`], then this lazily wraps it in a `FilterArray`
196    /// instead of eagerly filtering the values immediately.
197    pub fn filter(&self, mask: &Mask) -> VortexResult<Self> {
198        // NOTE(ngates): we take the mask as a reference to avoid the caller cloning unnecessarily
199        //  if we happen to be NonNullable, AllValid, or AllInvalid.
200        match self {
201            v @ (Validity::NonNullable | Validity::AllValid | Validity::AllInvalid) => {
202                Ok(v.clone())
203            }
204            Validity::Array(arr) => Ok(Validity::Array(arr.filter(mask.clone())?)),
205        }
206    }
207
208    /// Converts this validity into a [`Mask`] of the given length.
209    ///
210    /// Valid elements are `true` and invalid elements are `false`.
211    pub fn to_mask(&self, length: usize) -> Mask {
212        match self {
213            Self::NonNullable | Self::AllValid => Mask::new_true(length),
214            Self::AllInvalid => Mask::new_false(length),
215            Self::Array(a) => a.to_bool().to_mask(),
216        }
217    }
218
219    pub fn execute_mask(&self, length: usize, ctx: &mut ExecutionCtx) -> VortexResult<Mask> {
220        match self {
221            Self::NonNullable | Self::AllValid => Ok(Mask::AllTrue(length)),
222            Self::AllInvalid => Ok(Mask::AllFalse(length)),
223            Self::Array(arr) => {
224                assert_eq!(
225                    arr.len(),
226                    length,
227                    "Validity::Array length must equal to_logical's argument: {}, {}.",
228                    arr.len(),
229                    length,
230                );
231                // TODO(ngates): I'm not sure execution should take arrays by ownership.
232                //  If so we should fix call sites to clone and this function takes self.
233                arr.clone().execute::<Mask>(ctx)
234            }
235        }
236    }
237
238    /// Compare two Validity values of the same length by executing them into masks if necessary.
239    pub fn mask_eq(&self, other: &Validity, ctx: &mut ExecutionCtx) -> VortexResult<bool> {
240        match (self, other) {
241            (Validity::NonNullable, Validity::NonNullable) => Ok(true),
242            (Validity::AllValid, Validity::AllValid) => Ok(true),
243            (Validity::AllInvalid, Validity::AllInvalid) => Ok(true),
244            (Validity::Array(a), Validity::Array(b)) => {
245                let a = a.clone().execute::<Mask>(ctx)?;
246                let b = b.clone().execute::<Mask>(ctx)?;
247                Ok(a == b)
248            }
249            _ => Ok(false),
250        }
251    }
252
253    /// Logically & two Validity values of the same length
254    #[inline]
255    pub fn and(self, rhs: Validity) -> VortexResult<Validity> {
256        Ok(match (self, rhs) {
257            // Should be pretty clear
258            (Validity::NonNullable, Validity::NonNullable) => Validity::NonNullable,
259            // Any `AllInvalid` makes the output all invalid values
260            (Validity::AllInvalid, _) | (_, Validity::AllInvalid) => Validity::AllInvalid,
261            // All truthy values on one side, which makes no effect on an `Array` variant
262            (Validity::Array(a), Validity::AllValid)
263            | (Validity::Array(a), Validity::NonNullable)
264            | (Validity::NonNullable, Validity::Array(a))
265            | (Validity::AllValid, Validity::Array(a)) => Validity::Array(a),
266            // Both sides are all valid
267            (Validity::NonNullable, Validity::AllValid)
268            | (Validity::AllValid, Validity::NonNullable)
269            | (Validity::AllValid, Validity::AllValid) => Validity::AllValid,
270            // Here we actually have to do some work
271            (Validity::Array(lhs), Validity::Array(rhs)) => Validity::Array(
272                Binary
273                    .try_new_array(lhs.len(), Operator::And, [lhs, rhs])?
274                    .optimize()?,
275            ),
276        })
277    }
278
279    pub fn patch(
280        self,
281        len: usize,
282        indices_offset: usize,
283        indices: &ArrayRef,
284        patches: &Validity,
285        ctx: &mut ExecutionCtx,
286    ) -> VortexResult<Self> {
287        match (&self, patches) {
288            (Validity::NonNullable, Validity::NonNullable) => return Ok(Validity::NonNullable),
289            (Validity::NonNullable, _) => {
290                vortex_bail!("Can't patch a non-nullable validity with nullable validity")
291            }
292            (_, Validity::NonNullable) => {
293                vortex_bail!("Can't patch a nullable validity with non-nullable validity")
294            }
295            (Validity::AllValid, Validity::AllValid) => return Ok(Validity::AllValid),
296            (Validity::AllInvalid, Validity::AllInvalid) => return Ok(Validity::AllInvalid),
297            _ => {}
298        };
299
300        let own_nullability = if matches!(self, Validity::NonNullable) {
301            Nullability::NonNullable
302        } else {
303            Nullability::Nullable
304        };
305
306        let source = match self {
307            Validity::NonNullable => BoolArray::from(BitBuffer::new_set(len)),
308            Validity::AllValid => BoolArray::from(BitBuffer::new_set(len)),
309            Validity::AllInvalid => BoolArray::from(BitBuffer::new_unset(len)),
310            Validity::Array(a) => a.execute::<BoolArray>(ctx)?,
311        };
312
313        let patch_values = match patches {
314            Validity::NonNullable => BoolArray::from(BitBuffer::new_set(indices.len())),
315            Validity::AllValid => BoolArray::from(BitBuffer::new_set(indices.len())),
316            Validity::AllInvalid => BoolArray::from(BitBuffer::new_unset(indices.len())),
317            Validity::Array(a) => a.clone().execute::<BoolArray>(ctx)?,
318        };
319
320        let patches = Patches::new(
321            len,
322            indices_offset,
323            indices.clone(),
324            patch_values.into_array(),
325            // TODO(0ax1): chunk offsets
326            None,
327        )?;
328
329        Ok(Self::from_array(
330            source.patch(&patches, ctx)?.into_array(),
331            own_nullability,
332        ))
333    }
334
335    /// Convert into a nullable variant
336    #[inline]
337    pub fn into_nullable(self) -> Validity {
338        match self {
339            Self::NonNullable => Self::AllValid,
340            Self::AllValid | Self::AllInvalid | Self::Array(_) => self,
341        }
342    }
343
344    /// Convert into a non-nullable variant
345    #[inline]
346    pub fn into_non_nullable(self, len: usize) -> Option<Validity> {
347        match self {
348            _ if len == 0 => Some(Validity::NonNullable),
349            Self::NonNullable => Some(Self::NonNullable),
350            Self::AllValid => Some(Self::NonNullable),
351            Self::AllInvalid => None,
352            Self::Array(is_valid) => {
353                is_valid
354                    .statistics()
355                    .compute_min::<bool>()
356                    .vortex_expect("validity array must support min")
357                    .then(|| {
358                        // min true => all true
359                        Self::NonNullable
360                    })
361            }
362        }
363    }
364
365    /// Convert into a variant compatible with the given nullability, if possible.
366    #[inline]
367    pub fn cast_nullability(self, nullability: Nullability, len: usize) -> VortexResult<Validity> {
368        match nullability {
369            Nullability::NonNullable => self.into_non_nullable(len).ok_or_else(|| {
370                vortex_err!(InvalidArgument: "Cannot cast array with invalid values to non-nullable type.")
371            }),
372            Nullability::Nullable => Ok(self.into_nullable()),
373        }
374    }
375
376    /// Create Validity by copying the given array's validity.
377    #[inline]
378    pub fn copy_from_array(array: &ArrayRef) -> VortexResult<Self> {
379        Ok(Validity::from_mask(
380            array.validity_mask()?,
381            array.dtype().nullability(),
382        ))
383    }
384
385    /// Create Validity from boolean array with given nullability of the array.
386    ///
387    /// Note: You want to pass the nullability of parent array and not the nullability of the validity array itself
388    ///     as that is always nonnullable
389    fn from_array(value: ArrayRef, nullability: Nullability) -> Self {
390        if !matches!(value.dtype(), DType::Bool(Nullability::NonNullable)) {
391            vortex_panic!("Expected a non-nullable boolean array")
392        }
393        match nullability {
394            Nullability::NonNullable => Self::NonNullable,
395            Nullability::Nullable => Self::Array(value),
396        }
397    }
398
399    /// Returns the length of the validity array, if it exists.
400    #[inline]
401    pub fn maybe_len(&self) -> Option<usize> {
402        match self {
403            Self::NonNullable | Self::AllValid | Self::AllInvalid => None,
404            Self::Array(a) => Some(a.len()),
405        }
406    }
407
408    #[inline]
409    pub fn uncompressed_size(&self) -> usize {
410        if let Validity::Array(a) = self {
411            a.len().div_ceil(8)
412        } else {
413            0
414        }
415    }
416}
417
418impl From<BitBuffer> for Validity {
419    #[inline]
420    fn from(value: BitBuffer) -> Self {
421        let true_count = value.true_count();
422        if true_count == value.len() {
423            Self::AllValid
424        } else if true_count == 0 {
425            Self::AllInvalid
426        } else {
427            Self::Array(BoolArray::from(value).into_array())
428        }
429    }
430}
431
432impl FromIterator<Mask> for Validity {
433    #[inline]
434    fn from_iter<T: IntoIterator<Item = Mask>>(iter: T) -> Self {
435        Validity::from_mask(iter.into_iter().collect(), Nullability::Nullable)
436    }
437}
438
439impl FromIterator<bool> for Validity {
440    #[inline]
441    fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
442        Validity::from(BitBuffer::from_iter(iter))
443    }
444}
445
446impl From<Nullability> for Validity {
447    #[inline]
448    fn from(value: Nullability) -> Self {
449        Validity::from(&value)
450    }
451}
452
453impl From<&Nullability> for Validity {
454    #[inline]
455    fn from(value: &Nullability) -> Self {
456        match *value {
457            Nullability::NonNullable => Validity::NonNullable,
458            Nullability::Nullable => Validity::AllValid,
459        }
460    }
461}
462
463impl Validity {
464    pub fn from_bit_buffer(buffer: BitBuffer, nullability: Nullability) -> Self {
465        if buffer.true_count() == buffer.len() {
466            nullability.into()
467        } else if buffer.true_count() == 0 {
468            Validity::AllInvalid
469        } else {
470            Validity::Array(BoolArray::new(buffer, Validity::NonNullable).into_array())
471        }
472    }
473
474    pub fn from_mask(mask: Mask, nullability: Nullability) -> Self {
475        assert!(
476            nullability == Nullability::Nullable || matches!(mask, Mask::AllTrue(_)),
477            "NonNullable validity must be AllValid",
478        );
479        match mask {
480            Mask::AllTrue(_) => match nullability {
481                Nullability::NonNullable => Validity::NonNullable,
482                Nullability::Nullable => Validity::AllValid,
483            },
484            Mask::AllFalse(_) => Validity::AllInvalid,
485            Mask::Values(values) => Validity::Array(values.into_array()),
486        }
487    }
488}
489
490impl IntoArray for Mask {
491    #[inline]
492    fn into_array(self) -> ArrayRef {
493        match self {
494            Self::AllTrue(len) => ConstantArray::new(true, len).into_array(),
495            Self::AllFalse(len) => ConstantArray::new(false, len).into_array(),
496            Self::Values(a) => a.into_array(),
497        }
498    }
499}
500
501impl IntoArray for &MaskValues {
502    #[inline]
503    fn into_array(self) -> ArrayRef {
504        BoolArray::new(self.bit_buffer().clone(), Validity::NonNullable).into_array()
505    }
506}
507
508#[cfg(test)]
509mod tests {
510    use rstest::rstest;
511    use vortex_buffer::Buffer;
512    use vortex_buffer::buffer;
513    use vortex_mask::Mask;
514
515    use crate::ArrayRef;
516    use crate::IntoArray;
517    use crate::LEGACY_SESSION;
518    use crate::VortexSessionExecute;
519    use crate::arrays::PrimitiveArray;
520    use crate::dtype::Nullability;
521    use crate::validity::BoolArray;
522    use crate::validity::Validity;
523
524    #[rstest]
525    #[case(Validity::AllValid, 5, &[2, 4], Validity::AllValid, Validity::AllValid)]
526    #[case(
527        Validity::AllValid,
528        5,
529        &[2, 4],
530        Validity::AllInvalid,
531        Validity::Array(BoolArray::from_iter([true, true, false, true, false]).into_array())
532    )]
533    #[case(
534        Validity::AllValid,
535        5,
536        &[2, 4],
537        Validity::Array(BoolArray::from_iter([true, false]).into_array()),
538        Validity::Array(BoolArray::from_iter([true, true, true, true, false]).into_array())
539    )]
540    #[case(
541        Validity::AllInvalid,
542        5,
543        &[2, 4],
544        Validity::AllValid,
545        Validity::Array(BoolArray::from_iter([false, false, true, false, true]).into_array())
546    )]
547    #[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllInvalid, Validity::AllInvalid)]
548    #[case(
549        Validity::AllInvalid,
550        5,
551        &[2, 4],
552        Validity::Array(BoolArray::from_iter([true, false]).into_array()),
553        Validity::Array(BoolArray::from_iter([false, false, true, false, false]).into_array())
554    )]
555    #[case(
556        Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()),
557        5,
558        &[2, 4],
559        Validity::AllValid,
560        Validity::Array(BoolArray::from_iter([false, true, true, true, true]).into_array())
561    )]
562    #[case(
563        Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()),
564        5,
565        &[2, 4],
566        Validity::AllInvalid,
567        Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array())
568    )]
569    #[case(
570        Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()),
571        5,
572        &[2, 4],
573        Validity::Array(BoolArray::from_iter([true, false]).into_array()),
574        Validity::Array(BoolArray::from_iter([false, true, true, true, false]).into_array())
575    )]
576
577    fn patch_validity(
578        #[case] validity: Validity,
579        #[case] len: usize,
580        #[case] positions: &[u64],
581        #[case] patches: Validity,
582        #[case] expected: Validity,
583    ) {
584        let indices =
585            PrimitiveArray::new(Buffer::copy_from(positions), Validity::NonNullable).into_array();
586
587        let mut ctx = LEGACY_SESSION.create_execution_ctx();
588
589        assert!(
590            validity
591                .patch(
592                    len,
593                    0,
594                    &indices,
595                    &patches,
596                    &mut LEGACY_SESSION.create_execution_ctx(),
597                )
598                .unwrap()
599                .mask_eq(&expected, &mut ctx)
600                .unwrap()
601        );
602    }
603
604    #[test]
605    #[should_panic]
606    fn out_of_bounds_patch() {
607        Validity::NonNullable
608            .patch(
609                2,
610                0,
611                &buffer![4].into_array(),
612                &Validity::AllInvalid,
613                &mut LEGACY_SESSION.create_execution_ctx(),
614            )
615            .unwrap();
616    }
617
618    #[test]
619    #[should_panic]
620    fn into_validity_nullable() {
621        Validity::from_mask(Mask::AllFalse(10), Nullability::NonNullable);
622    }
623
624    #[test]
625    #[should_panic]
626    fn into_validity_nullable_array() {
627        Validity::from_mask(Mask::from_iter(vec![true, false]), Nullability::NonNullable);
628    }
629
630    #[rstest]
631    #[case(
632        Validity::AllValid,
633        PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(),
634        Validity::from_iter(vec![true, false])
635    )]
636    #[case(Validity::AllValid, buffer![0, 1].into_array(), Validity::AllValid)]
637    #[case(
638        Validity::AllValid,
639        PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(),
640        Validity::AllInvalid
641    )]
642    #[case(
643        Validity::NonNullable,
644        PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(),
645        Validity::from_iter(vec![true, false])
646    )]
647    #[case(Validity::NonNullable, buffer![0, 1].into_array(), Validity::NonNullable)]
648    #[case(
649        Validity::NonNullable,
650        PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(),
651        Validity::AllInvalid
652    )]
653    fn validity_take(
654        #[case] validity: Validity,
655        #[case] indices: ArrayRef,
656        #[case] expected: Validity,
657    ) {
658        let mut ctx = LEGACY_SESSION.create_execution_ctx();
659        assert!(
660            validity
661                .take(&indices)
662                .unwrap()
663                .mask_eq(&expected, &mut ctx)
664                .unwrap()
665        );
666    }
667}