1use std::fmt::Debug;
7use std::ops::Range;
8
9use itertools::Itertools as _;
10use vortex_buffer::BitBuffer;
11use vortex_error::VortexExpect as _;
12use vortex_error::VortexResult;
13use vortex_error::vortex_bail;
14use vortex_error::vortex_err;
15use vortex_mask::Mask;
16use vortex_mask::MaskValues;
17
18use crate::ArrayRef;
19use crate::Canonical;
20use crate::ExecutionCtx;
21use crate::IntoArray;
22use crate::LEGACY_SESSION;
23use crate::VortexSessionExecute;
24use crate::arrays::BoolArray;
25use crate::arrays::ChunkedArray;
26use crate::arrays::ConstantArray;
27use crate::arrays::scalar_fn::ScalarFnFactoryExt;
28use crate::builtins::ArrayBuiltins;
29use crate::dtype::DType;
30use crate::dtype::Nullability;
31use crate::optimizer::ArrayOptimizer;
32use crate::patches::Patches;
33use crate::scalar::Scalar;
34use crate::scalar_fn::fns::binary::Binary;
35use crate::scalar_fn::fns::operators::Operator;
36
37#[derive(Clone)]
39pub enum Validity {
40 NonNullable,
42 AllValid,
44 AllInvalid,
46 Array(ArrayRef),
50}
51
52impl Debug for Validity {
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 match self {
55 Self::NonNullable => write!(f, "NonNullable"),
56 Self::AllValid => write!(f, "AllValid"),
57 Self::AllInvalid => write!(f, "AllInvalid"),
58 Self::Array(arr) => write!(f, "SomeValid({})", arr.display_values()),
59 }
60 }
61}
62
63impl Validity {
64 pub fn execute(self, ctx: &mut ExecutionCtx) -> VortexResult<Validity> {
66 match self {
67 v @ Validity::NonNullable | v @ Validity::AllValid | v @ Validity::AllInvalid => Ok(v),
68 Validity::Array(a) => Ok(Validity::Array(a.execute::<Canonical>(ctx)?.into_array())),
69 }
70 }
71}
72
73impl Validity {
74 pub const DTYPE: DType = DType::Bool(Nullability::NonNullable);
76
77 pub fn to_array(&self, len: usize) -> ArrayRef {
79 match self {
80 Self::NonNullable | Self::AllValid => ConstantArray::new(true, len).into_array(),
81 Self::AllInvalid => ConstantArray::new(false, len).into_array(),
82 Self::Array(a) => a.clone(),
83 }
84 }
85
86 #[inline]
88 pub fn into_array(self) -> Option<ArrayRef> {
89 if let Self::Array(a) = self {
90 Some(a)
91 } else {
92 None
93 }
94 }
95
96 #[inline]
98 pub fn as_array(&self) -> Option<&ArrayRef> {
99 if let Self::Array(a) = self {
100 Some(a)
101 } else {
102 None
103 }
104 }
105
106 #[inline]
107 pub fn nullability(&self) -> Nullability {
108 if matches!(self, Self::NonNullable) {
109 Nullability::NonNullable
110 } else {
111 Nullability::Nullable
112 }
113 }
114
115 #[inline]
118 pub fn no_nulls(&self) -> bool {
119 matches!(self, Self::NonNullable | Self::AllValid)
120 }
121
122 #[inline]
124 pub fn union_nullability(self, nullability: Nullability) -> Self {
125 match nullability {
126 Nullability::NonNullable => self,
127 Nullability::Nullable => self.into_nullable(),
128 }
129 }
130
131 #[inline]
133 pub fn is_valid(&self, index: usize) -> VortexResult<bool> {
134 Ok(match self {
135 Self::NonNullable | Self::AllValid => true,
136 Self::AllInvalid => false,
137 Self::Array(a) => a
138 .execute_scalar(index, &mut LEGACY_SESSION.create_execution_ctx())
139 .vortex_expect("Validity array must support execute_scalar")
140 .as_bool()
141 .value()
142 .vortex_expect("Validity must be non-nullable"),
143 })
144 }
145
146 #[inline]
147 pub fn is_null(&self, index: usize) -> VortexResult<bool> {
148 Ok(!self.is_valid(index)?)
149 }
150
151 #[inline]
152 pub fn slice(&self, range: Range<usize>) -> VortexResult<Self> {
153 match self {
154 Self::Array(a) => Ok(Self::Array(a.slice(range)?)),
155 Self::NonNullable | Self::AllValid | Self::AllInvalid => Ok(self.clone()),
156 }
157 }
158
159 pub fn take(&self, indices: &ArrayRef) -> VortexResult<Self> {
160 match self {
161 Self::NonNullable => indices.validity(),
162 Self::AllValid => Ok(match indices.validity()? {
163 Self::NonNullable => Self::AllValid,
164 v => v,
165 }),
166 Self::AllInvalid => Ok(Self::AllInvalid),
167 Self::Array(is_valid) => {
168 let maybe_is_valid = is_valid.take(indices.clone())?;
169 let is_valid = maybe_is_valid.fill_null(Scalar::from(false))?;
171 Ok(Self::Array(is_valid))
172 }
173 }
174 }
175
176 pub fn not(&self) -> VortexResult<Self> {
178 match self {
179 Validity::NonNullable => Ok(Validity::NonNullable),
180 Validity::AllValid => Ok(Validity::AllInvalid),
181 Validity::AllInvalid => Ok(Validity::AllValid),
182 Validity::Array(arr) => Ok(Validity::Array(arr.not()?)),
183 }
184 }
185
186 pub fn filter(&self, mask: &Mask) -> VortexResult<Self> {
194 match self {
197 v @ (Validity::NonNullable | Validity::AllValid | Validity::AllInvalid) => {
198 Ok(v.clone())
199 }
200 Validity::Array(arr) => Ok(Validity::Array(arr.filter(mask.clone())?)),
201 }
202 }
203
204 #[deprecated(note = "Use execute_mask")]
208 pub fn to_mask(&self, length: usize, ctx: &mut ExecutionCtx) -> VortexResult<Mask> {
209 match self {
210 Self::NonNullable | Self::AllValid => Ok(Mask::new_true(length)),
211 Self::AllInvalid => Ok(Mask::new_false(length)),
212 Self::Array(arr) => arr.clone().execute::<Mask>(ctx),
213 }
214 }
215
216 pub fn execute_mask(&self, length: usize, ctx: &mut ExecutionCtx) -> VortexResult<Mask> {
217 match self {
218 Self::NonNullable | Self::AllValid => Ok(Mask::AllTrue(length)),
219 Self::AllInvalid => Ok(Mask::AllFalse(length)),
220 Self::Array(arr) => {
221 assert_eq!(
222 arr.len(),
223 length,
224 "Validity::Array length must equal to_logical's argument: {}, {}.",
225 arr.len(),
226 length,
227 );
228 arr.clone().execute::<Mask>(ctx)
231 }
232 }
233 }
234
235 pub fn mask_eq(&self, other: &Validity, ctx: &mut ExecutionCtx) -> VortexResult<bool> {
237 match (self, other) {
238 (Validity::NonNullable, Validity::NonNullable) => Ok(true),
239 (Validity::AllValid, Validity::AllValid) => Ok(true),
240 (Validity::AllInvalid, Validity::AllInvalid) => Ok(true),
241 (Validity::Array(a), Validity::Array(b)) => {
242 let a = a.clone().execute::<Mask>(ctx)?;
243 let b = b.clone().execute::<Mask>(ctx)?;
244 Ok(a == b)
245 }
246 _ => Ok(false),
247 }
248 }
249
250 #[inline]
252 pub fn and(self, rhs: Validity) -> VortexResult<Validity> {
253 Ok(match (self, rhs) {
254 (Validity::NonNullable, Validity::NonNullable) => Validity::NonNullable,
256 (Validity::AllInvalid, _) | (_, Validity::AllInvalid) => Validity::AllInvalid,
258 (Validity::Array(a), Validity::AllValid)
260 | (Validity::Array(a), Validity::NonNullable)
261 | (Validity::NonNullable, Validity::Array(a))
262 | (Validity::AllValid, Validity::Array(a)) => Validity::Array(a),
263 (Validity::NonNullable, Validity::AllValid)
265 | (Validity::AllValid, Validity::NonNullable)
266 | (Validity::AllValid, Validity::AllValid) => Validity::AllValid,
267 (Validity::Array(lhs), Validity::Array(rhs)) => Validity::Array(
269 Binary
270 .try_new_array(lhs.len(), Operator::And, [lhs, rhs])?
271 .optimize()?,
272 ),
273 })
274 }
275
276 pub fn patch(
277 self,
278 len: usize,
279 indices_offset: usize,
280 indices: &ArrayRef,
281 patches: &Validity,
282 ctx: &mut ExecutionCtx,
283 ) -> VortexResult<Self> {
284 match (&self, patches) {
285 (Validity::NonNullable, Validity::NonNullable) => return Ok(Validity::NonNullable),
286 (Validity::NonNullable, _) => {
287 vortex_bail!("Can't patch a non-nullable validity with nullable validity")
288 }
289 (_, Validity::NonNullable) => {
290 vortex_bail!("Can't patch a nullable validity with non-nullable validity")
291 }
292 (Validity::AllValid, Validity::AllValid) => return Ok(Validity::AllValid),
293 (Validity::AllInvalid, Validity::AllInvalid) => return Ok(Validity::AllInvalid),
294 _ => {}
295 };
296
297 if matches!(self, Validity::NonNullable) {
298 return Ok(Self::NonNullable);
299 }
300
301 let source = match self {
303 Validity::NonNullable => BoolArray::from(BitBuffer::new_set(len)),
304 Validity::AllValid => BoolArray::from(BitBuffer::new_set(len)),
305 Validity::AllInvalid => BoolArray::from(BitBuffer::new_unset(len)),
306 Validity::Array(a) => a.execute::<BoolArray>(ctx)?,
307 };
308
309 let patch_values = match patches {
310 Validity::NonNullable => BoolArray::from(BitBuffer::new_set(indices.len())),
311 Validity::AllValid => BoolArray::from(BitBuffer::new_set(indices.len())),
312 Validity::AllInvalid => BoolArray::from(BitBuffer::new_unset(indices.len())),
313 Validity::Array(a) => a.clone().execute::<BoolArray>(ctx)?,
314 };
315
316 let patches = Patches::new(
317 len,
318 indices_offset,
319 indices.clone(),
320 patch_values.into_array(),
321 None,
323 )?;
324
325 Ok(Self::Array(source.patch(&patches, ctx)?.into_array()))
326 }
327
328 #[inline]
330 pub fn into_nullable(self) -> Validity {
331 match self {
332 Self::NonNullable => Self::AllValid,
333 Self::AllValid | Self::AllInvalid | Self::Array(_) => self,
334 }
335 }
336
337 #[inline]
343 pub fn into_non_nullable(self, len: usize, ctx: &mut ExecutionCtx) -> Option<Validity> {
344 match self {
345 _ if len == 0 => Some(Validity::NonNullable),
346 Self::NonNullable => Some(Self::NonNullable),
347 Self::AllValid => Some(Self::NonNullable),
348 Self::AllInvalid => None,
349 Self::Array(is_valid) => {
350 is_valid
351 .statistics()
352 .compute_min::<bool>(ctx)
353 .vortex_expect("validity array must support min")
354 .then(|| {
355 Self::NonNullable
357 })
358 }
359 }
360 }
361
362 #[inline]
373 pub fn trivial_into_non_nullable(self, len: usize) -> VortexResult<Option<Validity>> {
374 match self {
375 _ if len == 0 => Ok(Some(Validity::NonNullable)),
376 Self::NonNullable => Ok(Some(Self::NonNullable)),
377 Self::AllValid => Ok(Some(Self::NonNullable)),
378 Self::AllInvalid => {
379 Err(vortex_err!(InvalidArgument: "Cannot cast AllInvalid to NonNullable"))
380 }
381 Self::Array(_) => Ok(None),
382 }
383 }
384
385 #[inline]
401 pub fn cast_nullability(
402 self,
403 nullability: Nullability,
404 len: usize,
405 ctx: &mut ExecutionCtx,
406 ) -> VortexResult<Validity> {
407 match nullability {
408 Nullability::NonNullable => self.into_non_nullable(len, ctx).ok_or_else(|| {
409 vortex_err!(InvalidArgument: "Cannot cast array with invalid values to non-nullable type.")
410 }),
411 Nullability::Nullable => Ok(self.into_nullable()),
412 }
413 }
414
415 #[inline]
440 pub fn trivial_cast_nullability(
441 self,
442 nullability: Nullability,
443 len: usize,
444 ) -> VortexResult<Option<Validity>> {
445 match nullability {
446 Nullability::NonNullable => self.trivial_into_non_nullable(len),
447 Nullability::Nullable => Ok(Some(self.into_nullable())),
448 }
449 }
450
451 #[inline]
453 pub fn maybe_len(&self) -> Option<usize> {
454 match self {
455 Self::NonNullable | Self::AllValid | Self::AllInvalid => None,
456 Self::Array(a) => Some(a.len()),
457 }
458 }
459}
460
461impl From<BitBuffer> for Validity {
462 #[inline]
463 fn from(value: BitBuffer) -> Self {
464 let true_count = value.true_count();
465 if true_count == value.len() {
466 Self::AllValid
467 } else if true_count == 0 {
468 Self::AllInvalid
469 } else {
470 Self::Array(BoolArray::from(value).into_array())
471 }
472 }
473}
474
475impl FromIterator<Mask> for Validity {
476 #[inline]
477 fn from_iter<T: IntoIterator<Item = Mask>>(iter: T) -> Self {
478 Validity::from_mask(iter.into_iter().collect(), Nullability::Nullable)
479 }
480}
481
482impl FromIterator<bool> for Validity {
483 #[inline]
484 fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
485 Validity::from(BitBuffer::from_iter(iter))
486 }
487}
488
489impl From<Nullability> for Validity {
490 #[inline]
491 fn from(value: Nullability) -> Self {
492 Validity::from(&value)
493 }
494}
495
496impl From<&Nullability> for Validity {
497 #[inline]
498 fn from(value: &Nullability) -> Self {
499 match *value {
500 Nullability::NonNullable => Validity::NonNullable,
501 Nullability::Nullable => Validity::AllValid,
502 }
503 }
504}
505
506impl Validity {
507 pub fn concat(validities: Vec<(Validity, usize)>) -> Option<Self> {
511 let mut validity_kinds = validities
512 .iter()
513 .map(|(v, _)| std::mem::discriminant(v))
514 .unique();
515 let validity_kind = validity_kinds.next()?;
516 if validity_kinds.next().is_none() {
517 if validity_kind == std::mem::discriminant(&Validity::AllValid) {
520 return Some(Validity::AllValid);
521 }
522 if validity_kind == std::mem::discriminant(&Validity::AllInvalid) {
523 return Some(Validity::AllInvalid);
524 }
525 if validity_kind == std::mem::discriminant(&Validity::NonNullable) {
526 return Some(Validity::NonNullable);
527 }
528 }
529
530 Some(Validity::Array(
531 unsafe {
532 ChunkedArray::new_unchecked(
533 validities
534 .into_iter()
535 .map(|(v, len)| v.to_array(len))
536 .collect(),
537 DType::Bool(Nullability::NonNullable),
538 )
539 }
540 .into_array(),
541 ))
542 }
543}
544
545impl Validity {
546 pub fn from_bit_buffer(buffer: BitBuffer, nullability: Nullability) -> Self {
547 if buffer.true_count() == buffer.len() {
548 nullability.into()
549 } else if buffer.true_count() == 0 {
550 Validity::AllInvalid
551 } else {
552 Validity::Array(BoolArray::new(buffer, Validity::NonNullable).into_array())
553 }
554 }
555
556 pub fn from_mask(mask: Mask, nullability: Nullability) -> Self {
557 assert!(
558 nullability == Nullability::Nullable || matches!(mask, Mask::AllTrue(_)),
559 "NonNullable validity must be AllValid",
560 );
561 match mask {
562 Mask::AllTrue(_) => match nullability {
563 Nullability::NonNullable => Validity::NonNullable,
564 Nullability::Nullable => Validity::AllValid,
565 },
566 Mask::AllFalse(_) => Validity::AllInvalid,
567 Mask::Values(values) => Validity::Array(values.into_array()),
568 }
569 }
570}
571
572impl IntoArray for Mask {
573 #[inline]
574 fn into_array(self) -> ArrayRef {
575 match self {
576 Self::AllTrue(len) => ConstantArray::new(true, len).into_array(),
577 Self::AllFalse(len) => ConstantArray::new(false, len).into_array(),
578 Self::Values(a) => a.into_array(),
579 }
580 }
581}
582
583impl IntoArray for &MaskValues {
584 #[inline]
585 fn into_array(self) -> ArrayRef {
586 BoolArray::new(self.bit_buffer().clone(), Validity::NonNullable).into_array()
587 }
588}
589
590#[cfg(test)]
591mod tests {
592 use rstest::rstest;
593 use vortex_buffer::Buffer;
594 use vortex_buffer::buffer;
595 use vortex_mask::Mask;
596
597 use crate::ArrayRef;
598 use crate::IntoArray;
599 use crate::LEGACY_SESSION;
600 use crate::VortexSessionExecute;
601 use crate::arrays::PrimitiveArray;
602 use crate::dtype::Nullability;
603 use crate::validity::BoolArray;
604 use crate::validity::Validity;
605
606 #[rstest]
607 #[case(Validity::AllValid, 5, &[2, 4], Validity::AllValid, Validity::AllValid)]
608 #[case(
609 Validity::AllValid,
610 5,
611 &[2, 4],
612 Validity::AllInvalid,
613 Validity::Array(BoolArray::from_iter([true, true, false, true, false]).into_array())
614 )]
615 #[case(
616 Validity::AllValid,
617 5,
618 &[2, 4],
619 Validity::Array(BoolArray::from_iter([true, false]).into_array()),
620 Validity::Array(BoolArray::from_iter([true, true, true, true, false]).into_array())
621 )]
622 #[case(
623 Validity::AllInvalid,
624 5,
625 &[2, 4],
626 Validity::AllValid,
627 Validity::Array(BoolArray::from_iter([false, false, true, false, true]).into_array())
628 )]
629 #[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllInvalid, Validity::AllInvalid)]
630 #[case(
631 Validity::AllInvalid,
632 5,
633 &[2, 4],
634 Validity::Array(BoolArray::from_iter([true, false]).into_array()),
635 Validity::Array(BoolArray::from_iter([false, false, true, false, false]).into_array())
636 )]
637 #[case(
638 Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()),
639 5,
640 &[2, 4],
641 Validity::AllValid,
642 Validity::Array(BoolArray::from_iter([false, true, true, true, true]).into_array())
643 )]
644 #[case(
645 Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()),
646 5,
647 &[2, 4],
648 Validity::AllInvalid,
649 Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array())
650 )]
651 #[case(
652 Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()),
653 5,
654 &[2, 4],
655 Validity::Array(BoolArray::from_iter([true, false]).into_array()),
656 Validity::Array(BoolArray::from_iter([false, true, true, true, false]).into_array())
657 )]
658
659 fn patch_validity(
660 #[case] validity: Validity,
661 #[case] len: usize,
662 #[case] positions: &[u64],
663 #[case] patches: Validity,
664 #[case] expected: Validity,
665 ) {
666 let indices =
667 PrimitiveArray::new(Buffer::copy_from(positions), Validity::NonNullable).into_array();
668
669 let mut ctx = LEGACY_SESSION.create_execution_ctx();
670
671 assert!(
672 validity
673 .patch(len, 0, &indices, &patches, &mut ctx,)
674 .unwrap()
675 .mask_eq(&expected, &mut ctx)
676 .unwrap()
677 );
678 }
679
680 #[test]
681 #[should_panic]
682 fn out_of_bounds_patch() {
683 let mut ctx = LEGACY_SESSION.create_execution_ctx();
684 Validity::NonNullable
685 .patch(
686 2,
687 0,
688 &buffer![4].into_array(),
689 &Validity::AllInvalid,
690 &mut ctx,
691 )
692 .unwrap();
693 }
694
695 #[test]
696 #[should_panic]
697 fn into_validity_nullable() {
698 Validity::from_mask(Mask::AllFalse(10), Nullability::NonNullable);
699 }
700
701 #[test]
702 #[should_panic]
703 fn into_validity_nullable_array() {
704 Validity::from_mask(Mask::from_iter(vec![true, false]), Nullability::NonNullable);
705 }
706
707 #[rstest]
708 #[case(
709 Validity::AllValid,
710 PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(),
711 Validity::from_iter(vec![true, false])
712 )]
713 #[case(Validity::AllValid, buffer![0, 1].into_array(), Validity::AllValid)]
714 #[case(
715 Validity::AllValid,
716 PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(),
717 Validity::AllInvalid
718 )]
719 #[case(
720 Validity::NonNullable,
721 PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(),
722 Validity::from_iter(vec![true, false])
723 )]
724 #[case(Validity::NonNullable, buffer![0, 1].into_array(), Validity::NonNullable)]
725 #[case(
726 Validity::NonNullable,
727 PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(),
728 Validity::AllInvalid
729 )]
730 fn validity_take(
731 #[case] validity: Validity,
732 #[case] indices: ArrayRef,
733 #[case] expected: Validity,
734 ) {
735 let mut ctx = LEGACY_SESSION.create_execution_ctx();
736 assert!(
737 validity
738 .take(&indices)
739 .unwrap()
740 .mask_eq(&expected, &mut ctx)
741 .unwrap()
742 );
743 }
744}