1use std::fmt::Debug;
7use std::ops::Range;
8
9use vortex_buffer::BitBuffer;
10use vortex_error::VortexExpect as _;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_error::vortex_err;
14use vortex_error::vortex_panic;
15use vortex_mask::AllOr;
16use vortex_mask::Mask;
17use vortex_mask::MaskValues;
18
19use crate::ArrayRef;
20use crate::Canonical;
21use crate::DynArray;
22use crate::ExecutionCtx;
23use crate::IntoArray;
24use crate::arrays::BoolArray;
25use crate::arrays::ConstantArray;
26use crate::arrays::scalar_fn::ScalarFnArrayExt;
27use crate::builtins::ArrayBuiltins;
28use crate::dtype::DType;
29use crate::dtype::Nullability;
30use crate::optimizer::ArrayOptimizer;
31use crate::patches::Patches;
32use crate::scalar::Scalar;
33use crate::scalar_fn::fns::binary::Binary;
34use crate::scalar_fn::fns::operators::Operator;
35
36#[derive(Clone)]
38pub enum Validity {
39 NonNullable,
41 AllValid,
43 AllInvalid,
45 Array(ArrayRef),
49}
50
51impl Debug for Validity {
52 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53 match self {
54 Self::NonNullable => write!(f, "NonNullable"),
55 Self::AllValid => write!(f, "AllValid"),
56 Self::AllInvalid => write!(f, "AllInvalid"),
57 Self::Array(arr) => write!(f, "SomeValid({})", arr.as_ref().display_values()),
58 }
59 }
60}
61
62impl Validity {
63 pub fn execute(self, ctx: &mut ExecutionCtx) -> VortexResult<Validity> {
65 match self {
66 v @ Validity::NonNullable | v @ Validity::AllValid | v @ Validity::AllInvalid => Ok(v),
67 Validity::Array(a) => Ok(Validity::Array(a.execute::<Canonical>(ctx)?.into_array())),
68 }
69 }
70}
71
72impl Validity {
73 pub const DTYPE: DType = DType::Bool(Nullability::NonNullable);
75
76 pub fn to_array(&self, len: usize) -> ArrayRef {
78 match self {
79 Self::NonNullable | Self::AllValid => ConstantArray::new(true, len).into_array(),
80 Self::AllInvalid => ConstantArray::new(false, len).into_array(),
81 Self::Array(a) => a.clone(),
82 }
83 }
84
85 #[inline]
87 pub fn into_array(self) -> Option<ArrayRef> {
88 if let Self::Array(a) = self {
89 Some(a)
90 } else {
91 None
92 }
93 }
94
95 #[inline]
97 pub fn as_array(&self) -> Option<&ArrayRef> {
98 if let Self::Array(a) = self {
99 Some(a)
100 } else {
101 None
102 }
103 }
104
105 #[inline]
106 pub fn nullability(&self) -> Nullability {
107 if matches!(self, Self::NonNullable) {
108 Nullability::NonNullable
109 } else {
110 Nullability::Nullable
111 }
112 }
113
114 #[inline]
116 pub fn union_nullability(self, nullability: Nullability) -> Self {
117 match nullability {
118 Nullability::NonNullable => self,
119 Nullability::Nullable => self.into_nullable(),
120 }
121 }
122
123 #[inline]
125 pub fn is_valid(&self, index: usize) -> VortexResult<bool> {
126 Ok(match self {
127 Self::NonNullable | Self::AllValid => true,
128 Self::AllInvalid => false,
129 Self::Array(a) => a
130 .scalar_at(index)
131 .vortex_expect("Validity array must support scalar_at")
132 .as_bool()
133 .value()
134 .vortex_expect("Validity must be non-nullable"),
135 })
136 }
137
138 #[inline]
139 pub fn is_null(&self, index: usize) -> VortexResult<bool> {
140 Ok(!self.is_valid(index)?)
141 }
142
143 #[inline]
144 pub fn slice(&self, range: Range<usize>) -> VortexResult<Self> {
145 match self {
146 Self::Array(a) => Ok(Self::Array(a.slice(range)?)),
147 Self::NonNullable | Self::AllValid | Self::AllInvalid => Ok(self.clone()),
148 }
149 }
150
151 pub fn take(&self, indices: &ArrayRef) -> VortexResult<Self> {
152 match self {
153 Self::NonNullable => match indices.validity_mask()?.bit_buffer() {
154 AllOr::All => {
155 if indices.dtype().is_nullable() {
156 Ok(Self::AllValid)
157 } else {
158 Ok(Self::NonNullable)
159 }
160 }
161 AllOr::None => Ok(Self::AllInvalid),
162 AllOr::Some(buf) => Ok(Validity::from(buf.clone())),
163 },
164 Self::AllValid => match indices.validity_mask()?.bit_buffer() {
165 AllOr::All => Ok(Self::AllValid),
166 AllOr::None => Ok(Self::AllInvalid),
167 AllOr::Some(buf) => Ok(Validity::from(buf.clone())),
168 },
169 Self::AllInvalid => Ok(Self::AllInvalid),
170 Self::Array(is_valid) => {
171 let maybe_is_valid = is_valid.take(indices.to_array())?;
172 let is_valid = maybe_is_valid.fill_null(Scalar::from(false))?;
174 Ok(Self::Array(is_valid))
175 }
176 }
177 }
178
179 pub fn not(&self) -> VortexResult<Self> {
181 match self {
182 Validity::NonNullable => Ok(Validity::NonNullable),
183 Validity::AllValid => Ok(Validity::AllInvalid),
184 Validity::AllInvalid => Ok(Validity::AllValid),
185 Validity::Array(arr) => Ok(Validity::Array(arr.not()?)),
186 }
187 }
188
189 pub fn filter(&self, mask: &Mask) -> VortexResult<Self> {
197 match self {
200 v @ (Validity::NonNullable | Validity::AllValid | Validity::AllInvalid) => {
201 Ok(v.clone())
202 }
203 Validity::Array(arr) => Ok(Validity::Array(arr.filter(mask.clone())?)),
204 }
205 }
206
207 pub fn execute_mask(&self, length: usize, ctx: &mut ExecutionCtx) -> VortexResult<Mask> {
208 match self {
209 Self::NonNullable | Self::AllValid => Ok(Mask::AllTrue(length)),
210 Self::AllInvalid => Ok(Mask::AllFalse(length)),
211 Self::Array(arr) => {
212 assert_eq!(
213 arr.len(),
214 length,
215 "Validity::Array length must equal to_logical's argument: {}, {}.",
216 arr.len(),
217 length,
218 );
219 arr.clone().execute::<Mask>(ctx)
222 }
223 }
224 }
225
226 pub fn mask_eq(&self, other: &Validity, ctx: &mut ExecutionCtx) -> VortexResult<bool> {
228 match (self, other) {
229 (Validity::NonNullable, Validity::NonNullable) => Ok(true),
230 (Validity::AllValid, Validity::AllValid) => Ok(true),
231 (Validity::AllInvalid, Validity::AllInvalid) => Ok(true),
232 (Validity::Array(a), Validity::Array(b)) => {
233 let a = a.clone().execute::<Mask>(ctx)?;
234 let b = b.clone().execute::<Mask>(ctx)?;
235 Ok(a == b)
236 }
237 _ => Ok(false),
238 }
239 }
240
241 #[inline]
243 pub fn and(self, rhs: Validity) -> VortexResult<Validity> {
244 Ok(match (self, rhs) {
245 (Validity::NonNullable, Validity::NonNullable) => Validity::NonNullable,
247 (Validity::AllInvalid, _) | (_, Validity::AllInvalid) => Validity::AllInvalid,
249 (Validity::Array(a), Validity::AllValid)
251 | (Validity::Array(a), Validity::NonNullable)
252 | (Validity::NonNullable, Validity::Array(a))
253 | (Validity::AllValid, Validity::Array(a)) => Validity::Array(a),
254 (Validity::NonNullable, Validity::AllValid)
256 | (Validity::AllValid, Validity::NonNullable)
257 | (Validity::AllValid, Validity::AllValid) => Validity::AllValid,
258 (Validity::Array(lhs), Validity::Array(rhs)) => Validity::Array(
260 Binary
261 .try_new_array(lhs.len(), Operator::And, [lhs, rhs])?
262 .optimize()?,
263 ),
264 })
265 }
266
267 pub fn patch(
268 self,
269 len: usize,
270 indices_offset: usize,
271 indices: &ArrayRef,
272 patches: &Validity,
273 ctx: &mut ExecutionCtx,
274 ) -> VortexResult<Self> {
275 match (&self, patches) {
276 (Validity::NonNullable, Validity::NonNullable) => return Ok(Validity::NonNullable),
277 (Validity::NonNullable, _) => {
278 vortex_bail!("Can't patch a non-nullable validity with nullable validity")
279 }
280 (_, Validity::NonNullable) => {
281 vortex_bail!("Can't patch a nullable validity with non-nullable validity")
282 }
283 (Validity::AllValid, Validity::AllValid) => return Ok(Validity::AllValid),
284 (Validity::AllInvalid, Validity::AllInvalid) => return Ok(Validity::AllInvalid),
285 _ => {}
286 };
287
288 let own_nullability = if matches!(self, Validity::NonNullable) {
289 Nullability::NonNullable
290 } else {
291 Nullability::Nullable
292 };
293
294 let source = match self {
295 Validity::NonNullable => BoolArray::from(BitBuffer::new_set(len)),
296 Validity::AllValid => BoolArray::from(BitBuffer::new_set(len)),
297 Validity::AllInvalid => BoolArray::from(BitBuffer::new_unset(len)),
298 Validity::Array(a) => a.clone().execute::<BoolArray>(ctx)?,
299 };
300
301 let patch_values = match patches {
302 Validity::NonNullable => BoolArray::from(BitBuffer::new_set(indices.len())),
303 Validity::AllValid => BoolArray::from(BitBuffer::new_set(indices.len())),
304 Validity::AllInvalid => BoolArray::from(BitBuffer::new_unset(indices.len())),
305 Validity::Array(a) => a.clone().execute::<BoolArray>(ctx)?,
306 };
307
308 let patches = Patches::new(
309 len,
310 indices_offset,
311 indices.to_array(),
312 patch_values.into_array(),
313 None,
315 )?;
316
317 Ok(Self::from_array(
318 source.patch(&patches, ctx)?.into_array(),
319 own_nullability,
320 ))
321 }
322
323 #[inline]
325 pub fn into_nullable(self) -> Validity {
326 match self {
327 Self::NonNullable => Self::AllValid,
328 Self::AllValid | Self::AllInvalid | Self::Array(_) => self,
329 }
330 }
331
332 #[inline]
334 pub fn into_non_nullable(self, len: usize) -> Option<Validity> {
335 match self {
336 _ if len == 0 => Some(Validity::NonNullable),
337 Self::NonNullable => Some(Self::NonNullable),
338 Self::AllValid => Some(Self::NonNullable),
339 Self::AllInvalid => None,
340 Self::Array(is_valid) => {
341 is_valid
342 .statistics()
343 .compute_min::<bool>()
344 .vortex_expect("validity array must support min")
345 .then(|| {
346 Self::NonNullable
348 })
349 }
350 }
351 }
352
353 #[inline]
355 pub fn cast_nullability(self, nullability: Nullability, len: usize) -> VortexResult<Validity> {
356 match nullability {
357 Nullability::NonNullable => self.into_non_nullable(len).ok_or_else(|| {
358 vortex_err!(InvalidArgument: "Cannot cast array with invalid values to non-nullable type.")
359 }),
360 Nullability::Nullable => Ok(self.into_nullable()),
361 }
362 }
363
364 #[inline]
366 pub fn copy_from_array(array: &ArrayRef) -> VortexResult<Self> {
367 Ok(Validity::from_mask(
368 array.validity_mask()?,
369 array.dtype().nullability(),
370 ))
371 }
372
373 #[inline]
378 fn from_array(value: ArrayRef, nullability: Nullability) -> Self {
379 if !matches!(value.dtype(), DType::Bool(Nullability::NonNullable)) {
380 vortex_panic!("Expected a non-nullable boolean array")
381 }
382 match nullability {
383 Nullability::NonNullable => Self::NonNullable,
384 Nullability::Nullable => Self::Array(value),
385 }
386 }
387
388 #[inline]
390 pub fn maybe_len(&self) -> Option<usize> {
391 match self {
392 Self::NonNullable | Self::AllValid | Self::AllInvalid => None,
393 Self::Array(a) => Some(a.len()),
394 }
395 }
396
397 #[inline]
398 pub fn uncompressed_size(&self) -> usize {
399 if let Validity::Array(a) = self {
400 a.len().div_ceil(8)
401 } else {
402 0
403 }
404 }
405}
406
407impl From<BitBuffer> for Validity {
408 #[inline]
409 fn from(value: BitBuffer) -> Self {
410 let true_count = value.true_count();
411 if true_count == value.len() {
412 Self::AllValid
413 } else if true_count == 0 {
414 Self::AllInvalid
415 } else {
416 Self::Array(BoolArray::from(value).into_array())
417 }
418 }
419}
420
421impl FromIterator<Mask> for Validity {
422 #[inline]
423 fn from_iter<T: IntoIterator<Item = Mask>>(iter: T) -> Self {
424 Validity::from_mask(iter.into_iter().collect(), Nullability::Nullable)
425 }
426}
427
428impl FromIterator<bool> for Validity {
429 #[inline]
430 fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
431 Validity::from(BitBuffer::from_iter(iter))
432 }
433}
434
435impl From<Nullability> for Validity {
436 #[inline]
437 fn from(value: Nullability) -> Self {
438 Validity::from(&value)
439 }
440}
441
442impl From<&Nullability> for Validity {
443 #[inline]
444 fn from(value: &Nullability) -> Self {
445 match *value {
446 Nullability::NonNullable => Validity::NonNullable,
447 Nullability::Nullable => Validity::AllValid,
448 }
449 }
450}
451
452impl Validity {
453 pub fn from_bit_buffer(buffer: BitBuffer, nullability: Nullability) -> Self {
454 if buffer.true_count() == buffer.len() {
455 nullability.into()
456 } else if buffer.true_count() == 0 {
457 Validity::AllInvalid
458 } else {
459 Validity::Array(BoolArray::new(buffer, Validity::NonNullable).into_array())
460 }
461 }
462
463 pub fn from_mask(mask: Mask, nullability: Nullability) -> Self {
464 assert!(
465 nullability == Nullability::Nullable || matches!(mask, Mask::AllTrue(_)),
466 "NonNullable validity must be AllValid",
467 );
468 match mask {
469 Mask::AllTrue(_) => match nullability {
470 Nullability::NonNullable => Validity::NonNullable,
471 Nullability::Nullable => Validity::AllValid,
472 },
473 Mask::AllFalse(_) => Validity::AllInvalid,
474 Mask::Values(values) => Validity::Array(values.into_array()),
475 }
476 }
477}
478
479impl IntoArray for Mask {
480 #[inline]
481 fn into_array(self) -> ArrayRef {
482 match self {
483 Self::AllTrue(len) => ConstantArray::new(true, len).into_array(),
484 Self::AllFalse(len) => ConstantArray::new(false, len).into_array(),
485 Self::Values(a) => a.into_array(),
486 }
487 }
488}
489
490impl IntoArray for &MaskValues {
491 #[inline]
492 fn into_array(self) -> ArrayRef {
493 BoolArray::new(self.bit_buffer().clone(), Validity::NonNullable).into_array()
494 }
495}
496
497#[cfg(test)]
498mod tests {
499 use rstest::rstest;
500 use vortex_buffer::Buffer;
501 use vortex_buffer::buffer;
502 use vortex_mask::Mask;
503
504 use crate::ArrayRef;
505 use crate::IntoArray;
506 use crate::LEGACY_SESSION;
507 use crate::VortexSessionExecute;
508 use crate::arrays::PrimitiveArray;
509 use crate::dtype::Nullability;
510 use crate::validity::BoolArray;
511 use crate::validity::Validity;
512
513 #[rstest]
514 #[case(Validity::AllValid, 5, &[2, 4], Validity::AllValid, Validity::AllValid)]
515 #[case(
516 Validity::AllValid,
517 5,
518 &[2, 4],
519 Validity::AllInvalid,
520 Validity::Array(BoolArray::from_iter([true, true, false, true, false]).into_array())
521 )]
522 #[case(
523 Validity::AllValid,
524 5,
525 &[2, 4],
526 Validity::Array(BoolArray::from_iter([true, false]).into_array()),
527 Validity::Array(BoolArray::from_iter([true, true, true, true, false]).into_array())
528 )]
529 #[case(
530 Validity::AllInvalid,
531 5,
532 &[2, 4],
533 Validity::AllValid,
534 Validity::Array(BoolArray::from_iter([false, false, true, false, true]).into_array())
535 )]
536 #[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllInvalid, Validity::AllInvalid)]
537 #[case(
538 Validity::AllInvalid,
539 5,
540 &[2, 4],
541 Validity::Array(BoolArray::from_iter([true, false]).into_array()),
542 Validity::Array(BoolArray::from_iter([false, false, true, false, false]).into_array())
543 )]
544 #[case(
545 Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()),
546 5,
547 &[2, 4],
548 Validity::AllValid,
549 Validity::Array(BoolArray::from_iter([false, true, true, true, true]).into_array())
550 )]
551 #[case(
552 Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()),
553 5,
554 &[2, 4],
555 Validity::AllInvalid,
556 Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array())
557 )]
558 #[case(
559 Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()),
560 5,
561 &[2, 4],
562 Validity::Array(BoolArray::from_iter([true, false]).into_array()),
563 Validity::Array(BoolArray::from_iter([false, true, true, true, false]).into_array())
564 )]
565
566 fn patch_validity(
567 #[case] validity: Validity,
568 #[case] len: usize,
569 #[case] positions: &[u64],
570 #[case] patches: Validity,
571 #[case] expected: Validity,
572 ) {
573 let indices =
574 PrimitiveArray::new(Buffer::copy_from(positions), Validity::NonNullable).into_array();
575
576 let mut ctx = LEGACY_SESSION.create_execution_ctx();
577
578 assert!(
579 validity
580 .patch(
581 len,
582 0,
583 &indices,
584 &patches,
585 &mut LEGACY_SESSION.create_execution_ctx(),
586 )
587 .unwrap()
588 .mask_eq(&expected, &mut ctx)
589 .unwrap()
590 );
591 }
592
593 #[test]
594 #[should_panic]
595 fn out_of_bounds_patch() {
596 Validity::NonNullable
597 .patch(
598 2,
599 0,
600 &buffer![4].into_array(),
601 &Validity::AllInvalid,
602 &mut LEGACY_SESSION.create_execution_ctx(),
603 )
604 .unwrap();
605 }
606
607 #[test]
608 #[should_panic]
609 fn into_validity_nullable() {
610 Validity::from_mask(Mask::AllFalse(10), Nullability::NonNullable);
611 }
612
613 #[test]
614 #[should_panic]
615 fn into_validity_nullable_array() {
616 Validity::from_mask(Mask::from_iter(vec![true, false]), Nullability::NonNullable);
617 }
618
619 #[rstest]
620 #[case(
621 Validity::AllValid,
622 PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(),
623 Validity::from_iter(vec![true, false])
624 )]
625 #[case(Validity::AllValid, buffer![0, 1].into_array(), Validity::AllValid)]
626 #[case(
627 Validity::AllValid,
628 PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(),
629 Validity::AllInvalid
630 )]
631 #[case(
632 Validity::NonNullable,
633 PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(),
634 Validity::from_iter(vec![true, false])
635 )]
636 #[case(Validity::NonNullable, buffer![0, 1].into_array(), Validity::NonNullable)]
637 #[case(
638 Validity::NonNullable,
639 PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(),
640 Validity::AllInvalid
641 )]
642 fn validity_take(
643 #[case] validity: Validity,
644 #[case] indices: ArrayRef,
645 #[case] expected: Validity,
646 ) {
647 let mut ctx = LEGACY_SESSION.create_execution_ctx();
648 assert!(
649 validity
650 .take(&indices)
651 .unwrap()
652 .mask_eq(&expected, &mut ctx)
653 .unwrap()
654 );
655 }
656}