vortex_array/aggregate_fn/fns/nan_count/
mod.rs1mod primitive;
5
6use vortex_error::VortexExpect;
7use vortex_error::VortexResult;
8use vortex_error::vortex_bail;
9use vortex_error::vortex_err;
10
11use self::primitive::accumulate_primitive;
12use crate::ArrayRef;
13use crate::Canonical;
14use crate::Columnar;
15use crate::ExecutionCtx;
16use crate::aggregate_fn::Accumulator;
17use crate::aggregate_fn::AggregateFnId;
18use crate::aggregate_fn::AggregateFnVTable;
19use crate::aggregate_fn::DynAccumulator;
20use crate::aggregate_fn::EmptyOptions;
21use crate::dtype::DType;
22use crate::dtype::Nullability::NonNullable;
23use crate::dtype::PType;
24use crate::expr::stats::Precision;
25use crate::expr::stats::Stat;
26use crate::expr::stats::StatsProvider;
27use crate::scalar::Scalar;
28use crate::scalar::ScalarValue;
29
30pub fn nan_count(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<usize> {
34 if let Some(Precision::Exact(nan_count_scalar)) = array.statistics().get(Stat::NaNCount) {
36 return usize::try_from(&nan_count_scalar)
37 .map_err(|e| vortex_err!("Failed to convert NaN count stat to usize: {e}"));
38 }
39
40 if NanCount
42 .return_dtype(&EmptyOptions, array.dtype())
43 .is_none()
44 {
45 return Ok(0);
46 }
47
48 if array.is_empty() || array.valid_count()? == 0 {
50 return Ok(0);
51 }
52
53 let mut acc = Accumulator::try_new(NanCount, EmptyOptions, array.dtype().clone())?;
55 acc.accumulate(array, ctx)?;
56 let result = acc.finish()?;
57
58 let count = result
59 .as_primitive()
60 .typed_value::<u64>()
61 .vortex_expect("nan_count result should not be null");
62 let count_usize = usize::try_from(count).vortex_expect("Cannot be more nans than usize::MAX");
63
64 array
66 .statistics()
67 .set(Stat::NaNCount, Precision::Exact(ScalarValue::from(count)));
68
69 Ok(count_usize)
70}
71
72#[derive(Clone, Debug)]
77pub struct NanCount;
78
79impl AggregateFnVTable for NanCount {
80 type Options = EmptyOptions;
81 type Partial = u64;
82
83 fn id(&self) -> AggregateFnId {
84 AggregateFnId::new_ref("vortex.nan_count")
85 }
86
87 fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
88 Ok(Some(vec![]))
89 }
90
91 fn deserialize(
92 &self,
93 _metadata: &[u8],
94 _session: &vortex_session::VortexSession,
95 ) -> VortexResult<Self::Options> {
96 Ok(EmptyOptions)
97 }
98
99 fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option<DType> {
100 if let DType::Primitive(ptype, ..) = input_dtype
101 && ptype.is_float()
102 {
103 Some(DType::Primitive(PType::U64, NonNullable))
104 } else {
105 None
106 }
107 }
108
109 fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option<DType> {
110 self.return_dtype(options, input_dtype)
111 }
112
113 fn empty_partial(
114 &self,
115 _options: &Self::Options,
116 _input_dtype: &DType,
117 ) -> VortexResult<Self::Partial> {
118 Ok(0u64)
119 }
120
121 fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()> {
122 let val = other
123 .as_primitive()
124 .typed_value::<u64>()
125 .vortex_expect("nan_count partial should not be null");
126 *partial += val;
127 Ok(())
128 }
129
130 fn to_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
131 Ok(Scalar::primitive(*partial, NonNullable))
132 }
133
134 fn reset(&self, partial: &mut Self::Partial) {
135 *partial = 0;
136 }
137
138 #[inline]
139 fn is_saturated(&self, _partial: &Self::Partial) -> bool {
140 false
141 }
142
143 fn accumulate(
144 &self,
145 partial: &mut Self::Partial,
146 batch: &Columnar,
147 _ctx: &mut ExecutionCtx,
148 ) -> VortexResult<()> {
149 match batch {
150 Columnar::Constant(c) => {
151 if c.scalar().is_null() {
152 return Ok(());
154 }
155 if c.scalar().as_primitive().is_nan() {
156 *partial += c.len() as u64;
157 }
158 Ok(())
159 }
160 Columnar::Canonical(c) => match c {
161 Canonical::Primitive(p) => accumulate_primitive(partial, p),
162 _ => vortex_bail!(
163 "Unsupported canonical type for nan_count: {}",
164 batch.dtype()
165 ),
166 },
167 }
168 }
169
170 fn finalize(&self, partials: ArrayRef) -> VortexResult<ArrayRef> {
171 Ok(partials)
172 }
173
174 fn finalize_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
175 self.to_scalar(partial)
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use vortex_buffer::buffer;
182 use vortex_error::VortexResult;
183
184 use crate::IntoArray;
185 use crate::LEGACY_SESSION;
186 use crate::VortexSessionExecute;
187 use crate::aggregate_fn::Accumulator;
188 use crate::aggregate_fn::AggregateFnVTable;
189 use crate::aggregate_fn::DynAccumulator;
190 use crate::aggregate_fn::EmptyOptions;
191 use crate::aggregate_fn::fns::nan_count::NanCount;
192 use crate::aggregate_fn::fns::nan_count::nan_count;
193 use crate::arrays::ChunkedArray;
194 use crate::arrays::ConstantArray;
195 use crate::arrays::PrimitiveArray;
196 use crate::dtype::DType;
197 use crate::dtype::Nullability;
198 use crate::dtype::PType;
199 use crate::scalar::Scalar;
200 use crate::validity::Validity;
201
202 #[test]
203 fn nan_count_multi_batch() -> VortexResult<()> {
204 let mut ctx = LEGACY_SESSION.create_execution_ctx();
205 let dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
206 let mut acc = Accumulator::try_new(NanCount, EmptyOptions, dtype)?;
207
208 let batch1 =
209 PrimitiveArray::new(buffer![f64::NAN, 1.0f64, f64::NAN], Validity::NonNullable)
210 .into_array();
211 acc.accumulate(&batch1, &mut ctx)?;
212
213 let batch2 =
214 PrimitiveArray::new(buffer![2.0f64, f64::NAN], Validity::NonNullable).into_array();
215 acc.accumulate(&batch2, &mut ctx)?;
216
217 let result = acc.finish()?;
218 assert_eq!(result.as_primitive().typed_value::<u64>(), Some(3));
219 Ok(())
220 }
221
222 #[test]
223 fn nan_count_finish_resets_state() -> VortexResult<()> {
224 let mut ctx = LEGACY_SESSION.create_execution_ctx();
225 let dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
226 let mut acc = Accumulator::try_new(NanCount, EmptyOptions, dtype)?;
227
228 let batch1 =
229 PrimitiveArray::new(buffer![f64::NAN, 1.0f64], Validity::NonNullable).into_array();
230 acc.accumulate(&batch1, &mut ctx)?;
231 let result1 = acc.finish()?;
232 assert_eq!(result1.as_primitive().typed_value::<u64>(), Some(1));
233
234 let batch2 = PrimitiveArray::new(buffer![f64::NAN, f64::NAN, 2.0], Validity::NonNullable)
235 .into_array();
236 acc.accumulate(&batch2, &mut ctx)?;
237 let result2 = acc.finish()?;
238 assert_eq!(result2.as_primitive().typed_value::<u64>(), Some(2));
239 Ok(())
240 }
241
242 #[test]
243 fn nan_count_state_merge() -> VortexResult<()> {
244 let dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
245 let mut state = NanCount.empty_partial(&EmptyOptions, &dtype)?;
246
247 let scalar1 = Scalar::primitive(5u64, Nullability::NonNullable);
248 NanCount.combine_partials(&mut state, scalar1)?;
249
250 let scalar2 = Scalar::primitive(3u64, Nullability::NonNullable);
251 NanCount.combine_partials(&mut state, scalar2)?;
252
253 let result = NanCount.to_scalar(&state)?;
254 NanCount.reset(&mut state);
255 assert_eq!(result.as_primitive().typed_value::<u64>(), Some(8));
256 Ok(())
257 }
258
259 #[test]
260 fn nan_count_constant_nan() -> VortexResult<()> {
261 let array = ConstantArray::new(f64::NAN, 10);
262 let mut ctx = LEGACY_SESSION.create_execution_ctx();
263 assert_eq!(nan_count(&array.into_array(), &mut ctx)?, 10);
264 Ok(())
265 }
266
267 #[test]
268 fn nan_count_constant_non_nan() -> VortexResult<()> {
269 let array = ConstantArray::new(1.0f64, 10);
270 let mut ctx = LEGACY_SESSION.create_execution_ctx();
271 assert_eq!(nan_count(&array.into_array(), &mut ctx)?, 0);
272 Ok(())
273 }
274
275 #[test]
276 fn nan_count_empty() -> VortexResult<()> {
277 let dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
278 let mut acc = Accumulator::try_new(NanCount, EmptyOptions, dtype)?;
279 let result = acc.finish()?;
280 assert_eq!(result.as_primitive().typed_value::<u64>(), Some(0));
281 Ok(())
282 }
283
284 #[test]
285 fn nan_count_chunked() -> VortexResult<()> {
286 let chunk1 = PrimitiveArray::from_option_iter([Some(f64::NAN), None, Some(1.0)]);
287 let chunk2 = PrimitiveArray::from_option_iter([Some(f64::NAN), Some(f64::NAN), None]);
288 let dtype = chunk1.dtype().clone();
289 let chunked = ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype)?;
290 let mut ctx = LEGACY_SESSION.create_execution_ctx();
291 assert_eq!(nan_count(&chunked.into_array(), &mut ctx)?, 3);
292 Ok(())
293 }
294
295 #[test]
296 fn nan_count_all_null() -> VortexResult<()> {
297 let p = PrimitiveArray::from_option_iter::<f64, _>([None, None, None]);
298 let mut ctx = LEGACY_SESSION.create_execution_ctx();
299 assert_eq!(nan_count(&p.into_array(), &mut ctx)?, 0);
300 Ok(())
301 }
302}