vortex_array/aggregate_fn/fns/nan_count/
mod.rs1mod primitive;
5
6use vortex_error::VortexExpect;
7use vortex_error::VortexResult;
8use vortex_error::vortex_bail;
9use vortex_error::vortex_err;
10
11use self::primitive::accumulate_primitive;
12use crate::ArrayRef;
13use crate::Canonical;
14use crate::Columnar;
15use crate::ExecutionCtx;
16use crate::aggregate_fn::Accumulator;
17use crate::aggregate_fn::AggregateFnId;
18use crate::aggregate_fn::AggregateFnVTable;
19use crate::aggregate_fn::DynAccumulator;
20use crate::aggregate_fn::EmptyOptions;
21use crate::dtype::DType;
22use crate::dtype::Nullability::NonNullable;
23use crate::dtype::PType;
24use crate::expr::stats::Precision;
25use crate::expr::stats::Stat;
26use crate::expr::stats::StatsProvider;
27use crate::scalar::Scalar;
28use crate::scalar::ScalarValue;
29
30pub fn nan_count(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<usize> {
36 if let Precision::Exact(nan_count_scalar) = array.statistics().get(Stat::NaNCount) {
38 return usize::try_from(&nan_count_scalar)
39 .map_err(|e| vortex_err!("Failed to convert NaN count stat to usize: {e}"));
40 }
41
42 if NanCount
44 .return_dtype(&EmptyOptions, array.dtype())
45 .is_none()
46 {
47 return Ok(0);
48 }
49
50 if array.is_empty() || array.valid_count(ctx)? == 0 {
52 return Ok(0);
53 }
54
55 let mut acc = Accumulator::try_new(NanCount, EmptyOptions, array.dtype().clone())?;
57 acc.accumulate(array, ctx)?;
58 let result = acc.finish()?;
59
60 let count = result
61 .as_primitive()
62 .typed_value::<u64>()
63 .vortex_expect("nan_count result should not be null");
64 let count_usize = usize::try_from(count).vortex_expect("Cannot be more nans than usize::MAX");
65
66 array
68 .statistics()
69 .set(Stat::NaNCount, Precision::Exact(ScalarValue::from(count)));
70
71 Ok(count_usize)
72}
73
74#[derive(Clone, Debug)]
79pub struct NanCount;
80
81impl AggregateFnVTable for NanCount {
82 type Options = EmptyOptions;
83 type Partial = u64;
84
85 fn id(&self) -> AggregateFnId {
86 AggregateFnId::new("vortex.nan_count")
87 }
88
89 fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
90 unimplemented!("NanCount is not yet serializable");
91 }
92
93 fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option<DType> {
94 if let DType::Primitive(ptype, ..) = input_dtype
95 && ptype.is_float()
96 {
97 Some(DType::Primitive(PType::U64, NonNullable))
98 } else {
99 None
100 }
101 }
102
103 fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option<DType> {
104 self.return_dtype(options, input_dtype)
105 }
106
107 fn empty_partial(
108 &self,
109 _options: &Self::Options,
110 _input_dtype: &DType,
111 ) -> VortexResult<Self::Partial> {
112 Ok(0u64)
113 }
114
115 fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()> {
116 let val = other
117 .as_primitive()
118 .typed_value::<u64>()
119 .vortex_expect("nan_count partial should not be null");
120 *partial += val;
121 Ok(())
122 }
123
124 fn to_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
125 Ok(Scalar::primitive(*partial, NonNullable))
126 }
127
128 fn reset(&self, partial: &mut Self::Partial) {
129 *partial = 0;
130 }
131
132 #[inline]
133 fn is_saturated(&self, _partial: &Self::Partial) -> bool {
134 false
135 }
136
137 fn accumulate(
138 &self,
139 partial: &mut Self::Partial,
140 batch: &Columnar,
141 ctx: &mut ExecutionCtx,
142 ) -> VortexResult<()> {
143 match batch {
144 Columnar::Constant(c) => {
145 if c.scalar().is_null() {
146 return Ok(());
148 }
149 if c.scalar().as_primitive().is_nan() {
150 *partial += c.len() as u64;
151 }
152 Ok(())
153 }
154 Columnar::Canonical(c) => match c {
155 Canonical::Primitive(p) => accumulate_primitive(partial, p, ctx),
156 _ => vortex_bail!(
157 "Unsupported canonical type for nan_count: {}",
158 batch.dtype()
159 ),
160 },
161 }
162 }
163
164 fn finalize(&self, partials: ArrayRef) -> VortexResult<ArrayRef> {
165 Ok(partials)
166 }
167
168 fn finalize_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
169 self.to_scalar(partial)
170 }
171}
172
173#[cfg(test)]
174mod tests {
175 use vortex_buffer::buffer;
176 use vortex_error::VortexResult;
177
178 use crate::IntoArray;
179 use crate::LEGACY_SESSION;
180 use crate::VortexSessionExecute;
181 use crate::aggregate_fn::Accumulator;
182 use crate::aggregate_fn::AggregateFnVTable;
183 use crate::aggregate_fn::DynAccumulator;
184 use crate::aggregate_fn::EmptyOptions;
185 use crate::aggregate_fn::fns::nan_count::NanCount;
186 use crate::aggregate_fn::fns::nan_count::nan_count;
187 use crate::arrays::ChunkedArray;
188 use crate::arrays::ConstantArray;
189 use crate::arrays::PrimitiveArray;
190 use crate::dtype::DType;
191 use crate::dtype::Nullability;
192 use crate::dtype::PType;
193 use crate::scalar::Scalar;
194 use crate::validity::Validity;
195
196 #[test]
197 fn nan_count_multi_batch() -> VortexResult<()> {
198 let mut ctx = LEGACY_SESSION.create_execution_ctx();
199 let dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
200 let mut acc = Accumulator::try_new(NanCount, EmptyOptions, dtype)?;
201
202 let batch1 =
203 PrimitiveArray::new(buffer![f64::NAN, 1.0f64, f64::NAN], Validity::NonNullable)
204 .into_array();
205 acc.accumulate(&batch1, &mut ctx)?;
206
207 let batch2 =
208 PrimitiveArray::new(buffer![2.0f64, f64::NAN], Validity::NonNullable).into_array();
209 acc.accumulate(&batch2, &mut ctx)?;
210
211 let result = acc.finish()?;
212 assert_eq!(result.as_primitive().typed_value::<u64>(), Some(3));
213 Ok(())
214 }
215
216 #[test]
217 fn nan_count_finish_resets_state() -> VortexResult<()> {
218 let mut ctx = LEGACY_SESSION.create_execution_ctx();
219 let dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
220 let mut acc = Accumulator::try_new(NanCount, EmptyOptions, dtype)?;
221
222 let batch1 =
223 PrimitiveArray::new(buffer![f64::NAN, 1.0f64], Validity::NonNullable).into_array();
224 acc.accumulate(&batch1, &mut ctx)?;
225 let result1 = acc.finish()?;
226 assert_eq!(result1.as_primitive().typed_value::<u64>(), Some(1));
227
228 let batch2 = PrimitiveArray::new(buffer![f64::NAN, f64::NAN, 2.0], Validity::NonNullable)
229 .into_array();
230 acc.accumulate(&batch2, &mut ctx)?;
231 let result2 = acc.finish()?;
232 assert_eq!(result2.as_primitive().typed_value::<u64>(), Some(2));
233 Ok(())
234 }
235
236 #[test]
237 fn nan_count_state_merge() -> VortexResult<()> {
238 let dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
239 let mut state = NanCount.empty_partial(&EmptyOptions, &dtype)?;
240
241 let scalar1 = Scalar::primitive(5u64, Nullability::NonNullable);
242 NanCount.combine_partials(&mut state, scalar1)?;
243
244 let scalar2 = Scalar::primitive(3u64, Nullability::NonNullable);
245 NanCount.combine_partials(&mut state, scalar2)?;
246
247 let result = NanCount.to_scalar(&state)?;
248 NanCount.reset(&mut state);
249 assert_eq!(result.as_primitive().typed_value::<u64>(), Some(8));
250 Ok(())
251 }
252
253 #[test]
254 fn nan_count_constant_nan() -> VortexResult<()> {
255 let array = ConstantArray::new(f64::NAN, 10);
256 let mut ctx = LEGACY_SESSION.create_execution_ctx();
257 assert_eq!(nan_count(&array.into_array(), &mut ctx)?, 10);
258 Ok(())
259 }
260
261 #[test]
262 fn nan_count_constant_non_nan() -> VortexResult<()> {
263 let array = ConstantArray::new(1.0f64, 10);
264 let mut ctx = LEGACY_SESSION.create_execution_ctx();
265 assert_eq!(nan_count(&array.into_array(), &mut ctx)?, 0);
266 Ok(())
267 }
268
269 #[test]
270 fn nan_count_empty() -> VortexResult<()> {
271 let dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
272 let mut acc = Accumulator::try_new(NanCount, EmptyOptions, dtype)?;
273 let result = acc.finish()?;
274 assert_eq!(result.as_primitive().typed_value::<u64>(), Some(0));
275 Ok(())
276 }
277
278 #[test]
279 fn nan_count_chunked() -> VortexResult<()> {
280 let chunk1 = PrimitiveArray::from_option_iter([Some(f64::NAN), None, Some(1.0)]);
281 let chunk2 = PrimitiveArray::from_option_iter([Some(f64::NAN), Some(f64::NAN), None]);
282 let dtype = chunk1.dtype().clone();
283 let chunked = ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype)?;
284 let mut ctx = LEGACY_SESSION.create_execution_ctx();
285 assert_eq!(nan_count(&chunked.into_array(), &mut ctx)?, 3);
286 Ok(())
287 }
288
289 #[test]
290 fn nan_count_all_null() -> VortexResult<()> {
291 let p = PrimitiveArray::from_option_iter::<f64, _>([None, None, None]);
292 let mut ctx = LEGACY_SESSION.create_execution_ctx();
293 assert_eq!(nan_count(&p.into_array(), &mut ctx)?, 0);
294 Ok(())
295 }
296}