Skip to main content

vortex_array/aggregate_fn/fns/nan_count/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod primitive;
5
6use vortex_error::VortexExpect;
7use vortex_error::VortexResult;
8use vortex_error::vortex_bail;
9use vortex_error::vortex_err;
10
11use self::primitive::accumulate_primitive;
12use crate::ArrayRef;
13use crate::Canonical;
14use crate::Columnar;
15use crate::ExecutionCtx;
16use crate::aggregate_fn::Accumulator;
17use crate::aggregate_fn::AggregateFnId;
18use crate::aggregate_fn::AggregateFnVTable;
19use crate::aggregate_fn::DynAccumulator;
20use crate::aggregate_fn::EmptyOptions;
21use crate::dtype::DType;
22use crate::dtype::Nullability::NonNullable;
23use crate::dtype::PType;
24use crate::expr::stats::Precision;
25use crate::expr::stats::Stat;
26use crate::expr::stats::StatsProvider;
27use crate::scalar::Scalar;
28use crate::scalar::ScalarValue;
29
30/// Return the number of NaN values in an array.
31///
32/// Null values are not NaN and are not counted.
33///
34/// See [`NanCount`] for details.
35pub fn nan_count(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<usize> {
36    // Short-circuit using cached array statistics.
37    if let Precision::Exact(nan_count_scalar) = array.statistics().get(Stat::NaNCount) {
38        return usize::try_from(&nan_count_scalar)
39            .map_err(|e| vortex_err!("Failed to convert NaN count stat to usize: {e}"));
40    }
41
42    // Short-circuit for non-float types.
43    if NanCount
44        .return_dtype(&EmptyOptions, array.dtype())
45        .is_none()
46    {
47        return Ok(0);
48    }
49
50    // Short-circuit for empty arrays or all-null arrays.
51    if array.is_empty() || array.valid_count(ctx)? == 0 {
52        return Ok(0);
53    }
54
55    // Compute using Accumulator<NanCount>.
56    let mut acc = Accumulator::try_new(NanCount, EmptyOptions, array.dtype().clone())?;
57    acc.accumulate(array, ctx)?;
58    let result = acc.finish()?;
59
60    let count = result
61        .as_primitive()
62        .typed_value::<u64>()
63        .vortex_expect("nan_count result should not be null");
64    let count_usize = usize::try_from(count).vortex_expect("Cannot be more nans than usize::MAX");
65
66    // Cache the computed NaN count as a statistic.
67    array
68        .statistics()
69        .set(Stat::NaNCount, Precision::Exact(ScalarValue::from(count)));
70
71    Ok(count_usize)
72}
73
74/// Count the number of NaN values in an array.
75///
76/// Only applies to floating-point primitive types. Returns a `u64` count.
77/// Null values are not NaN, so if the array is all-invalid, the NaN count is zero.
78#[derive(Clone, Debug)]
79pub struct NanCount;
80
81impl AggregateFnVTable for NanCount {
82    type Options = EmptyOptions;
83    type Partial = u64;
84
85    fn id(&self) -> AggregateFnId {
86        AggregateFnId::new("vortex.nan_count")
87    }
88
89    fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
90        unimplemented!("NanCount is not yet serializable");
91    }
92
93    fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option<DType> {
94        if let DType::Primitive(ptype, ..) = input_dtype
95            && ptype.is_float()
96        {
97            Some(DType::Primitive(PType::U64, NonNullable))
98        } else {
99            None
100        }
101    }
102
103    fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option<DType> {
104        self.return_dtype(options, input_dtype)
105    }
106
107    fn empty_partial(
108        &self,
109        _options: &Self::Options,
110        _input_dtype: &DType,
111    ) -> VortexResult<Self::Partial> {
112        Ok(0u64)
113    }
114
115    fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()> {
116        let val = other
117            .as_primitive()
118            .typed_value::<u64>()
119            .vortex_expect("nan_count partial should not be null");
120        *partial += val;
121        Ok(())
122    }
123
124    fn to_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
125        Ok(Scalar::primitive(*partial, NonNullable))
126    }
127
128    fn reset(&self, partial: &mut Self::Partial) {
129        *partial = 0;
130    }
131
132    #[inline]
133    fn is_saturated(&self, _partial: &Self::Partial) -> bool {
134        false
135    }
136
137    fn accumulate(
138        &self,
139        partial: &mut Self::Partial,
140        batch: &Columnar,
141        ctx: &mut ExecutionCtx,
142    ) -> VortexResult<()> {
143        match batch {
144            Columnar::Constant(c) => {
145                if c.scalar().is_null() {
146                    // Null values are not NaN.
147                    return Ok(());
148                }
149                if c.scalar().as_primitive().is_nan() {
150                    *partial += c.len() as u64;
151                }
152                Ok(())
153            }
154            Columnar::Canonical(c) => match c {
155                Canonical::Primitive(p) => accumulate_primitive(partial, p, ctx),
156                _ => vortex_bail!(
157                    "Unsupported canonical type for nan_count: {}",
158                    batch.dtype()
159                ),
160            },
161        }
162    }
163
164    fn finalize(&self, partials: ArrayRef) -> VortexResult<ArrayRef> {
165        Ok(partials)
166    }
167
168    fn finalize_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
169        self.to_scalar(partial)
170    }
171}
172
173#[cfg(test)]
174mod tests {
175    use vortex_buffer::buffer;
176    use vortex_error::VortexResult;
177
178    use crate::IntoArray;
179    use crate::LEGACY_SESSION;
180    use crate::VortexSessionExecute;
181    use crate::aggregate_fn::Accumulator;
182    use crate::aggregate_fn::AggregateFnVTable;
183    use crate::aggregate_fn::DynAccumulator;
184    use crate::aggregate_fn::EmptyOptions;
185    use crate::aggregate_fn::fns::nan_count::NanCount;
186    use crate::aggregate_fn::fns::nan_count::nan_count;
187    use crate::arrays::ChunkedArray;
188    use crate::arrays::ConstantArray;
189    use crate::arrays::PrimitiveArray;
190    use crate::dtype::DType;
191    use crate::dtype::Nullability;
192    use crate::dtype::PType;
193    use crate::scalar::Scalar;
194    use crate::validity::Validity;
195
196    #[test]
197    fn nan_count_multi_batch() -> VortexResult<()> {
198        let mut ctx = LEGACY_SESSION.create_execution_ctx();
199        let dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
200        let mut acc = Accumulator::try_new(NanCount, EmptyOptions, dtype)?;
201
202        let batch1 =
203            PrimitiveArray::new(buffer![f64::NAN, 1.0f64, f64::NAN], Validity::NonNullable)
204                .into_array();
205        acc.accumulate(&batch1, &mut ctx)?;
206
207        let batch2 =
208            PrimitiveArray::new(buffer![2.0f64, f64::NAN], Validity::NonNullable).into_array();
209        acc.accumulate(&batch2, &mut ctx)?;
210
211        let result = acc.finish()?;
212        assert_eq!(result.as_primitive().typed_value::<u64>(), Some(3));
213        Ok(())
214    }
215
216    #[test]
217    fn nan_count_finish_resets_state() -> VortexResult<()> {
218        let mut ctx = LEGACY_SESSION.create_execution_ctx();
219        let dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
220        let mut acc = Accumulator::try_new(NanCount, EmptyOptions, dtype)?;
221
222        let batch1 =
223            PrimitiveArray::new(buffer![f64::NAN, 1.0f64], Validity::NonNullable).into_array();
224        acc.accumulate(&batch1, &mut ctx)?;
225        let result1 = acc.finish()?;
226        assert_eq!(result1.as_primitive().typed_value::<u64>(), Some(1));
227
228        let batch2 = PrimitiveArray::new(buffer![f64::NAN, f64::NAN, 2.0], Validity::NonNullable)
229            .into_array();
230        acc.accumulate(&batch2, &mut ctx)?;
231        let result2 = acc.finish()?;
232        assert_eq!(result2.as_primitive().typed_value::<u64>(), Some(2));
233        Ok(())
234    }
235
236    #[test]
237    fn nan_count_state_merge() -> VortexResult<()> {
238        let dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
239        let mut state = NanCount.empty_partial(&EmptyOptions, &dtype)?;
240
241        let scalar1 = Scalar::primitive(5u64, Nullability::NonNullable);
242        NanCount.combine_partials(&mut state, scalar1)?;
243
244        let scalar2 = Scalar::primitive(3u64, Nullability::NonNullable);
245        NanCount.combine_partials(&mut state, scalar2)?;
246
247        let result = NanCount.to_scalar(&state)?;
248        NanCount.reset(&mut state);
249        assert_eq!(result.as_primitive().typed_value::<u64>(), Some(8));
250        Ok(())
251    }
252
253    #[test]
254    fn nan_count_constant_nan() -> VortexResult<()> {
255        let array = ConstantArray::new(f64::NAN, 10);
256        let mut ctx = LEGACY_SESSION.create_execution_ctx();
257        assert_eq!(nan_count(&array.into_array(), &mut ctx)?, 10);
258        Ok(())
259    }
260
261    #[test]
262    fn nan_count_constant_non_nan() -> VortexResult<()> {
263        let array = ConstantArray::new(1.0f64, 10);
264        let mut ctx = LEGACY_SESSION.create_execution_ctx();
265        assert_eq!(nan_count(&array.into_array(), &mut ctx)?, 0);
266        Ok(())
267    }
268
269    #[test]
270    fn nan_count_empty() -> VortexResult<()> {
271        let dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
272        let mut acc = Accumulator::try_new(NanCount, EmptyOptions, dtype)?;
273        let result = acc.finish()?;
274        assert_eq!(result.as_primitive().typed_value::<u64>(), Some(0));
275        Ok(())
276    }
277
278    #[test]
279    fn nan_count_chunked() -> VortexResult<()> {
280        let chunk1 = PrimitiveArray::from_option_iter([Some(f64::NAN), None, Some(1.0)]);
281        let chunk2 = PrimitiveArray::from_option_iter([Some(f64::NAN), Some(f64::NAN), None]);
282        let dtype = chunk1.dtype().clone();
283        let chunked = ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype)?;
284        let mut ctx = LEGACY_SESSION.create_execution_ctx();
285        assert_eq!(nan_count(&chunked.into_array(), &mut ctx)?, 3);
286        Ok(())
287    }
288
289    #[test]
290    fn nan_count_all_null() -> VortexResult<()> {
291        let p = PrimitiveArray::from_option_iter::<f64, _>([None, None, None]);
292        let mut ctx = LEGACY_SESSION.create_execution_ctx();
293        assert_eq!(nan_count(&p.into_array(), &mut ctx)?, 0);
294        Ok(())
295    }
296}