vortex_array/scalar_fn/fns/mask/
mod.rs1mod kernel;
5use std::fmt::Formatter;
6
7pub use kernel::*;
8use vortex_error::VortexExpect;
9use vortex_error::VortexResult;
10use vortex_error::vortex_ensure;
11use vortex_error::vortex_err;
12use vortex_session::VortexSession;
13
14use crate::ArrayRef;
15use crate::Canonical;
16use crate::IntoArray;
17use crate::arrays::BoolArray;
18use crate::arrays::ConstantArray;
19use crate::arrays::ConstantVTable;
20use crate::arrays::mask_validity_canonical;
21use crate::builtins::ArrayBuiltins;
22use crate::dtype::DType;
23use crate::dtype::Nullability;
24use crate::expr::Expression;
25use crate::expr::and;
26use crate::expr::lit;
27use crate::scalar::Scalar;
28use crate::scalar_fn::Arity;
29use crate::scalar_fn::ChildName;
30use crate::scalar_fn::EmptyOptions;
31use crate::scalar_fn::ExecutionArgs;
32use crate::scalar_fn::ScalarFnId;
33use crate::scalar_fn::ScalarFnVTable;
34use crate::scalar_fn::SimplifyCtx;
35use crate::scalar_fn::fns::literal::Literal;
36
37#[derive(Clone)]
42pub struct Mask;
43
44impl ScalarFnVTable for Mask {
45 type Options = EmptyOptions;
46
47 fn id(&self) -> ScalarFnId {
48 ScalarFnId::from("vortex.mask")
49 }
50
51 fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
52 Ok(Some(vec![]))
53 }
54
55 fn deserialize(
56 &self,
57 _metadata: &[u8],
58 _session: &VortexSession,
59 ) -> VortexResult<Self::Options> {
60 Ok(EmptyOptions)
61 }
62
63 fn arity(&self, _options: &Self::Options) -> Arity {
64 Arity::Exact(2)
65 }
66
67 fn child_name(&self, _options: &Self::Options, child_idx: usize) -> ChildName {
68 match child_idx {
69 0 => ChildName::from("input"),
70 1 => ChildName::from("mask"),
71 _ => unreachable!("Invalid child index {} for Mask expression", child_idx),
72 }
73 }
74
75 fn fmt_sql(
76 &self,
77 _options: &Self::Options,
78 expr: &Expression,
79 f: &mut Formatter<'_>,
80 ) -> std::fmt::Result {
81 write!(f, "mask(")?;
82 expr.child(0).fmt_sql(f)?;
83 write!(f, ", ")?;
84 expr.child(1).fmt_sql(f)?;
85 write!(f, ")")
86 }
87
88 fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
89 vortex_ensure!(
90 arg_dtypes[1] == DType::Bool(Nullability::NonNullable),
91 "The mask argument to 'mask' must be a non-nullable boolean array, got {}",
92 arg_dtypes[1]
93 );
94 Ok(arg_dtypes[0].as_nullable())
95 }
96
97 fn execute(&self, _options: &Self::Options, args: ExecutionArgs) -> VortexResult<ArrayRef> {
98 let [input, mask_array]: [ArrayRef; _] = args
99 .inputs
100 .try_into()
101 .map_err(|_| vortex_err!("Wrong arg count"))?;
102
103 if let Some(result) = execute_constant(&input, &mask_array)? {
104 return Ok(result);
105 }
106
107 execute_canonical(input, mask_array, args.ctx)
108 }
109
110 fn simplify(
111 &self,
112 _options: &Self::Options,
113 expr: &Expression,
114 ctx: &dyn SimplifyCtx,
115 ) -> VortexResult<Option<Expression>> {
116 let Some(mask_lit) = expr.child(1).as_opt::<Literal>() else {
117 return Ok(None);
118 };
119
120 let mask_lit = mask_lit
121 .as_bool()
122 .value()
123 .vortex_expect("Mask must be non-nullable");
124
125 if mask_lit {
126 Ok(Some(expr.child(0).clone()))
128 } else {
129 let input_dtype = ctx.return_dtype(expr.child(0))?;
131 Ok(Some(lit(Scalar::null(input_dtype.as_nullable()))))
132 }
133 }
134
135 fn validity(
136 &self,
137 _options: &Self::Options,
138 expression: &Expression,
139 ) -> VortexResult<Option<Expression>> {
140 Ok(Some(and(
141 expression.child(0).validity()?,
142 expression.child(1).clone(),
143 )))
144 }
145}
146
147fn execute_constant(input: &ArrayRef, mask_array: &ArrayRef) -> VortexResult<Option<ArrayRef>> {
151 let len = input.len();
152
153 if let Some(constant_mask) = mask_array.as_opt::<ConstantVTable>() {
154 let mask_value = constant_mask.scalar().as_bool().value().unwrap_or(false);
155 return if mask_value {
156 input.cast(input.dtype().as_nullable()).map(Some)
157 } else {
158 Ok(Some(
159 ConstantArray::new(Scalar::null(input.dtype().as_nullable()), len).into_array(),
160 ))
161 };
162 }
163
164 if let Some(constant_input) = input.as_opt::<ConstantVTable>()
165 && constant_input.scalar().is_null()
166 {
167 return Ok(Some(
168 ConstantArray::new(Scalar::null(input.dtype().as_nullable()), len).into_array(),
169 ));
170 }
171
172 Ok(None)
173}
174
175fn execute_canonical(
177 input: ArrayRef,
178 mask_array: ArrayRef,
179 ctx: &mut crate::executor::ExecutionCtx,
180) -> VortexResult<ArrayRef> {
181 let mask_bool = mask_array.execute::<BoolArray>(ctx)?;
182 let validity_mask = vortex_mask::Mask::from(mask_bool.to_bit_buffer());
183
184 let canonical = input.execute::<Canonical>(ctx)?;
185 Ok(mask_validity_canonical(canonical, &validity_mask, ctx)?.into_array())
186}
187
188#[cfg(test)]
189mod test {
190 use vortex_error::VortexExpect;
191
192 use crate::dtype::DType;
193 use crate::dtype::Nullability::Nullable;
194 use crate::dtype::PType;
195 use crate::expr::lit;
196 use crate::expr::mask;
197 use crate::scalar::Scalar;
198
199 #[test]
200 fn test_simplify() {
201 let input_expr = lit(42u32);
202 let true_mask_expr = lit(true);
203 let false_mask_expr = lit(false);
204
205 let mask_true_expr = mask(input_expr.clone(), true_mask_expr);
206 let simplified_true = mask_true_expr
207 .optimize(&DType::Null)
208 .vortex_expect("Simplification");
209 assert_eq!(&simplified_true, &input_expr);
210
211 let mask_false_expr = mask(input_expr, false_mask_expr);
212 let simplified_false = mask_false_expr
213 .optimize(&DType::Null)
214 .vortex_expect("Simplification");
215 let expected_null_expr = lit(Scalar::null(DType::Primitive(PType::U32, Nullable)));
216 assert_eq!(&simplified_false, &expected_null_expr);
217 }
218}