Skip to main content

vortex_array/scalar_fn/fns/fill_null/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod kernel;
5
6pub use kernel::*;
7use vortex_error::VortexResult;
8use vortex_error::vortex_bail;
9use vortex_error::vortex_ensure;
10use vortex_error::vortex_err;
11use vortex_session::VortexSession;
12use vortex_session::registry::CachedId;
13
14use crate::AnyColumnar;
15use crate::ArrayRef;
16use crate::CanonicalView;
17use crate::ColumnarView;
18use crate::ExecutionCtx;
19use crate::arrays::Bool;
20use crate::arrays::Decimal;
21use crate::arrays::Primitive;
22use crate::builtins::ArrayBuiltins;
23use crate::dtype::DType;
24use crate::expr::Expression;
25use crate::scalar::Scalar;
26use crate::scalar_fn::Arity;
27use crate::scalar_fn::ChildName;
28use crate::scalar_fn::EmptyOptions;
29use crate::scalar_fn::ExecutionArgs;
30use crate::scalar_fn::ScalarFnId;
31use crate::scalar_fn::ScalarFnVTable;
32
33/// An expression that replaces null values in the input with a fill value.
34#[derive(Clone)]
35pub struct FillNull;
36
37impl ScalarFnVTable for FillNull {
38    type Options = EmptyOptions;
39
40    fn id(&self) -> ScalarFnId {
41        static ID: CachedId = CachedId::new("vortex.fill_null");
42        *ID
43    }
44
45    fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
46        Ok(Some(vec![]))
47    }
48
49    fn deserialize(
50        &self,
51        _metadata: &[u8],
52        _session: &VortexSession,
53    ) -> VortexResult<Self::Options> {
54        Ok(EmptyOptions)
55    }
56
57    fn arity(&self, _options: &Self::Options) -> Arity {
58        Arity::Exact(2)
59    }
60
61    fn child_name(&self, _options: &Self::Options, child_idx: usize) -> ChildName {
62        match child_idx {
63            0 => ChildName::from("input"),
64            1 => ChildName::from("fill_value"),
65            _ => unreachable!("Invalid child index {} for FillNull expression", child_idx),
66        }
67    }
68
69    fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
70        vortex_ensure!(
71            arg_dtypes[0].eq_ignore_nullability(&arg_dtypes[1]),
72            "fill_null requires input and fill value to have the same base type, got {} and {}",
73            arg_dtypes[0],
74            arg_dtypes[1]
75        );
76        // The result dtype takes the nullability of the fill value.
77        Ok(arg_dtypes[0]
78            .clone()
79            .with_nullability(arg_dtypes[1].nullability()))
80    }
81
82    fn execute(
83        &self,
84        _options: &Self::Options,
85        args: &dyn ExecutionArgs,
86        ctx: &mut ExecutionCtx,
87    ) -> VortexResult<ArrayRef> {
88        let input = args.get(0)?;
89        let fill_value = args.get(1)?;
90
91        let fill_scalar = fill_value
92            .as_constant()
93            .ok_or_else(|| vortex_err!("fill_null fill_value must be a constant/scalar"))?;
94
95        vortex_ensure!(
96            !fill_scalar.is_null(),
97            "fill_null requires a non-null fill value"
98        );
99
100        let Some(columnar) = input.as_opt::<AnyColumnar>() else {
101            return input.execute::<ArrayRef>(ctx)?.fill_null(fill_scalar);
102        };
103
104        match columnar {
105            ColumnarView::Canonical(canonical) => fill_null_canonical(canonical, &fill_scalar, ctx),
106            ColumnarView::Constant(constant) => fill_null_constant(constant, &fill_scalar),
107        }
108    }
109
110    fn simplify(
111        &self,
112        _options: &Self::Options,
113        expr: &Expression,
114        ctx: &dyn crate::scalar_fn::SimplifyCtx,
115    ) -> VortexResult<Option<Expression>> {
116        let input_dtype = ctx.return_dtype(expr.child(0))?;
117
118        if !input_dtype.is_nullable() {
119            return Ok(Some(expr.child(0).clone()));
120        }
121
122        Ok(None)
123    }
124
125    fn validity(
126        &self,
127        _options: &Self::Options,
128        expression: &Expression,
129    ) -> VortexResult<Option<Expression>> {
130        // After fill_null, the result validity depends on the fill value's nullability.
131        // If fill_value is non-nullable, the result is always valid.
132        Ok(Some(expression.child(1).validity()?))
133    }
134
135    fn is_null_sensitive(&self, _options: &Self::Options) -> bool {
136        true
137    }
138
139    fn is_fallible(&self, _options: &Self::Options) -> bool {
140        false
141    }
142}
143
144/// Fill nulls on a canonical array by directly dispatching to the appropriate kernel.
145///
146/// Returns the filled array, or bails if no kernel is registered for the canonical type.
147fn fill_null_canonical(
148    canonical: CanonicalView<'_>,
149    fill_value: &Scalar,
150    ctx: &mut ExecutionCtx,
151) -> VortexResult<ArrayRef> {
152    let arr = canonical.to_array_ref();
153    if let Some(result) = precondition(&arr, fill_value)? {
154        // The result of precondition may return another ScalarFn, in which case we should
155        // apply it immediately.
156        // TODO(aduffy): Remove this once we have better driver check. We're also implicitly
157        //  relying on the fact that Cast execution will do an optimize on its result.
158        return result.execute::<ArrayRef>(ctx);
159    }
160    match canonical {
161        CanonicalView::Bool(a) => <Bool as FillNullKernel>::fill_null(a, fill_value, ctx)?
162            .ok_or_else(|| vortex_err!("FillNullKernel for BoolArray returned None")),
163        CanonicalView::Primitive(a) => {
164            <Primitive as FillNullKernel>::fill_null(a, fill_value, ctx)?
165                .ok_or_else(|| vortex_err!("FillNullKernel for PrimitiveArray returned None"))
166        }
167        CanonicalView::Decimal(a) => <Decimal as FillNullKernel>::fill_null(a, fill_value, ctx)?
168            .ok_or_else(|| vortex_err!("FillNullKernel for DecimalArray returned None")),
169        other => vortex_bail!(
170            "No FillNullKernel for canonical array {}",
171            other.to_array_ref().encoding_id()
172        ),
173    }
174}
175
176#[cfg(test)]
177mod tests {
178    use vortex_buffer::buffer;
179    use vortex_error::VortexExpect;
180
181    use crate::IntoArray;
182    use crate::arrays::PrimitiveArray;
183    use crate::arrays::StructArray;
184    use crate::assert_arrays_eq;
185    use crate::dtype::DType;
186    use crate::dtype::Nullability;
187    use crate::dtype::PType;
188    use crate::expr::fill_null;
189    use crate::expr::get_item;
190    use crate::expr::lit;
191    use crate::expr::root;
192
193    #[test]
194    fn dtype() {
195        let dtype = DType::Primitive(PType::I32, Nullability::Nullable);
196        assert_eq!(
197            fill_null(root(), lit(0i32)).return_dtype(&dtype).unwrap(),
198            DType::Primitive(PType::I32, Nullability::NonNullable)
199        );
200    }
201
202    #[test]
203    fn replace_children() {
204        let expr = fill_null(root(), lit(0i32));
205        expr.with_children(vec![root(), lit(0i32)])
206            .vortex_expect("operation should succeed in test");
207    }
208
209    #[test]
210    fn evaluate() {
211        let test_array =
212            PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), None, Some(5)])
213                .into_array();
214
215        let expr = fill_null(root(), lit(42i32));
216        let result = test_array.apply(&expr).unwrap();
217
218        assert_eq!(
219            result.dtype(),
220            &DType::Primitive(PType::I32, Nullability::NonNullable)
221        );
222        assert_arrays_eq!(result, PrimitiveArray::from_iter([1i32, 42, 3, 42, 5]));
223    }
224
225    #[test]
226    fn evaluate_struct_field() {
227        let test_array = StructArray::from_fields(&[(
228            "a",
229            PrimitiveArray::from_option_iter([Some(1i32), None, Some(3)]).into_array(),
230        )])
231        .unwrap()
232        .into_array();
233
234        let expr = fill_null(get_item("a", root()), lit(0i32));
235        let result = test_array.apply(&expr).unwrap();
236
237        assert_eq!(
238            result.dtype(),
239            &DType::Primitive(PType::I32, Nullability::NonNullable)
240        );
241        assert_arrays_eq!(result, PrimitiveArray::from_iter([1i32, 0, 3]));
242    }
243
244    #[test]
245    fn evaluate_non_nullable_input() {
246        let test_array = buffer![1i32, 2, 3].into_array();
247        let expr = fill_null(root(), lit(0i32));
248        let result = test_array.apply(&expr).unwrap();
249        assert_arrays_eq!(result, PrimitiveArray::from_iter([1i32, 2, 3]));
250    }
251
252    #[test]
253    fn test_display() {
254        let expr = fill_null(get_item("value", root()), lit(0i32));
255        assert_eq!(expr.to_string(), "vortex.fill_null($.value, 0i32)");
256    }
257}