1use std::fmt::Debug;
7use std::ops::{BitAnd, Not, Range};
8
9use arrow_buffer::{BooleanBuffer, NullBuffer};
10use vortex_dtype::{DType, Nullability};
11use vortex_error::{VortexExpect as _, VortexResult, vortex_err, vortex_panic};
12use vortex_mask::{AllOr, Mask, MaskValues};
13use vortex_scalar::Scalar;
14
15use crate::arrays::{BoolArray, ConstantArray};
16use crate::compute::{fill_null, filter, sum, take};
17use crate::patches::Patches;
18use crate::{Array, ArrayRef, IntoArray, ToCanonical};
19
20#[derive(Clone, Debug)]
22pub enum Validity {
23 NonNullable,
25 AllValid,
27 AllInvalid,
29 Array(ArrayRef),
33}
34
35impl Validity {
36 pub const DTYPE: DType = DType::Bool(Nullability::NonNullable);
38
39 #[inline]
41 pub fn into_array(self) -> Option<ArrayRef> {
42 if let Self::Array(a) = self {
43 Some(a)
44 } else {
45 None
46 }
47 }
48
49 #[inline]
51 pub fn as_array(&self) -> Option<&ArrayRef> {
52 if let Self::Array(a) = self {
53 Some(a)
54 } else {
55 None
56 }
57 }
58
59 #[inline]
60 pub fn nullability(&self) -> Nullability {
61 if matches!(self, Self::NonNullable) {
62 Nullability::NonNullable
63 } else {
64 Nullability::Nullable
65 }
66 }
67
68 #[inline]
70 pub fn union_nullability(self, nullability: Nullability) -> Self {
71 match nullability {
72 Nullability::NonNullable => self,
73 Nullability::Nullable => self.into_nullable(),
74 }
75 }
76
77 #[inline]
78 pub fn all_valid(&self, len: usize) -> bool {
79 match self {
80 _ if len == 0 => true,
81 Validity::NonNullable | Validity::AllValid => true,
82 Validity::AllInvalid => false,
83 Validity::Array(array) => {
84 usize::try_from(&sum(array).vortex_expect("must have sum for bool array"))
85 .vortex_expect("sum must be a usize")
86 == array.len()
87 }
88 }
89 }
90
91 #[inline]
92 pub fn all_invalid(&self, len: usize) -> bool {
93 match self {
94 _ if len == 0 => true,
95 Validity::NonNullable | Validity::AllValid => false,
96 Validity::AllInvalid => true,
97 Validity::Array(array) => {
98 usize::try_from(&sum(array).vortex_expect("must have sum for bool array"))
99 .vortex_expect("sum must be a usize")
100 == 0
101 }
102 }
103 }
104
105 #[inline]
107 pub fn is_valid(&self, index: usize) -> bool {
108 match self {
109 Self::NonNullable | Self::AllValid => true,
110 Self::AllInvalid => false,
111 Self::Array(a) => {
112 let scalar = a.scalar_at(index);
113 scalar
114 .as_bool()
115 .value()
116 .vortex_expect("Validity must be non-nullable")
117 }
118 }
119 }
120
121 #[inline]
122 pub fn is_null(&self, index: usize) -> bool {
123 !self.is_valid(index)
124 }
125
126 #[inline]
127 pub fn slice(&self, range: Range<usize>) -> Self {
128 match self {
129 Self::Array(a) => Self::Array(a.slice(range)),
130 Self::NonNullable | Self::AllValid | Self::AllInvalid => self.clone(),
131 }
132 }
133
134 pub fn take(&self, indices: &dyn Array) -> VortexResult<Self> {
135 match self {
136 Self::NonNullable => match indices.validity_mask().boolean_buffer() {
137 AllOr::All => {
138 if indices.dtype().is_nullable() {
139 Ok(Self::AllValid)
140 } else {
141 Ok(Self::NonNullable)
142 }
143 }
144 AllOr::None => Ok(Self::AllInvalid),
145 AllOr::Some(buf) => Ok(Validity::from(buf.clone())),
146 },
147 Self::AllValid => match indices.validity_mask().boolean_buffer() {
148 AllOr::All => Ok(Self::AllValid),
149 AllOr::None => Ok(Self::AllInvalid),
150 AllOr::Some(buf) => Ok(Validity::from(buf.clone())),
151 },
152 Self::AllInvalid => Ok(Self::AllInvalid),
153 Self::Array(is_valid) => {
154 let maybe_is_valid = take(is_valid, indices)?;
155 let is_valid = fill_null(&maybe_is_valid, &Scalar::from(false))?;
157 Ok(Self::Array(is_valid))
158 }
159 }
160 }
161
162 pub fn filter(&self, mask: &Mask) -> VortexResult<Self> {
166 match self {
169 v @ (Validity::NonNullable | Validity::AllValid | Validity::AllInvalid) => {
170 Ok(v.clone())
171 }
172 Validity::Array(arr) => Ok(Validity::Array(filter(arr, mask)?)),
173 }
174 }
175
176 #[inline]
180 pub fn mask(&self, mask: &Mask) -> Self {
181 match mask.boolean_buffer() {
182 AllOr::All => Validity::AllInvalid,
183 AllOr::None => self.clone(),
184 AllOr::Some(make_invalid) => match self {
185 Validity::NonNullable | Validity::AllValid => {
186 Validity::Array(BoolArray::from(make_invalid.not()).into_array())
187 }
188 Validity::AllInvalid => Validity::AllInvalid,
189 Validity::Array(is_valid) => {
190 let is_valid = is_valid.to_bool();
191 let keep_valid = make_invalid.not();
192 Validity::from(is_valid.boolean_buffer().bitand(&keep_valid))
193 }
194 },
195 }
196 }
197
198 #[inline]
199 pub fn to_mask(&self, length: usize) -> Mask {
200 match self {
201 Self::NonNullable | Self::AllValid => Mask::AllTrue(length),
202 Self::AllInvalid => Mask::AllFalse(length),
203 Self::Array(is_valid) => {
204 assert_eq!(
205 is_valid.len(),
206 length,
207 "Validity::Array length must equal to_logical's argument: {}, {}.",
208 is_valid.len(),
209 length,
210 );
211 is_valid.to_bool().to_mask()
212 }
213 }
214 }
215
216 #[inline]
218 pub fn and(self, rhs: Validity) -> Validity {
219 match (self, rhs) {
220 (Validity::NonNullable, Validity::NonNullable) => Validity::NonNullable,
222 (Validity::AllInvalid, _) | (_, Validity::AllInvalid) => Validity::AllInvalid,
224 (Validity::Array(a), Validity::AllValid)
226 | (Validity::Array(a), Validity::NonNullable)
227 | (Validity::NonNullable, Validity::Array(a))
228 | (Validity::AllValid, Validity::Array(a)) => Validity::Array(a),
229 (Validity::NonNullable, Validity::AllValid)
231 | (Validity::AllValid, Validity::NonNullable)
232 | (Validity::AllValid, Validity::AllValid) => Validity::AllValid,
233 (Validity::Array(lhs), Validity::Array(rhs)) => {
235 let lhs = lhs.to_bool();
236 let rhs = rhs.to_bool();
237
238 let lhs = lhs.boolean_buffer();
239 let rhs = rhs.boolean_buffer();
240
241 Validity::from(lhs.bitand(rhs))
242 }
243 }
244 }
245
246 pub fn patch(
247 self,
248 len: usize,
249 indices_offset: usize,
250 indices: &dyn Array,
251 patches: &Validity,
252 ) -> Self {
253 match (&self, patches) {
254 (Validity::NonNullable, Validity::NonNullable) => return Validity::NonNullable,
255 (Validity::NonNullable, _) => {
256 vortex_panic!("Can't patch a non-nullable validity with nullable validity")
257 }
258 (_, Validity::NonNullable) => {
259 vortex_panic!("Can't patch a nullable validity with non-nullable validity")
260 }
261 (Validity::AllValid, Validity::AllValid) => return Validity::AllValid,
262 (Validity::AllInvalid, Validity::AllInvalid) => return Validity::AllInvalid,
263 _ => {}
264 };
265
266 let own_nullability = if self == Validity::NonNullable {
267 Nullability::NonNullable
268 } else {
269 Nullability::Nullable
270 };
271
272 let source = match self {
273 Validity::NonNullable => BoolArray::from(BooleanBuffer::new_set(len)),
274 Validity::AllValid => BoolArray::from(BooleanBuffer::new_set(len)),
275 Validity::AllInvalid => BoolArray::from(BooleanBuffer::new_unset(len)),
276 Validity::Array(a) => a.to_bool(),
277 };
278
279 let patch_values = match patches {
280 Validity::NonNullable => BoolArray::from(BooleanBuffer::new_set(indices.len())),
281 Validity::AllValid => BoolArray::from(BooleanBuffer::new_set(indices.len())),
282 Validity::AllInvalid => BoolArray::from(BooleanBuffer::new_unset(indices.len())),
283 Validity::Array(a) => a.to_bool(),
284 };
285
286 let patches = Patches::new(
287 len,
288 indices_offset,
289 indices.to_array(),
290 patch_values.into_array(),
291 None,
293 );
294
295 Self::from_array(source.patch(&patches).into_array(), own_nullability)
296 }
297
298 #[inline]
300 pub fn into_nullable(self) -> Validity {
301 match self {
302 Self::NonNullable => Self::AllValid,
303 Self::AllValid | Self::AllInvalid | Self::Array(_) => self,
304 }
305 }
306
307 #[inline]
309 pub fn into_non_nullable(self, len: usize) -> Option<Validity> {
310 match self {
311 _ if len == 0 => Some(Validity::NonNullable),
312 Self::NonNullable => Some(Self::NonNullable),
313 Self::AllValid => Some(Self::NonNullable),
314 Self::AllInvalid => None,
315 Self::Array(is_valid) => {
316 is_valid
317 .statistics()
318 .compute_min::<bool>()
319 .vortex_expect("validity array must support min")
320 .then(|| {
321 Self::NonNullable
323 })
324 }
325 }
326 }
327
328 #[inline]
330 pub fn cast_nullability(self, nullability: Nullability, len: usize) -> VortexResult<Validity> {
331 match nullability {
332 Nullability::NonNullable => self.into_non_nullable(len).ok_or_else(|| {
333 vortex_err!("Cannot cast array with invalid values to non-nullable type.")
334 }),
335 Nullability::Nullable => Ok(self.into_nullable()),
336 }
337 }
338
339 #[inline]
341 pub fn copy_from_array(array: &dyn Array) -> Self {
342 Validity::from_mask(array.validity_mask(), array.dtype().nullability())
343 }
344
345 #[inline]
350 fn from_array(value: ArrayRef, nullability: Nullability) -> Self {
351 if !matches!(value.dtype(), DType::Bool(Nullability::NonNullable)) {
352 vortex_panic!("Expected a non-nullable boolean array")
353 }
354 match nullability {
355 Nullability::NonNullable => Self::NonNullable,
356 Nullability::Nullable => Self::Array(value),
357 }
358 }
359
360 #[inline]
362 pub fn maybe_len(&self) -> Option<usize> {
363 match self {
364 Self::NonNullable | Self::AllValid | Self::AllInvalid => None,
365 Self::Array(a) => Some(a.len()),
366 }
367 }
368
369 #[inline]
370 pub fn uncompressed_size(&self) -> usize {
371 if let Validity::Array(a) = self {
372 a.len().div_ceil(8)
373 } else {
374 0
375 }
376 }
377}
378
379impl PartialEq for Validity {
380 #[inline]
381 fn eq(&self, other: &Self) -> bool {
382 match (self, other) {
383 (Self::NonNullable, Self::NonNullable) => true,
384 (Self::AllValid, Self::AllValid) => true,
385 (Self::AllInvalid, Self::AllInvalid) => true,
386 (Self::Array(a), Self::Array(b)) => {
387 let a = a.to_bool();
388 let b = b.to_bool();
389 a.boolean_buffer() == b.boolean_buffer()
390 }
391 _ => false,
392 }
393 }
394}
395
396impl From<BooleanBuffer> for Validity {
397 #[inline]
398 fn from(value: BooleanBuffer) -> Self {
399 if value.count_set_bits() == value.len() {
400 Self::AllValid
401 } else if value.count_set_bits() == 0 {
402 Self::AllInvalid
403 } else {
404 Self::Array(BoolArray::from(value).into_array())
405 }
406 }
407}
408
409impl From<NullBuffer> for Validity {
410 #[inline]
411 fn from(value: NullBuffer) -> Self {
412 value.into_inner().into()
413 }
414}
415
416impl FromIterator<Mask> for Validity {
417 #[inline]
418 fn from_iter<T: IntoIterator<Item = Mask>>(iter: T) -> Self {
419 Validity::from_mask(iter.into_iter().collect(), Nullability::Nullable)
420 }
421}
422
423impl FromIterator<bool> for Validity {
424 #[inline]
425 fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
426 Validity::from(BooleanBuffer::from_iter(iter))
427 }
428}
429
430impl From<Nullability> for Validity {
431 #[inline]
432 fn from(value: Nullability) -> Self {
433 match value {
434 Nullability::NonNullable => Validity::NonNullable,
435 Nullability::Nullable => Validity::AllValid,
436 }
437 }
438}
439
440impl Validity {
441 pub fn from_null_buffer(buffer: Option<NullBuffer>, nullability: Nullability) -> Self {
442 match buffer {
443 None => nullability.into(),
445 Some(nulls) => {
446 if nulls.null_count() == nulls.len() {
447 Validity::AllInvalid
448 } else {
449 Validity::Array(BoolArray::from(nulls.into_inner()).into_array())
450 }
451 }
452 }
453 }
454
455 pub fn from_mask(mask: Mask, nullability: Nullability) -> Self {
456 assert!(
457 nullability == Nullability::Nullable || matches!(mask, Mask::AllTrue(_)),
458 "NonNullable validity must be AllValid",
459 );
460 match mask {
461 Mask::AllTrue(_) => match nullability {
462 Nullability::NonNullable => Validity::NonNullable,
463 Nullability::Nullable => Validity::AllValid,
464 },
465 Mask::AllFalse(_) => Validity::AllInvalid,
466 Mask::Values(values) => Validity::Array(values.into_array()),
467 }
468 }
469}
470
471impl IntoArray for Mask {
472 #[inline]
473 fn into_array(self) -> ArrayRef {
474 match self {
475 Self::AllTrue(len) => ConstantArray::new(true, len).into_array(),
476 Self::AllFalse(len) => ConstantArray::new(false, len).into_array(),
477 Self::Values(a) => a.into_array(),
478 }
479 }
480}
481
482impl IntoArray for &MaskValues {
483 #[inline]
484 fn into_array(self) -> ArrayRef {
485 BoolArray::from_bool_buffer(self.boolean_buffer().clone(), Validity::NonNullable)
486 .into_array()
487 }
488}
489
490#[cfg(test)]
491mod tests {
492 use rstest::rstest;
493 use vortex_buffer::{Buffer, buffer};
494 use vortex_dtype::Nullability;
495 use vortex_mask::Mask;
496
497 use crate::arrays::{BoolArray, PrimitiveArray};
498 use crate::validity::Validity;
499 use crate::{ArrayRef, IntoArray};
500
501 #[rstest]
502 #[case(Validity::AllValid, 5, &[2, 4], Validity::AllValid, Validity::AllValid)]
503 #[case(Validity::AllValid, 5, &[2, 4], Validity::AllInvalid, Validity::Array(BoolArray::from_iter([true, true, false, true, false]).into_array())
504 )]
505 #[case(Validity::AllValid, 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([true, true, true, true, false]).into_array())
506 )]
507 #[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllValid, Validity::Array(BoolArray::from_iter([false, false, true, false, true]).into_array())
508 )]
509 #[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllInvalid, Validity::AllInvalid)]
510 #[case(Validity::AllInvalid, 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([false, false, true, false, false]).into_array())
511 )]
512 #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::AllValid, Validity::Array(BoolArray::from_iter([false, true, true, true, true]).into_array())
513 )]
514 #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::AllInvalid, Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array())
515 )]
516 #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([false, true, true, true, false]).into_array())
517 )]
518 fn patch_validity(
519 #[case] validity: Validity,
520 #[case] len: usize,
521 #[case] positions: &[u64],
522 #[case] patches: Validity,
523 #[case] expected: Validity,
524 ) {
525 let indices =
526 PrimitiveArray::new(Buffer::copy_from(positions), Validity::NonNullable).into_array();
527 assert_eq!(validity.patch(len, 0, &indices, &patches), expected);
528 }
529
530 #[test]
531 #[should_panic]
532 fn out_of_bounds_patch() {
533 Validity::NonNullable.patch(2, 0, &buffer![4].into_array(), &Validity::AllInvalid);
534 }
535
536 #[test]
537 #[should_panic]
538 fn into_validity_nullable() {
539 Validity::from_mask(Mask::AllFalse(10), Nullability::NonNullable);
540 }
541
542 #[test]
543 #[should_panic]
544 fn into_validity_nullable_array() {
545 Validity::from_mask(Mask::from_iter(vec![true, false]), Nullability::NonNullable);
546 }
547
548 #[rstest]
549 #[case(Validity::AllValid, PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(), Validity::from_iter(vec![true, false]))]
550 #[case(Validity::AllValid, buffer![0, 1].into_array(), Validity::AllValid)]
551 #[case(Validity::AllValid, PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(), Validity::AllInvalid)]
552 #[case(Validity::NonNullable, PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(), Validity::from_iter(vec![true, false]))]
553 #[case(Validity::NonNullable, buffer![0, 1].into_array(), Validity::NonNullable)]
554 #[case(Validity::NonNullable, PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(), Validity::AllInvalid)]
555 fn validity_take(
556 #[case] validity: Validity,
557 #[case] indices: ArrayRef,
558 #[case] expected: Validity,
559 ) {
560 assert_eq!(validity.take(&indices).unwrap(), expected);
561 }
562}