vortex_array/compute/
nan_count.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::LazyLock;
5
6use arcref::ArcRef;
7use vortex_dtype::DType;
8use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
9use vortex_scalar::{Scalar, ScalarValue};
10
11use crate::Array;
12use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Output, UnaryArgs};
13use crate::stats::{Precision, Stat, StatsProviderExt};
14use crate::vtable::VTable;
15
16static NAN_COUNT_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
17    let compute = ComputeFn::new("nan_count".into(), ArcRef::new_ref(&NaNCount));
18    for kernel in inventory::iter::<NaNCountKernelRef> {
19        compute.register_kernel(kernel.0.clone());
20    }
21    compute
22});
23
24/// Computes the number of NaN values in the array.
25pub fn nan_count(array: &dyn Array) -> VortexResult<usize> {
26    Ok(NAN_COUNT_FN
27        .invoke(&InvocationArgs {
28            inputs: &[array.into()],
29            options: &(),
30        })?
31        .unwrap_scalar()?
32        .as_primitive()
33        .as_::<usize>()
34        .vortex_expect("NaN count should not return null"))
35}
36
37struct NaNCount;
38
39impl ComputeFnVTable for NaNCount {
40    fn invoke(
41        &self,
42        args: &InvocationArgs,
43        kernels: &[ArcRef<dyn Kernel>],
44    ) -> VortexResult<Output> {
45        let UnaryArgs { array, .. } = UnaryArgs::<()>::try_from(args)?;
46
47        let nan_count = nan_count_impl(array, kernels)?;
48
49        // Update the stats set with the computed NaN count
50        array.statistics().set(
51            Stat::NaNCount,
52            Precision::Exact(ScalarValue::from(nan_count as u64)),
53        );
54
55        Ok(Scalar::from(nan_count as u64).into())
56    }
57
58    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
59        let UnaryArgs { array, .. } = UnaryArgs::<()>::try_from(args)?;
60        Stat::NaNCount
61            .dtype(array.dtype())
62            .ok_or_else(|| vortex_err!("Cannot compute NaN count for dtype {}", array.dtype()))
63    }
64
65    fn return_len(&self, _args: &InvocationArgs) -> VortexResult<usize> {
66        Ok(1)
67    }
68
69    fn is_elementwise(&self) -> bool {
70        false
71    }
72}
73
74/// Computes the min and max of an array, returning the (min, max) values
75pub trait NaNCountKernel: VTable {
76    fn nan_count(&self, array: &Self::Array) -> VortexResult<usize>;
77}
78
79pub struct NaNCountKernelRef(ArcRef<dyn Kernel>);
80inventory::collect!(NaNCountKernelRef);
81
82#[derive(Debug)]
83pub struct NaNCountKernelAdapter<V: VTable>(pub V);
84
85impl<V: VTable + NaNCountKernel> NaNCountKernelAdapter<V> {
86    pub const fn lift(&'static self) -> NaNCountKernelRef {
87        NaNCountKernelRef(ArcRef::new_ref(self))
88    }
89}
90
91impl<V: VTable + NaNCountKernel> Kernel for NaNCountKernelAdapter<V> {
92    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
93        let UnaryArgs { array, .. } = UnaryArgs::<()>::try_from(args)?;
94        let Some(array) = array.as_opt::<V>() else {
95            return Ok(None);
96        };
97        let nan_count = V::nan_count(&self.0, array)?;
98        Ok(Some(Scalar::from(nan_count as u64).into()))
99    }
100}
101
102fn nan_count_impl(array: &dyn Array, kernels: &[ArcRef<dyn Kernel>]) -> VortexResult<usize> {
103    if array.is_empty() || array.valid_count() == 0 {
104        return Ok(0);
105    }
106
107    if let Some(nan_count) = array
108        .statistics()
109        .get_as::<usize>(Stat::NaNCount)
110        .and_then(Precision::as_exact)
111    {
112        // If the NaN count is already computed, return it
113        return Ok(nan_count);
114    }
115
116    let args = InvocationArgs {
117        inputs: &[array.into()],
118        options: &(),
119    };
120
121    for kernel in kernels {
122        if let Some(output) = kernel.invoke(&args)? {
123            return output
124                .unwrap_scalar()?
125                .as_primitive()
126                .as_::<usize>()
127                .ok_or_else(|| vortex_err!("NaN count should not return null"));
128        }
129    }
130    if let Some(output) = array.invoke(&NAN_COUNT_FN, &args)? {
131        return output
132            .unwrap_scalar()?
133            .as_primitive()
134            .as_::<usize>()
135            .ok_or_else(|| vortex_err!("NaN count should not return null"));
136    }
137
138    if !array.is_canonical() {
139        let canonical = array.to_canonical()?;
140        return nan_count(canonical.as_ref());
141    }
142
143    vortex_bail!(
144        "No NaN count kernel found for array type: {}",
145        array.dtype()
146    )
147}