Skip to main content

vortex_array/expr/exprs/fill_null/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod kernel;
5
6use std::fmt::Formatter;
7
8pub use kernel::*;
9use vortex_dtype::DType;
10use vortex_error::VortexResult;
11use vortex_error::vortex_bail;
12use vortex_error::vortex_ensure;
13use vortex_error::vortex_err;
14use vortex_session::VortexSession;
15
16use crate::AnyColumnar;
17use crate::ArrayRef;
18use crate::CanonicalView;
19use crate::ColumnarView;
20use crate::ExecutionCtx;
21use crate::arrays::BoolVTable;
22use crate::arrays::DecimalVTable;
23use crate::arrays::PrimitiveVTable;
24use crate::builtins::ArrayBuiltins;
25use crate::expr::Arity;
26use crate::expr::ChildName;
27use crate::expr::EmptyOptions;
28use crate::expr::ExecutionArgs;
29use crate::expr::ExprId;
30use crate::expr::Expression;
31use crate::expr::VTable;
32use crate::expr::VTableExt;
33use crate::scalar::Scalar;
34
35/// An expression that replaces null values in the input with a fill value.
36pub struct FillNull;
37
38impl VTable for FillNull {
39    type Options = EmptyOptions;
40
41    fn id(&self) -> ExprId {
42        ExprId::from("vortex.fill_null")
43    }
44
45    fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
46        Ok(Some(vec![]))
47    }
48
49    fn deserialize(
50        &self,
51        _metadata: &[u8],
52        _session: &VortexSession,
53    ) -> VortexResult<Self::Options> {
54        Ok(EmptyOptions)
55    }
56
57    fn arity(&self, _options: &Self::Options) -> Arity {
58        Arity::Exact(2)
59    }
60
61    fn child_name(&self, _options: &Self::Options, child_idx: usize) -> ChildName {
62        match child_idx {
63            0 => ChildName::from("input"),
64            1 => ChildName::from("fill_value"),
65            _ => unreachable!("Invalid child index {} for FillNull expression", child_idx),
66        }
67    }
68
69    fn fmt_sql(
70        &self,
71        _options: &Self::Options,
72        expr: &Expression,
73        f: &mut Formatter<'_>,
74    ) -> std::fmt::Result {
75        write!(f, "fill_null(")?;
76        expr.child(0).fmt_sql(f)?;
77        write!(f, ", ")?;
78        expr.child(1).fmt_sql(f)?;
79        write!(f, ")")
80    }
81
82    fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
83        vortex_ensure!(
84            arg_dtypes[0].eq_ignore_nullability(&arg_dtypes[1]),
85            "fill_null requires input and fill value to have the same base type, got {} and {}",
86            arg_dtypes[0],
87            arg_dtypes[1]
88        );
89        // The result dtype takes the nullability of the fill value.
90        Ok(arg_dtypes[0]
91            .clone()
92            .with_nullability(arg_dtypes[1].nullability()))
93    }
94
95    fn execute(&self, _options: &Self::Options, args: ExecutionArgs) -> VortexResult<ArrayRef> {
96        let [input, fill_value]: [ArrayRef; _] = args
97            .inputs
98            .try_into()
99            .map_err(|_| vortex_err!("Wrong arg count"))?;
100
101        let fill_scalar = fill_value
102            .as_constant()
103            .ok_or_else(|| vortex_err!("fill_null fill_value must be a constant/scalar"))?;
104
105        let Some(columnar) = input.as_opt::<AnyColumnar>() else {
106            return input.execute::<ArrayRef>(args.ctx)?.fill_null(fill_scalar);
107        };
108
109        match columnar {
110            ColumnarView::Canonical(canonical) => {
111                fill_null_canonical(canonical, &fill_scalar, args.ctx)
112            }
113            ColumnarView::Constant(constant) => fill_null_constant(constant, &fill_scalar),
114        }
115    }
116
117    fn simplify(
118        &self,
119        _options: &Self::Options,
120        expr: &Expression,
121        ctx: &dyn crate::expr::SimplifyCtx,
122    ) -> VortexResult<Option<Expression>> {
123        let input_dtype = ctx.return_dtype(expr.child(0))?;
124
125        if !input_dtype.is_nullable() {
126            return Ok(Some(expr.child(0).clone()));
127        }
128
129        Ok(None)
130    }
131
132    fn validity(
133        &self,
134        _options: &Self::Options,
135        expression: &Expression,
136    ) -> VortexResult<Option<Expression>> {
137        // After fill_null, the result validity depends on the fill value's nullability.
138        // If fill_value is non-nullable, the result is always valid.
139        Ok(Some(expression.child(1).validity()?))
140    }
141
142    fn is_null_sensitive(&self, _options: &Self::Options) -> bool {
143        true
144    }
145
146    fn is_fallible(&self, _options: &Self::Options) -> bool {
147        false
148    }
149}
150
151/// Fill nulls on a canonical array by directly dispatching to the appropriate kernel.
152///
153/// Returns the filled array, or bails if no kernel is registered for the canonical type.
154fn fill_null_canonical(
155    canonical: CanonicalView<'_>,
156    fill_value: &Scalar,
157    ctx: &mut ExecutionCtx,
158) -> VortexResult<ArrayRef> {
159    if let Some(result) = precondition(canonical.as_ref(), fill_value)? {
160        // The result of precondition may return another ScalarFn, in which case we should
161        // apply it immediately.
162        // TODO(aduffy): Remove this once we have better driver check. We're also implicitly
163        //  relying on the fact that Cast execution will do an optimize on its result.
164        return result.execute::<ArrayRef>(ctx);
165    }
166    match canonical {
167        CanonicalView::Bool(a) => <BoolVTable as FillNullKernel>::fill_null(a, fill_value, ctx)?
168            .ok_or_else(|| vortex_err!("FillNullKernel for BoolArray returned None")),
169        CanonicalView::Primitive(a) => {
170            <PrimitiveVTable as FillNullKernel>::fill_null(a, fill_value, ctx)?
171                .ok_or_else(|| vortex_err!("FillNullKernel for PrimitiveArray returned None"))
172        }
173        CanonicalView::Decimal(a) => {
174            <DecimalVTable as FillNullKernel>::fill_null(a, fill_value, ctx)?
175                .ok_or_else(|| vortex_err!("FillNullKernel for DecimalArray returned None"))
176        }
177        other => vortex_bail!(
178            "No FillNullKernel for canonical array {}",
179            other.as_ref().encoding_id()
180        ),
181    }
182}
183
184/// Creates an expression that replaces null values with a fill value.
185///
186/// ```rust
187/// # use vortex_array::expr::{fill_null, root, lit};
188/// let expr = fill_null(root(), lit(0i32));
189/// ```
190pub fn fill_null(child: Expression, fill_value: Expression) -> Expression {
191    FillNull.new_expr(EmptyOptions, [child, fill_value])
192}
193
194#[cfg(test)]
195mod tests {
196    use vortex_buffer::buffer;
197    use vortex_dtype::DType;
198    use vortex_dtype::Nullability;
199    use vortex_dtype::PType;
200    use vortex_error::VortexExpect;
201
202    use super::fill_null;
203    use crate::IntoArray;
204    use crate::arrays::PrimitiveArray;
205    use crate::arrays::StructArray;
206    use crate::assert_arrays_eq;
207    use crate::expr::exprs::get_item::get_item;
208    use crate::expr::exprs::literal::lit;
209    use crate::expr::exprs::root::root;
210
211    #[test]
212    fn dtype() {
213        let dtype = DType::Primitive(PType::I32, Nullability::Nullable);
214        assert_eq!(
215            fill_null(root(), lit(0i32)).return_dtype(&dtype).unwrap(),
216            DType::Primitive(PType::I32, Nullability::NonNullable)
217        );
218    }
219
220    #[test]
221    fn replace_children() {
222        let expr = fill_null(root(), lit(0i32));
223        expr.with_children(vec![root(), lit(0i32)])
224            .vortex_expect("operation should succeed in test");
225    }
226
227    #[test]
228    fn evaluate() {
229        let test_array =
230            PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), None, Some(5)])
231                .into_array();
232
233        let expr = fill_null(root(), lit(42i32));
234        let result = test_array.apply(&expr).unwrap();
235
236        assert_eq!(
237            result.dtype(),
238            &DType::Primitive(PType::I32, Nullability::NonNullable)
239        );
240        assert_arrays_eq!(result, PrimitiveArray::from_iter([1i32, 42, 3, 42, 5]));
241    }
242
243    #[test]
244    fn evaluate_struct_field() {
245        let test_array = StructArray::from_fields(&[(
246            "a",
247            PrimitiveArray::from_option_iter([Some(1i32), None, Some(3)]).into_array(),
248        )])
249        .unwrap()
250        .into_array();
251
252        let expr = fill_null(get_item("a", root()), lit(0i32));
253        let result = test_array.apply(&expr).unwrap();
254
255        assert_eq!(
256            result.dtype(),
257            &DType::Primitive(PType::I32, Nullability::NonNullable)
258        );
259        assert_arrays_eq!(result, PrimitiveArray::from_iter([1i32, 0, 3]));
260    }
261
262    #[test]
263    fn evaluate_non_nullable_input() {
264        let test_array = buffer![1i32, 2, 3].into_array();
265        let expr = fill_null(root(), lit(0i32));
266        let result = test_array.apply(&expr).unwrap();
267        assert_arrays_eq!(result, PrimitiveArray::from_iter([1i32, 2, 3]));
268    }
269
270    #[test]
271    fn test_display() {
272        let expr = fill_null(get_item("value", root()), lit(0i32));
273        assert_eq!(expr.to_string(), "fill_null($.value, 0i32)");
274    }
275}