vortex_array/scalar_fns/mask/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_dtype::DType;
5use vortex_dtype::Nullability;
6use vortex_error::VortexExpect;
7use vortex_error::VortexResult;
8use vortex_error::vortex_ensure;
9use vortex_mask::Mask;
10use vortex_vector::BoolDatum;
11use vortex_vector::Datum;
12use vortex_vector::ScalarOps;
13use vortex_vector::VectorMutOps;
14use vortex_vector::VectorOps;
15
16use crate::expr::functions::ArgName;
17use crate::expr::functions::Arity;
18use crate::expr::functions::EmptyOptions;
19use crate::expr::functions::ExecutionArgs;
20use crate::expr::functions::FunctionId;
21use crate::expr::functions::VTable;
22
23/// A function that intersects the validity of an array using another array as a mask.
24///
25/// Where the `mask` array is true, the corresponding v
26pub struct MaskFn;
27impl VTable for MaskFn {
28    type Options = EmptyOptions;
29
30    fn id(&self) -> FunctionId {
31        FunctionId::from("vortex.mask")
32    }
33
34    fn arity(&self, _options: &Self::Options) -> Arity {
35        Arity::Exact(2)
36    }
37
38    fn arg_name(&self, _options: &Self::Options, arg_idx: usize) -> ArgName {
39        match arg_idx {
40            0 => ArgName::from("input"),
41            1 => ArgName::from("mask"),
42            _ => unreachable!("unknown"),
43        }
44    }
45
46    fn return_dtype(&self, _options: &Self::Options, arg_types: &[DType]) -> VortexResult<DType> {
47        vortex_ensure!(
48            arg_types[1] == DType::Bool(Nullability::NonNullable),
49            "The mask argument to 'mask' must be a non-nullable boolean array, got {}",
50            arg_types[1]
51        );
52        Ok(arg_types[0].as_nullable())
53    }
54
55    fn execute(&self, _options: &Self::Options, args: &ExecutionArgs) -> VortexResult<Datum> {
56        let input = args.input_datums(0).clone();
57        let mask = args.input_datums(1).clone().into_bool();
58        match (input, mask) {
59            (Datum::Scalar(input), BoolDatum::Scalar(mask)) => {
60                let mut result = input;
61                result.mask_validity(mask.value().vortex_expect("mask is non-nullable"));
62                Ok(Datum::Scalar(result))
63            }
64            (Datum::Scalar(input), BoolDatum::Vector(mask)) => {
65                let mut result = input.repeat(args.row_count()).freeze();
66                result.mask_validity(&Mask::from(mask.into_bits()));
67                Ok(Datum::Vector(result))
68            }
69            (Datum::Vector(input_array), BoolDatum::Scalar(mask)) => {
70                let mut result = input_array;
71                result.mask_validity(&Mask::new(
72                    args.row_count(),
73                    mask.value().vortex_expect("mask is non-nullable"),
74                ));
75                Ok(Datum::Vector(result))
76            }
77            (Datum::Vector(input_array), BoolDatum::Vector(mask)) => {
78                let mut result = input_array;
79                result.mask_validity(&Mask::from(mask.into_bits()));
80                Ok(Datum::Vector(result))
81            }
82        }
83    }
84}