vortex_array/scalar_fns/mask/
mod.rs1use vortex_dtype::DType;
5use vortex_dtype::Nullability;
6use vortex_error::VortexExpect;
7use vortex_error::VortexResult;
8use vortex_error::vortex_ensure;
9use vortex_mask::Mask;
10use vortex_vector::BoolDatum;
11use vortex_vector::Datum;
12use vortex_vector::ScalarOps;
13use vortex_vector::VectorMutOps;
14use vortex_vector::VectorOps;
15
16use crate::expr::functions::ArgName;
17use crate::expr::functions::Arity;
18use crate::expr::functions::EmptyOptions;
19use crate::expr::functions::ExecutionArgs;
20use crate::expr::functions::FunctionId;
21use crate::expr::functions::VTable;
22
23pub struct MaskFn;
27impl VTable for MaskFn {
28 type Options = EmptyOptions;
29
30 fn id(&self) -> FunctionId {
31 FunctionId::from("vortex.mask")
32 }
33
34 fn arity(&self, _options: &Self::Options) -> Arity {
35 Arity::Exact(2)
36 }
37
38 fn arg_name(&self, _options: &Self::Options, arg_idx: usize) -> ArgName {
39 match arg_idx {
40 0 => ArgName::from("input"),
41 1 => ArgName::from("mask"),
42 _ => unreachable!("unknown"),
43 }
44 }
45
46 fn return_dtype(&self, _options: &Self::Options, arg_types: &[DType]) -> VortexResult<DType> {
47 vortex_ensure!(
48 arg_types[1] == DType::Bool(Nullability::NonNullable),
49 "The mask argument to 'mask' must be a non-nullable boolean array, got {}",
50 arg_types[1]
51 );
52 Ok(arg_types[0].as_nullable())
53 }
54
55 fn execute(&self, _options: &Self::Options, args: &ExecutionArgs) -> VortexResult<Datum> {
56 let input = args.input_datums(0).clone();
57 let mask = args.input_datums(1).clone().into_bool();
58 match (input, mask) {
59 (Datum::Scalar(input), BoolDatum::Scalar(mask)) => {
60 let mut result = input;
61 result.mask_validity(mask.value().vortex_expect("mask is non-nullable"));
62 Ok(Datum::Scalar(result))
63 }
64 (Datum::Scalar(input), BoolDatum::Vector(mask)) => {
65 let mut result = input.repeat(args.row_count()).freeze();
66 result.mask_validity(&Mask::from(mask.into_bits()));
67 Ok(Datum::Vector(result))
68 }
69 (Datum::Vector(input_array), BoolDatum::Scalar(mask)) => {
70 let mut result = input_array;
71 result.mask_validity(&Mask::new(
72 args.row_count(),
73 mask.value().vortex_expect("mask is non-nullable"),
74 ));
75 Ok(Datum::Vector(result))
76 }
77 (Datum::Vector(input_array), BoolDatum::Vector(mask)) => {
78 let mut result = input_array;
79 result.mask_validity(&Mask::from(mask.into_bits()));
80 Ok(Datum::Vector(result))
81 }
82 }
83 }
84}