1use std::fmt::Debug;
7use std::ops::Range;
8
9use itertools::Itertools as _;
10use vortex_buffer::BitBuffer;
11use vortex_error::VortexExpect as _;
12use vortex_error::VortexResult;
13use vortex_error::vortex_bail;
14use vortex_error::vortex_err;
15use vortex_mask::Mask;
16use vortex_mask::MaskValues;
17
18use crate::ArrayRef;
19use crate::Canonical;
20use crate::ExecutionCtx;
21use crate::IntoArray;
22use crate::LEGACY_SESSION;
23use crate::VortexSessionExecute;
24use crate::arrays::BoolArray;
25use crate::arrays::ChunkedArray;
26use crate::arrays::ConstantArray;
27use crate::arrays::scalar_fn::ScalarFnFactoryExt;
28use crate::builtins::ArrayBuiltins;
29use crate::dtype::DType;
30use crate::dtype::Nullability;
31use crate::optimizer::ArrayOptimizer;
32use crate::patches::Patches;
33use crate::scalar::Scalar;
34use crate::scalar_fn::fns::binary::Binary;
35use crate::scalar_fn::fns::operators::Operator;
36
37#[derive(Clone)]
39pub enum Validity {
40 NonNullable,
42 AllValid,
44 AllInvalid,
46 Array(ArrayRef),
50}
51
52impl Debug for Validity {
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 match self {
55 Self::NonNullable => write!(f, "NonNullable"),
56 Self::AllValid => write!(f, "AllValid"),
57 Self::AllInvalid => write!(f, "AllInvalid"),
58 Self::Array(arr) => write!(f, "SomeValid({})", arr.display_values()),
59 }
60 }
61}
62
63impl Validity {
64 pub fn execute(self, ctx: &mut ExecutionCtx) -> VortexResult<Validity> {
66 match self {
67 v @ Validity::NonNullable | v @ Validity::AllValid | v @ Validity::AllInvalid => Ok(v),
68 Validity::Array(a) => Ok(Validity::Array(a.execute::<Canonical>(ctx)?.into_array())),
69 }
70 }
71}
72
73impl Validity {
74 pub const DTYPE: DType = DType::Bool(Nullability::NonNullable);
76
77 pub fn to_array(&self, len: usize) -> ArrayRef {
79 match self {
80 Self::NonNullable | Self::AllValid => ConstantArray::new(true, len).into_array(),
81 Self::AllInvalid => ConstantArray::new(false, len).into_array(),
82 Self::Array(a) => a.clone(),
83 }
84 }
85
86 #[inline]
88 pub fn into_array(self) -> Option<ArrayRef> {
89 if let Self::Array(a) = self {
90 Some(a)
91 } else {
92 None
93 }
94 }
95
96 #[inline]
98 pub fn as_array(&self) -> Option<&ArrayRef> {
99 if let Self::Array(a) = self {
100 Some(a)
101 } else {
102 None
103 }
104 }
105
106 #[inline]
107 pub fn nullability(&self) -> Nullability {
108 if matches!(self, Self::NonNullable) {
109 Nullability::NonNullable
110 } else {
111 Nullability::Nullable
112 }
113 }
114
115 #[inline]
122 pub fn definitely_no_nulls(&self) -> bool {
123 matches!(self, Self::NonNullable | Self::AllValid)
124 }
125
126 pub fn execute_no_nulls(&self, length: usize, ctx: &mut ExecutionCtx) -> VortexResult<bool> {
132 match self {
133 Self::NonNullable | Self::AllValid => Ok(true),
134 Self::AllInvalid => Ok(length == 0),
135 Self::Array(_) => Ok(self.execute_mask(length, ctx)?.all_true()),
136 }
137 }
138
139 #[inline]
141 pub fn union_nullability(self, nullability: Nullability) -> Self {
142 match nullability {
143 Nullability::NonNullable => self,
144 Nullability::Nullable => self.into_nullable(),
145 }
146 }
147
148 #[inline]
150 pub fn execute_is_valid(&self, index: usize, ctx: &mut ExecutionCtx) -> VortexResult<bool> {
151 Ok(match self {
152 Self::NonNullable | Self::AllValid => true,
153 Self::AllInvalid => false,
154 Self::Array(a) => a
155 .execute_scalar(index, ctx)?
156 .as_bool()
157 .value()
158 .ok_or_else(|| vortex_err!("validity value at index {index} is null"))?,
159 })
160 }
161
162 #[inline]
164 pub fn execute_is_null(&self, index: usize, ctx: &mut ExecutionCtx) -> VortexResult<bool> {
165 Ok(!self.execute_is_valid(index, ctx)?)
166 }
167
168 #[deprecated(note = "use `execute_is_valid` with an explicit `ExecutionCtx`")]
170 #[inline]
171 pub fn is_valid(&self, index: usize) -> VortexResult<bool> {
172 self.execute_is_valid(index, &mut LEGACY_SESSION.create_execution_ctx())
173 }
174
175 #[deprecated(note = "use `execute_is_null` with an explicit `ExecutionCtx`")]
177 #[inline]
178 pub fn is_null(&self, index: usize) -> VortexResult<bool> {
179 self.execute_is_null(index, &mut LEGACY_SESSION.create_execution_ctx())
180 }
181
182 #[inline]
183 pub fn slice(&self, range: Range<usize>) -> VortexResult<Self> {
184 match self {
185 Self::Array(a) => Ok(Self::Array(a.slice(range)?)),
186 Self::NonNullable | Self::AllValid | Self::AllInvalid => Ok(self.clone()),
187 }
188 }
189
190 pub fn take(&self, indices: &ArrayRef) -> VortexResult<Self> {
191 match self {
192 Self::NonNullable => indices.validity(),
193 Self::AllValid => Ok(match indices.validity()? {
194 Self::NonNullable => Self::AllValid,
195 v => v,
196 }),
197 Self::AllInvalid => Ok(Self::AllInvalid),
198 Self::Array(is_valid) => {
199 let maybe_is_valid = is_valid.take(indices.clone())?;
200 let is_valid = maybe_is_valid.fill_null(Scalar::from(false))?;
202 Ok(Self::Array(is_valid))
203 }
204 }
205 }
206
207 pub fn not(&self) -> VortexResult<Self> {
209 match self {
210 Validity::NonNullable => Ok(Validity::NonNullable),
211 Validity::AllValid => Ok(Validity::AllInvalid),
212 Validity::AllInvalid => Ok(Validity::AllValid),
213 Validity::Array(arr) => Ok(Validity::Array(arr.not()?)),
214 }
215 }
216
217 pub fn filter(&self, mask: &Mask) -> VortexResult<Self> {
225 match self {
228 v @ (Validity::NonNullable | Validity::AllValid | Validity::AllInvalid) => {
229 Ok(v.clone())
230 }
231 Validity::Array(arr) => Ok(Validity::Array(arr.filter(mask.clone())?)),
232 }
233 }
234
235 #[deprecated(note = "Use execute_mask")]
239 pub fn to_mask(&self, length: usize, ctx: &mut ExecutionCtx) -> VortexResult<Mask> {
240 match self {
241 Self::NonNullable | Self::AllValid => Ok(Mask::new_true(length)),
242 Self::AllInvalid => Ok(Mask::new_false(length)),
243 Self::Array(arr) => arr.clone().execute::<Mask>(ctx),
244 }
245 }
246
247 #[inline]
248 pub fn execute_mask(&self, length: usize, ctx: &mut ExecutionCtx) -> VortexResult<Mask> {
249 match self {
250 Self::NonNullable | Self::AllValid => Ok(Mask::AllTrue(length)),
251 Self::AllInvalid => Ok(Mask::AllFalse(length)),
252 Self::Array(arr) => {
253 assert_eq!(
254 arr.len(),
255 length,
256 "Validity::Array length must equal to_logical's argument: {}, {}.",
257 arr.len(),
258 length,
259 );
260 arr.clone().execute::<Mask>(ctx)
263 }
264 }
265 }
266
267 pub fn mask_eq(
270 &self,
271 other: &Validity,
272 length: usize,
273 ctx: &mut ExecutionCtx,
274 ) -> VortexResult<bool> {
275 match (self, other) {
276 (
278 Validity::NonNullable | Validity::AllValid,
279 Validity::NonNullable | Validity::AllValid,
280 )
281 | (Validity::AllInvalid, Validity::AllInvalid) => Ok(true),
282 _ => Ok(self.execute_mask(length, ctx)? == other.execute_mask(length, ctx)?),
283 }
284 }
285
286 #[inline]
288 pub fn and(self, rhs: Validity) -> VortexResult<Validity> {
289 Ok(match (self, rhs) {
290 (Validity::NonNullable, Validity::NonNullable) => Validity::NonNullable,
292 (Validity::AllInvalid, _) | (_, Validity::AllInvalid) => Validity::AllInvalid,
294 (Validity::Array(a), Validity::AllValid)
296 | (Validity::Array(a), Validity::NonNullable)
297 | (Validity::NonNullable, Validity::Array(a))
298 | (Validity::AllValid, Validity::Array(a)) => Validity::Array(a),
299 (Validity::NonNullable, Validity::AllValid)
301 | (Validity::AllValid, Validity::NonNullable)
302 | (Validity::AllValid, Validity::AllValid) => Validity::AllValid,
303 (Validity::Array(lhs), Validity::Array(rhs)) => Validity::Array(
305 Binary
306 .try_new_array(lhs.len(), Operator::And, [lhs, rhs])?
307 .optimize()?,
308 ),
309 })
310 }
311
312 pub fn patch(
313 self,
314 len: usize,
315 indices_offset: usize,
316 indices: &ArrayRef,
317 patches: &Validity,
318 ctx: &mut ExecutionCtx,
319 ) -> VortexResult<Self> {
320 match (&self, patches) {
321 (Validity::NonNullable, Validity::NonNullable) => return Ok(Validity::NonNullable),
322 (Validity::NonNullable, _) => {
323 vortex_bail!("Can't patch a non-nullable validity with nullable validity")
324 }
325 (_, Validity::NonNullable) => {
326 vortex_bail!("Can't patch a nullable validity with non-nullable validity")
327 }
328 (Validity::AllValid, Validity::AllValid) => return Ok(Validity::AllValid),
329 (Validity::AllInvalid, Validity::AllInvalid) => return Ok(Validity::AllInvalid),
330 _ => {}
331 };
332
333 if matches!(self, Validity::NonNullable) {
334 return Ok(Self::NonNullable);
335 }
336
337 let source = match self {
339 Validity::NonNullable => BoolArray::from(BitBuffer::new_set(len)),
340 Validity::AllValid => BoolArray::from(BitBuffer::new_set(len)),
341 Validity::AllInvalid => BoolArray::from(BitBuffer::new_unset(len)),
342 Validity::Array(a) => a.execute::<BoolArray>(ctx)?,
343 };
344
345 let patch_values = match patches {
346 Validity::NonNullable => BoolArray::from(BitBuffer::new_set(indices.len())),
347 Validity::AllValid => BoolArray::from(BitBuffer::new_set(indices.len())),
348 Validity::AllInvalid => BoolArray::from(BitBuffer::new_unset(indices.len())),
349 Validity::Array(a) => a.clone().execute::<BoolArray>(ctx)?,
350 };
351
352 let patches = Patches::new(
353 len,
354 indices_offset,
355 indices.clone(),
356 patch_values.into_array(),
357 None,
359 )?;
360
361 Ok(Self::Array(source.patch(&patches, ctx)?.into_array()))
362 }
363
364 #[inline]
366 pub fn into_nullable(self) -> Validity {
367 match self {
368 Self::NonNullable => Self::AllValid,
369 Self::AllValid | Self::AllInvalid | Self::Array(_) => self,
370 }
371 }
372
373 #[inline]
379 pub fn into_non_nullable(self, len: usize, ctx: &mut ExecutionCtx) -> Option<Validity> {
380 match self {
381 _ if len == 0 => Some(Validity::NonNullable),
382 Self::NonNullable => Some(Self::NonNullable),
383 Self::AllValid => Some(Self::NonNullable),
384 Self::AllInvalid => None,
385 Self::Array(is_valid) => {
386 is_valid
387 .statistics()
388 .compute_min::<bool>(ctx)
389 .vortex_expect("validity array must support min")
390 .then(|| {
391 Self::NonNullable
393 })
394 }
395 }
396 }
397
398 #[inline]
409 pub fn trivial_into_non_nullable(self, len: usize) -> VortexResult<Option<Validity>> {
410 match self {
411 _ if len == 0 => Ok(Some(Validity::NonNullable)),
412 Self::NonNullable => Ok(Some(Self::NonNullable)),
413 Self::AllValid => Ok(Some(Self::NonNullable)),
414 Self::AllInvalid => {
415 Err(vortex_err!(InvalidArgument: "Cannot cast AllInvalid to NonNullable"))
416 }
417 Self::Array(_) => Ok(None),
418 }
419 }
420
421 #[inline]
437 pub fn cast_nullability(
438 self,
439 nullability: Nullability,
440 len: usize,
441 ctx: &mut ExecutionCtx,
442 ) -> VortexResult<Validity> {
443 match nullability {
444 Nullability::NonNullable => self.into_non_nullable(len, ctx).ok_or_else(|| {
445 vortex_err!(InvalidArgument: "Cannot cast array with invalid values to non-nullable type.")
446 }),
447 Nullability::Nullable => Ok(self.into_nullable()),
448 }
449 }
450
451 #[inline]
476 pub fn trivially_cast_nullability(
477 self,
478 nullability: Nullability,
479 len: usize,
480 ) -> VortexResult<Option<Validity>> {
481 match nullability {
482 Nullability::NonNullable => self.trivial_into_non_nullable(len),
483 Nullability::Nullable => Ok(Some(self.into_nullable())),
484 }
485 }
486
487 #[inline]
489 pub fn maybe_len(&self) -> Option<usize> {
490 match self {
491 Self::NonNullable | Self::AllValid | Self::AllInvalid => None,
492 Self::Array(a) => Some(a.len()),
493 }
494 }
495}
496
497impl From<BitBuffer> for Validity {
498 #[inline]
499 fn from(value: BitBuffer) -> Self {
500 let true_count = value.true_count();
501 if true_count == value.len() {
502 Self::AllValid
503 } else if true_count == 0 {
504 Self::AllInvalid
505 } else {
506 Self::Array(BoolArray::from(value).into_array())
507 }
508 }
509}
510
511impl FromIterator<Mask> for Validity {
512 #[inline]
513 fn from_iter<T: IntoIterator<Item = Mask>>(iter: T) -> Self {
514 Validity::from_mask(iter.into_iter().collect(), Nullability::Nullable)
515 }
516}
517
518impl FromIterator<bool> for Validity {
519 #[inline]
520 fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
521 Validity::from(BitBuffer::from_iter(iter))
522 }
523}
524
525impl From<Nullability> for Validity {
526 #[inline]
527 fn from(value: Nullability) -> Self {
528 Validity::from(&value)
529 }
530}
531
532impl From<&Nullability> for Validity {
533 #[inline]
534 fn from(value: &Nullability) -> Self {
535 match *value {
536 Nullability::NonNullable => Validity::NonNullable,
537 Nullability::Nullable => Validity::AllValid,
538 }
539 }
540}
541
542impl Validity {
543 pub fn concat(validities: Vec<(Validity, usize)>) -> Option<Self> {
547 let mut validity_kinds = validities
548 .iter()
549 .map(|(v, _)| std::mem::discriminant(v))
550 .unique();
551 let validity_kind = validity_kinds.next()?;
552 if validity_kinds.next().is_none() {
553 if validity_kind == std::mem::discriminant(&Validity::AllValid) {
556 return Some(Validity::AllValid);
557 }
558 if validity_kind == std::mem::discriminant(&Validity::AllInvalid) {
559 return Some(Validity::AllInvalid);
560 }
561 if validity_kind == std::mem::discriminant(&Validity::NonNullable) {
562 return Some(Validity::NonNullable);
563 }
564 }
565
566 Some(Validity::Array(
567 unsafe {
568 ChunkedArray::new_unchecked(
569 validities
570 .into_iter()
571 .map(|(v, len)| v.to_array(len))
572 .collect(),
573 DType::Bool(Nullability::NonNullable),
574 )
575 }
576 .into_array(),
577 ))
578 }
579}
580
581impl Validity {
582 pub fn from_bit_buffer(buffer: BitBuffer, nullability: Nullability) -> Self {
583 if buffer.true_count() == buffer.len() {
584 nullability.into()
585 } else if buffer.true_count() == 0 {
586 Validity::AllInvalid
587 } else {
588 Validity::Array(BoolArray::new(buffer, Validity::NonNullable).into_array())
589 }
590 }
591
592 pub fn from_mask(mask: Mask, nullability: Nullability) -> Self {
593 assert!(
594 nullability == Nullability::Nullable || matches!(mask, Mask::AllTrue(_)),
595 "NonNullable validity must be AllValid",
596 );
597 match mask {
598 Mask::AllTrue(_) => match nullability {
599 Nullability::NonNullable => Validity::NonNullable,
600 Nullability::Nullable => Validity::AllValid,
601 },
602 Mask::AllFalse(_) => Validity::AllInvalid,
603 Mask::Values(values) => Validity::Array(values.into_array()),
604 }
605 }
606}
607
608impl IntoArray for Mask {
609 #[inline]
610 fn into_array(self) -> ArrayRef {
611 match self {
612 Self::AllTrue(len) => ConstantArray::new(true, len).into_array(),
613 Self::AllFalse(len) => ConstantArray::new(false, len).into_array(),
614 Self::Values(a) => a.into_array(),
615 }
616 }
617}
618
619impl IntoArray for &MaskValues {
620 #[inline]
621 fn into_array(self) -> ArrayRef {
622 BoolArray::new(self.bit_buffer().clone(), Validity::NonNullable).into_array()
623 }
624}
625
626#[cfg(test)]
627mod tests {
628 use rstest::rstest;
629 use vortex_buffer::Buffer;
630 use vortex_buffer::buffer;
631 use vortex_mask::Mask;
632
633 use crate::ArrayRef;
634 use crate::IntoArray;
635 use crate::LEGACY_SESSION;
636 use crate::VortexSessionExecute;
637 use crate::arrays::PrimitiveArray;
638 use crate::dtype::Nullability;
639 use crate::validity::BoolArray;
640 use crate::validity::Validity;
641
642 #[rstest]
643 #[case(Validity::AllValid, 5, &[2, 4], Validity::AllValid, Validity::AllValid)]
644 #[case(
645 Validity::AllValid,
646 5,
647 &[2, 4],
648 Validity::AllInvalid,
649 Validity::Array(BoolArray::from_iter([true, true, false, true, false]).into_array())
650 )]
651 #[case(
652 Validity::AllValid,
653 5,
654 &[2, 4],
655 Validity::Array(BoolArray::from_iter([true, false]).into_array()),
656 Validity::Array(BoolArray::from_iter([true, true, true, true, false]).into_array())
657 )]
658 #[case(
659 Validity::AllInvalid,
660 5,
661 &[2, 4],
662 Validity::AllValid,
663 Validity::Array(BoolArray::from_iter([false, false, true, false, true]).into_array())
664 )]
665 #[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllInvalid, Validity::AllInvalid)]
666 #[case(
667 Validity::AllInvalid,
668 5,
669 &[2, 4],
670 Validity::Array(BoolArray::from_iter([true, false]).into_array()),
671 Validity::Array(BoolArray::from_iter([false, false, true, false, false]).into_array())
672 )]
673 #[case(
674 Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()),
675 5,
676 &[2, 4],
677 Validity::AllValid,
678 Validity::Array(BoolArray::from_iter([false, true, true, true, true]).into_array())
679 )]
680 #[case(
681 Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()),
682 5,
683 &[2, 4],
684 Validity::AllInvalid,
685 Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array())
686 )]
687 #[case(
688 Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()),
689 5,
690 &[2, 4],
691 Validity::Array(BoolArray::from_iter([true, false]).into_array()),
692 Validity::Array(BoolArray::from_iter([false, true, true, true, false]).into_array())
693 )]
694
695 fn patch_validity(
696 #[case] validity: Validity,
697 #[case] len: usize,
698 #[case] positions: &[u64],
699 #[case] patches: Validity,
700 #[case] expected: Validity,
701 ) {
702 let indices =
703 PrimitiveArray::new(Buffer::copy_from(positions), Validity::NonNullable).into_array();
704
705 let mut ctx = LEGACY_SESSION.create_execution_ctx();
706
707 assert!(
708 validity
709 .patch(len, 0, &indices, &patches, &mut ctx,)
710 .unwrap()
711 .mask_eq(&expected, len, &mut ctx)
712 .unwrap()
713 );
714 }
715
716 #[test]
717 #[should_panic]
718 fn out_of_bounds_patch() {
719 let mut ctx = LEGACY_SESSION.create_execution_ctx();
720 Validity::NonNullable
721 .patch(
722 2,
723 0,
724 &buffer![4].into_array(),
725 &Validity::AllInvalid,
726 &mut ctx,
727 )
728 .unwrap();
729 }
730
731 #[test]
732 #[should_panic]
733 fn into_validity_nullable() {
734 Validity::from_mask(Mask::AllFalse(10), Nullability::NonNullable);
735 }
736
737 #[test]
738 #[should_panic]
739 fn into_validity_nullable_array() {
740 Validity::from_mask(Mask::from_iter(vec![true, false]), Nullability::NonNullable);
741 }
742
743 #[rstest]
744 #[case(
745 Validity::AllValid,
746 PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(),
747 Validity::from_iter(vec![true, false])
748 )]
749 #[case(Validity::AllValid, buffer![0, 1].into_array(), Validity::AllValid)]
750 #[case(
751 Validity::AllValid,
752 PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(),
753 Validity::AllInvalid
754 )]
755 #[case(
756 Validity::NonNullable,
757 PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(),
758 Validity::from_iter(vec![true, false])
759 )]
760 #[case(Validity::NonNullable, buffer![0, 1].into_array(), Validity::NonNullable)]
761 #[case(
762 Validity::NonNullable,
763 PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(),
764 Validity::AllInvalid
765 )]
766 fn validity_take(
767 #[case] validity: Validity,
768 #[case] indices: ArrayRef,
769 #[case] expected: Validity,
770 ) {
771 let mut ctx = LEGACY_SESSION.create_execution_ctx();
772 assert!(
773 validity
774 .take(&indices)
775 .unwrap()
776 .mask_eq(&expected, indices.len(), &mut ctx)
777 .unwrap()
778 );
779 }
780
781 #[rstest]
782 #[case(Validity::NonNullable, Validity::AllValid, true)]
784 #[case(Validity::AllValid, Validity::NonNullable, true)]
785 #[case(Validity::AllValid, Validity::AllInvalid, false)]
786 #[case(Validity::NonNullable, Validity::AllInvalid, false)]
787 #[case(
789 Validity::Array(BoolArray::from_iter([true, true, true]).into_array()),
790 Validity::AllValid,
791 true
792 )]
793 #[case(
794 Validity::NonNullable,
795 Validity::Array(BoolArray::from_iter([true, true, true]).into_array()),
796 true
797 )]
798 #[case(
799 Validity::Array(BoolArray::from_iter([false, false, false]).into_array()),
800 Validity::AllInvalid,
801 true
802 )]
803 #[case(
804 Validity::Array(BoolArray::from_iter([true, false, true]).into_array()),
805 Validity::AllValid,
806 false
807 )]
808 #[case(
809 Validity::Array(BoolArray::from_iter([true, false, true]).into_array()),
810 Validity::AllInvalid,
811 false
812 )]
813 fn mask_eq_mixed_variants(
814 #[case] lhs: Validity,
815 #[case] rhs: Validity,
816 #[case] expected: bool,
817 ) -> vortex_error::VortexResult<()> {
818 let mut ctx = LEGACY_SESSION.create_execution_ctx();
819 assert_eq!(lhs.mask_eq(&rhs, 3, &mut ctx)?, expected);
820 Ok(())
821 }
822}