vortex_array/compute/
nan_count.rs1use std::sync::LazyLock;
2
3use arcref::ArcRef;
4use vortex_dtype::DType;
5use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
6use vortex_scalar::{Scalar, ScalarValue};
7
8use crate::Array;
9use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Output, UnaryArgs};
10use crate::stats::{Precision, Stat};
11use crate::vtable::VTable;
12
13pub fn nan_count(array: &dyn Array) -> VortexResult<usize> {
15    Ok(NAN_COUNT_FN
16        .invoke(&InvocationArgs {
17            inputs: &[array.into()],
18            options: &(),
19        })?
20        .unwrap_scalar()?
21        .as_primitive()
22        .as_::<usize>()?
23        .vortex_expect("NaN count should not return null"))
24}
25
26struct NaNCount;
27
28impl ComputeFnVTable for NaNCount {
29    fn invoke(
30        &self,
31        args: &InvocationArgs,
32        kernels: &[ArcRef<dyn Kernel>],
33    ) -> VortexResult<Output> {
34        let UnaryArgs { array, .. } = UnaryArgs::<()>::try_from(args)?;
35
36        let nan_count = nan_count_impl(array, kernels)?;
37
38        array.statistics().set(
40            Stat::NaNCount,
41            Precision::Exact(ScalarValue::from(nan_count as u64)),
42        );
43
44        Ok(Scalar::from(nan_count as u64).into())
45    }
46
47    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
48        let UnaryArgs { array, .. } = UnaryArgs::<()>::try_from(args)?;
49        Stat::NaNCount
50            .dtype(array.dtype())
51            .ok_or_else(|| vortex_err!("Cannot compute NaN count for dtype {}", array.dtype()))
52    }
53
54    fn return_len(&self, _args: &InvocationArgs) -> VortexResult<usize> {
55        Ok(1)
56    }
57
58    fn is_elementwise(&self) -> bool {
59        false
60    }
61}
62
63pub trait NaNCountKernel: VTable {
65    fn nan_count(&self, array: &Self::Array) -> VortexResult<usize>;
66}
67
68pub static NAN_COUNT_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
69    let compute = ComputeFn::new("nan_count".into(), ArcRef::new_ref(&NaNCount));
70    for kernel in inventory::iter::<NaNCountKernelRef> {
71        compute.register_kernel(kernel.0.clone());
72    }
73    compute
74});
75
76pub struct NaNCountKernelRef(ArcRef<dyn Kernel>);
77inventory::collect!(NaNCountKernelRef);
78
79#[derive(Debug)]
80pub struct NaNCountKernelAdapter<V: VTable>(pub V);
81
82impl<V: VTable + NaNCountKernel> NaNCountKernelAdapter<V> {
83    pub const fn lift(&'static self) -> NaNCountKernelRef {
84        NaNCountKernelRef(ArcRef::new_ref(self))
85    }
86}
87
88impl<V: VTable + NaNCountKernel> Kernel for NaNCountKernelAdapter<V> {
89    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
90        let UnaryArgs { array, .. } = UnaryArgs::<()>::try_from(args)?;
91        let Some(array) = array.as_opt::<V>() else {
92            return Ok(None);
93        };
94        let nan_count = V::nan_count(&self.0, array)?;
95        Ok(Some(Scalar::from(nan_count as u64).into()))
96    }
97}
98
99fn nan_count_impl(array: &dyn Array, kernels: &[ArcRef<dyn Kernel>]) -> VortexResult<usize> {
100    if array.is_empty() || array.valid_count()? == 0 {
101        return Ok(0);
102    }
103
104    if let Some(nan_count) = array
105        .statistics()
106        .get_as::<usize>(Stat::NaNCount)
107        .and_then(Precision::as_exact)
108    {
109        return Ok(nan_count);
111    }
112
113    let args = InvocationArgs {
114        inputs: &[array.into()],
115        options: &(),
116    };
117
118    for kernel in kernels {
119        if let Some(output) = kernel.invoke(&args)? {
120            return output
121                .unwrap_scalar()?
122                .as_primitive()
123                .as_::<usize>()?
124                .ok_or_else(|| vortex_err!("NaN count should not return null"));
125        }
126    }
127    if let Some(output) = array.invoke(&NAN_COUNT_FN, &args)? {
128        return output
129            .unwrap_scalar()?
130            .as_primitive()
131            .as_::<usize>()?
132            .ok_or_else(|| vortex_err!("NaN count should not return null"));
133    }
134
135    if !array.is_canonical() {
136        let canonical = array.to_canonical()?;
137        return nan_count(canonical.as_ref());
138    }
139
140    vortex_bail!(
141        "No NaN count kernel found for array type: {}",
142        array.dtype()
143    )
144}