vortex_array/compute/
nan_count.rs1use std::sync::LazyLock;
2
3use arcref::ArcRef;
4use vortex_dtype::DType;
5use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
6use vortex_scalar::{Scalar, ScalarValue};
7
8use crate::Array;
9use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Output, UnaryArgs};
10use crate::stats::{Precision, Stat};
11use crate::vtable::VTable;
12
13pub fn nan_count(array: &dyn Array) -> VortexResult<usize> {
15 Ok(NAN_COUNT_FN
16 .invoke(&InvocationArgs {
17 inputs: &[array.into()],
18 options: &(),
19 })?
20 .unwrap_scalar()?
21 .as_primitive()
22 .as_::<usize>()?
23 .vortex_expect("NaN count should not return null"))
24}
25
26struct NaNCount;
27
28impl ComputeFnVTable for NaNCount {
29 fn invoke(
30 &self,
31 args: &InvocationArgs,
32 kernels: &[ArcRef<dyn Kernel>],
33 ) -> VortexResult<Output> {
34 let UnaryArgs { array, .. } = UnaryArgs::<()>::try_from(args)?;
35
36 let nan_count = nan_count_impl(array, kernels)?;
37
38 array.statistics().set(
40 Stat::NaNCount,
41 Precision::Exact(ScalarValue::from(nan_count as u64)),
42 );
43
44 Ok(Scalar::from(nan_count as u64).into())
45 }
46
47 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
48 let UnaryArgs { array, .. } = UnaryArgs::<()>::try_from(args)?;
49 Stat::NaNCount
50 .dtype(array.dtype())
51 .ok_or_else(|| vortex_err!("Cannot compute NaN count for dtype {}", array.dtype()))
52 }
53
54 fn return_len(&self, _args: &InvocationArgs) -> VortexResult<usize> {
55 Ok(1)
56 }
57
58 fn is_elementwise(&self) -> bool {
59 false
60 }
61}
62
63pub trait NaNCountKernel: VTable {
65 fn nan_count(&self, array: &Self::Array) -> VortexResult<usize>;
66}
67
68pub static NAN_COUNT_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
69 let compute = ComputeFn::new("nan_count".into(), ArcRef::new_ref(&NaNCount));
70 for kernel in inventory::iter::<NaNCountKernelRef> {
71 compute.register_kernel(kernel.0.clone());
72 }
73 compute
74});
75
76pub struct NaNCountKernelRef(ArcRef<dyn Kernel>);
77inventory::collect!(NaNCountKernelRef);
78
79#[derive(Debug)]
80pub struct NaNCountKernelAdapter<V: VTable>(pub V);
81
82impl<V: VTable + NaNCountKernel> NaNCountKernelAdapter<V> {
83 pub const fn lift(&'static self) -> NaNCountKernelRef {
84 NaNCountKernelRef(ArcRef::new_ref(self))
85 }
86}
87
88impl<V: VTable + NaNCountKernel> Kernel for NaNCountKernelAdapter<V> {
89 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
90 let UnaryArgs { array, .. } = UnaryArgs::<()>::try_from(args)?;
91 let Some(array) = array.as_opt::<V>() else {
92 return Ok(None);
93 };
94 let nan_count = V::nan_count(&self.0, array)?;
95 Ok(Some(Scalar::from(nan_count as u64).into()))
96 }
97}
98
99fn nan_count_impl(array: &dyn Array, kernels: &[ArcRef<dyn Kernel>]) -> VortexResult<usize> {
100 if array.is_empty() || array.valid_count()? == 0 {
101 return Ok(0);
102 }
103
104 if let Some(nan_count) = array
105 .statistics()
106 .get_as::<usize>(Stat::NaNCount)
107 .and_then(Precision::as_exact)
108 {
109 return Ok(nan_count);
111 }
112
113 let args = InvocationArgs {
114 inputs: &[array.into()],
115 options: &(),
116 };
117
118 for kernel in kernels {
119 if let Some(output) = kernel.invoke(&args)? {
120 return output
121 .unwrap_scalar()?
122 .as_primitive()
123 .as_::<usize>()?
124 .ok_or_else(|| vortex_err!("NaN count should not return null"));
125 }
126 }
127 if let Some(output) = array.invoke(&NAN_COUNT_FN, &args)? {
128 return output
129 .unwrap_scalar()?
130 .as_primitive()
131 .as_::<usize>()?
132 .ok_or_else(|| vortex_err!("NaN count should not return null"));
133 }
134
135 if !array.is_canonical() {
136 let canonical = array.to_canonical()?;
137 return nan_count(canonical.as_ref());
138 }
139
140 vortex_bail!(
141 "No NaN count kernel found for array type: {}",
142 array.dtype()
143 )
144}