vortex_array/scalar_fn/fns/mask/
mod.rs1mod kernel;
5
6pub use kernel::*;
7use vortex_error::VortexExpect;
8use vortex_error::VortexResult;
9use vortex_error::vortex_ensure;
10use vortex_session::VortexSession;
11use vortex_session::registry::CachedId;
12
13use crate::ArrayRef;
14use crate::Canonical;
15use crate::ExecutionCtx;
16use crate::IntoArray;
17use crate::arrays::Constant;
18use crate::arrays::ConstantArray;
19use crate::arrays::masked::mask_validity_canonical;
20use crate::builtins::ArrayBuiltins;
21use crate::child_to_validity;
22use crate::dtype::DType;
23use crate::dtype::Nullability;
24use crate::expr::Expression;
25use crate::expr::and;
26use crate::expr::lit;
27use crate::scalar::Scalar;
28use crate::scalar_fn::Arity;
29use crate::scalar_fn::ChildName;
30use crate::scalar_fn::EmptyOptions;
31use crate::scalar_fn::ExecutionArgs;
32use crate::scalar_fn::ScalarFnId;
33use crate::scalar_fn::ScalarFnVTable;
34use crate::scalar_fn::SimplifyCtx;
35use crate::scalar_fn::fns::literal::Literal;
36
37#[derive(Clone)]
42pub struct Mask;
43
44impl ScalarFnVTable for Mask {
45 type Options = EmptyOptions;
46
47 fn id(&self) -> ScalarFnId {
48 static ID: CachedId = CachedId::new("vortex.mask");
49 *ID
50 }
51
52 fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
53 Ok(Some(vec![]))
54 }
55
56 fn deserialize(
57 &self,
58 _metadata: &[u8],
59 _session: &VortexSession,
60 ) -> VortexResult<Self::Options> {
61 Ok(EmptyOptions)
62 }
63
64 fn arity(&self, _options: &Self::Options) -> Arity {
65 Arity::Exact(2)
66 }
67
68 fn child_name(&self, _options: &Self::Options, child_idx: usize) -> ChildName {
69 match child_idx {
70 0 => ChildName::from("input"),
71 1 => ChildName::from("mask"),
72 _ => unreachable!("Invalid child index {} for Mask expression", child_idx),
73 }
74 }
75
76 fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
77 vortex_ensure!(
78 arg_dtypes[1] == DType::Bool(Nullability::NonNullable),
79 "The mask argument to 'mask' must be a non-nullable boolean array, got {}",
80 arg_dtypes[1]
81 );
82 Ok(arg_dtypes[0].as_nullable())
83 }
84
85 fn execute(
86 &self,
87 _options: &Self::Options,
88 args: &dyn ExecutionArgs,
89 ctx: &mut ExecutionCtx,
90 ) -> VortexResult<ArrayRef> {
91 let input = args.get(0)?;
92 let mask_array = args.get(1)?;
93
94 if let Some(result) = execute_constant(&input, &mask_array)? {
95 return Ok(result);
96 }
97
98 execute_canonical(input, mask_array, ctx)
99 }
100
101 fn simplify(
102 &self,
103 _options: &Self::Options,
104 expr: &Expression,
105 ctx: &dyn SimplifyCtx,
106 ) -> VortexResult<Option<Expression>> {
107 let Some(mask_lit) = expr.child(1).as_opt::<Literal>() else {
108 return Ok(None);
109 };
110
111 let mask_lit = mask_lit
112 .as_bool()
113 .value()
114 .vortex_expect("Mask must be non-nullable");
115
116 if mask_lit {
117 Ok(Some(expr.child(0).clone()))
119 } else {
120 let input_dtype = ctx.return_dtype(expr.child(0))?;
122 Ok(Some(lit(Scalar::null(input_dtype.as_nullable()))))
123 }
124 }
125
126 fn validity(
127 &self,
128 _options: &Self::Options,
129 expression: &Expression,
130 ) -> VortexResult<Option<Expression>> {
131 Ok(Some(and(
132 expression.child(0).validity()?,
133 expression.child(1).clone(),
134 )))
135 }
136}
137
138fn execute_constant(input: &ArrayRef, mask_array: &ArrayRef) -> VortexResult<Option<ArrayRef>> {
142 let len = input.len();
143
144 if let Some(constant_mask) = mask_array.as_opt::<Constant>() {
145 let mask_value = constant_mask.scalar().as_bool().value().unwrap_or(false);
146 return if mask_value {
147 input.cast(input.dtype().as_nullable()).map(Some)
148 } else {
149 Ok(Some(
150 ConstantArray::new(Scalar::null(input.dtype().as_nullable()), len).into_array(),
151 ))
152 };
153 }
154
155 if let Some(constant_input) = input.as_opt::<Constant>()
156 && constant_input.scalar().is_null()
157 {
158 return Ok(Some(
159 ConstantArray::new(Scalar::null(input.dtype().as_nullable()), len).into_array(),
160 ));
161 }
162
163 Ok(None)
164}
165
166fn execute_canonical(
168 input: ArrayRef,
169 mask_array: ArrayRef,
170 ctx: &mut ExecutionCtx,
171) -> VortexResult<ArrayRef> {
172 let validity = child_to_validity(Some(&mask_array), Nullability::Nullable);
173 let canonical = input.execute::<Canonical>(ctx)?;
174 Ok(mask_validity_canonical(canonical, validity, ctx)?.into_array())
175}
176
177#[cfg(test)]
178mod test {
179 use vortex_error::VortexExpect;
180
181 use crate::dtype::DType;
182 use crate::dtype::Nullability::Nullable;
183 use crate::dtype::PType;
184 use crate::expr::lit;
185 use crate::expr::mask;
186 use crate::scalar::Scalar;
187
188 #[test]
189 fn test_simplify() {
190 let input_expr = lit(42u32);
191 let true_mask_expr = lit(true);
192 let false_mask_expr = lit(false);
193
194 let mask_true_expr = mask(input_expr.clone(), true_mask_expr);
195 let simplified_true = mask_true_expr
196 .optimize(&DType::Null)
197 .vortex_expect("Simplification");
198 assert_eq!(&simplified_true, &input_expr);
199
200 let mask_false_expr = mask(input_expr, false_mask_expr);
201 let simplified_false = mask_false_expr
202 .optimize(&DType::Null)
203 .vortex_expect("Simplification");
204 let expected_null_expr = lit(Scalar::null(DType::Primitive(PType::U32, Nullable)));
205 assert_eq!(&simplified_false, &expected_null_expr);
206 }
207}