1use std::fmt::Debug;
7use std::ops::Range;
8
9use vortex_buffer::BitBuffer;
10use vortex_error::VortexExpect as _;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_error::vortex_err;
14use vortex_error::vortex_panic;
15use vortex_mask::AllOr;
16use vortex_mask::Mask;
17use vortex_mask::MaskValues;
18
19use crate::ArrayRef;
20use crate::Canonical;
21use crate::DynArray;
22use crate::ExecutionCtx;
23use crate::IntoArray;
24use crate::arrays::BoolArray;
25use crate::arrays::ConstantArray;
26use crate::arrays::scalar_fn::ScalarFnArrayExt;
27use crate::builtins::ArrayBuiltins;
28use crate::dtype::DType;
29use crate::dtype::Nullability;
30use crate::optimizer::ArrayOptimizer;
31use crate::patches::Patches;
32use crate::scalar::Scalar;
33use crate::scalar_fn::fns::binary::Binary;
34use crate::scalar_fn::fns::operators::Operator;
35
36#[derive(Clone)]
38pub enum Validity {
39 NonNullable,
41 AllValid,
43 AllInvalid,
45 Array(ArrayRef),
49}
50
51impl Debug for Validity {
52 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53 match self {
54 Self::NonNullable => write!(f, "NonNullable"),
55 Self::AllValid => write!(f, "AllValid"),
56 Self::AllInvalid => write!(f, "AllInvalid"),
57 Self::Array(arr) => write!(f, "SomeValid({})", arr.as_ref().display_values()),
58 }
59 }
60}
61
62impl Validity {
63 pub fn execute(self, ctx: &mut ExecutionCtx) -> VortexResult<Validity> {
65 match self {
66 v @ Validity::NonNullable | v @ Validity::AllValid | v @ Validity::AllInvalid => Ok(v),
67 Validity::Array(a) => Ok(Validity::Array(a.execute::<Canonical>(ctx)?.into_array())),
68 }
69 }
70}
71
72impl Validity {
73 pub const DTYPE: DType = DType::Bool(Nullability::NonNullable);
75
76 pub fn to_array(&self, len: usize) -> ArrayRef {
78 match self {
79 Self::NonNullable | Self::AllValid => ConstantArray::new(true, len).into_array(),
80 Self::AllInvalid => ConstantArray::new(false, len).into_array(),
81 Self::Array(a) => a.clone(),
82 }
83 }
84
85 #[inline]
87 pub fn into_array(self) -> Option<ArrayRef> {
88 if let Self::Array(a) = self {
89 Some(a)
90 } else {
91 None
92 }
93 }
94
95 #[inline]
97 pub fn as_array(&self) -> Option<&ArrayRef> {
98 if let Self::Array(a) = self {
99 Some(a)
100 } else {
101 None
102 }
103 }
104
105 #[inline]
106 pub fn nullability(&self) -> Nullability {
107 if matches!(self, Self::NonNullable) {
108 Nullability::NonNullable
109 } else {
110 Nullability::Nullable
111 }
112 }
113
114 #[inline]
116 pub fn union_nullability(self, nullability: Nullability) -> Self {
117 match nullability {
118 Nullability::NonNullable => self,
119 Nullability::Nullable => self.into_nullable(),
120 }
121 }
122
123 #[inline]
125 pub fn is_valid(&self, index: usize) -> VortexResult<bool> {
126 Ok(match self {
127 Self::NonNullable | Self::AllValid => true,
128 Self::AllInvalid => false,
129 Self::Array(a) => a
130 .scalar_at(index)
131 .vortex_expect("Validity array must support scalar_at")
132 .as_bool()
133 .value()
134 .vortex_expect("Validity must be non-nullable"),
135 })
136 }
137
138 #[inline]
139 pub fn is_null(&self, index: usize) -> VortexResult<bool> {
140 Ok(!self.is_valid(index)?)
141 }
142
143 #[inline]
144 pub fn slice(&self, range: Range<usize>) -> VortexResult<Self> {
145 match self {
146 Self::Array(a) => Ok(Self::Array(a.slice(range)?)),
147 Self::NonNullable | Self::AllValid | Self::AllInvalid => Ok(self.clone()),
148 }
149 }
150
151 pub fn take(&self, indices: &ArrayRef) -> VortexResult<Self> {
152 match self {
153 Self::NonNullable => match indices.validity_mask()?.bit_buffer() {
154 AllOr::All => {
155 if indices.dtype().is_nullable() {
156 Ok(Self::AllValid)
157 } else {
158 Ok(Self::NonNullable)
159 }
160 }
161 AllOr::None => Ok(Self::AllInvalid),
162 AllOr::Some(buf) => Ok(Validity::from(buf.clone())),
163 },
164 Self::AllValid => match indices.validity_mask()?.bit_buffer() {
165 AllOr::All => Ok(Self::AllValid),
166 AllOr::None => Ok(Self::AllInvalid),
167 AllOr::Some(buf) => Ok(Validity::from(buf.clone())),
168 },
169 Self::AllInvalid => Ok(Self::AllInvalid),
170 Self::Array(is_valid) => {
171 let maybe_is_valid = is_valid.take(indices.to_array())?;
172 let is_valid = maybe_is_valid.fill_null(Scalar::from(false))?;
174 Ok(Self::Array(is_valid))
175 }
176 }
177 }
178
179 pub fn not(&self) -> VortexResult<Self> {
181 match self {
182 Validity::NonNullable => Ok(Validity::NonNullable),
183 Validity::AllValid => Ok(Validity::AllInvalid),
184 Validity::AllInvalid => Ok(Validity::AllValid),
185 Validity::Array(arr) => Ok(Validity::Array(arr.not()?)),
186 }
187 }
188
189 pub fn filter(&self, mask: &Mask) -> VortexResult<Self> {
197 match self {
200 v @ (Validity::NonNullable | Validity::AllValid | Validity::AllInvalid) => {
201 Ok(v.clone())
202 }
203 Validity::Array(arr) => Ok(Validity::Array(arr.filter(mask.clone())?)),
204 }
205 }
206
207 pub fn execute_mask(&self, length: usize, ctx: &mut ExecutionCtx) -> VortexResult<Mask> {
208 match self {
209 Self::NonNullable | Self::AllValid => Ok(Mask::AllTrue(length)),
210 Self::AllInvalid => Ok(Mask::AllFalse(length)),
211 Self::Array(arr) => {
212 assert_eq!(
213 arr.len(),
214 length,
215 "Validity::Array length must equal to_logical's argument: {}, {}.",
216 arr.len(),
217 length,
218 );
219 arr.clone().execute::<Mask>(ctx)
222 }
223 }
224 }
225
226 pub fn mask_eq(&self, other: &Validity, ctx: &mut ExecutionCtx) -> VortexResult<bool> {
228 match (self, other) {
229 (Validity::NonNullable, Validity::NonNullable) => Ok(true),
230 (Validity::AllValid, Validity::AllValid) => Ok(true),
231 (Validity::AllInvalid, Validity::AllInvalid) => Ok(true),
232 (Validity::Array(a), Validity::Array(b)) => {
233 let a = a.clone().execute::<Mask>(ctx)?;
234 let b = b.clone().execute::<Mask>(ctx)?;
235 Ok(a == b)
236 }
237 _ => Ok(false),
238 }
239 }
240
241 #[inline]
243 pub fn and(self, rhs: Validity) -> VortexResult<Validity> {
244 Ok(match (self, rhs) {
245 (Validity::NonNullable, Validity::NonNullable) => Validity::NonNullable,
247 (Validity::AllInvalid, _) | (_, Validity::AllInvalid) => Validity::AllInvalid,
249 (Validity::Array(a), Validity::AllValid)
251 | (Validity::Array(a), Validity::NonNullable)
252 | (Validity::NonNullable, Validity::Array(a))
253 | (Validity::AllValid, Validity::Array(a)) => Validity::Array(a),
254 (Validity::NonNullable, Validity::AllValid)
256 | (Validity::AllValid, Validity::NonNullable)
257 | (Validity::AllValid, Validity::AllValid) => Validity::AllValid,
258 (Validity::Array(lhs), Validity::Array(rhs)) => Validity::Array(
260 Binary
261 .try_new_array(lhs.len(), Operator::And, [lhs, rhs])?
262 .optimize()?,
263 ),
264 })
265 }
266
267 pub fn patch(
268 self,
269 len: usize,
270 indices_offset: usize,
271 indices: &ArrayRef,
272 patches: &Validity,
273 ctx: &mut ExecutionCtx,
274 ) -> VortexResult<Self> {
275 match (&self, patches) {
276 (Validity::NonNullable, Validity::NonNullable) => return Ok(Validity::NonNullable),
277 (Validity::NonNullable, _) => {
278 vortex_bail!("Can't patch a non-nullable validity with nullable validity")
279 }
280 (_, Validity::NonNullable) => {
281 vortex_bail!("Can't patch a nullable validity with non-nullable validity")
282 }
283 (Validity::AllValid, Validity::AllValid) => return Ok(Validity::AllValid),
284 (Validity::AllInvalid, Validity::AllInvalid) => return Ok(Validity::AllInvalid),
285 _ => {}
286 };
287
288 let own_nullability = if matches!(self, Validity::NonNullable) {
289 Nullability::NonNullable
290 } else {
291 Nullability::Nullable
292 };
293
294 let source = match self {
295 Validity::NonNullable => BoolArray::from(BitBuffer::new_set(len)),
296 Validity::AllValid => BoolArray::from(BitBuffer::new_set(len)),
297 Validity::AllInvalid => BoolArray::from(BitBuffer::new_unset(len)),
298 Validity::Array(a) => a.clone().execute::<BoolArray>(ctx)?,
299 };
300
301 let patch_values = match patches {
302 Validity::NonNullable => BoolArray::from(BitBuffer::new_set(indices.len())),
303 Validity::AllValid => BoolArray::from(BitBuffer::new_set(indices.len())),
304 Validity::AllInvalid => BoolArray::from(BitBuffer::new_unset(indices.len())),
305 Validity::Array(a) => a.clone().execute::<BoolArray>(ctx)?,
306 };
307
308 let patches = Patches::new(
309 len,
310 indices_offset,
311 indices.to_array(),
312 patch_values.into_array(),
313 None,
315 )?;
316
317 Ok(Self::from_array(
318 source.patch(&patches, ctx)?.into_array(),
319 own_nullability,
320 ))
321 }
322
323 #[inline]
325 pub fn into_nullable(self) -> Validity {
326 match self {
327 Self::NonNullable => Self::AllValid,
328 Self::AllValid | Self::AllInvalid | Self::Array(_) => self,
329 }
330 }
331
332 #[inline]
334 pub fn into_non_nullable(self, len: usize) -> Option<Validity> {
335 match self {
336 _ if len == 0 => Some(Validity::NonNullable),
337 Self::NonNullable => Some(Self::NonNullable),
338 Self::AllValid => Some(Self::NonNullable),
339 Self::AllInvalid => None,
340 Self::Array(is_valid) => {
341 is_valid
342 .statistics()
343 .compute_min::<bool>()
344 .vortex_expect("validity array must support min")
345 .then(|| {
346 Self::NonNullable
348 })
349 }
350 }
351 }
352
353 #[inline]
355 pub fn cast_nullability(self, nullability: Nullability, len: usize) -> VortexResult<Validity> {
356 match nullability {
357 Nullability::NonNullable => self.into_non_nullable(len).ok_or_else(|| {
358 vortex_err!(InvalidArgument: "Cannot cast array with invalid values to non-nullable type.")
359 }),
360 Nullability::Nullable => Ok(self.into_nullable()),
361 }
362 }
363
364 #[inline]
366 pub fn copy_from_array(array: &ArrayRef) -> VortexResult<Self> {
367 Ok(Validity::from_mask(
368 array.validity_mask()?,
369 array.dtype().nullability(),
370 ))
371 }
372
373 fn from_array(value: ArrayRef, nullability: Nullability) -> Self {
378 if !matches!(value.dtype(), DType::Bool(Nullability::NonNullable)) {
379 vortex_panic!("Expected a non-nullable boolean array")
380 }
381 match nullability {
382 Nullability::NonNullable => Self::NonNullable,
383 Nullability::Nullable => Self::Array(value),
384 }
385 }
386
387 #[inline]
389 pub fn maybe_len(&self) -> Option<usize> {
390 match self {
391 Self::NonNullable | Self::AllValid | Self::AllInvalid => None,
392 Self::Array(a) => Some(a.len()),
393 }
394 }
395
396 #[inline]
397 pub fn uncompressed_size(&self) -> usize {
398 if let Validity::Array(a) = self {
399 a.len().div_ceil(8)
400 } else {
401 0
402 }
403 }
404}
405
406impl From<BitBuffer> for Validity {
407 #[inline]
408 fn from(value: BitBuffer) -> Self {
409 let true_count = value.true_count();
410 if true_count == value.len() {
411 Self::AllValid
412 } else if true_count == 0 {
413 Self::AllInvalid
414 } else {
415 Self::Array(BoolArray::from(value).into_array())
416 }
417 }
418}
419
420impl FromIterator<Mask> for Validity {
421 #[inline]
422 fn from_iter<T: IntoIterator<Item = Mask>>(iter: T) -> Self {
423 Validity::from_mask(iter.into_iter().collect(), Nullability::Nullable)
424 }
425}
426
427impl FromIterator<bool> for Validity {
428 #[inline]
429 fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
430 Validity::from(BitBuffer::from_iter(iter))
431 }
432}
433
434impl From<Nullability> for Validity {
435 #[inline]
436 fn from(value: Nullability) -> Self {
437 Validity::from(&value)
438 }
439}
440
441impl From<&Nullability> for Validity {
442 #[inline]
443 fn from(value: &Nullability) -> Self {
444 match *value {
445 Nullability::NonNullable => Validity::NonNullable,
446 Nullability::Nullable => Validity::AllValid,
447 }
448 }
449}
450
451impl Validity {
452 pub fn from_bit_buffer(buffer: BitBuffer, nullability: Nullability) -> Self {
453 if buffer.true_count() == buffer.len() {
454 nullability.into()
455 } else if buffer.true_count() == 0 {
456 Validity::AllInvalid
457 } else {
458 Validity::Array(BoolArray::new(buffer, Validity::NonNullable).into_array())
459 }
460 }
461
462 pub fn from_mask(mask: Mask, nullability: Nullability) -> Self {
463 assert!(
464 nullability == Nullability::Nullable || matches!(mask, Mask::AllTrue(_)),
465 "NonNullable validity must be AllValid",
466 );
467 match mask {
468 Mask::AllTrue(_) => match nullability {
469 Nullability::NonNullable => Validity::NonNullable,
470 Nullability::Nullable => Validity::AllValid,
471 },
472 Mask::AllFalse(_) => Validity::AllInvalid,
473 Mask::Values(values) => Validity::Array(values.into_array()),
474 }
475 }
476}
477
478impl IntoArray for Mask {
479 #[inline]
480 fn into_array(self) -> ArrayRef {
481 match self {
482 Self::AllTrue(len) => ConstantArray::new(true, len).into_array(),
483 Self::AllFalse(len) => ConstantArray::new(false, len).into_array(),
484 Self::Values(a) => a.into_array(),
485 }
486 }
487}
488
489impl IntoArray for &MaskValues {
490 #[inline]
491 fn into_array(self) -> ArrayRef {
492 BoolArray::new(self.bit_buffer().clone(), Validity::NonNullable).into_array()
493 }
494}
495
496#[cfg(test)]
497mod tests {
498 use rstest::rstest;
499 use vortex_buffer::Buffer;
500 use vortex_buffer::buffer;
501 use vortex_mask::Mask;
502
503 use crate::ArrayRef;
504 use crate::IntoArray;
505 use crate::LEGACY_SESSION;
506 use crate::VortexSessionExecute;
507 use crate::arrays::PrimitiveArray;
508 use crate::dtype::Nullability;
509 use crate::validity::BoolArray;
510 use crate::validity::Validity;
511
512 #[rstest]
513 #[case(Validity::AllValid, 5, &[2, 4], Validity::AllValid, Validity::AllValid)]
514 #[case(
515 Validity::AllValid,
516 5,
517 &[2, 4],
518 Validity::AllInvalid,
519 Validity::Array(BoolArray::from_iter([true, true, false, true, false]).into_array())
520 )]
521 #[case(
522 Validity::AllValid,
523 5,
524 &[2, 4],
525 Validity::Array(BoolArray::from_iter([true, false]).into_array()),
526 Validity::Array(BoolArray::from_iter([true, true, true, true, false]).into_array())
527 )]
528 #[case(
529 Validity::AllInvalid,
530 5,
531 &[2, 4],
532 Validity::AllValid,
533 Validity::Array(BoolArray::from_iter([false, false, true, false, true]).into_array())
534 )]
535 #[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllInvalid, Validity::AllInvalid)]
536 #[case(
537 Validity::AllInvalid,
538 5,
539 &[2, 4],
540 Validity::Array(BoolArray::from_iter([true, false]).into_array()),
541 Validity::Array(BoolArray::from_iter([false, false, true, false, false]).into_array())
542 )]
543 #[case(
544 Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()),
545 5,
546 &[2, 4],
547 Validity::AllValid,
548 Validity::Array(BoolArray::from_iter([false, true, true, true, true]).into_array())
549 )]
550 #[case(
551 Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()),
552 5,
553 &[2, 4],
554 Validity::AllInvalid,
555 Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array())
556 )]
557 #[case(
558 Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()),
559 5,
560 &[2, 4],
561 Validity::Array(BoolArray::from_iter([true, false]).into_array()),
562 Validity::Array(BoolArray::from_iter([false, true, true, true, false]).into_array())
563 )]
564
565 fn patch_validity(
566 #[case] validity: Validity,
567 #[case] len: usize,
568 #[case] positions: &[u64],
569 #[case] patches: Validity,
570 #[case] expected: Validity,
571 ) {
572 let indices =
573 PrimitiveArray::new(Buffer::copy_from(positions), Validity::NonNullable).into_array();
574
575 let mut ctx = LEGACY_SESSION.create_execution_ctx();
576
577 assert!(
578 validity
579 .patch(
580 len,
581 0,
582 &indices,
583 &patches,
584 &mut LEGACY_SESSION.create_execution_ctx(),
585 )
586 .unwrap()
587 .mask_eq(&expected, &mut ctx)
588 .unwrap()
589 );
590 }
591
592 #[test]
593 #[should_panic]
594 fn out_of_bounds_patch() {
595 Validity::NonNullable
596 .patch(
597 2,
598 0,
599 &buffer![4].into_array(),
600 &Validity::AllInvalid,
601 &mut LEGACY_SESSION.create_execution_ctx(),
602 )
603 .unwrap();
604 }
605
606 #[test]
607 #[should_panic]
608 fn into_validity_nullable() {
609 Validity::from_mask(Mask::AllFalse(10), Nullability::NonNullable);
610 }
611
612 #[test]
613 #[should_panic]
614 fn into_validity_nullable_array() {
615 Validity::from_mask(Mask::from_iter(vec![true, false]), Nullability::NonNullable);
616 }
617
618 #[rstest]
619 #[case(
620 Validity::AllValid,
621 PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(),
622 Validity::from_iter(vec![true, false])
623 )]
624 #[case(Validity::AllValid, buffer![0, 1].into_array(), Validity::AllValid)]
625 #[case(
626 Validity::AllValid,
627 PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(),
628 Validity::AllInvalid
629 )]
630 #[case(
631 Validity::NonNullable,
632 PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(),
633 Validity::from_iter(vec![true, false])
634 )]
635 #[case(Validity::NonNullable, buffer![0, 1].into_array(), Validity::NonNullable)]
636 #[case(
637 Validity::NonNullable,
638 PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(),
639 Validity::AllInvalid
640 )]
641 fn validity_take(
642 #[case] validity: Validity,
643 #[case] indices: ArrayRef,
644 #[case] expected: Validity,
645 ) {
646 let mut ctx = LEGACY_SESSION.create_execution_ctx();
647 assert!(
648 validity
649 .take(&indices)
650 .unwrap()
651 .mask_eq(&expected, &mut ctx)
652 .unwrap()
653 );
654 }
655}