1use std::fmt::Debug;
7use std::ops::BitAnd;
8use std::ops::Range;
9
10use vortex_buffer::BitBuffer;
11use vortex_dtype::DType;
12use vortex_dtype::Nullability;
13use vortex_error::VortexExpect as _;
14use vortex_error::VortexResult;
15use vortex_error::vortex_err;
16use vortex_error::vortex_panic;
17use vortex_mask::AllOr;
18use vortex_mask::Mask;
19use vortex_mask::MaskValues;
20use vortex_scalar::Scalar;
21
22use crate::Array;
23use crate::ArrayRef;
24use crate::IntoArray;
25use crate::ToCanonical;
26use crate::arrays::BoolArray;
27use crate::arrays::ConstantArray;
28use crate::compute::fill_null;
29use crate::compute::filter;
30use crate::compute::sum;
31use crate::compute::take;
32use crate::patches::Patches;
33
34#[derive(Clone, Debug)]
36pub enum Validity {
37 NonNullable,
39 AllValid,
41 AllInvalid,
43 Array(ArrayRef),
47}
48
49impl Validity {
50 pub const DTYPE: DType = DType::Bool(Nullability::NonNullable);
52
53 #[inline]
55 pub fn into_array(self) -> Option<ArrayRef> {
56 if let Self::Array(a) = self {
57 Some(a)
58 } else {
59 None
60 }
61 }
62
63 #[inline]
65 pub fn as_array(&self) -> Option<&ArrayRef> {
66 if let Self::Array(a) = self {
67 Some(a)
68 } else {
69 None
70 }
71 }
72
73 #[inline]
74 pub fn nullability(&self) -> Nullability {
75 if matches!(self, Self::NonNullable) {
76 Nullability::NonNullable
77 } else {
78 Nullability::Nullable
79 }
80 }
81
82 #[inline]
84 pub fn union_nullability(self, nullability: Nullability) -> Self {
85 match nullability {
86 Nullability::NonNullable => self,
87 Nullability::Nullable => self.into_nullable(),
88 }
89 }
90
91 #[inline]
92 pub fn all_valid(&self, len: usize) -> bool {
93 match self {
94 _ if len == 0 => true,
95 Validity::NonNullable | Validity::AllValid => true,
96 Validity::AllInvalid => false,
97 Validity::Array(array) => {
98 usize::try_from(&sum(array).vortex_expect("must have sum for bool array"))
99 .vortex_expect("sum must be a usize")
100 == array.len()
101 }
102 }
103 }
104
105 #[inline]
106 pub fn all_invalid(&self, len: usize) -> bool {
107 match self {
108 _ if len == 0 => true,
109 Validity::NonNullable | Validity::AllValid => false,
110 Validity::AllInvalid => true,
111 Validity::Array(array) => {
112 usize::try_from(&sum(array).vortex_expect("must have sum for bool array"))
113 .vortex_expect("sum must be a usize")
114 == 0
115 }
116 }
117 }
118
119 #[inline]
121 pub fn is_valid(&self, index: usize) -> bool {
122 match self {
123 Self::NonNullable | Self::AllValid => true,
124 Self::AllInvalid => false,
125 Self::Array(a) => {
126 let scalar = a.scalar_at(index);
127 scalar
128 .as_bool()
129 .value()
130 .vortex_expect("Validity must be non-nullable")
131 }
132 }
133 }
134
135 #[inline]
136 pub fn is_null(&self, index: usize) -> bool {
137 !self.is_valid(index)
138 }
139
140 #[inline]
141 pub fn slice(&self, range: Range<usize>) -> Self {
142 match self {
143 Self::Array(a) => Self::Array(a.slice(range)),
144 Self::NonNullable | Self::AllValid | Self::AllInvalid => self.clone(),
145 }
146 }
147
148 pub fn take(&self, indices: &dyn Array) -> VortexResult<Self> {
149 match self {
150 Self::NonNullable => match indices.validity_mask().bit_buffer() {
151 AllOr::All => {
152 if indices.dtype().is_nullable() {
153 Ok(Self::AllValid)
154 } else {
155 Ok(Self::NonNullable)
156 }
157 }
158 AllOr::None => Ok(Self::AllInvalid),
159 AllOr::Some(buf) => Ok(Validity::from(buf.clone())),
160 },
161 Self::AllValid => match indices.validity_mask().bit_buffer() {
162 AllOr::All => Ok(Self::AllValid),
163 AllOr::None => Ok(Self::AllInvalid),
164 AllOr::Some(buf) => Ok(Validity::from(buf.clone())),
165 },
166 Self::AllInvalid => Ok(Self::AllInvalid),
167 Self::Array(is_valid) => {
168 let maybe_is_valid = take(is_valid, indices)?;
169 let is_valid = fill_null(&maybe_is_valid, &Scalar::from(false))?;
171 Ok(Self::Array(is_valid))
172 }
173 }
174 }
175
176 pub fn filter(&self, mask: &Mask) -> VortexResult<Self> {
180 match self {
183 v @ (Validity::NonNullable | Validity::AllValid | Validity::AllInvalid) => {
184 Ok(v.clone())
185 }
186 Validity::Array(arr) => Ok(Validity::Array(filter(arr, mask)?)),
187 }
188 }
189
190 #[inline]
194 pub fn mask(&self, mask: &Mask) -> Self {
195 match mask.bit_buffer() {
196 AllOr::All => Validity::AllInvalid,
197 AllOr::None => self.clone().into_nullable(),
198 AllOr::Some(make_invalid) => match self {
199 Validity::NonNullable | Validity::AllValid => {
200 Validity::Array(BoolArray::from(!make_invalid).into_array())
201 }
202 Validity::AllInvalid => Validity::AllInvalid,
203 Validity::Array(is_valid) => {
204 let is_valid = is_valid.to_bool();
205 Validity::from(is_valid.bit_buffer() & !make_invalid)
206 }
207 },
208 }
209 }
210
211 #[inline]
212 pub fn to_mask(&self, length: usize) -> Mask {
213 match self {
214 Self::NonNullable | Self::AllValid => Mask::AllTrue(length),
215 Self::AllInvalid => Mask::AllFalse(length),
216 Self::Array(is_valid) => {
217 assert_eq!(
218 is_valid.len(),
219 length,
220 "Validity::Array length must equal to_logical's argument: {}, {}.",
221 is_valid.len(),
222 length,
223 );
224 is_valid.to_bool().to_mask()
225 }
226 }
227 }
228
229 #[inline]
231 pub fn and(self, rhs: Validity) -> Validity {
232 match (self, rhs) {
233 (Validity::NonNullable, Validity::NonNullable) => Validity::NonNullable,
235 (Validity::AllInvalid, _) | (_, Validity::AllInvalid) => Validity::AllInvalid,
237 (Validity::Array(a), Validity::AllValid)
239 | (Validity::Array(a), Validity::NonNullable)
240 | (Validity::NonNullable, Validity::Array(a))
241 | (Validity::AllValid, Validity::Array(a)) => Validity::Array(a),
242 (Validity::NonNullable, Validity::AllValid)
244 | (Validity::AllValid, Validity::NonNullable)
245 | (Validity::AllValid, Validity::AllValid) => Validity::AllValid,
246 (Validity::Array(lhs), Validity::Array(rhs)) => {
248 let lhs = lhs.to_bool();
249 let rhs = rhs.to_bool();
250
251 let lhs = lhs.bit_buffer();
252 let rhs = rhs.bit_buffer();
253
254 Validity::from(lhs.bitand(rhs))
255 }
256 }
257 }
258
259 pub fn patch(
260 self,
261 len: usize,
262 indices_offset: usize,
263 indices: &dyn Array,
264 patches: &Validity,
265 ) -> Self {
266 match (&self, patches) {
267 (Validity::NonNullable, Validity::NonNullable) => return Validity::NonNullable,
268 (Validity::NonNullable, _) => {
269 vortex_panic!("Can't patch a non-nullable validity with nullable validity")
270 }
271 (_, Validity::NonNullable) => {
272 vortex_panic!("Can't patch a nullable validity with non-nullable validity")
273 }
274 (Validity::AllValid, Validity::AllValid) => return Validity::AllValid,
275 (Validity::AllInvalid, Validity::AllInvalid) => return Validity::AllInvalid,
276 _ => {}
277 };
278
279 let own_nullability = if self == Validity::NonNullable {
280 Nullability::NonNullable
281 } else {
282 Nullability::Nullable
283 };
284
285 let source = match self {
286 Validity::NonNullable => BoolArray::from(BitBuffer::new_set(len)),
287 Validity::AllValid => BoolArray::from(BitBuffer::new_set(len)),
288 Validity::AllInvalid => BoolArray::from(BitBuffer::new_unset(len)),
289 Validity::Array(a) => a.to_bool(),
290 };
291
292 let patch_values = match patches {
293 Validity::NonNullable => BoolArray::from(BitBuffer::new_set(indices.len())),
294 Validity::AllValid => BoolArray::from(BitBuffer::new_set(indices.len())),
295 Validity::AllInvalid => BoolArray::from(BitBuffer::new_unset(indices.len())),
296 Validity::Array(a) => a.to_bool(),
297 };
298
299 let patches = Patches::new(
300 len,
301 indices_offset,
302 indices.to_array(),
303 patch_values.into_array(),
304 None,
306 );
307
308 Self::from_array(source.patch(&patches).into_array(), own_nullability)
309 }
310
311 #[inline]
313 pub fn into_nullable(self) -> Validity {
314 match self {
315 Self::NonNullable => Self::AllValid,
316 Self::AllValid | Self::AllInvalid | Self::Array(_) => self,
317 }
318 }
319
320 #[inline]
322 pub fn into_non_nullable(self, len: usize) -> Option<Validity> {
323 match self {
324 _ if len == 0 => Some(Validity::NonNullable),
325 Self::NonNullable => Some(Self::NonNullable),
326 Self::AllValid => Some(Self::NonNullable),
327 Self::AllInvalid => None,
328 Self::Array(is_valid) => {
329 is_valid
330 .statistics()
331 .compute_min::<bool>()
332 .vortex_expect("validity array must support min")
333 .then(|| {
334 Self::NonNullable
336 })
337 }
338 }
339 }
340
341 #[inline]
343 pub fn cast_nullability(self, nullability: Nullability, len: usize) -> VortexResult<Validity> {
344 match nullability {
345 Nullability::NonNullable => self.into_non_nullable(len).ok_or_else(|| {
346 vortex_err!("Cannot cast array with invalid values to non-nullable type.")
347 }),
348 Nullability::Nullable => Ok(self.into_nullable()),
349 }
350 }
351
352 #[inline]
354 pub fn copy_from_array(array: &dyn Array) -> Self {
355 Validity::from_mask(array.validity_mask(), array.dtype().nullability())
356 }
357
358 #[inline]
363 fn from_array(value: ArrayRef, nullability: Nullability) -> Self {
364 if !matches!(value.dtype(), DType::Bool(Nullability::NonNullable)) {
365 vortex_panic!("Expected a non-nullable boolean array")
366 }
367 match nullability {
368 Nullability::NonNullable => Self::NonNullable,
369 Nullability::Nullable => Self::Array(value),
370 }
371 }
372
373 #[inline]
375 pub fn maybe_len(&self) -> Option<usize> {
376 match self {
377 Self::NonNullable | Self::AllValid | Self::AllInvalid => None,
378 Self::Array(a) => Some(a.len()),
379 }
380 }
381
382 #[inline]
383 pub fn uncompressed_size(&self) -> usize {
384 if let Validity::Array(a) = self {
385 a.len().div_ceil(8)
386 } else {
387 0
388 }
389 }
390}
391
392impl PartialEq for Validity {
393 #[inline]
394 fn eq(&self, other: &Self) -> bool {
395 match (self, other) {
396 (Self::NonNullable, Self::NonNullable) => true,
397 (Self::AllValid, Self::AllValid) => true,
398 (Self::AllInvalid, Self::AllInvalid) => true,
399 (Self::Array(a), Self::Array(b)) => {
400 let a = a.to_bool();
401 let b = b.to_bool();
402 a.bit_buffer() == b.bit_buffer()
403 }
404 _ => false,
405 }
406 }
407}
408
409impl From<BitBuffer> for Validity {
410 #[inline]
411 fn from(value: BitBuffer) -> Self {
412 let true_count = value.true_count();
413 if true_count == value.len() {
414 Self::AllValid
415 } else if true_count == 0 {
416 Self::AllInvalid
417 } else {
418 Self::Array(BoolArray::from(value).into_array())
419 }
420 }
421}
422
423impl FromIterator<Mask> for Validity {
424 #[inline]
425 fn from_iter<T: IntoIterator<Item = Mask>>(iter: T) -> Self {
426 Validity::from_mask(iter.into_iter().collect(), Nullability::Nullable)
427 }
428}
429
430impl FromIterator<bool> for Validity {
431 #[inline]
432 fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
433 Validity::from(BitBuffer::from_iter(iter))
434 }
435}
436
437impl From<Nullability> for Validity {
438 #[inline]
439 fn from(value: Nullability) -> Self {
440 match value {
441 Nullability::NonNullable => Validity::NonNullable,
442 Nullability::Nullable => Validity::AllValid,
443 }
444 }
445}
446
447impl Validity {
448 pub fn from_bit_buffer(buffer: BitBuffer, nullability: Nullability) -> Self {
449 if buffer.true_count() == buffer.len() {
450 nullability.into()
451 } else if buffer.true_count() == 0 {
452 Validity::AllInvalid
453 } else {
454 Validity::Array(BoolArray::from_bit_buffer(buffer, Validity::NonNullable).into_array())
455 }
456 }
457
458 pub fn from_mask(mask: Mask, nullability: Nullability) -> Self {
459 assert!(
460 nullability == Nullability::Nullable || matches!(mask, Mask::AllTrue(_)),
461 "NonNullable validity must be AllValid",
462 );
463 match mask {
464 Mask::AllTrue(_) => match nullability {
465 Nullability::NonNullable => Validity::NonNullable,
466 Nullability::Nullable => Validity::AllValid,
467 },
468 Mask::AllFalse(_) => Validity::AllInvalid,
469 Mask::Values(values) => Validity::Array(values.into_array()),
470 }
471 }
472}
473
474impl IntoArray for Mask {
475 #[inline]
476 fn into_array(self) -> ArrayRef {
477 match self {
478 Self::AllTrue(len) => ConstantArray::new(true, len).into_array(),
479 Self::AllFalse(len) => ConstantArray::new(false, len).into_array(),
480 Self::Values(a) => a.into_array(),
481 }
482 }
483}
484
485impl IntoArray for &MaskValues {
486 #[inline]
487 fn into_array(self) -> ArrayRef {
488 BoolArray::from_bit_buffer(self.bit_buffer().clone(), Validity::NonNullable).into_array()
489 }
490}
491
492#[cfg(test)]
493mod tests {
494 use rstest::rstest;
495 use vortex_buffer::Buffer;
496 use vortex_buffer::buffer;
497 use vortex_dtype::Nullability;
498 use vortex_mask::Mask;
499
500 use crate::ArrayRef;
501 use crate::IntoArray;
502 use crate::arrays::BoolArray;
503 use crate::arrays::PrimitiveArray;
504 use crate::validity::Validity;
505
506 #[rstest]
507 #[case(Validity::AllValid, 5, &[2, 4], Validity::AllValid, Validity::AllValid)]
508 #[case(Validity::AllValid, 5, &[2, 4], Validity::AllInvalid, Validity::Array(BoolArray::from_iter([true, true, false, true, false]).into_array())
509 )]
510 #[case(Validity::AllValid, 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([true, true, true, true, false]).into_array())
511 )]
512 #[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllValid, Validity::Array(BoolArray::from_iter([false, false, true, false, true]).into_array())
513 )]
514 #[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllInvalid, Validity::AllInvalid)]
515 #[case(Validity::AllInvalid, 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([false, false, true, false, false]).into_array())
516 )]
517 #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::AllValid, Validity::Array(BoolArray::from_iter([false, true, true, true, true]).into_array())
518 )]
519 #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::AllInvalid, Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array())
520 )]
521 #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([false, true, true, true, false]).into_array())
522 )]
523 fn patch_validity(
524 #[case] validity: Validity,
525 #[case] len: usize,
526 #[case] positions: &[u64],
527 #[case] patches: Validity,
528 #[case] expected: Validity,
529 ) {
530 let indices =
531 PrimitiveArray::new(Buffer::copy_from(positions), Validity::NonNullable).into_array();
532 assert_eq!(validity.patch(len, 0, &indices, &patches), expected);
533 }
534
535 #[test]
536 #[should_panic]
537 fn out_of_bounds_patch() {
538 Validity::NonNullable.patch(2, 0, &buffer![4].into_array(), &Validity::AllInvalid);
539 }
540
541 #[test]
542 #[should_panic]
543 fn into_validity_nullable() {
544 Validity::from_mask(Mask::AllFalse(10), Nullability::NonNullable);
545 }
546
547 #[test]
548 #[should_panic]
549 fn into_validity_nullable_array() {
550 Validity::from_mask(Mask::from_iter(vec![true, false]), Nullability::NonNullable);
551 }
552
553 #[rstest]
554 #[case(Validity::AllValid, PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(), Validity::from_iter(vec![true, false])
555 )]
556 #[case(Validity::AllValid, buffer![0, 1].into_array(), Validity::AllValid)]
557 #[case(Validity::AllValid, PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(), Validity::AllInvalid
558 )]
559 #[case(Validity::NonNullable, PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(), Validity::from_iter(vec![true, false])
560 )]
561 #[case(Validity::NonNullable, buffer![0, 1].into_array(), Validity::NonNullable)]
562 #[case(Validity::NonNullable, PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(), Validity::AllInvalid
563 )]
564 fn validity_take(
565 #[case] validity: Validity,
566 #[case] indices: ArrayRef,
567 #[case] expected: Validity,
568 ) {
569 assert_eq!(validity.take(&indices).unwrap(), expected);
570 }
571
572 #[test]
573 fn mask_non_nullable() {
574 assert_eq!(
575 Validity::AllValid,
576 Validity::NonNullable.mask(&Mask::AllFalse(2))
577 )
578 }
579}