vortex_array/compute/
nan_count.rs1use std::sync::LazyLock;
5
6use arcref::ArcRef;
7use vortex_dtype::DType;
8use vortex_error::VortexExpect;
9use vortex_error::VortexResult;
10use vortex_error::vortex_bail;
11use vortex_error::vortex_err;
12use vortex_scalar::Scalar;
13use vortex_scalar::ScalarValue;
14
15use crate::Array;
16use crate::compute::ComputeFn;
17use crate::compute::ComputeFnVTable;
18use crate::compute::InvocationArgs;
19use crate::compute::Kernel;
20use crate::compute::Output;
21use crate::compute::UnaryArgs;
22use crate::expr::stats::Precision;
23use crate::expr::stats::Stat;
24use crate::expr::stats::StatsProviderExt;
25use crate::vtable::VTable;
26
27static NAN_COUNT_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
28 let compute = ComputeFn::new("nan_count".into(), ArcRef::new_ref(&NaNCount));
29 for kernel in inventory::iter::<NaNCountKernelRef> {
30 compute.register_kernel(kernel.0.clone());
31 }
32 compute
33});
34
35pub(crate) fn warm_up_vtable() -> usize {
36 NAN_COUNT_FN.kernels().len()
37}
38
39pub fn nan_count(array: &dyn Array) -> VortexResult<usize> {
41 Ok(NAN_COUNT_FN
42 .invoke(&InvocationArgs {
43 inputs: &[array.into()],
44 options: &(),
45 })?
46 .unwrap_scalar()?
47 .as_primitive()
48 .as_::<usize>()
49 .vortex_expect("NaN count should not return null"))
50}
51
52struct NaNCount;
53
54impl ComputeFnVTable for NaNCount {
55 fn invoke(
56 &self,
57 args: &InvocationArgs,
58 kernels: &[ArcRef<dyn Kernel>],
59 ) -> VortexResult<Output> {
60 let UnaryArgs { array, .. } = UnaryArgs::<()>::try_from(args)?;
61
62 let nan_count = nan_count_impl(array, kernels)?;
63
64 array.statistics().set(
66 Stat::NaNCount,
67 Precision::Exact(ScalarValue::from(nan_count as u64)),
68 );
69
70 Ok(Scalar::from(nan_count as u64).into())
71 }
72
73 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
74 let UnaryArgs { array, .. } = UnaryArgs::<()>::try_from(args)?;
75 Stat::NaNCount
76 .dtype(array.dtype())
77 .ok_or_else(|| vortex_err!("Cannot compute NaN count for dtype {}", array.dtype()))
78 }
79
80 fn return_len(&self, _args: &InvocationArgs) -> VortexResult<usize> {
81 Ok(1)
82 }
83
84 fn is_elementwise(&self) -> bool {
85 false
86 }
87}
88
89pub trait NaNCountKernel: VTable {
91 fn nan_count(&self, array: &Self::Array) -> VortexResult<usize>;
92}
93
94pub struct NaNCountKernelRef(ArcRef<dyn Kernel>);
95inventory::collect!(NaNCountKernelRef);
96
97#[derive(Debug)]
98pub struct NaNCountKernelAdapter<V: VTable>(pub V);
99
100impl<V: VTable + NaNCountKernel> NaNCountKernelAdapter<V> {
101 pub const fn lift(&'static self) -> NaNCountKernelRef {
102 NaNCountKernelRef(ArcRef::new_ref(self))
103 }
104}
105
106impl<V: VTable + NaNCountKernel> Kernel for NaNCountKernelAdapter<V> {
107 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
108 let UnaryArgs { array, .. } = UnaryArgs::<()>::try_from(args)?;
109 let Some(array) = array.as_opt::<V>() else {
110 return Ok(None);
111 };
112 let nan_count = V::nan_count(&self.0, array)?;
113 Ok(Some(Scalar::from(nan_count as u64).into()))
114 }
115}
116
117fn nan_count_impl(array: &dyn Array, kernels: &[ArcRef<dyn Kernel>]) -> VortexResult<usize> {
118 if array.is_empty() || array.valid_count() == 0 {
119 return Ok(0);
120 }
121
122 if let Some(nan_count) = array
123 .statistics()
124 .get_as::<usize>(Stat::NaNCount)
125 .and_then(Precision::as_exact)
126 {
127 return Ok(nan_count);
129 }
130
131 let args = InvocationArgs {
132 inputs: &[array.into()],
133 options: &(),
134 };
135
136 for kernel in kernels {
137 if let Some(output) = kernel.invoke(&args)? {
138 return output
139 .unwrap_scalar()?
140 .as_primitive()
141 .as_::<usize>()
142 .ok_or_else(|| vortex_err!("NaN count should not return null"));
143 }
144 }
145 if let Some(output) = array.invoke(&NAN_COUNT_FN, &args)? {
146 return output
147 .unwrap_scalar()?
148 .as_primitive()
149 .as_::<usize>()
150 .ok_or_else(|| vortex_err!("NaN count should not return null"));
151 }
152
153 if !array.is_canonical() {
154 let canonical = array.to_canonical();
155 return nan_count(canonical.as_ref());
156 }
157
158 vortex_bail!(
159 "No NaN count kernel found for array type: {}",
160 array.dtype()
161 )
162}