Skip to main content

vortex_array/scalar_fn/fns/mask/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod kernel;
5
6pub use kernel::*;
7use vortex_error::VortexExpect;
8use vortex_error::VortexResult;
9use vortex_error::vortex_ensure;
10use vortex_session::VortexSession;
11
12use crate::ArrayRef;
13use crate::Canonical;
14use crate::ExecutionCtx;
15use crate::IntoArray;
16use crate::arrays::Constant;
17use crate::arrays::ConstantArray;
18use crate::arrays::masked::mask_validity_canonical;
19use crate::builtins::ArrayBuiltins;
20use crate::child_to_validity;
21use crate::dtype::DType;
22use crate::dtype::Nullability;
23use crate::expr::Expression;
24use crate::expr::and;
25use crate::expr::lit;
26use crate::scalar::Scalar;
27use crate::scalar_fn::Arity;
28use crate::scalar_fn::ChildName;
29use crate::scalar_fn::EmptyOptions;
30use crate::scalar_fn::ExecutionArgs;
31use crate::scalar_fn::ScalarFnId;
32use crate::scalar_fn::ScalarFnVTable;
33use crate::scalar_fn::SimplifyCtx;
34use crate::scalar_fn::fns::literal::Literal;
35
36/// An expression that masks an input based on a boolean mask.
37///
38/// Where the mask is true, the input value is retained; where the mask is false, the output is
39/// null. In other words, this performs an intersection of the input's validity with the mask.
40#[derive(Clone)]
41pub struct Mask;
42
43impl ScalarFnVTable for Mask {
44    type Options = EmptyOptions;
45
46    fn id(&self) -> ScalarFnId {
47        ScalarFnId::new("vortex.mask")
48    }
49
50    fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
51        Ok(Some(vec![]))
52    }
53
54    fn deserialize(
55        &self,
56        _metadata: &[u8],
57        _session: &VortexSession,
58    ) -> VortexResult<Self::Options> {
59        Ok(EmptyOptions)
60    }
61
62    fn arity(&self, _options: &Self::Options) -> Arity {
63        Arity::Exact(2)
64    }
65
66    fn child_name(&self, _options: &Self::Options, child_idx: usize) -> ChildName {
67        match child_idx {
68            0 => ChildName::from("input"),
69            1 => ChildName::from("mask"),
70            _ => unreachable!("Invalid child index {} for Mask expression", child_idx),
71        }
72    }
73
74    fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
75        vortex_ensure!(
76            arg_dtypes[1] == DType::Bool(Nullability::NonNullable),
77            "The mask argument to 'mask' must be a non-nullable boolean array, got {}",
78            arg_dtypes[1]
79        );
80        Ok(arg_dtypes[0].as_nullable())
81    }
82
83    fn execute(
84        &self,
85        _options: &Self::Options,
86        args: &dyn ExecutionArgs,
87        ctx: &mut ExecutionCtx,
88    ) -> VortexResult<ArrayRef> {
89        let input = args.get(0)?;
90        let mask_array = args.get(1)?;
91
92        if let Some(result) = execute_constant(&input, &mask_array)? {
93            return Ok(result);
94        }
95
96        execute_canonical(input, mask_array, ctx)
97    }
98
99    fn simplify(
100        &self,
101        _options: &Self::Options,
102        expr: &Expression,
103        ctx: &dyn SimplifyCtx,
104    ) -> VortexResult<Option<Expression>> {
105        let Some(mask_lit) = expr.child(1).as_opt::<Literal>() else {
106            return Ok(None);
107        };
108
109        let mask_lit = mask_lit
110            .as_bool()
111            .value()
112            .vortex_expect("Mask must be non-nullable");
113
114        if mask_lit {
115            // Mask is all true, so the output is just the input.
116            Ok(Some(expr.child(0).clone()))
117        } else {
118            // Mask is all false, so the output is all nulls.
119            let input_dtype = ctx.return_dtype(expr.child(0))?;
120            Ok(Some(lit(Scalar::null(input_dtype.as_nullable()))))
121        }
122    }
123
124    fn validity(
125        &self,
126        _options: &Self::Options,
127        expression: &Expression,
128    ) -> VortexResult<Option<Expression>> {
129        Ok(Some(and(
130            expression.child(0).validity()?,
131            expression.child(1).clone(),
132        )))
133    }
134}
135
136/// Try to handle masking when at least one of the input or mask is a constant array.
137///
138/// Returns `Ok(Some(result))` if the constant case was handled, `Ok(None)` if not.
139fn execute_constant(input: &ArrayRef, mask_array: &ArrayRef) -> VortexResult<Option<ArrayRef>> {
140    let len = input.len();
141
142    if let Some(constant_mask) = mask_array.as_opt::<Constant>() {
143        let mask_value = constant_mask.scalar().as_bool().value().unwrap_or(false);
144        return if mask_value {
145            input.cast(input.dtype().as_nullable()).map(Some)
146        } else {
147            Ok(Some(
148                ConstantArray::new(Scalar::null(input.dtype().as_nullable()), len).into_array(),
149            ))
150        };
151    }
152
153    if let Some(constant_input) = input.as_opt::<Constant>()
154        && constant_input.scalar().is_null()
155    {
156        return Ok(Some(
157            ConstantArray::new(Scalar::null(input.dtype().as_nullable()), len).into_array(),
158        ));
159    }
160
161    Ok(None)
162}
163
164/// Execute the mask by materializing both inputs to their canonical forms.
165fn execute_canonical(
166    input: ArrayRef,
167    mask_array: ArrayRef,
168    ctx: &mut ExecutionCtx,
169) -> VortexResult<ArrayRef> {
170    let validity = child_to_validity(Some(&mask_array), Nullability::Nullable);
171    let canonical = input.execute::<Canonical>(ctx)?;
172    Ok(mask_validity_canonical(canonical, validity, ctx)?.into_array())
173}
174
175#[cfg(test)]
176mod test {
177    use vortex_error::VortexExpect;
178
179    use crate::dtype::DType;
180    use crate::dtype::Nullability::Nullable;
181    use crate::dtype::PType;
182    use crate::expr::lit;
183    use crate::expr::mask;
184    use crate::scalar::Scalar;
185
186    #[test]
187    fn test_simplify() {
188        let input_expr = lit(42u32);
189        let true_mask_expr = lit(true);
190        let false_mask_expr = lit(false);
191
192        let mask_true_expr = mask(input_expr.clone(), true_mask_expr);
193        let simplified_true = mask_true_expr
194            .optimize(&DType::Null)
195            .vortex_expect("Simplification");
196        assert_eq!(&simplified_true, &input_expr);
197
198        let mask_false_expr = mask(input_expr, false_mask_expr);
199        let simplified_false = mask_false_expr
200            .optimize(&DType::Null)
201            .vortex_expect("Simplification");
202        let expected_null_expr = lit(Scalar::null(DType::Primitive(PType::U32, Nullable)));
203        assert_eq!(&simplified_false, &expected_null_expr);
204    }
205}