vortex_array/
validity.rs

1//! Array validity and nullability behavior, used by arrays and compute functions.
2
3use std::fmt::Debug;
4use std::ops::{BitAnd, Not};
5
6use arrow_buffer::{BooleanBuffer, BooleanBufferBuilder, NullBuffer};
7use vortex_dtype::{DType, Nullability};
8use vortex_error::{VortexExpect as _, VortexResult, vortex_bail, vortex_err, vortex_panic};
9use vortex_mask::{AllOr, Mask, MaskValues};
10use vortex_scalar::Scalar;
11
12use crate::arrays::{BoolArray, ConstantArray};
13use crate::compute::{fill_null, filter, scalar_at, slice, take};
14use crate::patches::Patches;
15use crate::{Array, ArrayRef, ArrayVariants, IntoArray, ToCanonical};
16
17/// Validity information for an array
18#[derive(Clone, Debug)]
19pub enum Validity {
20    /// Items *can't* be null
21    NonNullable,
22    /// All items are valid
23    AllValid,
24    /// All items are null
25    AllInvalid,
26    /// Specified items are null
27    Array(ArrayRef),
28}
29
30impl Validity {
31    /// The [`DType`] of the underlying validity array (if it exists).
32    pub const DTYPE: DType = DType::Bool(Nullability::NonNullable);
33
34    pub fn null_count(&self, length: usize) -> VortexResult<usize> {
35        match self {
36            Self::NonNullable | Self::AllValid => Ok(0),
37            Self::AllInvalid => Ok(length),
38            Self::Array(a) => {
39                let validity_len = a.len();
40                if validity_len != length {
41                    vortex_bail!(
42                        "Validity array length {} doesn't match array length {}",
43                        validity_len,
44                        length
45                    )
46                }
47                let true_count = a
48                    .as_bool_typed()
49                    .vortex_expect("Validity array must be boolean")
50                    .true_count()?;
51                Ok(length - true_count)
52            }
53        }
54    }
55
56    /// If Validity is [`Validity::Array`], returns the array, otherwise returns `None`.
57    pub fn into_array(self) -> Option<ArrayRef> {
58        match self {
59            Self::Array(a) => Some(a),
60            _ => None,
61        }
62    }
63
64    /// If Validity is [`Validity::Array`], returns a reference to the array array, otherwise returns `None`.
65    pub fn as_array(&self) -> Option<&ArrayRef> {
66        match self {
67            Self::Array(a) => Some(a),
68            _ => None,
69        }
70    }
71
72    pub fn nullability(&self) -> Nullability {
73        match self {
74            Self::NonNullable => Nullability::NonNullable,
75            _ => Nullability::Nullable,
76        }
77    }
78
79    pub fn all_valid(&self) -> VortexResult<bool> {
80        Ok(match self {
81            Validity::NonNullable | Validity::AllValid => true,
82            Validity::AllInvalid => false,
83            Validity::Array(array) => {
84                // TODO(ngates): replace with SUM compute function
85                array.to_bool()?.boolean_buffer().count_set_bits() == array.len()
86            }
87        })
88    }
89
90    pub fn all_invalid(&self) -> VortexResult<bool> {
91        Ok(match self {
92            Validity::NonNullable | Validity::AllValid => false,
93            Validity::AllInvalid => true,
94            Validity::Array(array) => {
95                // TODO(ngates): replace with SUM compute function
96                array.to_bool()?.boolean_buffer().count_set_bits() == 0
97            }
98        })
99    }
100
101    /// Returns whether the `index` item is valid.
102    #[inline]
103    pub fn is_valid(&self, index: usize) -> VortexResult<bool> {
104        Ok(match self {
105            Self::NonNullable | Self::AllValid => true,
106            Self::AllInvalid => false,
107            Self::Array(a) => {
108                let scalar = scalar_at(a, index)?;
109                scalar
110                    .as_bool()
111                    .value()
112                    .vortex_expect("Validity must be non-nullable")
113            }
114        })
115    }
116
117    #[inline]
118    pub fn is_null(&self, index: usize) -> VortexResult<bool> {
119        Ok(!self.is_valid(index)?)
120    }
121
122    pub fn slice(&self, start: usize, stop: usize) -> VortexResult<Self> {
123        match self {
124            Self::Array(a) => Ok(Self::Array(slice(a, start, stop)?)),
125            _ => Ok(self.clone()),
126        }
127    }
128
129    pub fn take(&self, indices: &dyn Array) -> VortexResult<Self> {
130        match self {
131            Self::NonNullable => match indices.validity_mask()?.boolean_buffer() {
132                AllOr::All => {
133                    if indices.dtype().is_nullable() {
134                        Ok(Self::AllValid)
135                    } else {
136                        Ok(Self::NonNullable)
137                    }
138                }
139                AllOr::None => Ok(Self::AllInvalid),
140                AllOr::Some(buf) => Ok(Validity::from(buf.clone())),
141            },
142            Self::AllValid => match indices.validity_mask()?.boolean_buffer() {
143                AllOr::All => Ok(Self::AllValid),
144                AllOr::None => Ok(Self::AllInvalid),
145                AllOr::Some(buf) => Ok(Validity::from(buf.clone())),
146            },
147            Self::AllInvalid => Ok(Self::AllInvalid),
148            Self::Array(is_valid) => {
149                let maybe_is_valid = take(is_valid, indices)?;
150                // Null indices invalidite that position.
151                let is_valid = fill_null(&maybe_is_valid, Scalar::from(false))?;
152                Ok(Self::Array(is_valid))
153            }
154        }
155    }
156
157    /// Keep only the entries for which the mask is true.
158    ///
159    /// The result has length equal to the number of true values in mask.
160    pub fn filter(&self, mask: &Mask) -> VortexResult<Self> {
161        // NOTE(ngates): we take the mask as a reference to avoid the caller cloning unnecessarily
162        //  if we happen to be NonNullable, AllValid, or AllInvalid.
163        match self {
164            v @ (Validity::NonNullable | Validity::AllValid | Validity::AllInvalid) => {
165                Ok(v.clone())
166            }
167            Validity::Array(arr) => Ok(Validity::Array(filter(arr, mask)?)),
168        }
169    }
170
171    /// Set to false any entries for which the mask is true.
172    ///
173    /// The result is always nullable. The result has the same length as self.
174    pub fn mask(&self, mask: &Mask) -> VortexResult<Self> {
175        match mask.boolean_buffer() {
176            AllOr::All => Ok(Validity::AllInvalid),
177            AllOr::None => Ok(self.clone()),
178            AllOr::Some(make_invalid) => Ok(match self {
179                Validity::NonNullable | Validity::AllValid => {
180                    Validity::Array(BoolArray::from(make_invalid.not()).into_array())
181                }
182                Validity::AllInvalid => Validity::AllInvalid,
183                Validity::Array(is_valid) => {
184                    let is_valid = is_valid.to_bool()?;
185                    let keep_valid = make_invalid.not();
186                    Validity::from(is_valid.boolean_buffer().bitand(&keep_valid))
187                }
188            }),
189        }
190    }
191
192    // TODO(ngates): rename to to_mask
193    pub fn to_logical(&self, length: usize) -> VortexResult<Mask> {
194        Ok(match self {
195            Self::NonNullable | Self::AllValid => Mask::AllTrue(length),
196            Self::AllInvalid => Mask::AllFalse(length),
197            Self::Array(is_valid) => {
198                assert_eq!(
199                    is_valid.len(),
200                    length,
201                    "Validity::Array length must equal to_logical's argument: {}, {}.",
202                    is_valid.len(),
203                    length,
204                );
205                Mask::try_from(&is_valid.to_bool()?)?
206            }
207        })
208    }
209
210    /// Logically & two Validity values of the same length
211    pub fn and(self, rhs: Validity) -> VortexResult<Validity> {
212        let validity = match (self, rhs) {
213            // Should be pretty clear
214            (Validity::NonNullable, Validity::NonNullable) => Validity::NonNullable,
215            // Any `AllInvalid` makes the output all invalid values
216            (Validity::AllInvalid, _) | (_, Validity::AllInvalid) => Validity::AllInvalid,
217            // All truthy values on one side, which makes no effect on an `Array` variant
218            (Validity::Array(a), Validity::AllValid)
219            | (Validity::Array(a), Validity::NonNullable)
220            | (Validity::NonNullable, Validity::Array(a))
221            | (Validity::AllValid, Validity::Array(a)) => Validity::Array(a),
222            // Both sides are all valid
223            (Validity::NonNullable, Validity::AllValid)
224            | (Validity::AllValid, Validity::NonNullable)
225            | (Validity::AllValid, Validity::AllValid) => Validity::AllValid,
226            // Here we actually have to do some work
227            (Validity::Array(lhs), Validity::Array(rhs)) => {
228                let lhs = lhs.to_bool()?;
229                let rhs = rhs.to_bool()?;
230
231                let lhs = lhs.boolean_buffer();
232                let rhs = rhs.boolean_buffer();
233
234                Validity::from(lhs.bitand(rhs))
235            }
236        };
237
238        Ok(validity)
239    }
240
241    pub fn patch(
242        self,
243        len: usize,
244        indices_offset: usize,
245        indices: &dyn Array,
246        patches: &Validity,
247    ) -> VortexResult<Self> {
248        match (&self, patches) {
249            (Validity::NonNullable, Validity::NonNullable) => return Ok(Validity::NonNullable),
250            (Validity::NonNullable, _) => {
251                vortex_bail!("Can't patch a non-nullable validity with nullable validity")
252            }
253            (_, Validity::NonNullable) => {
254                vortex_bail!("Can't patch a nullable validity with non-nullable validity")
255            }
256            (Validity::AllValid, Validity::AllValid) => return Ok(Validity::AllValid),
257            (Validity::AllInvalid, Validity::AllInvalid) => return Ok(Validity::AllInvalid),
258            _ => {}
259        };
260
261        let own_nullability = if self == Validity::NonNullable {
262            Nullability::NonNullable
263        } else {
264            Nullability::Nullable
265        };
266
267        let source = match self {
268            Validity::NonNullable => BoolArray::from(BooleanBuffer::new_set(len)),
269            Validity::AllValid => BoolArray::from(BooleanBuffer::new_set(len)),
270            Validity::AllInvalid => BoolArray::from(BooleanBuffer::new_unset(len)),
271            Validity::Array(a) => a.to_bool()?,
272        };
273
274        let patch_values = match patches {
275            Validity::NonNullable => BoolArray::from(BooleanBuffer::new_set(indices.len())),
276            Validity::AllValid => BoolArray::from(BooleanBuffer::new_set(indices.len())),
277            Validity::AllInvalid => BoolArray::from(BooleanBuffer::new_unset(indices.len())),
278            Validity::Array(a) => a.to_bool()?,
279        };
280
281        let patches = Patches::new(
282            len,
283            indices_offset,
284            indices.to_array(),
285            patch_values.into_array(),
286        );
287
288        Ok(Self::from_array(
289            source.patch(&patches)?.into_array(),
290            own_nullability,
291        ))
292    }
293
294    /// Convert into a nullable variant
295    pub fn into_nullable(self) -> Validity {
296        match self {
297            Self::NonNullable => Self::AllValid,
298            _ => self,
299        }
300    }
301
302    /// Convert into a non-nullable variant
303    pub fn into_non_nullable(self) -> Option<Validity> {
304        match self {
305            Self::NonNullable => Some(Self::NonNullable),
306            Self::AllValid => Some(Self::NonNullable),
307            Self::AllInvalid => None,
308            Self::Array(is_valid) => {
309                is_valid
310                    .statistics()
311                    .compute_min::<bool>()
312                    .vortex_expect("validity array must support min")
313                    .then(|| {
314                        // min true => all true
315                        Self::NonNullable
316                    })
317            }
318        }
319    }
320
321    /// Convert into a variant compatible with the given nullability, if possible.
322    pub fn cast_nullability(self, nullability: Nullability) -> VortexResult<Validity> {
323        match nullability {
324            Nullability::NonNullable => self.into_non_nullable().ok_or_else(|| {
325                vortex_err!("Cannot cast array with invalid values to non-nullable type.")
326            }),
327            Nullability::Nullable => Ok(self.into_nullable()),
328        }
329    }
330
331    /// Create Validity by copying the given array's validity.
332    pub fn copy_from_array(array: &dyn Array) -> VortexResult<Self> {
333        Ok(Validity::from_mask(
334            array.validity_mask()?,
335            array.dtype().nullability(),
336        ))
337    }
338
339    /// Create Validity from boolean array with given nullability of the array.
340    ///
341    /// Note: You want to pass the nullability of parent array and not the nullability of the validity array itself
342    ///     as that is always nonnullable
343    fn from_array(value: ArrayRef, nullability: Nullability) -> Self {
344        if !matches!(value.dtype(), DType::Bool(Nullability::NonNullable)) {
345            vortex_panic!("Expected a non-nullable boolean array")
346        }
347        match nullability {
348            Nullability::NonNullable => Self::NonNullable,
349            Nullability::Nullable => Self::Array(value),
350        }
351    }
352
353    /// Returns the length of the validity array, if it exists.
354    pub fn maybe_len(&self) -> Option<usize> {
355        match self {
356            Self::NonNullable | Self::AllValid | Self::AllInvalid => None,
357            Self::Array(a) => Some(a.len()),
358        }
359    }
360
361    pub fn uncompressed_size(&self) -> usize {
362        if let Validity::Array(a) = self {
363            a.len().div_ceil(8)
364        } else {
365            0
366        }
367    }
368}
369
370impl PartialEq for Validity {
371    fn eq(&self, other: &Self) -> bool {
372        match (self, other) {
373            (Self::NonNullable, Self::NonNullable) => true,
374            (Self::AllValid, Self::AllValid) => true,
375            (Self::AllInvalid, Self::AllInvalid) => true,
376            (Self::Array(a), Self::Array(b)) => {
377                let a = a
378                    .to_bool()
379                    .vortex_expect("Failed to get Validity Array as BoolArray");
380                let b = b
381                    .to_bool()
382                    .vortex_expect("Failed to get Validity Array as BoolArray");
383                a.boolean_buffer() == b.boolean_buffer()
384            }
385            _ => false,
386        }
387    }
388}
389
390impl From<BooleanBuffer> for Validity {
391    fn from(value: BooleanBuffer) -> Self {
392        if value.count_set_bits() == value.len() {
393            Self::AllValid
394        } else if value.count_set_bits() == 0 {
395            Self::AllInvalid
396        } else {
397            Self::Array(BoolArray::from(value).into_array())
398        }
399    }
400}
401
402impl From<NullBuffer> for Validity {
403    fn from(value: NullBuffer) -> Self {
404        value.into_inner().into()
405    }
406}
407
408impl FromIterator<Mask> for Validity {
409    fn from_iter<T: IntoIterator<Item = Mask>>(iter: T) -> Self {
410        let validities: Vec<Mask> = iter.into_iter().collect();
411
412        // If they're all valid, then return a single validity.
413        if validities.iter().all(|v| v.all_true()) {
414            return Self::AllValid;
415        }
416        // If they're all invalid, then return a single invalidity.
417        if validities.iter().all(|v| v.all_false()) {
418            return Self::AllInvalid;
419        }
420
421        // Else, construct the boolean buffer
422        let mut buffer = BooleanBufferBuilder::new(validities.iter().map(|v| v.len()).sum());
423        for validity in validities {
424            match validity {
425                Mask::AllTrue(count) => buffer.append_n(count, true),
426                Mask::AllFalse(count) => buffer.append_n(count, false),
427                Mask::Values(values) => {
428                    buffer.append_buffer(values.boolean_buffer());
429                }
430            };
431        }
432        let bool_array = BoolArray::from(buffer.finish());
433        Self::Array(bool_array.into_array())
434    }
435}
436
437impl FromIterator<bool> for Validity {
438    fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
439        Validity::from(BooleanBuffer::from_iter(iter))
440    }
441}
442
443impl From<Nullability> for Validity {
444    fn from(value: Nullability) -> Self {
445        match value {
446            Nullability::NonNullable => Validity::NonNullable,
447            Nullability::Nullable => Validity::AllValid,
448        }
449    }
450}
451
452impl Validity {
453    pub fn from_mask(mask: Mask, nullability: Nullability) -> Self {
454        assert!(
455            nullability == Nullability::Nullable || matches!(mask, Mask::AllTrue(_)),
456            "NonNullable validity must be AllValid",
457        );
458        match mask {
459            Mask::AllTrue(_) => match nullability {
460                Nullability::NonNullable => Validity::NonNullable,
461                Nullability::Nullable => Validity::AllValid,
462            },
463            Mask::AllFalse(_) => Validity::AllInvalid,
464            Mask::Values(values) => Validity::Array(values.into_array()),
465        }
466    }
467}
468
469impl IntoArray for Mask {
470    fn into_array(self) -> ArrayRef {
471        match self {
472            Self::AllTrue(len) => ConstantArray::new(true, len).into_array(),
473            Self::AllFalse(len) => ConstantArray::new(false, len).into_array(),
474            Self::Values(a) => a.into_array(),
475        }
476    }
477}
478
479impl IntoArray for &MaskValues {
480    fn into_array(self) -> ArrayRef {
481        BoolArray::new(self.boolean_buffer().clone(), Validity::NonNullable).into_array()
482    }
483}
484
485#[cfg(test)]
486mod tests {
487    use rstest::rstest;
488    use vortex_buffer::{Buffer, buffer};
489    use vortex_dtype::Nullability;
490    use vortex_mask::Mask;
491
492    use crate::array::Array;
493    use crate::arrays::{BoolArray, PrimitiveArray};
494    use crate::validity::Validity;
495    use crate::{ArrayRef, IntoArray};
496
497    #[rstest]
498    #[case(Validity::AllValid, 5, &[2, 4], Validity::AllValid, Validity::AllValid)]
499    #[case(Validity::AllValid, 5, &[2, 4], Validity::AllInvalid, Validity::Array(BoolArray::from_iter([true, true, false, true, false]).into_array())
500    )]
501    #[case(Validity::AllValid, 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([true, true, true, true, false]).into_array())
502    )]
503    #[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllValid, Validity::Array(BoolArray::from_iter([false, false, true, false, true]).into_array())
504    )]
505    #[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllInvalid, Validity::AllInvalid)]
506    #[case(Validity::AllInvalid, 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([false, false, true, false, false]).into_array())
507    )]
508    #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::AllValid, Validity::Array(BoolArray::from_iter([false, true, true, true, true]).into_array())
509    )]
510    #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::AllInvalid, Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array())
511    )]
512    #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([false, true, true, true, false]).into_array())
513    )]
514    fn patch_validity(
515        #[case] validity: Validity,
516        #[case] len: usize,
517        #[case] positions: &[u64],
518        #[case] patches: Validity,
519        #[case] expected: Validity,
520    ) {
521        let indices =
522            PrimitiveArray::new(Buffer::copy_from(positions), Validity::NonNullable).into_array();
523        assert_eq!(
524            validity.patch(len, 0, &indices, &patches).unwrap(),
525            expected
526        );
527    }
528
529    #[test]
530    #[should_panic]
531    fn out_of_bounds_patch() {
532        Validity::NonNullable
533            .patch(2, 0, &buffer![4].into_array(), &Validity::AllInvalid)
534            .unwrap();
535    }
536
537    #[test]
538    #[should_panic]
539    fn into_validity_nullable() {
540        Validity::from_mask(Mask::AllFalse(10), Nullability::NonNullable);
541    }
542
543    #[test]
544    #[should_panic]
545    fn into_validity_nullable_array() {
546        Validity::from_mask(Mask::from_iter(vec![true, false]), Nullability::NonNullable);
547    }
548
549    #[rstest]
550    #[case(Validity::AllValid, PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(), Validity::from_iter(vec![true, false]))]
551    #[case(Validity::AllValid, buffer![0, 1].into_array(), Validity::AllValid)]
552    #[case(Validity::AllValid, PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(), Validity::AllInvalid)]
553    #[case(Validity::NonNullable, PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(), Validity::from_iter(vec![true, false]))]
554    #[case(Validity::NonNullable, buffer![0, 1].into_array(), Validity::NonNullable)]
555    #[case(Validity::NonNullable, PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(), Validity::AllInvalid)]
556    fn validity_take(
557        #[case] validity: Validity,
558        #[case] indices: ArrayRef,
559        #[case] expected: Validity,
560    ) {
561        assert_eq!(validity.take(&indices).unwrap(), expected);
562    }
563}