vortex_array/compute/
fill_null.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::LazyLock;
5
6use arcref::ArcRef;
7use vortex_dtype::DType;
8use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err};
9use vortex_scalar::Scalar;
10
11use crate::arrays::ConstantArray;
12use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Output, cast};
13use crate::vtable::VTable;
14use crate::{Array, ArrayRef, IntoArray};
15
16static FILL_NULL_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
17    let compute = ComputeFn::new("fill_null".into(), ArcRef::new_ref(&FillNull));
18    for kernel in inventory::iter::<FillNullKernelRef> {
19        compute.register_kernel(kernel.0.clone());
20    }
21    compute
22});
23
24pub(crate) fn warm_up_vtable() -> usize {
25    FILL_NULL_FN.kernels().len()
26}
27
28/// Replace nulls in the array with another value.
29///
30/// # Examples
31///
32/// ```
33/// use vortex_array::arrays::{PrimitiveArray};
34/// use vortex_array::compute::{fill_null};
35/// use vortex_scalar::Scalar;
36///
37/// let array =
38///     PrimitiveArray::from_option_iter([Some(0i32), None, Some(1i32), None, Some(2i32)]);
39/// let array = fill_null(array.as_ref(), &Scalar::from(42i32)).unwrap();
40/// assert_eq!(array.display_values().to_string(), "[0i32, 42i32, 1i32, 42i32, 2i32]");
41/// ```
42pub fn fill_null(array: &dyn Array, fill_value: &Scalar) -> VortexResult<ArrayRef> {
43    FILL_NULL_FN
44        .invoke(&InvocationArgs {
45            inputs: &[array.into(), fill_value.into()],
46            options: &(),
47        })?
48        .unwrap_array()
49}
50
51pub trait FillNullKernel: VTable {
52    fn fill_null(&self, array: &Self::Array, fill_value: &Scalar) -> VortexResult<ArrayRef>;
53}
54
55pub struct FillNullKernelRef(ArcRef<dyn Kernel>);
56inventory::collect!(FillNullKernelRef);
57
58#[derive(Debug)]
59pub struct FillNullKernelAdapter<V: VTable>(pub V);
60
61impl<V: VTable + FillNullKernel> FillNullKernelAdapter<V> {
62    pub const fn lift(&'static self) -> FillNullKernelRef {
63        FillNullKernelRef(ArcRef::new_ref(self))
64    }
65}
66
67impl<V: VTable + FillNullKernel> Kernel for FillNullKernelAdapter<V> {
68    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
69        let inputs = FillNullArgs::try_from(args)?;
70        let Some(array) = inputs.array.as_opt::<V>() else {
71            return Ok(None);
72        };
73        Ok(Some(
74            V::fill_null(&self.0, array, inputs.fill_value)?.into(),
75        ))
76    }
77}
78
79struct FillNull;
80
81impl ComputeFnVTable for FillNull {
82    fn invoke(
83        &self,
84        args: &InvocationArgs,
85        kernels: &[ArcRef<dyn Kernel>],
86    ) -> VortexResult<Output> {
87        let FillNullArgs { array, fill_value } = FillNullArgs::try_from(args)?;
88
89        if !array.dtype().is_nullable() || array.all_valid() {
90            return Ok(cast(array, fill_value.dtype())?.into());
91        }
92
93        if array.all_invalid() {
94            return Ok(ConstantArray::new(fill_value.clone(), array.len())
95                .into_array()
96                .into());
97        }
98
99        if fill_value.is_null() {
100            vortex_bail!("Cannot fill_null with a null value")
101        }
102
103        for kernel in kernels {
104            if let Some(output) = kernel.invoke(args)? {
105                return Ok(output);
106            }
107        }
108        if let Some(output) = array.invoke(&FILL_NULL_FN, args)? {
109            return Ok(output);
110        }
111
112        log::debug!("FillNullFn not implemented for {}", array.encoding_id());
113        if !array.is_canonical() {
114            let canonical_arr = array.to_canonical().into_array();
115            return Ok(fill_null(canonical_arr.as_ref(), fill_value)?.into());
116        }
117
118        // TODO(joe): update fuzzer when fixed
119        vortex_bail!("fill null not implemented for DType {}", array.dtype())
120    }
121
122    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
123        let FillNullArgs { array, fill_value } = FillNullArgs::try_from(args)?;
124        if !array.dtype().eq_ignore_nullability(fill_value.dtype()) {
125            vortex_bail!("FillNull value must match array type (ignoring nullability)");
126        }
127        Ok(fill_value.dtype().clone())
128    }
129
130    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
131        let FillNullArgs { array, .. } = FillNullArgs::try_from(args)?;
132        Ok(array.len())
133    }
134
135    fn is_elementwise(&self) -> bool {
136        true
137    }
138}
139
140struct FillNullArgs<'a> {
141    array: &'a dyn Array,
142    fill_value: &'a Scalar,
143}
144
145impl<'a> TryFrom<&InvocationArgs<'a>> for FillNullArgs<'a> {
146    type Error = VortexError;
147
148    fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
149        if value.inputs.len() != 2 {
150            vortex_bail!("FillNull requires 2 arguments");
151        }
152
153        let array = value.inputs[0]
154            .array()
155            .ok_or_else(|| vortex_err!("FillNull requires an array"))?;
156        let fill_value = value.inputs[1]
157            .scalar()
158            .ok_or_else(|| vortex_err!("FillNull requires a scalar"))?;
159
160        Ok(FillNullArgs { array, fill_value })
161    }
162}