vortex_array/aggregate_fn/fns/nan_count/
mod.rs1mod primitive;
5
6use vortex_error::VortexExpect;
7use vortex_error::VortexResult;
8use vortex_error::vortex_bail;
9use vortex_error::vortex_err;
10use vortex_session::VortexSession;
11use vortex_session::registry::CachedId;
12
13use self::primitive::accumulate_primitive;
14use crate::ArrayRef;
15use crate::Canonical;
16use crate::Columnar;
17use crate::ExecutionCtx;
18use crate::aggregate_fn::Accumulator;
19use crate::aggregate_fn::AggregateFnId;
20use crate::aggregate_fn::AggregateFnVTable;
21use crate::aggregate_fn::DynAccumulator;
22use crate::aggregate_fn::EmptyOptions;
23use crate::dtype::DType;
24use crate::dtype::Nullability::NonNullable;
25use crate::dtype::PType;
26use crate::expr::stats::Precision;
27use crate::expr::stats::Stat;
28use crate::expr::stats::StatsProvider;
29use crate::scalar::Scalar;
30use crate::scalar::ScalarValue;
31
32pub fn nan_count(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<usize> {
38 if let Precision::Exact(nan_count_scalar) = array.statistics().get(Stat::NaNCount) {
40 return usize::try_from(&nan_count_scalar)
41 .map_err(|e| vortex_err!("Failed to convert NaN count stat to usize: {e}"));
42 }
43
44 if NanCount
46 .return_dtype(&EmptyOptions, array.dtype())
47 .is_none()
48 {
49 return Ok(0);
50 }
51
52 if array.is_empty() || array.valid_count(ctx)? == 0 {
54 return Ok(0);
55 }
56
57 let mut acc = Accumulator::try_new(NanCount, EmptyOptions, array.dtype().clone())?;
59 acc.accumulate(array, ctx)?;
60 let result = acc.finish()?;
61
62 let count = result
63 .as_primitive()
64 .typed_value::<u64>()
65 .vortex_expect("nan_count result should not be null");
66 let count_usize = usize::try_from(count).vortex_expect("Cannot be more nans than usize::MAX");
67
68 array
70 .statistics()
71 .set(Stat::NaNCount, Precision::Exact(ScalarValue::from(count)));
72
73 Ok(count_usize)
74}
75
76#[derive(Clone, Debug)]
81pub struct NanCount;
82
83impl AggregateFnVTable for NanCount {
84 type Options = EmptyOptions;
85 type Partial = u64;
86
87 fn id(&self) -> AggregateFnId {
88 static ID: CachedId = CachedId::new("vortex.nan_count");
89 *ID
90 }
91
92 fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
93 Ok(Some(vec![]))
94 }
95
96 fn deserialize(
97 &self,
98 _metadata: &[u8],
99 _session: &VortexSession,
100 ) -> VortexResult<Self::Options> {
101 Ok(EmptyOptions)
102 }
103
104 fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option<DType> {
105 if let DType::Primitive(ptype, ..) = input_dtype
106 && ptype.is_float()
107 {
108 Some(DType::Primitive(PType::U64, NonNullable))
109 } else {
110 None
111 }
112 }
113
114 fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option<DType> {
115 self.return_dtype(options, input_dtype)
116 }
117
118 fn empty_partial(
119 &self,
120 _options: &Self::Options,
121 _input_dtype: &DType,
122 ) -> VortexResult<Self::Partial> {
123 Ok(0u64)
124 }
125
126 fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()> {
127 let val = other
128 .as_primitive()
129 .typed_value::<u64>()
130 .vortex_expect("nan_count partial should not be null");
131 *partial += val;
132 Ok(())
133 }
134
135 fn to_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
136 Ok(Scalar::primitive(*partial, NonNullable))
137 }
138
139 fn reset(&self, partial: &mut Self::Partial) {
140 *partial = 0;
141 }
142
143 #[inline]
144 fn is_saturated(&self, _partial: &Self::Partial) -> bool {
145 false
146 }
147
148 fn accumulate(
149 &self,
150 partial: &mut Self::Partial,
151 batch: &Columnar,
152 ctx: &mut ExecutionCtx,
153 ) -> VortexResult<()> {
154 match batch {
155 Columnar::Constant(c) => {
156 if c.scalar().is_null() {
157 return Ok(());
159 }
160 if c.scalar().as_primitive().is_nan() {
161 *partial += c.len() as u64;
162 }
163 Ok(())
164 }
165 Columnar::Canonical(c) => match c {
166 Canonical::Primitive(p) => accumulate_primitive(partial, p, ctx),
167 _ => vortex_bail!(
168 "Unsupported canonical type for nan_count: {}",
169 batch.dtype()
170 ),
171 },
172 }
173 }
174
175 fn finalize(&self, partials: ArrayRef) -> VortexResult<ArrayRef> {
176 Ok(partials)
177 }
178
179 fn finalize_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
180 self.to_scalar(partial)
181 }
182}
183
184#[cfg(test)]
185mod tests {
186 use vortex_buffer::buffer;
187 use vortex_error::VortexResult;
188
189 use crate::IntoArray;
190 use crate::VortexSessionExecute;
191 use crate::aggregate_fn::Accumulator;
192 use crate::aggregate_fn::AggregateFnVTable;
193 use crate::aggregate_fn::DynAccumulator;
194 use crate::aggregate_fn::EmptyOptions;
195 use crate::aggregate_fn::fns::nan_count::NanCount;
196 use crate::aggregate_fn::fns::nan_count::nan_count;
197 use crate::array_session;
198 use crate::arrays::ChunkedArray;
199 use crate::arrays::ConstantArray;
200 use crate::arrays::PrimitiveArray;
201 use crate::dtype::DType;
202 use crate::dtype::Nullability;
203 use crate::dtype::PType;
204 use crate::scalar::Scalar;
205 use crate::validity::Validity;
206
207 #[test]
208 fn nan_count_multi_batch() -> VortexResult<()> {
209 let mut ctx = array_session().create_execution_ctx();
210 let dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
211 let mut acc = Accumulator::try_new(NanCount, EmptyOptions, dtype)?;
212
213 let batch1 =
214 PrimitiveArray::new(buffer![f64::NAN, 1.0f64, f64::NAN], Validity::NonNullable)
215 .into_array();
216 acc.accumulate(&batch1, &mut ctx)?;
217
218 let batch2 =
219 PrimitiveArray::new(buffer![2.0f64, f64::NAN], Validity::NonNullable).into_array();
220 acc.accumulate(&batch2, &mut ctx)?;
221
222 let result = acc.finish()?;
223 assert_eq!(result.as_primitive().typed_value::<u64>(), Some(3));
224 Ok(())
225 }
226
227 #[test]
228 fn nan_count_finish_resets_state() -> VortexResult<()> {
229 let mut ctx = array_session().create_execution_ctx();
230 let dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
231 let mut acc = Accumulator::try_new(NanCount, EmptyOptions, dtype)?;
232
233 let batch1 =
234 PrimitiveArray::new(buffer![f64::NAN, 1.0f64], Validity::NonNullable).into_array();
235 acc.accumulate(&batch1, &mut ctx)?;
236 let result1 = acc.finish()?;
237 assert_eq!(result1.as_primitive().typed_value::<u64>(), Some(1));
238
239 let batch2 = PrimitiveArray::new(buffer![f64::NAN, f64::NAN, 2.0], Validity::NonNullable)
240 .into_array();
241 acc.accumulate(&batch2, &mut ctx)?;
242 let result2 = acc.finish()?;
243 assert_eq!(result2.as_primitive().typed_value::<u64>(), Some(2));
244 Ok(())
245 }
246
247 #[test]
248 fn nan_count_state_merge() -> VortexResult<()> {
249 let dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
250 let mut state = NanCount.empty_partial(&EmptyOptions, &dtype)?;
251
252 let scalar1 = Scalar::primitive(5u64, Nullability::NonNullable);
253 NanCount.combine_partials(&mut state, scalar1)?;
254
255 let scalar2 = Scalar::primitive(3u64, Nullability::NonNullable);
256 NanCount.combine_partials(&mut state, scalar2)?;
257
258 let result = NanCount.to_scalar(&state)?;
259 NanCount.reset(&mut state);
260 assert_eq!(result.as_primitive().typed_value::<u64>(), Some(8));
261 Ok(())
262 }
263
264 #[test]
265 fn nan_count_constant_nan() -> VortexResult<()> {
266 let array = ConstantArray::new(f64::NAN, 10);
267 let mut ctx = array_session().create_execution_ctx();
268 assert_eq!(nan_count(&array.into_array(), &mut ctx)?, 10);
269 Ok(())
270 }
271
272 #[test]
273 fn nan_count_constant_non_nan() -> VortexResult<()> {
274 let array = ConstantArray::new(1.0f64, 10);
275 let mut ctx = array_session().create_execution_ctx();
276 assert_eq!(nan_count(&array.into_array(), &mut ctx)?, 0);
277 Ok(())
278 }
279
280 #[test]
281 fn nan_count_empty() -> VortexResult<()> {
282 let dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
283 let mut acc = Accumulator::try_new(NanCount, EmptyOptions, dtype)?;
284 let result = acc.finish()?;
285 assert_eq!(result.as_primitive().typed_value::<u64>(), Some(0));
286 Ok(())
287 }
288
289 #[test]
290 fn nan_count_chunked() -> VortexResult<()> {
291 let chunk1 = PrimitiveArray::from_option_iter([Some(f64::NAN), None, Some(1.0)]);
292 let chunk2 = PrimitiveArray::from_option_iter([Some(f64::NAN), Some(f64::NAN), None]);
293 let dtype = chunk1.dtype().clone();
294 let chunked = ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype)?;
295 let mut ctx = array_session().create_execution_ctx();
296 assert_eq!(nan_count(&chunked.into_array(), &mut ctx)?, 3);
297 Ok(())
298 }
299
300 #[test]
301 fn nan_count_all_null() -> VortexResult<()> {
302 let p = PrimitiveArray::from_option_iter::<f64, _>([None, None, None]);
303 let mut ctx = array_session().create_execution_ctx();
304 assert_eq!(nan_count(&p.into_array(), &mut ctx)?, 0);
305 Ok(())
306 }
307}