Skip to main content

vortex_array/expr/exprs/mask/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod kernel;
5use std::fmt::Formatter;
6
7pub use kernel::*;
8use vortex_dtype::DType;
9use vortex_dtype::Nullability;
10use vortex_error::VortexExpect;
11use vortex_error::VortexResult;
12use vortex_error::vortex_ensure;
13use vortex_error::vortex_err;
14use vortex_session::VortexSession;
15
16use crate::ArrayRef;
17use crate::Canonical;
18use crate::IntoArray;
19use crate::arrays::BoolArray;
20use crate::arrays::ConstantArray;
21use crate::arrays::ConstantVTable;
22use crate::arrays::mask_validity_canonical;
23use crate::builtins::ArrayBuiltins;
24use crate::expr::Arity;
25use crate::expr::ChildName;
26use crate::expr::EmptyOptions;
27use crate::expr::ExecutionArgs;
28use crate::expr::ExprId;
29use crate::expr::Expression;
30use crate::expr::Literal;
31use crate::expr::SimplifyCtx;
32use crate::expr::VTable;
33use crate::expr::VTableExt;
34use crate::expr::and;
35use crate::expr::lit;
36use crate::scalar::Scalar;
37
38/// An expression that masks an input based on a boolean mask.
39///
40/// Where the mask is true, the input value is retained; where the mask is false, the output is
41/// null. In other words, this performs an intersection of the input's validity with the mask.
42pub struct Mask;
43
44impl VTable for Mask {
45    type Options = EmptyOptions;
46
47    fn id(&self) -> ExprId {
48        ExprId::from("vortex.mask")
49    }
50
51    fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
52        Ok(Some(vec![]))
53    }
54
55    fn deserialize(
56        &self,
57        _metadata: &[u8],
58        _session: &VortexSession,
59    ) -> VortexResult<Self::Options> {
60        Ok(EmptyOptions)
61    }
62
63    fn arity(&self, _options: &Self::Options) -> Arity {
64        Arity::Exact(2)
65    }
66
67    fn child_name(&self, _options: &Self::Options, child_idx: usize) -> ChildName {
68        match child_idx {
69            0 => ChildName::from("input"),
70            1 => ChildName::from("mask"),
71            _ => unreachable!("Invalid child index {} for Mask expression", child_idx),
72        }
73    }
74
75    fn fmt_sql(
76        &self,
77        _options: &Self::Options,
78        expr: &Expression,
79        f: &mut Formatter<'_>,
80    ) -> std::fmt::Result {
81        write!(f, "mask(")?;
82        expr.child(0).fmt_sql(f)?;
83        write!(f, ", ")?;
84        expr.child(1).fmt_sql(f)?;
85        write!(f, ")")
86    }
87
88    fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
89        vortex_ensure!(
90            arg_dtypes[1] == DType::Bool(Nullability::NonNullable),
91            "The mask argument to 'mask' must be a non-nullable boolean array, got {}",
92            arg_dtypes[1]
93        );
94        Ok(arg_dtypes[0].as_nullable())
95    }
96
97    fn execute(&self, _options: &Self::Options, args: ExecutionArgs) -> VortexResult<ArrayRef> {
98        let [input, mask_array]: [ArrayRef; _] = args
99            .inputs
100            .try_into()
101            .map_err(|_| vortex_err!("Wrong arg count"))?;
102
103        if let Some(result) = execute_constant(&input, &mask_array)? {
104            return Ok(result);
105        }
106
107        execute_canonical(input, mask_array, args.ctx)
108    }
109
110    fn simplify(
111        &self,
112        _options: &Self::Options,
113        expr: &Expression,
114        ctx: &dyn SimplifyCtx,
115    ) -> VortexResult<Option<Expression>> {
116        let Some(mask_lit) = expr.child(1).as_opt::<Literal>() else {
117            return Ok(None);
118        };
119
120        let mask_lit = mask_lit
121            .as_bool()
122            .value()
123            .vortex_expect("Mask must be non-nullable");
124
125        if mask_lit {
126            // Mask is all true, so the output is just the input.
127            Ok(Some(expr.child(0).clone()))
128        } else {
129            // Mask is all false, so the output is all nulls.
130            let input_dtype = ctx.return_dtype(expr.child(0))?;
131            Ok(Some(lit(Scalar::null(input_dtype.as_nullable()))))
132        }
133    }
134
135    fn validity(
136        &self,
137        _options: &Self::Options,
138        expression: &Expression,
139    ) -> VortexResult<Option<Expression>> {
140        Ok(Some(and(
141            expression.child(0).validity()?,
142            expression.child(1).clone(),
143        )))
144    }
145}
146
147/// Try to handle masking when at least one of the input or mask is a constant array.
148///
149/// Returns `Ok(Some(result))` if the constant case was handled, `Ok(None)` if not.
150fn execute_constant(input: &ArrayRef, mask_array: &ArrayRef) -> VortexResult<Option<ArrayRef>> {
151    let len = input.len();
152
153    if let Some(constant_mask) = mask_array.as_opt::<ConstantVTable>() {
154        let mask_value = constant_mask.scalar().as_bool().value().unwrap_or(false);
155        return if mask_value {
156            input.cast(input.dtype().as_nullable()).map(Some)
157        } else {
158            Ok(Some(
159                ConstantArray::new(Scalar::null(input.dtype().as_nullable()), len).into_array(),
160            ))
161        };
162    }
163
164    if let Some(constant_input) = input.as_opt::<ConstantVTable>()
165        && constant_input.scalar().is_null()
166    {
167        return Ok(Some(
168            ConstantArray::new(Scalar::null(input.dtype().as_nullable()), len).into_array(),
169        ));
170    }
171
172    Ok(None)
173}
174
175/// Execute the mask by materializing both inputs to their canonical forms.
176fn execute_canonical(
177    input: ArrayRef,
178    mask_array: ArrayRef,
179    ctx: &mut crate::executor::ExecutionCtx,
180) -> VortexResult<ArrayRef> {
181    let mask_bool = mask_array.execute::<BoolArray>(ctx)?;
182    let validity_mask = vortex_mask::Mask::from(mask_bool.to_bit_buffer());
183
184    let canonical = input.execute::<Canonical>(ctx)?;
185    Ok(mask_validity_canonical(canonical, &validity_mask, ctx)?.into_array())
186}
187
188/// Creates a mask expression that applies the given boolean mask to the input array.
189pub fn mask(array: Expression, mask: Expression) -> Expression {
190    Mask.new_expr(EmptyOptions, [array, mask])
191}
192
193#[cfg(test)]
194mod test {
195    use vortex_dtype::DType;
196    use vortex_dtype::Nullability::Nullable;
197    use vortex_dtype::PType;
198    use vortex_error::VortexExpect;
199
200    use crate::expr::exprs::literal::lit;
201    use crate::expr::exprs::mask::mask;
202    use crate::scalar::Scalar;
203
204    #[test]
205    fn test_simplify() {
206        let input_expr = lit(42u32);
207        let true_mask_expr = lit(true);
208        let false_mask_expr = lit(false);
209
210        let mask_true_expr = mask(input_expr.clone(), true_mask_expr);
211        let simplified_true = mask_true_expr
212            .optimize(&DType::Null)
213            .vortex_expect("Simplification");
214        assert_eq!(&simplified_true, &input_expr);
215
216        let mask_false_expr = mask(input_expr, false_mask_expr);
217        let simplified_false = mask_false_expr
218            .optimize(&DType::Null)
219            .vortex_expect("Simplification");
220        let expected_null_expr = lit(Scalar::null(DType::Primitive(PType::U32, Nullable)));
221        assert_eq!(&simplified_false, &expected_null_expr);
222    }
223}