1use std::fmt::Debug;
4use std::ops::{BitAnd, Not};
5
6use arrow_buffer::{BooleanBuffer, BooleanBufferBuilder, NullBuffer};
7use vortex_dtype::{DType, Nullability};
8use vortex_error::{VortexExpect as _, VortexResult, vortex_bail, vortex_err, vortex_panic};
9use vortex_mask::{AllOr, Mask, MaskValues};
10use vortex_scalar::Scalar;
11
12use crate::arrays::{BoolArray, ConstantArray};
13use crate::compute::{fill_null, filter, scalar_at, slice, take};
14use crate::patches::Patches;
15use crate::{Array, ArrayRef, ArrayVariants, IntoArray, ToCanonical};
16
17#[derive(Clone, Debug)]
19pub enum Validity {
20 NonNullable,
22 AllValid,
24 AllInvalid,
26 Array(ArrayRef),
28}
29
30impl Validity {
31 pub const DTYPE: DType = DType::Bool(Nullability::NonNullable);
33
34 pub fn null_count(&self, length: usize) -> VortexResult<usize> {
35 match self {
36 Self::NonNullable | Self::AllValid => Ok(0),
37 Self::AllInvalid => Ok(length),
38 Self::Array(a) => {
39 let validity_len = a.len();
40 if validity_len != length {
41 vortex_bail!(
42 "Validity array length {} doesn't match array length {}",
43 validity_len,
44 length
45 )
46 }
47 let true_count = a
48 .as_bool_typed()
49 .vortex_expect("Validity array must be boolean")
50 .true_count()?;
51 Ok(length - true_count)
52 }
53 }
54 }
55
56 pub fn into_array(self) -> Option<ArrayRef> {
58 match self {
59 Self::Array(a) => Some(a),
60 _ => None,
61 }
62 }
63
64 pub fn as_array(&self) -> Option<&ArrayRef> {
66 match self {
67 Self::Array(a) => Some(a),
68 _ => None,
69 }
70 }
71
72 pub fn nullability(&self) -> Nullability {
73 match self {
74 Self::NonNullable => Nullability::NonNullable,
75 _ => Nullability::Nullable,
76 }
77 }
78
79 pub fn all_valid(&self) -> VortexResult<bool> {
80 Ok(match self {
81 Validity::NonNullable | Validity::AllValid => true,
82 Validity::AllInvalid => false,
83 Validity::Array(array) => {
84 array.to_bool()?.boolean_buffer().count_set_bits() == array.len()
86 }
87 })
88 }
89
90 pub fn all_invalid(&self) -> VortexResult<bool> {
91 Ok(match self {
92 Validity::NonNullable | Validity::AllValid => false,
93 Validity::AllInvalid => true,
94 Validity::Array(array) => {
95 array.to_bool()?.boolean_buffer().count_set_bits() == 0
97 }
98 })
99 }
100
101 #[inline]
103 pub fn is_valid(&self, index: usize) -> VortexResult<bool> {
104 Ok(match self {
105 Self::NonNullable | Self::AllValid => true,
106 Self::AllInvalid => false,
107 Self::Array(a) => {
108 let scalar = scalar_at(a, index)?;
109 scalar
110 .as_bool()
111 .value()
112 .vortex_expect("Validity must be non-nullable")
113 }
114 })
115 }
116
117 #[inline]
118 pub fn is_null(&self, index: usize) -> VortexResult<bool> {
119 Ok(!self.is_valid(index)?)
120 }
121
122 pub fn slice(&self, start: usize, stop: usize) -> VortexResult<Self> {
123 match self {
124 Self::Array(a) => Ok(Self::Array(slice(a, start, stop)?)),
125 _ => Ok(self.clone()),
126 }
127 }
128
129 pub fn take(&self, indices: &dyn Array) -> VortexResult<Self> {
130 match self {
131 Self::NonNullable => match indices.validity_mask()?.boolean_buffer() {
132 AllOr::All => {
133 if indices.dtype().is_nullable() {
134 Ok(Self::AllValid)
135 } else {
136 Ok(Self::NonNullable)
137 }
138 }
139 AllOr::None => Ok(Self::AllInvalid),
140 AllOr::Some(buf) => Ok(Validity::from(buf.clone())),
141 },
142 Self::AllValid => match indices.validity_mask()?.boolean_buffer() {
143 AllOr::All => Ok(Self::AllValid),
144 AllOr::None => Ok(Self::AllInvalid),
145 AllOr::Some(buf) => Ok(Validity::from(buf.clone())),
146 },
147 Self::AllInvalid => Ok(Self::AllInvalid),
148 Self::Array(is_valid) => {
149 let maybe_is_valid = take(is_valid, indices)?;
150 let is_valid = fill_null(&maybe_is_valid, Scalar::from(false))?;
152 Ok(Self::Array(is_valid))
153 }
154 }
155 }
156
157 pub fn filter(&self, mask: &Mask) -> VortexResult<Self> {
161 match self {
164 v @ (Validity::NonNullable | Validity::AllValid | Validity::AllInvalid) => {
165 Ok(v.clone())
166 }
167 Validity::Array(arr) => Ok(Validity::Array(filter(arr, mask)?)),
168 }
169 }
170
171 pub fn mask(&self, mask: &Mask) -> VortexResult<Self> {
175 match mask.boolean_buffer() {
176 AllOr::All => Ok(Validity::AllInvalid),
177 AllOr::None => Ok(self.clone()),
178 AllOr::Some(make_invalid) => Ok(match self {
179 Validity::NonNullable | Validity::AllValid => {
180 Validity::Array(BoolArray::from(make_invalid.not()).into_array())
181 }
182 Validity::AllInvalid => Validity::AllInvalid,
183 Validity::Array(is_valid) => {
184 let is_valid = is_valid.to_bool()?;
185 let keep_valid = make_invalid.not();
186 Validity::from(is_valid.boolean_buffer().bitand(&keep_valid))
187 }
188 }),
189 }
190 }
191
192 pub fn to_logical(&self, length: usize) -> VortexResult<Mask> {
194 Ok(match self {
195 Self::NonNullable | Self::AllValid => Mask::AllTrue(length),
196 Self::AllInvalid => Mask::AllFalse(length),
197 Self::Array(is_valid) => {
198 assert_eq!(
199 is_valid.len(),
200 length,
201 "Validity::Array length must equal to_logical's argument: {}, {}.",
202 is_valid.len(),
203 length,
204 );
205 Mask::try_from(&is_valid.to_bool()?)?
206 }
207 })
208 }
209
210 pub fn and(self, rhs: Validity) -> VortexResult<Validity> {
212 let validity = match (self, rhs) {
213 (Validity::NonNullable, Validity::NonNullable) => Validity::NonNullable,
215 (Validity::AllInvalid, _) | (_, Validity::AllInvalid) => Validity::AllInvalid,
217 (Validity::Array(a), Validity::AllValid)
219 | (Validity::Array(a), Validity::NonNullable)
220 | (Validity::NonNullable, Validity::Array(a))
221 | (Validity::AllValid, Validity::Array(a)) => Validity::Array(a),
222 (Validity::NonNullable, Validity::AllValid)
224 | (Validity::AllValid, Validity::NonNullable)
225 | (Validity::AllValid, Validity::AllValid) => Validity::AllValid,
226 (Validity::Array(lhs), Validity::Array(rhs)) => {
228 let lhs = lhs.to_bool()?;
229 let rhs = rhs.to_bool()?;
230
231 let lhs = lhs.boolean_buffer();
232 let rhs = rhs.boolean_buffer();
233
234 Validity::from(lhs.bitand(rhs))
235 }
236 };
237
238 Ok(validity)
239 }
240
241 pub fn patch(
242 self,
243 len: usize,
244 indices_offset: usize,
245 indices: &dyn Array,
246 patches: &Validity,
247 ) -> VortexResult<Self> {
248 match (&self, patches) {
249 (Validity::NonNullable, Validity::NonNullable) => return Ok(Validity::NonNullable),
250 (Validity::NonNullable, _) => {
251 vortex_bail!("Can't patch a non-nullable validity with nullable validity")
252 }
253 (_, Validity::NonNullable) => {
254 vortex_bail!("Can't patch a nullable validity with non-nullable validity")
255 }
256 (Validity::AllValid, Validity::AllValid) => return Ok(Validity::AllValid),
257 (Validity::AllInvalid, Validity::AllInvalid) => return Ok(Validity::AllInvalid),
258 _ => {}
259 };
260
261 let own_nullability = if self == Validity::NonNullable {
262 Nullability::NonNullable
263 } else {
264 Nullability::Nullable
265 };
266
267 let source = match self {
268 Validity::NonNullable => BoolArray::from(BooleanBuffer::new_set(len)),
269 Validity::AllValid => BoolArray::from(BooleanBuffer::new_set(len)),
270 Validity::AllInvalid => BoolArray::from(BooleanBuffer::new_unset(len)),
271 Validity::Array(a) => a.to_bool()?,
272 };
273
274 let patch_values = match patches {
275 Validity::NonNullable => BoolArray::from(BooleanBuffer::new_set(indices.len())),
276 Validity::AllValid => BoolArray::from(BooleanBuffer::new_set(indices.len())),
277 Validity::AllInvalid => BoolArray::from(BooleanBuffer::new_unset(indices.len())),
278 Validity::Array(a) => a.to_bool()?,
279 };
280
281 let patches = Patches::new(
282 len,
283 indices_offset,
284 indices.to_array(),
285 patch_values.into_array(),
286 );
287
288 Ok(Self::from_array(
289 source.patch(&patches)?.into_array(),
290 own_nullability,
291 ))
292 }
293
294 pub fn into_nullable(self) -> Validity {
296 match self {
297 Self::NonNullable => Self::AllValid,
298 _ => self,
299 }
300 }
301
302 pub fn into_non_nullable(self) -> Option<Validity> {
304 match self {
305 Self::NonNullable => Some(Self::NonNullable),
306 Self::AllValid => Some(Self::NonNullable),
307 Self::AllInvalid => None,
308 Self::Array(is_valid) => {
309 is_valid
310 .statistics()
311 .compute_min::<bool>()
312 .vortex_expect("validity array must support min")
313 .then(|| {
314 Self::NonNullable
316 })
317 }
318 }
319 }
320
321 pub fn cast_nullability(self, nullability: Nullability) -> VortexResult<Validity> {
323 match nullability {
324 Nullability::NonNullable => self.into_non_nullable().ok_or_else(|| {
325 vortex_err!("Cannot cast array with invalid values to non-nullable type.")
326 }),
327 Nullability::Nullable => Ok(self.into_nullable()),
328 }
329 }
330
331 pub fn copy_from_array(array: &dyn Array) -> VortexResult<Self> {
333 Ok(Validity::from_mask(
334 array.validity_mask()?,
335 array.dtype().nullability(),
336 ))
337 }
338
339 fn from_array(value: ArrayRef, nullability: Nullability) -> Self {
344 if !matches!(value.dtype(), DType::Bool(Nullability::NonNullable)) {
345 vortex_panic!("Expected a non-nullable boolean array")
346 }
347 match nullability {
348 Nullability::NonNullable => Self::NonNullable,
349 Nullability::Nullable => Self::Array(value),
350 }
351 }
352
353 pub fn maybe_len(&self) -> Option<usize> {
355 match self {
356 Self::NonNullable | Self::AllValid | Self::AllInvalid => None,
357 Self::Array(a) => Some(a.len()),
358 }
359 }
360
361 pub fn uncompressed_size(&self) -> usize {
362 if let Validity::Array(a) = self {
363 a.len().div_ceil(8)
364 } else {
365 0
366 }
367 }
368}
369
370impl PartialEq for Validity {
371 fn eq(&self, other: &Self) -> bool {
372 match (self, other) {
373 (Self::NonNullable, Self::NonNullable) => true,
374 (Self::AllValid, Self::AllValid) => true,
375 (Self::AllInvalid, Self::AllInvalid) => true,
376 (Self::Array(a), Self::Array(b)) => {
377 let a = a
378 .to_bool()
379 .vortex_expect("Failed to get Validity Array as BoolArray");
380 let b = b
381 .to_bool()
382 .vortex_expect("Failed to get Validity Array as BoolArray");
383 a.boolean_buffer() == b.boolean_buffer()
384 }
385 _ => false,
386 }
387 }
388}
389
390impl From<BooleanBuffer> for Validity {
391 fn from(value: BooleanBuffer) -> Self {
392 if value.count_set_bits() == value.len() {
393 Self::AllValid
394 } else if value.count_set_bits() == 0 {
395 Self::AllInvalid
396 } else {
397 Self::Array(BoolArray::from(value).into_array())
398 }
399 }
400}
401
402impl From<NullBuffer> for Validity {
403 fn from(value: NullBuffer) -> Self {
404 value.into_inner().into()
405 }
406}
407
408impl FromIterator<Mask> for Validity {
409 fn from_iter<T: IntoIterator<Item = Mask>>(iter: T) -> Self {
410 let validities: Vec<Mask> = iter.into_iter().collect();
411
412 if validities.iter().all(|v| v.all_true()) {
414 return Self::AllValid;
415 }
416 if validities.iter().all(|v| v.all_false()) {
418 return Self::AllInvalid;
419 }
420
421 let mut buffer = BooleanBufferBuilder::new(validities.iter().map(|v| v.len()).sum());
423 for validity in validities {
424 match validity {
425 Mask::AllTrue(count) => buffer.append_n(count, true),
426 Mask::AllFalse(count) => buffer.append_n(count, false),
427 Mask::Values(values) => {
428 buffer.append_buffer(values.boolean_buffer());
429 }
430 };
431 }
432 let bool_array = BoolArray::from(buffer.finish());
433 Self::Array(bool_array.into_array())
434 }
435}
436
437impl FromIterator<bool> for Validity {
438 fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
439 Validity::from(BooleanBuffer::from_iter(iter))
440 }
441}
442
443impl From<Nullability> for Validity {
444 fn from(value: Nullability) -> Self {
445 match value {
446 Nullability::NonNullable => Validity::NonNullable,
447 Nullability::Nullable => Validity::AllValid,
448 }
449 }
450}
451
452impl Validity {
453 pub fn from_mask(mask: Mask, nullability: Nullability) -> Self {
454 assert!(
455 nullability == Nullability::Nullable || matches!(mask, Mask::AllTrue(_)),
456 "NonNullable validity must be AllValid",
457 );
458 match mask {
459 Mask::AllTrue(_) => match nullability {
460 Nullability::NonNullable => Validity::NonNullable,
461 Nullability::Nullable => Validity::AllValid,
462 },
463 Mask::AllFalse(_) => Validity::AllInvalid,
464 Mask::Values(values) => Validity::Array(values.into_array()),
465 }
466 }
467}
468
469impl IntoArray for Mask {
470 fn into_array(self) -> ArrayRef {
471 match self {
472 Self::AllTrue(len) => ConstantArray::new(true, len).into_array(),
473 Self::AllFalse(len) => ConstantArray::new(false, len).into_array(),
474 Self::Values(a) => a.into_array(),
475 }
476 }
477}
478
479impl IntoArray for &MaskValues {
480 fn into_array(self) -> ArrayRef {
481 BoolArray::new(self.boolean_buffer().clone(), Validity::NonNullable).into_array()
482 }
483}
484
485#[cfg(test)]
486mod tests {
487 use rstest::rstest;
488 use vortex_buffer::{Buffer, buffer};
489 use vortex_dtype::Nullability;
490 use vortex_mask::Mask;
491
492 use crate::array::Array;
493 use crate::arrays::{BoolArray, PrimitiveArray};
494 use crate::validity::Validity;
495 use crate::{ArrayRef, IntoArray};
496
497 #[rstest]
498 #[case(Validity::AllValid, 5, &[2, 4], Validity::AllValid, Validity::AllValid)]
499 #[case(Validity::AllValid, 5, &[2, 4], Validity::AllInvalid, Validity::Array(BoolArray::from_iter([true, true, false, true, false]).into_array())
500 )]
501 #[case(Validity::AllValid, 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([true, true, true, true, false]).into_array())
502 )]
503 #[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllValid, Validity::Array(BoolArray::from_iter([false, false, true, false, true]).into_array())
504 )]
505 #[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllInvalid, Validity::AllInvalid)]
506 #[case(Validity::AllInvalid, 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([false, false, true, false, false]).into_array())
507 )]
508 #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::AllValid, Validity::Array(BoolArray::from_iter([false, true, true, true, true]).into_array())
509 )]
510 #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::AllInvalid, Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array())
511 )]
512 #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([false, true, true, true, false]).into_array())
513 )]
514 fn patch_validity(
515 #[case] validity: Validity,
516 #[case] len: usize,
517 #[case] positions: &[u64],
518 #[case] patches: Validity,
519 #[case] expected: Validity,
520 ) {
521 let indices =
522 PrimitiveArray::new(Buffer::copy_from(positions), Validity::NonNullable).into_array();
523 assert_eq!(
524 validity.patch(len, 0, &indices, &patches).unwrap(),
525 expected
526 );
527 }
528
529 #[test]
530 #[should_panic]
531 fn out_of_bounds_patch() {
532 Validity::NonNullable
533 .patch(2, 0, &buffer![4].into_array(), &Validity::AllInvalid)
534 .unwrap();
535 }
536
537 #[test]
538 #[should_panic]
539 fn into_validity_nullable() {
540 Validity::from_mask(Mask::AllFalse(10), Nullability::NonNullable);
541 }
542
543 #[test]
544 #[should_panic]
545 fn into_validity_nullable_array() {
546 Validity::from_mask(Mask::from_iter(vec![true, false]), Nullability::NonNullable);
547 }
548
549 #[rstest]
550 #[case(Validity::AllValid, PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(), Validity::from_iter(vec![true, false]))]
551 #[case(Validity::AllValid, buffer![0, 1].into_array(), Validity::AllValid)]
552 #[case(Validity::AllValid, PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(), Validity::AllInvalid)]
553 #[case(Validity::NonNullable, PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(), Validity::from_iter(vec![true, false]))]
554 #[case(Validity::NonNullable, buffer![0, 1].into_array(), Validity::NonNullable)]
555 #[case(Validity::NonNullable, PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(), Validity::AllInvalid)]
556 fn validity_take(
557 #[case] validity: Validity,
558 #[case] indices: ArrayRef,
559 #[case] expected: Validity,
560 ) {
561 assert_eq!(validity.take(&indices).unwrap(), expected);
562 }
563}