vortex_array/
validity.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Array validity and nullability behavior, used by arrays and compute functions.
5
6use std::fmt::Debug;
7use std::ops::{BitAnd, Not};
8
9use arrow_buffer::{BooleanBuffer, NullBuffer};
10use vortex_dtype::{DType, Nullability};
11use vortex_error::{VortexExpect as _, VortexResult, vortex_bail, vortex_err, vortex_panic};
12use vortex_mask::{AllOr, Mask, MaskValues};
13use vortex_scalar::Scalar;
14
15use crate::arrays::{BoolArray, ConstantArray};
16use crate::compute::{fill_null, filter, sum, take};
17use crate::patches::Patches;
18use crate::{Array, ArrayRef, IntoArray, ToCanonical};
19
20/// Validity information for an array
21#[derive(Clone, Debug)]
22pub enum Validity {
23    /// Items *can't* be null
24    NonNullable,
25    /// All items are valid
26    AllValid,
27    /// All items are null
28    AllInvalid,
29    /// Specified items are null
30    Array(ArrayRef),
31}
32
33impl Validity {
34    /// The [`DType`] of the underlying validity array (if it exists).
35    pub const DTYPE: DType = DType::Bool(Nullability::NonNullable);
36
37    pub fn null_count(&self, length: usize) -> VortexResult<usize> {
38        match self {
39            Self::NonNullable | Self::AllValid => Ok(0),
40            Self::AllInvalid => Ok(length),
41            Self::Array(a) => {
42                let validity_len = a.len();
43                if validity_len != length {
44                    vortex_bail!(
45                        "Validity array length {} doesn't match array length {}",
46                        validity_len,
47                        length
48                    )
49                }
50                let true_count = sum(a)?
51                    .as_primitive()
52                    .as_::<usize>()?
53                    .ok_or_else(|| vortex_err!("Failed to compute true count"))?;
54                Ok(length - true_count)
55            }
56        }
57    }
58
59    /// If Validity is [`Validity::Array`], returns the array, otherwise returns `None`.
60    pub fn into_array(self) -> Option<ArrayRef> {
61        match self {
62            Self::Array(a) => Some(a),
63            _ => None,
64        }
65    }
66
67    /// If Validity is [`Validity::Array`], returns a reference to the array array, otherwise returns `None`.
68    pub fn as_array(&self) -> Option<&ArrayRef> {
69        match self {
70            Self::Array(a) => Some(a),
71            _ => None,
72        }
73    }
74
75    pub fn nullability(&self) -> Nullability {
76        match self {
77            Self::NonNullable => Nullability::NonNullable,
78            _ => Nullability::Nullable,
79        }
80    }
81
82    /// The union nullability and validity.
83    pub fn union_nullability(self, nullability: Nullability) -> Self {
84        match nullability {
85            Nullability::NonNullable => self,
86            Nullability::Nullable => self.into_nullable(),
87        }
88    }
89
90    pub fn all_valid(&self) -> VortexResult<bool> {
91        Ok(match self {
92            Validity::NonNullable | Validity::AllValid => true,
93            Validity::AllInvalid => false,
94            Validity::Array(array) => sum(array)
95                .map(|v| {
96                    v.as_primitive()
97                        .typed_value::<u64>()
98                        .map(|count| count == array.len() as u64)
99                })?
100                .ok_or_else(|| vortex_err!("Failed to compute sum for validity array"))?,
101        })
102    }
103
104    pub fn all_invalid(&self) -> VortexResult<bool> {
105        Ok(match self {
106            Validity::NonNullable | Validity::AllValid => false,
107            Validity::AllInvalid => true,
108            Validity::Array(array) => sum(array)
109                .map(|v| {
110                    v.as_primitive()
111                        .typed_value::<u64>()
112                        .map(|count| count == 0u64)
113                })?
114                .ok_or_else(|| vortex_err!("Failed to compute sum for validity array"))?,
115        })
116    }
117
118    /// Returns whether the `index` item is valid.
119    #[inline]
120    pub fn is_valid(&self, index: usize) -> VortexResult<bool> {
121        Ok(match self {
122            Self::NonNullable | Self::AllValid => true,
123            Self::AllInvalid => false,
124            Self::Array(a) => {
125                let scalar = a.scalar_at(index)?;
126                scalar
127                    .as_bool()
128                    .value()
129                    .vortex_expect("Validity must be non-nullable")
130            }
131        })
132    }
133
134    #[inline]
135    pub fn is_null(&self, index: usize) -> VortexResult<bool> {
136        Ok(!self.is_valid(index)?)
137    }
138
139    pub fn slice(&self, start: usize, stop: usize) -> VortexResult<Self> {
140        match self {
141            Self::Array(a) => Ok(Self::Array(a.slice(start, stop)?)),
142            _ => Ok(self.clone()),
143        }
144    }
145
146    pub fn take(&self, indices: &dyn Array) -> VortexResult<Self> {
147        match self {
148            Self::NonNullable => match indices.validity_mask()?.boolean_buffer() {
149                AllOr::All => {
150                    if indices.dtype().is_nullable() {
151                        Ok(Self::AllValid)
152                    } else {
153                        Ok(Self::NonNullable)
154                    }
155                }
156                AllOr::None => Ok(Self::AllInvalid),
157                AllOr::Some(buf) => Ok(Validity::from(buf.clone())),
158            },
159            Self::AllValid => match indices.validity_mask()?.boolean_buffer() {
160                AllOr::All => Ok(Self::AllValid),
161                AllOr::None => Ok(Self::AllInvalid),
162                AllOr::Some(buf) => Ok(Validity::from(buf.clone())),
163            },
164            Self::AllInvalid => Ok(Self::AllInvalid),
165            Self::Array(is_valid) => {
166                let maybe_is_valid = take(is_valid, indices)?;
167                // Null indices invalidate that position.
168                let is_valid = fill_null(&maybe_is_valid, &Scalar::from(false))?;
169                Ok(Self::Array(is_valid))
170            }
171        }
172    }
173
174    /// Keep only the entries for which the mask is true.
175    ///
176    /// The result has length equal to the number of true values in mask.
177    pub fn filter(&self, mask: &Mask) -> VortexResult<Self> {
178        // NOTE(ngates): we take the mask as a reference to avoid the caller cloning unnecessarily
179        //  if we happen to be NonNullable, AllValid, or AllInvalid.
180        match self {
181            v @ (Validity::NonNullable | Validity::AllValid | Validity::AllInvalid) => {
182                Ok(v.clone())
183            }
184            Validity::Array(arr) => Ok(Validity::Array(filter(arr, mask)?)),
185        }
186    }
187
188    /// Set to false any entries for which the mask is true.
189    ///
190    /// The result is always nullable. The result has the same length as self.
191    pub fn mask(&self, mask: &Mask) -> VortexResult<Self> {
192        match mask.boolean_buffer() {
193            AllOr::All => Ok(Validity::AllInvalid),
194            AllOr::None => Ok(self.clone()),
195            AllOr::Some(make_invalid) => Ok(match self {
196                Validity::NonNullable | Validity::AllValid => {
197                    Validity::Array(BoolArray::from(make_invalid.not()).into_array())
198                }
199                Validity::AllInvalid => Validity::AllInvalid,
200                Validity::Array(is_valid) => {
201                    let is_valid = is_valid.to_bool()?;
202                    let keep_valid = make_invalid.not();
203                    Validity::from(is_valid.boolean_buffer().bitand(&keep_valid))
204                }
205            }),
206        }
207    }
208
209    pub fn to_mask(&self, length: usize) -> VortexResult<Mask> {
210        Ok(match self {
211            Self::NonNullable | Self::AllValid => Mask::AllTrue(length),
212            Self::AllInvalid => Mask::AllFalse(length),
213            Self::Array(is_valid) => {
214                assert_eq!(
215                    is_valid.len(),
216                    length,
217                    "Validity::Array length must equal to_logical's argument: {}, {}.",
218                    is_valid.len(),
219                    length,
220                );
221                Mask::try_from(&is_valid.to_bool()?)?
222            }
223        })
224    }
225
226    /// Logically & two Validity values of the same length
227    pub fn and(self, rhs: Validity) -> VortexResult<Validity> {
228        let validity = match (self, rhs) {
229            // Should be pretty clear
230            (Validity::NonNullable, Validity::NonNullable) => Validity::NonNullable,
231            // Any `AllInvalid` makes the output all invalid values
232            (Validity::AllInvalid, _) | (_, Validity::AllInvalid) => Validity::AllInvalid,
233            // All truthy values on one side, which makes no effect on an `Array` variant
234            (Validity::Array(a), Validity::AllValid)
235            | (Validity::Array(a), Validity::NonNullable)
236            | (Validity::NonNullable, Validity::Array(a))
237            | (Validity::AllValid, Validity::Array(a)) => Validity::Array(a),
238            // Both sides are all valid
239            (Validity::NonNullable, Validity::AllValid)
240            | (Validity::AllValid, Validity::NonNullable)
241            | (Validity::AllValid, Validity::AllValid) => Validity::AllValid,
242            // Here we actually have to do some work
243            (Validity::Array(lhs), Validity::Array(rhs)) => {
244                let lhs = lhs.to_bool()?;
245                let rhs = rhs.to_bool()?;
246
247                let lhs = lhs.boolean_buffer();
248                let rhs = rhs.boolean_buffer();
249
250                Validity::from(lhs.bitand(rhs))
251            }
252        };
253
254        Ok(validity)
255    }
256
257    pub fn patch(
258        self,
259        len: usize,
260        indices_offset: usize,
261        indices: &dyn Array,
262        patches: &Validity,
263    ) -> VortexResult<Self> {
264        match (&self, patches) {
265            (Validity::NonNullable, Validity::NonNullable) => return Ok(Validity::NonNullable),
266            (Validity::NonNullable, _) => {
267                vortex_bail!("Can't patch a non-nullable validity with nullable validity")
268            }
269            (_, Validity::NonNullable) => {
270                vortex_bail!("Can't patch a nullable validity with non-nullable validity")
271            }
272            (Validity::AllValid, Validity::AllValid) => return Ok(Validity::AllValid),
273            (Validity::AllInvalid, Validity::AllInvalid) => return Ok(Validity::AllInvalid),
274            _ => {}
275        };
276
277        let own_nullability = if self == Validity::NonNullable {
278            Nullability::NonNullable
279        } else {
280            Nullability::Nullable
281        };
282
283        let source = match self {
284            Validity::NonNullable => BoolArray::from(BooleanBuffer::new_set(len)),
285            Validity::AllValid => BoolArray::from(BooleanBuffer::new_set(len)),
286            Validity::AllInvalid => BoolArray::from(BooleanBuffer::new_unset(len)),
287            Validity::Array(a) => a.to_bool()?,
288        };
289
290        let patch_values = match patches {
291            Validity::NonNullable => BoolArray::from(BooleanBuffer::new_set(indices.len())),
292            Validity::AllValid => BoolArray::from(BooleanBuffer::new_set(indices.len())),
293            Validity::AllInvalid => BoolArray::from(BooleanBuffer::new_unset(indices.len())),
294            Validity::Array(a) => a.to_bool()?,
295        };
296
297        let patches = Patches::new(
298            len,
299            indices_offset,
300            indices.to_array(),
301            patch_values.into_array(),
302        );
303
304        Ok(Self::from_array(
305            source.patch(&patches)?.into_array(),
306            own_nullability,
307        ))
308    }
309
310    /// Convert into a nullable variant
311    pub fn into_nullable(self) -> Validity {
312        match self {
313            Self::NonNullable => Self::AllValid,
314            _ => self,
315        }
316    }
317
318    /// Convert into a non-nullable variant
319    pub fn into_non_nullable(self) -> Option<Validity> {
320        match self {
321            Self::NonNullable => Some(Self::NonNullable),
322            Self::AllValid => Some(Self::NonNullable),
323            Self::AllInvalid => None,
324            Self::Array(is_valid) => {
325                is_valid
326                    .statistics()
327                    .compute_min::<bool>()
328                    .vortex_expect("validity array must support min")
329                    .then(|| {
330                        // min true => all true
331                        Self::NonNullable
332                    })
333            }
334        }
335    }
336
337    /// Convert into a variant compatible with the given nullability, if possible.
338    pub fn cast_nullability(self, nullability: Nullability) -> VortexResult<Validity> {
339        match nullability {
340            Nullability::NonNullable => self.into_non_nullable().ok_or_else(|| {
341                vortex_err!("Cannot cast array with invalid values to non-nullable type.")
342            }),
343            Nullability::Nullable => Ok(self.into_nullable()),
344        }
345    }
346
347    /// Create Validity by copying the given array's validity.
348    pub fn copy_from_array(array: &dyn Array) -> VortexResult<Self> {
349        Ok(Validity::from_mask(
350            array.validity_mask()?,
351            array.dtype().nullability(),
352        ))
353    }
354
355    /// Create Validity from boolean array with given nullability of the array.
356    ///
357    /// Note: You want to pass the nullability of parent array and not the nullability of the validity array itself
358    ///     as that is always nonnullable
359    fn from_array(value: ArrayRef, nullability: Nullability) -> Self {
360        if !matches!(value.dtype(), DType::Bool(Nullability::NonNullable)) {
361            vortex_panic!("Expected a non-nullable boolean array")
362        }
363        match nullability {
364            Nullability::NonNullable => Self::NonNullable,
365            Nullability::Nullable => Self::Array(value),
366        }
367    }
368
369    /// Returns the length of the validity array, if it exists.
370    pub fn maybe_len(&self) -> Option<usize> {
371        match self {
372            Self::NonNullable | Self::AllValid | Self::AllInvalid => None,
373            Self::Array(a) => Some(a.len()),
374        }
375    }
376
377    pub fn uncompressed_size(&self) -> usize {
378        if let Validity::Array(a) = self {
379            a.len().div_ceil(8)
380        } else {
381            0
382        }
383    }
384
385    pub fn is_array(&self) -> bool {
386        matches!(self, Validity::Array(_))
387    }
388}
389
390impl PartialEq for Validity {
391    fn eq(&self, other: &Self) -> bool {
392        match (self, other) {
393            (Self::NonNullable, Self::NonNullable) => true,
394            (Self::AllValid, Self::AllValid) => true,
395            (Self::AllInvalid, Self::AllInvalid) => true,
396            (Self::Array(a), Self::Array(b)) => {
397                let a = a
398                    .to_bool()
399                    .vortex_expect("Failed to get Validity Array as BoolArray");
400                let b = b
401                    .to_bool()
402                    .vortex_expect("Failed to get Validity Array as BoolArray");
403                a.boolean_buffer() == b.boolean_buffer()
404            }
405            _ => false,
406        }
407    }
408}
409
410impl From<BooleanBuffer> for Validity {
411    fn from(value: BooleanBuffer) -> Self {
412        if value.count_set_bits() == value.len() {
413            Self::AllValid
414        } else if value.count_set_bits() == 0 {
415            Self::AllInvalid
416        } else {
417            Self::Array(BoolArray::from(value).into_array())
418        }
419    }
420}
421
422impl From<NullBuffer> for Validity {
423    fn from(value: NullBuffer) -> Self {
424        value.into_inner().into()
425    }
426}
427
428impl FromIterator<Mask> for Validity {
429    fn from_iter<T: IntoIterator<Item = Mask>>(iter: T) -> Self {
430        Validity::from_mask(iter.into_iter().collect(), Nullability::Nullable)
431    }
432}
433
434impl FromIterator<bool> for Validity {
435    fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
436        Validity::from(BooleanBuffer::from_iter(iter))
437    }
438}
439
440impl From<Nullability> for Validity {
441    fn from(value: Nullability) -> Self {
442        match value {
443            Nullability::NonNullable => Validity::NonNullable,
444            Nullability::Nullable => Validity::AllValid,
445        }
446    }
447}
448
449impl Validity {
450    pub fn from_mask(mask: Mask, nullability: Nullability) -> Self {
451        assert!(
452            nullability == Nullability::Nullable || matches!(mask, Mask::AllTrue(_)),
453            "NonNullable validity must be AllValid",
454        );
455        match mask {
456            Mask::AllTrue(_) => match nullability {
457                Nullability::NonNullable => Validity::NonNullable,
458                Nullability::Nullable => Validity::AllValid,
459            },
460            Mask::AllFalse(_) => Validity::AllInvalid,
461            Mask::Values(values) => Validity::Array(values.into_array()),
462        }
463    }
464}
465
466impl IntoArray for Mask {
467    fn into_array(self) -> ArrayRef {
468        match self {
469            Self::AllTrue(len) => ConstantArray::new(true, len).into_array(),
470            Self::AllFalse(len) => ConstantArray::new(false, len).into_array(),
471            Self::Values(a) => a.into_array(),
472        }
473    }
474}
475
476impl IntoArray for &MaskValues {
477    fn into_array(self) -> ArrayRef {
478        BoolArray::new(self.boolean_buffer().clone(), Validity::NonNullable).into_array()
479    }
480}
481
482#[cfg(test)]
483mod tests {
484    use rstest::rstest;
485    use vortex_buffer::{Buffer, buffer};
486    use vortex_dtype::Nullability;
487    use vortex_mask::Mask;
488
489    use crate::arrays::{BoolArray, PrimitiveArray};
490    use crate::validity::Validity;
491    use crate::{ArrayRef, IntoArray};
492
493    #[rstest]
494    #[case(Validity::AllValid, 5, &[2, 4], Validity::AllValid, Validity::AllValid)]
495    #[case(Validity::AllValid, 5, &[2, 4], Validity::AllInvalid, Validity::Array(BoolArray::from_iter([true, true, false, true, false]).into_array())
496    )]
497    #[case(Validity::AllValid, 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([true, true, true, true, false]).into_array())
498    )]
499    #[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllValid, Validity::Array(BoolArray::from_iter([false, false, true, false, true]).into_array())
500    )]
501    #[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllInvalid, Validity::AllInvalid)]
502    #[case(Validity::AllInvalid, 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([false, false, true, false, false]).into_array())
503    )]
504    #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::AllValid, Validity::Array(BoolArray::from_iter([false, true, true, true, true]).into_array())
505    )]
506    #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::AllInvalid, Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array())
507    )]
508    #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([false, true, true, true, false]).into_array())
509    )]
510    fn patch_validity(
511        #[case] validity: Validity,
512        #[case] len: usize,
513        #[case] positions: &[u64],
514        #[case] patches: Validity,
515        #[case] expected: Validity,
516    ) {
517        let indices =
518            PrimitiveArray::new(Buffer::copy_from(positions), Validity::NonNullable).into_array();
519        assert_eq!(
520            validity.patch(len, 0, &indices, &patches).unwrap(),
521            expected
522        );
523    }
524
525    #[test]
526    #[should_panic]
527    fn out_of_bounds_patch() {
528        Validity::NonNullable
529            .patch(2, 0, &buffer![4].into_array(), &Validity::AllInvalid)
530            .unwrap();
531    }
532
533    #[test]
534    #[should_panic]
535    fn into_validity_nullable() {
536        Validity::from_mask(Mask::AllFalse(10), Nullability::NonNullable);
537    }
538
539    #[test]
540    #[should_panic]
541    fn into_validity_nullable_array() {
542        Validity::from_mask(Mask::from_iter(vec![true, false]), Nullability::NonNullable);
543    }
544
545    #[rstest]
546    #[case(Validity::AllValid, PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(), Validity::from_iter(vec![true, false]))]
547    #[case(Validity::AllValid, buffer![0, 1].into_array(), Validity::AllValid)]
548    #[case(Validity::AllValid, PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(), Validity::AllInvalid)]
549    #[case(Validity::NonNullable, PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(), Validity::from_iter(vec![true, false]))]
550    #[case(Validity::NonNullable, buffer![0, 1].into_array(), Validity::NonNullable)]
551    #[case(Validity::NonNullable, PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(), Validity::AllInvalid)]
552    fn validity_take(
553        #[case] validity: Validity,
554        #[case] indices: ArrayRef,
555        #[case] expected: Validity,
556    ) {
557        assert_eq!(validity.take(&indices).unwrap(), expected);
558    }
559}