1use std::fmt::Debug;
7use std::ops::{BitAnd, Not};
8
9use arrow_buffer::{BooleanBuffer, NullBuffer};
10use vortex_dtype::{DType, Nullability};
11use vortex_error::{VortexExpect as _, VortexResult, vortex_bail, vortex_err, vortex_panic};
12use vortex_mask::{AllOr, Mask, MaskValues};
13use vortex_scalar::Scalar;
14
15use crate::arrays::{BoolArray, ConstantArray};
16use crate::compute::{fill_null, filter, sum, take};
17use crate::patches::Patches;
18use crate::{Array, ArrayRef, IntoArray, ToCanonical};
19
20#[derive(Clone, Debug)]
22pub enum Validity {
23 NonNullable,
25 AllValid,
27 AllInvalid,
29 Array(ArrayRef),
31}
32
33impl Validity {
34 pub const DTYPE: DType = DType::Bool(Nullability::NonNullable);
36
37 pub fn null_count(&self, length: usize) -> VortexResult<usize> {
38 match self {
39 Self::NonNullable | Self::AllValid => Ok(0),
40 Self::AllInvalid => Ok(length),
41 Self::Array(a) => {
42 let validity_len = a.len();
43 if validity_len != length {
44 vortex_bail!(
45 "Validity array length {} doesn't match array length {}",
46 validity_len,
47 length
48 )
49 }
50 let true_count = sum(a)?
51 .as_primitive()
52 .as_::<usize>()
53 .ok_or_else(|| vortex_err!("Failed to compute true count"))?;
54 Ok(length - true_count)
55 }
56 }
57 }
58
59 pub fn into_array(self) -> Option<ArrayRef> {
61 if let Self::Array(a) = self {
62 Some(a)
63 } else {
64 None
65 }
66 }
67
68 pub fn as_array(&self) -> Option<&ArrayRef> {
70 if let Self::Array(a) = self {
71 Some(a)
72 } else {
73 None
74 }
75 }
76
77 pub fn nullability(&self) -> Nullability {
78 if matches!(self, Self::NonNullable) {
79 Nullability::NonNullable
80 } else {
81 Nullability::Nullable
82 }
83 }
84
85 pub fn union_nullability(self, nullability: Nullability) -> Self {
87 match nullability {
88 Nullability::NonNullable => self,
89 Nullability::Nullable => self.into_nullable(),
90 }
91 }
92
93 pub fn all_valid(&self) -> VortexResult<bool> {
94 Ok(match self {
95 Validity::NonNullable | Validity::AllValid => true,
96 Validity::AllInvalid => false,
97 Validity::Array(array) => sum(array)
98 .map(|v| {
99 v.as_primitive()
100 .typed_value::<u64>()
101 .map(|count| count == array.len() as u64)
102 })?
103 .ok_or_else(|| vortex_err!("Failed to compute sum for validity array"))?,
104 })
105 }
106
107 pub fn all_invalid(&self) -> VortexResult<bool> {
108 Ok(match self {
109 Validity::NonNullable | Validity::AllValid => false,
110 Validity::AllInvalid => true,
111 Validity::Array(array) => sum(array)
112 .map(|v| {
113 v.as_primitive()
114 .typed_value::<u64>()
115 .map(|count| count == 0u64)
116 })?
117 .ok_or_else(|| vortex_err!("Failed to compute sum for validity array"))?,
118 })
119 }
120
121 #[inline]
123 pub fn is_valid(&self, index: usize) -> VortexResult<bool> {
124 Ok(match self {
125 Self::NonNullable | Self::AllValid => true,
126 Self::AllInvalid => false,
127 Self::Array(a) => {
128 let scalar = a.scalar_at(index);
129 scalar
130 .as_bool()
131 .value()
132 .vortex_expect("Validity must be non-nullable")
133 }
134 })
135 }
136
137 #[inline]
138 pub fn is_null(&self, index: usize) -> VortexResult<bool> {
139 Ok(!self.is_valid(index)?)
140 }
141
142 pub fn slice(&self, start: usize, stop: usize) -> Self {
143 match self {
144 Self::Array(a) => Self::Array(a.slice(start, stop)),
145 Self::NonNullable | Self::AllValid | Self::AllInvalid => self.clone(),
146 }
147 }
148
149 pub fn take(&self, indices: &dyn Array) -> VortexResult<Self> {
150 match self {
151 Self::NonNullable => match indices.validity_mask()?.boolean_buffer() {
152 AllOr::All => {
153 if indices.dtype().is_nullable() {
154 Ok(Self::AllValid)
155 } else {
156 Ok(Self::NonNullable)
157 }
158 }
159 AllOr::None => Ok(Self::AllInvalid),
160 AllOr::Some(buf) => Ok(Validity::from(buf.clone())),
161 },
162 Self::AllValid => match indices.validity_mask()?.boolean_buffer() {
163 AllOr::All => Ok(Self::AllValid),
164 AllOr::None => Ok(Self::AllInvalid),
165 AllOr::Some(buf) => Ok(Validity::from(buf.clone())),
166 },
167 Self::AllInvalid => Ok(Self::AllInvalid),
168 Self::Array(is_valid) => {
169 let maybe_is_valid = take(is_valid, indices)?;
170 let is_valid = fill_null(&maybe_is_valid, &Scalar::from(false))?;
172 Ok(Self::Array(is_valid))
173 }
174 }
175 }
176
177 pub fn filter(&self, mask: &Mask) -> VortexResult<Self> {
181 match self {
184 v @ (Validity::NonNullable | Validity::AllValid | Validity::AllInvalid) => {
185 Ok(v.clone())
186 }
187 Validity::Array(arr) => Ok(Validity::Array(filter(arr, mask)?)),
188 }
189 }
190
191 pub fn mask(&self, mask: &Mask) -> VortexResult<Self> {
195 match mask.boolean_buffer() {
196 AllOr::All => Ok(Validity::AllInvalid),
197 AllOr::None => Ok(self.clone()),
198 AllOr::Some(make_invalid) => Ok(match self {
199 Validity::NonNullable | Validity::AllValid => {
200 Validity::Array(BoolArray::from(make_invalid.not()).into_array())
201 }
202 Validity::AllInvalid => Validity::AllInvalid,
203 Validity::Array(is_valid) => {
204 let is_valid = is_valid.to_bool()?;
205 let keep_valid = make_invalid.not();
206 Validity::from(is_valid.boolean_buffer().bitand(&keep_valid))
207 }
208 }),
209 }
210 }
211
212 pub fn to_mask(&self, length: usize) -> VortexResult<Mask> {
213 Ok(match self {
214 Self::NonNullable | Self::AllValid => Mask::AllTrue(length),
215 Self::AllInvalid => Mask::AllFalse(length),
216 Self::Array(is_valid) => {
217 assert_eq!(
218 is_valid.len(),
219 length,
220 "Validity::Array length must equal to_logical's argument: {}, {}.",
221 is_valid.len(),
222 length,
223 );
224 Mask::try_from(&is_valid.to_bool()?)?
225 }
226 })
227 }
228
229 pub fn and(self, rhs: Validity) -> VortexResult<Validity> {
231 let validity = match (self, rhs) {
232 (Validity::NonNullable, Validity::NonNullable) => Validity::NonNullable,
234 (Validity::AllInvalid, _) | (_, Validity::AllInvalid) => Validity::AllInvalid,
236 (Validity::Array(a), Validity::AllValid)
238 | (Validity::Array(a), Validity::NonNullable)
239 | (Validity::NonNullable, Validity::Array(a))
240 | (Validity::AllValid, Validity::Array(a)) => Validity::Array(a),
241 (Validity::NonNullable, Validity::AllValid)
243 | (Validity::AllValid, Validity::NonNullable)
244 | (Validity::AllValid, Validity::AllValid) => Validity::AllValid,
245 (Validity::Array(lhs), Validity::Array(rhs)) => {
247 let lhs = lhs.to_bool()?;
248 let rhs = rhs.to_bool()?;
249
250 let lhs = lhs.boolean_buffer();
251 let rhs = rhs.boolean_buffer();
252
253 Validity::from(lhs.bitand(rhs))
254 }
255 };
256
257 Ok(validity)
258 }
259
260 pub fn patch(
261 self,
262 len: usize,
263 indices_offset: usize,
264 indices: &dyn Array,
265 patches: &Validity,
266 ) -> VortexResult<Self> {
267 match (&self, patches) {
268 (Validity::NonNullable, Validity::NonNullable) => return Ok(Validity::NonNullable),
269 (Validity::NonNullable, _) => {
270 vortex_bail!("Can't patch a non-nullable validity with nullable validity")
271 }
272 (_, Validity::NonNullable) => {
273 vortex_bail!("Can't patch a nullable validity with non-nullable validity")
274 }
275 (Validity::AllValid, Validity::AllValid) => return Ok(Validity::AllValid),
276 (Validity::AllInvalid, Validity::AllInvalid) => return Ok(Validity::AllInvalid),
277 _ => {}
278 };
279
280 let own_nullability = if self == Validity::NonNullable {
281 Nullability::NonNullable
282 } else {
283 Nullability::Nullable
284 };
285
286 let source = match self {
287 Validity::NonNullable => BoolArray::from(BooleanBuffer::new_set(len)),
288 Validity::AllValid => BoolArray::from(BooleanBuffer::new_set(len)),
289 Validity::AllInvalid => BoolArray::from(BooleanBuffer::new_unset(len)),
290 Validity::Array(a) => a.to_bool()?,
291 };
292
293 let patch_values = match patches {
294 Validity::NonNullable => BoolArray::from(BooleanBuffer::new_set(indices.len())),
295 Validity::AllValid => BoolArray::from(BooleanBuffer::new_set(indices.len())),
296 Validity::AllInvalid => BoolArray::from(BooleanBuffer::new_unset(indices.len())),
297 Validity::Array(a) => a.to_bool()?,
298 };
299
300 let patches = Patches::new(
301 len,
302 indices_offset,
303 indices.to_array(),
304 patch_values.into_array(),
305 );
306
307 Ok(Self::from_array(
308 source.patch(&patches)?.into_array(),
309 own_nullability,
310 ))
311 }
312
313 pub fn into_nullable(self) -> Validity {
315 match self {
316 Self::NonNullable => Self::AllValid,
317 Self::AllValid | Self::AllInvalid | Self::Array(_) => self,
318 }
319 }
320
321 pub fn into_non_nullable(self) -> Option<Validity> {
323 match self {
324 Self::NonNullable => Some(Self::NonNullable),
325 Self::AllValid => Some(Self::NonNullable),
326 Self::AllInvalid => None,
327 Self::Array(is_valid) => {
328 is_valid
329 .statistics()
330 .compute_min::<bool>()
331 .vortex_expect("validity array must support min")
332 .then(|| {
333 Self::NonNullable
335 })
336 }
337 }
338 }
339
340 pub fn cast_nullability(self, nullability: Nullability) -> VortexResult<Validity> {
342 match nullability {
343 Nullability::NonNullable => self.into_non_nullable().ok_or_else(|| {
344 vortex_err!("Cannot cast array with invalid values to non-nullable type.")
345 }),
346 Nullability::Nullable => Ok(self.into_nullable()),
347 }
348 }
349
350 pub fn copy_from_array(array: &dyn Array) -> VortexResult<Self> {
352 Ok(Validity::from_mask(
353 array.validity_mask()?,
354 array.dtype().nullability(),
355 ))
356 }
357
358 fn from_array(value: ArrayRef, nullability: Nullability) -> Self {
363 if !matches!(value.dtype(), DType::Bool(Nullability::NonNullable)) {
364 vortex_panic!("Expected a non-nullable boolean array")
365 }
366 match nullability {
367 Nullability::NonNullable => Self::NonNullable,
368 Nullability::Nullable => Self::Array(value),
369 }
370 }
371
372 pub fn maybe_len(&self) -> Option<usize> {
374 match self {
375 Self::NonNullable | Self::AllValid | Self::AllInvalid => None,
376 Self::Array(a) => Some(a.len()),
377 }
378 }
379
380 pub fn uncompressed_size(&self) -> usize {
381 if let Validity::Array(a) = self {
382 a.len().div_ceil(8)
383 } else {
384 0
385 }
386 }
387
388 pub fn is_array(&self) -> bool {
389 matches!(self, Validity::Array(_))
390 }
391}
392
393impl PartialEq for Validity {
394 fn eq(&self, other: &Self) -> bool {
395 match (self, other) {
396 (Self::NonNullable, Self::NonNullable) => true,
397 (Self::AllValid, Self::AllValid) => true,
398 (Self::AllInvalid, Self::AllInvalid) => true,
399 (Self::Array(a), Self::Array(b)) => {
400 let a = a
401 .to_bool()
402 .vortex_expect("Failed to get Validity Array as BoolArray");
403 let b = b
404 .to_bool()
405 .vortex_expect("Failed to get Validity Array as BoolArray");
406 a.boolean_buffer() == b.boolean_buffer()
407 }
408 _ => false,
409 }
410 }
411}
412
413impl From<BooleanBuffer> for Validity {
414 fn from(value: BooleanBuffer) -> Self {
415 if value.count_set_bits() == value.len() {
416 Self::AllValid
417 } else if value.count_set_bits() == 0 {
418 Self::AllInvalid
419 } else {
420 Self::Array(BoolArray::from(value).into_array())
421 }
422 }
423}
424
425impl From<NullBuffer> for Validity {
426 fn from(value: NullBuffer) -> Self {
427 value.into_inner().into()
428 }
429}
430
431impl FromIterator<Mask> for Validity {
432 fn from_iter<T: IntoIterator<Item = Mask>>(iter: T) -> Self {
433 Validity::from_mask(iter.into_iter().collect(), Nullability::Nullable)
434 }
435}
436
437impl FromIterator<bool> for Validity {
438 fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
439 Validity::from(BooleanBuffer::from_iter(iter))
440 }
441}
442
443impl From<Nullability> for Validity {
444 fn from(value: Nullability) -> Self {
445 match value {
446 Nullability::NonNullable => Validity::NonNullable,
447 Nullability::Nullable => Validity::AllValid,
448 }
449 }
450}
451
452impl Validity {
453 pub fn from_mask(mask: Mask, nullability: Nullability) -> Self {
454 assert!(
455 nullability == Nullability::Nullable || matches!(mask, Mask::AllTrue(_)),
456 "NonNullable validity must be AllValid",
457 );
458 match mask {
459 Mask::AllTrue(_) => match nullability {
460 Nullability::NonNullable => Validity::NonNullable,
461 Nullability::Nullable => Validity::AllValid,
462 },
463 Mask::AllFalse(_) => Validity::AllInvalid,
464 Mask::Values(values) => Validity::Array(values.into_array()),
465 }
466 }
467}
468
469impl IntoArray for Mask {
470 fn into_array(self) -> ArrayRef {
471 match self {
472 Self::AllTrue(len) => ConstantArray::new(true, len).into_array(),
473 Self::AllFalse(len) => ConstantArray::new(false, len).into_array(),
474 Self::Values(a) => a.into_array(),
475 }
476 }
477}
478
479impl IntoArray for &MaskValues {
480 fn into_array(self) -> ArrayRef {
481 BoolArray::new(self.boolean_buffer().clone(), Validity::NonNullable).into_array()
482 }
483}
484
485#[cfg(test)]
486mod tests {
487 use rstest::rstest;
488 use vortex_buffer::{Buffer, buffer};
489 use vortex_dtype::Nullability;
490 use vortex_mask::Mask;
491
492 use crate::arrays::{BoolArray, PrimitiveArray};
493 use crate::validity::Validity;
494 use crate::{ArrayRef, IntoArray};
495
496 #[rstest]
497 #[case(Validity::AllValid, 5, &[2, 4], Validity::AllValid, Validity::AllValid)]
498 #[case(Validity::AllValid, 5, &[2, 4], Validity::AllInvalid, Validity::Array(BoolArray::from_iter([true, true, false, true, false]).into_array())
499 )]
500 #[case(Validity::AllValid, 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([true, true, true, true, false]).into_array())
501 )]
502 #[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllValid, Validity::Array(BoolArray::from_iter([false, false, true, false, true]).into_array())
503 )]
504 #[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllInvalid, Validity::AllInvalid)]
505 #[case(Validity::AllInvalid, 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([false, false, true, false, false]).into_array())
506 )]
507 #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::AllValid, Validity::Array(BoolArray::from_iter([false, true, true, true, true]).into_array())
508 )]
509 #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::AllInvalid, Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array())
510 )]
511 #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([false, true, true, true, false]).into_array())
512 )]
513 fn patch_validity(
514 #[case] validity: Validity,
515 #[case] len: usize,
516 #[case] positions: &[u64],
517 #[case] patches: Validity,
518 #[case] expected: Validity,
519 ) {
520 let indices =
521 PrimitiveArray::new(Buffer::copy_from(positions), Validity::NonNullable).into_array();
522 assert_eq!(
523 validity.patch(len, 0, &indices, &patches).unwrap(),
524 expected
525 );
526 }
527
528 #[test]
529 #[should_panic]
530 fn out_of_bounds_patch() {
531 Validity::NonNullable
532 .patch(2, 0, &buffer![4].into_array(), &Validity::AllInvalid)
533 .unwrap();
534 }
535
536 #[test]
537 #[should_panic]
538 fn into_validity_nullable() {
539 Validity::from_mask(Mask::AllFalse(10), Nullability::NonNullable);
540 }
541
542 #[test]
543 #[should_panic]
544 fn into_validity_nullable_array() {
545 Validity::from_mask(Mask::from_iter(vec![true, false]), Nullability::NonNullable);
546 }
547
548 #[rstest]
549 #[case(Validity::AllValid, PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(), Validity::from_iter(vec![true, false]))]
550 #[case(Validity::AllValid, buffer![0, 1].into_array(), Validity::AllValid)]
551 #[case(Validity::AllValid, PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(), Validity::AllInvalid)]
552 #[case(Validity::NonNullable, PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(), Validity::from_iter(vec![true, false]))]
553 #[case(Validity::NonNullable, buffer![0, 1].into_array(), Validity::NonNullable)]
554 #[case(Validity::NonNullable, PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(), Validity::AllInvalid)]
555 fn validity_take(
556 #[case] validity: Validity,
557 #[case] indices: ArrayRef,
558 #[case] expected: Validity,
559 ) {
560 assert_eq!(validity.take(&indices).unwrap(), expected);
561 }
562}