vortex_array/execution/
mask.rs1use vortex_dtype::DType;
5use vortex_dtype::Nullability::NonNullable;
6use vortex_error::{VortexExpect, VortexResult, vortex_bail};
7use vortex_mask::Mask;
8
9use crate::ArrayRef;
10use crate::execution::BindCtx;
11
12pub enum MaskExecution {
13 AllTrue(usize),
14 AllFalse(usize),
15 Lazy(Box<dyn FnOnce() -> VortexResult<Mask> + Send + 'static>),
16}
17
18impl MaskExecution {
19 pub fn lazy<F: FnOnce() -> VortexResult<Mask> + Send + 'static>(f: F) -> MaskExecution {
20 MaskExecution::Lazy(Box::new(f))
21 }
22
23 pub fn execute(self) -> VortexResult<Mask> {
24 match self {
25 MaskExecution::AllTrue(len) => Ok(Mask::new_true(len)),
26 MaskExecution::AllFalse(len) => Ok(Mask::new_false(len)),
27 MaskExecution::Lazy(f) => f(),
28 }
29 }
30}
31
32impl dyn BindCtx + '_ {
33 pub fn bind_selection(
37 &mut self,
38 mask_len: usize,
39 mask: Option<&ArrayRef>,
40 ) -> VortexResult<MaskExecution> {
41 match mask {
42 Some(mask) => {
43 assert_eq!(mask.len(), mask_len);
44 self.bind_mask(mask)
45 }
46 None => Ok(MaskExecution::AllTrue(mask_len)),
47 }
48 }
49
50 pub fn bind_mask(&mut self, mask: &ArrayRef) -> VortexResult<MaskExecution> {
55 if !matches!(mask.dtype(), DType::Bool(NonNullable)) {
56 vortex_bail!(
57 "Expected non-nullable boolean array for mask binding, got {}",
58 mask.dtype()
59 );
60 }
61
62 if let Some(scalar) = mask.as_constant() {
64 let constant = scalar
65 .as_bool()
66 .value()
67 .vortex_expect("checked non-nullable");
68 let len = mask.len();
69 if constant {
70 return Ok(MaskExecution::AllTrue(len));
71 } else {
72 return Ok(MaskExecution::AllFalse(len));
73 }
74 }
75
76 let execution = self.bind(mask, None)?;
81 Ok(MaskExecution::lazy(move || {
82 let mask = execution.execute()?.into_bool();
83 Ok(Mask::from(mask.bits().clone()))
84 }))
85 }
86}