vortex_array/execution/
mask.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_dtype::DType;
5use vortex_dtype::Nullability::NonNullable;
6use vortex_error::VortexExpect;
7use vortex_error::VortexResult;
8use vortex_error::vortex_bail;
9use vortex_mask::Mask;
10
11use crate::ArrayRef;
12use crate::execution::BindCtx;
13
14pub enum MaskExecution {
15    AllTrue(usize),
16    AllFalse(usize),
17    Lazy(Box<dyn FnOnce() -> VortexResult<Mask> + Send + 'static>),
18}
19
20impl MaskExecution {
21    pub fn lazy<F: FnOnce() -> VortexResult<Mask> + Send + 'static>(f: F) -> MaskExecution {
22        MaskExecution::Lazy(Box::new(f))
23    }
24
25    pub fn execute(self) -> VortexResult<Mask> {
26        match self {
27            MaskExecution::AllTrue(len) => Ok(Mask::new_true(len)),
28            MaskExecution::AllFalse(len) => Ok(Mask::new_false(len)),
29            MaskExecution::Lazy(f) => f(),
30        }
31    }
32}
33
34impl dyn BindCtx + '_ {
35    /// Bind an optional selection mask into a `MaskExecution`.
36    ///
37    /// The caller must provide a mask length to handle the case where no mask is provided.
38    pub fn bind_selection(
39        &mut self,
40        mask_len: usize,
41        mask: Option<&ArrayRef>,
42    ) -> VortexResult<MaskExecution> {
43        match mask {
44            Some(mask) => {
45                assert_eq!(mask.len(), mask_len);
46                self.bind_mask(mask)
47            }
48            None => Ok(MaskExecution::AllTrue(mask_len)),
49        }
50    }
51
52    /// Bind a non-nullable boolean array into a `MaskExecution`.
53    ///
54    /// This binding will optimize for constant arrays or other array types that can be more
55    /// efficiently converted into a `Mask`.
56    pub fn bind_mask(&mut self, mask: &ArrayRef) -> VortexResult<MaskExecution> {
57        if !matches!(mask.dtype(), DType::Bool(NonNullable)) {
58            vortex_bail!(
59                "Expected non-nullable boolean array for mask binding, got {}",
60                mask.dtype()
61            );
62        }
63
64        // Check for a constant mask
65        if let Some(scalar) = mask.as_constant() {
66            let constant = scalar
67                .as_bool()
68                .value()
69                .vortex_expect("checked non-nullable");
70            let len = mask.len();
71            if constant {
72                return Ok(MaskExecution::AllTrue(len));
73            } else {
74                return Ok(MaskExecution::AllFalse(len));
75            }
76        }
77
78        // TODO(ngates): we may want to support creating masks from iterator of slices, in which
79        //  case we could check for run-end encoding here?
80
81        // If none of the above patterns match, we fall back to canonicalizing.
82        let execution = self.bind(mask, None)?;
83        Ok(MaskExecution::lazy(move || {
84            let mask = execution.execute()?.into_bool();
85            Ok(Mask::from(mask.bits().clone()))
86        }))
87    }
88}