vortex_array/execution/
mask.rs1use vortex_dtype::DType;
5use vortex_dtype::Nullability::NonNullable;
6use vortex_error::VortexExpect;
7use vortex_error::VortexResult;
8use vortex_error::vortex_bail;
9use vortex_mask::Mask;
10
11use crate::ArrayRef;
12use crate::execution::BindCtx;
13
14pub enum MaskExecution {
15 AllTrue(usize),
16 AllFalse(usize),
17 Lazy(Box<dyn FnOnce() -> VortexResult<Mask> + Send + 'static>),
18}
19
20impl MaskExecution {
21 pub fn lazy<F: FnOnce() -> VortexResult<Mask> + Send + 'static>(f: F) -> MaskExecution {
22 MaskExecution::Lazy(Box::new(f))
23 }
24
25 pub fn execute(self) -> VortexResult<Mask> {
26 match self {
27 MaskExecution::AllTrue(len) => Ok(Mask::new_true(len)),
28 MaskExecution::AllFalse(len) => Ok(Mask::new_false(len)),
29 MaskExecution::Lazy(f) => f(),
30 }
31 }
32}
33
34impl dyn BindCtx + '_ {
35 pub fn bind_selection(
39 &mut self,
40 mask_len: usize,
41 mask: Option<&ArrayRef>,
42 ) -> VortexResult<MaskExecution> {
43 match mask {
44 Some(mask) => {
45 assert_eq!(mask.len(), mask_len);
46 self.bind_mask(mask)
47 }
48 None => Ok(MaskExecution::AllTrue(mask_len)),
49 }
50 }
51
52 pub fn bind_mask(&mut self, mask: &ArrayRef) -> VortexResult<MaskExecution> {
57 if !matches!(mask.dtype(), DType::Bool(NonNullable)) {
58 vortex_bail!(
59 "Expected non-nullable boolean array for mask binding, got {}",
60 mask.dtype()
61 );
62 }
63
64 if let Some(scalar) = mask.as_constant() {
66 let constant = scalar
67 .as_bool()
68 .value()
69 .vortex_expect("checked non-nullable");
70 let len = mask.len();
71 if constant {
72 return Ok(MaskExecution::AllTrue(len));
73 } else {
74 return Ok(MaskExecution::AllFalse(len));
75 }
76 }
77
78 let execution = self.bind(mask, None)?;
83 Ok(MaskExecution::lazy(move || {
84 let mask = execution.execute()?.into_bool();
85 Ok(Mask::from(mask.bits().clone()))
86 }))
87 }
88}