vortex_array/compute/
fill_null.rs

1use std::sync::LazyLock;
2
3use arcref::ArcRef;
4use vortex_dtype::DType;
5use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err};
6use vortex_scalar::Scalar;
7
8use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Output, cast};
9use crate::vtable::VTable;
10use crate::{Array, ArrayRef, IntoArray};
11
12pub fn fill_null(array: &dyn Array, fill_value: &Scalar) -> VortexResult<ArrayRef> {
13    FILL_NULL_FN
14        .invoke(&InvocationArgs {
15            inputs: &[array.into(), fill_value.into()],
16            options: &(),
17        })?
18        .unwrap_array()
19}
20
21pub trait FillNullKernel: VTable {
22    fn fill_null(&self, array: &Self::Array, fill_value: &Scalar) -> VortexResult<ArrayRef>;
23}
24
25pub struct FillNullKernelRef(ArcRef<dyn Kernel>);
26inventory::collect!(FillNullKernelRef);
27
28#[derive(Debug)]
29pub struct FillNullKernelAdapter<V: VTable>(pub V);
30
31impl<V: VTable + FillNullKernel> FillNullKernelAdapter<V> {
32    pub const fn lift(&'static self) -> FillNullKernelRef {
33        FillNullKernelRef(ArcRef::new_ref(self))
34    }
35}
36
37impl<V: VTable + FillNullKernel> Kernel for FillNullKernelAdapter<V> {
38    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
39        let inputs = FillNullArgs::try_from(args)?;
40        let Some(array) = inputs.array.as_opt::<V>() else {
41            return Ok(None);
42        };
43        Ok(Some(
44            V::fill_null(&self.0, array, inputs.fill_value)?.into(),
45        ))
46    }
47}
48
49pub static FILL_NULL_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
50    let compute = ComputeFn::new("fill_null".into(), ArcRef::new_ref(&FillNull));
51    for kernel in inventory::iter::<FillNullKernelRef> {
52        compute.register_kernel(kernel.0.clone());
53    }
54    compute
55});
56
57struct FillNull;
58
59impl ComputeFnVTable for FillNull {
60    fn invoke(
61        &self,
62        args: &InvocationArgs,
63        kernels: &[ArcRef<dyn Kernel>],
64    ) -> VortexResult<Output> {
65        let FillNullArgs { array, fill_value } = FillNullArgs::try_from(args)?;
66
67        if !array.dtype().is_nullable() {
68            return Ok(array.to_array().into());
69        }
70
71        if array.invalid_count()? == 0 {
72            return Ok(cast(array, fill_value.dtype())?.into());
73        }
74
75        if fill_value.is_null() {
76            vortex_bail!("Cannot fill_null with a null value")
77        }
78
79        for kernel in kernels {
80            if let Some(output) = kernel.invoke(args)? {
81                return Ok(output);
82            }
83        }
84        if let Some(output) = array.invoke(&FILL_NULL_FN, args)? {
85            return Ok(output);
86        }
87
88        log::debug!("FillNullFn not implemented for {}", array.encoding_id());
89        if !array.is_canonical() {
90            let canonical_arr = array.to_canonical()?.into_array();
91            return Ok(fill_null(canonical_arr.as_ref(), fill_value)?.into());
92        }
93
94        vortex_bail!("fill null not implemented for DType {}", array.dtype())
95    }
96
97    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
98        let FillNullArgs { array, fill_value } = FillNullArgs::try_from(args)?;
99        if !array.dtype().eq_ignore_nullability(fill_value.dtype()) {
100            vortex_bail!("FillNull value must match array type (ignoring nullability)");
101        }
102        Ok(fill_value.dtype().clone())
103    }
104
105    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
106        let FillNullArgs { array, .. } = FillNullArgs::try_from(args)?;
107        Ok(array.len())
108    }
109
110    fn is_elementwise(&self) -> bool {
111        true
112    }
113}
114
115struct FillNullArgs<'a> {
116    array: &'a dyn Array,
117    fill_value: &'a Scalar,
118}
119
120impl<'a> TryFrom<&InvocationArgs<'a>> for FillNullArgs<'a> {
121    type Error = VortexError;
122
123    fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
124        if value.inputs.len() != 2 {
125            vortex_bail!("FillNull requires 2 arguments");
126        }
127
128        let array = value.inputs[0]
129            .array()
130            .ok_or_else(|| vortex_err!("FillNull requires an array"))?;
131        let fill_value = value.inputs[1]
132            .scalar()
133            .ok_or_else(|| vortex_err!("FillNull requires a scalar"))?;
134
135        Ok(FillNullArgs { array, fill_value })
136    }
137}