1use std::fmt::Debug;
7use std::ops::{BitAnd, Not, Range};
8
9use arrow_buffer::{BooleanBuffer, NullBuffer};
10use vortex_dtype::{DType, Nullability};
11use vortex_error::{VortexExpect as _, VortexResult, vortex_err, vortex_panic};
12use vortex_mask::{AllOr, Mask, MaskValues};
13use vortex_scalar::Scalar;
14
15use crate::arrays::{BoolArray, ConstantArray};
16use crate::compute::{fill_null, filter, sum, take};
17use crate::patches::Patches;
18use crate::{Array, ArrayRef, IntoArray, ToCanonical};
19
20#[derive(Clone, Debug)]
22pub enum Validity {
23 NonNullable,
25 AllValid,
27 AllInvalid,
29 Array(ArrayRef),
31}
32
33impl Validity {
34 pub const DTYPE: DType = DType::Bool(Nullability::NonNullable);
36
37 #[inline]
39 pub fn into_array(self) -> Option<ArrayRef> {
40 if let Self::Array(a) = self {
41 Some(a)
42 } else {
43 None
44 }
45 }
46
47 #[inline]
49 pub fn as_array(&self) -> Option<&ArrayRef> {
50 if let Self::Array(a) = self {
51 Some(a)
52 } else {
53 None
54 }
55 }
56
57 #[inline]
58 pub fn nullability(&self) -> Nullability {
59 if matches!(self, Self::NonNullable) {
60 Nullability::NonNullable
61 } else {
62 Nullability::Nullable
63 }
64 }
65
66 #[inline]
68 pub fn union_nullability(self, nullability: Nullability) -> Self {
69 match nullability {
70 Nullability::NonNullable => self,
71 Nullability::Nullable => self.into_nullable(),
72 }
73 }
74
75 #[inline]
76 pub fn all_valid(&self, len: usize) -> bool {
77 match self {
78 _ if len == 0 => true,
79 Validity::NonNullable | Validity::AllValid => true,
80 Validity::AllInvalid => false,
81 Validity::Array(array) => {
82 usize::try_from(&sum(array).vortex_expect("must have sum for bool array"))
83 .vortex_expect("sum must be a usize")
84 == array.len()
85 }
86 }
87 }
88
89 #[inline]
90 pub fn all_invalid(&self, len: usize) -> bool {
91 match self {
92 _ if len == 0 => true,
93 Validity::NonNullable | Validity::AllValid => false,
94 Validity::AllInvalid => true,
95 Validity::Array(array) => {
96 usize::try_from(&sum(array).vortex_expect("must have sum for bool array"))
97 .vortex_expect("sum must be a usize")
98 == 0
99 }
100 }
101 }
102
103 #[inline]
105 pub fn is_valid(&self, index: usize) -> bool {
106 match self {
107 Self::NonNullable | Self::AllValid => true,
108 Self::AllInvalid => false,
109 Self::Array(a) => {
110 let scalar = a.scalar_at(index);
111 scalar
112 .as_bool()
113 .value()
114 .vortex_expect("Validity must be non-nullable")
115 }
116 }
117 }
118
119 #[inline]
120 pub fn is_null(&self, index: usize) -> bool {
121 !self.is_valid(index)
122 }
123
124 #[inline]
125 pub fn slice(&self, range: Range<usize>) -> Self {
126 match self {
127 Self::Array(a) => Self::Array(a.slice(range)),
128 Self::NonNullable | Self::AllValid | Self::AllInvalid => self.clone(),
129 }
130 }
131
132 pub fn take(&self, indices: &dyn Array) -> VortexResult<Self> {
133 match self {
134 Self::NonNullable => match indices.validity_mask().boolean_buffer() {
135 AllOr::All => {
136 if indices.dtype().is_nullable() {
137 Ok(Self::AllValid)
138 } else {
139 Ok(Self::NonNullable)
140 }
141 }
142 AllOr::None => Ok(Self::AllInvalid),
143 AllOr::Some(buf) => Ok(Validity::from(buf.clone())),
144 },
145 Self::AllValid => match indices.validity_mask().boolean_buffer() {
146 AllOr::All => Ok(Self::AllValid),
147 AllOr::None => Ok(Self::AllInvalid),
148 AllOr::Some(buf) => Ok(Validity::from(buf.clone())),
149 },
150 Self::AllInvalid => Ok(Self::AllInvalid),
151 Self::Array(is_valid) => {
152 let maybe_is_valid = take(is_valid, indices)?;
153 let is_valid = fill_null(&maybe_is_valid, &Scalar::from(false))?;
155 Ok(Self::Array(is_valid))
156 }
157 }
158 }
159
160 pub fn filter(&self, mask: &Mask) -> VortexResult<Self> {
164 match self {
167 v @ (Validity::NonNullable | Validity::AllValid | Validity::AllInvalid) => {
168 Ok(v.clone())
169 }
170 Validity::Array(arr) => Ok(Validity::Array(filter(arr, mask)?)),
171 }
172 }
173
174 #[inline]
178 pub fn mask(&self, mask: &Mask) -> Self {
179 match mask.boolean_buffer() {
180 AllOr::All => Validity::AllInvalid,
181 AllOr::None => self.clone(),
182 AllOr::Some(make_invalid) => match self {
183 Validity::NonNullable | Validity::AllValid => {
184 Validity::Array(BoolArray::from(make_invalid.not()).into_array())
185 }
186 Validity::AllInvalid => Validity::AllInvalid,
187 Validity::Array(is_valid) => {
188 let is_valid = is_valid.to_bool();
189 let keep_valid = make_invalid.not();
190 Validity::from(is_valid.boolean_buffer().bitand(&keep_valid))
191 }
192 },
193 }
194 }
195
196 #[inline]
197 pub fn to_mask(&self, length: usize) -> Mask {
198 match self {
199 Self::NonNullable | Self::AllValid => Mask::AllTrue(length),
200 Self::AllInvalid => Mask::AllFalse(length),
201 Self::Array(is_valid) => {
202 assert_eq!(
203 is_valid.len(),
204 length,
205 "Validity::Array length must equal to_logical's argument: {}, {}.",
206 is_valid.len(),
207 length,
208 );
209 is_valid.to_bool().to_mask()
210 }
211 }
212 }
213
214 #[inline]
216 pub fn and(self, rhs: Validity) -> Validity {
217 match (self, rhs) {
218 (Validity::NonNullable, Validity::NonNullable) => Validity::NonNullable,
220 (Validity::AllInvalid, _) | (_, Validity::AllInvalid) => Validity::AllInvalid,
222 (Validity::Array(a), Validity::AllValid)
224 | (Validity::Array(a), Validity::NonNullable)
225 | (Validity::NonNullable, Validity::Array(a))
226 | (Validity::AllValid, Validity::Array(a)) => Validity::Array(a),
227 (Validity::NonNullable, Validity::AllValid)
229 | (Validity::AllValid, Validity::NonNullable)
230 | (Validity::AllValid, Validity::AllValid) => Validity::AllValid,
231 (Validity::Array(lhs), Validity::Array(rhs)) => {
233 let lhs = lhs.to_bool();
234 let rhs = rhs.to_bool();
235
236 let lhs = lhs.boolean_buffer();
237 let rhs = rhs.boolean_buffer();
238
239 Validity::from(lhs.bitand(rhs))
240 }
241 }
242 }
243
244 pub fn patch(
245 self,
246 len: usize,
247 indices_offset: usize,
248 indices: &dyn Array,
249 patches: &Validity,
250 ) -> Self {
251 match (&self, patches) {
252 (Validity::NonNullable, Validity::NonNullable) => return Validity::NonNullable,
253 (Validity::NonNullable, _) => {
254 vortex_panic!("Can't patch a non-nullable validity with nullable validity")
255 }
256 (_, Validity::NonNullable) => {
257 vortex_panic!("Can't patch a nullable validity with non-nullable validity")
258 }
259 (Validity::AllValid, Validity::AllValid) => return Validity::AllValid,
260 (Validity::AllInvalid, Validity::AllInvalid) => return Validity::AllInvalid,
261 _ => {}
262 };
263
264 let own_nullability = if self == Validity::NonNullable {
265 Nullability::NonNullable
266 } else {
267 Nullability::Nullable
268 };
269
270 let source = match self {
271 Validity::NonNullable => BoolArray::from(BooleanBuffer::new_set(len)),
272 Validity::AllValid => BoolArray::from(BooleanBuffer::new_set(len)),
273 Validity::AllInvalid => BoolArray::from(BooleanBuffer::new_unset(len)),
274 Validity::Array(a) => a.to_bool(),
275 };
276
277 let patch_values = match patches {
278 Validity::NonNullable => BoolArray::from(BooleanBuffer::new_set(indices.len())),
279 Validity::AllValid => BoolArray::from(BooleanBuffer::new_set(indices.len())),
280 Validity::AllInvalid => BoolArray::from(BooleanBuffer::new_unset(indices.len())),
281 Validity::Array(a) => a.to_bool(),
282 };
283
284 let patches = Patches::new(
285 len,
286 indices_offset,
287 indices.to_array(),
288 patch_values.into_array(),
289 );
290
291 Self::from_array(source.patch(&patches).into_array(), own_nullability)
292 }
293
294 #[inline]
296 pub fn into_nullable(self) -> Validity {
297 match self {
298 Self::NonNullable => Self::AllValid,
299 Self::AllValid | Self::AllInvalid | Self::Array(_) => self,
300 }
301 }
302
303 #[inline]
305 pub fn into_non_nullable(self, len: usize) -> Option<Validity> {
306 match self {
307 _ if len == 0 => Some(Validity::NonNullable),
308 Self::NonNullable => Some(Self::NonNullable),
309 Self::AllValid => Some(Self::NonNullable),
310 Self::AllInvalid => None,
311 Self::Array(is_valid) => {
312 is_valid
313 .statistics()
314 .compute_min::<bool>()
315 .vortex_expect("validity array must support min")
316 .then(|| {
317 Self::NonNullable
319 })
320 }
321 }
322 }
323
324 #[inline]
326 pub fn cast_nullability(self, nullability: Nullability, len: usize) -> VortexResult<Validity> {
327 match nullability {
328 Nullability::NonNullable => self.into_non_nullable(len).ok_or_else(|| {
329 vortex_err!("Cannot cast array with invalid values to non-nullable type.")
330 }),
331 Nullability::Nullable => Ok(self.into_nullable()),
332 }
333 }
334
335 #[inline]
337 pub fn copy_from_array(array: &dyn Array) -> Self {
338 Validity::from_mask(array.validity_mask(), array.dtype().nullability())
339 }
340
341 #[inline]
346 fn from_array(value: ArrayRef, nullability: Nullability) -> Self {
347 if !matches!(value.dtype(), DType::Bool(Nullability::NonNullable)) {
348 vortex_panic!("Expected a non-nullable boolean array")
349 }
350 match nullability {
351 Nullability::NonNullable => Self::NonNullable,
352 Nullability::Nullable => Self::Array(value),
353 }
354 }
355
356 #[inline]
358 pub fn maybe_len(&self) -> Option<usize> {
359 match self {
360 Self::NonNullable | Self::AllValid | Self::AllInvalid => None,
361 Self::Array(a) => Some(a.len()),
362 }
363 }
364
365 #[inline]
366 pub fn uncompressed_size(&self) -> usize {
367 if let Validity::Array(a) = self {
368 a.len().div_ceil(8)
369 } else {
370 0
371 }
372 }
373}
374
375impl PartialEq for Validity {
376 #[inline]
377 fn eq(&self, other: &Self) -> bool {
378 match (self, other) {
379 (Self::NonNullable, Self::NonNullable) => true,
380 (Self::AllValid, Self::AllValid) => true,
381 (Self::AllInvalid, Self::AllInvalid) => true,
382 (Self::Array(a), Self::Array(b)) => {
383 let a = a.to_bool();
384 let b = b.to_bool();
385 a.boolean_buffer() == b.boolean_buffer()
386 }
387 _ => false,
388 }
389 }
390}
391
392impl From<BooleanBuffer> for Validity {
393 #[inline]
394 fn from(value: BooleanBuffer) -> Self {
395 if value.count_set_bits() == value.len() {
396 Self::AllValid
397 } else if value.count_set_bits() == 0 {
398 Self::AllInvalid
399 } else {
400 Self::Array(BoolArray::from(value).into_array())
401 }
402 }
403}
404
405impl From<NullBuffer> for Validity {
406 #[inline]
407 fn from(value: NullBuffer) -> Self {
408 value.into_inner().into()
409 }
410}
411
412impl FromIterator<Mask> for Validity {
413 #[inline]
414 fn from_iter<T: IntoIterator<Item = Mask>>(iter: T) -> Self {
415 Validity::from_mask(iter.into_iter().collect(), Nullability::Nullable)
416 }
417}
418
419impl FromIterator<bool> for Validity {
420 #[inline]
421 fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
422 Validity::from(BooleanBuffer::from_iter(iter))
423 }
424}
425
426impl From<Nullability> for Validity {
427 #[inline]
428 fn from(value: Nullability) -> Self {
429 match value {
430 Nullability::NonNullable => Validity::NonNullable,
431 Nullability::Nullable => Validity::AllValid,
432 }
433 }
434}
435
436impl Validity {
437 pub fn from_null_buffer(buffer: Option<NullBuffer>, nullability: Nullability) -> Self {
438 match buffer {
439 None => nullability.into(),
441 Some(nulls) => {
442 if nulls.null_count() == nulls.len() {
443 Validity::AllInvalid
444 } else {
445 Validity::Array(BoolArray::from(nulls.into_inner()).into_array())
446 }
447 }
448 }
449 }
450
451 pub fn from_mask(mask: Mask, nullability: Nullability) -> Self {
452 assert!(
453 nullability == Nullability::Nullable || matches!(mask, Mask::AllTrue(_)),
454 "NonNullable validity must be AllValid",
455 );
456 match mask {
457 Mask::AllTrue(_) => match nullability {
458 Nullability::NonNullable => Validity::NonNullable,
459 Nullability::Nullable => Validity::AllValid,
460 },
461 Mask::AllFalse(_) => Validity::AllInvalid,
462 Mask::Values(values) => Validity::Array(values.into_array()),
463 }
464 }
465}
466
467impl IntoArray for Mask {
468 #[inline]
469 fn into_array(self) -> ArrayRef {
470 match self {
471 Self::AllTrue(len) => ConstantArray::new(true, len).into_array(),
472 Self::AllFalse(len) => ConstantArray::new(false, len).into_array(),
473 Self::Values(a) => a.into_array(),
474 }
475 }
476}
477
478impl IntoArray for &MaskValues {
479 #[inline]
480 fn into_array(self) -> ArrayRef {
481 BoolArray::from_bool_buffer(self.boolean_buffer().clone(), Validity::NonNullable)
482 .into_array()
483 }
484}
485
486#[cfg(test)]
487mod tests {
488 use rstest::rstest;
489 use vortex_buffer::{Buffer, buffer};
490 use vortex_dtype::Nullability;
491 use vortex_mask::Mask;
492
493 use crate::arrays::{BoolArray, PrimitiveArray};
494 use crate::validity::Validity;
495 use crate::{ArrayRef, IntoArray};
496
497 #[rstest]
498 #[case(Validity::AllValid, 5, &[2, 4], Validity::AllValid, Validity::AllValid)]
499 #[case(Validity::AllValid, 5, &[2, 4], Validity::AllInvalid, Validity::Array(BoolArray::from_iter([true, true, false, true, false]).into_array())
500 )]
501 #[case(Validity::AllValid, 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([true, true, true, true, false]).into_array())
502 )]
503 #[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllValid, Validity::Array(BoolArray::from_iter([false, false, true, false, true]).into_array())
504 )]
505 #[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllInvalid, Validity::AllInvalid)]
506 #[case(Validity::AllInvalid, 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([false, false, true, false, false]).into_array())
507 )]
508 #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::AllValid, Validity::Array(BoolArray::from_iter([false, true, true, true, true]).into_array())
509 )]
510 #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::AllInvalid, Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array())
511 )]
512 #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([false, true, true, true, false]).into_array())
513 )]
514 fn patch_validity(
515 #[case] validity: Validity,
516 #[case] len: usize,
517 #[case] positions: &[u64],
518 #[case] patches: Validity,
519 #[case] expected: Validity,
520 ) {
521 let indices =
522 PrimitiveArray::new(Buffer::copy_from(positions), Validity::NonNullable).into_array();
523 assert_eq!(validity.patch(len, 0, &indices, &patches), expected);
524 }
525
526 #[test]
527 #[should_panic]
528 fn out_of_bounds_patch() {
529 Validity::NonNullable.patch(2, 0, &buffer![4].into_array(), &Validity::AllInvalid);
530 }
531
532 #[test]
533 #[should_panic]
534 fn into_validity_nullable() {
535 Validity::from_mask(Mask::AllFalse(10), Nullability::NonNullable);
536 }
537
538 #[test]
539 #[should_panic]
540 fn into_validity_nullable_array() {
541 Validity::from_mask(Mask::from_iter(vec![true, false]), Nullability::NonNullable);
542 }
543
544 #[rstest]
545 #[case(Validity::AllValid, PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(), Validity::from_iter(vec![true, false]))]
546 #[case(Validity::AllValid, buffer![0, 1].into_array(), Validity::AllValid)]
547 #[case(Validity::AllValid, PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(), Validity::AllInvalid)]
548 #[case(Validity::NonNullable, PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(), Validity::from_iter(vec![true, false]))]
549 #[case(Validity::NonNullable, buffer![0, 1].into_array(), Validity::NonNullable)]
550 #[case(Validity::NonNullable, PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(), Validity::AllInvalid)]
551 fn validity_take(
552 #[case] validity: Validity,
553 #[case] indices: ArrayRef,
554 #[case] expected: Validity,
555 ) {
556 assert_eq!(validity.take(&indices).unwrap(), expected);
557 }
558}