vortex_array/compute/
fill_null.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::LazyLock;
5
6use arcref::ArcRef;
7use vortex_dtype::DType;
8use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err};
9use vortex_scalar::Scalar;
10
11use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Output, cast};
12use crate::vtable::VTable;
13use crate::{Array, ArrayRef, IntoArray};
14
15pub fn fill_null(array: &dyn Array, fill_value: &Scalar) -> VortexResult<ArrayRef> {
16    FILL_NULL_FN
17        .invoke(&InvocationArgs {
18            inputs: &[array.into(), fill_value.into()],
19            options: &(),
20        })?
21        .unwrap_array()
22}
23
24pub trait FillNullKernel: VTable {
25    fn fill_null(&self, array: &Self::Array, fill_value: &Scalar) -> VortexResult<ArrayRef>;
26}
27
28pub struct FillNullKernelRef(ArcRef<dyn Kernel>);
29inventory::collect!(FillNullKernelRef);
30
31#[derive(Debug)]
32pub struct FillNullKernelAdapter<V: VTable>(pub V);
33
34impl<V: VTable + FillNullKernel> FillNullKernelAdapter<V> {
35    pub const fn lift(&'static self) -> FillNullKernelRef {
36        FillNullKernelRef(ArcRef::new_ref(self))
37    }
38}
39
40impl<V: VTable + FillNullKernel> Kernel for FillNullKernelAdapter<V> {
41    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
42        let inputs = FillNullArgs::try_from(args)?;
43        let Some(array) = inputs.array.as_opt::<V>() else {
44            return Ok(None);
45        };
46        Ok(Some(
47            V::fill_null(&self.0, array, inputs.fill_value)?.into(),
48        ))
49    }
50}
51
52pub static FILL_NULL_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
53    let compute = ComputeFn::new("fill_null".into(), ArcRef::new_ref(&FillNull));
54    for kernel in inventory::iter::<FillNullKernelRef> {
55        compute.register_kernel(kernel.0.clone());
56    }
57    compute
58});
59
60struct FillNull;
61
62impl ComputeFnVTable for FillNull {
63    fn invoke(
64        &self,
65        args: &InvocationArgs,
66        kernels: &[ArcRef<dyn Kernel>],
67    ) -> VortexResult<Output> {
68        let FillNullArgs { array, fill_value } = FillNullArgs::try_from(args)?;
69
70        if !array.dtype().is_nullable() {
71            return Ok(array.to_array().into());
72        }
73
74        if array.invalid_count()? == 0 {
75            return Ok(cast(array, fill_value.dtype())?.into());
76        }
77
78        if fill_value.is_null() {
79            vortex_bail!("Cannot fill_null with a null value")
80        }
81
82        for kernel in kernels {
83            if let Some(output) = kernel.invoke(args)? {
84                return Ok(output);
85            }
86        }
87        if let Some(output) = array.invoke(&FILL_NULL_FN, args)? {
88            return Ok(output);
89        }
90
91        log::debug!("FillNullFn not implemented for {}", array.encoding_id());
92        if !array.is_canonical() {
93            let canonical_arr = array.to_canonical()?.into_array();
94            return Ok(fill_null(canonical_arr.as_ref(), fill_value)?.into());
95        }
96
97        vortex_bail!("fill null not implemented for DType {}", array.dtype())
98    }
99
100    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
101        let FillNullArgs { array, fill_value } = FillNullArgs::try_from(args)?;
102        if !array.dtype().eq_ignore_nullability(fill_value.dtype()) {
103            vortex_bail!("FillNull value must match array type (ignoring nullability)");
104        }
105        Ok(fill_value.dtype().clone())
106    }
107
108    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
109        let FillNullArgs { array, .. } = FillNullArgs::try_from(args)?;
110        Ok(array.len())
111    }
112
113    fn is_elementwise(&self) -> bool {
114        true
115    }
116}
117
118struct FillNullArgs<'a> {
119    array: &'a dyn Array,
120    fill_value: &'a Scalar,
121}
122
123impl<'a> TryFrom<&InvocationArgs<'a>> for FillNullArgs<'a> {
124    type Error = VortexError;
125
126    fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
127        if value.inputs.len() != 2 {
128            vortex_bail!("FillNull requires 2 arguments");
129        }
130
131        let array = value.inputs[0]
132            .array()
133            .ok_or_else(|| vortex_err!("FillNull requires an array"))?;
134        let fill_value = value.inputs[1]
135            .scalar()
136            .ok_or_else(|| vortex_err!("FillNull requires a scalar"))?;
137
138        Ok(FillNullArgs { array, fill_value })
139    }
140}