vortex_array/
validity.rs

1//! Array validity and nullability behavior, used by arrays and compute functions.
2
3use std::fmt::Debug;
4use std::ops::{BitAnd, Not};
5
6use arrow_buffer::{BooleanBuffer, NullBuffer};
7use vortex_dtype::{DType, Nullability};
8use vortex_error::{VortexExpect as _, VortexResult, vortex_bail, vortex_err, vortex_panic};
9use vortex_mask::{AllOr, Mask, MaskValues};
10use vortex_scalar::Scalar;
11
12use crate::arrays::{BoolArray, ConstantArray};
13use crate::compute::{fill_null, filter, scalar_at, slice, take};
14use crate::patches::Patches;
15use crate::{Array, ArrayRef, ArrayVariants, IntoArray, ToCanonical};
16
17/// Validity information for an array
18#[derive(Clone, Debug)]
19pub enum Validity {
20    /// Items *can't* be null
21    NonNullable,
22    /// All items are valid
23    AllValid,
24    /// All items are null
25    AllInvalid,
26    /// Specified items are null
27    Array(ArrayRef),
28}
29
30impl Validity {
31    /// The [`DType`] of the underlying validity array (if it exists).
32    pub const DTYPE: DType = DType::Bool(Nullability::NonNullable);
33
34    pub fn null_count(&self, length: usize) -> VortexResult<usize> {
35        match self {
36            Self::NonNullable | Self::AllValid => Ok(0),
37            Self::AllInvalid => Ok(length),
38            Self::Array(a) => {
39                let validity_len = a.len();
40                if validity_len != length {
41                    vortex_bail!(
42                        "Validity array length {} doesn't match array length {}",
43                        validity_len,
44                        length
45                    )
46                }
47                let true_count = a
48                    .as_bool_typed()
49                    .vortex_expect("Validity array must be boolean")
50                    .true_count()?;
51                Ok(length - true_count)
52            }
53        }
54    }
55
56    /// If Validity is [`Validity::Array`], returns the array, otherwise returns `None`.
57    pub fn into_array(self) -> Option<ArrayRef> {
58        match self {
59            Self::Array(a) => Some(a),
60            _ => None,
61        }
62    }
63
64    /// If Validity is [`Validity::Array`], returns a reference to the array array, otherwise returns `None`.
65    pub fn as_array(&self) -> Option<&ArrayRef> {
66        match self {
67            Self::Array(a) => Some(a),
68            _ => None,
69        }
70    }
71
72    pub fn nullability(&self) -> Nullability {
73        match self {
74            Self::NonNullable => Nullability::NonNullable,
75            _ => Nullability::Nullable,
76        }
77    }
78
79    /// The union nullability and validity.
80    pub fn union_nullability(self, nullability: Nullability) -> Self {
81        match nullability {
82            Nullability::NonNullable => self,
83            Nullability::Nullable => self.into_nullable(),
84        }
85    }
86
87    pub fn all_valid(&self) -> VortexResult<bool> {
88        Ok(match self {
89            Validity::NonNullable | Validity::AllValid => true,
90            Validity::AllInvalid => false,
91            Validity::Array(array) => {
92                // TODO(ngates): replace with SUM compute function
93                array.to_bool()?.boolean_buffer().count_set_bits() == array.len()
94            }
95        })
96    }
97
98    pub fn all_invalid(&self) -> VortexResult<bool> {
99        Ok(match self {
100            Validity::NonNullable | Validity::AllValid => false,
101            Validity::AllInvalid => true,
102            Validity::Array(array) => {
103                // TODO(ngates): replace with SUM compute function
104                array.to_bool()?.boolean_buffer().count_set_bits() == 0
105            }
106        })
107    }
108
109    /// Returns whether the `index` item is valid.
110    #[inline]
111    pub fn is_valid(&self, index: usize) -> VortexResult<bool> {
112        Ok(match self {
113            Self::NonNullable | Self::AllValid => true,
114            Self::AllInvalid => false,
115            Self::Array(a) => {
116                let scalar = scalar_at(a, index)?;
117                scalar
118                    .as_bool()
119                    .value()
120                    .vortex_expect("Validity must be non-nullable")
121            }
122        })
123    }
124
125    #[inline]
126    pub fn is_null(&self, index: usize) -> VortexResult<bool> {
127        Ok(!self.is_valid(index)?)
128    }
129
130    pub fn slice(&self, start: usize, stop: usize) -> VortexResult<Self> {
131        match self {
132            Self::Array(a) => Ok(Self::Array(slice(a, start, stop)?)),
133            _ => Ok(self.clone()),
134        }
135    }
136
137    pub fn take(&self, indices: &dyn Array) -> VortexResult<Self> {
138        match self {
139            Self::NonNullable => match indices.validity_mask()?.boolean_buffer() {
140                AllOr::All => {
141                    if indices.dtype().is_nullable() {
142                        Ok(Self::AllValid)
143                    } else {
144                        Ok(Self::NonNullable)
145                    }
146                }
147                AllOr::None => Ok(Self::AllInvalid),
148                AllOr::Some(buf) => Ok(Validity::from(buf.clone())),
149            },
150            Self::AllValid => match indices.validity_mask()?.boolean_buffer() {
151                AllOr::All => Ok(Self::AllValid),
152                AllOr::None => Ok(Self::AllInvalid),
153                AllOr::Some(buf) => Ok(Validity::from(buf.clone())),
154            },
155            Self::AllInvalid => Ok(Self::AllInvalid),
156            Self::Array(is_valid) => {
157                let maybe_is_valid = take(is_valid, indices)?;
158                // Null indices invalidate that position.
159                let is_valid = fill_null(&maybe_is_valid, Scalar::from(false))?;
160                Ok(Self::Array(is_valid))
161            }
162        }
163    }
164
165    /// Keep only the entries for which the mask is true.
166    ///
167    /// The result has length equal to the number of true values in mask.
168    pub fn filter(&self, mask: &Mask) -> VortexResult<Self> {
169        // NOTE(ngates): we take the mask as a reference to avoid the caller cloning unnecessarily
170        //  if we happen to be NonNullable, AllValid, or AllInvalid.
171        match self {
172            v @ (Validity::NonNullable | Validity::AllValid | Validity::AllInvalid) => {
173                Ok(v.clone())
174            }
175            Validity::Array(arr) => Ok(Validity::Array(filter(arr, mask)?)),
176        }
177    }
178
179    /// Set to false any entries for which the mask is true.
180    ///
181    /// The result is always nullable. The result has the same length as self.
182    pub fn mask(&self, mask: &Mask) -> VortexResult<Self> {
183        match mask.boolean_buffer() {
184            AllOr::All => Ok(Validity::AllInvalid),
185            AllOr::None => Ok(self.clone()),
186            AllOr::Some(make_invalid) => Ok(match self {
187                Validity::NonNullable | Validity::AllValid => {
188                    Validity::Array(BoolArray::from(make_invalid.not()).into_array())
189                }
190                Validity::AllInvalid => Validity::AllInvalid,
191                Validity::Array(is_valid) => {
192                    let is_valid = is_valid.to_bool()?;
193                    let keep_valid = make_invalid.not();
194                    Validity::from(is_valid.boolean_buffer().bitand(&keep_valid))
195                }
196            }),
197        }
198    }
199
200    pub fn to_mask(&self, length: usize) -> VortexResult<Mask> {
201        Ok(match self {
202            Self::NonNullable | Self::AllValid => Mask::AllTrue(length),
203            Self::AllInvalid => Mask::AllFalse(length),
204            Self::Array(is_valid) => {
205                assert_eq!(
206                    is_valid.len(),
207                    length,
208                    "Validity::Array length must equal to_logical's argument: {}, {}.",
209                    is_valid.len(),
210                    length,
211                );
212                Mask::try_from(&is_valid.to_bool()?)?
213            }
214        })
215    }
216
217    /// Logically & two Validity values of the same length
218    pub fn and(self, rhs: Validity) -> VortexResult<Validity> {
219        let validity = match (self, rhs) {
220            // Should be pretty clear
221            (Validity::NonNullable, Validity::NonNullable) => Validity::NonNullable,
222            // Any `AllInvalid` makes the output all invalid values
223            (Validity::AllInvalid, _) | (_, Validity::AllInvalid) => Validity::AllInvalid,
224            // All truthy values on one side, which makes no effect on an `Array` variant
225            (Validity::Array(a), Validity::AllValid)
226            | (Validity::Array(a), Validity::NonNullable)
227            | (Validity::NonNullable, Validity::Array(a))
228            | (Validity::AllValid, Validity::Array(a)) => Validity::Array(a),
229            // Both sides are all valid
230            (Validity::NonNullable, Validity::AllValid)
231            | (Validity::AllValid, Validity::NonNullable)
232            | (Validity::AllValid, Validity::AllValid) => Validity::AllValid,
233            // Here we actually have to do some work
234            (Validity::Array(lhs), Validity::Array(rhs)) => {
235                let lhs = lhs.to_bool()?;
236                let rhs = rhs.to_bool()?;
237
238                let lhs = lhs.boolean_buffer();
239                let rhs = rhs.boolean_buffer();
240
241                Validity::from(lhs.bitand(rhs))
242            }
243        };
244
245        Ok(validity)
246    }
247
248    pub fn patch(
249        self,
250        len: usize,
251        indices_offset: usize,
252        indices: &dyn Array,
253        patches: &Validity,
254    ) -> VortexResult<Self> {
255        match (&self, patches) {
256            (Validity::NonNullable, Validity::NonNullable) => return Ok(Validity::NonNullable),
257            (Validity::NonNullable, _) => {
258                vortex_bail!("Can't patch a non-nullable validity with nullable validity")
259            }
260            (_, Validity::NonNullable) => {
261                vortex_bail!("Can't patch a nullable validity with non-nullable validity")
262            }
263            (Validity::AllValid, Validity::AllValid) => return Ok(Validity::AllValid),
264            (Validity::AllInvalid, Validity::AllInvalid) => return Ok(Validity::AllInvalid),
265            _ => {}
266        };
267
268        let own_nullability = if self == Validity::NonNullable {
269            Nullability::NonNullable
270        } else {
271            Nullability::Nullable
272        };
273
274        let source = match self {
275            Validity::NonNullable => BoolArray::from(BooleanBuffer::new_set(len)),
276            Validity::AllValid => BoolArray::from(BooleanBuffer::new_set(len)),
277            Validity::AllInvalid => BoolArray::from(BooleanBuffer::new_unset(len)),
278            Validity::Array(a) => a.to_bool()?,
279        };
280
281        let patch_values = match patches {
282            Validity::NonNullable => BoolArray::from(BooleanBuffer::new_set(indices.len())),
283            Validity::AllValid => BoolArray::from(BooleanBuffer::new_set(indices.len())),
284            Validity::AllInvalid => BoolArray::from(BooleanBuffer::new_unset(indices.len())),
285            Validity::Array(a) => a.to_bool()?,
286        };
287
288        let patches = Patches::new(
289            len,
290            indices_offset,
291            indices.to_array(),
292            patch_values.into_array(),
293        );
294
295        Ok(Self::from_array(
296            source.patch(&patches)?.into_array(),
297            own_nullability,
298        ))
299    }
300
301    /// Convert into a nullable variant
302    pub fn into_nullable(self) -> Validity {
303        match self {
304            Self::NonNullable => Self::AllValid,
305            _ => self,
306        }
307    }
308
309    /// Convert into a non-nullable variant
310    pub fn into_non_nullable(self) -> Option<Validity> {
311        match self {
312            Self::NonNullable => Some(Self::NonNullable),
313            Self::AllValid => Some(Self::NonNullable),
314            Self::AllInvalid => None,
315            Self::Array(is_valid) => {
316                is_valid
317                    .statistics()
318                    .compute_min::<bool>()
319                    .vortex_expect("validity array must support min")
320                    .then(|| {
321                        // min true => all true
322                        Self::NonNullable
323                    })
324            }
325        }
326    }
327
328    /// Convert into a variant compatible with the given nullability, if possible.
329    pub fn cast_nullability(self, nullability: Nullability) -> VortexResult<Validity> {
330        match nullability {
331            Nullability::NonNullable => self.into_non_nullable().ok_or_else(|| {
332                vortex_err!("Cannot cast array with invalid values to non-nullable type.")
333            }),
334            Nullability::Nullable => Ok(self.into_nullable()),
335        }
336    }
337
338    /// Create Validity by copying the given array's validity.
339    pub fn copy_from_array(array: &dyn Array) -> VortexResult<Self> {
340        Ok(Validity::from_mask(
341            array.validity_mask()?,
342            array.dtype().nullability(),
343        ))
344    }
345
346    /// Create Validity from boolean array with given nullability of the array.
347    ///
348    /// Note: You want to pass the nullability of parent array and not the nullability of the validity array itself
349    ///     as that is always nonnullable
350    fn from_array(value: ArrayRef, nullability: Nullability) -> Self {
351        if !matches!(value.dtype(), DType::Bool(Nullability::NonNullable)) {
352            vortex_panic!("Expected a non-nullable boolean array")
353        }
354        match nullability {
355            Nullability::NonNullable => Self::NonNullable,
356            Nullability::Nullable => Self::Array(value),
357        }
358    }
359
360    /// Returns the length of the validity array, if it exists.
361    pub fn maybe_len(&self) -> Option<usize> {
362        match self {
363            Self::NonNullable | Self::AllValid | Self::AllInvalid => None,
364            Self::Array(a) => Some(a.len()),
365        }
366    }
367
368    pub fn uncompressed_size(&self) -> usize {
369        if let Validity::Array(a) = self {
370            a.len().div_ceil(8)
371        } else {
372            0
373        }
374    }
375
376    pub fn is_array(&self) -> bool {
377        matches!(self, Validity::Array(_))
378    }
379}
380
381impl PartialEq for Validity {
382    fn eq(&self, other: &Self) -> bool {
383        match (self, other) {
384            (Self::NonNullable, Self::NonNullable) => true,
385            (Self::AllValid, Self::AllValid) => true,
386            (Self::AllInvalid, Self::AllInvalid) => true,
387            (Self::Array(a), Self::Array(b)) => {
388                let a = a
389                    .to_bool()
390                    .vortex_expect("Failed to get Validity Array as BoolArray");
391                let b = b
392                    .to_bool()
393                    .vortex_expect("Failed to get Validity Array as BoolArray");
394                a.boolean_buffer() == b.boolean_buffer()
395            }
396            _ => false,
397        }
398    }
399}
400
401impl From<BooleanBuffer> for Validity {
402    fn from(value: BooleanBuffer) -> Self {
403        if value.count_set_bits() == value.len() {
404            Self::AllValid
405        } else if value.count_set_bits() == 0 {
406            Self::AllInvalid
407        } else {
408            Self::Array(BoolArray::from(value).into_array())
409        }
410    }
411}
412
413impl From<NullBuffer> for Validity {
414    fn from(value: NullBuffer) -> Self {
415        value.into_inner().into()
416    }
417}
418
419impl FromIterator<Mask> for Validity {
420    fn from_iter<T: IntoIterator<Item = Mask>>(iter: T) -> Self {
421        Validity::from_mask(iter.into_iter().collect(), Nullability::Nullable)
422    }
423}
424
425impl FromIterator<bool> for Validity {
426    fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
427        Validity::from(BooleanBuffer::from_iter(iter))
428    }
429}
430
431impl From<Nullability> for Validity {
432    fn from(value: Nullability) -> Self {
433        match value {
434            Nullability::NonNullable => Validity::NonNullable,
435            Nullability::Nullable => Validity::AllValid,
436        }
437    }
438}
439
440impl Validity {
441    pub fn from_mask(mask: Mask, nullability: Nullability) -> Self {
442        assert!(
443            nullability == Nullability::Nullable || matches!(mask, Mask::AllTrue(_)),
444            "NonNullable validity must be AllValid",
445        );
446        match mask {
447            Mask::AllTrue(_) => match nullability {
448                Nullability::NonNullable => Validity::NonNullable,
449                Nullability::Nullable => Validity::AllValid,
450            },
451            Mask::AllFalse(_) => Validity::AllInvalid,
452            Mask::Values(values) => Validity::Array(values.into_array()),
453        }
454    }
455}
456
457impl IntoArray for Mask {
458    fn into_array(self) -> ArrayRef {
459        match self {
460            Self::AllTrue(len) => ConstantArray::new(true, len).into_array(),
461            Self::AllFalse(len) => ConstantArray::new(false, len).into_array(),
462            Self::Values(a) => a.into_array(),
463        }
464    }
465}
466
467impl IntoArray for &MaskValues {
468    fn into_array(self) -> ArrayRef {
469        BoolArray::new(self.boolean_buffer().clone(), Validity::NonNullable).into_array()
470    }
471}
472
473#[cfg(test)]
474mod tests {
475    use rstest::rstest;
476    use vortex_buffer::{Buffer, buffer};
477    use vortex_dtype::Nullability;
478    use vortex_mask::Mask;
479
480    use crate::array::Array;
481    use crate::arrays::{BoolArray, PrimitiveArray};
482    use crate::validity::Validity;
483    use crate::{ArrayRef, IntoArray};
484
485    #[rstest]
486    #[case(Validity::AllValid, 5, &[2, 4], Validity::AllValid, Validity::AllValid)]
487    #[case(Validity::AllValid, 5, &[2, 4], Validity::AllInvalid, Validity::Array(BoolArray::from_iter([true, true, false, true, false]).into_array())
488    )]
489    #[case(Validity::AllValid, 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([true, true, true, true, false]).into_array())
490    )]
491    #[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllValid, Validity::Array(BoolArray::from_iter([false, false, true, false, true]).into_array())
492    )]
493    #[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllInvalid, Validity::AllInvalid)]
494    #[case(Validity::AllInvalid, 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([false, false, true, false, false]).into_array())
495    )]
496    #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::AllValid, Validity::Array(BoolArray::from_iter([false, true, true, true, true]).into_array())
497    )]
498    #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::AllInvalid, Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array())
499    )]
500    #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([false, true, true, true, false]).into_array())
501    )]
502    fn patch_validity(
503        #[case] validity: Validity,
504        #[case] len: usize,
505        #[case] positions: &[u64],
506        #[case] patches: Validity,
507        #[case] expected: Validity,
508    ) {
509        let indices =
510            PrimitiveArray::new(Buffer::copy_from(positions), Validity::NonNullable).into_array();
511        assert_eq!(
512            validity.patch(len, 0, &indices, &patches).unwrap(),
513            expected
514        );
515    }
516
517    #[test]
518    #[should_panic]
519    fn out_of_bounds_patch() {
520        Validity::NonNullable
521            .patch(2, 0, &buffer![4].into_array(), &Validity::AllInvalid)
522            .unwrap();
523    }
524
525    #[test]
526    #[should_panic]
527    fn into_validity_nullable() {
528        Validity::from_mask(Mask::AllFalse(10), Nullability::NonNullable);
529    }
530
531    #[test]
532    #[should_panic]
533    fn into_validity_nullable_array() {
534        Validity::from_mask(Mask::from_iter(vec![true, false]), Nullability::NonNullable);
535    }
536
537    #[rstest]
538    #[case(Validity::AllValid, PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(), Validity::from_iter(vec![true, false]))]
539    #[case(Validity::AllValid, buffer![0, 1].into_array(), Validity::AllValid)]
540    #[case(Validity::AllValid, PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(), Validity::AllInvalid)]
541    #[case(Validity::NonNullable, PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(), Validity::from_iter(vec![true, false]))]
542    #[case(Validity::NonNullable, buffer![0, 1].into_array(), Validity::NonNullable)]
543    #[case(Validity::NonNullable, PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(), Validity::AllInvalid)]
544    fn validity_take(
545        #[case] validity: Validity,
546        #[case] indices: ArrayRef,
547        #[case] expected: Validity,
548    ) {
549        assert_eq!(validity.take(&indices).unwrap(), expected);
550    }
551}