vortex_array/expr/exprs/
mask.rs1use std::fmt::Formatter;
5use std::ops::Not;
6
7use vortex_dtype::DType;
8use vortex_dtype::Nullability;
9use vortex_error::VortexExpect;
10use vortex_error::VortexResult;
11use vortex_error::vortex_ensure;
12use vortex_error::vortex_err;
13use vortex_vector::BoolDatum;
14use vortex_vector::Datum;
15use vortex_vector::ScalarOps;
16use vortex_vector::VectorMutOps;
17use vortex_vector::VectorOps;
18
19use crate::Array;
20use crate::ArrayRef;
21use crate::expr::Arity;
22use crate::expr::ChildName;
23use crate::expr::EmptyOptions;
24use crate::expr::ExecutionArgs;
25use crate::expr::ExprId;
26use crate::expr::Expression;
27use crate::expr::VTable;
28use crate::expr::VTableExt;
29
30pub struct Mask;
35
36impl VTable for Mask {
37 type Options = EmptyOptions;
38
39 fn id(&self) -> ExprId {
40 ExprId::from("vortex.mask")
41 }
42
43 fn arity(&self, _options: &Self::Options) -> Arity {
44 Arity::Exact(2)
45 }
46
47 fn child_name(&self, _options: &Self::Options, child_idx: usize) -> ChildName {
48 match child_idx {
49 0 => ChildName::from("input"),
50 1 => ChildName::from("mask"),
51 _ => unreachable!("Invalid child index {} for Mask expression", child_idx),
52 }
53 }
54
55 fn fmt_sql(
56 &self,
57 _options: &Self::Options,
58 expr: &Expression,
59 f: &mut Formatter<'_>,
60 ) -> std::fmt::Result {
61 write!(f, "mask(")?;
62 expr.child(0).fmt_sql(f)?;
63 write!(f, ", ")?;
64 expr.child(1).fmt_sql(f)?;
65 write!(f, ")")
66 }
67
68 fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
69 vortex_ensure!(
70 arg_dtypes[1] == DType::Bool(Nullability::NonNullable),
71 "The mask argument to 'mask' must be a non-nullable boolean array, got {}",
72 arg_dtypes[1]
73 );
74 Ok(arg_dtypes[0].as_nullable())
75 }
76
77 fn evaluate(
78 &self,
79 _options: &Self::Options,
80 expr: &Expression,
81 scope: &ArrayRef,
82 ) -> VortexResult<ArrayRef> {
83 let child = expr.child(0).evaluate(scope)?;
84
85 let inverted_mask = child.validity_mask().not();
87
88 crate::compute::mask(&child, &inverted_mask)
89 }
90
91 fn execute(&self, _options: &Self::Options, args: ExecutionArgs) -> VortexResult<Datum> {
92 let [input, mask]: [Datum; _] = args
93 .datums
94 .try_into()
95 .map_err(|_| vortex_err!("Wrong arg count"))?;
96 let mask = mask.into_bool();
97
98 match (input, mask) {
99 (Datum::Scalar(input), BoolDatum::Scalar(mask)) => {
100 let mut result = input;
101 result.mask_validity(mask.value().vortex_expect("mask is non-nullable"));
102 Ok(Datum::Scalar(result))
103 }
104 (Datum::Scalar(input), BoolDatum::Vector(mask)) => {
105 let mut result = input.repeat(args.row_count).freeze();
106 result.mask_validity(&vortex_mask::Mask::from(mask.into_bits()));
107 Ok(Datum::Vector(result))
108 }
109 (Datum::Vector(input_array), BoolDatum::Scalar(mask)) => {
110 let mut result = input_array;
111 result.mask_validity(&vortex_mask::Mask::new(
112 args.row_count,
113 mask.value().vortex_expect("mask is non-nullable"),
114 ));
115 Ok(Datum::Vector(result))
116 }
117 (Datum::Vector(input_array), BoolDatum::Vector(mask)) => {
118 let mut result = input_array;
119 result.mask_validity(&vortex_mask::Mask::from(mask.into_bits()));
120 Ok(Datum::Vector(result))
121 }
122 }
123 }
124}
125
126pub fn mask(array: Expression, mask: Expression) -> Expression {
128 Mask.new_expr(EmptyOptions, [array, mask])
129}