1use std::fmt::Debug;
7use std::ops::BitAnd;
8use std::ops::Range;
9
10use vortex_buffer::BitBuffer;
11use vortex_dtype::DType;
12use vortex_dtype::Nullability;
13use vortex_error::VortexExpect as _;
14use vortex_error::VortexResult;
15use vortex_error::vortex_err;
16use vortex_error::vortex_panic;
17use vortex_mask::AllOr;
18use vortex_mask::Mask;
19use vortex_mask::MaskValues;
20use vortex_scalar::Scalar;
21
22use crate::Array;
23use crate::ArrayRef;
24use crate::IntoArray;
25use crate::ToCanonical;
26use crate::arrays::BoolArray;
27use crate::arrays::ConstantArray;
28use crate::compute::fill_null;
29use crate::compute::filter;
30use crate::compute::sum;
31use crate::compute::take;
32use crate::patches::Patches;
33
34#[derive(Clone, Debug)]
36pub enum Validity {
37 NonNullable,
39 AllValid,
41 AllInvalid,
43 Array(ArrayRef),
47}
48
49impl Validity {
50 pub const DTYPE: DType = DType::Bool(Nullability::NonNullable);
52
53 pub fn to_array(&self, len: usize) -> ArrayRef {
55 match self {
56 Self::NonNullable | Self::AllValid => ConstantArray::new(true, len).into_array(),
57 Self::AllInvalid => ConstantArray::new(false, len).into_array(),
58 Self::Array(a) => a.clone(),
59 }
60 }
61
62 #[inline]
64 pub fn into_array(self) -> Option<ArrayRef> {
65 if let Self::Array(a) = self {
66 Some(a)
67 } else {
68 None
69 }
70 }
71
72 #[inline]
74 pub fn as_array(&self) -> Option<&ArrayRef> {
75 if let Self::Array(a) = self {
76 Some(a)
77 } else {
78 None
79 }
80 }
81
82 #[inline]
83 pub fn nullability(&self) -> Nullability {
84 if matches!(self, Self::NonNullable) {
85 Nullability::NonNullable
86 } else {
87 Nullability::Nullable
88 }
89 }
90
91 #[inline]
93 pub fn union_nullability(self, nullability: Nullability) -> Self {
94 match nullability {
95 Nullability::NonNullable => self,
96 Nullability::Nullable => self.into_nullable(),
97 }
98 }
99
100 #[inline]
101 pub fn all_valid(&self, len: usize) -> bool {
102 match self {
103 _ if len == 0 => true,
104 Validity::NonNullable | Validity::AllValid => true,
105 Validity::AllInvalid => false,
106 Validity::Array(array) => {
107 usize::try_from(&sum(array).vortex_expect("must have sum for bool array"))
108 .vortex_expect("sum must be a usize")
109 == array.len()
110 }
111 }
112 }
113
114 #[inline]
115 pub fn all_invalid(&self, len: usize) -> bool {
116 match self {
117 _ if len == 0 => true,
118 Validity::NonNullable | Validity::AllValid => false,
119 Validity::AllInvalid => true,
120 Validity::Array(array) => {
121 usize::try_from(&sum(array).vortex_expect("must have sum for bool array"))
122 .vortex_expect("sum must be a usize")
123 == 0
124 }
125 }
126 }
127
128 #[inline]
130 pub fn is_valid(&self, index: usize) -> bool {
131 match self {
132 Self::NonNullable | Self::AllValid => true,
133 Self::AllInvalid => false,
134 Self::Array(a) => {
135 let scalar = a.scalar_at(index);
136 scalar
137 .as_bool()
138 .value()
139 .vortex_expect("Validity must be non-nullable")
140 }
141 }
142 }
143
144 #[inline]
145 pub fn is_null(&self, index: usize) -> bool {
146 !self.is_valid(index)
147 }
148
149 #[inline]
150 pub fn slice(&self, range: Range<usize>) -> Self {
151 match self {
152 Self::Array(a) => Self::Array(a.slice(range)),
153 Self::NonNullable | Self::AllValid | Self::AllInvalid => self.clone(),
154 }
155 }
156
157 pub fn take(&self, indices: &dyn Array) -> VortexResult<Self> {
158 match self {
159 Self::NonNullable => match indices.validity_mask().bit_buffer() {
160 AllOr::All => {
161 if indices.dtype().is_nullable() {
162 Ok(Self::AllValid)
163 } else {
164 Ok(Self::NonNullable)
165 }
166 }
167 AllOr::None => Ok(Self::AllInvalid),
168 AllOr::Some(buf) => Ok(Validity::from(buf.clone())),
169 },
170 Self::AllValid => match indices.validity_mask().bit_buffer() {
171 AllOr::All => Ok(Self::AllValid),
172 AllOr::None => Ok(Self::AllInvalid),
173 AllOr::Some(buf) => Ok(Validity::from(buf.clone())),
174 },
175 Self::AllInvalid => Ok(Self::AllInvalid),
176 Self::Array(is_valid) => {
177 let maybe_is_valid = take(is_valid, indices)?;
178 let is_valid = fill_null(&maybe_is_valid, &Scalar::from(false))?;
180 Ok(Self::Array(is_valid))
181 }
182 }
183 }
184
185 pub fn filter(&self, mask: &Mask) -> VortexResult<Self> {
189 match self {
192 v @ (Validity::NonNullable | Validity::AllValid | Validity::AllInvalid) => {
193 Ok(v.clone())
194 }
195 Validity::Array(arr) => Ok(Validity::Array(filter(arr, mask)?)),
196 }
197 }
198
199 #[inline]
203 pub fn mask(&self, mask: &Mask) -> Self {
204 match mask.bit_buffer() {
205 AllOr::All => Validity::AllInvalid,
206 AllOr::None => self.clone().into_nullable(),
207 AllOr::Some(make_invalid) => match self {
208 Validity::NonNullable | Validity::AllValid => {
209 Validity::Array(BoolArray::from(!make_invalid).into_array())
210 }
211 Validity::AllInvalid => Validity::AllInvalid,
212 Validity::Array(is_valid) => {
213 let is_valid = is_valid.to_bool();
214 Validity::from(is_valid.bit_buffer() & !make_invalid)
215 }
216 },
217 }
218 }
219
220 #[inline]
221 pub fn to_mask(&self, length: usize) -> Mask {
222 match self {
223 Self::NonNullable | Self::AllValid => Mask::AllTrue(length),
224 Self::AllInvalid => Mask::AllFalse(length),
225 Self::Array(is_valid) => {
226 assert_eq!(
227 is_valid.len(),
228 length,
229 "Validity::Array length must equal to_logical's argument: {}, {}.",
230 is_valid.len(),
231 length,
232 );
233 is_valid.to_bool().to_mask()
234 }
235 }
236 }
237
238 #[inline]
240 pub fn and(self, rhs: Validity) -> Validity {
241 match (self, rhs) {
242 (Validity::NonNullable, Validity::NonNullable) => Validity::NonNullable,
244 (Validity::AllInvalid, _) | (_, Validity::AllInvalid) => Validity::AllInvalid,
246 (Validity::Array(a), Validity::AllValid)
248 | (Validity::Array(a), Validity::NonNullable)
249 | (Validity::NonNullable, Validity::Array(a))
250 | (Validity::AllValid, Validity::Array(a)) => Validity::Array(a),
251 (Validity::NonNullable, Validity::AllValid)
253 | (Validity::AllValid, Validity::NonNullable)
254 | (Validity::AllValid, Validity::AllValid) => Validity::AllValid,
255 (Validity::Array(lhs), Validity::Array(rhs)) => {
257 let lhs = lhs.to_bool();
258 let rhs = rhs.to_bool();
259
260 let lhs = lhs.bit_buffer();
261 let rhs = rhs.bit_buffer();
262
263 Validity::from(lhs.bitand(rhs))
264 }
265 }
266 }
267
268 pub fn patch(
269 self,
270 len: usize,
271 indices_offset: usize,
272 indices: &dyn Array,
273 patches: &Validity,
274 ) -> Self {
275 match (&self, patches) {
276 (Validity::NonNullable, Validity::NonNullable) => return Validity::NonNullable,
277 (Validity::NonNullable, _) => {
278 vortex_panic!("Can't patch a non-nullable validity with nullable validity")
279 }
280 (_, Validity::NonNullable) => {
281 vortex_panic!("Can't patch a nullable validity with non-nullable validity")
282 }
283 (Validity::AllValid, Validity::AllValid) => return Validity::AllValid,
284 (Validity::AllInvalid, Validity::AllInvalid) => return Validity::AllInvalid,
285 _ => {}
286 };
287
288 let own_nullability = if self == Validity::NonNullable {
289 Nullability::NonNullable
290 } else {
291 Nullability::Nullable
292 };
293
294 let source = match self {
295 Validity::NonNullable => BoolArray::from(BitBuffer::new_set(len)),
296 Validity::AllValid => BoolArray::from(BitBuffer::new_set(len)),
297 Validity::AllInvalid => BoolArray::from(BitBuffer::new_unset(len)),
298 Validity::Array(a) => a.to_bool(),
299 };
300
301 let patch_values = match patches {
302 Validity::NonNullable => BoolArray::from(BitBuffer::new_set(indices.len())),
303 Validity::AllValid => BoolArray::from(BitBuffer::new_set(indices.len())),
304 Validity::AllInvalid => BoolArray::from(BitBuffer::new_unset(indices.len())),
305 Validity::Array(a) => a.to_bool(),
306 };
307
308 let patches = Patches::new(
309 len,
310 indices_offset,
311 indices.to_array(),
312 patch_values.into_array(),
313 None,
315 );
316
317 Self::from_array(source.patch(&patches).into_array(), own_nullability)
318 }
319
320 #[inline]
322 pub fn into_nullable(self) -> Validity {
323 match self {
324 Self::NonNullable => Self::AllValid,
325 Self::AllValid | Self::AllInvalid | Self::Array(_) => self,
326 }
327 }
328
329 #[inline]
331 pub fn into_non_nullable(self, len: usize) -> Option<Validity> {
332 match self {
333 _ if len == 0 => Some(Validity::NonNullable),
334 Self::NonNullable => Some(Self::NonNullable),
335 Self::AllValid => Some(Self::NonNullable),
336 Self::AllInvalid => None,
337 Self::Array(is_valid) => {
338 is_valid
339 .statistics()
340 .compute_min::<bool>()
341 .vortex_expect("validity array must support min")
342 .then(|| {
343 Self::NonNullable
345 })
346 }
347 }
348 }
349
350 #[inline]
352 pub fn cast_nullability(self, nullability: Nullability, len: usize) -> VortexResult<Validity> {
353 match nullability {
354 Nullability::NonNullable => self.into_non_nullable(len).ok_or_else(|| {
355 vortex_err!("Cannot cast array with invalid values to non-nullable type.")
356 }),
357 Nullability::Nullable => Ok(self.into_nullable()),
358 }
359 }
360
361 #[inline]
363 pub fn copy_from_array(array: &dyn Array) -> Self {
364 Validity::from_mask(array.validity_mask(), array.dtype().nullability())
365 }
366
367 #[inline]
372 fn from_array(value: ArrayRef, nullability: Nullability) -> Self {
373 if !matches!(value.dtype(), DType::Bool(Nullability::NonNullable)) {
374 vortex_panic!("Expected a non-nullable boolean array")
375 }
376 match nullability {
377 Nullability::NonNullable => Self::NonNullable,
378 Nullability::Nullable => Self::Array(value),
379 }
380 }
381
382 #[inline]
384 pub fn maybe_len(&self) -> Option<usize> {
385 match self {
386 Self::NonNullable | Self::AllValid | Self::AllInvalid => None,
387 Self::Array(a) => Some(a.len()),
388 }
389 }
390
391 #[inline]
392 pub fn uncompressed_size(&self) -> usize {
393 if let Validity::Array(a) = self {
394 a.len().div_ceil(8)
395 } else {
396 0
397 }
398 }
399}
400
401impl PartialEq for Validity {
402 #[inline]
403 fn eq(&self, other: &Self) -> bool {
404 match (self, other) {
405 (Self::NonNullable, Self::NonNullable) => true,
406 (Self::AllValid, Self::AllValid) => true,
407 (Self::AllInvalid, Self::AllInvalid) => true,
408 (Self::Array(a), Self::Array(b)) => {
409 let a = a.to_bool();
410 let b = b.to_bool();
411 a.bit_buffer() == b.bit_buffer()
412 }
413 _ => false,
414 }
415 }
416}
417
418impl From<BitBuffer> for Validity {
419 #[inline]
420 fn from(value: BitBuffer) -> Self {
421 let true_count = value.true_count();
422 if true_count == value.len() {
423 Self::AllValid
424 } else if true_count == 0 {
425 Self::AllInvalid
426 } else {
427 Self::Array(BoolArray::from(value).into_array())
428 }
429 }
430}
431
432impl FromIterator<Mask> for Validity {
433 #[inline]
434 fn from_iter<T: IntoIterator<Item = Mask>>(iter: T) -> Self {
435 Validity::from_mask(iter.into_iter().collect(), Nullability::Nullable)
436 }
437}
438
439impl FromIterator<bool> for Validity {
440 #[inline]
441 fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
442 Validity::from(BitBuffer::from_iter(iter))
443 }
444}
445
446impl From<Nullability> for Validity {
447 #[inline]
448 fn from(value: Nullability) -> Self {
449 match value {
450 Nullability::NonNullable => Validity::NonNullable,
451 Nullability::Nullable => Validity::AllValid,
452 }
453 }
454}
455
456impl Validity {
457 pub fn from_bit_buffer(buffer: BitBuffer, nullability: Nullability) -> Self {
458 if buffer.true_count() == buffer.len() {
459 nullability.into()
460 } else if buffer.true_count() == 0 {
461 Validity::AllInvalid
462 } else {
463 Validity::Array(BoolArray::from_bit_buffer(buffer, Validity::NonNullable).into_array())
464 }
465 }
466
467 pub fn from_mask(mask: Mask, nullability: Nullability) -> Self {
468 assert!(
469 nullability == Nullability::Nullable || matches!(mask, Mask::AllTrue(_)),
470 "NonNullable validity must be AllValid",
471 );
472 match mask {
473 Mask::AllTrue(_) => match nullability {
474 Nullability::NonNullable => Validity::NonNullable,
475 Nullability::Nullable => Validity::AllValid,
476 },
477 Mask::AllFalse(_) => Validity::AllInvalid,
478 Mask::Values(values) => Validity::Array(values.into_array()),
479 }
480 }
481}
482
483impl IntoArray for Mask {
484 #[inline]
485 fn into_array(self) -> ArrayRef {
486 match self {
487 Self::AllTrue(len) => ConstantArray::new(true, len).into_array(),
488 Self::AllFalse(len) => ConstantArray::new(false, len).into_array(),
489 Self::Values(a) => a.into_array(),
490 }
491 }
492}
493
494impl IntoArray for &MaskValues {
495 #[inline]
496 fn into_array(self) -> ArrayRef {
497 BoolArray::from_bit_buffer(self.bit_buffer().clone(), Validity::NonNullable).into_array()
498 }
499}
500
501#[cfg(test)]
502mod tests {
503 use rstest::rstest;
504 use vortex_buffer::Buffer;
505 use vortex_buffer::buffer;
506 use vortex_dtype::Nullability;
507 use vortex_mask::Mask;
508
509 use crate::ArrayRef;
510 use crate::IntoArray;
511 use crate::arrays::BoolArray;
512 use crate::arrays::PrimitiveArray;
513 use crate::validity::Validity;
514
515 #[rstest]
516 #[case(Validity::AllValid, 5, &[2, 4], Validity::AllValid, Validity::AllValid)]
517 #[case(Validity::AllValid, 5, &[2, 4], Validity::AllInvalid, Validity::Array(BoolArray::from_iter([true, true, false, true, false]).into_array())
518 )]
519 #[case(Validity::AllValid, 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([true, true, true, true, false]).into_array())
520 )]
521 #[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllValid, Validity::Array(BoolArray::from_iter([false, false, true, false, true]).into_array())
522 )]
523 #[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllInvalid, Validity::AllInvalid)]
524 #[case(Validity::AllInvalid, 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([false, false, true, false, false]).into_array())
525 )]
526 #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::AllValid, Validity::Array(BoolArray::from_iter([false, true, true, true, true]).into_array())
527 )]
528 #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::AllInvalid, Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array())
529 )]
530 #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([false, true, true, true, false]).into_array())
531 )]
532 fn patch_validity(
533 #[case] validity: Validity,
534 #[case] len: usize,
535 #[case] positions: &[u64],
536 #[case] patches: Validity,
537 #[case] expected: Validity,
538 ) {
539 let indices =
540 PrimitiveArray::new(Buffer::copy_from(positions), Validity::NonNullable).into_array();
541 assert_eq!(validity.patch(len, 0, &indices, &patches), expected);
542 }
543
544 #[test]
545 #[should_panic]
546 fn out_of_bounds_patch() {
547 Validity::NonNullable.patch(2, 0, &buffer![4].into_array(), &Validity::AllInvalid);
548 }
549
550 #[test]
551 #[should_panic]
552 fn into_validity_nullable() {
553 Validity::from_mask(Mask::AllFalse(10), Nullability::NonNullable);
554 }
555
556 #[test]
557 #[should_panic]
558 fn into_validity_nullable_array() {
559 Validity::from_mask(Mask::from_iter(vec![true, false]), Nullability::NonNullable);
560 }
561
562 #[rstest]
563 #[case(Validity::AllValid, PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(), Validity::from_iter(vec![true, false])
564 )]
565 #[case(Validity::AllValid, buffer![0, 1].into_array(), Validity::AllValid)]
566 #[case(Validity::AllValid, PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(), Validity::AllInvalid
567 )]
568 #[case(Validity::NonNullable, PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(), Validity::from_iter(vec![true, false])
569 )]
570 #[case(Validity::NonNullable, buffer![0, 1].into_array(), Validity::NonNullable)]
571 #[case(Validity::NonNullable, PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(), Validity::AllInvalid
572 )]
573 fn validity_take(
574 #[case] validity: Validity,
575 #[case] indices: ArrayRef,
576 #[case] expected: Validity,
577 ) {
578 assert_eq!(validity.take(&indices).unwrap(), expected);
579 }
580
581 #[test]
582 fn mask_non_nullable() {
583 assert_eq!(
584 Validity::AllValid,
585 Validity::NonNullable.mask(&Mask::AllFalse(2))
586 )
587 }
588}