vortex_array/compute/
nan_count.rs1use std::sync::LazyLock;
5
6use arcref::ArcRef;
7use vortex_error::VortexExpect;
8use vortex_error::VortexResult;
9use vortex_error::vortex_bail;
10use vortex_error::vortex_err;
11
12use crate::Array;
13use crate::ArrayRef;
14use crate::IntoArray as _;
15use crate::compute::ComputeFn;
16use crate::compute::ComputeFnVTable;
17use crate::compute::InvocationArgs;
18use crate::compute::Kernel;
19use crate::compute::Output;
20use crate::compute::UnaryArgs;
21use crate::dtype::DType;
22use crate::expr::stats::Precision;
23use crate::expr::stats::Stat;
24use crate::expr::stats::StatsProviderExt;
25use crate::scalar::Scalar;
26use crate::scalar::ScalarValue;
27use crate::vtable::VTable;
28
29static NAN_COUNT_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
30 let compute = ComputeFn::new("nan_count".into(), ArcRef::new_ref(&NaNCount));
31 for kernel in inventory::iter::<NaNCountKernelRef> {
32 compute.register_kernel(kernel.0.clone());
33 }
34 compute
35});
36
37pub(crate) fn warm_up_vtable() -> usize {
38 NAN_COUNT_FN.kernels().len()
39}
40
41pub fn nan_count(array: &ArrayRef) -> VortexResult<usize> {
43 Ok(NAN_COUNT_FN
44 .invoke(&InvocationArgs {
45 inputs: &[array.into()],
46 options: &(),
47 })?
48 .unwrap_scalar()?
49 .as_primitive()
50 .as_::<usize>()
51 .vortex_expect("NaN count should not return null"))
52}
53
54struct NaNCount;
55
56impl ComputeFnVTable for NaNCount {
57 fn invoke(
58 &self,
59 args: &InvocationArgs,
60 kernels: &[ArcRef<dyn Kernel>],
61 ) -> VortexResult<Output> {
62 let UnaryArgs { array, .. } = UnaryArgs::<()>::try_from(args)?;
63 let array = array.to_array();
64
65 let nan_count = nan_count_impl(&array, kernels)?;
66
67 array.statistics().set(
69 Stat::NaNCount,
70 Precision::Exact(ScalarValue::from(nan_count as u64)),
71 );
72
73 Ok(Scalar::from(nan_count as u64).into())
74 }
75
76 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
77 let UnaryArgs { array, .. } = UnaryArgs::<()>::try_from(args)?;
78 Stat::NaNCount
79 .dtype(array.dtype())
80 .ok_or_else(|| vortex_err!("Cannot compute NaN count for dtype {}", array.dtype()))
81 }
82
83 fn return_len(&self, _args: &InvocationArgs) -> VortexResult<usize> {
84 Ok(1)
85 }
86
87 fn is_elementwise(&self) -> bool {
88 false
89 }
90}
91
92pub trait NaNCountKernel: VTable {
94 fn nan_count(&self, array: &Self::Array) -> VortexResult<usize>;
95}
96
97pub struct NaNCountKernelRef(ArcRef<dyn Kernel>);
98inventory::collect!(NaNCountKernelRef);
99
100#[derive(Debug)]
101pub struct NaNCountKernelAdapter<V: VTable>(pub V);
102
103impl<V: VTable + NaNCountKernel> NaNCountKernelAdapter<V> {
104 pub const fn lift(&'static self) -> NaNCountKernelRef {
105 NaNCountKernelRef(ArcRef::new_ref(self))
106 }
107}
108
109impl<V: VTable + NaNCountKernel> Kernel for NaNCountKernelAdapter<V> {
110 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
111 let UnaryArgs { array, .. } = UnaryArgs::<()>::try_from(args)?;
112 let Some(array) = array.as_opt::<V>() else {
113 return Ok(None);
114 };
115 let nan_count = V::nan_count(&self.0, array)?;
116 Ok(Some(Scalar::from(nan_count as u64).into()))
117 }
118}
119
120fn nan_count_impl(array: &ArrayRef, kernels: &[ArcRef<dyn Kernel>]) -> VortexResult<usize> {
121 if array.is_empty() || array.valid_count()? == 0 {
122 return Ok(0);
123 }
124
125 if let Some(nan_count) = array
126 .statistics()
127 .get_as::<usize>(Stat::NaNCount)
128 .and_then(Precision::as_exact)
129 {
130 return Ok(nan_count);
132 }
133
134 let args = InvocationArgs {
135 inputs: &[array.into()],
136 options: &(),
137 };
138
139 for kernel in kernels {
140 if let Some(output) = kernel.invoke(&args)? {
141 return output
142 .unwrap_scalar()?
143 .as_primitive()
144 .as_::<usize>()
145 .ok_or_else(|| vortex_err!("NaN count should not return null"));
146 }
147 }
148
149 if !array.is_canonical() {
150 let canonical = array.to_canonical()?.into_array();
151 return nan_count(&canonical);
152 }
153
154 vortex_bail!(
155 "No NaN count kernel found for array type: {}",
156 array.dtype()
157 )
158}