Skip to main content

vortex_array/aggregate_fn/fns/count/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod grouped;
5pub(crate) use grouped::CountGroupedKernel;
6use vortex_error::VortexExpect;
7use vortex_error::VortexResult;
8use vortex_session::registry::CachedId;
9
10use crate::ArrayRef;
11use crate::Columnar;
12use crate::ExecutionCtx;
13use crate::aggregate_fn::AggregateFnId;
14use crate::aggregate_fn::AggregateFnVTable;
15use crate::aggregate_fn::NumericalAggregateOpts;
16use crate::aggregate_fn::fns::nan_count::nan_count;
17use crate::dtype::DType;
18use crate::dtype::Nullability;
19use crate::dtype::PType;
20use crate::scalar::Scalar;
21
22/// Count the number of non-null elements in an array.
23///
24/// Applies to all types. Returns a `u64` count.
25/// The identity value is zero.
26///
27/// For float inputs, NaN handling is controlled by [`NumericalAggregateOpts`]: with `skip_nans` (the
28/// default) NaN values are treated as missing and excluded from the count, otherwise they are
29/// counted like any other non-null value.
30#[derive(Clone, Debug)]
31pub struct Count;
32
33/// Partial accumulator state for the count aggregate.
34pub struct CountPartial {
35    count: u64,
36    /// Whether NaN values must be excluded from the count (float input with `skip_nans`).
37    exclude_nans: bool,
38}
39
40impl AggregateFnVTable for Count {
41    type Options = NumericalAggregateOpts;
42    type Partial = CountPartial;
43
44    fn id(&self) -> AggregateFnId {
45        static ID: CachedId = CachedId::new("vortex.count");
46        *ID
47    }
48
49    fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
50        unimplemented!("Count is not yet serializable");
51    }
52
53    fn return_dtype(&self, _options: &Self::Options, _input_dtype: &DType) -> Option<DType> {
54        Some(DType::Primitive(PType::U64, Nullability::NonNullable))
55    }
56
57    fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option<DType> {
58        self.return_dtype(options, input_dtype)
59    }
60
61    fn empty_partial(
62        &self,
63        options: &Self::Options,
64        input_dtype: &DType,
65    ) -> VortexResult<Self::Partial> {
66        Ok(CountPartial {
67            count: 0,
68            exclude_nans: options.skip_nans && input_dtype.is_float(),
69        })
70    }
71
72    fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()> {
73        let val = other
74            .as_primitive()
75            .typed_value::<u64>()
76            .vortex_expect("count partial should not be null");
77        partial.count += val;
78        Ok(())
79    }
80
81    fn to_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
82        Ok(Scalar::primitive(partial.count, Nullability::NonNullable))
83    }
84
85    fn reset(&self, partial: &mut Self::Partial) {
86        partial.count = 0;
87    }
88
89    #[inline]
90    fn is_saturated(&self, _partial: &Self::Partial) -> bool {
91        false
92    }
93
94    fn try_accumulate(
95        &self,
96        state: &mut Self::Partial,
97        batch: &ArrayRef,
98        ctx: &mut ExecutionCtx,
99    ) -> VortexResult<bool> {
100        let mut count = batch.valid_count(ctx)? as u64;
101        if state.exclude_nans {
102            // `nan_count` shortcircuits on an exact `Stat::NaNCount` before scanning the batch.
103            count = count.saturating_sub(nan_count(batch, ctx)? as u64);
104        }
105        state.count += count;
106        Ok(true)
107    }
108
109    fn accumulate(
110        &self,
111        _partial: &mut Self::Partial,
112        _batch: &Columnar,
113        _ctx: &mut ExecutionCtx,
114    ) -> VortexResult<()> {
115        unreachable!("Count::try_accumulate handles all arrays")
116    }
117
118    fn finalize(&self, partials: ArrayRef) -> VortexResult<ArrayRef> {
119        Ok(partials)
120    }
121
122    fn finalize_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
123        self.to_scalar(partial)
124    }
125}
126
127#[cfg(test)]
128mod tests {
129    use std::sync::LazyLock;
130
131    use vortex_buffer::buffer;
132    use vortex_error::VortexExpect;
133    use vortex_error::VortexResult;
134    use vortex_session::VortexSession;
135
136    use crate::ArrayRef;
137    use crate::ExecutionCtx;
138    use crate::IntoArray;
139    use crate::VortexSessionExecute;
140    use crate::aggregate_fn::Accumulator;
141    use crate::aggregate_fn::AggregateFnVTable;
142    use crate::aggregate_fn::DynAccumulator;
143    use crate::aggregate_fn::NumericalAggregateOpts;
144    use crate::aggregate_fn::fns::count::Count;
145    use crate::arrays::ChunkedArray;
146    use crate::arrays::ConstantArray;
147    use crate::arrays::PrimitiveArray;
148    use crate::dtype::DType;
149    use crate::dtype::Nullability;
150    use crate::dtype::PType;
151    use crate::expr::stats::Precision;
152    use crate::expr::stats::Stat;
153    use crate::scalar::Scalar;
154    use crate::scalar::ScalarValue;
155    use crate::validity::Validity;
156
157    static SESSION: LazyLock<VortexSession> = LazyLock::new(vortex_array::array_session);
158
159    pub fn count(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<usize> {
160        let mut acc = Accumulator::try_new(
161            Count,
162            NumericalAggregateOpts::default(),
163            array.dtype().clone(),
164        )?;
165        acc.accumulate(array, ctx)?;
166        let result = acc.finish()?;
167
168        Ok(usize::try_from(
169            result
170                .as_primitive()
171                .typed_value::<u64>()
172                .vortex_expect("count result should not be null"),
173        )?)
174    }
175
176    #[test]
177    fn count_all_valid() -> VortexResult<()> {
178        let array =
179            PrimitiveArray::new(buffer![1i32, 2, 3, 4, 5], Validity::NonNullable).into_array();
180        let mut ctx = SESSION.create_execution_ctx();
181        assert_eq!(count(&array, &mut ctx)?, 5);
182        Ok(())
183    }
184
185    #[test]
186    fn count_with_nulls() -> VortexResult<()> {
187        let array = PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), None, Some(5)])
188            .into_array();
189        let mut ctx = SESSION.create_execution_ctx();
190        assert_eq!(count(&array, &mut ctx)?, 3);
191        Ok(())
192    }
193
194    #[test]
195    fn count_all_null() -> VortexResult<()> {
196        let array = PrimitiveArray::from_option_iter::<i32, _>([None, None, None]).into_array();
197        let mut ctx = SESSION.create_execution_ctx();
198        assert_eq!(count(&array, &mut ctx)?, 0);
199        Ok(())
200    }
201
202    #[test]
203    fn count_empty() -> VortexResult<()> {
204        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
205        let mut acc = Accumulator::try_new(Count, NumericalAggregateOpts::default(), dtype)?;
206        let result = acc.finish()?;
207        assert_eq!(result.as_primitive().typed_value::<u64>(), Some(0));
208        Ok(())
209    }
210
211    #[test]
212    fn count_multi_batch() -> VortexResult<()> {
213        let mut ctx = SESSION.create_execution_ctx();
214        let dtype = DType::Primitive(PType::I32, Nullability::Nullable);
215        let mut acc = Accumulator::try_new(Count, NumericalAggregateOpts::default(), dtype)?;
216
217        let batch1 = PrimitiveArray::from_option_iter([Some(1i32), None, Some(3)]).into_array();
218        acc.accumulate(&batch1, &mut ctx)?;
219
220        let batch2 = PrimitiveArray::from_option_iter([None, Some(5i32)]).into_array();
221        acc.accumulate(&batch2, &mut ctx)?;
222
223        let result = acc.finish()?;
224        assert_eq!(result.as_primitive().typed_value::<u64>(), Some(3));
225        Ok(())
226    }
227
228    #[test]
229    fn count_finish_resets_state() -> VortexResult<()> {
230        let mut ctx = SESSION.create_execution_ctx();
231        let dtype = DType::Primitive(PType::I32, Nullability::Nullable);
232        let mut acc = Accumulator::try_new(Count, NumericalAggregateOpts::default(), dtype)?;
233
234        let batch1 = PrimitiveArray::from_option_iter([Some(1i32), None]).into_array();
235        acc.accumulate(&batch1, &mut ctx)?;
236        let result1 = acc.finish()?;
237        assert_eq!(result1.as_primitive().typed_value::<u64>(), Some(1));
238
239        let batch2 = PrimitiveArray::from_option_iter([Some(2i32), Some(3), None]).into_array();
240        acc.accumulate(&batch2, &mut ctx)?;
241        let result2 = acc.finish()?;
242        assert_eq!(result2.as_primitive().typed_value::<u64>(), Some(2));
243        Ok(())
244    }
245
246    #[test]
247    fn count_state_merge() -> VortexResult<()> {
248        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
249        let mut state = Count.empty_partial(&NumericalAggregateOpts::default(), &dtype)?;
250
251        let scalar1 = Scalar::primitive(5u64, Nullability::NonNullable);
252        Count.combine_partials(&mut state, scalar1)?;
253
254        let scalar2 = Scalar::primitive(3u64, Nullability::NonNullable);
255        Count.combine_partials(&mut state, scalar2)?;
256
257        let result = Count.to_scalar(&state)?;
258        Count.reset(&mut state);
259        assert_eq!(result.as_primitive().typed_value::<u64>(), Some(8));
260        Ok(())
261    }
262
263    fn count_with_options(
264        array: &ArrayRef,
265        ctx: &mut ExecutionCtx,
266        options: NumericalAggregateOpts,
267    ) -> VortexResult<u64> {
268        let mut acc = Accumulator::try_new(Count, options, array.dtype().clone())?;
269        acc.accumulate(array, ctx)?;
270        Ok(acc
271            .finish()?
272            .as_primitive()
273            .typed_value::<u64>()
274            .vortex_expect("count result should not be null"))
275    }
276
277    #[test]
278    fn count_float_excludes_nans_by_default() -> VortexResult<()> {
279        let array =
280            PrimitiveArray::from_option_iter([Some(1.0f64), Some(f64::NAN), None, Some(3.0)])
281                .into_array();
282        let mut ctx = SESSION.create_execution_ctx();
283        assert_eq!(count(&array, &mut ctx)?, 2);
284        Ok(())
285    }
286
287    #[test]
288    fn count_float_includes_nans_when_not_skipping() -> VortexResult<()> {
289        let array =
290            PrimitiveArray::from_option_iter([Some(1.0f64), Some(f64::NAN), None, Some(3.0)])
291                .into_array();
292        let mut ctx = SESSION.create_execution_ctx();
293        assert_eq!(
294            count_with_options(&array, &mut ctx, NumericalAggregateOpts::include_nans())?,
295            3
296        );
297        Ok(())
298    }
299
300    #[test]
301    fn count_float_shortcircuits_on_exact_nan_count_stat() -> VortexResult<()> {
302        // The array has no NaNs; a planted exact NaNCount stat proves the count is derived from
303        // the stat rather than a scan.
304        let array =
305            PrimitiveArray::new(buffer![1.0f64, 2.0, 3.0, 4.0], Validity::NonNullable).into_array();
306        array
307            .statistics()
308            .set(Stat::NaNCount, Precision::Exact(ScalarValue::from(3u64)));
309        let mut ctx = SESSION.create_execution_ctx();
310        assert_eq!(count(&array, &mut ctx)?, 1);
311        Ok(())
312    }
313
314    #[test]
315    fn count_constant_nan() -> VortexResult<()> {
316        let array = ConstantArray::new(f64::NAN, 5).into_array();
317        let mut ctx = SESSION.create_execution_ctx();
318        assert_eq!(count(&array, &mut ctx)?, 0);
319        assert_eq!(
320            count_with_options(&array, &mut ctx, NumericalAggregateOpts::include_nans())?,
321            5
322        );
323        Ok(())
324    }
325
326    #[test]
327    fn count_constant_non_null() -> VortexResult<()> {
328        let array = ConstantArray::new(42i32, 10);
329        let mut ctx = SESSION.create_execution_ctx();
330        assert_eq!(count(&array.into_array(), &mut ctx)?, 10);
331        Ok(())
332    }
333
334    #[test]
335    fn count_constant_null() -> VortexResult<()> {
336        let array = ConstantArray::new(
337            Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)),
338            10,
339        );
340        let mut ctx = SESSION.create_execution_ctx();
341        assert_eq!(count(&array.into_array(), &mut ctx)?, 0);
342        Ok(())
343    }
344
345    #[test]
346    fn count_chunked() -> VortexResult<()> {
347        let chunk1 = PrimitiveArray::from_option_iter([Some(1i32), None, Some(3)]);
348        let chunk2 = PrimitiveArray::from_option_iter([None, Some(5i32), None]);
349        let dtype = chunk1.dtype().clone();
350        let chunked = ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype)?;
351        let mut ctx = SESSION.create_execution_ctx();
352        assert_eq!(count(&chunked.into_array(), &mut ctx)?, 3);
353        Ok(())
354    }
355}