1use std::fmt::Debug;
7use std::ops::Range;
8
9use vortex_buffer::BitBuffer;
10use vortex_error::VortexExpect as _;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_error::vortex_err;
14use vortex_error::vortex_panic;
15use vortex_mask::AllOr;
16use vortex_mask::Mask;
17use vortex_mask::MaskValues;
18
19use crate::ArrayRef;
20use crate::Canonical;
21use crate::DynArray;
22use crate::ExecutionCtx;
23use crate::IntoArray;
24use crate::ToCanonical;
25use crate::arrays::BoolArray;
26use crate::arrays::ConstantArray;
27use crate::arrays::scalar_fn::ScalarFnArrayExt;
28use crate::builtins::ArrayBuiltins;
29use crate::compute::sum;
30use crate::dtype::DType;
31use crate::dtype::Nullability;
32use crate::optimizer::ArrayOptimizer;
33use crate::patches::Patches;
34use crate::scalar::Scalar;
35use crate::scalar_fn::fns::binary::Binary;
36use crate::scalar_fn::fns::operators::Operator;
37
38#[derive(Clone)]
40pub enum Validity {
41 NonNullable,
43 AllValid,
45 AllInvalid,
47 Array(ArrayRef),
51}
52
53impl Debug for Validity {
54 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55 match self {
56 Self::NonNullable => write!(f, "NonNullable"),
57 Self::AllValid => write!(f, "AllValid"),
58 Self::AllInvalid => write!(f, "AllInvalid"),
59 Self::Array(arr) => write!(f, "SomeValid({})", arr.as_ref().display_values()),
60 }
61 }
62}
63
64impl Validity {
65 pub fn execute(self, ctx: &mut ExecutionCtx) -> VortexResult<Validity> {
67 match self {
68 v @ Validity::NonNullable | v @ Validity::AllValid | v @ Validity::AllInvalid => Ok(v),
69 Validity::Array(a) => Ok(Validity::Array(a.execute::<Canonical>(ctx)?.into_array())),
70 }
71 }
72}
73
74impl Validity {
75 pub const DTYPE: DType = DType::Bool(Nullability::NonNullable);
77
78 pub fn to_array(&self, len: usize) -> ArrayRef {
80 match self {
81 Self::NonNullable | Self::AllValid => ConstantArray::new(true, len).into_array(),
82 Self::AllInvalid => ConstantArray::new(false, len).into_array(),
83 Self::Array(a) => a.clone(),
84 }
85 }
86
87 #[inline]
89 pub fn into_array(self) -> Option<ArrayRef> {
90 if let Self::Array(a) = self {
91 Some(a)
92 } else {
93 None
94 }
95 }
96
97 #[inline]
99 pub fn as_array(&self) -> Option<&ArrayRef> {
100 if let Self::Array(a) = self {
101 Some(a)
102 } else {
103 None
104 }
105 }
106
107 #[inline]
108 pub fn nullability(&self) -> Nullability {
109 if matches!(self, Self::NonNullable) {
110 Nullability::NonNullable
111 } else {
112 Nullability::Nullable
113 }
114 }
115
116 #[inline]
118 pub fn union_nullability(self, nullability: Nullability) -> Self {
119 match nullability {
120 Nullability::NonNullable => self,
121 Nullability::Nullable => self.into_nullable(),
122 }
123 }
124
125 #[inline]
126 pub fn all_valid(&self, len: usize) -> VortexResult<bool> {
127 Ok(match self {
128 _ if len == 0 => true,
129 Validity::NonNullable | Validity::AllValid => true,
130 Validity::AllInvalid => false,
131 Validity::Array(array) => {
132 usize::try_from(&sum(array).vortex_expect("must have sum for bool array"))
133 .vortex_expect("sum must be a usize")
134 == array.len()
135 }
136 })
137 }
138
139 #[inline]
140 pub fn all_invalid(&self, len: usize) -> VortexResult<bool> {
141 Ok(match self {
142 _ if len == 0 => true,
143 Validity::NonNullable | Validity::AllValid => false,
144 Validity::AllInvalid => true,
145 Validity::Array(array) => {
146 usize::try_from(&sum(array).vortex_expect("must have sum for bool array"))
147 .vortex_expect("sum must be a usize")
148 == 0
149 }
150 })
151 }
152
153 #[inline]
155 pub fn is_valid(&self, index: usize) -> VortexResult<bool> {
156 Ok(match self {
157 Self::NonNullable | Self::AllValid => true,
158 Self::AllInvalid => false,
159 Self::Array(a) => a
160 .scalar_at(index)
161 .vortex_expect("Validity array must support scalar_at")
162 .as_bool()
163 .value()
164 .vortex_expect("Validity must be non-nullable"),
165 })
166 }
167
168 #[inline]
169 pub fn is_null(&self, index: usize) -> VortexResult<bool> {
170 Ok(!self.is_valid(index)?)
171 }
172
173 #[inline]
174 pub fn slice(&self, range: Range<usize>) -> VortexResult<Self> {
175 match self {
176 Self::Array(a) => Ok(Self::Array(a.slice(range)?)),
177 Self::NonNullable | Self::AllValid | Self::AllInvalid => Ok(self.clone()),
178 }
179 }
180
181 pub fn take(&self, indices: &ArrayRef) -> VortexResult<Self> {
182 match self {
183 Self::NonNullable => match indices.validity_mask()?.bit_buffer() {
184 AllOr::All => {
185 if indices.dtype().is_nullable() {
186 Ok(Self::AllValid)
187 } else {
188 Ok(Self::NonNullable)
189 }
190 }
191 AllOr::None => Ok(Self::AllInvalid),
192 AllOr::Some(buf) => Ok(Validity::from(buf.clone())),
193 },
194 Self::AllValid => match indices.validity_mask()?.bit_buffer() {
195 AllOr::All => Ok(Self::AllValid),
196 AllOr::None => Ok(Self::AllInvalid),
197 AllOr::Some(buf) => Ok(Validity::from(buf.clone())),
198 },
199 Self::AllInvalid => Ok(Self::AllInvalid),
200 Self::Array(is_valid) => {
201 let maybe_is_valid = is_valid.take(indices.to_array())?;
202 let is_valid = maybe_is_valid.fill_null(Scalar::from(false))?;
204 Ok(Self::Array(is_valid))
205 }
206 }
207 }
208
209 pub fn not(&self) -> VortexResult<Self> {
211 match self {
212 Validity::NonNullable => Ok(Validity::NonNullable),
213 Validity::AllValid => Ok(Validity::AllInvalid),
214 Validity::AllInvalid => Ok(Validity::AllValid),
215 Validity::Array(arr) => Ok(Validity::Array(arr.not()?)),
216 }
217 }
218
219 pub fn filter(&self, mask: &Mask) -> VortexResult<Self> {
227 match self {
230 v @ (Validity::NonNullable | Validity::AllValid | Validity::AllInvalid) => {
231 Ok(v.clone())
232 }
233 Validity::Array(arr) => Ok(Validity::Array(arr.filter(mask.clone())?)),
234 }
235 }
236
237 #[inline]
238 pub fn to_mask(&self, length: usize) -> Mask {
239 match self {
240 Self::NonNullable | Self::AllValid => Mask::AllTrue(length),
241 Self::AllInvalid => Mask::AllFalse(length),
242 Self::Array(is_valid) => {
243 assert_eq!(
244 is_valid.len(),
245 length,
246 "Validity::Array length must equal to_logical's argument: {}, {}.",
247 is_valid.len(),
248 length,
249 );
250 is_valid.to_bool().to_mask()
251 }
252 }
253 }
254
255 #[inline]
257 pub fn and(self, rhs: Validity) -> VortexResult<Validity> {
258 Ok(match (self, rhs) {
259 (Validity::NonNullable, Validity::NonNullable) => Validity::NonNullable,
261 (Validity::AllInvalid, _) | (_, Validity::AllInvalid) => Validity::AllInvalid,
263 (Validity::Array(a), Validity::AllValid)
265 | (Validity::Array(a), Validity::NonNullable)
266 | (Validity::NonNullable, Validity::Array(a))
267 | (Validity::AllValid, Validity::Array(a)) => Validity::Array(a),
268 (Validity::NonNullable, Validity::AllValid)
270 | (Validity::AllValid, Validity::NonNullable)
271 | (Validity::AllValid, Validity::AllValid) => Validity::AllValid,
272 (Validity::Array(lhs), Validity::Array(rhs)) => Validity::Array(
274 Binary
275 .try_new_array(lhs.len(), Operator::And, [lhs, rhs])?
276 .optimize()?,
277 ),
278 })
279 }
280
281 pub fn patch(
282 self,
283 len: usize,
284 indices_offset: usize,
285 indices: &ArrayRef,
286 patches: &Validity,
287 ctx: &mut ExecutionCtx,
288 ) -> VortexResult<Self> {
289 match (&self, patches) {
290 (Validity::NonNullable, Validity::NonNullable) => return Ok(Validity::NonNullable),
291 (Validity::NonNullable, _) => {
292 vortex_bail!("Can't patch a non-nullable validity with nullable validity")
293 }
294 (_, Validity::NonNullable) => {
295 vortex_bail!("Can't patch a nullable validity with non-nullable validity")
296 }
297 (Validity::AllValid, Validity::AllValid) => return Ok(Validity::AllValid),
298 (Validity::AllInvalid, Validity::AllInvalid) => return Ok(Validity::AllInvalid),
299 _ => {}
300 };
301
302 let own_nullability = if self == Validity::NonNullable {
303 Nullability::NonNullable
304 } else {
305 Nullability::Nullable
306 };
307
308 let source = match self {
309 Validity::NonNullable => BoolArray::from(BitBuffer::new_set(len)),
310 Validity::AllValid => BoolArray::from(BitBuffer::new_set(len)),
311 Validity::AllInvalid => BoolArray::from(BitBuffer::new_unset(len)),
312 Validity::Array(a) => a.clone().execute::<BoolArray>(ctx)?,
313 };
314
315 let patch_values = match patches {
316 Validity::NonNullable => BoolArray::from(BitBuffer::new_set(indices.len())),
317 Validity::AllValid => BoolArray::from(BitBuffer::new_set(indices.len())),
318 Validity::AllInvalid => BoolArray::from(BitBuffer::new_unset(indices.len())),
319 Validity::Array(a) => a.clone().execute::<BoolArray>(ctx)?,
320 };
321
322 let patches = Patches::new(
323 len,
324 indices_offset,
325 indices.to_array(),
326 patch_values.into_array(),
327 None,
329 )?;
330
331 Ok(Self::from_array(
332 source.patch(&patches, ctx)?.into_array(),
333 own_nullability,
334 ))
335 }
336
337 #[inline]
339 pub fn into_nullable(self) -> Validity {
340 match self {
341 Self::NonNullable => Self::AllValid,
342 Self::AllValid | Self::AllInvalid | Self::Array(_) => self,
343 }
344 }
345
346 #[inline]
348 pub fn into_non_nullable(self, len: usize) -> Option<Validity> {
349 match self {
350 _ if len == 0 => Some(Validity::NonNullable),
351 Self::NonNullable => Some(Self::NonNullable),
352 Self::AllValid => Some(Self::NonNullable),
353 Self::AllInvalid => None,
354 Self::Array(is_valid) => {
355 is_valid
356 .statistics()
357 .compute_min::<bool>()
358 .vortex_expect("validity array must support min")
359 .then(|| {
360 Self::NonNullable
362 })
363 }
364 }
365 }
366
367 #[inline]
369 pub fn cast_nullability(self, nullability: Nullability, len: usize) -> VortexResult<Validity> {
370 match nullability {
371 Nullability::NonNullable => self.into_non_nullable(len).ok_or_else(|| {
372 vortex_err!(InvalidArgument: "Cannot cast array with invalid values to non-nullable type.")
373 }),
374 Nullability::Nullable => Ok(self.into_nullable()),
375 }
376 }
377
378 #[inline]
380 pub fn copy_from_array(array: &ArrayRef) -> VortexResult<Self> {
381 Ok(Validity::from_mask(
382 array.validity_mask()?,
383 array.dtype().nullability(),
384 ))
385 }
386
387 #[inline]
392 fn from_array(value: ArrayRef, nullability: Nullability) -> Self {
393 if !matches!(value.dtype(), DType::Bool(Nullability::NonNullable)) {
394 vortex_panic!("Expected a non-nullable boolean array")
395 }
396 match nullability {
397 Nullability::NonNullable => Self::NonNullable,
398 Nullability::Nullable => Self::Array(value),
399 }
400 }
401
402 #[inline]
404 pub fn maybe_len(&self) -> Option<usize> {
405 match self {
406 Self::NonNullable | Self::AllValid | Self::AllInvalid => None,
407 Self::Array(a) => Some(a.len()),
408 }
409 }
410
411 #[inline]
412 pub fn uncompressed_size(&self) -> usize {
413 if let Validity::Array(a) = self {
414 a.len().div_ceil(8)
415 } else {
416 0
417 }
418 }
419}
420
421impl PartialEq for Validity {
422 #[inline]
423 fn eq(&self, other: &Self) -> bool {
424 match (self, other) {
425 (Self::NonNullable, Self::NonNullable) => true,
426 (Self::AllValid, Self::AllValid) => true,
427 (Self::AllInvalid, Self::AllInvalid) => true,
428 (Self::Array(a), Self::Array(b)) => {
429 let a = a.to_bool();
430 let b = b.to_bool();
431 a.to_bit_buffer() == b.to_bit_buffer()
432 }
433 _ => false,
434 }
435 }
436}
437
438impl From<BitBuffer> for Validity {
439 #[inline]
440 fn from(value: BitBuffer) -> Self {
441 let true_count = value.true_count();
442 if true_count == value.len() {
443 Self::AllValid
444 } else if true_count == 0 {
445 Self::AllInvalid
446 } else {
447 Self::Array(BoolArray::from(value).into_array())
448 }
449 }
450}
451
452impl FromIterator<Mask> for Validity {
453 #[inline]
454 fn from_iter<T: IntoIterator<Item = Mask>>(iter: T) -> Self {
455 Validity::from_mask(iter.into_iter().collect(), Nullability::Nullable)
456 }
457}
458
459impl FromIterator<bool> for Validity {
460 #[inline]
461 fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
462 Validity::from(BitBuffer::from_iter(iter))
463 }
464}
465
466impl From<Nullability> for Validity {
467 #[inline]
468 fn from(value: Nullability) -> Self {
469 Validity::from(&value)
470 }
471}
472
473impl From<&Nullability> for Validity {
474 #[inline]
475 fn from(value: &Nullability) -> Self {
476 match *value {
477 Nullability::NonNullable => Validity::NonNullable,
478 Nullability::Nullable => Validity::AllValid,
479 }
480 }
481}
482
483impl Validity {
484 pub fn from_bit_buffer(buffer: BitBuffer, nullability: Nullability) -> Self {
485 if buffer.true_count() == buffer.len() {
486 nullability.into()
487 } else if buffer.true_count() == 0 {
488 Validity::AllInvalid
489 } else {
490 Validity::Array(BoolArray::new(buffer, Validity::NonNullable).into_array())
491 }
492 }
493
494 pub fn from_mask(mask: Mask, nullability: Nullability) -> Self {
495 assert!(
496 nullability == Nullability::Nullable || matches!(mask, Mask::AllTrue(_)),
497 "NonNullable validity must be AllValid",
498 );
499 match mask {
500 Mask::AllTrue(_) => match nullability {
501 Nullability::NonNullable => Validity::NonNullable,
502 Nullability::Nullable => Validity::AllValid,
503 },
504 Mask::AllFalse(_) => Validity::AllInvalid,
505 Mask::Values(values) => Validity::Array(values.into_array()),
506 }
507 }
508}
509
510impl IntoArray for Mask {
511 #[inline]
512 fn into_array(self) -> ArrayRef {
513 match self {
514 Self::AllTrue(len) => ConstantArray::new(true, len).into_array(),
515 Self::AllFalse(len) => ConstantArray::new(false, len).into_array(),
516 Self::Values(a) => a.into_array(),
517 }
518 }
519}
520
521impl IntoArray for &MaskValues {
522 #[inline]
523 fn into_array(self) -> ArrayRef {
524 BoolArray::new(self.bit_buffer().clone(), Validity::NonNullable).into_array()
525 }
526}
527
528#[cfg(test)]
529mod tests {
530 use rstest::rstest;
531 use vortex_buffer::Buffer;
532 use vortex_buffer::buffer;
533 use vortex_mask::Mask;
534
535 use crate::ArrayRef;
536 use crate::IntoArray;
537 use crate::LEGACY_SESSION;
538 use crate::VortexSessionExecute;
539 use crate::arrays::PrimitiveArray;
540 use crate::dtype::Nullability;
541 use crate::validity::BoolArray;
542 use crate::validity::Validity;
543
544 #[rstest]
545 #[case(Validity::AllValid, 5, &[2, 4], Validity::AllValid, Validity::AllValid)]
546 #[case(
547 Validity::AllValid,
548 5,
549 &[2, 4],
550 Validity::AllInvalid,
551 Validity::Array(BoolArray::from_iter([true, true, false, true, false]).into_array())
552 )]
553 #[case(
554 Validity::AllValid,
555 5,
556 &[2, 4],
557 Validity::Array(BoolArray::from_iter([true, false]).into_array()),
558 Validity::Array(BoolArray::from_iter([true, true, true, true, false]).into_array())
559 )]
560 #[case(
561 Validity::AllInvalid,
562 5,
563 &[2, 4],
564 Validity::AllValid,
565 Validity::Array(BoolArray::from_iter([false, false, true, false, true]).into_array())
566 )]
567 #[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllInvalid, Validity::AllInvalid)]
568 #[case(
569 Validity::AllInvalid,
570 5,
571 &[2, 4],
572 Validity::Array(BoolArray::from_iter([true, false]).into_array()),
573 Validity::Array(BoolArray::from_iter([false, false, true, false, false]).into_array())
574 )]
575 #[case(
576 Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()),
577 5,
578 &[2, 4],
579 Validity::AllValid,
580 Validity::Array(BoolArray::from_iter([false, true, true, true, true]).into_array())
581 )]
582 #[case(
583 Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()),
584 5,
585 &[2, 4],
586 Validity::AllInvalid,
587 Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array())
588 )]
589 #[case(
590 Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()),
591 5,
592 &[2, 4],
593 Validity::Array(BoolArray::from_iter([true, false]).into_array()),
594 Validity::Array(BoolArray::from_iter([false, true, true, true, false]).into_array())
595 )]
596
597 fn patch_validity(
598 #[case] validity: Validity,
599 #[case] len: usize,
600 #[case] positions: &[u64],
601 #[case] patches: Validity,
602 #[case] expected: Validity,
603 ) {
604 let indices =
605 PrimitiveArray::new(Buffer::copy_from(positions), Validity::NonNullable).into_array();
606 assert_eq!(
607 validity
608 .patch(
609 len,
610 0,
611 &indices,
612 &patches,
613 &mut LEGACY_SESSION.create_execution_ctx()
614 )
615 .unwrap(),
616 expected
617 );
618 }
619
620 #[test]
621 #[should_panic]
622 fn out_of_bounds_patch() {
623 Validity::NonNullable
624 .patch(
625 2,
626 0,
627 &buffer![4].into_array(),
628 &Validity::AllInvalid,
629 &mut LEGACY_SESSION.create_execution_ctx(),
630 )
631 .unwrap();
632 }
633
634 #[test]
635 #[should_panic]
636 fn into_validity_nullable() {
637 Validity::from_mask(Mask::AllFalse(10), Nullability::NonNullable);
638 }
639
640 #[test]
641 #[should_panic]
642 fn into_validity_nullable_array() {
643 Validity::from_mask(Mask::from_iter(vec![true, false]), Nullability::NonNullable);
644 }
645
646 #[rstest]
647 #[case(
648 Validity::AllValid,
649 PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(),
650 Validity::from_iter(vec![true, false])
651 )]
652 #[case(Validity::AllValid, buffer![0, 1].into_array(), Validity::AllValid)]
653 #[case(
654 Validity::AllValid,
655 PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(),
656 Validity::AllInvalid
657 )]
658 #[case(
659 Validity::NonNullable,
660 PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(),
661 Validity::from_iter(vec![true, false])
662 )]
663 #[case(Validity::NonNullable, buffer![0, 1].into_array(), Validity::NonNullable)]
664 #[case(
665 Validity::NonNullable,
666 PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(),
667 Validity::AllInvalid
668 )]
669 fn validity_take(
670 #[case] validity: Validity,
671 #[case] indices: ArrayRef,
672 #[case] expected: Validity,
673 ) {
674 assert_eq!(validity.take(&indices).unwrap(), expected);
675 }
676}