1use std::fmt::Debug;
4use std::ops::{BitAnd, Not};
5
6use arrow_buffer::{BooleanBuffer, NullBuffer};
7use vortex_dtype::{DType, Nullability};
8use vortex_error::{VortexExpect as _, VortexResult, vortex_bail, vortex_err, vortex_panic};
9use vortex_mask::{AllOr, Mask, MaskValues};
10use vortex_scalar::Scalar;
11
12use crate::arrays::{BoolArray, ConstantArray};
13use crate::compute::{fill_null, filter, scalar_at, slice, take};
14use crate::patches::Patches;
15use crate::{Array, ArrayRef, ArrayVariants, IntoArray, ToCanonical};
16
17#[derive(Clone, Debug)]
19pub enum Validity {
20 NonNullable,
22 AllValid,
24 AllInvalid,
26 Array(ArrayRef),
28}
29
30impl Validity {
31 pub const DTYPE: DType = DType::Bool(Nullability::NonNullable);
33
34 pub fn null_count(&self, length: usize) -> VortexResult<usize> {
35 match self {
36 Self::NonNullable | Self::AllValid => Ok(0),
37 Self::AllInvalid => Ok(length),
38 Self::Array(a) => {
39 let validity_len = a.len();
40 if validity_len != length {
41 vortex_bail!(
42 "Validity array length {} doesn't match array length {}",
43 validity_len,
44 length
45 )
46 }
47 let true_count = a
48 .as_bool_typed()
49 .vortex_expect("Validity array must be boolean")
50 .true_count()?;
51 Ok(length - true_count)
52 }
53 }
54 }
55
56 pub fn into_array(self) -> Option<ArrayRef> {
58 match self {
59 Self::Array(a) => Some(a),
60 _ => None,
61 }
62 }
63
64 pub fn as_array(&self) -> Option<&ArrayRef> {
66 match self {
67 Self::Array(a) => Some(a),
68 _ => None,
69 }
70 }
71
72 pub fn nullability(&self) -> Nullability {
73 match self {
74 Self::NonNullable => Nullability::NonNullable,
75 _ => Nullability::Nullable,
76 }
77 }
78
79 pub fn union_nullability(self, nullability: Nullability) -> Self {
81 match nullability {
82 Nullability::NonNullable => self,
83 Nullability::Nullable => self.into_nullable(),
84 }
85 }
86
87 pub fn all_valid(&self) -> VortexResult<bool> {
88 Ok(match self {
89 Validity::NonNullable | Validity::AllValid => true,
90 Validity::AllInvalid => false,
91 Validity::Array(array) => {
92 array.to_bool()?.boolean_buffer().count_set_bits() == array.len()
94 }
95 })
96 }
97
98 pub fn all_invalid(&self) -> VortexResult<bool> {
99 Ok(match self {
100 Validity::NonNullable | Validity::AllValid => false,
101 Validity::AllInvalid => true,
102 Validity::Array(array) => {
103 array.to_bool()?.boolean_buffer().count_set_bits() == 0
105 }
106 })
107 }
108
109 #[inline]
111 pub fn is_valid(&self, index: usize) -> VortexResult<bool> {
112 Ok(match self {
113 Self::NonNullable | Self::AllValid => true,
114 Self::AllInvalid => false,
115 Self::Array(a) => {
116 let scalar = scalar_at(a, index)?;
117 scalar
118 .as_bool()
119 .value()
120 .vortex_expect("Validity must be non-nullable")
121 }
122 })
123 }
124
125 #[inline]
126 pub fn is_null(&self, index: usize) -> VortexResult<bool> {
127 Ok(!self.is_valid(index)?)
128 }
129
130 pub fn slice(&self, start: usize, stop: usize) -> VortexResult<Self> {
131 match self {
132 Self::Array(a) => Ok(Self::Array(slice(a, start, stop)?)),
133 _ => Ok(self.clone()),
134 }
135 }
136
137 pub fn take(&self, indices: &dyn Array) -> VortexResult<Self> {
138 match self {
139 Self::NonNullable => match indices.validity_mask()?.boolean_buffer() {
140 AllOr::All => {
141 if indices.dtype().is_nullable() {
142 Ok(Self::AllValid)
143 } else {
144 Ok(Self::NonNullable)
145 }
146 }
147 AllOr::None => Ok(Self::AllInvalid),
148 AllOr::Some(buf) => Ok(Validity::from(buf.clone())),
149 },
150 Self::AllValid => match indices.validity_mask()?.boolean_buffer() {
151 AllOr::All => Ok(Self::AllValid),
152 AllOr::None => Ok(Self::AllInvalid),
153 AllOr::Some(buf) => Ok(Validity::from(buf.clone())),
154 },
155 Self::AllInvalid => Ok(Self::AllInvalid),
156 Self::Array(is_valid) => {
157 let maybe_is_valid = take(is_valid, indices)?;
158 let is_valid = fill_null(&maybe_is_valid, Scalar::from(false))?;
160 Ok(Self::Array(is_valid))
161 }
162 }
163 }
164
165 pub fn filter(&self, mask: &Mask) -> VortexResult<Self> {
169 match self {
172 v @ (Validity::NonNullable | Validity::AllValid | Validity::AllInvalid) => {
173 Ok(v.clone())
174 }
175 Validity::Array(arr) => Ok(Validity::Array(filter(arr, mask)?)),
176 }
177 }
178
179 pub fn mask(&self, mask: &Mask) -> VortexResult<Self> {
183 match mask.boolean_buffer() {
184 AllOr::All => Ok(Validity::AllInvalid),
185 AllOr::None => Ok(self.clone()),
186 AllOr::Some(make_invalid) => Ok(match self {
187 Validity::NonNullable | Validity::AllValid => {
188 Validity::Array(BoolArray::from(make_invalid.not()).into_array())
189 }
190 Validity::AllInvalid => Validity::AllInvalid,
191 Validity::Array(is_valid) => {
192 let is_valid = is_valid.to_bool()?;
193 let keep_valid = make_invalid.not();
194 Validity::from(is_valid.boolean_buffer().bitand(&keep_valid))
195 }
196 }),
197 }
198 }
199
200 pub fn to_mask(&self, length: usize) -> VortexResult<Mask> {
201 Ok(match self {
202 Self::NonNullable | Self::AllValid => Mask::AllTrue(length),
203 Self::AllInvalid => Mask::AllFalse(length),
204 Self::Array(is_valid) => {
205 assert_eq!(
206 is_valid.len(),
207 length,
208 "Validity::Array length must equal to_logical's argument: {}, {}.",
209 is_valid.len(),
210 length,
211 );
212 Mask::try_from(&is_valid.to_bool()?)?
213 }
214 })
215 }
216
217 pub fn and(self, rhs: Validity) -> VortexResult<Validity> {
219 let validity = match (self, rhs) {
220 (Validity::NonNullable, Validity::NonNullable) => Validity::NonNullable,
222 (Validity::AllInvalid, _) | (_, Validity::AllInvalid) => Validity::AllInvalid,
224 (Validity::Array(a), Validity::AllValid)
226 | (Validity::Array(a), Validity::NonNullable)
227 | (Validity::NonNullable, Validity::Array(a))
228 | (Validity::AllValid, Validity::Array(a)) => Validity::Array(a),
229 (Validity::NonNullable, Validity::AllValid)
231 | (Validity::AllValid, Validity::NonNullable)
232 | (Validity::AllValid, Validity::AllValid) => Validity::AllValid,
233 (Validity::Array(lhs), Validity::Array(rhs)) => {
235 let lhs = lhs.to_bool()?;
236 let rhs = rhs.to_bool()?;
237
238 let lhs = lhs.boolean_buffer();
239 let rhs = rhs.boolean_buffer();
240
241 Validity::from(lhs.bitand(rhs))
242 }
243 };
244
245 Ok(validity)
246 }
247
248 pub fn patch(
249 self,
250 len: usize,
251 indices_offset: usize,
252 indices: &dyn Array,
253 patches: &Validity,
254 ) -> VortexResult<Self> {
255 match (&self, patches) {
256 (Validity::NonNullable, Validity::NonNullable) => return Ok(Validity::NonNullable),
257 (Validity::NonNullable, _) => {
258 vortex_bail!("Can't patch a non-nullable validity with nullable validity")
259 }
260 (_, Validity::NonNullable) => {
261 vortex_bail!("Can't patch a nullable validity with non-nullable validity")
262 }
263 (Validity::AllValid, Validity::AllValid) => return Ok(Validity::AllValid),
264 (Validity::AllInvalid, Validity::AllInvalid) => return Ok(Validity::AllInvalid),
265 _ => {}
266 };
267
268 let own_nullability = if self == Validity::NonNullable {
269 Nullability::NonNullable
270 } else {
271 Nullability::Nullable
272 };
273
274 let source = match self {
275 Validity::NonNullable => BoolArray::from(BooleanBuffer::new_set(len)),
276 Validity::AllValid => BoolArray::from(BooleanBuffer::new_set(len)),
277 Validity::AllInvalid => BoolArray::from(BooleanBuffer::new_unset(len)),
278 Validity::Array(a) => a.to_bool()?,
279 };
280
281 let patch_values = match patches {
282 Validity::NonNullable => BoolArray::from(BooleanBuffer::new_set(indices.len())),
283 Validity::AllValid => BoolArray::from(BooleanBuffer::new_set(indices.len())),
284 Validity::AllInvalid => BoolArray::from(BooleanBuffer::new_unset(indices.len())),
285 Validity::Array(a) => a.to_bool()?,
286 };
287
288 let patches = Patches::new(
289 len,
290 indices_offset,
291 indices.to_array(),
292 patch_values.into_array(),
293 );
294
295 Ok(Self::from_array(
296 source.patch(&patches)?.into_array(),
297 own_nullability,
298 ))
299 }
300
301 pub fn into_nullable(self) -> Validity {
303 match self {
304 Self::NonNullable => Self::AllValid,
305 _ => self,
306 }
307 }
308
309 pub fn into_non_nullable(self) -> Option<Validity> {
311 match self {
312 Self::NonNullable => Some(Self::NonNullable),
313 Self::AllValid => Some(Self::NonNullable),
314 Self::AllInvalid => None,
315 Self::Array(is_valid) => {
316 is_valid
317 .statistics()
318 .compute_min::<bool>()
319 .vortex_expect("validity array must support min")
320 .then(|| {
321 Self::NonNullable
323 })
324 }
325 }
326 }
327
328 pub fn cast_nullability(self, nullability: Nullability) -> VortexResult<Validity> {
330 match nullability {
331 Nullability::NonNullable => self.into_non_nullable().ok_or_else(|| {
332 vortex_err!("Cannot cast array with invalid values to non-nullable type.")
333 }),
334 Nullability::Nullable => Ok(self.into_nullable()),
335 }
336 }
337
338 pub fn copy_from_array(array: &dyn Array) -> VortexResult<Self> {
340 Ok(Validity::from_mask(
341 array.validity_mask()?,
342 array.dtype().nullability(),
343 ))
344 }
345
346 fn from_array(value: ArrayRef, nullability: Nullability) -> Self {
351 if !matches!(value.dtype(), DType::Bool(Nullability::NonNullable)) {
352 vortex_panic!("Expected a non-nullable boolean array")
353 }
354 match nullability {
355 Nullability::NonNullable => Self::NonNullable,
356 Nullability::Nullable => Self::Array(value),
357 }
358 }
359
360 pub fn maybe_len(&self) -> Option<usize> {
362 match self {
363 Self::NonNullable | Self::AllValid | Self::AllInvalid => None,
364 Self::Array(a) => Some(a.len()),
365 }
366 }
367
368 pub fn uncompressed_size(&self) -> usize {
369 if let Validity::Array(a) = self {
370 a.len().div_ceil(8)
371 } else {
372 0
373 }
374 }
375
376 pub fn is_array(&self) -> bool {
377 matches!(self, Validity::Array(_))
378 }
379}
380
381impl PartialEq for Validity {
382 fn eq(&self, other: &Self) -> bool {
383 match (self, other) {
384 (Self::NonNullable, Self::NonNullable) => true,
385 (Self::AllValid, Self::AllValid) => true,
386 (Self::AllInvalid, Self::AllInvalid) => true,
387 (Self::Array(a), Self::Array(b)) => {
388 let a = a
389 .to_bool()
390 .vortex_expect("Failed to get Validity Array as BoolArray");
391 let b = b
392 .to_bool()
393 .vortex_expect("Failed to get Validity Array as BoolArray");
394 a.boolean_buffer() == b.boolean_buffer()
395 }
396 _ => false,
397 }
398 }
399}
400
401impl From<BooleanBuffer> for Validity {
402 fn from(value: BooleanBuffer) -> Self {
403 if value.count_set_bits() == value.len() {
404 Self::AllValid
405 } else if value.count_set_bits() == 0 {
406 Self::AllInvalid
407 } else {
408 Self::Array(BoolArray::from(value).into_array())
409 }
410 }
411}
412
413impl From<NullBuffer> for Validity {
414 fn from(value: NullBuffer) -> Self {
415 value.into_inner().into()
416 }
417}
418
419impl FromIterator<Mask> for Validity {
420 fn from_iter<T: IntoIterator<Item = Mask>>(iter: T) -> Self {
421 Validity::from_mask(iter.into_iter().collect(), Nullability::Nullable)
422 }
423}
424
425impl FromIterator<bool> for Validity {
426 fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
427 Validity::from(BooleanBuffer::from_iter(iter))
428 }
429}
430
431impl From<Nullability> for Validity {
432 fn from(value: Nullability) -> Self {
433 match value {
434 Nullability::NonNullable => Validity::NonNullable,
435 Nullability::Nullable => Validity::AllValid,
436 }
437 }
438}
439
440impl Validity {
441 pub fn from_mask(mask: Mask, nullability: Nullability) -> Self {
442 assert!(
443 nullability == Nullability::Nullable || matches!(mask, Mask::AllTrue(_)),
444 "NonNullable validity must be AllValid",
445 );
446 match mask {
447 Mask::AllTrue(_) => match nullability {
448 Nullability::NonNullable => Validity::NonNullable,
449 Nullability::Nullable => Validity::AllValid,
450 },
451 Mask::AllFalse(_) => Validity::AllInvalid,
452 Mask::Values(values) => Validity::Array(values.into_array()),
453 }
454 }
455}
456
457impl IntoArray for Mask {
458 fn into_array(self) -> ArrayRef {
459 match self {
460 Self::AllTrue(len) => ConstantArray::new(true, len).into_array(),
461 Self::AllFalse(len) => ConstantArray::new(false, len).into_array(),
462 Self::Values(a) => a.into_array(),
463 }
464 }
465}
466
467impl IntoArray for &MaskValues {
468 fn into_array(self) -> ArrayRef {
469 BoolArray::new(self.boolean_buffer().clone(), Validity::NonNullable).into_array()
470 }
471}
472
473#[cfg(test)]
474mod tests {
475 use rstest::rstest;
476 use vortex_buffer::{Buffer, buffer};
477 use vortex_dtype::Nullability;
478 use vortex_mask::Mask;
479
480 use crate::array::Array;
481 use crate::arrays::{BoolArray, PrimitiveArray};
482 use crate::validity::Validity;
483 use crate::{ArrayRef, IntoArray};
484
485 #[rstest]
486 #[case(Validity::AllValid, 5, &[2, 4], Validity::AllValid, Validity::AllValid)]
487 #[case(Validity::AllValid, 5, &[2, 4], Validity::AllInvalid, Validity::Array(BoolArray::from_iter([true, true, false, true, false]).into_array())
488 )]
489 #[case(Validity::AllValid, 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([true, true, true, true, false]).into_array())
490 )]
491 #[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllValid, Validity::Array(BoolArray::from_iter([false, false, true, false, true]).into_array())
492 )]
493 #[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllInvalid, Validity::AllInvalid)]
494 #[case(Validity::AllInvalid, 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([false, false, true, false, false]).into_array())
495 )]
496 #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::AllValid, Validity::Array(BoolArray::from_iter([false, true, true, true, true]).into_array())
497 )]
498 #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::AllInvalid, Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array())
499 )]
500 #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([false, true, true, true, false]).into_array())
501 )]
502 fn patch_validity(
503 #[case] validity: Validity,
504 #[case] len: usize,
505 #[case] positions: &[u64],
506 #[case] patches: Validity,
507 #[case] expected: Validity,
508 ) {
509 let indices =
510 PrimitiveArray::new(Buffer::copy_from(positions), Validity::NonNullable).into_array();
511 assert_eq!(
512 validity.patch(len, 0, &indices, &patches).unwrap(),
513 expected
514 );
515 }
516
517 #[test]
518 #[should_panic]
519 fn out_of_bounds_patch() {
520 Validity::NonNullable
521 .patch(2, 0, &buffer![4].into_array(), &Validity::AllInvalid)
522 .unwrap();
523 }
524
525 #[test]
526 #[should_panic]
527 fn into_validity_nullable() {
528 Validity::from_mask(Mask::AllFalse(10), Nullability::NonNullable);
529 }
530
531 #[test]
532 #[should_panic]
533 fn into_validity_nullable_array() {
534 Validity::from_mask(Mask::from_iter(vec![true, false]), Nullability::NonNullable);
535 }
536
537 #[rstest]
538 #[case(Validity::AllValid, PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(), Validity::from_iter(vec![true, false]))]
539 #[case(Validity::AllValid, buffer![0, 1].into_array(), Validity::AllValid)]
540 #[case(Validity::AllValid, PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(), Validity::AllInvalid)]
541 #[case(Validity::NonNullable, PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(), Validity::from_iter(vec![true, false]))]
542 #[case(Validity::NonNullable, buffer![0, 1].into_array(), Validity::NonNullable)]
543 #[case(Validity::NonNullable, PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(), Validity::AllInvalid)]
544 fn validity_take(
545 #[case] validity: Validity,
546 #[case] indices: ArrayRef,
547 #[case] expected: Validity,
548 ) {
549 assert_eq!(validity.take(&indices).unwrap(), expected);
550 }
551}