1use std::fmt::Debug;
7use std::ops::{BitAnd, Not};
8
9use arrow_buffer::{BooleanBuffer, NullBuffer};
10use vortex_dtype::{DType, Nullability};
11use vortex_error::{VortexExpect as _, VortexResult, vortex_bail, vortex_err, vortex_panic};
12use vortex_mask::{AllOr, Mask, MaskValues};
13use vortex_scalar::Scalar;
14
15use crate::arrays::{BoolArray, ConstantArray};
16use crate::compute::{fill_null, filter, sum, take};
17use crate::patches::Patches;
18use crate::{Array, ArrayRef, IntoArray, ToCanonical};
19
20#[derive(Clone, Debug)]
22pub enum Validity {
23 NonNullable,
25 AllValid,
27 AllInvalid,
29 Array(ArrayRef),
31}
32
33impl Validity {
34 pub const DTYPE: DType = DType::Bool(Nullability::NonNullable);
36
37 pub fn null_count(&self, length: usize) -> VortexResult<usize> {
38 match self {
39 Self::NonNullable | Self::AllValid => Ok(0),
40 Self::AllInvalid => Ok(length),
41 Self::Array(a) => {
42 let validity_len = a.len();
43 if validity_len != length {
44 vortex_bail!(
45 "Validity array length {} doesn't match array length {}",
46 validity_len,
47 length
48 )
49 }
50 let true_count = sum(a)?
51 .as_primitive()
52 .as_::<usize>()?
53 .ok_or_else(|| vortex_err!("Failed to compute true count"))?;
54 Ok(length - true_count)
55 }
56 }
57 }
58
59 pub fn into_array(self) -> Option<ArrayRef> {
61 match self {
62 Self::Array(a) => Some(a),
63 _ => None,
64 }
65 }
66
67 pub fn as_array(&self) -> Option<&ArrayRef> {
69 match self {
70 Self::Array(a) => Some(a),
71 _ => None,
72 }
73 }
74
75 pub fn nullability(&self) -> Nullability {
76 match self {
77 Self::NonNullable => Nullability::NonNullable,
78 _ => Nullability::Nullable,
79 }
80 }
81
82 pub fn union_nullability(self, nullability: Nullability) -> Self {
84 match nullability {
85 Nullability::NonNullable => self,
86 Nullability::Nullable => self.into_nullable(),
87 }
88 }
89
90 pub fn all_valid(&self) -> VortexResult<bool> {
91 Ok(match self {
92 Validity::NonNullable | Validity::AllValid => true,
93 Validity::AllInvalid => false,
94 Validity::Array(array) => sum(array)
95 .map(|v| {
96 v.as_primitive()
97 .typed_value::<u64>()
98 .map(|count| count == array.len() as u64)
99 })?
100 .ok_or_else(|| vortex_err!("Failed to compute sum for validity array"))?,
101 })
102 }
103
104 pub fn all_invalid(&self) -> VortexResult<bool> {
105 Ok(match self {
106 Validity::NonNullable | Validity::AllValid => false,
107 Validity::AllInvalid => true,
108 Validity::Array(array) => sum(array)
109 .map(|v| {
110 v.as_primitive()
111 .typed_value::<u64>()
112 .map(|count| count == 0u64)
113 })?
114 .ok_or_else(|| vortex_err!("Failed to compute sum for validity array"))?,
115 })
116 }
117
118 #[inline]
120 pub fn is_valid(&self, index: usize) -> VortexResult<bool> {
121 Ok(match self {
122 Self::NonNullable | Self::AllValid => true,
123 Self::AllInvalid => false,
124 Self::Array(a) => {
125 let scalar = a.scalar_at(index)?;
126 scalar
127 .as_bool()
128 .value()
129 .vortex_expect("Validity must be non-nullable")
130 }
131 })
132 }
133
134 #[inline]
135 pub fn is_null(&self, index: usize) -> VortexResult<bool> {
136 Ok(!self.is_valid(index)?)
137 }
138
139 pub fn slice(&self, start: usize, stop: usize) -> VortexResult<Self> {
140 match self {
141 Self::Array(a) => Ok(Self::Array(a.slice(start, stop)?)),
142 _ => Ok(self.clone()),
143 }
144 }
145
146 pub fn take(&self, indices: &dyn Array) -> VortexResult<Self> {
147 match self {
148 Self::NonNullable => match indices.validity_mask()?.boolean_buffer() {
149 AllOr::All => {
150 if indices.dtype().is_nullable() {
151 Ok(Self::AllValid)
152 } else {
153 Ok(Self::NonNullable)
154 }
155 }
156 AllOr::None => Ok(Self::AllInvalid),
157 AllOr::Some(buf) => Ok(Validity::from(buf.clone())),
158 },
159 Self::AllValid => match indices.validity_mask()?.boolean_buffer() {
160 AllOr::All => Ok(Self::AllValid),
161 AllOr::None => Ok(Self::AllInvalid),
162 AllOr::Some(buf) => Ok(Validity::from(buf.clone())),
163 },
164 Self::AllInvalid => Ok(Self::AllInvalid),
165 Self::Array(is_valid) => {
166 let maybe_is_valid = take(is_valid, indices)?;
167 let is_valid = fill_null(&maybe_is_valid, &Scalar::from(false))?;
169 Ok(Self::Array(is_valid))
170 }
171 }
172 }
173
174 pub fn filter(&self, mask: &Mask) -> VortexResult<Self> {
178 match self {
181 v @ (Validity::NonNullable | Validity::AllValid | Validity::AllInvalid) => {
182 Ok(v.clone())
183 }
184 Validity::Array(arr) => Ok(Validity::Array(filter(arr, mask)?)),
185 }
186 }
187
188 pub fn mask(&self, mask: &Mask) -> VortexResult<Self> {
192 match mask.boolean_buffer() {
193 AllOr::All => Ok(Validity::AllInvalid),
194 AllOr::None => Ok(self.clone()),
195 AllOr::Some(make_invalid) => Ok(match self {
196 Validity::NonNullable | Validity::AllValid => {
197 Validity::Array(BoolArray::from(make_invalid.not()).into_array())
198 }
199 Validity::AllInvalid => Validity::AllInvalid,
200 Validity::Array(is_valid) => {
201 let is_valid = is_valid.to_bool()?;
202 let keep_valid = make_invalid.not();
203 Validity::from(is_valid.boolean_buffer().bitand(&keep_valid))
204 }
205 }),
206 }
207 }
208
209 pub fn to_mask(&self, length: usize) -> VortexResult<Mask> {
210 Ok(match self {
211 Self::NonNullable | Self::AllValid => Mask::AllTrue(length),
212 Self::AllInvalid => Mask::AllFalse(length),
213 Self::Array(is_valid) => {
214 assert_eq!(
215 is_valid.len(),
216 length,
217 "Validity::Array length must equal to_logical's argument: {}, {}.",
218 is_valid.len(),
219 length,
220 );
221 Mask::try_from(&is_valid.to_bool()?)?
222 }
223 })
224 }
225
226 pub fn and(self, rhs: Validity) -> VortexResult<Validity> {
228 let validity = match (self, rhs) {
229 (Validity::NonNullable, Validity::NonNullable) => Validity::NonNullable,
231 (Validity::AllInvalid, _) | (_, Validity::AllInvalid) => Validity::AllInvalid,
233 (Validity::Array(a), Validity::AllValid)
235 | (Validity::Array(a), Validity::NonNullable)
236 | (Validity::NonNullable, Validity::Array(a))
237 | (Validity::AllValid, Validity::Array(a)) => Validity::Array(a),
238 (Validity::NonNullable, Validity::AllValid)
240 | (Validity::AllValid, Validity::NonNullable)
241 | (Validity::AllValid, Validity::AllValid) => Validity::AllValid,
242 (Validity::Array(lhs), Validity::Array(rhs)) => {
244 let lhs = lhs.to_bool()?;
245 let rhs = rhs.to_bool()?;
246
247 let lhs = lhs.boolean_buffer();
248 let rhs = rhs.boolean_buffer();
249
250 Validity::from(lhs.bitand(rhs))
251 }
252 };
253
254 Ok(validity)
255 }
256
257 pub fn patch(
258 self,
259 len: usize,
260 indices_offset: usize,
261 indices: &dyn Array,
262 patches: &Validity,
263 ) -> VortexResult<Self> {
264 match (&self, patches) {
265 (Validity::NonNullable, Validity::NonNullable) => return Ok(Validity::NonNullable),
266 (Validity::NonNullable, _) => {
267 vortex_bail!("Can't patch a non-nullable validity with nullable validity")
268 }
269 (_, Validity::NonNullable) => {
270 vortex_bail!("Can't patch a nullable validity with non-nullable validity")
271 }
272 (Validity::AllValid, Validity::AllValid) => return Ok(Validity::AllValid),
273 (Validity::AllInvalid, Validity::AllInvalid) => return Ok(Validity::AllInvalid),
274 _ => {}
275 };
276
277 let own_nullability = if self == Validity::NonNullable {
278 Nullability::NonNullable
279 } else {
280 Nullability::Nullable
281 };
282
283 let source = match self {
284 Validity::NonNullable => BoolArray::from(BooleanBuffer::new_set(len)),
285 Validity::AllValid => BoolArray::from(BooleanBuffer::new_set(len)),
286 Validity::AllInvalid => BoolArray::from(BooleanBuffer::new_unset(len)),
287 Validity::Array(a) => a.to_bool()?,
288 };
289
290 let patch_values = match patches {
291 Validity::NonNullable => BoolArray::from(BooleanBuffer::new_set(indices.len())),
292 Validity::AllValid => BoolArray::from(BooleanBuffer::new_set(indices.len())),
293 Validity::AllInvalid => BoolArray::from(BooleanBuffer::new_unset(indices.len())),
294 Validity::Array(a) => a.to_bool()?,
295 };
296
297 let patches = Patches::new(
298 len,
299 indices_offset,
300 indices.to_array(),
301 patch_values.into_array(),
302 );
303
304 Ok(Self::from_array(
305 source.patch(&patches)?.into_array(),
306 own_nullability,
307 ))
308 }
309
310 pub fn into_nullable(self) -> Validity {
312 match self {
313 Self::NonNullable => Self::AllValid,
314 _ => self,
315 }
316 }
317
318 pub fn into_non_nullable(self) -> Option<Validity> {
320 match self {
321 Self::NonNullable => Some(Self::NonNullable),
322 Self::AllValid => Some(Self::NonNullable),
323 Self::AllInvalid => None,
324 Self::Array(is_valid) => {
325 is_valid
326 .statistics()
327 .compute_min::<bool>()
328 .vortex_expect("validity array must support min")
329 .then(|| {
330 Self::NonNullable
332 })
333 }
334 }
335 }
336
337 pub fn cast_nullability(self, nullability: Nullability) -> VortexResult<Validity> {
339 match nullability {
340 Nullability::NonNullable => self.into_non_nullable().ok_or_else(|| {
341 vortex_err!("Cannot cast array with invalid values to non-nullable type.")
342 }),
343 Nullability::Nullable => Ok(self.into_nullable()),
344 }
345 }
346
347 pub fn copy_from_array(array: &dyn Array) -> VortexResult<Self> {
349 Ok(Validity::from_mask(
350 array.validity_mask()?,
351 array.dtype().nullability(),
352 ))
353 }
354
355 fn from_array(value: ArrayRef, nullability: Nullability) -> Self {
360 if !matches!(value.dtype(), DType::Bool(Nullability::NonNullable)) {
361 vortex_panic!("Expected a non-nullable boolean array")
362 }
363 match nullability {
364 Nullability::NonNullable => Self::NonNullable,
365 Nullability::Nullable => Self::Array(value),
366 }
367 }
368
369 pub fn maybe_len(&self) -> Option<usize> {
371 match self {
372 Self::NonNullable | Self::AllValid | Self::AllInvalid => None,
373 Self::Array(a) => Some(a.len()),
374 }
375 }
376
377 pub fn uncompressed_size(&self) -> usize {
378 if let Validity::Array(a) = self {
379 a.len().div_ceil(8)
380 } else {
381 0
382 }
383 }
384
385 pub fn is_array(&self) -> bool {
386 matches!(self, Validity::Array(_))
387 }
388}
389
390impl PartialEq for Validity {
391 fn eq(&self, other: &Self) -> bool {
392 match (self, other) {
393 (Self::NonNullable, Self::NonNullable) => true,
394 (Self::AllValid, Self::AllValid) => true,
395 (Self::AllInvalid, Self::AllInvalid) => true,
396 (Self::Array(a), Self::Array(b)) => {
397 let a = a
398 .to_bool()
399 .vortex_expect("Failed to get Validity Array as BoolArray");
400 let b = b
401 .to_bool()
402 .vortex_expect("Failed to get Validity Array as BoolArray");
403 a.boolean_buffer() == b.boolean_buffer()
404 }
405 _ => false,
406 }
407 }
408}
409
410impl From<BooleanBuffer> for Validity {
411 fn from(value: BooleanBuffer) -> Self {
412 if value.count_set_bits() == value.len() {
413 Self::AllValid
414 } else if value.count_set_bits() == 0 {
415 Self::AllInvalid
416 } else {
417 Self::Array(BoolArray::from(value).into_array())
418 }
419 }
420}
421
422impl From<NullBuffer> for Validity {
423 fn from(value: NullBuffer) -> Self {
424 value.into_inner().into()
425 }
426}
427
428impl FromIterator<Mask> for Validity {
429 fn from_iter<T: IntoIterator<Item = Mask>>(iter: T) -> Self {
430 Validity::from_mask(iter.into_iter().collect(), Nullability::Nullable)
431 }
432}
433
434impl FromIterator<bool> for Validity {
435 fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
436 Validity::from(BooleanBuffer::from_iter(iter))
437 }
438}
439
440impl From<Nullability> for Validity {
441 fn from(value: Nullability) -> Self {
442 match value {
443 Nullability::NonNullable => Validity::NonNullable,
444 Nullability::Nullable => Validity::AllValid,
445 }
446 }
447}
448
449impl Validity {
450 pub fn from_mask(mask: Mask, nullability: Nullability) -> Self {
451 assert!(
452 nullability == Nullability::Nullable || matches!(mask, Mask::AllTrue(_)),
453 "NonNullable validity must be AllValid",
454 );
455 match mask {
456 Mask::AllTrue(_) => match nullability {
457 Nullability::NonNullable => Validity::NonNullable,
458 Nullability::Nullable => Validity::AllValid,
459 },
460 Mask::AllFalse(_) => Validity::AllInvalid,
461 Mask::Values(values) => Validity::Array(values.into_array()),
462 }
463 }
464}
465
466impl IntoArray for Mask {
467 fn into_array(self) -> ArrayRef {
468 match self {
469 Self::AllTrue(len) => ConstantArray::new(true, len).into_array(),
470 Self::AllFalse(len) => ConstantArray::new(false, len).into_array(),
471 Self::Values(a) => a.into_array(),
472 }
473 }
474}
475
476impl IntoArray for &MaskValues {
477 fn into_array(self) -> ArrayRef {
478 BoolArray::new(self.boolean_buffer().clone(), Validity::NonNullable).into_array()
479 }
480}
481
482#[cfg(test)]
483mod tests {
484 use rstest::rstest;
485 use vortex_buffer::{Buffer, buffer};
486 use vortex_dtype::Nullability;
487 use vortex_mask::Mask;
488
489 use crate::arrays::{BoolArray, PrimitiveArray};
490 use crate::validity::Validity;
491 use crate::{ArrayRef, IntoArray};
492
493 #[rstest]
494 #[case(Validity::AllValid, 5, &[2, 4], Validity::AllValid, Validity::AllValid)]
495 #[case(Validity::AllValid, 5, &[2, 4], Validity::AllInvalid, Validity::Array(BoolArray::from_iter([true, true, false, true, false]).into_array())
496 )]
497 #[case(Validity::AllValid, 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([true, true, true, true, false]).into_array())
498 )]
499 #[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllValid, Validity::Array(BoolArray::from_iter([false, false, true, false, true]).into_array())
500 )]
501 #[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllInvalid, Validity::AllInvalid)]
502 #[case(Validity::AllInvalid, 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([false, false, true, false, false]).into_array())
503 )]
504 #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::AllValid, Validity::Array(BoolArray::from_iter([false, true, true, true, true]).into_array())
505 )]
506 #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::AllInvalid, Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array())
507 )]
508 #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([false, true, true, true, false]).into_array())
509 )]
510 fn patch_validity(
511 #[case] validity: Validity,
512 #[case] len: usize,
513 #[case] positions: &[u64],
514 #[case] patches: Validity,
515 #[case] expected: Validity,
516 ) {
517 let indices =
518 PrimitiveArray::new(Buffer::copy_from(positions), Validity::NonNullable).into_array();
519 assert_eq!(
520 validity.patch(len, 0, &indices, &patches).unwrap(),
521 expected
522 );
523 }
524
525 #[test]
526 #[should_panic]
527 fn out_of_bounds_patch() {
528 Validity::NonNullable
529 .patch(2, 0, &buffer![4].into_array(), &Validity::AllInvalid)
530 .unwrap();
531 }
532
533 #[test]
534 #[should_panic]
535 fn into_validity_nullable() {
536 Validity::from_mask(Mask::AllFalse(10), Nullability::NonNullable);
537 }
538
539 #[test]
540 #[should_panic]
541 fn into_validity_nullable_array() {
542 Validity::from_mask(Mask::from_iter(vec![true, false]), Nullability::NonNullable);
543 }
544
545 #[rstest]
546 #[case(Validity::AllValid, PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(), Validity::from_iter(vec![true, false]))]
547 #[case(Validity::AllValid, buffer![0, 1].into_array(), Validity::AllValid)]
548 #[case(Validity::AllValid, PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(), Validity::AllInvalid)]
549 #[case(Validity::NonNullable, PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(), Validity::from_iter(vec![true, false]))]
550 #[case(Validity::NonNullable, buffer![0, 1].into_array(), Validity::NonNullable)]
551 #[case(Validity::NonNullable, PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(), Validity::AllInvalid)]
552 fn validity_take(
553 #[case] validity: Validity,
554 #[case] indices: ArrayRef,
555 #[case] expected: Validity,
556 ) {
557 assert_eq!(validity.take(&indices).unwrap(), expected);
558 }
559}