1use std::fmt::Debug;
7use std::ops::{BitAnd, Not};
8
9use arrow_buffer::{BooleanBuffer, NullBuffer};
10use vortex_dtype::{DType, Nullability};
11use vortex_error::{VortexExpect as _, VortexResult, vortex_bail, vortex_err, vortex_panic};
12use vortex_mask::{AllOr, Mask, MaskValues};
13use vortex_scalar::Scalar;
14
15use crate::arrays::{BoolArray, ConstantArray};
16use crate::compute::{fill_null, filter, sum, take};
17use crate::patches::Patches;
18use crate::{Array, ArrayRef, IntoArray, ToCanonical};
19
20#[derive(Clone, Debug)]
22pub enum Validity {
23 NonNullable,
25 AllValid,
27 AllInvalid,
29 Array(ArrayRef),
31}
32
33impl Validity {
34 pub const DTYPE: DType = DType::Bool(Nullability::NonNullable);
36
37 pub fn null_count(&self, length: usize) -> VortexResult<usize> {
38 match self {
39 Self::NonNullable | Self::AllValid => Ok(0),
40 Self::AllInvalid => Ok(length),
41 Self::Array(a) => {
42 let validity_len = a.len();
43 if validity_len != length {
44 vortex_bail!(
45 "Validity array length {} doesn't match array length {}",
46 validity_len,
47 length
48 )
49 }
50 let true_count = sum(a)?
51 .as_primitive()
52 .as_::<usize>()?
53 .ok_or_else(|| vortex_err!("Failed to compute true count"))?;
54 Ok(length - true_count)
55 }
56 }
57 }
58
59 pub fn into_array(self) -> Option<ArrayRef> {
61 match self {
62 Self::Array(a) => Some(a),
63 _ => None,
64 }
65 }
66
67 pub fn as_array(&self) -> Option<&ArrayRef> {
69 match self {
70 Self::Array(a) => Some(a),
71 _ => None,
72 }
73 }
74
75 pub fn nullability(&self) -> Nullability {
76 match self {
77 Self::NonNullable => Nullability::NonNullable,
78 _ => Nullability::Nullable,
79 }
80 }
81
82 pub fn union_nullability(self, nullability: Nullability) -> Self {
84 match nullability {
85 Nullability::NonNullable => self,
86 Nullability::Nullable => self.into_nullable(),
87 }
88 }
89
90 pub fn all_valid(&self) -> VortexResult<bool> {
91 Ok(match self {
92 Validity::NonNullable | Validity::AllValid => true,
93 Validity::AllInvalid => false,
94 Validity::Array(array) => {
95 array.to_bool()?.boolean_buffer().count_set_bits() == array.len()
97 }
98 })
99 }
100
101 pub fn all_invalid(&self) -> VortexResult<bool> {
102 Ok(match self {
103 Validity::NonNullable | Validity::AllValid => false,
104 Validity::AllInvalid => true,
105 Validity::Array(array) => {
106 array.to_bool()?.boolean_buffer().count_set_bits() == 0
108 }
109 })
110 }
111
112 #[inline]
114 pub fn is_valid(&self, index: usize) -> VortexResult<bool> {
115 Ok(match self {
116 Self::NonNullable | Self::AllValid => true,
117 Self::AllInvalid => false,
118 Self::Array(a) => {
119 let scalar = a.scalar_at(index)?;
120 scalar
121 .as_bool()
122 .value()
123 .vortex_expect("Validity must be non-nullable")
124 }
125 })
126 }
127
128 #[inline]
129 pub fn is_null(&self, index: usize) -> VortexResult<bool> {
130 Ok(!self.is_valid(index)?)
131 }
132
133 pub fn slice(&self, start: usize, stop: usize) -> VortexResult<Self> {
134 match self {
135 Self::Array(a) => Ok(Self::Array(a.slice(start, stop)?)),
136 _ => Ok(self.clone()),
137 }
138 }
139
140 pub fn take(&self, indices: &dyn Array) -> VortexResult<Self> {
141 match self {
142 Self::NonNullable => match indices.validity_mask()?.boolean_buffer() {
143 AllOr::All => {
144 if indices.dtype().is_nullable() {
145 Ok(Self::AllValid)
146 } else {
147 Ok(Self::NonNullable)
148 }
149 }
150 AllOr::None => Ok(Self::AllInvalid),
151 AllOr::Some(buf) => Ok(Validity::from(buf.clone())),
152 },
153 Self::AllValid => match indices.validity_mask()?.boolean_buffer() {
154 AllOr::All => Ok(Self::AllValid),
155 AllOr::None => Ok(Self::AllInvalid),
156 AllOr::Some(buf) => Ok(Validity::from(buf.clone())),
157 },
158 Self::AllInvalid => Ok(Self::AllInvalid),
159 Self::Array(is_valid) => {
160 let maybe_is_valid = take(is_valid, indices)?;
161 let is_valid = fill_null(&maybe_is_valid, &Scalar::from(false))?;
163 Ok(Self::Array(is_valid))
164 }
165 }
166 }
167
168 pub fn filter(&self, mask: &Mask) -> VortexResult<Self> {
172 match self {
175 v @ (Validity::NonNullable | Validity::AllValid | Validity::AllInvalid) => {
176 Ok(v.clone())
177 }
178 Validity::Array(arr) => Ok(Validity::Array(filter(arr, mask)?)),
179 }
180 }
181
182 pub fn mask(&self, mask: &Mask) -> VortexResult<Self> {
186 match mask.boolean_buffer() {
187 AllOr::All => Ok(Validity::AllInvalid),
188 AllOr::None => Ok(self.clone()),
189 AllOr::Some(make_invalid) => Ok(match self {
190 Validity::NonNullable | Validity::AllValid => {
191 Validity::Array(BoolArray::from(make_invalid.not()).into_array())
192 }
193 Validity::AllInvalid => Validity::AllInvalid,
194 Validity::Array(is_valid) => {
195 let is_valid = is_valid.to_bool()?;
196 let keep_valid = make_invalid.not();
197 Validity::from(is_valid.boolean_buffer().bitand(&keep_valid))
198 }
199 }),
200 }
201 }
202
203 pub fn to_mask(&self, length: usize) -> VortexResult<Mask> {
204 Ok(match self {
205 Self::NonNullable | Self::AllValid => Mask::AllTrue(length),
206 Self::AllInvalid => Mask::AllFalse(length),
207 Self::Array(is_valid) => {
208 assert_eq!(
209 is_valid.len(),
210 length,
211 "Validity::Array length must equal to_logical's argument: {}, {}.",
212 is_valid.len(),
213 length,
214 );
215 Mask::try_from(&is_valid.to_bool()?)?
216 }
217 })
218 }
219
220 pub fn and(self, rhs: Validity) -> VortexResult<Validity> {
222 let validity = match (self, rhs) {
223 (Validity::NonNullable, Validity::NonNullable) => Validity::NonNullable,
225 (Validity::AllInvalid, _) | (_, Validity::AllInvalid) => Validity::AllInvalid,
227 (Validity::Array(a), Validity::AllValid)
229 | (Validity::Array(a), Validity::NonNullable)
230 | (Validity::NonNullable, Validity::Array(a))
231 | (Validity::AllValid, Validity::Array(a)) => Validity::Array(a),
232 (Validity::NonNullable, Validity::AllValid)
234 | (Validity::AllValid, Validity::NonNullable)
235 | (Validity::AllValid, Validity::AllValid) => Validity::AllValid,
236 (Validity::Array(lhs), Validity::Array(rhs)) => {
238 let lhs = lhs.to_bool()?;
239 let rhs = rhs.to_bool()?;
240
241 let lhs = lhs.boolean_buffer();
242 let rhs = rhs.boolean_buffer();
243
244 Validity::from(lhs.bitand(rhs))
245 }
246 };
247
248 Ok(validity)
249 }
250
251 pub fn patch(
252 self,
253 len: usize,
254 indices_offset: usize,
255 indices: &dyn Array,
256 patches: &Validity,
257 ) -> VortexResult<Self> {
258 match (&self, patches) {
259 (Validity::NonNullable, Validity::NonNullable) => return Ok(Validity::NonNullable),
260 (Validity::NonNullable, _) => {
261 vortex_bail!("Can't patch a non-nullable validity with nullable validity")
262 }
263 (_, Validity::NonNullable) => {
264 vortex_bail!("Can't patch a nullable validity with non-nullable validity")
265 }
266 (Validity::AllValid, Validity::AllValid) => return Ok(Validity::AllValid),
267 (Validity::AllInvalid, Validity::AllInvalid) => return Ok(Validity::AllInvalid),
268 _ => {}
269 };
270
271 let own_nullability = if self == Validity::NonNullable {
272 Nullability::NonNullable
273 } else {
274 Nullability::Nullable
275 };
276
277 let source = match self {
278 Validity::NonNullable => BoolArray::from(BooleanBuffer::new_set(len)),
279 Validity::AllValid => BoolArray::from(BooleanBuffer::new_set(len)),
280 Validity::AllInvalid => BoolArray::from(BooleanBuffer::new_unset(len)),
281 Validity::Array(a) => a.to_bool()?,
282 };
283
284 let patch_values = match patches {
285 Validity::NonNullable => BoolArray::from(BooleanBuffer::new_set(indices.len())),
286 Validity::AllValid => BoolArray::from(BooleanBuffer::new_set(indices.len())),
287 Validity::AllInvalid => BoolArray::from(BooleanBuffer::new_unset(indices.len())),
288 Validity::Array(a) => a.to_bool()?,
289 };
290
291 let patches = Patches::new(
292 len,
293 indices_offset,
294 indices.to_array(),
295 patch_values.into_array(),
296 );
297
298 Ok(Self::from_array(
299 source.patch(&patches)?.into_array(),
300 own_nullability,
301 ))
302 }
303
304 pub fn into_nullable(self) -> Validity {
306 match self {
307 Self::NonNullable => Self::AllValid,
308 _ => self,
309 }
310 }
311
312 pub fn into_non_nullable(self) -> Option<Validity> {
314 match self {
315 Self::NonNullable => Some(Self::NonNullable),
316 Self::AllValid => Some(Self::NonNullable),
317 Self::AllInvalid => None,
318 Self::Array(is_valid) => {
319 is_valid
320 .statistics()
321 .compute_min::<bool>()
322 .vortex_expect("validity array must support min")
323 .then(|| {
324 Self::NonNullable
326 })
327 }
328 }
329 }
330
331 pub fn cast_nullability(self, nullability: Nullability) -> VortexResult<Validity> {
333 match nullability {
334 Nullability::NonNullable => self.into_non_nullable().ok_or_else(|| {
335 vortex_err!("Cannot cast array with invalid values to non-nullable type.")
336 }),
337 Nullability::Nullable => Ok(self.into_nullable()),
338 }
339 }
340
341 pub fn copy_from_array(array: &dyn Array) -> VortexResult<Self> {
343 Ok(Validity::from_mask(
344 array.validity_mask()?,
345 array.dtype().nullability(),
346 ))
347 }
348
349 fn from_array(value: ArrayRef, nullability: Nullability) -> Self {
354 if !matches!(value.dtype(), DType::Bool(Nullability::NonNullable)) {
355 vortex_panic!("Expected a non-nullable boolean array")
356 }
357 match nullability {
358 Nullability::NonNullable => Self::NonNullable,
359 Nullability::Nullable => Self::Array(value),
360 }
361 }
362
363 pub fn maybe_len(&self) -> Option<usize> {
365 match self {
366 Self::NonNullable | Self::AllValid | Self::AllInvalid => None,
367 Self::Array(a) => Some(a.len()),
368 }
369 }
370
371 pub fn uncompressed_size(&self) -> usize {
372 if let Validity::Array(a) = self {
373 a.len().div_ceil(8)
374 } else {
375 0
376 }
377 }
378
379 pub fn is_array(&self) -> bool {
380 matches!(self, Validity::Array(_))
381 }
382}
383
384impl PartialEq for Validity {
385 fn eq(&self, other: &Self) -> bool {
386 match (self, other) {
387 (Self::NonNullable, Self::NonNullable) => true,
388 (Self::AllValid, Self::AllValid) => true,
389 (Self::AllInvalid, Self::AllInvalid) => true,
390 (Self::Array(a), Self::Array(b)) => {
391 let a = a
392 .to_bool()
393 .vortex_expect("Failed to get Validity Array as BoolArray");
394 let b = b
395 .to_bool()
396 .vortex_expect("Failed to get Validity Array as BoolArray");
397 a.boolean_buffer() == b.boolean_buffer()
398 }
399 _ => false,
400 }
401 }
402}
403
404impl From<BooleanBuffer> for Validity {
405 fn from(value: BooleanBuffer) -> Self {
406 if value.count_set_bits() == value.len() {
407 Self::AllValid
408 } else if value.count_set_bits() == 0 {
409 Self::AllInvalid
410 } else {
411 Self::Array(BoolArray::from(value).into_array())
412 }
413 }
414}
415
416impl From<NullBuffer> for Validity {
417 fn from(value: NullBuffer) -> Self {
418 value.into_inner().into()
419 }
420}
421
422impl FromIterator<Mask> for Validity {
423 fn from_iter<T: IntoIterator<Item = Mask>>(iter: T) -> Self {
424 Validity::from_mask(iter.into_iter().collect(), Nullability::Nullable)
425 }
426}
427
428impl FromIterator<bool> for Validity {
429 fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
430 Validity::from(BooleanBuffer::from_iter(iter))
431 }
432}
433
434impl From<Nullability> for Validity {
435 fn from(value: Nullability) -> Self {
436 match value {
437 Nullability::NonNullable => Validity::NonNullable,
438 Nullability::Nullable => Validity::AllValid,
439 }
440 }
441}
442
443impl Validity {
444 pub fn from_mask(mask: Mask, nullability: Nullability) -> Self {
445 assert!(
446 nullability == Nullability::Nullable || matches!(mask, Mask::AllTrue(_)),
447 "NonNullable validity must be AllValid",
448 );
449 match mask {
450 Mask::AllTrue(_) => match nullability {
451 Nullability::NonNullable => Validity::NonNullable,
452 Nullability::Nullable => Validity::AllValid,
453 },
454 Mask::AllFalse(_) => Validity::AllInvalid,
455 Mask::Values(values) => Validity::Array(values.into_array()),
456 }
457 }
458}
459
460impl IntoArray for Mask {
461 fn into_array(self) -> ArrayRef {
462 match self {
463 Self::AllTrue(len) => ConstantArray::new(true, len).into_array(),
464 Self::AllFalse(len) => ConstantArray::new(false, len).into_array(),
465 Self::Values(a) => a.into_array(),
466 }
467 }
468}
469
470impl IntoArray for &MaskValues {
471 fn into_array(self) -> ArrayRef {
472 BoolArray::new(self.boolean_buffer().clone(), Validity::NonNullable).into_array()
473 }
474}
475
476#[cfg(test)]
477mod tests {
478 use rstest::rstest;
479 use vortex_buffer::{Buffer, buffer};
480 use vortex_dtype::Nullability;
481 use vortex_mask::Mask;
482
483 use crate::arrays::{BoolArray, PrimitiveArray};
484 use crate::validity::Validity;
485 use crate::{ArrayRef, IntoArray};
486
487 #[rstest]
488 #[case(Validity::AllValid, 5, &[2, 4], Validity::AllValid, Validity::AllValid)]
489 #[case(Validity::AllValid, 5, &[2, 4], Validity::AllInvalid, Validity::Array(BoolArray::from_iter([true, true, false, true, false]).into_array())
490 )]
491 #[case(Validity::AllValid, 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([true, true, true, true, false]).into_array())
492 )]
493 #[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllValid, Validity::Array(BoolArray::from_iter([false, false, true, false, true]).into_array())
494 )]
495 #[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllInvalid, Validity::AllInvalid)]
496 #[case(Validity::AllInvalid, 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([false, false, true, false, false]).into_array())
497 )]
498 #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::AllValid, Validity::Array(BoolArray::from_iter([false, true, true, true, true]).into_array())
499 )]
500 #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::AllInvalid, Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array())
501 )]
502 #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([false, true, true, true, false]).into_array())
503 )]
504 fn patch_validity(
505 #[case] validity: Validity,
506 #[case] len: usize,
507 #[case] positions: &[u64],
508 #[case] patches: Validity,
509 #[case] expected: Validity,
510 ) {
511 let indices =
512 PrimitiveArray::new(Buffer::copy_from(positions), Validity::NonNullable).into_array();
513 assert_eq!(
514 validity.patch(len, 0, &indices, &patches).unwrap(),
515 expected
516 );
517 }
518
519 #[test]
520 #[should_panic]
521 fn out_of_bounds_patch() {
522 Validity::NonNullable
523 .patch(2, 0, &buffer![4].into_array(), &Validity::AllInvalid)
524 .unwrap();
525 }
526
527 #[test]
528 #[should_panic]
529 fn into_validity_nullable() {
530 Validity::from_mask(Mask::AllFalse(10), Nullability::NonNullable);
531 }
532
533 #[test]
534 #[should_panic]
535 fn into_validity_nullable_array() {
536 Validity::from_mask(Mask::from_iter(vec![true, false]), Nullability::NonNullable);
537 }
538
539 #[rstest]
540 #[case(Validity::AllValid, PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(), Validity::from_iter(vec![true, false]))]
541 #[case(Validity::AllValid, buffer![0, 1].into_array(), Validity::AllValid)]
542 #[case(Validity::AllValid, PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(), Validity::AllInvalid)]
543 #[case(Validity::NonNullable, PrimitiveArray::new(buffer![0, 1], Validity::from_iter(vec![true, false])).into_array(), Validity::from_iter(vec![true, false]))]
544 #[case(Validity::NonNullable, buffer![0, 1].into_array(), Validity::NonNullable)]
545 #[case(Validity::NonNullable, PrimitiveArray::new(buffer![0, 1], Validity::AllInvalid).into_array(), Validity::AllInvalid)]
546 fn validity_take(
547 #[case] validity: Validity,
548 #[case] indices: ArrayRef,
549 #[case] expected: Validity,
550 ) {
551 assert_eq!(validity.take(&indices).unwrap(), expected);
552 }
553}