vortex_array/compute/
nan_count.rs1use std::sync::LazyLock;
5
6use arcref::ArcRef;
7use vortex_dtype::DType;
8use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
9use vortex_scalar::{Scalar, ScalarValue};
10
11use crate::Array;
12use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Output, UnaryArgs};
13use crate::stats::{Precision, Stat};
14use crate::vtable::VTable;
15
16pub fn nan_count(array: &dyn Array) -> VortexResult<usize> {
18 Ok(NAN_COUNT_FN
19 .invoke(&InvocationArgs {
20 inputs: &[array.into()],
21 options: &(),
22 })?
23 .unwrap_scalar()?
24 .as_primitive()
25 .as_::<usize>()?
26 .vortex_expect("NaN count should not return null"))
27}
28
29struct NaNCount;
30
31impl ComputeFnVTable for NaNCount {
32 fn invoke(
33 &self,
34 args: &InvocationArgs,
35 kernels: &[ArcRef<dyn Kernel>],
36 ) -> VortexResult<Output> {
37 let UnaryArgs { array, .. } = UnaryArgs::<()>::try_from(args)?;
38
39 let nan_count = nan_count_impl(array, kernels)?;
40
41 array.statistics().set(
43 Stat::NaNCount,
44 Precision::Exact(ScalarValue::from(nan_count as u64)),
45 );
46
47 Ok(Scalar::from(nan_count as u64).into())
48 }
49
50 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
51 let UnaryArgs { array, .. } = UnaryArgs::<()>::try_from(args)?;
52 Stat::NaNCount
53 .dtype(array.dtype())
54 .ok_or_else(|| vortex_err!("Cannot compute NaN count for dtype {}", array.dtype()))
55 }
56
57 fn return_len(&self, _args: &InvocationArgs) -> VortexResult<usize> {
58 Ok(1)
59 }
60
61 fn is_elementwise(&self) -> bool {
62 false
63 }
64}
65
66pub trait NaNCountKernel: VTable {
68 fn nan_count(&self, array: &Self::Array) -> VortexResult<usize>;
69}
70
71pub static NAN_COUNT_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
72 let compute = ComputeFn::new("nan_count".into(), ArcRef::new_ref(&NaNCount));
73 for kernel in inventory::iter::<NaNCountKernelRef> {
74 compute.register_kernel(kernel.0.clone());
75 }
76 compute
77});
78
79pub struct NaNCountKernelRef(ArcRef<dyn Kernel>);
80inventory::collect!(NaNCountKernelRef);
81
82#[derive(Debug)]
83pub struct NaNCountKernelAdapter<V: VTable>(pub V);
84
85impl<V: VTable + NaNCountKernel> NaNCountKernelAdapter<V> {
86 pub const fn lift(&'static self) -> NaNCountKernelRef {
87 NaNCountKernelRef(ArcRef::new_ref(self))
88 }
89}
90
91impl<V: VTable + NaNCountKernel> Kernel for NaNCountKernelAdapter<V> {
92 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
93 let UnaryArgs { array, .. } = UnaryArgs::<()>::try_from(args)?;
94 let Some(array) = array.as_opt::<V>() else {
95 return Ok(None);
96 };
97 let nan_count = V::nan_count(&self.0, array)?;
98 Ok(Some(Scalar::from(nan_count as u64).into()))
99 }
100}
101
102fn nan_count_impl(array: &dyn Array, kernels: &[ArcRef<dyn Kernel>]) -> VortexResult<usize> {
103 if array.is_empty() || array.valid_count()? == 0 {
104 return Ok(0);
105 }
106
107 if let Some(nan_count) = array
108 .statistics()
109 .get_as::<usize>(Stat::NaNCount)
110 .and_then(Precision::as_exact)
111 {
112 return Ok(nan_count);
114 }
115
116 let args = InvocationArgs {
117 inputs: &[array.into()],
118 options: &(),
119 };
120
121 for kernel in kernels {
122 if let Some(output) = kernel.invoke(&args)? {
123 return output
124 .unwrap_scalar()?
125 .as_primitive()
126 .as_::<usize>()?
127 .ok_or_else(|| vortex_err!("NaN count should not return null"));
128 }
129 }
130 if let Some(output) = array.invoke(&NAN_COUNT_FN, &args)? {
131 return output
132 .unwrap_scalar()?
133 .as_primitive()
134 .as_::<usize>()?
135 .ok_or_else(|| vortex_err!("NaN count should not return null"));
136 }
137
138 if !array.is_canonical() {
139 let canonical = array.to_canonical()?;
140 return nan_count(canonical.as_ref());
141 }
142
143 vortex_bail!(
144 "No NaN count kernel found for array type: {}",
145 array.dtype()
146 )
147}