1use std::fmt::Debug;
7use std::ops::Range;
8
9use vortex_buffer::BitBuffer;
10use vortex_error::VortexExpect as _;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_error::vortex_err;
14use vortex_error::vortex_panic;
15use vortex_mask::AllOr;
16use vortex_mask::Mask;
17use vortex_mask::MaskValues;
18
19use crate::ArrayRef;
20use crate::Canonical;
21use crate::ExecutionCtx;
22use crate::IntoArray;
23use crate::ToCanonical;
24use crate::arrays::BoolArray;
25use crate::arrays::ConstantArray;
26use crate::arrays::bool::BoolArrayExt;
27use crate::arrays::scalar_fn::ScalarFnFactoryExt;
28use crate::builtins::ArrayBuiltins;
29use crate::dtype::DType;
30use crate::dtype::Nullability;
31use crate::optimizer::ArrayOptimizer;
32use crate::patches::Patches;
33use crate::scalar::Scalar;
34use crate::scalar_fn::fns::binary::Binary;
35use crate::scalar_fn::fns::operators::Operator;
36
37#[derive(Clone)]
39pub enum Validity {
40 NonNullable,
42 AllValid,
44 AllInvalid,
46 Array(ArrayRef),
50}
51
52impl Debug for Validity {
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 match self {
55 Self::NonNullable => write!(f, "NonNullable"),
56 Self::AllValid => write!(f, "AllValid"),
57 Self::AllInvalid => write!(f, "AllInvalid"),
58 Self::Array(arr) => write!(f, "SomeValid({})", arr.display_values()),
59 }
60 }
61}
62
63impl Validity {
64 pub fn execute(self, ctx: &mut ExecutionCtx) -> VortexResult<Validity> {
66 match self {
67 v @ Validity::NonNullable | v @ Validity::AllValid | v @ Validity::AllInvalid => Ok(v),
68 Validity::Array(a) => Ok(Validity::Array(a.execute::<Canonical>(ctx)?.into_array())),
69 }
70 }
71}
72
73impl Validity {
74 pub const DTYPE: DType = DType::Bool(Nullability::NonNullable);
76
77 pub fn to_array(&self, len: usize) -> ArrayRef {
79 match self {
80 Self::NonNullable | Self::AllValid => ConstantArray::new(true, len).into_array(),
81 Self::AllInvalid => ConstantArray::new(false, len).into_array(),
82 Self::Array(a) => a.clone(),
83 }
84 }
85
86 #[inline]
88 pub fn into_array(self) -> Option<ArrayRef> {
89 if let Self::Array(a) = self {
90 Some(a)
91 } else {
92 None
93 }
94 }
95
96 #[inline]
98 pub fn as_array(&self) -> Option<&ArrayRef> {
99 if let Self::Array(a) = self {
100 Some(a)
101 } else {
102 None
103 }
104 }
105
106 #[inline]
107 pub fn nullability(&self) -> Nullability {
108 if matches!(self, Self::NonNullable) {
109 Nullability::NonNullable
110 } else {
111 Nullability::Nullable
112 }
113 }
114
115 #[inline]
117 pub fn union_nullability(self, nullability: Nullability) -> Self {
118 match nullability {
119 Nullability::NonNullable => self,
120 Nullability::Nullable => self.into_nullable(),
121 }
122 }
123
124 #[inline]
126 pub fn is_valid(&self, index: usize) -> VortexResult<bool> {
127 Ok(match self {
128 Self::NonNullable | Self::AllValid => true,
129 Self::AllInvalid => false,
130 Self::Array(a) => a
131 .scalar_at(index)
132 .vortex_expect("Validity array must support scalar_at")
133 .as_bool()
134 .value()
135 .vortex_expect("Validity must be non-nullable"),
136 })
137 }
138
139 #[inline]
140 pub fn is_null(&self, index: usize) -> VortexResult<bool> {
141 Ok(!self.is_valid(index)?)
142 }
143
144 #[inline]
145 pub fn slice(&self, range: Range<usize>) -> VortexResult<Self> {
146 match self {
147 Self::Array(a) => Ok(Self::Array(a.slice(range)?)),
148 Self::NonNullable | Self::AllValid | Self::AllInvalid => Ok(self.clone()),
149 }
150 }
151
152 pub fn take(&self, indices: &ArrayRef) -> VortexResult<Self> {
153 match self {
154 Self::NonNullable => match indices.validity_mask()?.bit_buffer() {
155 AllOr::All => {
156 if indices.dtype().is_nullable() {
157 Ok(Self::AllValid)
158 } else {
159 Ok(Self::NonNullable)
160 }
161 }
162 AllOr::None => Ok(Self::AllInvalid),
163 AllOr::Some(buf) => Ok(Validity::from(buf.clone())),
164 },
165 Self::AllValid => match indices.validity_mask()?.bit_buffer() {
166 AllOr::All => Ok(Self::AllValid),
167 AllOr::None => Ok(Self::AllInvalid),
168 AllOr::Some(buf) => Ok(Validity::from(buf.clone())),
169 },
170 Self::AllInvalid => Ok(Self::AllInvalid),
171 Self::Array(is_valid) => {
172 let maybe_is_valid = is_valid.take(indices.clone())?;
173 let is_valid = maybe_is_valid.fill_null(Scalar::from(false))?;
175 Ok(Self::Array(is_valid))
176 }
177 }
178 }
179
180 pub fn not(&self) -> VortexResult<Self> {
182 match self {
183 Validity::NonNullable => Ok(Validity::NonNullable),
184 Validity::AllValid => Ok(Validity::AllInvalid),
185 Validity::AllInvalid => Ok(Validity::AllValid),
186 Validity::Array(arr) => Ok(Validity::Array(arr.not()?)),
187 }
188 }
189
190 pub fn filter(&self, mask: &Mask) -> VortexResult<Self> {
198 match self {
201 v @ (Validity::NonNullable | Validity::AllValid | Validity::AllInvalid) => {
202 Ok(v.clone())
203 }
204 Validity::Array(arr) => Ok(Validity::Array(arr.filter(mask.clone())?)),
205 }
206 }
207
208 pub fn to_mask(&self, length: usize) -> Mask {
212 match self {
213 Self::NonNullable | Self::AllValid => Mask::new_true(length),
214 Self::AllInvalid => Mask::new_false(length),
215 Self::Array(a) => a.to_bool().to_mask(),
216 }
217 }
218
219 pub fn execute_mask(&self, length: usize, ctx: &mut ExecutionCtx) -> VortexResult<Mask> {
220 match self {
221 Self::NonNullable | Self::AllValid => Ok(Mask::AllTrue(length)),
222 Self::AllInvalid => Ok(Mask::AllFalse(length)),
223 Self::Array(arr) => {
224 assert_eq!(
225 arr.len(),
226 length,
227 "Validity::Array length must equal to_logical's argument: {}, {}.",
228 arr.len(),
229 length,
230 );
231 arr.clone().execute::<Mask>(ctx)
234 }
235 }
236 }
237
238 pub fn mask_eq(&self, other: &Validity, ctx: &mut ExecutionCtx) -> VortexResult<bool> {
240 match (self, other) {
241 (Validity::NonNullable, Validity::NonNullable) => Ok(true),
242 (Validity::AllValid, Validity::AllValid) => Ok(true),
243 (Validity::AllInvalid, Validity::AllInvalid) => Ok(true),
244 (Validity::Array(a), Validity::Array(b)) => {
245 let a = a.clone().execute::<Mask>(ctx)?;
246 let b = b.clone().execute::<Mask>(ctx)?;
247 Ok(a == b)
248 }
249 _ => Ok(false),
250 }
251 }
252
253 #[inline]
255 pub fn and(self, rhs: Validity) -> VortexResult<Validity> {
256 Ok(match (self, rhs) {
257 (Validity::NonNullable, Validity::NonNullable) => Validity::NonNullable,
259 (Validity::AllInvalid, _) | (_, Validity::AllInvalid) => Validity::AllInvalid,
261 (Validity::Array(a), Validity::AllValid)
263 | (Validity::Array(a), Validity::NonNullable)
264 | (Validity::NonNullable, Validity::Array(a))
265 | (Validity::AllValid, Validity::Array(a)) => Validity::Array(a),
266 (Validity::NonNullable, Validity::AllValid)
268 | (Validity::AllValid, Validity::NonNullable)
269 | (Validity::AllValid, Validity::AllValid) => Validity::AllValid,
270 (Validity::Array(lhs), Validity::Array(rhs)) => Validity::Array(
272 Binary
273 .try_new_array(lhs.len(), Operator::And, [lhs, rhs])?
274 .optimize()?,
275 ),
276 })
277 }
278
279 pub fn patch(
280 self,
281 len: usize,
282 indices_offset: usize,
283 indices: &ArrayRef,
284 patches: &Validity,
285 ctx: &mut ExecutionCtx,
286 ) -> VortexResult<Self> {
287 match (&self, patches) {
288 (Validity::NonNullable, Validity::NonNullable) => return Ok(Validity::NonNullable),
289 (Validity::NonNullable, _) => {
290 vortex_bail!("Can't patch a non-nullable validity with nullable validity")
291 }
292 (_, Validity::NonNullable) => {
293 vortex_bail!("Can't patch a nullable validity with non-nullable validity")
294 }
295 (Validity::AllValid, Validity::AllValid) => return Ok(Validity::AllValid),
296 (Validity::AllInvalid, Validity::AllInvalid) => return Ok(Validity::AllInvalid),
297 _ => {}
298 };
299
300 let own_nullability = if matches!(self, Validity::NonNullable) {
301 Nullability::NonNullable
302 } else {
303 Nullability::Nullable
304 };
305
306 let source = match self {
307 Validity::NonNullable => BoolArray::from(BitBuffer::new_set(len)),
308 Validity::AllValid => BoolArray::from(BitBuffer::new_set(len)),
309 Validity::AllInvalid => BoolArray::from(BitBuffer::new_unset(len)),
310 Validity::Array(a) => a.execute::<BoolArray>(ctx)?,
311 };
312
313 let patch_values = match patches {
314 Validity::NonNullable => BoolArray::from(BitBuffer::new_set(indices.len())),
315 Validity::AllValid => BoolArray::from(BitBuffer::new_set(indices.len())),
316 Validity::AllInvalid => BoolArray::from(BitBuffer::new_unset(indices.len())),
317 Validity::Array(a) => a.clone().execute::<BoolArray>(ctx)?,
318 };
319
320 let patches = Patches::new(
321 len,
322 indices_offset,
323 indices.clone(),
324 patch_values.into_array(),
325 None,
327 )?;
328
329 Ok(Self::from_array(
330 source.patch(&patches, ctx)?.into_array(),
331 own_nullability,
332 ))
333 }
334
335 #[inline]
337 pub fn into_nullable(self) -> Validity {
338 match self {
339 Self::NonNullable => Self::AllValid,
340 Self::AllValid | Self::AllInvalid | Self::Array(_) => self,
341 }
342 }
343
344 #[inline]
346 pub fn into_non_nullable(self, len: usize) -> Option<Validity> {
347 match self {
348 _ if len == 0 => Some(Validity::NonNullable),
349 Self::NonNullable => Some(Self::NonNullable),
350 Self::AllValid => Some(Self::NonNullable),
351 Self::AllInvalid => None,
352 Self::Array(is_valid) => {
353 is_valid
354 .statistics()
355 .compute_min::<bool>()
356 .vortex_expect("validity array must support min")
357 .then(|| {
358 Self::NonNullable
360 })
361 }
362 }
363 }
364
365 #[inline]
367 pub fn cast_nullability(self, nullability: Nullability, len: usize) -> VortexResult<Validity> {
368 match nullability {
369 Nullability::NonNullable => self.into_non_nullable(len).ok_or_else(|| {
370 vortex_err!(InvalidArgument: "Cannot cast array with invalid values to non-nullable type.")
371 }),
372 Nullability::Nullable => Ok(self.into_nullable()),
373 }
374 }
375
376 #[inline]
378 pub fn copy_from_array(array: &ArrayRef) -> VortexResult<Self> {
379 Ok(Validity::from_mask(
380 array.validity_mask()?,
381 array.dtype().nullability(),
382 ))
383 }
384
385 fn from_array(value: ArrayRef, nullability: Nullability) -> Self {
390 if !matches!(value.dtype(), DType::Bool(Nullability::NonNullable)) {
391 vortex_panic!("Expected a non-nullable boolean array")
392 }
393 match nullability {
394 Nullability::NonNullable => Self::NonNullable,
395 Nullability::Nullable => Self::Array(value),
396 }
397 }
398
399 #[inline]
401 pub fn maybe_len(&self) -> Option<usize> {
402 match self {
403 Self::NonNullable | Self::AllValid | Self::AllInvalid => None,
404 Self::Array(a) => Some(a.len()),
405 }
406 }
407
408 #[inline]
409 pub fn uncompressed_size(&self) -> usize {
410 if let Validity::Array(a) = self {
411 a.len().div_ceil(8)
412 } else {
413 0
414 }
415 }
416}
417
418impl From<BitBuffer> for Validity {
419 #[inline]
420 fn from(value: BitBuffer) -> Self {
421 let true_count = value.true_count();
422 if true_count == value.len() {
423 Self::AllValid
424 } else if true_count == 0 {
425 Self::AllInvalid
426 } else {
427 Self::Array(BoolArray::from(value).into_array())
428 }
429 }
430}
431
432impl FromIterator<Mask> for Validity {
433 #[inline]
434 fn from_iter<T: IntoIterator<Item = Mask>>(iter: T) -> Self {
435 Validity::from_mask(iter.into_iter().collect(), Nullability::Nullable)
436 }
437}
438
439impl FromIterator<bool> for Validity {
440 #[inline]
441 fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
442 Validity::from(BitBuffer::from_iter(iter))
443 }
444}
445
446impl From<Nullability> for Validity {
447 #[inline]
448 fn from(value: Nullability) -> Self {
449 Validity::from(&value)
450 }
451}
452
453impl From<&Nullability> for Validity {
454 #[inline]
455 fn from(value: &Nullability) -> Self {
456 match *value {
457 Nullability::NonNullable => Validity::NonNullable,
458 Nullability::Nullable => Validity::AllValid,
459 }
460 }
461}
462
463impl Validity {
464 pub fn from_bit_buffer(buffer: BitBuffer, nullability: Nullability) -> Self {
465 if buffer.true_count() == buffer.len() {
466 nullability.into()
467 } else if buffer.true_count() == 0 {
468 Validity::AllInvalid
469 } else {
470 Validity::Array(BoolArray::new(buffer, Validity::NonNullable).into_array())
471 }
472 }
473
474 pub fn from_mask(mask: Mask, nullability: Nullability) -> Self {
475 assert!(
476 nullability == Nullability::Nullable || matches!(mask, Mask::AllTrue(_)),
477 "NonNullable validity must be AllValid",
478 );
479 match mask {
480 Mask::AllTrue(_) => match nullability {
481 Nullability::NonNullable => Validity::NonNullable,
482 Nullability::Nullable => Validity::AllValid,
483 },
484 Mask::AllFalse(_) => Validity::AllInvalid,
485 Mask::Values(values) => Validity::Array(values.into_array()),
486 }
487 }
488}
489
490impl IntoArray for Mask {
491 #[inline]
492 fn into_array(self) -> ArrayRef {
493 match self {
494 Self::AllTrue(len) => ConstantArray::new(true, len).into_array(),
495 Self::AllFalse(len) => ConstantArray::new(false, len).into_array(),
496 Self::Values(a) => a.into_array(),
497 }
498 }
499}
500
501impl IntoArray for &MaskValues {
502 #[inline]
503 fn into_array(self) -> ArrayRef {
504 BoolArray::new(self.bit_buffer().clone(), Validity::NonNullable).into_array()
505 }
506}
507
508#[cfg(test)]
509mod tests {
510 use rstest::rstest;
511 use vortex_buffer::Buffer;
512 use vortex_buffer::buffer;
513 use vortex_mask::Mask;
514
515 use crate::ArrayRef;
516 use crate::IntoArray;
517 use crate::LEGACY_SESSION;
518 use crate::VortexSessionExecute;
519 use crate::arrays::PrimitiveArray;
520 use crate::dtype::Nullability;
521 use crate::validity::BoolArray;
522 use crate::validity::Validity;
523
524 #[rstest]
525 #[case(Validity::AllValid, 5, &[2, 4], Validity::AllValid, Validity::AllValid)]
526 #[case(
527 Validity::AllValid,
528 5,
529 &[2, 4],
530 Validity::AllInvalid,
531 Validity::Array(BoolArray::from_iter([true, true, false, true, false]).into_array())
532 )]
533 #[case(
534 Validity::AllValid,
535 5,
536 &[2, 4],
537 Validity::Array(BoolArray::from_iter([true, false]).into_array()),
538 Validity::Array(BoolArray::from_iter([true, true, true, true, false]).into_array())
539 )]
540 #[case(
541 Validity::AllInvalid,
542 5,
543 &[2, 4],
544 Validity::AllValid,
545 Validity::Array(BoolArray::from_iter([false, false, true, false, true]).into_array())
546 )]
547 #[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllInvalid, Validity::AllInvalid)]
548 #[case(
549 Validity::AllInvalid,
550 5,
551 &[2, 4],
552 Validity::Array(BoolArray::from_iter([true, false]).into_array()),
553 Validity::Array(BoolArray::from_iter([false, false, true, false, false]).into_array())
554 )]
555 #[case(
556 Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()),
557 5,
558 &[2, 4],
559 Validity::AllValid,
560 Validity::Array(BoolArray::from_iter([false, true, true, true, true]).into_array())
561 )]
562 #[case(
563 Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()),
564 5,
565 &[2, 4],
566 Validity::AllInvalid,
567 Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array())
568 )]
569 #[case(
570 Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()),
571 5,
572 &[2, 4],
573 Validity::Array(BoolArray::from_iter([true, false]).into_array()),
574 Validity::Array(BoolArray::from_iter([false, true, true, true, false]).into_array())
575 )]
576
577 fn patch_validity(
578 #[case] validity: Validity,
579 #[case] len: usize,
580 #[case] positions: &[u64],
581 #[case] patches: Validity,
582 #[case] expected: Validity,
583 ) {
584 let indices =
585 PrimitiveArray::new(Buffer::copy_from(positions), Validity::NonNullable).into_array();
586
587 let mut ctx = LEGACY_SESSION.create_execution_ctx();
588
589 assert!(
590 validity
591 .patch(
592 len,
593 0,
594 &indices,
595 &patches,
596 &mut LEGACY_SESSION.create_execution_ctx(),
597 )
598 .unwrap()
599 .mask_eq(&expected, &mut ctx)
600 .unwrap()
601 );
602 }
603
604 #[test]
605 #[should_panic]
606 fn out_of_bounds_patch() {
607 Validity::NonNullable
608 .patch(
609 2,
610 0,
611 &buffer![4].into_array(),
612 &Validity::AllInvalid,
613 &mut LEGACY_SESSION.create_execution_ctx(),
614 )
615 .unwrap();
616 }
617
618 #[test]
619 #[should_panic]
620 fn into_validity_nullable() {
621 Validity::from_mask(Mask::AllFalse(10), Nullability::NonNullable);
622 }
623
624 #[test]
625 #[should_panic]
626 fn into_validity_nullable_array() {
627 Validity::from_mask(Mask::from_iter(vec![true, false]), Nullability::NonNullable);
628 }
629
630 #[rstest]
631 #[case(
632 Validity::AllValid,
633 PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(),
634 Validity::from_iter(vec![true, false])
635 )]
636 #[case(Validity::AllValid, buffer![0, 1].into_array(), Validity::AllValid)]
637 #[case(
638 Validity::AllValid,
639 PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(),
640 Validity::AllInvalid
641 )]
642 #[case(
643 Validity::NonNullable,
644 PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(),
645 Validity::from_iter(vec![true, false])
646 )]
647 #[case(Validity::NonNullable, buffer![0, 1].into_array(), Validity::NonNullable)]
648 #[case(
649 Validity::NonNullable,
650 PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(),
651 Validity::AllInvalid
652 )]
653 fn validity_take(
654 #[case] validity: Validity,
655 #[case] indices: ArrayRef,
656 #[case] expected: Validity,
657 ) {
658 let mut ctx = LEGACY_SESSION.create_execution_ctx();
659 assert!(
660 validity
661 .take(&indices)
662 .unwrap()
663 .mask_eq(&expected, &mut ctx)
664 .unwrap()
665 );
666 }
667}