1use std::fmt::Debug;
7use std::ops::{BitAnd, Not, Range};
8
9use arrow_buffer::{BooleanBuffer, NullBuffer};
10use vortex_dtype::{DType, Nullability};
11use vortex_error::{VortexExpect as _, VortexResult, vortex_err, vortex_panic};
12use vortex_mask::{AllOr, Mask, MaskValues};
13use vortex_scalar::Scalar;
14
15use crate::arrays::{BoolArray, ConstantArray};
16use crate::compute::{fill_null, filter, sum, take};
17use crate::patches::Patches;
18use crate::{Array, ArrayRef, IntoArray, ToCanonical};
19
20#[derive(Clone, Debug)]
22pub enum Validity {
23 NonNullable,
25 AllValid,
27 AllInvalid,
29 Array(ArrayRef),
31}
32
33impl Validity {
34 pub const DTYPE: DType = DType::Bool(Nullability::NonNullable);
36
37 #[inline]
39 pub fn into_array(self) -> Option<ArrayRef> {
40 if let Self::Array(a) = self {
41 Some(a)
42 } else {
43 None
44 }
45 }
46
47 #[inline]
49 pub fn as_array(&self) -> Option<&ArrayRef> {
50 if let Self::Array(a) = self {
51 Some(a)
52 } else {
53 None
54 }
55 }
56
57 #[inline]
58 pub fn nullability(&self) -> Nullability {
59 if matches!(self, Self::NonNullable) {
60 Nullability::NonNullable
61 } else {
62 Nullability::Nullable
63 }
64 }
65
66 #[inline]
68 pub fn union_nullability(self, nullability: Nullability) -> Self {
69 match nullability {
70 Nullability::NonNullable => self,
71 Nullability::Nullable => self.into_nullable(),
72 }
73 }
74
75 #[inline]
76 pub fn all_valid(&self) -> bool {
77 match self {
78 Validity::NonNullable | Validity::AllValid => true,
79 Validity::AllInvalid => false,
80 Validity::Array(array) => {
81 usize::try_from(&sum(array).vortex_expect("must have sum for bool array"))
82 .vortex_expect("sum must be a usize")
83 == array.len()
84 }
85 }
86 }
87
88 #[inline]
89 pub fn all_invalid(&self) -> bool {
90 match self {
91 Validity::NonNullable | Validity::AllValid => false,
92 Validity::AllInvalid => true,
93 Validity::Array(array) => {
94 usize::try_from(&sum(array).vortex_expect("must have sum for bool array"))
95 .vortex_expect("sum must be a usize")
96 == 0
97 }
98 }
99 }
100
101 #[inline]
103 pub fn is_valid(&self, index: usize) -> bool {
104 match self {
105 Self::NonNullable | Self::AllValid => true,
106 Self::AllInvalid => false,
107 Self::Array(a) => {
108 let scalar = a.scalar_at(index);
109 scalar
110 .as_bool()
111 .value()
112 .vortex_expect("Validity must be non-nullable")
113 }
114 }
115 }
116
117 #[inline]
118 pub fn is_null(&self, index: usize) -> bool {
119 !self.is_valid(index)
120 }
121
122 #[inline]
123 pub fn slice(&self, range: Range<usize>) -> Self {
124 match self {
125 Self::Array(a) => Self::Array(a.slice(range)),
126 Self::NonNullable | Self::AllValid | Self::AllInvalid => self.clone(),
127 }
128 }
129
130 pub fn take(&self, indices: &dyn Array) -> VortexResult<Self> {
131 match self {
132 Self::NonNullable => match indices.validity_mask().boolean_buffer() {
133 AllOr::All => {
134 if indices.dtype().is_nullable() {
135 Ok(Self::AllValid)
136 } else {
137 Ok(Self::NonNullable)
138 }
139 }
140 AllOr::None => Ok(Self::AllInvalid),
141 AllOr::Some(buf) => Ok(Validity::from(buf.clone())),
142 },
143 Self::AllValid => match indices.validity_mask().boolean_buffer() {
144 AllOr::All => Ok(Self::AllValid),
145 AllOr::None => Ok(Self::AllInvalid),
146 AllOr::Some(buf) => Ok(Validity::from(buf.clone())),
147 },
148 Self::AllInvalid => Ok(Self::AllInvalid),
149 Self::Array(is_valid) => {
150 let maybe_is_valid = take(is_valid, indices)?;
151 let is_valid = fill_null(&maybe_is_valid, &Scalar::from(false))?;
153 Ok(Self::Array(is_valid))
154 }
155 }
156 }
157
158 pub fn filter(&self, mask: &Mask) -> VortexResult<Self> {
162 match self {
165 v @ (Validity::NonNullable | Validity::AllValid | Validity::AllInvalid) => {
166 Ok(v.clone())
167 }
168 Validity::Array(arr) => Ok(Validity::Array(filter(arr, mask)?)),
169 }
170 }
171
172 #[inline]
176 pub fn mask(&self, mask: &Mask) -> Self {
177 match mask.boolean_buffer() {
178 AllOr::All => Validity::AllInvalid,
179 AllOr::None => self.clone(),
180 AllOr::Some(make_invalid) => match self {
181 Validity::NonNullable | Validity::AllValid => {
182 Validity::Array(BoolArray::from(make_invalid.not()).into_array())
183 }
184 Validity::AllInvalid => Validity::AllInvalid,
185 Validity::Array(is_valid) => {
186 let is_valid = is_valid.to_bool();
187 let keep_valid = make_invalid.not();
188 Validity::from(is_valid.boolean_buffer().bitand(&keep_valid))
189 }
190 },
191 }
192 }
193
194 #[inline]
195 pub fn to_mask(&self, length: usize) -> Mask {
196 match self {
197 Self::NonNullable | Self::AllValid => Mask::AllTrue(length),
198 Self::AllInvalid => Mask::AllFalse(length),
199 Self::Array(is_valid) => {
200 assert_eq!(
201 is_valid.len(),
202 length,
203 "Validity::Array length must equal to_logical's argument: {}, {}.",
204 is_valid.len(),
205 length,
206 );
207 is_valid.to_bool().to_mask()
208 }
209 }
210 }
211
212 #[inline]
214 pub fn and(self, rhs: Validity) -> Validity {
215 match (self, rhs) {
216 (Validity::NonNullable, Validity::NonNullable) => Validity::NonNullable,
218 (Validity::AllInvalid, _) | (_, Validity::AllInvalid) => Validity::AllInvalid,
220 (Validity::Array(a), Validity::AllValid)
222 | (Validity::Array(a), Validity::NonNullable)
223 | (Validity::NonNullable, Validity::Array(a))
224 | (Validity::AllValid, Validity::Array(a)) => Validity::Array(a),
225 (Validity::NonNullable, Validity::AllValid)
227 | (Validity::AllValid, Validity::NonNullable)
228 | (Validity::AllValid, Validity::AllValid) => Validity::AllValid,
229 (Validity::Array(lhs), Validity::Array(rhs)) => {
231 let lhs = lhs.to_bool();
232 let rhs = rhs.to_bool();
233
234 let lhs = lhs.boolean_buffer();
235 let rhs = rhs.boolean_buffer();
236
237 Validity::from(lhs.bitand(rhs))
238 }
239 }
240 }
241
242 pub fn patch(
243 self,
244 len: usize,
245 indices_offset: usize,
246 indices: &dyn Array,
247 patches: &Validity,
248 ) -> Self {
249 match (&self, patches) {
250 (Validity::NonNullable, Validity::NonNullable) => return Validity::NonNullable,
251 (Validity::NonNullable, _) => {
252 vortex_panic!("Can't patch a non-nullable validity with nullable validity")
253 }
254 (_, Validity::NonNullable) => {
255 vortex_panic!("Can't patch a nullable validity with non-nullable validity")
256 }
257 (Validity::AllValid, Validity::AllValid) => return Validity::AllValid,
258 (Validity::AllInvalid, Validity::AllInvalid) => return Validity::AllInvalid,
259 _ => {}
260 };
261
262 let own_nullability = if self == Validity::NonNullable {
263 Nullability::NonNullable
264 } else {
265 Nullability::Nullable
266 };
267
268 let source = match self {
269 Validity::NonNullable => BoolArray::from(BooleanBuffer::new_set(len)),
270 Validity::AllValid => BoolArray::from(BooleanBuffer::new_set(len)),
271 Validity::AllInvalid => BoolArray::from(BooleanBuffer::new_unset(len)),
272 Validity::Array(a) => a.to_bool(),
273 };
274
275 let patch_values = match patches {
276 Validity::NonNullable => BoolArray::from(BooleanBuffer::new_set(indices.len())),
277 Validity::AllValid => BoolArray::from(BooleanBuffer::new_set(indices.len())),
278 Validity::AllInvalid => BoolArray::from(BooleanBuffer::new_unset(indices.len())),
279 Validity::Array(a) => a.to_bool(),
280 };
281
282 let patches = Patches::new(
283 len,
284 indices_offset,
285 indices.to_array(),
286 patch_values.into_array(),
287 );
288
289 Self::from_array(source.patch(&patches).into_array(), own_nullability)
290 }
291
292 #[inline]
294 pub fn into_nullable(self) -> Validity {
295 match self {
296 Self::NonNullable => Self::AllValid,
297 Self::AllValid | Self::AllInvalid | Self::Array(_) => self,
298 }
299 }
300
301 #[inline]
303 pub fn into_non_nullable(self) -> Option<Validity> {
304 match self {
305 Self::NonNullable => Some(Self::NonNullable),
306 Self::AllValid => Some(Self::NonNullable),
307 Self::AllInvalid => None,
308 Self::Array(is_valid) => {
309 is_valid
310 .statistics()
311 .compute_min::<bool>()
312 .vortex_expect("validity array must support min")
313 .then(|| {
314 Self::NonNullable
316 })
317 }
318 }
319 }
320
321 #[inline]
323 pub fn cast_nullability(self, nullability: Nullability) -> VortexResult<Validity> {
324 match nullability {
325 Nullability::NonNullable => self.into_non_nullable().ok_or_else(|| {
326 vortex_err!("Cannot cast array with invalid values to non-nullable type.")
327 }),
328 Nullability::Nullable => Ok(self.into_nullable()),
329 }
330 }
331
332 #[inline]
334 pub fn copy_from_array(array: &dyn Array) -> Self {
335 Validity::from_mask(array.validity_mask(), array.dtype().nullability())
336 }
337
338 #[inline]
343 fn from_array(value: ArrayRef, nullability: Nullability) -> Self {
344 if !matches!(value.dtype(), DType::Bool(Nullability::NonNullable)) {
345 vortex_panic!("Expected a non-nullable boolean array")
346 }
347 match nullability {
348 Nullability::NonNullable => Self::NonNullable,
349 Nullability::Nullable => Self::Array(value),
350 }
351 }
352
353 #[inline]
355 pub fn maybe_len(&self) -> Option<usize> {
356 match self {
357 Self::NonNullable | Self::AllValid | Self::AllInvalid => None,
358 Self::Array(a) => Some(a.len()),
359 }
360 }
361
362 #[inline]
363 pub fn uncompressed_size(&self) -> usize {
364 if let Validity::Array(a) = self {
365 a.len().div_ceil(8)
366 } else {
367 0
368 }
369 }
370}
371
372impl PartialEq for Validity {
373 #[inline]
374 fn eq(&self, other: &Self) -> bool {
375 match (self, other) {
376 (Self::NonNullable, Self::NonNullable) => true,
377 (Self::AllValid, Self::AllValid) => true,
378 (Self::AllInvalid, Self::AllInvalid) => true,
379 (Self::Array(a), Self::Array(b)) => {
380 let a = a.to_bool();
381 let b = b.to_bool();
382 a.boolean_buffer() == b.boolean_buffer()
383 }
384 _ => false,
385 }
386 }
387}
388
389impl From<BooleanBuffer> for Validity {
390 #[inline]
391 fn from(value: BooleanBuffer) -> Self {
392 if value.count_set_bits() == value.len() {
393 Self::AllValid
394 } else if value.count_set_bits() == 0 {
395 Self::AllInvalid
396 } else {
397 Self::Array(BoolArray::from(value).into_array())
398 }
399 }
400}
401
402impl From<NullBuffer> for Validity {
403 #[inline]
404 fn from(value: NullBuffer) -> Self {
405 value.into_inner().into()
406 }
407}
408
409impl FromIterator<Mask> for Validity {
410 #[inline]
411 fn from_iter<T: IntoIterator<Item = Mask>>(iter: T) -> Self {
412 Validity::from_mask(iter.into_iter().collect(), Nullability::Nullable)
413 }
414}
415
416impl FromIterator<bool> for Validity {
417 #[inline]
418 fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
419 Validity::from(BooleanBuffer::from_iter(iter))
420 }
421}
422
423impl From<Nullability> for Validity {
424 #[inline]
425 fn from(value: Nullability) -> Self {
426 match value {
427 Nullability::NonNullable => Validity::NonNullable,
428 Nullability::Nullable => Validity::AllValid,
429 }
430 }
431}
432
433impl Validity {
434 pub fn from_mask(mask: Mask, nullability: Nullability) -> Self {
435 assert!(
436 nullability == Nullability::Nullable || matches!(mask, Mask::AllTrue(_)),
437 "NonNullable validity must be AllValid",
438 );
439 match mask {
440 Mask::AllTrue(_) => match nullability {
441 Nullability::NonNullable => Validity::NonNullable,
442 Nullability::Nullable => Validity::AllValid,
443 },
444 Mask::AllFalse(_) => Validity::AllInvalid,
445 Mask::Values(values) => Validity::Array(values.into_array()),
446 }
447 }
448}
449
450impl IntoArray for Mask {
451 #[inline]
452 fn into_array(self) -> ArrayRef {
453 match self {
454 Self::AllTrue(len) => ConstantArray::new(true, len).into_array(),
455 Self::AllFalse(len) => ConstantArray::new(false, len).into_array(),
456 Self::Values(a) => a.into_array(),
457 }
458 }
459}
460
461impl IntoArray for &MaskValues {
462 #[inline]
463 fn into_array(self) -> ArrayRef {
464 BoolArray::new(self.boolean_buffer().clone(), Validity::NonNullable).into_array()
465 }
466}
467
468#[cfg(test)]
469mod tests {
470 use rstest::rstest;
471 use vortex_buffer::{Buffer, buffer};
472 use vortex_dtype::Nullability;
473 use vortex_mask::Mask;
474
475 use crate::arrays::{BoolArray, PrimitiveArray};
476 use crate::validity::Validity;
477 use crate::{ArrayRef, IntoArray};
478
479 #[rstest]
480 #[case(Validity::AllValid, 5, &[2, 4], Validity::AllValid, Validity::AllValid)]
481 #[case(Validity::AllValid, 5, &[2, 4], Validity::AllInvalid, Validity::Array(BoolArray::from_iter([true, true, false, true, false]).into_array())
482 )]
483 #[case(Validity::AllValid, 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([true, true, true, true, false]).into_array())
484 )]
485 #[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllValid, Validity::Array(BoolArray::from_iter([false, false, true, false, true]).into_array())
486 )]
487 #[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllInvalid, Validity::AllInvalid)]
488 #[case(Validity::AllInvalid, 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([false, false, true, false, false]).into_array())
489 )]
490 #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::AllValid, Validity::Array(BoolArray::from_iter([false, true, true, true, true]).into_array())
491 )]
492 #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::AllInvalid, Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array())
493 )]
494 #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([false, true, true, true, false]).into_array())
495 )]
496 fn patch_validity(
497 #[case] validity: Validity,
498 #[case] len: usize,
499 #[case] positions: &[u64],
500 #[case] patches: Validity,
501 #[case] expected: Validity,
502 ) {
503 let indices =
504 PrimitiveArray::new(Buffer::copy_from(positions), Validity::NonNullable).into_array();
505 assert_eq!(validity.patch(len, 0, &indices, &patches), expected);
506 }
507
508 #[test]
509 #[should_panic]
510 fn out_of_bounds_patch() {
511 Validity::NonNullable.patch(2, 0, &buffer![4].into_array(), &Validity::AllInvalid);
512 }
513
514 #[test]
515 #[should_panic]
516 fn into_validity_nullable() {
517 Validity::from_mask(Mask::AllFalse(10), Nullability::NonNullable);
518 }
519
520 #[test]
521 #[should_panic]
522 fn into_validity_nullable_array() {
523 Validity::from_mask(Mask::from_iter(vec![true, false]), Nullability::NonNullable);
524 }
525
526 #[rstest]
527 #[case(Validity::AllValid, PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(), Validity::from_iter(vec![true, false]))]
528 #[case(Validity::AllValid, buffer![0, 1].into_array(), Validity::AllValid)]
529 #[case(Validity::AllValid, PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(), Validity::AllInvalid)]
530 #[case(Validity::NonNullable, PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(), Validity::from_iter(vec![true, false]))]
531 #[case(Validity::NonNullable, buffer![0, 1].into_array(), Validity::NonNullable)]
532 #[case(Validity::NonNullable, PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(), Validity::AllInvalid)]
533 fn validity_take(
534 #[case] validity: Validity,
535 #[case] indices: ArrayRef,
536 #[case] expected: Validity,
537 ) {
538 assert_eq!(validity.take(&indices).unwrap(), expected);
539 }
540}