Skip to main content

vortex_array/scalar_fn/fns/mask/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod kernel;
5use std::fmt::Formatter;
6
7pub use kernel::*;
8use vortex_error::VortexExpect;
9use vortex_error::VortexResult;
10use vortex_error::vortex_ensure;
11use vortex_session::VortexSession;
12
13use crate::ArrayRef;
14use crate::Canonical;
15use crate::ExecutionCtx;
16use crate::IntoArray;
17use crate::arrays::BoolArray;
18use crate::arrays::Constant;
19use crate::arrays::ConstantArray;
20use crate::arrays::bool::BoolArrayExt;
21use crate::arrays::masked::mask_validity_canonical;
22use crate::builtins::ArrayBuiltins;
23use crate::dtype::DType;
24use crate::dtype::Nullability;
25use crate::expr::Expression;
26use crate::expr::and;
27use crate::expr::lit;
28use crate::scalar::Scalar;
29use crate::scalar_fn::Arity;
30use crate::scalar_fn::ChildName;
31use crate::scalar_fn::EmptyOptions;
32use crate::scalar_fn::ExecutionArgs;
33use crate::scalar_fn::ScalarFnId;
34use crate::scalar_fn::ScalarFnVTable;
35use crate::scalar_fn::SimplifyCtx;
36use crate::scalar_fn::fns::literal::Literal;
37
38/// An expression that masks an input based on a boolean mask.
39///
40/// Where the mask is true, the input value is retained; where the mask is false, the output is
41/// null. In other words, this performs an intersection of the input's validity with the mask.
42#[derive(Clone)]
43pub struct Mask;
44
45impl ScalarFnVTable for Mask {
46    type Options = EmptyOptions;
47
48    fn id(&self) -> ScalarFnId {
49        ScalarFnId::from("vortex.mask")
50    }
51
52    fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
53        Ok(Some(vec![]))
54    }
55
56    fn deserialize(
57        &self,
58        _metadata: &[u8],
59        _session: &VortexSession,
60    ) -> VortexResult<Self::Options> {
61        Ok(EmptyOptions)
62    }
63
64    fn arity(&self, _options: &Self::Options) -> Arity {
65        Arity::Exact(2)
66    }
67
68    fn child_name(&self, _options: &Self::Options, child_idx: usize) -> ChildName {
69        match child_idx {
70            0 => ChildName::from("input"),
71            1 => ChildName::from("mask"),
72            _ => unreachable!("Invalid child index {} for Mask expression", child_idx),
73        }
74    }
75
76    fn fmt_sql(
77        &self,
78        _options: &Self::Options,
79        expr: &Expression,
80        f: &mut Formatter<'_>,
81    ) -> std::fmt::Result {
82        write!(f, "mask(")?;
83        expr.child(0).fmt_sql(f)?;
84        write!(f, ", ")?;
85        expr.child(1).fmt_sql(f)?;
86        write!(f, ")")
87    }
88
89    fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
90        vortex_ensure!(
91            arg_dtypes[1] == DType::Bool(Nullability::NonNullable),
92            "The mask argument to 'mask' must be a non-nullable boolean array, got {}",
93            arg_dtypes[1]
94        );
95        Ok(arg_dtypes[0].as_nullable())
96    }
97
98    fn execute(
99        &self,
100        _options: &Self::Options,
101        args: &dyn ExecutionArgs,
102        ctx: &mut ExecutionCtx,
103    ) -> VortexResult<ArrayRef> {
104        let input = args.get(0)?;
105        let mask_array = args.get(1)?;
106
107        if let Some(result) = execute_constant(&input, &mask_array)? {
108            return Ok(result);
109        }
110
111        execute_canonical(input, mask_array, ctx)
112    }
113
114    fn simplify(
115        &self,
116        _options: &Self::Options,
117        expr: &Expression,
118        ctx: &dyn SimplifyCtx,
119    ) -> VortexResult<Option<Expression>> {
120        let Some(mask_lit) = expr.child(1).as_opt::<Literal>() else {
121            return Ok(None);
122        };
123
124        let mask_lit = mask_lit
125            .as_bool()
126            .value()
127            .vortex_expect("Mask must be non-nullable");
128
129        if mask_lit {
130            // Mask is all true, so the output is just the input.
131            Ok(Some(expr.child(0).clone()))
132        } else {
133            // Mask is all false, so the output is all nulls.
134            let input_dtype = ctx.return_dtype(expr.child(0))?;
135            Ok(Some(lit(Scalar::null(input_dtype.as_nullable()))))
136        }
137    }
138
139    fn validity(
140        &self,
141        _options: &Self::Options,
142        expression: &Expression,
143    ) -> VortexResult<Option<Expression>> {
144        Ok(Some(and(
145            expression.child(0).validity()?,
146            expression.child(1).clone(),
147        )))
148    }
149}
150
151/// Try to handle masking when at least one of the input or mask is a constant array.
152///
153/// Returns `Ok(Some(result))` if the constant case was handled, `Ok(None)` if not.
154fn execute_constant(input: &ArrayRef, mask_array: &ArrayRef) -> VortexResult<Option<ArrayRef>> {
155    let len = input.len();
156
157    if let Some(constant_mask) = mask_array.as_opt::<Constant>() {
158        let mask_value = constant_mask.scalar().as_bool().value().unwrap_or(false);
159        return if mask_value {
160            input.cast(input.dtype().as_nullable()).map(Some)
161        } else {
162            Ok(Some(
163                ConstantArray::new(Scalar::null(input.dtype().as_nullable()), len).into_array(),
164            ))
165        };
166    }
167
168    if let Some(constant_input) = input.as_opt::<Constant>()
169        && constant_input.scalar().is_null()
170    {
171        return Ok(Some(
172            ConstantArray::new(Scalar::null(input.dtype().as_nullable()), len).into_array(),
173        ));
174    }
175
176    Ok(None)
177}
178
179/// Execute the mask by materializing both inputs to their canonical forms.
180fn execute_canonical(
181    input: ArrayRef,
182    mask_array: ArrayRef,
183    ctx: &mut ExecutionCtx,
184) -> VortexResult<ArrayRef> {
185    let mask_bool = mask_array.execute::<BoolArray>(ctx)?;
186    let validity_mask = vortex_mask::Mask::from(mask_bool.to_bit_buffer());
187
188    let canonical = input.execute::<Canonical>(ctx)?;
189    Ok(mask_validity_canonical(canonical, &validity_mask, ctx)?.into_array())
190}
191
192#[cfg(test)]
193mod test {
194    use vortex_error::VortexExpect;
195
196    use crate::dtype::DType;
197    use crate::dtype::Nullability::Nullable;
198    use crate::dtype::PType;
199    use crate::expr::lit;
200    use crate::expr::mask;
201    use crate::scalar::Scalar;
202
203    #[test]
204    fn test_simplify() {
205        let input_expr = lit(42u32);
206        let true_mask_expr = lit(true);
207        let false_mask_expr = lit(false);
208
209        let mask_true_expr = mask(input_expr.clone(), true_mask_expr);
210        let simplified_true = mask_true_expr
211            .optimize(&DType::Null)
212            .vortex_expect("Simplification");
213        assert_eq!(&simplified_true, &input_expr);
214
215        let mask_false_expr = mask(input_expr, false_mask_expr);
216        let simplified_false = mask_false_expr
217            .optimize(&DType::Null)
218            .vortex_expect("Simplification");
219        let expected_null_expr = lit(Scalar::null(DType::Primitive(PType::U32, Nullable)));
220        assert_eq!(&simplified_false, &expected_null_expr);
221    }
222}