Skip to main content

vortex_array/aggregate_fn/fns/nan_count/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod primitive;
5
6use vortex_error::VortexExpect;
7use vortex_error::VortexResult;
8use vortex_error::vortex_bail;
9use vortex_error::vortex_err;
10
11use self::primitive::accumulate_primitive;
12use crate::ArrayRef;
13use crate::Canonical;
14use crate::Columnar;
15use crate::ExecutionCtx;
16use crate::aggregate_fn::Accumulator;
17use crate::aggregate_fn::AggregateFnId;
18use crate::aggregate_fn::AggregateFnVTable;
19use crate::aggregate_fn::DynAccumulator;
20use crate::aggregate_fn::EmptyOptions;
21use crate::dtype::DType;
22use crate::dtype::Nullability::NonNullable;
23use crate::dtype::PType;
24use crate::expr::stats::Precision;
25use crate::expr::stats::Stat;
26use crate::expr::stats::StatsProvider;
27use crate::scalar::Scalar;
28use crate::scalar::ScalarValue;
29
30/// Return the number of NaN values in an array.
31///
32/// See [`NanCount`] for details.
33pub fn nan_count(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<usize> {
34    // Short-circuit using cached array statistics.
35    if let Some(Precision::Exact(nan_count_scalar)) = array.statistics().get(Stat::NaNCount) {
36        return usize::try_from(&nan_count_scalar)
37            .map_err(|e| vortex_err!("Failed to convert NaN count stat to usize: {e}"));
38    }
39
40    // Short-circuit for non-float types.
41    if NanCount
42        .return_dtype(&EmptyOptions, array.dtype())
43        .is_none()
44    {
45        return Ok(0);
46    }
47
48    // Short-circuit for empty arrays or all-null arrays.
49    if array.is_empty() || array.valid_count()? == 0 {
50        return Ok(0);
51    }
52
53    // Compute using Accumulator<NanCount>.
54    let mut acc = Accumulator::try_new(NanCount, EmptyOptions, array.dtype().clone())?;
55    acc.accumulate(array, ctx)?;
56    let result = acc.finish()?;
57
58    let count = result
59        .as_primitive()
60        .typed_value::<u64>()
61        .vortex_expect("nan_count result should not be null");
62    let count_usize = usize::try_from(count).vortex_expect("Cannot be more nans than usize::MAX");
63
64    // Cache the computed NaN count as a statistic.
65    array
66        .statistics()
67        .set(Stat::NaNCount, Precision::Exact(ScalarValue::from(count)));
68
69    Ok(count_usize)
70}
71
72/// Count the number of NaN values in an array.
73///
74/// Only applies to floating-point primitive types. Returns a `u64` count.
75/// If the array is all-invalid, the NaN count is zero.
76#[derive(Clone, Debug)]
77pub struct NanCount;
78
79impl AggregateFnVTable for NanCount {
80    type Options = EmptyOptions;
81    type Partial = u64;
82
83    fn id(&self) -> AggregateFnId {
84        AggregateFnId::new_ref("vortex.nan_count")
85    }
86
87    fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
88        Ok(Some(vec![]))
89    }
90
91    fn deserialize(
92        &self,
93        _metadata: &[u8],
94        _session: &vortex_session::VortexSession,
95    ) -> VortexResult<Self::Options> {
96        Ok(EmptyOptions)
97    }
98
99    fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option<DType> {
100        if let DType::Primitive(ptype, ..) = input_dtype
101            && ptype.is_float()
102        {
103            Some(DType::Primitive(PType::U64, NonNullable))
104        } else {
105            None
106        }
107    }
108
109    fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option<DType> {
110        self.return_dtype(options, input_dtype)
111    }
112
113    fn empty_partial(
114        &self,
115        _options: &Self::Options,
116        _input_dtype: &DType,
117    ) -> VortexResult<Self::Partial> {
118        Ok(0u64)
119    }
120
121    fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()> {
122        let val = other
123            .as_primitive()
124            .typed_value::<u64>()
125            .vortex_expect("nan_count partial should not be null");
126        *partial += val;
127        Ok(())
128    }
129
130    fn to_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
131        Ok(Scalar::primitive(*partial, NonNullable))
132    }
133
134    fn reset(&self, partial: &mut Self::Partial) {
135        *partial = 0;
136    }
137
138    #[inline]
139    fn is_saturated(&self, _partial: &Self::Partial) -> bool {
140        false
141    }
142
143    fn accumulate(
144        &self,
145        partial: &mut Self::Partial,
146        batch: &Columnar,
147        _ctx: &mut ExecutionCtx,
148    ) -> VortexResult<()> {
149        match batch {
150            Columnar::Constant(c) => {
151                if c.scalar().is_null() {
152                    // Null values are not NaN.
153                    return Ok(());
154                }
155                if c.scalar().as_primitive().is_nan() {
156                    *partial += c.len() as u64;
157                }
158                Ok(())
159            }
160            Columnar::Canonical(c) => match c {
161                Canonical::Primitive(p) => accumulate_primitive(partial, p),
162                _ => vortex_bail!(
163                    "Unsupported canonical type for nan_count: {}",
164                    batch.dtype()
165                ),
166            },
167        }
168    }
169
170    fn finalize(&self, partials: ArrayRef) -> VortexResult<ArrayRef> {
171        Ok(partials)
172    }
173
174    fn finalize_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
175        self.to_scalar(partial)
176    }
177}
178
179#[cfg(test)]
180mod tests {
181    use vortex_buffer::buffer;
182    use vortex_error::VortexResult;
183
184    use crate::IntoArray;
185    use crate::LEGACY_SESSION;
186    use crate::VortexSessionExecute;
187    use crate::aggregate_fn::Accumulator;
188    use crate::aggregate_fn::AggregateFnVTable;
189    use crate::aggregate_fn::DynAccumulator;
190    use crate::aggregate_fn::EmptyOptions;
191    use crate::aggregate_fn::fns::nan_count::NanCount;
192    use crate::aggregate_fn::fns::nan_count::nan_count;
193    use crate::arrays::ChunkedArray;
194    use crate::arrays::ConstantArray;
195    use crate::arrays::PrimitiveArray;
196    use crate::dtype::DType;
197    use crate::dtype::Nullability;
198    use crate::dtype::PType;
199    use crate::scalar::Scalar;
200    use crate::validity::Validity;
201
202    #[test]
203    fn nan_count_multi_batch() -> VortexResult<()> {
204        let mut ctx = LEGACY_SESSION.create_execution_ctx();
205        let dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
206        let mut acc = Accumulator::try_new(NanCount, EmptyOptions, dtype)?;
207
208        let batch1 =
209            PrimitiveArray::new(buffer![f64::NAN, 1.0f64, f64::NAN], Validity::NonNullable)
210                .into_array();
211        acc.accumulate(&batch1, &mut ctx)?;
212
213        let batch2 =
214            PrimitiveArray::new(buffer![2.0f64, f64::NAN], Validity::NonNullable).into_array();
215        acc.accumulate(&batch2, &mut ctx)?;
216
217        let result = acc.finish()?;
218        assert_eq!(result.as_primitive().typed_value::<u64>(), Some(3));
219        Ok(())
220    }
221
222    #[test]
223    fn nan_count_finish_resets_state() -> VortexResult<()> {
224        let mut ctx = LEGACY_SESSION.create_execution_ctx();
225        let dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
226        let mut acc = Accumulator::try_new(NanCount, EmptyOptions, dtype)?;
227
228        let batch1 =
229            PrimitiveArray::new(buffer![f64::NAN, 1.0f64], Validity::NonNullable).into_array();
230        acc.accumulate(&batch1, &mut ctx)?;
231        let result1 = acc.finish()?;
232        assert_eq!(result1.as_primitive().typed_value::<u64>(), Some(1));
233
234        let batch2 = PrimitiveArray::new(buffer![f64::NAN, f64::NAN, 2.0], Validity::NonNullable)
235            .into_array();
236        acc.accumulate(&batch2, &mut ctx)?;
237        let result2 = acc.finish()?;
238        assert_eq!(result2.as_primitive().typed_value::<u64>(), Some(2));
239        Ok(())
240    }
241
242    #[test]
243    fn nan_count_state_merge() -> VortexResult<()> {
244        let dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
245        let mut state = NanCount.empty_partial(&EmptyOptions, &dtype)?;
246
247        let scalar1 = Scalar::primitive(5u64, Nullability::NonNullable);
248        NanCount.combine_partials(&mut state, scalar1)?;
249
250        let scalar2 = Scalar::primitive(3u64, Nullability::NonNullable);
251        NanCount.combine_partials(&mut state, scalar2)?;
252
253        let result = NanCount.to_scalar(&state)?;
254        NanCount.reset(&mut state);
255        assert_eq!(result.as_primitive().typed_value::<u64>(), Some(8));
256        Ok(())
257    }
258
259    #[test]
260    fn nan_count_constant_nan() -> VortexResult<()> {
261        let array = ConstantArray::new(f64::NAN, 10);
262        let mut ctx = LEGACY_SESSION.create_execution_ctx();
263        assert_eq!(nan_count(&array.into_array(), &mut ctx)?, 10);
264        Ok(())
265    }
266
267    #[test]
268    fn nan_count_constant_non_nan() -> VortexResult<()> {
269        let array = ConstantArray::new(1.0f64, 10);
270        let mut ctx = LEGACY_SESSION.create_execution_ctx();
271        assert_eq!(nan_count(&array.into_array(), &mut ctx)?, 0);
272        Ok(())
273    }
274
275    #[test]
276    fn nan_count_empty() -> VortexResult<()> {
277        let dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
278        let mut acc = Accumulator::try_new(NanCount, EmptyOptions, dtype)?;
279        let result = acc.finish()?;
280        assert_eq!(result.as_primitive().typed_value::<u64>(), Some(0));
281        Ok(())
282    }
283
284    #[test]
285    fn nan_count_chunked() -> VortexResult<()> {
286        let chunk1 = PrimitiveArray::from_option_iter([Some(f64::NAN), None, Some(1.0)]);
287        let chunk2 = PrimitiveArray::from_option_iter([Some(f64::NAN), Some(f64::NAN), None]);
288        let dtype = chunk1.dtype().clone();
289        let chunked = ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype)?;
290        let mut ctx = LEGACY_SESSION.create_execution_ctx();
291        assert_eq!(nan_count(&chunked.into_array(), &mut ctx)?, 3);
292        Ok(())
293    }
294
295    #[test]
296    fn nan_count_all_null() -> VortexResult<()> {
297        let p = PrimitiveArray::from_option_iter::<f64, _>([None, None, None]);
298        let mut ctx = LEGACY_SESSION.create_execution_ctx();
299        assert_eq!(nan_count(&p.into_array(), &mut ctx)?, 0);
300        Ok(())
301    }
302}