1use std::fmt::Debug;
12use std::ops::Range;
13
14use itertools::Itertools as _;
15use vortex_buffer::BitBuffer;
16use vortex_error::VortexExpect as _;
17use vortex_error::VortexResult;
18use vortex_error::vortex_bail;
19use vortex_error::vortex_err;
20use vortex_mask::Mask;
21use vortex_mask::MaskValues;
22
23use crate::ArrayRef;
24use crate::Canonical;
25use crate::ExecutionCtx;
26use crate::IntoArray;
27use crate::LEGACY_SESSION;
28use crate::VortexSessionExecute;
29use crate::arrays::BoolArray;
30use crate::arrays::ChunkedArray;
31use crate::arrays::ConstantArray;
32use crate::arrays::scalar_fn::ScalarFnFactoryExt;
33use crate::builtins::ArrayBuiltins;
34use crate::dtype::DType;
35use crate::dtype::Nullability;
36use crate::optimizer::ArrayOptimizer;
37use crate::patches::Patches;
38use crate::scalar::Scalar;
39use crate::scalar_fn::fns::binary::Binary;
40use crate::scalar_fn::fns::operators::Operator;
41
42#[derive(Clone)]
44pub enum Validity {
45 NonNullable,
47 AllValid,
49 AllInvalid,
51 Array(ArrayRef),
55}
56
57impl Debug for Validity {
58 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59 match self {
60 Self::NonNullable => write!(f, "NonNullable"),
61 Self::AllValid => write!(f, "AllValid"),
62 Self::AllInvalid => write!(f, "AllInvalid"),
63 Self::Array(arr) => write!(f, "SomeValid({})", arr.display_values()),
64 }
65 }
66}
67
68impl Validity {
69 pub fn execute(self, ctx: &mut ExecutionCtx) -> VortexResult<Validity> {
71 match self {
72 v @ Validity::NonNullable | v @ Validity::AllValid | v @ Validity::AllInvalid => Ok(v),
73 Validity::Array(a) => Ok(Validity::Array(a.execute::<Canonical>(ctx)?.into_array())),
74 }
75 }
76}
77
78impl Validity {
79 pub const DTYPE: DType = DType::Bool(Nullability::NonNullable);
81
82 pub fn to_array(&self, len: usize) -> ArrayRef {
84 match self {
85 Self::NonNullable | Self::AllValid => ConstantArray::new(true, len).into_array(),
86 Self::AllInvalid => ConstantArray::new(false, len).into_array(),
87 Self::Array(a) => a.clone(),
88 }
89 }
90
91 #[inline]
93 pub fn into_array(self) -> Option<ArrayRef> {
94 if let Self::Array(a) = self {
95 Some(a)
96 } else {
97 None
98 }
99 }
100
101 #[inline]
103 pub fn as_array(&self) -> Option<&ArrayRef> {
104 if let Self::Array(a) = self {
105 Some(a)
106 } else {
107 None
108 }
109 }
110
111 #[inline]
112 pub fn nullability(&self) -> Nullability {
113 if matches!(self, Self::NonNullable) {
114 Nullability::NonNullable
115 } else {
116 Nullability::Nullable
117 }
118 }
119
120 #[inline]
127 pub fn definitely_no_nulls(&self) -> bool {
128 matches!(self, Self::NonNullable | Self::AllValid)
129 }
130
131 #[inline]
140 pub fn definitely_all_null(&self) -> bool {
141 matches!(self, Self::AllInvalid)
142 }
143
144 pub fn execute_no_nulls(&self, length: usize, ctx: &mut ExecutionCtx) -> VortexResult<bool> {
150 match self {
151 Self::NonNullable | Self::AllValid => Ok(true),
152 Self::AllInvalid => Ok(length == 0),
153 Self::Array(_) => Ok(self.execute_mask(length, ctx)?.all_true()),
154 }
155 }
156
157 #[inline]
159 pub fn union_nullability(self, nullability: Nullability) -> Self {
160 match nullability {
161 Nullability::NonNullable => self,
162 Nullability::Nullable => self.into_nullable(),
163 }
164 }
165
166 #[inline]
168 pub fn execute_is_valid(&self, index: usize, ctx: &mut ExecutionCtx) -> VortexResult<bool> {
169 Ok(match self {
170 Self::NonNullable | Self::AllValid => true,
171 Self::AllInvalid => false,
172 Self::Array(a) => a
173 .execute_scalar(index, ctx)?
174 .as_bool()
175 .value()
176 .ok_or_else(|| vortex_err!("validity value at index {index} is null"))?,
177 })
178 }
179
180 #[inline]
182 pub fn execute_is_null(&self, index: usize, ctx: &mut ExecutionCtx) -> VortexResult<bool> {
183 Ok(!self.execute_is_valid(index, ctx)?)
184 }
185
186 #[deprecated(note = "use `execute_is_valid` with an explicit `ExecutionCtx`")]
188 #[inline]
189 pub fn is_valid(&self, index: usize) -> VortexResult<bool> {
190 self.execute_is_valid(index, &mut LEGACY_SESSION.create_execution_ctx())
191 }
192
193 #[deprecated(note = "use `execute_is_null` with an explicit `ExecutionCtx`")]
195 #[inline]
196 pub fn is_null(&self, index: usize) -> VortexResult<bool> {
197 self.execute_is_null(index, &mut LEGACY_SESSION.create_execution_ctx())
198 }
199
200 #[inline]
201 pub fn slice(&self, range: Range<usize>) -> VortexResult<Self> {
202 match self {
203 Self::Array(a) => Ok(Self::Array(a.slice(range)?)),
204 Self::NonNullable | Self::AllValid | Self::AllInvalid => Ok(self.clone()),
205 }
206 }
207
208 pub fn take(&self, indices: &ArrayRef) -> VortexResult<Self> {
209 match self {
210 Self::NonNullable => indices.validity(),
211 Self::AllValid => Ok(match indices.validity()? {
212 Self::NonNullable => Self::AllValid,
213 v => v,
214 }),
215 Self::AllInvalid => Ok(Self::AllInvalid),
216 Self::Array(is_valid) => {
217 let maybe_is_valid = is_valid.take(indices.clone())?;
218 let is_valid = maybe_is_valid.fill_null(Scalar::from(false))?;
220 Ok(Self::Array(is_valid))
221 }
222 }
223 }
224
225 pub fn not(&self) -> VortexResult<Self> {
227 match self {
228 Validity::NonNullable => Ok(Validity::NonNullable),
229 Validity::AllValid => Ok(Validity::AllInvalid),
230 Validity::AllInvalid => Ok(Validity::AllValid),
231 Validity::Array(arr) => Ok(Validity::Array(arr.not()?)),
232 }
233 }
234
235 pub fn filter(&self, mask: &Mask) -> VortexResult<Self> {
243 match self {
246 v @ (Validity::NonNullable | Validity::AllValid | Validity::AllInvalid) => {
247 Ok(v.clone())
248 }
249 Validity::Array(arr) => Ok(Validity::Array(arr.filter(mask.clone())?)),
250 }
251 }
252
253 #[deprecated(note = "Use execute_mask")]
257 pub fn to_mask(&self, length: usize, ctx: &mut ExecutionCtx) -> VortexResult<Mask> {
258 match self {
259 Self::NonNullable | Self::AllValid => Ok(Mask::new_true(length)),
260 Self::AllInvalid => Ok(Mask::new_false(length)),
261 Self::Array(arr) => arr.clone().execute::<Mask>(ctx),
262 }
263 }
264
265 #[inline]
266 pub fn execute_mask(&self, length: usize, ctx: &mut ExecutionCtx) -> VortexResult<Mask> {
267 match self {
268 Self::NonNullable | Self::AllValid => Ok(Mask::AllTrue(length)),
269 Self::AllInvalid => Ok(Mask::AllFalse(length)),
270 Self::Array(arr) => {
271 assert_eq!(
272 arr.len(),
273 length,
274 "Validity::Array length must equal to_logical's argument: {}, {}.",
275 arr.len(),
276 length,
277 );
278 arr.clone().execute::<Mask>(ctx)
281 }
282 }
283 }
284
285 pub fn mask_eq(
288 &self,
289 other: &Validity,
290 length: usize,
291 ctx: &mut ExecutionCtx,
292 ) -> VortexResult<bool> {
293 match (self, other) {
294 (
296 Validity::NonNullable | Validity::AllValid,
297 Validity::NonNullable | Validity::AllValid,
298 )
299 | (Validity::AllInvalid, Validity::AllInvalid) => Ok(true),
300 _ => Ok(self.execute_mask(length, ctx)? == other.execute_mask(length, ctx)?),
301 }
302 }
303
304 #[inline]
306 pub fn and(self, rhs: Validity) -> VortexResult<Validity> {
307 Ok(match (self, rhs) {
308 (Validity::NonNullable, Validity::NonNullable) => Validity::NonNullable,
310 (Validity::AllInvalid, _) | (_, Validity::AllInvalid) => Validity::AllInvalid,
312 (Validity::Array(a), Validity::AllValid)
314 | (Validity::Array(a), Validity::NonNullable)
315 | (Validity::NonNullable, Validity::Array(a))
316 | (Validity::AllValid, Validity::Array(a)) => Validity::Array(a),
317 (Validity::NonNullable, Validity::AllValid)
319 | (Validity::AllValid, Validity::NonNullable)
320 | (Validity::AllValid, Validity::AllValid) => Validity::AllValid,
321 (Validity::Array(lhs), Validity::Array(rhs)) => Validity::Array(
323 Binary
324 .try_new_array(lhs.len(), Operator::And, [lhs, rhs])?
325 .optimize()?,
326 ),
327 })
328 }
329
330 pub fn patch(
331 self,
332 len: usize,
333 indices_offset: usize,
334 indices: &ArrayRef,
335 patches: &Validity,
336 ctx: &mut ExecutionCtx,
337 ) -> VortexResult<Self> {
338 match (&self, patches) {
339 (Validity::NonNullable, Validity::NonNullable) => return Ok(Validity::NonNullable),
340 (Validity::NonNullable, _) => {
341 vortex_bail!("Can't patch a non-nullable validity with nullable validity")
342 }
343 (_, Validity::NonNullable) => {
344 vortex_bail!("Can't patch a nullable validity with non-nullable validity")
345 }
346 (Validity::AllValid, Validity::AllValid) => return Ok(Validity::AllValid),
347 (Validity::AllInvalid, Validity::AllInvalid) => return Ok(Validity::AllInvalid),
348 _ => {}
349 };
350
351 if matches!(self, Validity::NonNullable) {
352 return Ok(Self::NonNullable);
353 }
354
355 let source = match self {
357 Validity::NonNullable => BoolArray::from(BitBuffer::new_set(len)),
358 Validity::AllValid => BoolArray::from(BitBuffer::new_set(len)),
359 Validity::AllInvalid => BoolArray::from(BitBuffer::new_unset(len)),
360 Validity::Array(a) => a.execute::<BoolArray>(ctx)?,
361 };
362
363 let patch_values = match patches {
364 Validity::NonNullable => BoolArray::from(BitBuffer::new_set(indices.len())),
365 Validity::AllValid => BoolArray::from(BitBuffer::new_set(indices.len())),
366 Validity::AllInvalid => BoolArray::from(BitBuffer::new_unset(indices.len())),
367 Validity::Array(a) => a.clone().execute::<BoolArray>(ctx)?,
368 };
369
370 let patches = Patches::new(
371 len,
372 indices_offset,
373 indices.clone(),
374 patch_values.into_array(),
375 None,
377 )?;
378
379 Ok(Self::Array(source.patch(&patches, ctx)?.into_array()))
380 }
381
382 #[inline]
384 pub fn into_nullable(self) -> Validity {
385 match self {
386 Self::NonNullable => Self::AllValid,
387 Self::AllValid | Self::AllInvalid | Self::Array(_) => self,
388 }
389 }
390
391 #[inline]
397 pub fn into_non_nullable(self, len: usize, ctx: &mut ExecutionCtx) -> Option<Validity> {
398 match self {
399 _ if len == 0 => Some(Validity::NonNullable),
400 Self::NonNullable => Some(Self::NonNullable),
401 Self::AllValid => Some(Self::NonNullable),
402 Self::AllInvalid => None,
403 Self::Array(is_valid) => {
404 is_valid
405 .statistics()
406 .compute_min::<bool>(ctx)
407 .vortex_expect("validity array must support min")
408 .then(|| {
409 Self::NonNullable
411 })
412 }
413 }
414 }
415
416 #[inline]
427 pub fn trivial_into_non_nullable(self, len: usize) -> VortexResult<Option<Validity>> {
428 match self {
429 _ if len == 0 => Ok(Some(Validity::NonNullable)),
430 Self::NonNullable => Ok(Some(Self::NonNullable)),
431 Self::AllValid => Ok(Some(Self::NonNullable)),
432 Self::AllInvalid => {
433 Err(vortex_err!(InvalidArgument: "Cannot cast AllInvalid to NonNullable"))
434 }
435 Self::Array(_) => Ok(None),
436 }
437 }
438
439 #[inline]
455 pub fn cast_nullability(
456 self,
457 nullability: Nullability,
458 len: usize,
459 ctx: &mut ExecutionCtx,
460 ) -> VortexResult<Validity> {
461 match nullability {
462 Nullability::NonNullable => self.into_non_nullable(len, ctx).ok_or_else(|| {
463 vortex_err!(InvalidArgument: "Cannot cast array with invalid values to non-nullable type.")
464 }),
465 Nullability::Nullable => Ok(self.into_nullable()),
466 }
467 }
468
469 #[inline]
494 pub fn trivially_cast_nullability(
495 self,
496 nullability: Nullability,
497 len: usize,
498 ) -> VortexResult<Option<Validity>> {
499 match nullability {
500 Nullability::NonNullable => self.trivial_into_non_nullable(len),
501 Nullability::Nullable => Ok(Some(self.into_nullable())),
502 }
503 }
504
505 #[inline]
507 pub fn maybe_len(&self) -> Option<usize> {
508 match self {
509 Self::NonNullable | Self::AllValid | Self::AllInvalid => None,
510 Self::Array(a) => Some(a.len()),
511 }
512 }
513}
514
515impl From<BitBuffer> for Validity {
516 #[inline]
517 fn from(value: BitBuffer) -> Self {
518 let true_count = value.true_count();
519 if true_count == value.len() {
520 Self::AllValid
521 } else if true_count == 0 {
522 Self::AllInvalid
523 } else {
524 Self::Array(BoolArray::from(value).into_array())
525 }
526 }
527}
528
529impl FromIterator<Mask> for Validity {
530 #[inline]
531 fn from_iter<T: IntoIterator<Item = Mask>>(iter: T) -> Self {
532 Validity::from_mask(iter.into_iter().collect(), Nullability::Nullable)
533 }
534}
535
536impl FromIterator<bool> for Validity {
537 #[inline]
538 fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
539 Validity::from(BitBuffer::from_iter(iter))
540 }
541}
542
543impl From<Nullability> for Validity {
544 #[inline]
545 fn from(value: Nullability) -> Self {
546 Validity::from(&value)
547 }
548}
549
550impl From<&Nullability> for Validity {
551 #[inline]
552 fn from(value: &Nullability) -> Self {
553 match *value {
554 Nullability::NonNullable => Validity::NonNullable,
555 Nullability::Nullable => Validity::AllValid,
556 }
557 }
558}
559
560impl Validity {
561 pub fn concat(validities: Vec<(Validity, usize)>) -> Option<Self> {
565 let mut validity_kinds = validities
566 .iter()
567 .map(|(v, _)| std::mem::discriminant(v))
568 .unique();
569 let validity_kind = validity_kinds.next()?;
570 if validity_kinds.next().is_none() {
571 if validity_kind == std::mem::discriminant(&Validity::AllValid) {
574 return Some(Validity::AllValid);
575 }
576 if validity_kind == std::mem::discriminant(&Validity::AllInvalid) {
577 return Some(Validity::AllInvalid);
578 }
579 if validity_kind == std::mem::discriminant(&Validity::NonNullable) {
580 return Some(Validity::NonNullable);
581 }
582 }
583
584 Some(Validity::Array(
585 unsafe {
586 ChunkedArray::new_unchecked(
587 validities
588 .into_iter()
589 .map(|(v, len)| v.to_array(len))
590 .collect(),
591 DType::Bool(Nullability::NonNullable),
592 )
593 }
594 .into_array(),
595 ))
596 }
597}
598
599impl Validity {
600 pub fn from_bit_buffer(buffer: BitBuffer, nullability: Nullability) -> Self {
601 if buffer.true_count() == buffer.len() {
602 nullability.into()
603 } else if buffer.true_count() == 0 {
604 Validity::AllInvalid
605 } else {
606 Validity::Array(BoolArray::new(buffer, Validity::NonNullable).into_array())
607 }
608 }
609
610 pub fn from_mask(mask: Mask, nullability: Nullability) -> Self {
611 assert!(
612 nullability == Nullability::Nullable || matches!(mask, Mask::AllTrue(_)),
613 "NonNullable validity must be AllValid",
614 );
615 match mask {
616 Mask::AllTrue(_) => match nullability {
617 Nullability::NonNullable => Validity::NonNullable,
618 Nullability::Nullable => Validity::AllValid,
619 },
620 Mask::AllFalse(_) => Validity::AllInvalid,
621 Mask::Values(values) => Validity::Array(values.into_array()),
622 }
623 }
624}
625
626impl IntoArray for Mask {
627 #[inline]
628 fn into_array(self) -> ArrayRef {
629 match self {
630 Self::AllTrue(len) => ConstantArray::new(true, len).into_array(),
631 Self::AllFalse(len) => ConstantArray::new(false, len).into_array(),
632 Self::Values(a) => a.into_array(),
633 }
634 }
635}
636
637impl IntoArray for &MaskValues {
638 #[inline]
639 fn into_array(self) -> ArrayRef {
640 BoolArray::new(self.bit_buffer().clone(), Validity::NonNullable).into_array()
641 }
642}
643
644#[cfg(test)]
645mod tests {
646 use rstest::rstest;
647 use vortex_buffer::Buffer;
648 use vortex_buffer::buffer;
649 use vortex_mask::Mask;
650
651 use crate::ArrayRef;
652 use crate::IntoArray;
653 use crate::VortexSessionExecute;
654 use crate::array_session;
655 use crate::arrays::PrimitiveArray;
656 use crate::dtype::Nullability;
657 use crate::validity::BoolArray;
658 use crate::validity::Validity;
659
660 #[rstest]
661 #[case(Validity::AllValid, 5, &[2, 4], Validity::AllValid, Validity::AllValid)]
662 #[case(
663 Validity::AllValid,
664 5,
665 &[2, 4],
666 Validity::AllInvalid,
667 Validity::Array(BoolArray::from_iter([true, true, false, true, false]).into_array())
668 )]
669 #[case(
670 Validity::AllValid,
671 5,
672 &[2, 4],
673 Validity::Array(BoolArray::from_iter([true, false]).into_array()),
674 Validity::Array(BoolArray::from_iter([true, true, true, true, false]).into_array())
675 )]
676 #[case(
677 Validity::AllInvalid,
678 5,
679 &[2, 4],
680 Validity::AllValid,
681 Validity::Array(BoolArray::from_iter([false, false, true, false, true]).into_array())
682 )]
683 #[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllInvalid, Validity::AllInvalid)]
684 #[case(
685 Validity::AllInvalid,
686 5,
687 &[2, 4],
688 Validity::Array(BoolArray::from_iter([true, false]).into_array()),
689 Validity::Array(BoolArray::from_iter([false, false, true, false, false]).into_array())
690 )]
691 #[case(
692 Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()),
693 5,
694 &[2, 4],
695 Validity::AllValid,
696 Validity::Array(BoolArray::from_iter([false, true, true, true, true]).into_array())
697 )]
698 #[case(
699 Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()),
700 5,
701 &[2, 4],
702 Validity::AllInvalid,
703 Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array())
704 )]
705 #[case(
706 Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()),
707 5,
708 &[2, 4],
709 Validity::Array(BoolArray::from_iter([true, false]).into_array()),
710 Validity::Array(BoolArray::from_iter([false, true, true, true, false]).into_array())
711 )]
712
713 fn patch_validity(
714 #[case] validity: Validity,
715 #[case] len: usize,
716 #[case] positions: &[u64],
717 #[case] patches: Validity,
718 #[case] expected: Validity,
719 ) {
720 let indices =
721 PrimitiveArray::new(Buffer::copy_from(positions), Validity::NonNullable).into_array();
722
723 let mut ctx = array_session().create_execution_ctx();
724
725 assert!(
726 validity
727 .patch(len, 0, &indices, &patches, &mut ctx,)
728 .unwrap()
729 .mask_eq(&expected, len, &mut ctx)
730 .unwrap()
731 );
732 }
733
734 #[test]
735 #[should_panic]
736 fn out_of_bounds_patch() {
737 let mut ctx = array_session().create_execution_ctx();
738 Validity::NonNullable
739 .patch(
740 2,
741 0,
742 &buffer![4].into_array(),
743 &Validity::AllInvalid,
744 &mut ctx,
745 )
746 .unwrap();
747 }
748
749 #[test]
750 #[should_panic]
751 fn into_validity_nullable() {
752 Validity::from_mask(Mask::AllFalse(10), Nullability::NonNullable);
753 }
754
755 #[test]
756 #[should_panic]
757 fn into_validity_nullable_array() {
758 Validity::from_mask(Mask::from_iter(vec![true, false]), Nullability::NonNullable);
759 }
760
761 #[rstest]
762 #[case(
763 Validity::AllValid,
764 PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(),
765 Validity::from_iter(vec![true, false])
766 )]
767 #[case(Validity::AllValid, buffer![0, 1].into_array(), Validity::AllValid)]
768 #[case(
769 Validity::AllValid,
770 PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(),
771 Validity::AllInvalid
772 )]
773 #[case(
774 Validity::NonNullable,
775 PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(),
776 Validity::from_iter(vec![true, false])
777 )]
778 #[case(Validity::NonNullable, buffer![0, 1].into_array(), Validity::NonNullable)]
779 #[case(
780 Validity::NonNullable,
781 PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(),
782 Validity::AllInvalid
783 )]
784 fn validity_take(
785 #[case] validity: Validity,
786 #[case] indices: ArrayRef,
787 #[case] expected: Validity,
788 ) {
789 let mut ctx = array_session().create_execution_ctx();
790 assert!(
791 validity
792 .take(&indices)
793 .unwrap()
794 .mask_eq(&expected, indices.len(), &mut ctx)
795 .unwrap()
796 );
797 }
798
799 #[rstest]
800 #[case(Validity::NonNullable, Validity::AllValid, true)]
802 #[case(Validity::AllValid, Validity::NonNullable, true)]
803 #[case(Validity::AllValid, Validity::AllInvalid, false)]
804 #[case(Validity::NonNullable, Validity::AllInvalid, false)]
805 #[case(
807 Validity::Array(BoolArray::from_iter([true, true, true]).into_array()),
808 Validity::AllValid,
809 true
810 )]
811 #[case(
812 Validity::NonNullable,
813 Validity::Array(BoolArray::from_iter([true, true, true]).into_array()),
814 true
815 )]
816 #[case(
817 Validity::Array(BoolArray::from_iter([false, false, false]).into_array()),
818 Validity::AllInvalid,
819 true
820 )]
821 #[case(
822 Validity::Array(BoolArray::from_iter([true, false, true]).into_array()),
823 Validity::AllValid,
824 false
825 )]
826 #[case(
827 Validity::Array(BoolArray::from_iter([true, false, true]).into_array()),
828 Validity::AllInvalid,
829 false
830 )]
831 fn mask_eq_mixed_variants(
832 #[case] lhs: Validity,
833 #[case] rhs: Validity,
834 #[case] expected: bool,
835 ) -> vortex_error::VortexResult<()> {
836 let mut ctx = array_session().create_execution_ctx();
837 assert_eq!(lhs.mask_eq(&rhs, 3, &mut ctx)?, expected);
838 Ok(())
839 }
840}