Skip to main content

vortex_array/compute/
nan_count.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::LazyLock;
5
6use arcref::ArcRef;
7use vortex_error::VortexExpect;
8use vortex_error::VortexResult;
9use vortex_error::vortex_bail;
10use vortex_error::vortex_err;
11
12use crate::Array;
13use crate::ArrayRef;
14use crate::IntoArray as _;
15use crate::compute::ComputeFn;
16use crate::compute::ComputeFnVTable;
17use crate::compute::InvocationArgs;
18use crate::compute::Kernel;
19use crate::compute::Output;
20use crate::compute::UnaryArgs;
21use crate::dtype::DType;
22use crate::expr::stats::Precision;
23use crate::expr::stats::Stat;
24use crate::expr::stats::StatsProviderExt;
25use crate::scalar::Scalar;
26use crate::scalar::ScalarValue;
27use crate::vtable::VTable;
28
29static NAN_COUNT_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
30    let compute = ComputeFn::new("nan_count".into(), ArcRef::new_ref(&NaNCount));
31    for kernel in inventory::iter::<NaNCountKernelRef> {
32        compute.register_kernel(kernel.0.clone());
33    }
34    compute
35});
36
37pub(crate) fn warm_up_vtable() -> usize {
38    NAN_COUNT_FN.kernels().len()
39}
40
41/// Computes the number of NaN values in the array.
42pub fn nan_count(array: &ArrayRef) -> VortexResult<usize> {
43    Ok(NAN_COUNT_FN
44        .invoke(&InvocationArgs {
45            inputs: &[array.into()],
46            options: &(),
47        })?
48        .unwrap_scalar()?
49        .as_primitive()
50        .as_::<usize>()
51        .vortex_expect("NaN count should not return null"))
52}
53
54struct NaNCount;
55
56impl ComputeFnVTable for NaNCount {
57    fn invoke(
58        &self,
59        args: &InvocationArgs,
60        kernels: &[ArcRef<dyn Kernel>],
61    ) -> VortexResult<Output> {
62        let UnaryArgs { array, .. } = UnaryArgs::<()>::try_from(args)?;
63        let array = array.to_array();
64
65        let nan_count = nan_count_impl(&array, kernels)?;
66
67        // Update the stats set with the computed NaN count
68        array.statistics().set(
69            Stat::NaNCount,
70            Precision::Exact(ScalarValue::from(nan_count as u64)),
71        );
72
73        Ok(Scalar::from(nan_count as u64).into())
74    }
75
76    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
77        let UnaryArgs { array, .. } = UnaryArgs::<()>::try_from(args)?;
78        Stat::NaNCount
79            .dtype(array.dtype())
80            .ok_or_else(|| vortex_err!("Cannot compute NaN count for dtype {}", array.dtype()))
81    }
82
83    fn return_len(&self, _args: &InvocationArgs) -> VortexResult<usize> {
84        Ok(1)
85    }
86
87    fn is_elementwise(&self) -> bool {
88        false
89    }
90}
91
92/// Computes the min and max of an array, returning the (min, max) values
93pub trait NaNCountKernel: VTable {
94    fn nan_count(&self, array: &Self::Array) -> VortexResult<usize>;
95}
96
97pub struct NaNCountKernelRef(ArcRef<dyn Kernel>);
98inventory::collect!(NaNCountKernelRef);
99
100#[derive(Debug)]
101pub struct NaNCountKernelAdapter<V: VTable>(pub V);
102
103impl<V: VTable + NaNCountKernel> NaNCountKernelAdapter<V> {
104    pub const fn lift(&'static self) -> NaNCountKernelRef {
105        NaNCountKernelRef(ArcRef::new_ref(self))
106    }
107}
108
109impl<V: VTable + NaNCountKernel> Kernel for NaNCountKernelAdapter<V> {
110    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
111        let UnaryArgs { array, .. } = UnaryArgs::<()>::try_from(args)?;
112        let Some(array) = array.as_opt::<V>() else {
113            return Ok(None);
114        };
115        let nan_count = V::nan_count(&self.0, array)?;
116        Ok(Some(Scalar::from(nan_count as u64).into()))
117    }
118}
119
120fn nan_count_impl(array: &ArrayRef, kernels: &[ArcRef<dyn Kernel>]) -> VortexResult<usize> {
121    if array.is_empty() || array.valid_count()? == 0 {
122        return Ok(0);
123    }
124
125    if let Some(nan_count) = array
126        .statistics()
127        .get_as::<usize>(Stat::NaNCount)
128        .and_then(Precision::as_exact)
129    {
130        // If the NaN count is already computed, return it
131        return Ok(nan_count);
132    }
133
134    let args = InvocationArgs {
135        inputs: &[array.into()],
136        options: &(),
137    };
138
139    for kernel in kernels {
140        if let Some(output) = kernel.invoke(&args)? {
141            return output
142                .unwrap_scalar()?
143                .as_primitive()
144                .as_::<usize>()
145                .ok_or_else(|| vortex_err!("NaN count should not return null"));
146        }
147    }
148
149    if !array.is_canonical() {
150        let canonical = array.to_canonical()?.into_array();
151        return nan_count(&canonical);
152    }
153
154    vortex_bail!(
155        "No NaN count kernel found for array type: {}",
156        array.dtype()
157    )
158}