vortex_array/execution/
mask.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_dtype::DType;
5use vortex_dtype::Nullability::NonNullable;
6use vortex_error::{VortexExpect, VortexResult, vortex_bail};
7use vortex_mask::Mask;
8
9use crate::ArrayRef;
10use crate::execution::BindCtx;
11
12pub enum MaskExecution {
13    AllTrue(usize),
14    AllFalse(usize),
15    Lazy(Box<dyn FnOnce() -> VortexResult<Mask> + Send + 'static>),
16}
17
18impl MaskExecution {
19    pub fn lazy<F: FnOnce() -> VortexResult<Mask> + Send + 'static>(f: F) -> MaskExecution {
20        MaskExecution::Lazy(Box::new(f))
21    }
22
23    pub fn execute(self) -> VortexResult<Mask> {
24        match self {
25            MaskExecution::AllTrue(len) => Ok(Mask::new_true(len)),
26            MaskExecution::AllFalse(len) => Ok(Mask::new_false(len)),
27            MaskExecution::Lazy(f) => f(),
28        }
29    }
30}
31
32impl dyn BindCtx + '_ {
33    /// Bind an optional selection mask into a `MaskExecution`.
34    ///
35    /// The caller must provide a mask length to handle the case where no mask is provided.
36    pub fn bind_selection(
37        &mut self,
38        mask_len: usize,
39        mask: Option<&ArrayRef>,
40    ) -> VortexResult<MaskExecution> {
41        match mask {
42            Some(mask) => {
43                assert_eq!(mask.len(), mask_len);
44                self.bind_mask(mask)
45            }
46            None => Ok(MaskExecution::AllTrue(mask_len)),
47        }
48    }
49
50    /// Bind a non-nullable boolean array into a `MaskExecution`.
51    ///
52    /// This binding will optimize for constant arrays or other array types that can be more
53    /// efficiently converted into a `Mask`.
54    pub fn bind_mask(&mut self, mask: &ArrayRef) -> VortexResult<MaskExecution> {
55        if !matches!(mask.dtype(), DType::Bool(NonNullable)) {
56            vortex_bail!(
57                "Expected non-nullable boolean array for mask binding, got {}",
58                mask.dtype()
59            );
60        }
61
62        // Check for a constant mask
63        if let Some(scalar) = mask.as_constant() {
64            let constant = scalar
65                .as_bool()
66                .value()
67                .vortex_expect("checked non-nullable");
68            let len = mask.len();
69            if constant {
70                return Ok(MaskExecution::AllTrue(len));
71            } else {
72                return Ok(MaskExecution::AllFalse(len));
73            }
74        }
75
76        // TODO(ngates): we may want to support creating masks from iterator of slices, in which
77        //  case we could check for run-end encoding here?
78
79        // If none of the above patterns match, we fall back to canonicalizing.
80        let execution = self.bind(mask, None)?;
81        Ok(MaskExecution::lazy(move || {
82            let mask = execution.execute()?.into_bool();
83            Ok(Mask::from(mask.bits().clone()))
84        }))
85    }
86}