Skip to main content

vortex_array/
validity.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Array validity and nullability behavior, used by arrays and compute functions.
5
6use std::fmt::Debug;
7use std::ops::Range;
8
9use vortex_buffer::BitBuffer;
10use vortex_error::VortexExpect as _;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_error::vortex_err;
14use vortex_error::vortex_panic;
15use vortex_mask::AllOr;
16use vortex_mask::Mask;
17use vortex_mask::MaskValues;
18
19use crate::ArrayRef;
20use crate::Canonical;
21use crate::DynArray;
22use crate::ExecutionCtx;
23use crate::IntoArray;
24use crate::arrays::BoolArray;
25use crate::arrays::ConstantArray;
26use crate::arrays::scalar_fn::ScalarFnArrayExt;
27use crate::builtins::ArrayBuiltins;
28use crate::dtype::DType;
29use crate::dtype::Nullability;
30use crate::optimizer::ArrayOptimizer;
31use crate::patches::Patches;
32use crate::scalar::Scalar;
33use crate::scalar_fn::fns::binary::Binary;
34use crate::scalar_fn::fns::operators::Operator;
35
36/// Validity information for an array
37#[derive(Clone)]
38pub enum Validity {
39    /// Items *can't* be null
40    NonNullable,
41    /// All items are valid
42    AllValid,
43    /// All items are null
44    AllInvalid,
45    /// The validity of each position in the array is determined by a boolean array.
46    ///
47    /// True values are valid, false values are invalid ("null").
48    Array(ArrayRef),
49}
50
51impl Debug for Validity {
52    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53        match self {
54            Self::NonNullable => write!(f, "NonNullable"),
55            Self::AllValid => write!(f, "AllValid"),
56            Self::AllInvalid => write!(f, "AllInvalid"),
57            Self::Array(arr) => write!(f, "SomeValid({})", arr.as_ref().display_values()),
58        }
59    }
60}
61
62impl Validity {
63    /// Make a step towards canonicalising validity if necessary
64    pub fn execute(self, ctx: &mut ExecutionCtx) -> VortexResult<Validity> {
65        match self {
66            v @ Validity::NonNullable | v @ Validity::AllValid | v @ Validity::AllInvalid => Ok(v),
67            Validity::Array(a) => Ok(Validity::Array(a.execute::<Canonical>(ctx)?.into_array())),
68        }
69    }
70}
71
72impl Validity {
73    /// The [`DType`] of the underlying validity array (if it exists).
74    pub const DTYPE: DType = DType::Bool(Nullability::NonNullable);
75
76    /// Convert the validity to an array representation.
77    pub fn to_array(&self, len: usize) -> ArrayRef {
78        match self {
79            Self::NonNullable | Self::AllValid => ConstantArray::new(true, len).into_array(),
80            Self::AllInvalid => ConstantArray::new(false, len).into_array(),
81            Self::Array(a) => a.clone(),
82        }
83    }
84
85    /// If Validity is [`Validity::Array`], returns the array, otherwise returns `None`.
86    #[inline]
87    pub fn into_array(self) -> Option<ArrayRef> {
88        if let Self::Array(a) = self {
89            Some(a)
90        } else {
91            None
92        }
93    }
94
95    /// If Validity is [`Validity::Array`], returns a reference to the array array, otherwise returns `None`.
96    #[inline]
97    pub fn as_array(&self) -> Option<&ArrayRef> {
98        if let Self::Array(a) = self {
99            Some(a)
100        } else {
101            None
102        }
103    }
104
105    #[inline]
106    pub fn nullability(&self) -> Nullability {
107        if matches!(self, Self::NonNullable) {
108            Nullability::NonNullable
109        } else {
110            Nullability::Nullable
111        }
112    }
113
114    /// The union nullability and validity.
115    #[inline]
116    pub fn union_nullability(self, nullability: Nullability) -> Self {
117        match nullability {
118            Nullability::NonNullable => self,
119            Nullability::Nullable => self.into_nullable(),
120        }
121    }
122
123    /// Returns whether the `index` item is valid.
124    #[inline]
125    pub fn is_valid(&self, index: usize) -> VortexResult<bool> {
126        Ok(match self {
127            Self::NonNullable | Self::AllValid => true,
128            Self::AllInvalid => false,
129            Self::Array(a) => a
130                .scalar_at(index)
131                .vortex_expect("Validity array must support scalar_at")
132                .as_bool()
133                .value()
134                .vortex_expect("Validity must be non-nullable"),
135        })
136    }
137
138    #[inline]
139    pub fn is_null(&self, index: usize) -> VortexResult<bool> {
140        Ok(!self.is_valid(index)?)
141    }
142
143    #[inline]
144    pub fn slice(&self, range: Range<usize>) -> VortexResult<Self> {
145        match self {
146            Self::Array(a) => Ok(Self::Array(a.slice(range)?)),
147            Self::NonNullable | Self::AllValid | Self::AllInvalid => Ok(self.clone()),
148        }
149    }
150
151    pub fn take(&self, indices: &ArrayRef) -> VortexResult<Self> {
152        match self {
153            Self::NonNullable => match indices.validity_mask()?.bit_buffer() {
154                AllOr::All => {
155                    if indices.dtype().is_nullable() {
156                        Ok(Self::AllValid)
157                    } else {
158                        Ok(Self::NonNullable)
159                    }
160                }
161                AllOr::None => Ok(Self::AllInvalid),
162                AllOr::Some(buf) => Ok(Validity::from(buf.clone())),
163            },
164            Self::AllValid => match indices.validity_mask()?.bit_buffer() {
165                AllOr::All => Ok(Self::AllValid),
166                AllOr::None => Ok(Self::AllInvalid),
167                AllOr::Some(buf) => Ok(Validity::from(buf.clone())),
168            },
169            Self::AllInvalid => Ok(Self::AllInvalid),
170            Self::Array(is_valid) => {
171                let maybe_is_valid = is_valid.take(indices.to_array())?;
172                // Null indices invalidate that position.
173                let is_valid = maybe_is_valid.fill_null(Scalar::from(false))?;
174                Ok(Self::Array(is_valid))
175            }
176        }
177    }
178
179    // Invert the validity
180    pub fn not(&self) -> VortexResult<Self> {
181        match self {
182            Validity::NonNullable => Ok(Validity::NonNullable),
183            Validity::AllValid => Ok(Validity::AllInvalid),
184            Validity::AllInvalid => Ok(Validity::AllValid),
185            Validity::Array(arr) => Ok(Validity::Array(arr.not()?)),
186        }
187    }
188
189    /// Lazily filters a [`Validity`] with a selection mask, which keeps only the entries for which
190    /// the mask is true.
191    ///
192    /// The result has length equal to the number of true values in mask.
193    ///
194    /// If the validity is a [`Validity::Array`], then this lazily wraps it in a `FilterArray`
195    /// instead of eagerly filtering the values immediately.
196    pub fn filter(&self, mask: &Mask) -> VortexResult<Self> {
197        // NOTE(ngates): we take the mask as a reference to avoid the caller cloning unnecessarily
198        //  if we happen to be NonNullable, AllValid, or AllInvalid.
199        match self {
200            v @ (Validity::NonNullable | Validity::AllValid | Validity::AllInvalid) => {
201                Ok(v.clone())
202            }
203            Validity::Array(arr) => Ok(Validity::Array(arr.filter(mask.clone())?)),
204        }
205    }
206
207    pub fn execute_mask(&self, length: usize, ctx: &mut ExecutionCtx) -> VortexResult<Mask> {
208        match self {
209            Self::NonNullable | Self::AllValid => Ok(Mask::AllTrue(length)),
210            Self::AllInvalid => Ok(Mask::AllFalse(length)),
211            Self::Array(arr) => {
212                assert_eq!(
213                    arr.len(),
214                    length,
215                    "Validity::Array length must equal to_logical's argument: {}, {}.",
216                    arr.len(),
217                    length,
218                );
219                // TODO(ngates): I'm not sure execution should take arrays by ownership.
220                //  If so we should fix call sites to clone and this function takes self.
221                arr.clone().execute::<Mask>(ctx)
222            }
223        }
224    }
225
226    /// Compare two Validity values of the same length by executing them into masks if necessary.
227    pub fn mask_eq(&self, other: &Validity, ctx: &mut ExecutionCtx) -> VortexResult<bool> {
228        match (self, other) {
229            (Validity::NonNullable, Validity::NonNullable) => Ok(true),
230            (Validity::AllValid, Validity::AllValid) => Ok(true),
231            (Validity::AllInvalid, Validity::AllInvalid) => Ok(true),
232            (Validity::Array(a), Validity::Array(b)) => {
233                let a = a.clone().execute::<Mask>(ctx)?;
234                let b = b.clone().execute::<Mask>(ctx)?;
235                Ok(a == b)
236            }
237            _ => Ok(false),
238        }
239    }
240
241    /// Logically & two Validity values of the same length
242    #[inline]
243    pub fn and(self, rhs: Validity) -> VortexResult<Validity> {
244        Ok(match (self, rhs) {
245            // Should be pretty clear
246            (Validity::NonNullable, Validity::NonNullable) => Validity::NonNullable,
247            // Any `AllInvalid` makes the output all invalid values
248            (Validity::AllInvalid, _) | (_, Validity::AllInvalid) => Validity::AllInvalid,
249            // All truthy values on one side, which makes no effect on an `Array` variant
250            (Validity::Array(a), Validity::AllValid)
251            | (Validity::Array(a), Validity::NonNullable)
252            | (Validity::NonNullable, Validity::Array(a))
253            | (Validity::AllValid, Validity::Array(a)) => Validity::Array(a),
254            // Both sides are all valid
255            (Validity::NonNullable, Validity::AllValid)
256            | (Validity::AllValid, Validity::NonNullable)
257            | (Validity::AllValid, Validity::AllValid) => Validity::AllValid,
258            // Here we actually have to do some work
259            (Validity::Array(lhs), Validity::Array(rhs)) => Validity::Array(
260                Binary
261                    .try_new_array(lhs.len(), Operator::And, [lhs, rhs])?
262                    .optimize()?,
263            ),
264        })
265    }
266
267    pub fn patch(
268        self,
269        len: usize,
270        indices_offset: usize,
271        indices: &ArrayRef,
272        patches: &Validity,
273        ctx: &mut ExecutionCtx,
274    ) -> VortexResult<Self> {
275        match (&self, patches) {
276            (Validity::NonNullable, Validity::NonNullable) => return Ok(Validity::NonNullable),
277            (Validity::NonNullable, _) => {
278                vortex_bail!("Can't patch a non-nullable validity with nullable validity")
279            }
280            (_, Validity::NonNullable) => {
281                vortex_bail!("Can't patch a nullable validity with non-nullable validity")
282            }
283            (Validity::AllValid, Validity::AllValid) => return Ok(Validity::AllValid),
284            (Validity::AllInvalid, Validity::AllInvalid) => return Ok(Validity::AllInvalid),
285            _ => {}
286        };
287
288        let own_nullability = if matches!(self, Validity::NonNullable) {
289            Nullability::NonNullable
290        } else {
291            Nullability::Nullable
292        };
293
294        let source = match self {
295            Validity::NonNullable => BoolArray::from(BitBuffer::new_set(len)),
296            Validity::AllValid => BoolArray::from(BitBuffer::new_set(len)),
297            Validity::AllInvalid => BoolArray::from(BitBuffer::new_unset(len)),
298            Validity::Array(a) => a.clone().execute::<BoolArray>(ctx)?,
299        };
300
301        let patch_values = match patches {
302            Validity::NonNullable => BoolArray::from(BitBuffer::new_set(indices.len())),
303            Validity::AllValid => BoolArray::from(BitBuffer::new_set(indices.len())),
304            Validity::AllInvalid => BoolArray::from(BitBuffer::new_unset(indices.len())),
305            Validity::Array(a) => a.clone().execute::<BoolArray>(ctx)?,
306        };
307
308        let patches = Patches::new(
309            len,
310            indices_offset,
311            indices.to_array(),
312            patch_values.into_array(),
313            // TODO(0ax1): chunk offsets
314            None,
315        )?;
316
317        Ok(Self::from_array(
318            source.patch(&patches, ctx)?.into_array(),
319            own_nullability,
320        ))
321    }
322
323    /// Convert into a nullable variant
324    #[inline]
325    pub fn into_nullable(self) -> Validity {
326        match self {
327            Self::NonNullable => Self::AllValid,
328            Self::AllValid | Self::AllInvalid | Self::Array(_) => self,
329        }
330    }
331
332    /// Convert into a non-nullable variant
333    #[inline]
334    pub fn into_non_nullable(self, len: usize) -> Option<Validity> {
335        match self {
336            _ if len == 0 => Some(Validity::NonNullable),
337            Self::NonNullable => Some(Self::NonNullable),
338            Self::AllValid => Some(Self::NonNullable),
339            Self::AllInvalid => None,
340            Self::Array(is_valid) => {
341                is_valid
342                    .statistics()
343                    .compute_min::<bool>()
344                    .vortex_expect("validity array must support min")
345                    .then(|| {
346                        // min true => all true
347                        Self::NonNullable
348                    })
349            }
350        }
351    }
352
353    /// Convert into a variant compatible with the given nullability, if possible.
354    #[inline]
355    pub fn cast_nullability(self, nullability: Nullability, len: usize) -> VortexResult<Validity> {
356        match nullability {
357            Nullability::NonNullable => self.into_non_nullable(len).ok_or_else(|| {
358                vortex_err!(InvalidArgument: "Cannot cast array with invalid values to non-nullable type.")
359            }),
360            Nullability::Nullable => Ok(self.into_nullable()),
361        }
362    }
363
364    /// Create Validity by copying the given array's validity.
365    #[inline]
366    pub fn copy_from_array(array: &ArrayRef) -> VortexResult<Self> {
367        Ok(Validity::from_mask(
368            array.validity_mask()?,
369            array.dtype().nullability(),
370        ))
371    }
372
373    /// Create Validity from boolean array with given nullability of the array.
374    ///
375    /// Note: You want to pass the nullability of parent array and not the nullability of the validity array itself
376    ///     as that is always nonnullable
377    #[inline]
378    fn from_array(value: ArrayRef, nullability: Nullability) -> Self {
379        if !matches!(value.dtype(), DType::Bool(Nullability::NonNullable)) {
380            vortex_panic!("Expected a non-nullable boolean array")
381        }
382        match nullability {
383            Nullability::NonNullable => Self::NonNullable,
384            Nullability::Nullable => Self::Array(value),
385        }
386    }
387
388    /// Returns the length of the validity array, if it exists.
389    #[inline]
390    pub fn maybe_len(&self) -> Option<usize> {
391        match self {
392            Self::NonNullable | Self::AllValid | Self::AllInvalid => None,
393            Self::Array(a) => Some(a.len()),
394        }
395    }
396
397    #[inline]
398    pub fn uncompressed_size(&self) -> usize {
399        if let Validity::Array(a) = self {
400            a.len().div_ceil(8)
401        } else {
402            0
403        }
404    }
405}
406
407impl From<BitBuffer> for Validity {
408    #[inline]
409    fn from(value: BitBuffer) -> Self {
410        let true_count = value.true_count();
411        if true_count == value.len() {
412            Self::AllValid
413        } else if true_count == 0 {
414            Self::AllInvalid
415        } else {
416            Self::Array(BoolArray::from(value).into_array())
417        }
418    }
419}
420
421impl FromIterator<Mask> for Validity {
422    #[inline]
423    fn from_iter<T: IntoIterator<Item = Mask>>(iter: T) -> Self {
424        Validity::from_mask(iter.into_iter().collect(), Nullability::Nullable)
425    }
426}
427
428impl FromIterator<bool> for Validity {
429    #[inline]
430    fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
431        Validity::from(BitBuffer::from_iter(iter))
432    }
433}
434
435impl From<Nullability> for Validity {
436    #[inline]
437    fn from(value: Nullability) -> Self {
438        Validity::from(&value)
439    }
440}
441
442impl From<&Nullability> for Validity {
443    #[inline]
444    fn from(value: &Nullability) -> Self {
445        match *value {
446            Nullability::NonNullable => Validity::NonNullable,
447            Nullability::Nullable => Validity::AllValid,
448        }
449    }
450}
451
452impl Validity {
453    pub fn from_bit_buffer(buffer: BitBuffer, nullability: Nullability) -> Self {
454        if buffer.true_count() == buffer.len() {
455            nullability.into()
456        } else if buffer.true_count() == 0 {
457            Validity::AllInvalid
458        } else {
459            Validity::Array(BoolArray::new(buffer, Validity::NonNullable).into_array())
460        }
461    }
462
463    pub fn from_mask(mask: Mask, nullability: Nullability) -> Self {
464        assert!(
465            nullability == Nullability::Nullable || matches!(mask, Mask::AllTrue(_)),
466            "NonNullable validity must be AllValid",
467        );
468        match mask {
469            Mask::AllTrue(_) => match nullability {
470                Nullability::NonNullable => Validity::NonNullable,
471                Nullability::Nullable => Validity::AllValid,
472            },
473            Mask::AllFalse(_) => Validity::AllInvalid,
474            Mask::Values(values) => Validity::Array(values.into_array()),
475        }
476    }
477}
478
479impl IntoArray for Mask {
480    #[inline]
481    fn into_array(self) -> ArrayRef {
482        match self {
483            Self::AllTrue(len) => ConstantArray::new(true, len).into_array(),
484            Self::AllFalse(len) => ConstantArray::new(false, len).into_array(),
485            Self::Values(a) => a.into_array(),
486        }
487    }
488}
489
490impl IntoArray for &MaskValues {
491    #[inline]
492    fn into_array(self) -> ArrayRef {
493        BoolArray::new(self.bit_buffer().clone(), Validity::NonNullable).into_array()
494    }
495}
496
497#[cfg(test)]
498mod tests {
499    use rstest::rstest;
500    use vortex_buffer::Buffer;
501    use vortex_buffer::buffer;
502    use vortex_mask::Mask;
503
504    use crate::ArrayRef;
505    use crate::IntoArray;
506    use crate::LEGACY_SESSION;
507    use crate::VortexSessionExecute;
508    use crate::arrays::PrimitiveArray;
509    use crate::dtype::Nullability;
510    use crate::validity::BoolArray;
511    use crate::validity::Validity;
512
513    #[rstest]
514    #[case(Validity::AllValid, 5, &[2, 4], Validity::AllValid, Validity::AllValid)]
515    #[case(
516        Validity::AllValid,
517        5,
518        &[2, 4],
519        Validity::AllInvalid,
520        Validity::Array(BoolArray::from_iter([true, true, false, true, false]).into_array())
521    )]
522    #[case(
523        Validity::AllValid,
524        5,
525        &[2, 4],
526        Validity::Array(BoolArray::from_iter([true, false]).into_array()),
527        Validity::Array(BoolArray::from_iter([true, true, true, true, false]).into_array())
528    )]
529    #[case(
530        Validity::AllInvalid,
531        5,
532        &[2, 4],
533        Validity::AllValid,
534        Validity::Array(BoolArray::from_iter([false, false, true, false, true]).into_array())
535    )]
536    #[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllInvalid, Validity::AllInvalid)]
537    #[case(
538        Validity::AllInvalid,
539        5,
540        &[2, 4],
541        Validity::Array(BoolArray::from_iter([true, false]).into_array()),
542        Validity::Array(BoolArray::from_iter([false, false, true, false, false]).into_array())
543    )]
544    #[case(
545        Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()),
546        5,
547        &[2, 4],
548        Validity::AllValid,
549        Validity::Array(BoolArray::from_iter([false, true, true, true, true]).into_array())
550    )]
551    #[case(
552        Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()),
553        5,
554        &[2, 4],
555        Validity::AllInvalid,
556        Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array())
557    )]
558    #[case(
559        Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()),
560        5,
561        &[2, 4],
562        Validity::Array(BoolArray::from_iter([true, false]).into_array()),
563        Validity::Array(BoolArray::from_iter([false, true, true, true, false]).into_array())
564    )]
565
566    fn patch_validity(
567        #[case] validity: Validity,
568        #[case] len: usize,
569        #[case] positions: &[u64],
570        #[case] patches: Validity,
571        #[case] expected: Validity,
572    ) {
573        let indices =
574            PrimitiveArray::new(Buffer::copy_from(positions), Validity::NonNullable).into_array();
575
576        let mut ctx = LEGACY_SESSION.create_execution_ctx();
577
578        assert!(
579            validity
580                .patch(
581                    len,
582                    0,
583                    &indices,
584                    &patches,
585                    &mut LEGACY_SESSION.create_execution_ctx(),
586                )
587                .unwrap()
588                .mask_eq(&expected, &mut ctx)
589                .unwrap()
590        );
591    }
592
593    #[test]
594    #[should_panic]
595    fn out_of_bounds_patch() {
596        Validity::NonNullable
597            .patch(
598                2,
599                0,
600                &buffer![4].into_array(),
601                &Validity::AllInvalid,
602                &mut LEGACY_SESSION.create_execution_ctx(),
603            )
604            .unwrap();
605    }
606
607    #[test]
608    #[should_panic]
609    fn into_validity_nullable() {
610        Validity::from_mask(Mask::AllFalse(10), Nullability::NonNullable);
611    }
612
613    #[test]
614    #[should_panic]
615    fn into_validity_nullable_array() {
616        Validity::from_mask(Mask::from_iter(vec![true, false]), Nullability::NonNullable);
617    }
618
619    #[rstest]
620    #[case(
621        Validity::AllValid,
622        PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(),
623        Validity::from_iter(vec![true, false])
624    )]
625    #[case(Validity::AllValid, buffer![0, 1].into_array(), Validity::AllValid)]
626    #[case(
627        Validity::AllValid,
628        PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(),
629        Validity::AllInvalid
630    )]
631    #[case(
632        Validity::NonNullable,
633        PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(),
634        Validity::from_iter(vec![true, false])
635    )]
636    #[case(Validity::NonNullable, buffer![0, 1].into_array(), Validity::NonNullable)]
637    #[case(
638        Validity::NonNullable,
639        PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(),
640        Validity::AllInvalid
641    )]
642    fn validity_take(
643        #[case] validity: Validity,
644        #[case] indices: ArrayRef,
645        #[case] expected: Validity,
646    ) {
647        let mut ctx = LEGACY_SESSION.create_execution_ctx();
648        assert!(
649            validity
650                .take(&indices)
651                .unwrap()
652                .mask_eq(&expected, &mut ctx)
653                .unwrap()
654        );
655    }
656}