vortex_array/compute/
nan_count.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::LazyLock;
5
6use arcref::ArcRef;
7use vortex_dtype::DType;
8use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
9use vortex_scalar::{Scalar, ScalarValue};
10
11use crate::Array;
12use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Output, UnaryArgs};
13use crate::stats::{Precision, Stat, StatsProviderExt};
14use crate::vtable::VTable;
15
16static NAN_COUNT_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
17    let compute = ComputeFn::new("nan_count".into(), ArcRef::new_ref(&NaNCount));
18    for kernel in inventory::iter::<NaNCountKernelRef> {
19        compute.register_kernel(kernel.0.clone());
20    }
21    compute
22});
23
24pub(crate) fn warm_up_vtable() -> usize {
25    NAN_COUNT_FN.kernels().len()
26}
27
28/// Computes the number of NaN values in the array.
29pub fn nan_count(array: &dyn Array) -> VortexResult<usize> {
30    Ok(NAN_COUNT_FN
31        .invoke(&InvocationArgs {
32            inputs: &[array.into()],
33            options: &(),
34        })?
35        .unwrap_scalar()?
36        .as_primitive()
37        .as_::<usize>()
38        .vortex_expect("NaN count should not return null"))
39}
40
41struct NaNCount;
42
43impl ComputeFnVTable for NaNCount {
44    fn invoke(
45        &self,
46        args: &InvocationArgs,
47        kernels: &[ArcRef<dyn Kernel>],
48    ) -> VortexResult<Output> {
49        let UnaryArgs { array, .. } = UnaryArgs::<()>::try_from(args)?;
50
51        let nan_count = nan_count_impl(array, kernels)?;
52
53        // Update the stats set with the computed NaN count
54        array.statistics().set(
55            Stat::NaNCount,
56            Precision::Exact(ScalarValue::from(nan_count as u64)),
57        );
58
59        Ok(Scalar::from(nan_count as u64).into())
60    }
61
62    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
63        let UnaryArgs { array, .. } = UnaryArgs::<()>::try_from(args)?;
64        Stat::NaNCount
65            .dtype(array.dtype())
66            .ok_or_else(|| vortex_err!("Cannot compute NaN count for dtype {}", array.dtype()))
67    }
68
69    fn return_len(&self, _args: &InvocationArgs) -> VortexResult<usize> {
70        Ok(1)
71    }
72
73    fn is_elementwise(&self) -> bool {
74        false
75    }
76}
77
78/// Computes the min and max of an array, returning the (min, max) values
79pub trait NaNCountKernel: VTable {
80    fn nan_count(&self, array: &Self::Array) -> VortexResult<usize>;
81}
82
83pub struct NaNCountKernelRef(ArcRef<dyn Kernel>);
84inventory::collect!(NaNCountKernelRef);
85
86#[derive(Debug)]
87pub struct NaNCountKernelAdapter<V: VTable>(pub V);
88
89impl<V: VTable + NaNCountKernel> NaNCountKernelAdapter<V> {
90    pub const fn lift(&'static self) -> NaNCountKernelRef {
91        NaNCountKernelRef(ArcRef::new_ref(self))
92    }
93}
94
95impl<V: VTable + NaNCountKernel> Kernel for NaNCountKernelAdapter<V> {
96    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
97        let UnaryArgs { array, .. } = UnaryArgs::<()>::try_from(args)?;
98        let Some(array) = array.as_opt::<V>() else {
99            return Ok(None);
100        };
101        let nan_count = V::nan_count(&self.0, array)?;
102        Ok(Some(Scalar::from(nan_count as u64).into()))
103    }
104}
105
106fn nan_count_impl(array: &dyn Array, kernels: &[ArcRef<dyn Kernel>]) -> VortexResult<usize> {
107    if array.is_empty() || array.valid_count() == 0 {
108        return Ok(0);
109    }
110
111    if let Some(nan_count) = array
112        .statistics()
113        .get_as::<usize>(Stat::NaNCount)
114        .and_then(Precision::as_exact)
115    {
116        // If the NaN count is already computed, return it
117        return Ok(nan_count);
118    }
119
120    let args = InvocationArgs {
121        inputs: &[array.into()],
122        options: &(),
123    };
124
125    for kernel in kernels {
126        if let Some(output) = kernel.invoke(&args)? {
127            return output
128                .unwrap_scalar()?
129                .as_primitive()
130                .as_::<usize>()
131                .ok_or_else(|| vortex_err!("NaN count should not return null"));
132        }
133    }
134    if let Some(output) = array.invoke(&NAN_COUNT_FN, &args)? {
135        return output
136            .unwrap_scalar()?
137            .as_primitive()
138            .as_::<usize>()
139            .ok_or_else(|| vortex_err!("NaN count should not return null"));
140    }
141
142    if !array.is_canonical() {
143        let canonical = array.to_canonical();
144        return nan_count(canonical.as_ref());
145    }
146
147    vortex_bail!(
148        "No NaN count kernel found for array type: {}",
149        array.dtype()
150    )
151}