Skip to main content

vortex_array/scalar_fn/fns/fill_null/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod kernel;
5
6pub use kernel::*;
7use vortex_error::VortexResult;
8use vortex_error::vortex_bail;
9use vortex_error::vortex_ensure;
10use vortex_error::vortex_err;
11use vortex_session::VortexSession;
12
13use crate::AnyColumnar;
14use crate::ArrayRef;
15use crate::CanonicalView;
16use crate::ColumnarView;
17use crate::ExecutionCtx;
18use crate::arrays::Bool;
19use crate::arrays::Decimal;
20use crate::arrays::Primitive;
21use crate::builtins::ArrayBuiltins;
22use crate::dtype::DType;
23use crate::expr::Expression;
24use crate::scalar::Scalar;
25use crate::scalar_fn::Arity;
26use crate::scalar_fn::ChildName;
27use crate::scalar_fn::EmptyOptions;
28use crate::scalar_fn::ExecutionArgs;
29use crate::scalar_fn::ScalarFnId;
30use crate::scalar_fn::ScalarFnVTable;
31
32/// An expression that replaces null values in the input with a fill value.
33#[derive(Clone)]
34pub struct FillNull;
35
36impl ScalarFnVTable for FillNull {
37    type Options = EmptyOptions;
38
39    fn id(&self) -> ScalarFnId {
40        ScalarFnId::new("vortex.fill_null")
41    }
42
43    fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
44        Ok(Some(vec![]))
45    }
46
47    fn deserialize(
48        &self,
49        _metadata: &[u8],
50        _session: &VortexSession,
51    ) -> VortexResult<Self::Options> {
52        Ok(EmptyOptions)
53    }
54
55    fn arity(&self, _options: &Self::Options) -> Arity {
56        Arity::Exact(2)
57    }
58
59    fn child_name(&self, _options: &Self::Options, child_idx: usize) -> ChildName {
60        match child_idx {
61            0 => ChildName::from("input"),
62            1 => ChildName::from("fill_value"),
63            _ => unreachable!("Invalid child index {} for FillNull expression", child_idx),
64        }
65    }
66
67    fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
68        vortex_ensure!(
69            arg_dtypes[0].eq_ignore_nullability(&arg_dtypes[1]),
70            "fill_null requires input and fill value to have the same base type, got {} and {}",
71            arg_dtypes[0],
72            arg_dtypes[1]
73        );
74        // The result dtype takes the nullability of the fill value.
75        Ok(arg_dtypes[0]
76            .clone()
77            .with_nullability(arg_dtypes[1].nullability()))
78    }
79
80    fn execute(
81        &self,
82        _options: &Self::Options,
83        args: &dyn ExecutionArgs,
84        ctx: &mut ExecutionCtx,
85    ) -> VortexResult<ArrayRef> {
86        let input = args.get(0)?;
87        let fill_value = args.get(1)?;
88
89        let fill_scalar = fill_value
90            .as_constant()
91            .ok_or_else(|| vortex_err!("fill_null fill_value must be a constant/scalar"))?;
92
93        vortex_ensure!(
94            !fill_scalar.is_null(),
95            "fill_null requires a non-null fill value"
96        );
97
98        let Some(columnar) = input.as_opt::<AnyColumnar>() else {
99            return input.execute::<ArrayRef>(ctx)?.fill_null(fill_scalar);
100        };
101
102        match columnar {
103            ColumnarView::Canonical(canonical) => fill_null_canonical(canonical, &fill_scalar, ctx),
104            ColumnarView::Constant(constant) => fill_null_constant(constant, &fill_scalar),
105        }
106    }
107
108    fn simplify(
109        &self,
110        _options: &Self::Options,
111        expr: &Expression,
112        ctx: &dyn crate::scalar_fn::SimplifyCtx,
113    ) -> VortexResult<Option<Expression>> {
114        let input_dtype = ctx.return_dtype(expr.child(0))?;
115
116        if !input_dtype.is_nullable() {
117            return Ok(Some(expr.child(0).clone()));
118        }
119
120        Ok(None)
121    }
122
123    fn validity(
124        &self,
125        _options: &Self::Options,
126        expression: &Expression,
127    ) -> VortexResult<Option<Expression>> {
128        // After fill_null, the result validity depends on the fill value's nullability.
129        // If fill_value is non-nullable, the result is always valid.
130        Ok(Some(expression.child(1).validity()?))
131    }
132
133    fn is_null_sensitive(&self, _options: &Self::Options) -> bool {
134        true
135    }
136
137    fn is_fallible(&self, _options: &Self::Options) -> bool {
138        false
139    }
140}
141
142/// Fill nulls on a canonical array by directly dispatching to the appropriate kernel.
143///
144/// Returns the filled array, or bails if no kernel is registered for the canonical type.
145fn fill_null_canonical(
146    canonical: CanonicalView<'_>,
147    fill_value: &Scalar,
148    ctx: &mut ExecutionCtx,
149) -> VortexResult<ArrayRef> {
150    let arr = canonical.to_array_ref();
151    if let Some(result) = precondition(&arr, fill_value)? {
152        // The result of precondition may return another ScalarFn, in which case we should
153        // apply it immediately.
154        // TODO(aduffy): Remove this once we have better driver check. We're also implicitly
155        //  relying on the fact that Cast execution will do an optimize on its result.
156        return result.execute::<ArrayRef>(ctx);
157    }
158    match canonical {
159        CanonicalView::Bool(a) => <Bool as FillNullKernel>::fill_null(a, fill_value, ctx)?
160            .ok_or_else(|| vortex_err!("FillNullKernel for BoolArray returned None")),
161        CanonicalView::Primitive(a) => {
162            <Primitive as FillNullKernel>::fill_null(a, fill_value, ctx)?
163                .ok_or_else(|| vortex_err!("FillNullKernel for PrimitiveArray returned None"))
164        }
165        CanonicalView::Decimal(a) => <Decimal as FillNullKernel>::fill_null(a, fill_value, ctx)?
166            .ok_or_else(|| vortex_err!("FillNullKernel for DecimalArray returned None")),
167        other => vortex_bail!(
168            "No FillNullKernel for canonical array {}",
169            other.to_array_ref().encoding_id()
170        ),
171    }
172}
173
174#[cfg(test)]
175mod tests {
176    use vortex_buffer::buffer;
177    use vortex_error::VortexExpect;
178
179    use crate::IntoArray;
180    use crate::arrays::PrimitiveArray;
181    use crate::arrays::StructArray;
182    use crate::assert_arrays_eq;
183    use crate::dtype::DType;
184    use crate::dtype::Nullability;
185    use crate::dtype::PType;
186    use crate::expr::fill_null;
187    use crate::expr::get_item;
188    use crate::expr::lit;
189    use crate::expr::root;
190
191    #[test]
192    fn dtype() {
193        let dtype = DType::Primitive(PType::I32, Nullability::Nullable);
194        assert_eq!(
195            fill_null(root(), lit(0i32)).return_dtype(&dtype).unwrap(),
196            DType::Primitive(PType::I32, Nullability::NonNullable)
197        );
198    }
199
200    #[test]
201    fn replace_children() {
202        let expr = fill_null(root(), lit(0i32));
203        expr.with_children(vec![root(), lit(0i32)])
204            .vortex_expect("operation should succeed in test");
205    }
206
207    #[test]
208    fn evaluate() {
209        let test_array =
210            PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), None, Some(5)])
211                .into_array();
212
213        let expr = fill_null(root(), lit(42i32));
214        let result = test_array.apply(&expr).unwrap();
215
216        assert_eq!(
217            result.dtype(),
218            &DType::Primitive(PType::I32, Nullability::NonNullable)
219        );
220        assert_arrays_eq!(result, PrimitiveArray::from_iter([1i32, 42, 3, 42, 5]));
221    }
222
223    #[test]
224    fn evaluate_struct_field() {
225        let test_array = StructArray::from_fields(&[(
226            "a",
227            PrimitiveArray::from_option_iter([Some(1i32), None, Some(3)]).into_array(),
228        )])
229        .unwrap()
230        .into_array();
231
232        let expr = fill_null(get_item("a", root()), lit(0i32));
233        let result = test_array.apply(&expr).unwrap();
234
235        assert_eq!(
236            result.dtype(),
237            &DType::Primitive(PType::I32, Nullability::NonNullable)
238        );
239        assert_arrays_eq!(result, PrimitiveArray::from_iter([1i32, 0, 3]));
240    }
241
242    #[test]
243    fn evaluate_non_nullable_input() {
244        let test_array = buffer![1i32, 2, 3].into_array();
245        let expr = fill_null(root(), lit(0i32));
246        let result = test_array.apply(&expr).unwrap();
247        assert_arrays_eq!(result, PrimitiveArray::from_iter([1i32, 2, 3]));
248    }
249
250    #[test]
251    fn test_display() {
252        let expr = fill_null(get_item("value", root()), lit(0i32));
253        assert_eq!(expr.to_string(), "vortex.fill_null($.value, 0i32)");
254    }
255}