Skip to main content

vortex_array/aggregate_fn/fns/nan_count/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod primitive;
5
6use vortex_error::VortexExpect;
7use vortex_error::VortexResult;
8use vortex_error::vortex_bail;
9use vortex_error::vortex_err;
10
11use self::primitive::accumulate_primitive;
12use crate::ArrayRef;
13use crate::Canonical;
14use crate::Columnar;
15use crate::ExecutionCtx;
16use crate::aggregate_fn::Accumulator;
17use crate::aggregate_fn::AggregateFnId;
18use crate::aggregate_fn::AggregateFnVTable;
19use crate::aggregate_fn::DynAccumulator;
20use crate::aggregate_fn::EmptyOptions;
21use crate::dtype::DType;
22use crate::dtype::Nullability::NonNullable;
23use crate::dtype::PType;
24use crate::expr::stats::Precision;
25use crate::expr::stats::Stat;
26use crate::expr::stats::StatsProvider;
27use crate::scalar::Scalar;
28use crate::scalar::ScalarValue;
29
30/// Return the number of NaN values in an array.
31///
32/// See [`NanCount`] for details.
33pub fn nan_count(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<usize> {
34    // Short-circuit using cached array statistics.
35    if let Some(Precision::Exact(nan_count_scalar)) = array.statistics().get(Stat::NaNCount) {
36        return usize::try_from(&nan_count_scalar)
37            .map_err(|e| vortex_err!("Failed to convert NaN count stat to usize: {e}"));
38    }
39
40    // Short-circuit for non-float types.
41    if NanCount
42        .return_dtype(&EmptyOptions, array.dtype())
43        .is_none()
44    {
45        return Ok(0);
46    }
47
48    // Short-circuit for empty arrays or all-null arrays.
49    if array.is_empty() || array.valid_count()? == 0 {
50        return Ok(0);
51    }
52
53    // Compute using Accumulator<NanCount>.
54    let mut acc = Accumulator::try_new(NanCount, EmptyOptions, array.dtype().clone())?;
55    acc.accumulate(array, ctx)?;
56    let result = acc.finish()?;
57
58    let count = result
59        .as_primitive()
60        .typed_value::<u64>()
61        .vortex_expect("nan_count result should not be null");
62    let count_usize = usize::try_from(count).vortex_expect("Cannot be more nans than usize::MAX");
63
64    // Cache the computed NaN count as a statistic.
65    array
66        .statistics()
67        .set(Stat::NaNCount, Precision::Exact(ScalarValue::from(count)));
68
69    Ok(count_usize)
70}
71
72/// Count the number of NaN values in an array.
73///
74/// Only applies to floating-point primitive types. Returns a `u64` count.
75/// If the array is all-invalid, the NaN count is zero.
76#[derive(Clone, Debug)]
77pub struct NanCount;
78
79impl AggregateFnVTable for NanCount {
80    type Options = EmptyOptions;
81    type Partial = u64;
82
83    fn id(&self) -> AggregateFnId {
84        AggregateFnId::new_ref("vortex.nan_count")
85    }
86
87    fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
88        unimplemented!("NanCount is not yet serializable");
89    }
90
91    fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option<DType> {
92        if let DType::Primitive(ptype, ..) = input_dtype
93            && ptype.is_float()
94        {
95            Some(DType::Primitive(PType::U64, NonNullable))
96        } else {
97            None
98        }
99    }
100
101    fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option<DType> {
102        self.return_dtype(options, input_dtype)
103    }
104
105    fn empty_partial(
106        &self,
107        _options: &Self::Options,
108        _input_dtype: &DType,
109    ) -> VortexResult<Self::Partial> {
110        Ok(0u64)
111    }
112
113    fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()> {
114        let val = other
115            .as_primitive()
116            .typed_value::<u64>()
117            .vortex_expect("nan_count partial should not be null");
118        *partial += val;
119        Ok(())
120    }
121
122    fn to_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
123        Ok(Scalar::primitive(*partial, NonNullable))
124    }
125
126    fn reset(&self, partial: &mut Self::Partial) {
127        *partial = 0;
128    }
129
130    #[inline]
131    fn is_saturated(&self, _partial: &Self::Partial) -> bool {
132        false
133    }
134
135    fn accumulate(
136        &self,
137        partial: &mut Self::Partial,
138        batch: &Columnar,
139        _ctx: &mut ExecutionCtx,
140    ) -> VortexResult<()> {
141        match batch {
142            Columnar::Constant(c) => {
143                if c.scalar().is_null() {
144                    // Null values are not NaN.
145                    return Ok(());
146                }
147                if c.scalar().as_primitive().is_nan() {
148                    *partial += c.len() as u64;
149                }
150                Ok(())
151            }
152            Columnar::Canonical(c) => match c {
153                Canonical::Primitive(p) => accumulate_primitive(partial, p),
154                _ => vortex_bail!(
155                    "Unsupported canonical type for nan_count: {}",
156                    batch.dtype()
157                ),
158            },
159        }
160    }
161
162    fn finalize(&self, partials: ArrayRef) -> VortexResult<ArrayRef> {
163        Ok(partials)
164    }
165
166    fn finalize_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
167        self.to_scalar(partial)
168    }
169}
170
171#[cfg(test)]
172mod tests {
173    use vortex_buffer::buffer;
174    use vortex_error::VortexResult;
175
176    use crate::IntoArray;
177    use crate::LEGACY_SESSION;
178    use crate::VortexSessionExecute;
179    use crate::aggregate_fn::Accumulator;
180    use crate::aggregate_fn::AggregateFnVTable;
181    use crate::aggregate_fn::DynAccumulator;
182    use crate::aggregate_fn::EmptyOptions;
183    use crate::aggregate_fn::fns::nan_count::NanCount;
184    use crate::aggregate_fn::fns::nan_count::nan_count;
185    use crate::arrays::ChunkedArray;
186    use crate::arrays::ConstantArray;
187    use crate::arrays::PrimitiveArray;
188    use crate::dtype::DType;
189    use crate::dtype::Nullability;
190    use crate::dtype::PType;
191    use crate::scalar::Scalar;
192    use crate::validity::Validity;
193
194    #[test]
195    fn nan_count_multi_batch() -> VortexResult<()> {
196        let mut ctx = LEGACY_SESSION.create_execution_ctx();
197        let dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
198        let mut acc = Accumulator::try_new(NanCount, EmptyOptions, dtype)?;
199
200        let batch1 =
201            PrimitiveArray::new(buffer![f64::NAN, 1.0f64, f64::NAN], Validity::NonNullable)
202                .into_array();
203        acc.accumulate(&batch1, &mut ctx)?;
204
205        let batch2 =
206            PrimitiveArray::new(buffer![2.0f64, f64::NAN], Validity::NonNullable).into_array();
207        acc.accumulate(&batch2, &mut ctx)?;
208
209        let result = acc.finish()?;
210        assert_eq!(result.as_primitive().typed_value::<u64>(), Some(3));
211        Ok(())
212    }
213
214    #[test]
215    fn nan_count_finish_resets_state() -> VortexResult<()> {
216        let mut ctx = LEGACY_SESSION.create_execution_ctx();
217        let dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
218        let mut acc = Accumulator::try_new(NanCount, EmptyOptions, dtype)?;
219
220        let batch1 =
221            PrimitiveArray::new(buffer![f64::NAN, 1.0f64], Validity::NonNullable).into_array();
222        acc.accumulate(&batch1, &mut ctx)?;
223        let result1 = acc.finish()?;
224        assert_eq!(result1.as_primitive().typed_value::<u64>(), Some(1));
225
226        let batch2 = PrimitiveArray::new(buffer![f64::NAN, f64::NAN, 2.0], Validity::NonNullable)
227            .into_array();
228        acc.accumulate(&batch2, &mut ctx)?;
229        let result2 = acc.finish()?;
230        assert_eq!(result2.as_primitive().typed_value::<u64>(), Some(2));
231        Ok(())
232    }
233
234    #[test]
235    fn nan_count_state_merge() -> VortexResult<()> {
236        let dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
237        let mut state = NanCount.empty_partial(&EmptyOptions, &dtype)?;
238
239        let scalar1 = Scalar::primitive(5u64, Nullability::NonNullable);
240        NanCount.combine_partials(&mut state, scalar1)?;
241
242        let scalar2 = Scalar::primitive(3u64, Nullability::NonNullable);
243        NanCount.combine_partials(&mut state, scalar2)?;
244
245        let result = NanCount.to_scalar(&state)?;
246        NanCount.reset(&mut state);
247        assert_eq!(result.as_primitive().typed_value::<u64>(), Some(8));
248        Ok(())
249    }
250
251    #[test]
252    fn nan_count_constant_nan() -> VortexResult<()> {
253        let array = ConstantArray::new(f64::NAN, 10);
254        let mut ctx = LEGACY_SESSION.create_execution_ctx();
255        assert_eq!(nan_count(&array.into_array(), &mut ctx)?, 10);
256        Ok(())
257    }
258
259    #[test]
260    fn nan_count_constant_non_nan() -> VortexResult<()> {
261        let array = ConstantArray::new(1.0f64, 10);
262        let mut ctx = LEGACY_SESSION.create_execution_ctx();
263        assert_eq!(nan_count(&array.into_array(), &mut ctx)?, 0);
264        Ok(())
265    }
266
267    #[test]
268    fn nan_count_empty() -> VortexResult<()> {
269        let dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
270        let mut acc = Accumulator::try_new(NanCount, EmptyOptions, dtype)?;
271        let result = acc.finish()?;
272        assert_eq!(result.as_primitive().typed_value::<u64>(), Some(0));
273        Ok(())
274    }
275
276    #[test]
277    fn nan_count_chunked() -> VortexResult<()> {
278        let chunk1 = PrimitiveArray::from_option_iter([Some(f64::NAN), None, Some(1.0)]);
279        let chunk2 = PrimitiveArray::from_option_iter([Some(f64::NAN), Some(f64::NAN), None]);
280        let dtype = chunk1.dtype().clone();
281        let chunked = ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype)?;
282        let mut ctx = LEGACY_SESSION.create_execution_ctx();
283        assert_eq!(nan_count(&chunked.into_array(), &mut ctx)?, 3);
284        Ok(())
285    }
286
287    #[test]
288    fn nan_count_all_null() -> VortexResult<()> {
289        let p = PrimitiveArray::from_option_iter::<f64, _>([None, None, None]);
290        let mut ctx = LEGACY_SESSION.create_execution_ctx();
291        assert_eq!(nan_count(&p.into_array(), &mut ctx)?, 0);
292        Ok(())
293    }
294}