Skip to main content

vortex_array/scalar_fn/fns/fill_null/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod kernel;
5
6use std::fmt::Formatter;
7
8pub use kernel::*;
9use vortex_error::VortexResult;
10use vortex_error::vortex_bail;
11use vortex_error::vortex_ensure;
12use vortex_error::vortex_err;
13use vortex_session::VortexSession;
14
15use crate::AnyColumnar;
16use crate::ArrayRef;
17use crate::CanonicalView;
18use crate::ColumnarView;
19use crate::ExecutionCtx;
20use crate::arrays::Bool;
21use crate::arrays::Decimal;
22use crate::arrays::Primitive;
23use crate::builtins::ArrayBuiltins;
24use crate::dtype::DType;
25use crate::expr::Expression;
26use crate::scalar::Scalar;
27use crate::scalar_fn::Arity;
28use crate::scalar_fn::ChildName;
29use crate::scalar_fn::EmptyOptions;
30use crate::scalar_fn::ExecutionArgs;
31use crate::scalar_fn::ScalarFnId;
32use crate::scalar_fn::ScalarFnVTable;
33
34/// An expression that replaces null values in the input with a fill value.
35#[derive(Clone)]
36pub struct FillNull;
37
38impl ScalarFnVTable for FillNull {
39    type Options = EmptyOptions;
40
41    fn id(&self) -> ScalarFnId {
42        ScalarFnId::from("vortex.fill_null")
43    }
44
45    fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
46        Ok(Some(vec![]))
47    }
48
49    fn deserialize(
50        &self,
51        _metadata: &[u8],
52        _session: &VortexSession,
53    ) -> VortexResult<Self::Options> {
54        Ok(EmptyOptions)
55    }
56
57    fn arity(&self, _options: &Self::Options) -> Arity {
58        Arity::Exact(2)
59    }
60
61    fn child_name(&self, _options: &Self::Options, child_idx: usize) -> ChildName {
62        match child_idx {
63            0 => ChildName::from("input"),
64            1 => ChildName::from("fill_value"),
65            _ => unreachable!("Invalid child index {} for FillNull expression", child_idx),
66        }
67    }
68
69    fn fmt_sql(
70        &self,
71        _options: &Self::Options,
72        expr: &Expression,
73        f: &mut Formatter<'_>,
74    ) -> std::fmt::Result {
75        write!(f, "fill_null(")?;
76        expr.child(0).fmt_sql(f)?;
77        write!(f, ", ")?;
78        expr.child(1).fmt_sql(f)?;
79        write!(f, ")")
80    }
81
82    fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
83        vortex_ensure!(
84            arg_dtypes[0].eq_ignore_nullability(&arg_dtypes[1]),
85            "fill_null requires input and fill value to have the same base type, got {} and {}",
86            arg_dtypes[0],
87            arg_dtypes[1]
88        );
89        // The result dtype takes the nullability of the fill value.
90        Ok(arg_dtypes[0]
91            .clone()
92            .with_nullability(arg_dtypes[1].nullability()))
93    }
94
95    fn execute(
96        &self,
97        _options: &Self::Options,
98        args: &dyn ExecutionArgs,
99        ctx: &mut ExecutionCtx,
100    ) -> VortexResult<ArrayRef> {
101        let input = args.get(0)?;
102        let fill_value = args.get(1)?;
103
104        let fill_scalar = fill_value
105            .as_constant()
106            .ok_or_else(|| vortex_err!("fill_null fill_value must be a constant/scalar"))?;
107
108        vortex_ensure!(
109            !fill_scalar.is_null(),
110            "fill_null requires a non-null fill value"
111        );
112
113        let Some(columnar) = input.as_opt::<AnyColumnar>() else {
114            return input.execute::<ArrayRef>(ctx)?.fill_null(fill_scalar);
115        };
116
117        match columnar {
118            ColumnarView::Canonical(canonical) => fill_null_canonical(canonical, &fill_scalar, ctx),
119            ColumnarView::Constant(constant) => fill_null_constant(constant, &fill_scalar),
120        }
121    }
122
123    fn simplify(
124        &self,
125        _options: &Self::Options,
126        expr: &Expression,
127        ctx: &dyn crate::scalar_fn::SimplifyCtx,
128    ) -> VortexResult<Option<Expression>> {
129        let input_dtype = ctx.return_dtype(expr.child(0))?;
130
131        if !input_dtype.is_nullable() {
132            return Ok(Some(expr.child(0).clone()));
133        }
134
135        Ok(None)
136    }
137
138    fn validity(
139        &self,
140        _options: &Self::Options,
141        expression: &Expression,
142    ) -> VortexResult<Option<Expression>> {
143        // After fill_null, the result validity depends on the fill value's nullability.
144        // If fill_value is non-nullable, the result is always valid.
145        Ok(Some(expression.child(1).validity()?))
146    }
147
148    fn is_null_sensitive(&self, _options: &Self::Options) -> bool {
149        true
150    }
151
152    fn is_fallible(&self, _options: &Self::Options) -> bool {
153        false
154    }
155}
156
157/// Fill nulls on a canonical array by directly dispatching to the appropriate kernel.
158///
159/// Returns the filled array, or bails if no kernel is registered for the canonical type.
160fn fill_null_canonical(
161    canonical: CanonicalView<'_>,
162    fill_value: &Scalar,
163    ctx: &mut ExecutionCtx,
164) -> VortexResult<ArrayRef> {
165    let arr = canonical.as_ref().to_array();
166    if let Some(result) = precondition(&arr, fill_value)? {
167        // The result of precondition may return another ScalarFn, in which case we should
168        // apply it immediately.
169        // TODO(aduffy): Remove this once we have better driver check. We're also implicitly
170        //  relying on the fact that Cast execution will do an optimize on its result.
171        return result.execute::<ArrayRef>(ctx);
172    }
173    match canonical {
174        CanonicalView::Bool(a) => <Bool as FillNullKernel>::fill_null(a, fill_value, ctx)?
175            .ok_or_else(|| vortex_err!("FillNullKernel for BoolArray returned None")),
176        CanonicalView::Primitive(a) => {
177            <Primitive as FillNullKernel>::fill_null(a, fill_value, ctx)?
178                .ok_or_else(|| vortex_err!("FillNullKernel for PrimitiveArray returned None"))
179        }
180        CanonicalView::Decimal(a) => <Decimal as FillNullKernel>::fill_null(a, fill_value, ctx)?
181            .ok_or_else(|| vortex_err!("FillNullKernel for DecimalArray returned None")),
182        other => vortex_bail!(
183            "No FillNullKernel for canonical array {}",
184            other.as_ref().encoding_id()
185        ),
186    }
187}
188
189#[cfg(test)]
190mod tests {
191    use vortex_buffer::buffer;
192    use vortex_error::VortexExpect;
193
194    use crate::IntoArray;
195    use crate::arrays::PrimitiveArray;
196    use crate::arrays::StructArray;
197    use crate::assert_arrays_eq;
198    use crate::dtype::DType;
199    use crate::dtype::Nullability;
200    use crate::dtype::PType;
201    use crate::expr::fill_null;
202    use crate::expr::get_item;
203    use crate::expr::lit;
204    use crate::expr::root;
205
206    #[test]
207    fn dtype() {
208        let dtype = DType::Primitive(PType::I32, Nullability::Nullable);
209        assert_eq!(
210            fill_null(root(), lit(0i32)).return_dtype(&dtype).unwrap(),
211            DType::Primitive(PType::I32, Nullability::NonNullable)
212        );
213    }
214
215    #[test]
216    fn replace_children() {
217        let expr = fill_null(root(), lit(0i32));
218        expr.with_children(vec![root(), lit(0i32)])
219            .vortex_expect("operation should succeed in test");
220    }
221
222    #[test]
223    fn evaluate() {
224        let test_array =
225            PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), None, Some(5)])
226                .into_array();
227
228        let expr = fill_null(root(), lit(42i32));
229        let result = test_array.apply(&expr).unwrap();
230
231        assert_eq!(
232            result.dtype(),
233            &DType::Primitive(PType::I32, Nullability::NonNullable)
234        );
235        assert_arrays_eq!(result, PrimitiveArray::from_iter([1i32, 42, 3, 42, 5]));
236    }
237
238    #[test]
239    fn evaluate_struct_field() {
240        let test_array = StructArray::from_fields(&[(
241            "a",
242            PrimitiveArray::from_option_iter([Some(1i32), None, Some(3)]).into_array(),
243        )])
244        .unwrap()
245        .into_array();
246
247        let expr = fill_null(get_item("a", root()), lit(0i32));
248        let result = test_array.apply(&expr).unwrap();
249
250        assert_eq!(
251            result.dtype(),
252            &DType::Primitive(PType::I32, Nullability::NonNullable)
253        );
254        assert_arrays_eq!(result, PrimitiveArray::from_iter([1i32, 0, 3]));
255    }
256
257    #[test]
258    fn evaluate_non_nullable_input() {
259        let test_array = buffer![1i32, 2, 3].into_array();
260        let expr = fill_null(root(), lit(0i32));
261        let result = test_array.apply(&expr).unwrap();
262        assert_arrays_eq!(result, PrimitiveArray::from_iter([1i32, 2, 3]));
263    }
264
265    #[test]
266    fn test_display() {
267        let expr = fill_null(get_item("value", root()), lit(0i32));
268        assert_eq!(expr.to_string(), "fill_null($.value, 0i32)");
269    }
270}