1use std::fmt::Debug;
7use std::ops::Range;
8
9use vortex_buffer::BitBuffer;
10use vortex_error::VortexExpect as _;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_error::vortex_err;
14use vortex_error::vortex_panic;
15use vortex_mask::AllOr;
16use vortex_mask::Mask;
17use vortex_mask::MaskValues;
18
19use crate::Array;
20use crate::ArrayRef;
21use crate::Canonical;
22use crate::ExecutionCtx;
23use crate::IntoArray;
24use crate::ToCanonical;
25use crate::arrays::BoolArray;
26use crate::arrays::ConstantArray;
27use crate::arrays::ScalarFnArrayExt;
28use crate::builtins::ArrayBuiltins;
29use crate::compute::sum;
30use crate::dtype::DType;
31use crate::dtype::Nullability;
32use crate::optimizer::ArrayOptimizer;
33use crate::patches::Patches;
34use crate::scalar::Scalar;
35use crate::scalar_fn::fns::binary::Binary;
36use crate::scalar_fn::fns::operators::Operator;
37
38#[derive(Clone, Debug)]
40pub enum Validity {
41 NonNullable,
43 AllValid,
45 AllInvalid,
47 Array(ArrayRef),
51}
52
53impl Validity {
54 pub fn execute(self, ctx: &mut ExecutionCtx) -> VortexResult<Validity> {
56 match self {
57 v @ Validity::NonNullable | v @ Validity::AllValid | v @ Validity::AllInvalid => Ok(v),
58 Validity::Array(a) => Ok(Validity::Array(a.execute::<Canonical>(ctx)?.into_array())),
59 }
60 }
61}
62
63impl Validity {
64 pub const DTYPE: DType = DType::Bool(Nullability::NonNullable);
66
67 pub fn to_array(&self, len: usize) -> ArrayRef {
69 match self {
70 Self::NonNullable | Self::AllValid => ConstantArray::new(true, len).into_array(),
71 Self::AllInvalid => ConstantArray::new(false, len).into_array(),
72 Self::Array(a) => a.clone(),
73 }
74 }
75
76 #[inline]
78 pub fn into_array(self) -> Option<ArrayRef> {
79 if let Self::Array(a) = self {
80 Some(a)
81 } else {
82 None
83 }
84 }
85
86 #[inline]
88 pub fn as_array(&self) -> Option<&ArrayRef> {
89 if let Self::Array(a) = self {
90 Some(a)
91 } else {
92 None
93 }
94 }
95
96 #[inline]
97 pub fn nullability(&self) -> Nullability {
98 if matches!(self, Self::NonNullable) {
99 Nullability::NonNullable
100 } else {
101 Nullability::Nullable
102 }
103 }
104
105 #[inline]
107 pub fn union_nullability(self, nullability: Nullability) -> Self {
108 match nullability {
109 Nullability::NonNullable => self,
110 Nullability::Nullable => self.into_nullable(),
111 }
112 }
113
114 #[inline]
115 pub fn all_valid(&self, len: usize) -> VortexResult<bool> {
116 Ok(match self {
117 _ if len == 0 => true,
118 Validity::NonNullable | Validity::AllValid => true,
119 Validity::AllInvalid => false,
120 Validity::Array(array) => {
121 usize::try_from(&sum(array).vortex_expect("must have sum for bool array"))
122 .vortex_expect("sum must be a usize")
123 == array.len()
124 }
125 })
126 }
127
128 #[inline]
129 pub fn all_invalid(&self, len: usize) -> VortexResult<bool> {
130 Ok(match self {
131 _ if len == 0 => true,
132 Validity::NonNullable | Validity::AllValid => false,
133 Validity::AllInvalid => true,
134 Validity::Array(array) => {
135 usize::try_from(&sum(array).vortex_expect("must have sum for bool array"))
136 .vortex_expect("sum must be a usize")
137 == 0
138 }
139 })
140 }
141
142 #[inline]
144 pub fn is_valid(&self, index: usize) -> VortexResult<bool> {
145 Ok(match self {
146 Self::NonNullable | Self::AllValid => true,
147 Self::AllInvalid => false,
148 Self::Array(a) => a
149 .scalar_at(index)
150 .vortex_expect("Validity array must support scalar_at")
151 .as_bool()
152 .value()
153 .vortex_expect("Validity must be non-nullable"),
154 })
155 }
156
157 #[inline]
158 pub fn is_null(&self, index: usize) -> VortexResult<bool> {
159 Ok(!self.is_valid(index)?)
160 }
161
162 #[inline]
163 pub fn slice(&self, range: Range<usize>) -> VortexResult<Self> {
164 match self {
165 Self::Array(a) => Ok(Self::Array(a.slice(range)?)),
166 Self::NonNullable | Self::AllValid | Self::AllInvalid => Ok(self.clone()),
167 }
168 }
169
170 pub fn take(&self, indices: &ArrayRef) -> VortexResult<Self> {
171 match self {
172 Self::NonNullable => match indices.validity_mask()?.bit_buffer() {
173 AllOr::All => {
174 if indices.dtype().is_nullable() {
175 Ok(Self::AllValid)
176 } else {
177 Ok(Self::NonNullable)
178 }
179 }
180 AllOr::None => Ok(Self::AllInvalid),
181 AllOr::Some(buf) => Ok(Validity::from(buf.clone())),
182 },
183 Self::AllValid => match indices.validity_mask()?.bit_buffer() {
184 AllOr::All => Ok(Self::AllValid),
185 AllOr::None => Ok(Self::AllInvalid),
186 AllOr::Some(buf) => Ok(Validity::from(buf.clone())),
187 },
188 Self::AllInvalid => Ok(Self::AllInvalid),
189 Self::Array(is_valid) => {
190 let maybe_is_valid = is_valid.take(indices.to_array())?;
191 let is_valid = maybe_is_valid.fill_null(Scalar::from(false))?;
193 Ok(Self::Array(is_valid))
194 }
195 }
196 }
197
198 pub fn not(&self) -> VortexResult<Self> {
200 match self {
201 Validity::NonNullable => Ok(Validity::NonNullable),
202 Validity::AllValid => Ok(Validity::AllInvalid),
203 Validity::AllInvalid => Ok(Validity::AllValid),
204 Validity::Array(arr) => Ok(Validity::Array(arr.not()?)),
205 }
206 }
207
208 pub fn filter(&self, mask: &Mask) -> VortexResult<Self> {
216 match self {
219 v @ (Validity::NonNullable | Validity::AllValid | Validity::AllInvalid) => {
220 Ok(v.clone())
221 }
222 Validity::Array(arr) => Ok(Validity::Array(arr.filter(mask.clone())?)),
223 }
224 }
225
226 #[inline]
227 pub fn to_mask(&self, length: usize) -> Mask {
228 match self {
229 Self::NonNullable | Self::AllValid => Mask::AllTrue(length),
230 Self::AllInvalid => Mask::AllFalse(length),
231 Self::Array(is_valid) => {
232 assert_eq!(
233 is_valid.len(),
234 length,
235 "Validity::Array length must equal to_logical's argument: {}, {}.",
236 is_valid.len(),
237 length,
238 );
239 is_valid.to_bool().to_mask()
240 }
241 }
242 }
243
244 #[inline]
246 pub fn and(self, rhs: Validity) -> VortexResult<Validity> {
247 Ok(match (self, rhs) {
248 (Validity::NonNullable, Validity::NonNullable) => Validity::NonNullable,
250 (Validity::AllInvalid, _) | (_, Validity::AllInvalid) => Validity::AllInvalid,
252 (Validity::Array(a), Validity::AllValid)
254 | (Validity::Array(a), Validity::NonNullable)
255 | (Validity::NonNullable, Validity::Array(a))
256 | (Validity::AllValid, Validity::Array(a)) => Validity::Array(a),
257 (Validity::NonNullable, Validity::AllValid)
259 | (Validity::AllValid, Validity::NonNullable)
260 | (Validity::AllValid, Validity::AllValid) => Validity::AllValid,
261 (Validity::Array(lhs), Validity::Array(rhs)) => Validity::Array(
263 Binary
264 .try_new_array(lhs.len(), Operator::And, [lhs, rhs])?
265 .optimize()?,
266 ),
267 })
268 }
269
270 pub fn patch(
271 self,
272 len: usize,
273 indices_offset: usize,
274 indices: &ArrayRef,
275 patches: &Validity,
276 ctx: &mut ExecutionCtx,
277 ) -> VortexResult<Self> {
278 match (&self, patches) {
279 (Validity::NonNullable, Validity::NonNullable) => return Ok(Validity::NonNullable),
280 (Validity::NonNullable, _) => {
281 vortex_bail!("Can't patch a non-nullable validity with nullable validity")
282 }
283 (_, Validity::NonNullable) => {
284 vortex_bail!("Can't patch a nullable validity with non-nullable validity")
285 }
286 (Validity::AllValid, Validity::AllValid) => return Ok(Validity::AllValid),
287 (Validity::AllInvalid, Validity::AllInvalid) => return Ok(Validity::AllInvalid),
288 _ => {}
289 };
290
291 let own_nullability = if self == Validity::NonNullable {
292 Nullability::NonNullable
293 } else {
294 Nullability::Nullable
295 };
296
297 let source = match self {
298 Validity::NonNullable => BoolArray::from(BitBuffer::new_set(len)),
299 Validity::AllValid => BoolArray::from(BitBuffer::new_set(len)),
300 Validity::AllInvalid => BoolArray::from(BitBuffer::new_unset(len)),
301 Validity::Array(a) => a.clone().execute::<BoolArray>(ctx)?,
302 };
303
304 let patch_values = match patches {
305 Validity::NonNullable => BoolArray::from(BitBuffer::new_set(indices.len())),
306 Validity::AllValid => BoolArray::from(BitBuffer::new_set(indices.len())),
307 Validity::AllInvalid => BoolArray::from(BitBuffer::new_unset(indices.len())),
308 Validity::Array(a) => a.clone().execute::<BoolArray>(ctx)?,
309 };
310
311 let patches = Patches::new(
312 len,
313 indices_offset,
314 indices.to_array(),
315 patch_values.into_array(),
316 None,
318 )?;
319
320 Ok(Self::from_array(
321 source.patch(&patches, ctx)?.into_array(),
322 own_nullability,
323 ))
324 }
325
326 #[inline]
328 pub fn into_nullable(self) -> Validity {
329 match self {
330 Self::NonNullable => Self::AllValid,
331 Self::AllValid | Self::AllInvalid | Self::Array(_) => self,
332 }
333 }
334
335 #[inline]
337 pub fn into_non_nullable(self, len: usize) -> Option<Validity> {
338 match self {
339 _ if len == 0 => Some(Validity::NonNullable),
340 Self::NonNullable => Some(Self::NonNullable),
341 Self::AllValid => Some(Self::NonNullable),
342 Self::AllInvalid => None,
343 Self::Array(is_valid) => {
344 is_valid
345 .statistics()
346 .compute_min::<bool>()
347 .vortex_expect("validity array must support min")
348 .then(|| {
349 Self::NonNullable
351 })
352 }
353 }
354 }
355
356 #[inline]
358 pub fn cast_nullability(self, nullability: Nullability, len: usize) -> VortexResult<Validity> {
359 match nullability {
360 Nullability::NonNullable => self.into_non_nullable(len).ok_or_else(|| {
361 vortex_err!(InvalidArgument: "Cannot cast array with invalid values to non-nullable type.")
362 }),
363 Nullability::Nullable => Ok(self.into_nullable()),
364 }
365 }
366
367 #[inline]
369 pub fn copy_from_array(array: &ArrayRef) -> VortexResult<Self> {
370 Ok(Validity::from_mask(
371 array.validity_mask()?,
372 array.dtype().nullability(),
373 ))
374 }
375
376 #[inline]
381 fn from_array(value: ArrayRef, nullability: Nullability) -> Self {
382 if !matches!(value.dtype(), DType::Bool(Nullability::NonNullable)) {
383 vortex_panic!("Expected a non-nullable boolean array")
384 }
385 match nullability {
386 Nullability::NonNullable => Self::NonNullable,
387 Nullability::Nullable => Self::Array(value),
388 }
389 }
390
391 #[inline]
393 pub fn maybe_len(&self) -> Option<usize> {
394 match self {
395 Self::NonNullable | Self::AllValid | Self::AllInvalid => None,
396 Self::Array(a) => Some(a.len()),
397 }
398 }
399
400 #[inline]
401 pub fn uncompressed_size(&self) -> usize {
402 if let Validity::Array(a) = self {
403 a.len().div_ceil(8)
404 } else {
405 0
406 }
407 }
408}
409
410impl PartialEq for Validity {
411 #[inline]
412 fn eq(&self, other: &Self) -> bool {
413 match (self, other) {
414 (Self::NonNullable, Self::NonNullable) => true,
415 (Self::AllValid, Self::AllValid) => true,
416 (Self::AllInvalid, Self::AllInvalid) => true,
417 (Self::Array(a), Self::Array(b)) => {
418 let a = a.to_bool();
419 let b = b.to_bool();
420 a.to_bit_buffer() == b.to_bit_buffer()
421 }
422 _ => false,
423 }
424 }
425}
426
427impl From<BitBuffer> for Validity {
428 #[inline]
429 fn from(value: BitBuffer) -> Self {
430 let true_count = value.true_count();
431 if true_count == value.len() {
432 Self::AllValid
433 } else if true_count == 0 {
434 Self::AllInvalid
435 } else {
436 Self::Array(BoolArray::from(value).into_array())
437 }
438 }
439}
440
441impl FromIterator<Mask> for Validity {
442 #[inline]
443 fn from_iter<T: IntoIterator<Item = Mask>>(iter: T) -> Self {
444 Validity::from_mask(iter.into_iter().collect(), Nullability::Nullable)
445 }
446}
447
448impl FromIterator<bool> for Validity {
449 #[inline]
450 fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
451 Validity::from(BitBuffer::from_iter(iter))
452 }
453}
454
455impl From<Nullability> for Validity {
456 #[inline]
457 fn from(value: Nullability) -> Self {
458 Validity::from(&value)
459 }
460}
461
462impl From<&Nullability> for Validity {
463 #[inline]
464 fn from(value: &Nullability) -> Self {
465 match *value {
466 Nullability::NonNullable => Validity::NonNullable,
467 Nullability::Nullable => Validity::AllValid,
468 }
469 }
470}
471
472impl Validity {
473 pub fn from_bit_buffer(buffer: BitBuffer, nullability: Nullability) -> Self {
474 if buffer.true_count() == buffer.len() {
475 nullability.into()
476 } else if buffer.true_count() == 0 {
477 Validity::AllInvalid
478 } else {
479 Validity::Array(BoolArray::new(buffer, Validity::NonNullable).into_array())
480 }
481 }
482
483 pub fn from_mask(mask: Mask, nullability: Nullability) -> Self {
484 assert!(
485 nullability == Nullability::Nullable || matches!(mask, Mask::AllTrue(_)),
486 "NonNullable validity must be AllValid",
487 );
488 match mask {
489 Mask::AllTrue(_) => match nullability {
490 Nullability::NonNullable => Validity::NonNullable,
491 Nullability::Nullable => Validity::AllValid,
492 },
493 Mask::AllFalse(_) => Validity::AllInvalid,
494 Mask::Values(values) => Validity::Array(values.into_array()),
495 }
496 }
497}
498
499impl IntoArray for Mask {
500 #[inline]
501 fn into_array(self) -> ArrayRef {
502 match self {
503 Self::AllTrue(len) => ConstantArray::new(true, len).into_array(),
504 Self::AllFalse(len) => ConstantArray::new(false, len).into_array(),
505 Self::Values(a) => a.into_array(),
506 }
507 }
508}
509
510impl IntoArray for &MaskValues {
511 #[inline]
512 fn into_array(self) -> ArrayRef {
513 BoolArray::new(self.bit_buffer().clone(), Validity::NonNullable).into_array()
514 }
515}
516
517#[cfg(test)]
518mod tests {
519 use rstest::rstest;
520 use vortex_buffer::Buffer;
521 use vortex_buffer::buffer;
522 use vortex_mask::Mask;
523
524 use crate::ArrayRef;
525 use crate::IntoArray;
526 use crate::LEGACY_SESSION;
527 use crate::VortexSessionExecute;
528 use crate::arrays::BoolArray;
529 use crate::arrays::PrimitiveArray;
530 use crate::dtype::Nullability;
531 use crate::validity::Validity;
532
533 #[rstest]
534 #[case(Validity::AllValid, 5, &[2, 4], Validity::AllValid, Validity::AllValid)]
535 #[case(
536 Validity::AllValid,
537 5,
538 &[2, 4],
539 Validity::AllInvalid,
540 Validity::Array(BoolArray::from_iter([true, true, false, true, false]).into_array())
541 )]
542 #[case(
543 Validity::AllValid,
544 5,
545 &[2, 4],
546 Validity::Array(BoolArray::from_iter([true, false]).into_array()),
547 Validity::Array(BoolArray::from_iter([true, true, true, true, false]).into_array())
548 )]
549 #[case(
550 Validity::AllInvalid,
551 5,
552 &[2, 4],
553 Validity::AllValid,
554 Validity::Array(BoolArray::from_iter([false, false, true, false, true]).into_array())
555 )]
556 #[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllInvalid, Validity::AllInvalid)]
557 #[case(
558 Validity::AllInvalid,
559 5,
560 &[2, 4],
561 Validity::Array(BoolArray::from_iter([true, false]).into_array()),
562 Validity::Array(BoolArray::from_iter([false, false, true, false, false]).into_array())
563 )]
564 #[case(
565 Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()),
566 5,
567 &[2, 4],
568 Validity::AllValid,
569 Validity::Array(BoolArray::from_iter([false, true, true, true, true]).into_array())
570 )]
571 #[case(
572 Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()),
573 5,
574 &[2, 4],
575 Validity::AllInvalid,
576 Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array())
577 )]
578 #[case(
579 Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()),
580 5,
581 &[2, 4],
582 Validity::Array(BoolArray::from_iter([true, false]).into_array()),
583 Validity::Array(BoolArray::from_iter([false, true, true, true, false]).into_array())
584 )]
585
586 fn patch_validity(
587 #[case] validity: Validity,
588 #[case] len: usize,
589 #[case] positions: &[u64],
590 #[case] patches: Validity,
591 #[case] expected: Validity,
592 ) {
593 let indices =
594 PrimitiveArray::new(Buffer::copy_from(positions), Validity::NonNullable).into_array();
595 assert_eq!(
596 validity
597 .patch(
598 len,
599 0,
600 &indices,
601 &patches,
602 &mut LEGACY_SESSION.create_execution_ctx()
603 )
604 .unwrap(),
605 expected
606 );
607 }
608
609 #[test]
610 #[should_panic]
611 fn out_of_bounds_patch() {
612 Validity::NonNullable
613 .patch(
614 2,
615 0,
616 &buffer![4].into_array(),
617 &Validity::AllInvalid,
618 &mut LEGACY_SESSION.create_execution_ctx(),
619 )
620 .unwrap();
621 }
622
623 #[test]
624 #[should_panic]
625 fn into_validity_nullable() {
626 Validity::from_mask(Mask::AllFalse(10), Nullability::NonNullable);
627 }
628
629 #[test]
630 #[should_panic]
631 fn into_validity_nullable_array() {
632 Validity::from_mask(Mask::from_iter(vec![true, false]), Nullability::NonNullable);
633 }
634
635 #[rstest]
636 #[case(
637 Validity::AllValid,
638 PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(),
639 Validity::from_iter(vec![true, false])
640 )]
641 #[case(Validity::AllValid, buffer![0, 1].into_array(), Validity::AllValid)]
642 #[case(
643 Validity::AllValid,
644 PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(),
645 Validity::AllInvalid
646 )]
647 #[case(
648 Validity::NonNullable,
649 PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(),
650 Validity::from_iter(vec![true, false])
651 )]
652 #[case(Validity::NonNullable, buffer![0, 1].into_array(), Validity::NonNullable)]
653 #[case(
654 Validity::NonNullable,
655 PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(),
656 Validity::AllInvalid
657 )]
658 fn validity_take(
659 #[case] validity: Validity,
660 #[case] indices: ArrayRef,
661 #[case] expected: Validity,
662 ) {
663 assert_eq!(validity.take(&indices).unwrap(), expected);
664 }
665}