vortex_array/compute/
nan_count.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::LazyLock;
5
6use arcref::ArcRef;
7use vortex_dtype::DType;
8use vortex_error::VortexExpect;
9use vortex_error::VortexResult;
10use vortex_error::vortex_bail;
11use vortex_error::vortex_err;
12use vortex_scalar::Scalar;
13use vortex_scalar::ScalarValue;
14
15use crate::Array;
16use crate::compute::ComputeFn;
17use crate::compute::ComputeFnVTable;
18use crate::compute::InvocationArgs;
19use crate::compute::Kernel;
20use crate::compute::Output;
21use crate::compute::UnaryArgs;
22use crate::expr::stats::Precision;
23use crate::expr::stats::Stat;
24use crate::expr::stats::StatsProviderExt;
25use crate::vtable::VTable;
26
27static NAN_COUNT_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
28    let compute = ComputeFn::new("nan_count".into(), ArcRef::new_ref(&NaNCount));
29    for kernel in inventory::iter::<NaNCountKernelRef> {
30        compute.register_kernel(kernel.0.clone());
31    }
32    compute
33});
34
35pub(crate) fn warm_up_vtable() -> usize {
36    NAN_COUNT_FN.kernels().len()
37}
38
39/// Computes the number of NaN values in the array.
40pub fn nan_count(array: &dyn Array) -> VortexResult<usize> {
41    Ok(NAN_COUNT_FN
42        .invoke(&InvocationArgs {
43            inputs: &[array.into()],
44            options: &(),
45        })?
46        .unwrap_scalar()?
47        .as_primitive()
48        .as_::<usize>()
49        .vortex_expect("NaN count should not return null"))
50}
51
52struct NaNCount;
53
54impl ComputeFnVTable for NaNCount {
55    fn invoke(
56        &self,
57        args: &InvocationArgs,
58        kernels: &[ArcRef<dyn Kernel>],
59    ) -> VortexResult<Output> {
60        let UnaryArgs { array, .. } = UnaryArgs::<()>::try_from(args)?;
61
62        let nan_count = nan_count_impl(array, kernels)?;
63
64        // Update the stats set with the computed NaN count
65        array.statistics().set(
66            Stat::NaNCount,
67            Precision::Exact(ScalarValue::from(nan_count as u64)),
68        );
69
70        Ok(Scalar::from(nan_count as u64).into())
71    }
72
73    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
74        let UnaryArgs { array, .. } = UnaryArgs::<()>::try_from(args)?;
75        Stat::NaNCount
76            .dtype(array.dtype())
77            .ok_or_else(|| vortex_err!("Cannot compute NaN count for dtype {}", array.dtype()))
78    }
79
80    fn return_len(&self, _args: &InvocationArgs) -> VortexResult<usize> {
81        Ok(1)
82    }
83
84    fn is_elementwise(&self) -> bool {
85        false
86    }
87}
88
89/// Computes the min and max of an array, returning the (min, max) values
90pub trait NaNCountKernel: VTable {
91    fn nan_count(&self, array: &Self::Array) -> VortexResult<usize>;
92}
93
94pub struct NaNCountKernelRef(ArcRef<dyn Kernel>);
95inventory::collect!(NaNCountKernelRef);
96
97#[derive(Debug)]
98pub struct NaNCountKernelAdapter<V: VTable>(pub V);
99
100impl<V: VTable + NaNCountKernel> NaNCountKernelAdapter<V> {
101    pub const fn lift(&'static self) -> NaNCountKernelRef {
102        NaNCountKernelRef(ArcRef::new_ref(self))
103    }
104}
105
106impl<V: VTable + NaNCountKernel> Kernel for NaNCountKernelAdapter<V> {
107    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
108        let UnaryArgs { array, .. } = UnaryArgs::<()>::try_from(args)?;
109        let Some(array) = array.as_opt::<V>() else {
110            return Ok(None);
111        };
112        let nan_count = V::nan_count(&self.0, array)?;
113        Ok(Some(Scalar::from(nan_count as u64).into()))
114    }
115}
116
117fn nan_count_impl(array: &dyn Array, kernels: &[ArcRef<dyn Kernel>]) -> VortexResult<usize> {
118    if array.is_empty() || array.valid_count() == 0 {
119        return Ok(0);
120    }
121
122    if let Some(nan_count) = array
123        .statistics()
124        .get_as::<usize>(Stat::NaNCount)
125        .and_then(Precision::as_exact)
126    {
127        // If the NaN count is already computed, return it
128        return Ok(nan_count);
129    }
130
131    let args = InvocationArgs {
132        inputs: &[array.into()],
133        options: &(),
134    };
135
136    for kernel in kernels {
137        if let Some(output) = kernel.invoke(&args)? {
138            return output
139                .unwrap_scalar()?
140                .as_primitive()
141                .as_::<usize>()
142                .ok_or_else(|| vortex_err!("NaN count should not return null"));
143        }
144    }
145    if let Some(output) = array.invoke(&NAN_COUNT_FN, &args)? {
146        return output
147            .unwrap_scalar()?
148            .as_primitive()
149            .as_::<usize>()
150            .ok_or_else(|| vortex_err!("NaN count should not return null"));
151    }
152
153    if !array.is_canonical() {
154        let canonical = array.to_canonical();
155        return nan_count(canonical.as_ref());
156    }
157
158    vortex_bail!(
159        "No NaN count kernel found for array type: {}",
160        array.dtype()
161    )
162}