Skip to main content

vortex_array/aggregate_fn/fns/count/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod grouped;
5pub(crate) use grouped::CountGroupedKernel;
6use vortex_error::VortexExpect;
7use vortex_error::VortexResult;
8
9use crate::ArrayRef;
10use crate::Columnar;
11use crate::ExecutionCtx;
12use crate::aggregate_fn::AggregateFnId;
13use crate::aggregate_fn::AggregateFnVTable;
14use crate::aggregate_fn::EmptyOptions;
15use crate::dtype::DType;
16use crate::dtype::Nullability;
17use crate::dtype::PType;
18use crate::scalar::Scalar;
19
20/// Count the number of non-null elements in an array.
21///
22/// Applies to all types. Returns a `u64` count.
23/// The identity value is zero.
24#[derive(Clone, Debug)]
25pub struct Count;
26
27impl AggregateFnVTable for Count {
28    type Options = EmptyOptions;
29    type Partial = u64;
30
31    fn id(&self) -> AggregateFnId {
32        AggregateFnId::new("vortex.count")
33    }
34
35    fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
36        unimplemented!("Count is not yet serializable");
37    }
38
39    fn return_dtype(&self, _options: &Self::Options, _input_dtype: &DType) -> Option<DType> {
40        Some(DType::Primitive(PType::U64, Nullability::NonNullable))
41    }
42
43    fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option<DType> {
44        self.return_dtype(options, input_dtype)
45    }
46
47    fn empty_partial(
48        &self,
49        _options: &Self::Options,
50        _input_dtype: &DType,
51    ) -> VortexResult<Self::Partial> {
52        Ok(0u64)
53    }
54
55    fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()> {
56        let val = other
57            .as_primitive()
58            .typed_value::<u64>()
59            .vortex_expect("count partial should not be null");
60        *partial += val;
61        Ok(())
62    }
63
64    fn to_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
65        Ok(Scalar::primitive(*partial, Nullability::NonNullable))
66    }
67
68    fn reset(&self, partial: &mut Self::Partial) {
69        *partial = 0;
70    }
71
72    #[inline]
73    fn is_saturated(&self, _partial: &Self::Partial) -> bool {
74        false
75    }
76
77    fn try_accumulate(
78        &self,
79        state: &mut Self::Partial,
80        batch: &ArrayRef,
81        ctx: &mut ExecutionCtx,
82    ) -> VortexResult<bool> {
83        *state += batch.valid_count(ctx)? as u64;
84        Ok(true)
85    }
86
87    fn accumulate(
88        &self,
89        _partial: &mut Self::Partial,
90        _batch: &Columnar,
91        _ctx: &mut ExecutionCtx,
92    ) -> VortexResult<()> {
93        unreachable!("Count::try_accumulate handles all arrays")
94    }
95
96    fn finalize(&self, partials: ArrayRef) -> VortexResult<ArrayRef> {
97        Ok(partials)
98    }
99
100    fn finalize_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
101        self.to_scalar(partial)
102    }
103}
104
105#[cfg(test)]
106mod tests {
107    use vortex_buffer::buffer;
108    use vortex_error::VortexExpect;
109    use vortex_error::VortexResult;
110
111    use crate::ArrayRef;
112    use crate::ExecutionCtx;
113    use crate::IntoArray;
114    use crate::LEGACY_SESSION;
115    use crate::VortexSessionExecute;
116    use crate::aggregate_fn::Accumulator;
117    use crate::aggregate_fn::AggregateFnVTable;
118    use crate::aggregate_fn::DynAccumulator;
119    use crate::aggregate_fn::EmptyOptions;
120    use crate::aggregate_fn::fns::count::Count;
121    use crate::arrays::ChunkedArray;
122    use crate::arrays::ConstantArray;
123    use crate::arrays::PrimitiveArray;
124    use crate::dtype::DType;
125    use crate::dtype::Nullability;
126    use crate::dtype::PType;
127    use crate::scalar::Scalar;
128    use crate::validity::Validity;
129
130    pub fn count(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<usize> {
131        let mut acc = Accumulator::try_new(Count, EmptyOptions, array.dtype().clone())?;
132        acc.accumulate(array, ctx)?;
133        let result = acc.finish()?;
134
135        Ok(usize::try_from(
136            result
137                .as_primitive()
138                .typed_value::<u64>()
139                .vortex_expect("count result should not be null"),
140        )?)
141    }
142
143    #[test]
144    fn count_all_valid() -> VortexResult<()> {
145        let array =
146            PrimitiveArray::new(buffer![1i32, 2, 3, 4, 5], Validity::NonNullable).into_array();
147        let mut ctx = LEGACY_SESSION.create_execution_ctx();
148        assert_eq!(count(&array, &mut ctx)?, 5);
149        Ok(())
150    }
151
152    #[test]
153    fn count_with_nulls() -> VortexResult<()> {
154        let array = PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), None, Some(5)])
155            .into_array();
156        let mut ctx = LEGACY_SESSION.create_execution_ctx();
157        assert_eq!(count(&array, &mut ctx)?, 3);
158        Ok(())
159    }
160
161    #[test]
162    fn count_all_null() -> VortexResult<()> {
163        let array = PrimitiveArray::from_option_iter::<i32, _>([None, None, None]).into_array();
164        let mut ctx = LEGACY_SESSION.create_execution_ctx();
165        assert_eq!(count(&array, &mut ctx)?, 0);
166        Ok(())
167    }
168
169    #[test]
170    fn count_empty() -> VortexResult<()> {
171        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
172        let mut acc = Accumulator::try_new(Count, EmptyOptions, dtype)?;
173        let result = acc.finish()?;
174        assert_eq!(result.as_primitive().typed_value::<u64>(), Some(0));
175        Ok(())
176    }
177
178    #[test]
179    fn count_multi_batch() -> VortexResult<()> {
180        let mut ctx = LEGACY_SESSION.create_execution_ctx();
181        let dtype = DType::Primitive(PType::I32, Nullability::Nullable);
182        let mut acc = Accumulator::try_new(Count, EmptyOptions, dtype)?;
183
184        let batch1 = PrimitiveArray::from_option_iter([Some(1i32), None, Some(3)]).into_array();
185        acc.accumulate(&batch1, &mut ctx)?;
186
187        let batch2 = PrimitiveArray::from_option_iter([None, Some(5i32)]).into_array();
188        acc.accumulate(&batch2, &mut ctx)?;
189
190        let result = acc.finish()?;
191        assert_eq!(result.as_primitive().typed_value::<u64>(), Some(3));
192        Ok(())
193    }
194
195    #[test]
196    fn count_finish_resets_state() -> VortexResult<()> {
197        let mut ctx = LEGACY_SESSION.create_execution_ctx();
198        let dtype = DType::Primitive(PType::I32, Nullability::Nullable);
199        let mut acc = Accumulator::try_new(Count, EmptyOptions, dtype)?;
200
201        let batch1 = PrimitiveArray::from_option_iter([Some(1i32), None]).into_array();
202        acc.accumulate(&batch1, &mut ctx)?;
203        let result1 = acc.finish()?;
204        assert_eq!(result1.as_primitive().typed_value::<u64>(), Some(1));
205
206        let batch2 = PrimitiveArray::from_option_iter([Some(2i32), Some(3), None]).into_array();
207        acc.accumulate(&batch2, &mut ctx)?;
208        let result2 = acc.finish()?;
209        assert_eq!(result2.as_primitive().typed_value::<u64>(), Some(2));
210        Ok(())
211    }
212
213    #[test]
214    fn count_state_merge() -> VortexResult<()> {
215        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
216        let mut state = Count.empty_partial(&EmptyOptions, &dtype)?;
217
218        let scalar1 = Scalar::primitive(5u64, Nullability::NonNullable);
219        Count.combine_partials(&mut state, scalar1)?;
220
221        let scalar2 = Scalar::primitive(3u64, Nullability::NonNullable);
222        Count.combine_partials(&mut state, scalar2)?;
223
224        let result = Count.to_scalar(&state)?;
225        Count.reset(&mut state);
226        assert_eq!(result.as_primitive().typed_value::<u64>(), Some(8));
227        Ok(())
228    }
229
230    #[test]
231    fn count_constant_non_null() -> VortexResult<()> {
232        let array = ConstantArray::new(42i32, 10);
233        let mut ctx = LEGACY_SESSION.create_execution_ctx();
234        assert_eq!(count(&array.into_array(), &mut ctx)?, 10);
235        Ok(())
236    }
237
238    #[test]
239    fn count_constant_null() -> VortexResult<()> {
240        let array = ConstantArray::new(
241            Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)),
242            10,
243        );
244        let mut ctx = LEGACY_SESSION.create_execution_ctx();
245        assert_eq!(count(&array.into_array(), &mut ctx)?, 0);
246        Ok(())
247    }
248
249    #[test]
250    fn count_chunked() -> VortexResult<()> {
251        let chunk1 = PrimitiveArray::from_option_iter([Some(1i32), None, Some(3)]);
252        let chunk2 = PrimitiveArray::from_option_iter([None, Some(5i32), None]);
253        let dtype = chunk1.dtype().clone();
254        let chunked = ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype)?;
255        let mut ctx = LEGACY_SESSION.create_execution_ctx();
256        assert_eq!(count(&chunked.into_array(), &mut ctx)?, 3);
257        Ok(())
258    }
259}