1use std::fmt::Debug;
7use std::ops::Range;
8
9use vortex_buffer::BitBuffer;
10use vortex_dtype::DType;
11use vortex_dtype::Nullability;
12use vortex_error::VortexExpect as _;
13use vortex_error::VortexResult;
14use vortex_error::vortex_bail;
15use vortex_error::vortex_err;
16use vortex_error::vortex_panic;
17use vortex_mask::AllOr;
18use vortex_mask::Mask;
19use vortex_mask::MaskValues;
20
21use crate::Array;
22use crate::ArrayRef;
23use crate::Canonical;
24use crate::ExecutionCtx;
25use crate::IntoArray;
26use crate::ToCanonical;
27use crate::arrays::BoolArray;
28use crate::arrays::ConstantArray;
29use crate::arrays::ScalarFnArrayExt;
30use crate::builtins::ArrayBuiltins;
31use crate::compute::sum;
32use crate::expr::Binary;
33use crate::expr::Operator;
34use crate::optimizer::ArrayOptimizer;
35use crate::patches::Patches;
36use crate::scalar::Scalar;
37
38#[derive(Clone, Debug)]
40pub enum Validity {
41 NonNullable,
43 AllValid,
45 AllInvalid,
47 Array(ArrayRef),
51}
52
53impl Validity {
54 pub fn execute(self, ctx: &mut ExecutionCtx) -> VortexResult<Validity> {
56 match self {
57 v @ Validity::NonNullable | v @ Validity::AllValid | v @ Validity::AllInvalid => Ok(v),
58 Validity::Array(a) => Ok(Validity::Array(a.execute::<Canonical>(ctx)?.into_array())),
59 }
60 }
61}
62
63impl Validity {
64 pub const DTYPE: DType = DType::Bool(Nullability::NonNullable);
66
67 pub fn to_array(&self, len: usize) -> ArrayRef {
69 match self {
70 Self::NonNullable | Self::AllValid => ConstantArray::new(true, len).into_array(),
71 Self::AllInvalid => ConstantArray::new(false, len).into_array(),
72 Self::Array(a) => a.clone(),
73 }
74 }
75
76 #[inline]
78 pub fn into_array(self) -> Option<ArrayRef> {
79 if let Self::Array(a) = self {
80 Some(a)
81 } else {
82 None
83 }
84 }
85
86 #[inline]
88 pub fn as_array(&self) -> Option<&ArrayRef> {
89 if let Self::Array(a) = self {
90 Some(a)
91 } else {
92 None
93 }
94 }
95
96 #[inline]
97 pub fn nullability(&self) -> Nullability {
98 if matches!(self, Self::NonNullable) {
99 Nullability::NonNullable
100 } else {
101 Nullability::Nullable
102 }
103 }
104
105 #[inline]
107 pub fn union_nullability(self, nullability: Nullability) -> Self {
108 match nullability {
109 Nullability::NonNullable => self,
110 Nullability::Nullable => self.into_nullable(),
111 }
112 }
113
114 #[inline]
115 pub fn all_valid(&self, len: usize) -> VortexResult<bool> {
116 Ok(match self {
117 _ if len == 0 => true,
118 Validity::NonNullable | Validity::AllValid => true,
119 Validity::AllInvalid => false,
120 Validity::Array(array) => {
121 usize::try_from(&sum(array).vortex_expect("must have sum for bool array"))
122 .vortex_expect("sum must be a usize")
123 == array.len()
124 }
125 })
126 }
127
128 #[inline]
129 pub fn all_invalid(&self, len: usize) -> VortexResult<bool> {
130 Ok(match self {
131 _ if len == 0 => true,
132 Validity::NonNullable | Validity::AllValid => false,
133 Validity::AllInvalid => true,
134 Validity::Array(array) => {
135 usize::try_from(&sum(array).vortex_expect("must have sum for bool array"))
136 .vortex_expect("sum must be a usize")
137 == 0
138 }
139 })
140 }
141
142 #[inline]
144 pub fn is_valid(&self, index: usize) -> VortexResult<bool> {
145 Ok(match self {
146 Self::NonNullable | Self::AllValid => true,
147 Self::AllInvalid => false,
148 Self::Array(a) => a
149 .scalar_at(index)
150 .vortex_expect("Validity array must support scalar_at")
151 .as_bool()
152 .value()
153 .vortex_expect("Validity must be non-nullable"),
154 })
155 }
156
157 #[inline]
158 pub fn is_null(&self, index: usize) -> VortexResult<bool> {
159 Ok(!self.is_valid(index)?)
160 }
161
162 #[inline]
163 pub fn slice(&self, range: Range<usize>) -> VortexResult<Self> {
164 match self {
165 Self::Array(a) => Ok(Self::Array(a.slice(range)?)),
166 Self::NonNullable | Self::AllValid | Self::AllInvalid => Ok(self.clone()),
167 }
168 }
169
170 pub fn take(&self, indices: &dyn Array) -> VortexResult<Self> {
171 match self {
172 Self::NonNullable => match indices.validity_mask()?.bit_buffer() {
173 AllOr::All => {
174 if indices.dtype().is_nullable() {
175 Ok(Self::AllValid)
176 } else {
177 Ok(Self::NonNullable)
178 }
179 }
180 AllOr::None => Ok(Self::AllInvalid),
181 AllOr::Some(buf) => Ok(Validity::from(buf.clone())),
182 },
183 Self::AllValid => match indices.validity_mask()?.bit_buffer() {
184 AllOr::All => Ok(Self::AllValid),
185 AllOr::None => Ok(Self::AllInvalid),
186 AllOr::Some(buf) => Ok(Validity::from(buf.clone())),
187 },
188 Self::AllInvalid => Ok(Self::AllInvalid),
189 Self::Array(is_valid) => {
190 let maybe_is_valid = is_valid
191 .take(indices.to_array())?
192 .to_canonical()?
193 .into_array();
194 let is_valid = maybe_is_valid.fill_null(Scalar::from(false))?;
196 Ok(Self::Array(is_valid))
197 }
198 }
199 }
200
201 pub fn not(&self) -> VortexResult<Self> {
203 match self {
204 Validity::NonNullable => Ok(Validity::NonNullable),
205 Validity::AllValid => Ok(Validity::AllInvalid),
206 Validity::AllInvalid => Ok(Validity::AllValid),
207 Validity::Array(arr) => Ok(Validity::Array(arr.not()?)),
208 }
209 }
210
211 pub fn filter(&self, mask: &Mask) -> VortexResult<Self> {
219 match self {
222 v @ (Validity::NonNullable | Validity::AllValid | Validity::AllInvalid) => {
223 Ok(v.clone())
224 }
225 Validity::Array(arr) => Ok(Validity::Array(
226 arr.filter(mask.clone())?
227 .to_canonical()?
230 .into_array(),
231 )),
232 }
233 }
234
235 #[inline]
236 pub fn to_mask(&self, length: usize) -> Mask {
237 match self {
238 Self::NonNullable | Self::AllValid => Mask::AllTrue(length),
239 Self::AllInvalid => Mask::AllFalse(length),
240 Self::Array(is_valid) => {
241 assert_eq!(
242 is_valid.len(),
243 length,
244 "Validity::Array length must equal to_logical's argument: {}, {}.",
245 is_valid.len(),
246 length,
247 );
248 is_valid.to_bool().to_mask()
249 }
250 }
251 }
252
253 #[inline]
255 pub fn and(self, rhs: Validity) -> VortexResult<Validity> {
256 Ok(match (self, rhs) {
257 (Validity::NonNullable, Validity::NonNullable) => Validity::NonNullable,
259 (Validity::AllInvalid, _) | (_, Validity::AllInvalid) => Validity::AllInvalid,
261 (Validity::Array(a), Validity::AllValid)
263 | (Validity::Array(a), Validity::NonNullable)
264 | (Validity::NonNullable, Validity::Array(a))
265 | (Validity::AllValid, Validity::Array(a)) => Validity::Array(a),
266 (Validity::NonNullable, Validity::AllValid)
268 | (Validity::AllValid, Validity::NonNullable)
269 | (Validity::AllValid, Validity::AllValid) => Validity::AllValid,
270 (Validity::Array(lhs), Validity::Array(rhs)) => Validity::Array(
272 Binary
273 .try_new_array(lhs.len(), Operator::And, [lhs, rhs])?
274 .optimize()?,
275 ),
276 })
277 }
278
279 pub fn patch(
280 self,
281 len: usize,
282 indices_offset: usize,
283 indices: &dyn Array,
284 patches: &Validity,
285 ) -> VortexResult<Self> {
286 match (&self, patches) {
287 (Validity::NonNullable, Validity::NonNullable) => return Ok(Validity::NonNullable),
288 (Validity::NonNullable, _) => {
289 vortex_bail!("Can't patch a non-nullable validity with nullable validity")
290 }
291 (_, Validity::NonNullable) => {
292 vortex_bail!("Can't patch a nullable validity with non-nullable validity")
293 }
294 (Validity::AllValid, Validity::AllValid) => return Ok(Validity::AllValid),
295 (Validity::AllInvalid, Validity::AllInvalid) => return Ok(Validity::AllInvalid),
296 _ => {}
297 };
298
299 let own_nullability = if self == Validity::NonNullable {
300 Nullability::NonNullable
301 } else {
302 Nullability::Nullable
303 };
304
305 let source = match self {
306 Validity::NonNullable => BoolArray::from(BitBuffer::new_set(len)),
307 Validity::AllValid => BoolArray::from(BitBuffer::new_set(len)),
308 Validity::AllInvalid => BoolArray::from(BitBuffer::new_unset(len)),
309 Validity::Array(a) => a.to_bool(),
310 };
311
312 let patch_values = match patches {
313 Validity::NonNullable => BoolArray::from(BitBuffer::new_set(indices.len())),
314 Validity::AllValid => BoolArray::from(BitBuffer::new_set(indices.len())),
315 Validity::AllInvalid => BoolArray::from(BitBuffer::new_unset(indices.len())),
316 Validity::Array(a) => a.to_bool(),
317 };
318
319 let patches = Patches::new(
320 len,
321 indices_offset,
322 indices.to_array(),
323 patch_values.into_array(),
324 None,
326 )?;
327
328 Ok(Self::from_array(
329 source.patch(&patches)?.into_array(),
330 own_nullability,
331 ))
332 }
333
334 #[inline]
336 pub fn into_nullable(self) -> Validity {
337 match self {
338 Self::NonNullable => Self::AllValid,
339 Self::AllValid | Self::AllInvalid | Self::Array(_) => self,
340 }
341 }
342
343 #[inline]
345 pub fn into_non_nullable(self, len: usize) -> Option<Validity> {
346 match self {
347 _ if len == 0 => Some(Validity::NonNullable),
348 Self::NonNullable => Some(Self::NonNullable),
349 Self::AllValid => Some(Self::NonNullable),
350 Self::AllInvalid => None,
351 Self::Array(is_valid) => {
352 is_valid
353 .statistics()
354 .compute_min::<bool>()
355 .vortex_expect("validity array must support min")
356 .then(|| {
357 Self::NonNullable
359 })
360 }
361 }
362 }
363
364 #[inline]
366 pub fn cast_nullability(self, nullability: Nullability, len: usize) -> VortexResult<Validity> {
367 match nullability {
368 Nullability::NonNullable => self.into_non_nullable(len).ok_or_else(|| {
369 vortex_err!(InvalidArgument: "Cannot cast array with invalid values to non-nullable type.")
370 }),
371 Nullability::Nullable => Ok(self.into_nullable()),
372 }
373 }
374
375 #[inline]
377 pub fn copy_from_array(array: &dyn Array) -> VortexResult<Self> {
378 Ok(Validity::from_mask(
379 array.validity_mask()?,
380 array.dtype().nullability(),
381 ))
382 }
383
384 #[inline]
389 fn from_array(value: ArrayRef, nullability: Nullability) -> Self {
390 if !matches!(value.dtype(), DType::Bool(Nullability::NonNullable)) {
391 vortex_panic!("Expected a non-nullable boolean array")
392 }
393 match nullability {
394 Nullability::NonNullable => Self::NonNullable,
395 Nullability::Nullable => Self::Array(value),
396 }
397 }
398
399 #[inline]
401 pub fn maybe_len(&self) -> Option<usize> {
402 match self {
403 Self::NonNullable | Self::AllValid | Self::AllInvalid => None,
404 Self::Array(a) => Some(a.len()),
405 }
406 }
407
408 #[inline]
409 pub fn uncompressed_size(&self) -> usize {
410 if let Validity::Array(a) = self {
411 a.len().div_ceil(8)
412 } else {
413 0
414 }
415 }
416}
417
418impl PartialEq for Validity {
419 #[inline]
420 fn eq(&self, other: &Self) -> bool {
421 match (self, other) {
422 (Self::NonNullable, Self::NonNullable) => true,
423 (Self::AllValid, Self::AllValid) => true,
424 (Self::AllInvalid, Self::AllInvalid) => true,
425 (Self::Array(a), Self::Array(b)) => {
426 let a = a.to_bool();
427 let b = b.to_bool();
428 a.to_bit_buffer() == b.to_bit_buffer()
429 }
430 _ => false,
431 }
432 }
433}
434
435impl From<BitBuffer> for Validity {
436 #[inline]
437 fn from(value: BitBuffer) -> Self {
438 let true_count = value.true_count();
439 if true_count == value.len() {
440 Self::AllValid
441 } else if true_count == 0 {
442 Self::AllInvalid
443 } else {
444 Self::Array(BoolArray::from(value).into_array())
445 }
446 }
447}
448
449impl FromIterator<Mask> for Validity {
450 #[inline]
451 fn from_iter<T: IntoIterator<Item = Mask>>(iter: T) -> Self {
452 Validity::from_mask(iter.into_iter().collect(), Nullability::Nullable)
453 }
454}
455
456impl FromIterator<bool> for Validity {
457 #[inline]
458 fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
459 Validity::from(BitBuffer::from_iter(iter))
460 }
461}
462
463impl From<Nullability> for Validity {
464 #[inline]
465 fn from(value: Nullability) -> Self {
466 Validity::from(&value)
467 }
468}
469
470impl From<&Nullability> for Validity {
471 #[inline]
472 fn from(value: &Nullability) -> Self {
473 match *value {
474 Nullability::NonNullable => Validity::NonNullable,
475 Nullability::Nullable => Validity::AllValid,
476 }
477 }
478}
479
480impl Validity {
481 pub fn from_bit_buffer(buffer: BitBuffer, nullability: Nullability) -> Self {
482 if buffer.true_count() == buffer.len() {
483 nullability.into()
484 } else if buffer.true_count() == 0 {
485 Validity::AllInvalid
486 } else {
487 Validity::Array(BoolArray::new(buffer, Validity::NonNullable).into_array())
488 }
489 }
490
491 pub fn from_mask(mask: Mask, nullability: Nullability) -> Self {
492 assert!(
493 nullability == Nullability::Nullable || matches!(mask, Mask::AllTrue(_)),
494 "NonNullable validity must be AllValid",
495 );
496 match mask {
497 Mask::AllTrue(_) => match nullability {
498 Nullability::NonNullable => Validity::NonNullable,
499 Nullability::Nullable => Validity::AllValid,
500 },
501 Mask::AllFalse(_) => Validity::AllInvalid,
502 Mask::Values(values) => Validity::Array(values.into_array()),
503 }
504 }
505}
506
507impl IntoArray for Mask {
508 #[inline]
509 fn into_array(self) -> ArrayRef {
510 match self {
511 Self::AllTrue(len) => ConstantArray::new(true, len).into_array(),
512 Self::AllFalse(len) => ConstantArray::new(false, len).into_array(),
513 Self::Values(a) => a.into_array(),
514 }
515 }
516}
517
518impl IntoArray for &MaskValues {
519 #[inline]
520 fn into_array(self) -> ArrayRef {
521 BoolArray::new(self.bit_buffer().clone(), Validity::NonNullable).into_array()
522 }
523}
524
525#[cfg(test)]
526mod tests {
527 use rstest::rstest;
528 use vortex_buffer::Buffer;
529 use vortex_buffer::buffer;
530 use vortex_dtype::Nullability;
531 use vortex_mask::Mask;
532
533 use crate::ArrayRef;
534 use crate::IntoArray;
535 use crate::arrays::BoolArray;
536 use crate::arrays::PrimitiveArray;
537 use crate::validity::Validity;
538
539 #[rstest]
540 #[case(Validity::AllValid, 5, &[2, 4], Validity::AllValid, Validity::AllValid)]
541 #[case(
542 Validity::AllValid,
543 5,
544 &[2, 4],
545 Validity::AllInvalid,
546 Validity::Array(BoolArray::from_iter([true, true, false, true, false]).into_array())
547 )]
548 #[case(
549 Validity::AllValid,
550 5,
551 &[2, 4],
552 Validity::Array(BoolArray::from_iter([true, false]).into_array()),
553 Validity::Array(BoolArray::from_iter([true, true, true, true, false]).into_array())
554 )]
555 #[case(
556 Validity::AllInvalid,
557 5,
558 &[2, 4],
559 Validity::AllValid,
560 Validity::Array(BoolArray::from_iter([false, false, true, false, true]).into_array())
561 )]
562 #[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllInvalid, Validity::AllInvalid)]
563 #[case(
564 Validity::AllInvalid,
565 5,
566 &[2, 4],
567 Validity::Array(BoolArray::from_iter([true, false]).into_array()),
568 Validity::Array(BoolArray::from_iter([false, false, true, false, false]).into_array())
569 )]
570 #[case(
571 Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()),
572 5,
573 &[2, 4],
574 Validity::AllValid,
575 Validity::Array(BoolArray::from_iter([false, true, true, true, true]).into_array())
576 )]
577 #[case(
578 Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()),
579 5,
580 &[2, 4],
581 Validity::AllInvalid,
582 Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array())
583 )]
584 #[case(
585 Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()),
586 5,
587 &[2, 4],
588 Validity::Array(BoolArray::from_iter([true, false]).into_array()),
589 Validity::Array(BoolArray::from_iter([false, true, true, true, false]).into_array())
590 )]
591
592 fn patch_validity(
593 #[case] validity: Validity,
594 #[case] len: usize,
595 #[case] positions: &[u64],
596 #[case] patches: Validity,
597 #[case] expected: Validity,
598 ) {
599 let indices =
600 PrimitiveArray::new(Buffer::copy_from(positions), Validity::NonNullable).into_array();
601 assert_eq!(
602 validity.patch(len, 0, &indices, &patches).unwrap(),
603 expected
604 );
605 }
606
607 #[test]
608 #[should_panic]
609 fn out_of_bounds_patch() {
610 Validity::NonNullable
611 .patch(2, 0, &buffer![4].into_array(), &Validity::AllInvalid)
612 .unwrap();
613 }
614
615 #[test]
616 #[should_panic]
617 fn into_validity_nullable() {
618 Validity::from_mask(Mask::AllFalse(10), Nullability::NonNullable);
619 }
620
621 #[test]
622 #[should_panic]
623 fn into_validity_nullable_array() {
624 Validity::from_mask(Mask::from_iter(vec![true, false]), Nullability::NonNullable);
625 }
626
627 #[rstest]
628 #[case(
629 Validity::AllValid,
630 PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(),
631 Validity::from_iter(vec![true, false])
632 )]
633 #[case(Validity::AllValid, buffer![0, 1].into_array(), Validity::AllValid)]
634 #[case(
635 Validity::AllValid,
636 PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(),
637 Validity::AllInvalid
638 )]
639 #[case(
640 Validity::NonNullable,
641 PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(),
642 Validity::from_iter(vec![true, false])
643 )]
644 #[case(Validity::NonNullable, buffer![0, 1].into_array(), Validity::NonNullable)]
645 #[case(
646 Validity::NonNullable,
647 PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(),
648 Validity::AllInvalid
649 )]
650 fn validity_take(
651 #[case] validity: Validity,
652 #[case] indices: ArrayRef,
653 #[case] expected: Validity,
654 ) {
655 assert_eq!(validity.take(&indices).unwrap(), expected);
656 }
657}