Skip to main content

vortex_array/aggregate_fn/fns/nan_count/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod primitive;
5
6use vortex_error::VortexExpect;
7use vortex_error::VortexResult;
8use vortex_error::vortex_bail;
9use vortex_error::vortex_err;
10use vortex_session::VortexSession;
11use vortex_session::registry::CachedId;
12
13use self::primitive::accumulate_primitive;
14use crate::ArrayRef;
15use crate::Canonical;
16use crate::Columnar;
17use crate::ExecutionCtx;
18use crate::aggregate_fn::Accumulator;
19use crate::aggregate_fn::AggregateFnId;
20use crate::aggregate_fn::AggregateFnVTable;
21use crate::aggregate_fn::DynAccumulator;
22use crate::aggregate_fn::EmptyOptions;
23use crate::dtype::DType;
24use crate::dtype::Nullability::NonNullable;
25use crate::dtype::PType;
26use crate::expr::stats::Precision;
27use crate::expr::stats::Stat;
28use crate::expr::stats::StatsProvider;
29use crate::scalar::Scalar;
30use crate::scalar::ScalarValue;
31
32/// Return the number of NaN values in an array.
33///
34/// Null values are not NaN and are not counted.
35///
36/// See [`NanCount`] for details.
37pub fn nan_count(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<usize> {
38    // Short-circuit using cached array statistics.
39    if let Precision::Exact(nan_count_scalar) = array.statistics().get(Stat::NaNCount) {
40        return usize::try_from(&nan_count_scalar)
41            .map_err(|e| vortex_err!("Failed to convert NaN count stat to usize: {e}"));
42    }
43
44    // Short-circuit for non-float types.
45    if NanCount
46        .return_dtype(&EmptyOptions, array.dtype())
47        .is_none()
48    {
49        return Ok(0);
50    }
51
52    // Short-circuit for empty arrays or all-null arrays.
53    if array.is_empty() || array.valid_count(ctx)? == 0 {
54        return Ok(0);
55    }
56
57    // Compute using Accumulator<NanCount>.
58    let mut acc = Accumulator::try_new(NanCount, EmptyOptions, array.dtype().clone())?;
59    acc.accumulate(array, ctx)?;
60    let result = acc.finish()?;
61
62    let count = result
63        .as_primitive()
64        .typed_value::<u64>()
65        .vortex_expect("nan_count result should not be null");
66    let count_usize = usize::try_from(count).vortex_expect("Cannot be more nans than usize::MAX");
67
68    // Cache the computed NaN count as a statistic.
69    array
70        .statistics()
71        .set(Stat::NaNCount, Precision::Exact(ScalarValue::from(count)));
72
73    Ok(count_usize)
74}
75
76/// Count the number of NaN values in an array.
77///
78/// Only applies to floating-point primitive types. Returns a `u64` count.
79/// Null values are not NaN, so if the array is all-invalid, the NaN count is zero.
80#[derive(Clone, Debug)]
81pub struct NanCount;
82
83impl AggregateFnVTable for NanCount {
84    type Options = EmptyOptions;
85    type Partial = u64;
86
87    fn id(&self) -> AggregateFnId {
88        static ID: CachedId = CachedId::new("vortex.nan_count");
89        *ID
90    }
91
92    fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
93        Ok(Some(vec![]))
94    }
95
96    fn deserialize(
97        &self,
98        _metadata: &[u8],
99        _session: &VortexSession,
100    ) -> VortexResult<Self::Options> {
101        Ok(EmptyOptions)
102    }
103
104    fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option<DType> {
105        if let DType::Primitive(ptype, ..) = input_dtype
106            && ptype.is_float()
107        {
108            Some(DType::Primitive(PType::U64, NonNullable))
109        } else {
110            None
111        }
112    }
113
114    fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option<DType> {
115        self.return_dtype(options, input_dtype)
116    }
117
118    fn empty_partial(
119        &self,
120        _options: &Self::Options,
121        _input_dtype: &DType,
122    ) -> VortexResult<Self::Partial> {
123        Ok(0u64)
124    }
125
126    fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()> {
127        let val = other
128            .as_primitive()
129            .typed_value::<u64>()
130            .vortex_expect("nan_count partial should not be null");
131        *partial += val;
132        Ok(())
133    }
134
135    fn to_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
136        Ok(Scalar::primitive(*partial, NonNullable))
137    }
138
139    fn reset(&self, partial: &mut Self::Partial) {
140        *partial = 0;
141    }
142
143    #[inline]
144    fn is_saturated(&self, _partial: &Self::Partial) -> bool {
145        false
146    }
147
148    fn accumulate(
149        &self,
150        partial: &mut Self::Partial,
151        batch: &Columnar,
152        ctx: &mut ExecutionCtx,
153    ) -> VortexResult<()> {
154        match batch {
155            Columnar::Constant(c) => {
156                if c.scalar().is_null() {
157                    // Null values are not NaN.
158                    return Ok(());
159                }
160                if c.scalar().as_primitive().is_nan() {
161                    *partial += c.len() as u64;
162                }
163                Ok(())
164            }
165            Columnar::Canonical(c) => match c {
166                Canonical::Primitive(p) => accumulate_primitive(partial, p, ctx),
167                _ => vortex_bail!(
168                    "Unsupported canonical type for nan_count: {}",
169                    batch.dtype()
170                ),
171            },
172        }
173    }
174
175    fn finalize(&self, partials: ArrayRef) -> VortexResult<ArrayRef> {
176        Ok(partials)
177    }
178
179    fn finalize_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
180        self.to_scalar(partial)
181    }
182}
183
184#[cfg(test)]
185mod tests {
186    use vortex_buffer::buffer;
187    use vortex_error::VortexResult;
188
189    use crate::IntoArray;
190    use crate::VortexSessionExecute;
191    use crate::aggregate_fn::Accumulator;
192    use crate::aggregate_fn::AggregateFnVTable;
193    use crate::aggregate_fn::DynAccumulator;
194    use crate::aggregate_fn::EmptyOptions;
195    use crate::aggregate_fn::fns::nan_count::NanCount;
196    use crate::aggregate_fn::fns::nan_count::nan_count;
197    use crate::array_session;
198    use crate::arrays::ChunkedArray;
199    use crate::arrays::ConstantArray;
200    use crate::arrays::PrimitiveArray;
201    use crate::dtype::DType;
202    use crate::dtype::Nullability;
203    use crate::dtype::PType;
204    use crate::scalar::Scalar;
205    use crate::validity::Validity;
206
207    #[test]
208    fn nan_count_multi_batch() -> VortexResult<()> {
209        let mut ctx = array_session().create_execution_ctx();
210        let dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
211        let mut acc = Accumulator::try_new(NanCount, EmptyOptions, dtype)?;
212
213        let batch1 =
214            PrimitiveArray::new(buffer![f64::NAN, 1.0f64, f64::NAN], Validity::NonNullable)
215                .into_array();
216        acc.accumulate(&batch1, &mut ctx)?;
217
218        let batch2 =
219            PrimitiveArray::new(buffer![2.0f64, f64::NAN], Validity::NonNullable).into_array();
220        acc.accumulate(&batch2, &mut ctx)?;
221
222        let result = acc.finish()?;
223        assert_eq!(result.as_primitive().typed_value::<u64>(), Some(3));
224        Ok(())
225    }
226
227    #[test]
228    fn nan_count_finish_resets_state() -> VortexResult<()> {
229        let mut ctx = array_session().create_execution_ctx();
230        let dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
231        let mut acc = Accumulator::try_new(NanCount, EmptyOptions, dtype)?;
232
233        let batch1 =
234            PrimitiveArray::new(buffer![f64::NAN, 1.0f64], Validity::NonNullable).into_array();
235        acc.accumulate(&batch1, &mut ctx)?;
236        let result1 = acc.finish()?;
237        assert_eq!(result1.as_primitive().typed_value::<u64>(), Some(1));
238
239        let batch2 = PrimitiveArray::new(buffer![f64::NAN, f64::NAN, 2.0], Validity::NonNullable)
240            .into_array();
241        acc.accumulate(&batch2, &mut ctx)?;
242        let result2 = acc.finish()?;
243        assert_eq!(result2.as_primitive().typed_value::<u64>(), Some(2));
244        Ok(())
245    }
246
247    #[test]
248    fn nan_count_state_merge() -> VortexResult<()> {
249        let dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
250        let mut state = NanCount.empty_partial(&EmptyOptions, &dtype)?;
251
252        let scalar1 = Scalar::primitive(5u64, Nullability::NonNullable);
253        NanCount.combine_partials(&mut state, scalar1)?;
254
255        let scalar2 = Scalar::primitive(3u64, Nullability::NonNullable);
256        NanCount.combine_partials(&mut state, scalar2)?;
257
258        let result = NanCount.to_scalar(&state)?;
259        NanCount.reset(&mut state);
260        assert_eq!(result.as_primitive().typed_value::<u64>(), Some(8));
261        Ok(())
262    }
263
264    #[test]
265    fn nan_count_constant_nan() -> VortexResult<()> {
266        let array = ConstantArray::new(f64::NAN, 10);
267        let mut ctx = array_session().create_execution_ctx();
268        assert_eq!(nan_count(&array.into_array(), &mut ctx)?, 10);
269        Ok(())
270    }
271
272    #[test]
273    fn nan_count_constant_non_nan() -> VortexResult<()> {
274        let array = ConstantArray::new(1.0f64, 10);
275        let mut ctx = array_session().create_execution_ctx();
276        assert_eq!(nan_count(&array.into_array(), &mut ctx)?, 0);
277        Ok(())
278    }
279
280    #[test]
281    fn nan_count_empty() -> VortexResult<()> {
282        let dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
283        let mut acc = Accumulator::try_new(NanCount, EmptyOptions, dtype)?;
284        let result = acc.finish()?;
285        assert_eq!(result.as_primitive().typed_value::<u64>(), Some(0));
286        Ok(())
287    }
288
289    #[test]
290    fn nan_count_chunked() -> VortexResult<()> {
291        let chunk1 = PrimitiveArray::from_option_iter([Some(f64::NAN), None, Some(1.0)]);
292        let chunk2 = PrimitiveArray::from_option_iter([Some(f64::NAN), Some(f64::NAN), None]);
293        let dtype = chunk1.dtype().clone();
294        let chunked = ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype)?;
295        let mut ctx = array_session().create_execution_ctx();
296        assert_eq!(nan_count(&chunked.into_array(), &mut ctx)?, 3);
297        Ok(())
298    }
299
300    #[test]
301    fn nan_count_all_null() -> VortexResult<()> {
302        let p = PrimitiveArray::from_option_iter::<f64, _>([None, None, None]);
303        let mut ctx = array_session().create_execution_ctx();
304        assert_eq!(nan_count(&p.into_array(), &mut ctx)?, 0);
305        Ok(())
306    }
307}