vortex_array/
validity.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Array validity and nullability behavior, used by arrays and compute functions.
5
6use std::fmt::Debug;
7use std::ops::{BitAnd, Not};
8
9use arrow_buffer::{BooleanBuffer, NullBuffer};
10use vortex_dtype::{DType, Nullability};
11use vortex_error::{VortexExpect as _, VortexResult, vortex_bail, vortex_err, vortex_panic};
12use vortex_mask::{AllOr, Mask, MaskValues};
13use vortex_scalar::Scalar;
14
15use crate::arrays::{BoolArray, ConstantArray};
16use crate::compute::{fill_null, filter, sum, take};
17use crate::patches::Patches;
18use crate::{Array, ArrayRef, IntoArray, ToCanonical};
19
20/// Validity information for an array
21#[derive(Clone, Debug)]
22pub enum Validity {
23    /// Items *can't* be null
24    NonNullable,
25    /// All items are valid
26    AllValid,
27    /// All items are null
28    AllInvalid,
29    /// Specified items are null
30    Array(ArrayRef),
31}
32
33impl Validity {
34    /// The [`DType`] of the underlying validity array (if it exists).
35    pub const DTYPE: DType = DType::Bool(Nullability::NonNullable);
36
37    pub fn null_count(&self, length: usize) -> VortexResult<usize> {
38        match self {
39            Self::NonNullable | Self::AllValid => Ok(0),
40            Self::AllInvalid => Ok(length),
41            Self::Array(a) => {
42                let validity_len = a.len();
43                if validity_len != length {
44                    vortex_bail!(
45                        "Validity array length {} doesn't match array length {}",
46                        validity_len,
47                        length
48                    )
49                }
50                let true_count = sum(a)?
51                    .as_primitive()
52                    .as_::<usize>()?
53                    .ok_or_else(|| vortex_err!("Failed to compute true count"))?;
54                Ok(length - true_count)
55            }
56        }
57    }
58
59    /// If Validity is [`Validity::Array`], returns the array, otherwise returns `None`.
60    pub fn into_array(self) -> Option<ArrayRef> {
61        match self {
62            Self::Array(a) => Some(a),
63            _ => None,
64        }
65    }
66
67    /// If Validity is [`Validity::Array`], returns a reference to the array array, otherwise returns `None`.
68    pub fn as_array(&self) -> Option<&ArrayRef> {
69        match self {
70            Self::Array(a) => Some(a),
71            _ => None,
72        }
73    }
74
75    pub fn nullability(&self) -> Nullability {
76        match self {
77            Self::NonNullable => Nullability::NonNullable,
78            _ => Nullability::Nullable,
79        }
80    }
81
82    /// The union nullability and validity.
83    pub fn union_nullability(self, nullability: Nullability) -> Self {
84        match nullability {
85            Nullability::NonNullable => self,
86            Nullability::Nullable => self.into_nullable(),
87        }
88    }
89
90    pub fn all_valid(&self) -> VortexResult<bool> {
91        Ok(match self {
92            Validity::NonNullable | Validity::AllValid => true,
93            Validity::AllInvalid => false,
94            Validity::Array(array) => {
95                // TODO(ngates): replace with SUM compute function
96                array.to_bool()?.boolean_buffer().count_set_bits() == array.len()
97            }
98        })
99    }
100
101    pub fn all_invalid(&self) -> VortexResult<bool> {
102        Ok(match self {
103            Validity::NonNullable | Validity::AllValid => false,
104            Validity::AllInvalid => true,
105            Validity::Array(array) => {
106                // TODO(ngates): replace with SUM compute function
107                array.to_bool()?.boolean_buffer().count_set_bits() == 0
108            }
109        })
110    }
111
112    /// Returns whether the `index` item is valid.
113    #[inline]
114    pub fn is_valid(&self, index: usize) -> VortexResult<bool> {
115        Ok(match self {
116            Self::NonNullable | Self::AllValid => true,
117            Self::AllInvalid => false,
118            Self::Array(a) => {
119                let scalar = a.scalar_at(index)?;
120                scalar
121                    .as_bool()
122                    .value()
123                    .vortex_expect("Validity must be non-nullable")
124            }
125        })
126    }
127
128    #[inline]
129    pub fn is_null(&self, index: usize) -> VortexResult<bool> {
130        Ok(!self.is_valid(index)?)
131    }
132
133    pub fn slice(&self, start: usize, stop: usize) -> VortexResult<Self> {
134        match self {
135            Self::Array(a) => Ok(Self::Array(a.slice(start, stop)?)),
136            _ => Ok(self.clone()),
137        }
138    }
139
140    pub fn take(&self, indices: &dyn Array) -> VortexResult<Self> {
141        match self {
142            Self::NonNullable => match indices.validity_mask()?.boolean_buffer() {
143                AllOr::All => {
144                    if indices.dtype().is_nullable() {
145                        Ok(Self::AllValid)
146                    } else {
147                        Ok(Self::NonNullable)
148                    }
149                }
150                AllOr::None => Ok(Self::AllInvalid),
151                AllOr::Some(buf) => Ok(Validity::from(buf.clone())),
152            },
153            Self::AllValid => match indices.validity_mask()?.boolean_buffer() {
154                AllOr::All => Ok(Self::AllValid),
155                AllOr::None => Ok(Self::AllInvalid),
156                AllOr::Some(buf) => Ok(Validity::from(buf.clone())),
157            },
158            Self::AllInvalid => Ok(Self::AllInvalid),
159            Self::Array(is_valid) => {
160                let maybe_is_valid = take(is_valid, indices)?;
161                // Null indices invalidate that position.
162                let is_valid = fill_null(&maybe_is_valid, &Scalar::from(false))?;
163                Ok(Self::Array(is_valid))
164            }
165        }
166    }
167
168    /// Keep only the entries for which the mask is true.
169    ///
170    /// The result has length equal to the number of true values in mask.
171    pub fn filter(&self, mask: &Mask) -> VortexResult<Self> {
172        // NOTE(ngates): we take the mask as a reference to avoid the caller cloning unnecessarily
173        //  if we happen to be NonNullable, AllValid, or AllInvalid.
174        match self {
175            v @ (Validity::NonNullable | Validity::AllValid | Validity::AllInvalid) => {
176                Ok(v.clone())
177            }
178            Validity::Array(arr) => Ok(Validity::Array(filter(arr, mask)?)),
179        }
180    }
181
182    /// Set to false any entries for which the mask is true.
183    ///
184    /// The result is always nullable. The result has the same length as self.
185    pub fn mask(&self, mask: &Mask) -> VortexResult<Self> {
186        match mask.boolean_buffer() {
187            AllOr::All => Ok(Validity::AllInvalid),
188            AllOr::None => Ok(self.clone()),
189            AllOr::Some(make_invalid) => Ok(match self {
190                Validity::NonNullable | Validity::AllValid => {
191                    Validity::Array(BoolArray::from(make_invalid.not()).into_array())
192                }
193                Validity::AllInvalid => Validity::AllInvalid,
194                Validity::Array(is_valid) => {
195                    let is_valid = is_valid.to_bool()?;
196                    let keep_valid = make_invalid.not();
197                    Validity::from(is_valid.boolean_buffer().bitand(&keep_valid))
198                }
199            }),
200        }
201    }
202
203    pub fn to_mask(&self, length: usize) -> VortexResult<Mask> {
204        Ok(match self {
205            Self::NonNullable | Self::AllValid => Mask::AllTrue(length),
206            Self::AllInvalid => Mask::AllFalse(length),
207            Self::Array(is_valid) => {
208                assert_eq!(
209                    is_valid.len(),
210                    length,
211                    "Validity::Array length must equal to_logical's argument: {}, {}.",
212                    is_valid.len(),
213                    length,
214                );
215                Mask::try_from(&is_valid.to_bool()?)?
216            }
217        })
218    }
219
220    /// Logically & two Validity values of the same length
221    pub fn and(self, rhs: Validity) -> VortexResult<Validity> {
222        let validity = match (self, rhs) {
223            // Should be pretty clear
224            (Validity::NonNullable, Validity::NonNullable) => Validity::NonNullable,
225            // Any `AllInvalid` makes the output all invalid values
226            (Validity::AllInvalid, _) | (_, Validity::AllInvalid) => Validity::AllInvalid,
227            // All truthy values on one side, which makes no effect on an `Array` variant
228            (Validity::Array(a), Validity::AllValid)
229            | (Validity::Array(a), Validity::NonNullable)
230            | (Validity::NonNullable, Validity::Array(a))
231            | (Validity::AllValid, Validity::Array(a)) => Validity::Array(a),
232            // Both sides are all valid
233            (Validity::NonNullable, Validity::AllValid)
234            | (Validity::AllValid, Validity::NonNullable)
235            | (Validity::AllValid, Validity::AllValid) => Validity::AllValid,
236            // Here we actually have to do some work
237            (Validity::Array(lhs), Validity::Array(rhs)) => {
238                let lhs = lhs.to_bool()?;
239                let rhs = rhs.to_bool()?;
240
241                let lhs = lhs.boolean_buffer();
242                let rhs = rhs.boolean_buffer();
243
244                Validity::from(lhs.bitand(rhs))
245            }
246        };
247
248        Ok(validity)
249    }
250
251    pub fn patch(
252        self,
253        len: usize,
254        indices_offset: usize,
255        indices: &dyn Array,
256        patches: &Validity,
257    ) -> VortexResult<Self> {
258        match (&self, patches) {
259            (Validity::NonNullable, Validity::NonNullable) => return Ok(Validity::NonNullable),
260            (Validity::NonNullable, _) => {
261                vortex_bail!("Can't patch a non-nullable validity with nullable validity")
262            }
263            (_, Validity::NonNullable) => {
264                vortex_bail!("Can't patch a nullable validity with non-nullable validity")
265            }
266            (Validity::AllValid, Validity::AllValid) => return Ok(Validity::AllValid),
267            (Validity::AllInvalid, Validity::AllInvalid) => return Ok(Validity::AllInvalid),
268            _ => {}
269        };
270
271        let own_nullability = if self == Validity::NonNullable {
272            Nullability::NonNullable
273        } else {
274            Nullability::Nullable
275        };
276
277        let source = match self {
278            Validity::NonNullable => BoolArray::from(BooleanBuffer::new_set(len)),
279            Validity::AllValid => BoolArray::from(BooleanBuffer::new_set(len)),
280            Validity::AllInvalid => BoolArray::from(BooleanBuffer::new_unset(len)),
281            Validity::Array(a) => a.to_bool()?,
282        };
283
284        let patch_values = match patches {
285            Validity::NonNullable => BoolArray::from(BooleanBuffer::new_set(indices.len())),
286            Validity::AllValid => BoolArray::from(BooleanBuffer::new_set(indices.len())),
287            Validity::AllInvalid => BoolArray::from(BooleanBuffer::new_unset(indices.len())),
288            Validity::Array(a) => a.to_bool()?,
289        };
290
291        let patches = Patches::new(
292            len,
293            indices_offset,
294            indices.to_array(),
295            patch_values.into_array(),
296        );
297
298        Ok(Self::from_array(
299            source.patch(&patches)?.into_array(),
300            own_nullability,
301        ))
302    }
303
304    /// Convert into a nullable variant
305    pub fn into_nullable(self) -> Validity {
306        match self {
307            Self::NonNullable => Self::AllValid,
308            _ => self,
309        }
310    }
311
312    /// Convert into a non-nullable variant
313    pub fn into_non_nullable(self) -> Option<Validity> {
314        match self {
315            Self::NonNullable => Some(Self::NonNullable),
316            Self::AllValid => Some(Self::NonNullable),
317            Self::AllInvalid => None,
318            Self::Array(is_valid) => {
319                is_valid
320                    .statistics()
321                    .compute_min::<bool>()
322                    .vortex_expect("validity array must support min")
323                    .then(|| {
324                        // min true => all true
325                        Self::NonNullable
326                    })
327            }
328        }
329    }
330
331    /// Convert into a variant compatible with the given nullability, if possible.
332    pub fn cast_nullability(self, nullability: Nullability) -> VortexResult<Validity> {
333        match nullability {
334            Nullability::NonNullable => self.into_non_nullable().ok_or_else(|| {
335                vortex_err!("Cannot cast array with invalid values to non-nullable type.")
336            }),
337            Nullability::Nullable => Ok(self.into_nullable()),
338        }
339    }
340
341    /// Create Validity by copying the given array's validity.
342    pub fn copy_from_array(array: &dyn Array) -> VortexResult<Self> {
343        Ok(Validity::from_mask(
344            array.validity_mask()?,
345            array.dtype().nullability(),
346        ))
347    }
348
349    /// Create Validity from boolean array with given nullability of the array.
350    ///
351    /// Note: You want to pass the nullability of parent array and not the nullability of the validity array itself
352    ///     as that is always nonnullable
353    fn from_array(value: ArrayRef, nullability: Nullability) -> Self {
354        if !matches!(value.dtype(), DType::Bool(Nullability::NonNullable)) {
355            vortex_panic!("Expected a non-nullable boolean array")
356        }
357        match nullability {
358            Nullability::NonNullable => Self::NonNullable,
359            Nullability::Nullable => Self::Array(value),
360        }
361    }
362
363    /// Returns the length of the validity array, if it exists.
364    pub fn maybe_len(&self) -> Option<usize> {
365        match self {
366            Self::NonNullable | Self::AllValid | Self::AllInvalid => None,
367            Self::Array(a) => Some(a.len()),
368        }
369    }
370
371    pub fn uncompressed_size(&self) -> usize {
372        if let Validity::Array(a) = self {
373            a.len().div_ceil(8)
374        } else {
375            0
376        }
377    }
378
379    pub fn is_array(&self) -> bool {
380        matches!(self, Validity::Array(_))
381    }
382}
383
384impl PartialEq for Validity {
385    fn eq(&self, other: &Self) -> bool {
386        match (self, other) {
387            (Self::NonNullable, Self::NonNullable) => true,
388            (Self::AllValid, Self::AllValid) => true,
389            (Self::AllInvalid, Self::AllInvalid) => true,
390            (Self::Array(a), Self::Array(b)) => {
391                let a = a
392                    .to_bool()
393                    .vortex_expect("Failed to get Validity Array as BoolArray");
394                let b = b
395                    .to_bool()
396                    .vortex_expect("Failed to get Validity Array as BoolArray");
397                a.boolean_buffer() == b.boolean_buffer()
398            }
399            _ => false,
400        }
401    }
402}
403
404impl From<BooleanBuffer> for Validity {
405    fn from(value: BooleanBuffer) -> Self {
406        if value.count_set_bits() == value.len() {
407            Self::AllValid
408        } else if value.count_set_bits() == 0 {
409            Self::AllInvalid
410        } else {
411            Self::Array(BoolArray::from(value).into_array())
412        }
413    }
414}
415
416impl From<NullBuffer> for Validity {
417    fn from(value: NullBuffer) -> Self {
418        value.into_inner().into()
419    }
420}
421
422impl FromIterator<Mask> for Validity {
423    fn from_iter<T: IntoIterator<Item = Mask>>(iter: T) -> Self {
424        Validity::from_mask(iter.into_iter().collect(), Nullability::Nullable)
425    }
426}
427
428impl FromIterator<bool> for Validity {
429    fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
430        Validity::from(BooleanBuffer::from_iter(iter))
431    }
432}
433
434impl From<Nullability> for Validity {
435    fn from(value: Nullability) -> Self {
436        match value {
437            Nullability::NonNullable => Validity::NonNullable,
438            Nullability::Nullable => Validity::AllValid,
439        }
440    }
441}
442
443impl Validity {
444    pub fn from_mask(mask: Mask, nullability: Nullability) -> Self {
445        assert!(
446            nullability == Nullability::Nullable || matches!(mask, Mask::AllTrue(_)),
447            "NonNullable validity must be AllValid",
448        );
449        match mask {
450            Mask::AllTrue(_) => match nullability {
451                Nullability::NonNullable => Validity::NonNullable,
452                Nullability::Nullable => Validity::AllValid,
453            },
454            Mask::AllFalse(_) => Validity::AllInvalid,
455            Mask::Values(values) => Validity::Array(values.into_array()),
456        }
457    }
458}
459
460impl IntoArray for Mask {
461    fn into_array(self) -> ArrayRef {
462        match self {
463            Self::AllTrue(len) => ConstantArray::new(true, len).into_array(),
464            Self::AllFalse(len) => ConstantArray::new(false, len).into_array(),
465            Self::Values(a) => a.into_array(),
466        }
467    }
468}
469
470impl IntoArray for &MaskValues {
471    fn into_array(self) -> ArrayRef {
472        BoolArray::new(self.boolean_buffer().clone(), Validity::NonNullable).into_array()
473    }
474}
475
476#[cfg(test)]
477mod tests {
478    use rstest::rstest;
479    use vortex_buffer::{Buffer, buffer};
480    use vortex_dtype::Nullability;
481    use vortex_mask::Mask;
482
483    use crate::arrays::{BoolArray, PrimitiveArray};
484    use crate::validity::Validity;
485    use crate::{ArrayRef, IntoArray};
486
487    #[rstest]
488    #[case(Validity::AllValid, 5, &[2, 4], Validity::AllValid, Validity::AllValid)]
489    #[case(Validity::AllValid, 5, &[2, 4], Validity::AllInvalid, Validity::Array(BoolArray::from_iter([true, true, false, true, false]).into_array())
490    )]
491    #[case(Validity::AllValid, 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([true, true, true, true, false]).into_array())
492    )]
493    #[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllValid, Validity::Array(BoolArray::from_iter([false, false, true, false, true]).into_array())
494    )]
495    #[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllInvalid, Validity::AllInvalid)]
496    #[case(Validity::AllInvalid, 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([false, false, true, false, false]).into_array())
497    )]
498    #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::AllValid, Validity::Array(BoolArray::from_iter([false, true, true, true, true]).into_array())
499    )]
500    #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::AllInvalid, Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array())
501    )]
502    #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([false, true, true, true, false]).into_array())
503    )]
504    fn patch_validity(
505        #[case] validity: Validity,
506        #[case] len: usize,
507        #[case] positions: &[u64],
508        #[case] patches: Validity,
509        #[case] expected: Validity,
510    ) {
511        let indices =
512            PrimitiveArray::new(Buffer::copy_from(positions), Validity::NonNullable).into_array();
513        assert_eq!(
514            validity.patch(len, 0, &indices, &patches).unwrap(),
515            expected
516        );
517    }
518
519    #[test]
520    #[should_panic]
521    fn out_of_bounds_patch() {
522        Validity::NonNullable
523            .patch(2, 0, &buffer![4].into_array(), &Validity::AllInvalid)
524            .unwrap();
525    }
526
527    #[test]
528    #[should_panic]
529    fn into_validity_nullable() {
530        Validity::from_mask(Mask::AllFalse(10), Nullability::NonNullable);
531    }
532
533    #[test]
534    #[should_panic]
535    fn into_validity_nullable_array() {
536        Validity::from_mask(Mask::from_iter(vec![true, false]), Nullability::NonNullable);
537    }
538
539    #[rstest]
540    #[case(Validity::AllValid, PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(), Validity::from_iter(vec![true, false]))]
541    #[case(Validity::AllValid, buffer![0, 1].into_array(), Validity::AllValid)]
542    #[case(Validity::AllValid, PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(), Validity::AllInvalid)]
543    #[case(Validity::NonNullable, PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(), Validity::from_iter(vec![true, false]))]
544    #[case(Validity::NonNullable, buffer![0, 1].into_array(), Validity::NonNullable)]
545    #[case(Validity::NonNullable, PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(), Validity::AllInvalid)]
546    fn validity_take(
547        #[case] validity: Validity,
548        #[case] indices: ArrayRef,
549        #[case] expected: Validity,
550    ) {
551        assert_eq!(validity.take(&indices).unwrap(), expected);
552    }
553}