Skip to main content

vortex_array/scalar_fn/fns/mask/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod kernel;
5
6pub use kernel::*;
7use vortex_error::VortexExpect;
8use vortex_error::VortexResult;
9use vortex_error::vortex_ensure;
10use vortex_session::VortexSession;
11use vortex_session::registry::CachedId;
12
13use crate::ArrayRef;
14use crate::Canonical;
15use crate::ExecutionCtx;
16use crate::IntoArray;
17use crate::arrays::Constant;
18use crate::arrays::ConstantArray;
19use crate::arrays::masked::mask_validity_canonical;
20use crate::builtins::ArrayBuiltins;
21use crate::child_to_validity;
22use crate::dtype::DType;
23use crate::dtype::Nullability;
24use crate::expr::Expression;
25use crate::expr::and;
26use crate::expr::lit;
27use crate::scalar::Scalar;
28use crate::scalar_fn::Arity;
29use crate::scalar_fn::ChildName;
30use crate::scalar_fn::EmptyOptions;
31use crate::scalar_fn::ExecutionArgs;
32use crate::scalar_fn::ScalarFnId;
33use crate::scalar_fn::ScalarFnVTable;
34use crate::scalar_fn::SimplifyCtx;
35use crate::scalar_fn::fns::literal::Literal;
36
37/// An expression that masks an input based on a boolean mask.
38///
39/// Where the mask is true, the input value is retained; where the mask is false, the output is
40/// null. In other words, this performs an intersection of the input's validity with the mask.
41#[derive(Clone)]
42pub struct Mask;
43
44impl ScalarFnVTable for Mask {
45    type Options = EmptyOptions;
46
47    fn id(&self) -> ScalarFnId {
48        static ID: CachedId = CachedId::new("vortex.mask");
49        *ID
50    }
51
52    fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
53        Ok(Some(vec![]))
54    }
55
56    fn deserialize(
57        &self,
58        _metadata: &[u8],
59        _session: &VortexSession,
60    ) -> VortexResult<Self::Options> {
61        Ok(EmptyOptions)
62    }
63
64    fn arity(&self, _options: &Self::Options) -> Arity {
65        Arity::Exact(2)
66    }
67
68    fn child_name(&self, _options: &Self::Options, child_idx: usize) -> ChildName {
69        match child_idx {
70            0 => ChildName::from("input"),
71            1 => ChildName::from("mask"),
72            _ => unreachable!("Invalid child index {} for Mask expression", child_idx),
73        }
74    }
75
76    fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
77        vortex_ensure!(
78            arg_dtypes[1] == DType::Bool(Nullability::NonNullable),
79            "The mask argument to 'mask' must be a non-nullable boolean array, got {}",
80            arg_dtypes[1]
81        );
82        Ok(arg_dtypes[0].as_nullable())
83    }
84
85    fn execute(
86        &self,
87        _options: &Self::Options,
88        args: &dyn ExecutionArgs,
89        ctx: &mut ExecutionCtx,
90    ) -> VortexResult<ArrayRef> {
91        let input = args.get(0)?;
92        let mask_array = args.get(1)?;
93
94        if let Some(result) = execute_constant(&input, &mask_array)? {
95            return Ok(result);
96        }
97
98        execute_canonical(input, mask_array, ctx)
99    }
100
101    fn simplify(
102        &self,
103        _options: &Self::Options,
104        expr: &Expression,
105        ctx: &dyn SimplifyCtx,
106    ) -> VortexResult<Option<Expression>> {
107        let Some(mask_lit) = expr.child(1).as_opt::<Literal>() else {
108            return Ok(None);
109        };
110
111        let mask_lit = mask_lit
112            .as_bool()
113            .value()
114            .vortex_expect("Mask must be non-nullable");
115
116        if mask_lit {
117            // Mask is all true, so the output is just the input.
118            Ok(Some(expr.child(0).clone()))
119        } else {
120            // Mask is all false, so the output is all nulls.
121            let input_dtype = ctx.return_dtype(expr.child(0))?;
122            Ok(Some(lit(Scalar::null(input_dtype.as_nullable()))))
123        }
124    }
125
126    fn validity(
127        &self,
128        _options: &Self::Options,
129        expression: &Expression,
130    ) -> VortexResult<Option<Expression>> {
131        Ok(Some(and(
132            expression.child(0).validity()?,
133            expression.child(1).clone(),
134        )))
135    }
136}
137
138/// Try to handle masking when at least one of the input or mask is a constant array.
139///
140/// Returns `Ok(Some(result))` if the constant case was handled, `Ok(None)` if not.
141fn execute_constant(input: &ArrayRef, mask_array: &ArrayRef) -> VortexResult<Option<ArrayRef>> {
142    let len = input.len();
143
144    if let Some(constant_mask) = mask_array.as_opt::<Constant>() {
145        let mask_value = constant_mask.scalar().as_bool().value().unwrap_or(false);
146        return if mask_value {
147            input.cast(input.dtype().as_nullable()).map(Some)
148        } else {
149            Ok(Some(
150                ConstantArray::new(Scalar::null(input.dtype().as_nullable()), len).into_array(),
151            ))
152        };
153    }
154
155    if let Some(constant_input) = input.as_opt::<Constant>()
156        && constant_input.scalar().is_null()
157    {
158        return Ok(Some(
159            ConstantArray::new(Scalar::null(input.dtype().as_nullable()), len).into_array(),
160        ));
161    }
162
163    Ok(None)
164}
165
166/// Execute the mask by materializing both inputs to their canonical forms.
167fn execute_canonical(
168    input: ArrayRef,
169    mask_array: ArrayRef,
170    ctx: &mut ExecutionCtx,
171) -> VortexResult<ArrayRef> {
172    let validity = child_to_validity(Some(&mask_array), Nullability::Nullable);
173    let canonical = input.execute::<Canonical>(ctx)?;
174    Ok(mask_validity_canonical(canonical, validity, ctx)?.into_array())
175}
176
177#[cfg(test)]
178mod test {
179    use vortex_error::VortexExpect;
180
181    use crate::dtype::DType;
182    use crate::dtype::Nullability::Nullable;
183    use crate::dtype::PType;
184    use crate::expr::lit;
185    use crate::expr::mask;
186    use crate::scalar::Scalar;
187
188    #[test]
189    fn test_simplify() {
190        let input_expr = lit(42u32);
191        let true_mask_expr = lit(true);
192        let false_mask_expr = lit(false);
193
194        let mask_true_expr = mask(input_expr.clone(), true_mask_expr);
195        let simplified_true = mask_true_expr
196            .optimize(&DType::Null)
197            .vortex_expect("Simplification");
198        assert_eq!(&simplified_true, &input_expr);
199
200        let mask_false_expr = mask(input_expr, false_mask_expr);
201        let simplified_false = mask_false_expr
202            .optimize(&DType::Null)
203            .vortex_expect("Simplification");
204        let expected_null_expr = lit(Scalar::null(DType::Primitive(PType::U32, Nullable)));
205        assert_eq!(&simplified_false, &expected_null_expr);
206    }
207}