vortex_array/
validity.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Array validity and nullability behavior, used by arrays and compute functions.
5
6use std::fmt::Debug;
7use std::ops::BitAnd;
8use std::ops::Range;
9
10use vortex_buffer::BitBuffer;
11use vortex_dtype::DType;
12use vortex_dtype::Nullability;
13use vortex_error::VortexExpect as _;
14use vortex_error::VortexResult;
15use vortex_error::vortex_err;
16use vortex_error::vortex_panic;
17use vortex_mask::AllOr;
18use vortex_mask::Mask;
19use vortex_mask::MaskValues;
20use vortex_scalar::Scalar;
21
22use crate::Array;
23use crate::ArrayRef;
24use crate::IntoArray;
25use crate::ToCanonical;
26use crate::arrays::BoolArray;
27use crate::arrays::ConstantArray;
28use crate::compute::fill_null;
29use crate::compute::filter;
30use crate::compute::sum;
31use crate::compute::take;
32use crate::patches::Patches;
33
34/// Validity information for an array
35#[derive(Clone, Debug)]
36pub enum Validity {
37    /// Items *can't* be null
38    NonNullable,
39    /// All items are valid
40    AllValid,
41    /// All items are null
42    AllInvalid,
43    /// The validity of each position in the array is determined by a boolean array.
44    ///
45    /// True values are valid, false values are invalid ("null").
46    Array(ArrayRef),
47}
48
49impl Validity {
50    /// The [`DType`] of the underlying validity array (if it exists).
51    pub const DTYPE: DType = DType::Bool(Nullability::NonNullable);
52
53    /// If Validity is [`Validity::Array`], returns the array, otherwise returns `None`.
54    #[inline]
55    pub fn into_array(self) -> Option<ArrayRef> {
56        if let Self::Array(a) = self {
57            Some(a)
58        } else {
59            None
60        }
61    }
62
63    /// If Validity is [`Validity::Array`], returns a reference to the array array, otherwise returns `None`.
64    #[inline]
65    pub fn as_array(&self) -> Option<&ArrayRef> {
66        if let Self::Array(a) = self {
67            Some(a)
68        } else {
69            None
70        }
71    }
72
73    #[inline]
74    pub fn nullability(&self) -> Nullability {
75        if matches!(self, Self::NonNullable) {
76            Nullability::NonNullable
77        } else {
78            Nullability::Nullable
79        }
80    }
81
82    /// The union nullability and validity.
83    #[inline]
84    pub fn union_nullability(self, nullability: Nullability) -> Self {
85        match nullability {
86            Nullability::NonNullable => self,
87            Nullability::Nullable => self.into_nullable(),
88        }
89    }
90
91    #[inline]
92    pub fn all_valid(&self, len: usize) -> bool {
93        match self {
94            _ if len == 0 => true,
95            Validity::NonNullable | Validity::AllValid => true,
96            Validity::AllInvalid => false,
97            Validity::Array(array) => {
98                usize::try_from(&sum(array).vortex_expect("must have sum for bool array"))
99                    .vortex_expect("sum must be a usize")
100                    == array.len()
101            }
102        }
103    }
104
105    #[inline]
106    pub fn all_invalid(&self, len: usize) -> bool {
107        match self {
108            _ if len == 0 => true,
109            Validity::NonNullable | Validity::AllValid => false,
110            Validity::AllInvalid => true,
111            Validity::Array(array) => {
112                usize::try_from(&sum(array).vortex_expect("must have sum for bool array"))
113                    .vortex_expect("sum must be a usize")
114                    == 0
115            }
116        }
117    }
118
119    /// Returns whether the `index` item is valid.
120    #[inline]
121    pub fn is_valid(&self, index: usize) -> bool {
122        match self {
123            Self::NonNullable | Self::AllValid => true,
124            Self::AllInvalid => false,
125            Self::Array(a) => {
126                let scalar = a.scalar_at(index);
127                scalar
128                    .as_bool()
129                    .value()
130                    .vortex_expect("Validity must be non-nullable")
131            }
132        }
133    }
134
135    #[inline]
136    pub fn is_null(&self, index: usize) -> bool {
137        !self.is_valid(index)
138    }
139
140    #[inline]
141    pub fn slice(&self, range: Range<usize>) -> Self {
142        match self {
143            Self::Array(a) => Self::Array(a.slice(range)),
144            Self::NonNullable | Self::AllValid | Self::AllInvalid => self.clone(),
145        }
146    }
147
148    pub fn take(&self, indices: &dyn Array) -> VortexResult<Self> {
149        match self {
150            Self::NonNullable => match indices.validity_mask().bit_buffer() {
151                AllOr::All => {
152                    if indices.dtype().is_nullable() {
153                        Ok(Self::AllValid)
154                    } else {
155                        Ok(Self::NonNullable)
156                    }
157                }
158                AllOr::None => Ok(Self::AllInvalid),
159                AllOr::Some(buf) => Ok(Validity::from(buf.clone())),
160            },
161            Self::AllValid => match indices.validity_mask().bit_buffer() {
162                AllOr::All => Ok(Self::AllValid),
163                AllOr::None => Ok(Self::AllInvalid),
164                AllOr::Some(buf) => Ok(Validity::from(buf.clone())),
165            },
166            Self::AllInvalid => Ok(Self::AllInvalid),
167            Self::Array(is_valid) => {
168                let maybe_is_valid = take(is_valid, indices)?;
169                // Null indices invalidate that position.
170                let is_valid = fill_null(&maybe_is_valid, &Scalar::from(false))?;
171                Ok(Self::Array(is_valid))
172            }
173        }
174    }
175
176    /// Keep only the entries for which the mask is true.
177    ///
178    /// The result has length equal to the number of true values in mask.
179    pub fn filter(&self, mask: &Mask) -> VortexResult<Self> {
180        // NOTE(ngates): we take the mask as a reference to avoid the caller cloning unnecessarily
181        //  if we happen to be NonNullable, AllValid, or AllInvalid.
182        match self {
183            v @ (Validity::NonNullable | Validity::AllValid | Validity::AllInvalid) => {
184                Ok(v.clone())
185            }
186            Validity::Array(arr) => Ok(Validity::Array(filter(arr, mask)?)),
187        }
188    }
189
190    /// Set to false any entries for which the mask is true.
191    ///
192    /// The result is always nullable. The result has the same length as self.
193    #[inline]
194    pub fn mask(&self, mask: &Mask) -> Self {
195        match mask.bit_buffer() {
196            AllOr::All => Validity::AllInvalid,
197            AllOr::None => self.clone().into_nullable(),
198            AllOr::Some(make_invalid) => match self {
199                Validity::NonNullable | Validity::AllValid => {
200                    Validity::Array(BoolArray::from(!make_invalid).into_array())
201                }
202                Validity::AllInvalid => Validity::AllInvalid,
203                Validity::Array(is_valid) => {
204                    let is_valid = is_valid.to_bool();
205                    Validity::from(is_valid.bit_buffer() & !make_invalid)
206                }
207            },
208        }
209    }
210
211    #[inline]
212    pub fn to_mask(&self, length: usize) -> Mask {
213        match self {
214            Self::NonNullable | Self::AllValid => Mask::AllTrue(length),
215            Self::AllInvalid => Mask::AllFalse(length),
216            Self::Array(is_valid) => {
217                assert_eq!(
218                    is_valid.len(),
219                    length,
220                    "Validity::Array length must equal to_logical's argument: {}, {}.",
221                    is_valid.len(),
222                    length,
223                );
224                is_valid.to_bool().to_mask()
225            }
226        }
227    }
228
229    /// Logically & two Validity values of the same length
230    #[inline]
231    pub fn and(self, rhs: Validity) -> Validity {
232        match (self, rhs) {
233            // Should be pretty clear
234            (Validity::NonNullable, Validity::NonNullable) => Validity::NonNullable,
235            // Any `AllInvalid` makes the output all invalid values
236            (Validity::AllInvalid, _) | (_, Validity::AllInvalid) => Validity::AllInvalid,
237            // All truthy values on one side, which makes no effect on an `Array` variant
238            (Validity::Array(a), Validity::AllValid)
239            | (Validity::Array(a), Validity::NonNullable)
240            | (Validity::NonNullable, Validity::Array(a))
241            | (Validity::AllValid, Validity::Array(a)) => Validity::Array(a),
242            // Both sides are all valid
243            (Validity::NonNullable, Validity::AllValid)
244            | (Validity::AllValid, Validity::NonNullable)
245            | (Validity::AllValid, Validity::AllValid) => Validity::AllValid,
246            // Here we actually have to do some work
247            (Validity::Array(lhs), Validity::Array(rhs)) => {
248                let lhs = lhs.to_bool();
249                let rhs = rhs.to_bool();
250
251                let lhs = lhs.bit_buffer();
252                let rhs = rhs.bit_buffer();
253
254                Validity::from(lhs.bitand(rhs))
255            }
256        }
257    }
258
259    pub fn patch(
260        self,
261        len: usize,
262        indices_offset: usize,
263        indices: &dyn Array,
264        patches: &Validity,
265    ) -> Self {
266        match (&self, patches) {
267            (Validity::NonNullable, Validity::NonNullable) => return Validity::NonNullable,
268            (Validity::NonNullable, _) => {
269                vortex_panic!("Can't patch a non-nullable validity with nullable validity")
270            }
271            (_, Validity::NonNullable) => {
272                vortex_panic!("Can't patch a nullable validity with non-nullable validity")
273            }
274            (Validity::AllValid, Validity::AllValid) => return Validity::AllValid,
275            (Validity::AllInvalid, Validity::AllInvalid) => return Validity::AllInvalid,
276            _ => {}
277        };
278
279        let own_nullability = if self == Validity::NonNullable {
280            Nullability::NonNullable
281        } else {
282            Nullability::Nullable
283        };
284
285        let source = match self {
286            Validity::NonNullable => BoolArray::from(BitBuffer::new_set(len)),
287            Validity::AllValid => BoolArray::from(BitBuffer::new_set(len)),
288            Validity::AllInvalid => BoolArray::from(BitBuffer::new_unset(len)),
289            Validity::Array(a) => a.to_bool(),
290        };
291
292        let patch_values = match patches {
293            Validity::NonNullable => BoolArray::from(BitBuffer::new_set(indices.len())),
294            Validity::AllValid => BoolArray::from(BitBuffer::new_set(indices.len())),
295            Validity::AllInvalid => BoolArray::from(BitBuffer::new_unset(indices.len())),
296            Validity::Array(a) => a.to_bool(),
297        };
298
299        let patches = Patches::new(
300            len,
301            indices_offset,
302            indices.to_array(),
303            patch_values.into_array(),
304            // TODO(0ax1): chunk offsets
305            None,
306        );
307
308        Self::from_array(source.patch(&patches).into_array(), own_nullability)
309    }
310
311    /// Convert into a nullable variant
312    #[inline]
313    pub fn into_nullable(self) -> Validity {
314        match self {
315            Self::NonNullable => Self::AllValid,
316            Self::AllValid | Self::AllInvalid | Self::Array(_) => self,
317        }
318    }
319
320    /// Convert into a non-nullable variant
321    #[inline]
322    pub fn into_non_nullable(self, len: usize) -> Option<Validity> {
323        match self {
324            _ if len == 0 => Some(Validity::NonNullable),
325            Self::NonNullable => Some(Self::NonNullable),
326            Self::AllValid => Some(Self::NonNullable),
327            Self::AllInvalid => None,
328            Self::Array(is_valid) => {
329                is_valid
330                    .statistics()
331                    .compute_min::<bool>()
332                    .vortex_expect("validity array must support min")
333                    .then(|| {
334                        // min true => all true
335                        Self::NonNullable
336                    })
337            }
338        }
339    }
340
341    /// Convert into a variant compatible with the given nullability, if possible.
342    #[inline]
343    pub fn cast_nullability(self, nullability: Nullability, len: usize) -> VortexResult<Validity> {
344        match nullability {
345            Nullability::NonNullable => self.into_non_nullable(len).ok_or_else(|| {
346                vortex_err!("Cannot cast array with invalid values to non-nullable type.")
347            }),
348            Nullability::Nullable => Ok(self.into_nullable()),
349        }
350    }
351
352    /// Create Validity by copying the given array's validity.
353    #[inline]
354    pub fn copy_from_array(array: &dyn Array) -> Self {
355        Validity::from_mask(array.validity_mask(), array.dtype().nullability())
356    }
357
358    /// Create Validity from boolean array with given nullability of the array.
359    ///
360    /// Note: You want to pass the nullability of parent array and not the nullability of the validity array itself
361    ///     as that is always nonnullable
362    #[inline]
363    fn from_array(value: ArrayRef, nullability: Nullability) -> Self {
364        if !matches!(value.dtype(), DType::Bool(Nullability::NonNullable)) {
365            vortex_panic!("Expected a non-nullable boolean array")
366        }
367        match nullability {
368            Nullability::NonNullable => Self::NonNullable,
369            Nullability::Nullable => Self::Array(value),
370        }
371    }
372
373    /// Returns the length of the validity array, if it exists.
374    #[inline]
375    pub fn maybe_len(&self) -> Option<usize> {
376        match self {
377            Self::NonNullable | Self::AllValid | Self::AllInvalid => None,
378            Self::Array(a) => Some(a.len()),
379        }
380    }
381
382    #[inline]
383    pub fn uncompressed_size(&self) -> usize {
384        if let Validity::Array(a) = self {
385            a.len().div_ceil(8)
386        } else {
387            0
388        }
389    }
390}
391
392impl PartialEq for Validity {
393    #[inline]
394    fn eq(&self, other: &Self) -> bool {
395        match (self, other) {
396            (Self::NonNullable, Self::NonNullable) => true,
397            (Self::AllValid, Self::AllValid) => true,
398            (Self::AllInvalid, Self::AllInvalid) => true,
399            (Self::Array(a), Self::Array(b)) => {
400                let a = a.to_bool();
401                let b = b.to_bool();
402                a.bit_buffer() == b.bit_buffer()
403            }
404            _ => false,
405        }
406    }
407}
408
409impl From<BitBuffer> for Validity {
410    #[inline]
411    fn from(value: BitBuffer) -> Self {
412        let true_count = value.true_count();
413        if true_count == value.len() {
414            Self::AllValid
415        } else if true_count == 0 {
416            Self::AllInvalid
417        } else {
418            Self::Array(BoolArray::from(value).into_array())
419        }
420    }
421}
422
423impl FromIterator<Mask> for Validity {
424    #[inline]
425    fn from_iter<T: IntoIterator<Item = Mask>>(iter: T) -> Self {
426        Validity::from_mask(iter.into_iter().collect(), Nullability::Nullable)
427    }
428}
429
430impl FromIterator<bool> for Validity {
431    #[inline]
432    fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
433        Validity::from(BitBuffer::from_iter(iter))
434    }
435}
436
437impl From<Nullability> for Validity {
438    #[inline]
439    fn from(value: Nullability) -> Self {
440        match value {
441            Nullability::NonNullable => Validity::NonNullable,
442            Nullability::Nullable => Validity::AllValid,
443        }
444    }
445}
446
447impl Validity {
448    pub fn from_bit_buffer(buffer: BitBuffer, nullability: Nullability) -> Self {
449        if buffer.true_count() == buffer.len() {
450            nullability.into()
451        } else if buffer.true_count() == 0 {
452            Validity::AllInvalid
453        } else {
454            Validity::Array(BoolArray::from_bit_buffer(buffer, Validity::NonNullable).into_array())
455        }
456    }
457
458    pub fn from_mask(mask: Mask, nullability: Nullability) -> Self {
459        assert!(
460            nullability == Nullability::Nullable || matches!(mask, Mask::AllTrue(_)),
461            "NonNullable validity must be AllValid",
462        );
463        match mask {
464            Mask::AllTrue(_) => match nullability {
465                Nullability::NonNullable => Validity::NonNullable,
466                Nullability::Nullable => Validity::AllValid,
467            },
468            Mask::AllFalse(_) => Validity::AllInvalid,
469            Mask::Values(values) => Validity::Array(values.into_array()),
470        }
471    }
472}
473
474impl IntoArray for Mask {
475    #[inline]
476    fn into_array(self) -> ArrayRef {
477        match self {
478            Self::AllTrue(len) => ConstantArray::new(true, len).into_array(),
479            Self::AllFalse(len) => ConstantArray::new(false, len).into_array(),
480            Self::Values(a) => a.into_array(),
481        }
482    }
483}
484
485impl IntoArray for &MaskValues {
486    #[inline]
487    fn into_array(self) -> ArrayRef {
488        BoolArray::from_bit_buffer(self.bit_buffer().clone(), Validity::NonNullable).into_array()
489    }
490}
491
492#[cfg(test)]
493mod tests {
494    use rstest::rstest;
495    use vortex_buffer::Buffer;
496    use vortex_buffer::buffer;
497    use vortex_dtype::Nullability;
498    use vortex_mask::Mask;
499
500    use crate::ArrayRef;
501    use crate::IntoArray;
502    use crate::arrays::BoolArray;
503    use crate::arrays::PrimitiveArray;
504    use crate::validity::Validity;
505
506    #[rstest]
507    #[case(Validity::AllValid, 5, &[2, 4], Validity::AllValid, Validity::AllValid)]
508    #[case(Validity::AllValid, 5, &[2, 4], Validity::AllInvalid, Validity::Array(BoolArray::from_iter([true, true, false, true, false]).into_array())
509    )]
510    #[case(Validity::AllValid, 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([true, true, true, true, false]).into_array())
511    )]
512    #[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllValid, Validity::Array(BoolArray::from_iter([false, false, true, false, true]).into_array())
513    )]
514    #[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllInvalid, Validity::AllInvalid)]
515    #[case(Validity::AllInvalid, 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([false, false, true, false, false]).into_array())
516    )]
517    #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::AllValid, Validity::Array(BoolArray::from_iter([false, true, true, true, true]).into_array())
518    )]
519    #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::AllInvalid, Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array())
520    )]
521    #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([false, true, true, true, false]).into_array())
522    )]
523    fn patch_validity(
524        #[case] validity: Validity,
525        #[case] len: usize,
526        #[case] positions: &[u64],
527        #[case] patches: Validity,
528        #[case] expected: Validity,
529    ) {
530        let indices =
531            PrimitiveArray::new(Buffer::copy_from(positions), Validity::NonNullable).into_array();
532        assert_eq!(validity.patch(len, 0, &indices, &patches), expected);
533    }
534
535    #[test]
536    #[should_panic]
537    fn out_of_bounds_patch() {
538        Validity::NonNullable.patch(2, 0, &buffer![4].into_array(), &Validity::AllInvalid);
539    }
540
541    #[test]
542    #[should_panic]
543    fn into_validity_nullable() {
544        Validity::from_mask(Mask::AllFalse(10), Nullability::NonNullable);
545    }
546
547    #[test]
548    #[should_panic]
549    fn into_validity_nullable_array() {
550        Validity::from_mask(Mask::from_iter(vec![true, false]), Nullability::NonNullable);
551    }
552
553    #[rstest]
554    #[case(Validity::AllValid, PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(), Validity::from_iter(vec![true, false])
555    )]
556    #[case(Validity::AllValid, buffer![0, 1].into_array(), Validity::AllValid)]
557    #[case(Validity::AllValid, PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(), Validity::AllInvalid
558    )]
559    #[case(Validity::NonNullable, PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(), Validity::from_iter(vec![true, false])
560    )]
561    #[case(Validity::NonNullable, buffer![0, 1].into_array(), Validity::NonNullable)]
562    #[case(Validity::NonNullable, PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(), Validity::AllInvalid
563    )]
564    fn validity_take(
565        #[case] validity: Validity,
566        #[case] indices: ArrayRef,
567        #[case] expected: Validity,
568    ) {
569        assert_eq!(validity.take(&indices).unwrap(), expected);
570    }
571
572    #[test]
573    fn mask_non_nullable() {
574        assert_eq!(
575            Validity::AllValid,
576            Validity::NonNullable.mask(&Mask::AllFalse(2))
577        )
578    }
579}