vortex_array/expr/exprs/
mask.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Formatter;
5use std::ops::Not;
6
7use vortex_dtype::DType;
8use vortex_dtype::Nullability;
9use vortex_error::VortexExpect;
10use vortex_error::VortexResult;
11use vortex_error::vortex_ensure;
12use vortex_error::vortex_err;
13use vortex_vector::BoolDatum;
14use vortex_vector::Datum;
15use vortex_vector::ScalarOps;
16use vortex_vector::VectorMutOps;
17use vortex_vector::VectorOps;
18
19use crate::Array;
20use crate::ArrayRef;
21use crate::expr::Arity;
22use crate::expr::ChildName;
23use crate::expr::EmptyOptions;
24use crate::expr::ExecutionArgs;
25use crate::expr::ExprId;
26use crate::expr::Expression;
27use crate::expr::VTable;
28use crate::expr::VTableExt;
29
30/// An expression that masks an input based on a boolean mask.
31///
32/// Where the mask is true, the input value is retained; where the mask is false, the output is
33/// null. In other words, this performs an intersection of the input's validity with the mask.
34pub struct Mask;
35
36impl VTable for Mask {
37    type Options = EmptyOptions;
38
39    fn id(&self) -> ExprId {
40        ExprId::from("vortex.mask")
41    }
42
43    fn arity(&self, _options: &Self::Options) -> Arity {
44        Arity::Exact(2)
45    }
46
47    fn child_name(&self, _options: &Self::Options, child_idx: usize) -> ChildName {
48        match child_idx {
49            0 => ChildName::from("input"),
50            1 => ChildName::from("mask"),
51            _ => unreachable!("Invalid child index {} for Mask expression", child_idx),
52        }
53    }
54
55    fn fmt_sql(
56        &self,
57        _options: &Self::Options,
58        expr: &Expression,
59        f: &mut Formatter<'_>,
60    ) -> std::fmt::Result {
61        write!(f, "mask(")?;
62        expr.child(0).fmt_sql(f)?;
63        write!(f, ", ")?;
64        expr.child(1).fmt_sql(f)?;
65        write!(f, ")")
66    }
67
68    fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
69        vortex_ensure!(
70            arg_dtypes[1] == DType::Bool(Nullability::NonNullable),
71            "The mask argument to 'mask' must be a non-nullable boolean array, got {}",
72            arg_dtypes[1]
73        );
74        Ok(arg_dtypes[0].as_nullable())
75    }
76
77    fn evaluate(
78        &self,
79        _options: &Self::Options,
80        expr: &Expression,
81        scope: &ArrayRef,
82    ) -> VortexResult<ArrayRef> {
83        let child = expr.child(0).evaluate(scope)?;
84
85        // Invert the validity mask - we want to set values to null where validity is false.
86        let inverted_mask = child.validity_mask().not();
87
88        crate::compute::mask(&child, &inverted_mask)
89    }
90
91    fn execute(&self, _options: &Self::Options, args: ExecutionArgs) -> VortexResult<Datum> {
92        let [input, mask]: [Datum; _] = args
93            .datums
94            .try_into()
95            .map_err(|_| vortex_err!("Wrong arg count"))?;
96        let mask = mask.into_bool();
97
98        match (input, mask) {
99            (Datum::Scalar(input), BoolDatum::Scalar(mask)) => {
100                let mut result = input;
101                result.mask_validity(mask.value().vortex_expect("mask is non-nullable"));
102                Ok(Datum::Scalar(result))
103            }
104            (Datum::Scalar(input), BoolDatum::Vector(mask)) => {
105                let mut result = input.repeat(args.row_count).freeze();
106                result.mask_validity(&vortex_mask::Mask::from(mask.into_bits()));
107                Ok(Datum::Vector(result))
108            }
109            (Datum::Vector(input_array), BoolDatum::Scalar(mask)) => {
110                let mut result = input_array;
111                result.mask_validity(&vortex_mask::Mask::new(
112                    args.row_count,
113                    mask.value().vortex_expect("mask is non-nullable"),
114                ));
115                Ok(Datum::Vector(result))
116            }
117            (Datum::Vector(input_array), BoolDatum::Vector(mask)) => {
118                let mut result = input_array;
119                result.mask_validity(&vortex_mask::Mask::from(mask.into_bits()));
120                Ok(Datum::Vector(result))
121            }
122        }
123    }
124}
125
126/// Creates a mask expression that applies the given boolean mask to the input array.
127pub fn mask(array: Expression, mask: Expression) -> Expression {
128    Mask.new_expr(EmptyOptions, [array, mask])
129}