Skip to main content

vortex_array/aggregate_fn/fns/count/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexExpect;
5use vortex_error::VortexResult;
6
7use crate::ArrayRef;
8use crate::Columnar;
9use crate::ExecutionCtx;
10use crate::aggregate_fn::AggregateFnId;
11use crate::aggregate_fn::AggregateFnVTable;
12use crate::aggregate_fn::EmptyOptions;
13use crate::dtype::DType;
14use crate::dtype::Nullability;
15use crate::dtype::PType;
16use crate::scalar::Scalar;
17
18/// Count the number of non-null elements in an array.
19///
20/// Applies to all types. Returns a `u64` count.
21/// The identity value is zero.
22#[derive(Clone, Debug)]
23pub struct Count;
24
25impl AggregateFnVTable for Count {
26    type Options = EmptyOptions;
27    type Partial = u64;
28
29    fn id(&self) -> AggregateFnId {
30        AggregateFnId::new_ref("vortex.count")
31    }
32
33    fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
34        unimplemented!("Count is not yet serializable");
35    }
36
37    fn return_dtype(&self, _options: &Self::Options, _input_dtype: &DType) -> Option<DType> {
38        Some(DType::Primitive(PType::U64, Nullability::NonNullable))
39    }
40
41    fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option<DType> {
42        self.return_dtype(options, input_dtype)
43    }
44
45    fn empty_partial(
46        &self,
47        _options: &Self::Options,
48        _input_dtype: &DType,
49    ) -> VortexResult<Self::Partial> {
50        Ok(0u64)
51    }
52
53    fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()> {
54        let val = other
55            .as_primitive()
56            .typed_value::<u64>()
57            .vortex_expect("count partial should not be null");
58        *partial += val;
59        Ok(())
60    }
61
62    fn to_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
63        Ok(Scalar::primitive(*partial, Nullability::NonNullable))
64    }
65
66    fn reset(&self, partial: &mut Self::Partial) {
67        *partial = 0;
68    }
69
70    #[inline]
71    fn is_saturated(&self, _partial: &Self::Partial) -> bool {
72        false
73    }
74
75    fn try_accumulate(
76        &self,
77        state: &mut Self::Partial,
78        batch: &ArrayRef,
79        _ctx: &mut ExecutionCtx,
80    ) -> VortexResult<bool> {
81        *state += batch.valid_count()? as u64;
82        Ok(true)
83    }
84
85    fn accumulate(
86        &self,
87        _partial: &mut Self::Partial,
88        _batch: &Columnar,
89        _ctx: &mut ExecutionCtx,
90    ) -> VortexResult<()> {
91        unreachable!("Count::try_accumulate handles all arrays")
92    }
93
94    fn finalize(&self, partials: ArrayRef) -> VortexResult<ArrayRef> {
95        Ok(partials)
96    }
97
98    fn finalize_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
99        self.to_scalar(partial)
100    }
101}
102
103#[cfg(test)]
104mod tests {
105    use vortex_buffer::buffer;
106    use vortex_error::VortexExpect;
107    use vortex_error::VortexResult;
108
109    use crate::ArrayRef;
110    use crate::ExecutionCtx;
111    use crate::IntoArray;
112    use crate::LEGACY_SESSION;
113    use crate::VortexSessionExecute;
114    use crate::aggregate_fn::Accumulator;
115    use crate::aggregate_fn::AggregateFnVTable;
116    use crate::aggregate_fn::DynAccumulator;
117    use crate::aggregate_fn::EmptyOptions;
118    use crate::aggregate_fn::fns::count::Count;
119    use crate::arrays::ChunkedArray;
120    use crate::arrays::ConstantArray;
121    use crate::arrays::PrimitiveArray;
122    use crate::dtype::DType;
123    use crate::dtype::Nullability;
124    use crate::dtype::PType;
125    use crate::scalar::Scalar;
126    use crate::validity::Validity;
127
128    pub fn count(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<usize> {
129        let mut acc = Accumulator::try_new(Count, EmptyOptions, array.dtype().clone())?;
130        acc.accumulate(array, ctx)?;
131        let result = acc.finish()?;
132
133        Ok(usize::try_from(
134            result
135                .as_primitive()
136                .typed_value::<u64>()
137                .vortex_expect("count result should not be null"),
138        )?)
139    }
140
141    #[test]
142    fn count_all_valid() -> VortexResult<()> {
143        let array =
144            PrimitiveArray::new(buffer![1i32, 2, 3, 4, 5], Validity::NonNullable).into_array();
145        let mut ctx = LEGACY_SESSION.create_execution_ctx();
146        assert_eq!(count(&array, &mut ctx)?, 5);
147        Ok(())
148    }
149
150    #[test]
151    fn count_with_nulls() -> VortexResult<()> {
152        let array = PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), None, Some(5)])
153            .into_array();
154        let mut ctx = LEGACY_SESSION.create_execution_ctx();
155        assert_eq!(count(&array, &mut ctx)?, 3);
156        Ok(())
157    }
158
159    #[test]
160    fn count_all_null() -> VortexResult<()> {
161        let array = PrimitiveArray::from_option_iter::<i32, _>([None, None, None]).into_array();
162        let mut ctx = LEGACY_SESSION.create_execution_ctx();
163        assert_eq!(count(&array, &mut ctx)?, 0);
164        Ok(())
165    }
166
167    #[test]
168    fn count_empty() -> VortexResult<()> {
169        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
170        let mut acc = Accumulator::try_new(Count, EmptyOptions, dtype)?;
171        let result = acc.finish()?;
172        assert_eq!(result.as_primitive().typed_value::<u64>(), Some(0));
173        Ok(())
174    }
175
176    #[test]
177    fn count_multi_batch() -> VortexResult<()> {
178        let mut ctx = LEGACY_SESSION.create_execution_ctx();
179        let dtype = DType::Primitive(PType::I32, Nullability::Nullable);
180        let mut acc = Accumulator::try_new(Count, EmptyOptions, dtype)?;
181
182        let batch1 = PrimitiveArray::from_option_iter([Some(1i32), None, Some(3)]).into_array();
183        acc.accumulate(&batch1, &mut ctx)?;
184
185        let batch2 = PrimitiveArray::from_option_iter([None, Some(5i32)]).into_array();
186        acc.accumulate(&batch2, &mut ctx)?;
187
188        let result = acc.finish()?;
189        assert_eq!(result.as_primitive().typed_value::<u64>(), Some(3));
190        Ok(())
191    }
192
193    #[test]
194    fn count_finish_resets_state() -> VortexResult<()> {
195        let mut ctx = LEGACY_SESSION.create_execution_ctx();
196        let dtype = DType::Primitive(PType::I32, Nullability::Nullable);
197        let mut acc = Accumulator::try_new(Count, EmptyOptions, dtype)?;
198
199        let batch1 = PrimitiveArray::from_option_iter([Some(1i32), None]).into_array();
200        acc.accumulate(&batch1, &mut ctx)?;
201        let result1 = acc.finish()?;
202        assert_eq!(result1.as_primitive().typed_value::<u64>(), Some(1));
203
204        let batch2 = PrimitiveArray::from_option_iter([Some(2i32), Some(3), None]).into_array();
205        acc.accumulate(&batch2, &mut ctx)?;
206        let result2 = acc.finish()?;
207        assert_eq!(result2.as_primitive().typed_value::<u64>(), Some(2));
208        Ok(())
209    }
210
211    #[test]
212    fn count_state_merge() -> VortexResult<()> {
213        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
214        let mut state = Count.empty_partial(&EmptyOptions, &dtype)?;
215
216        let scalar1 = Scalar::primitive(5u64, Nullability::NonNullable);
217        Count.combine_partials(&mut state, scalar1)?;
218
219        let scalar2 = Scalar::primitive(3u64, Nullability::NonNullable);
220        Count.combine_partials(&mut state, scalar2)?;
221
222        let result = Count.to_scalar(&state)?;
223        Count.reset(&mut state);
224        assert_eq!(result.as_primitive().typed_value::<u64>(), Some(8));
225        Ok(())
226    }
227
228    #[test]
229    fn count_constant_non_null() -> VortexResult<()> {
230        let array = ConstantArray::new(42i32, 10);
231        let mut ctx = LEGACY_SESSION.create_execution_ctx();
232        assert_eq!(count(&array.into_array(), &mut ctx)?, 10);
233        Ok(())
234    }
235
236    #[test]
237    fn count_constant_null() -> VortexResult<()> {
238        let array = ConstantArray::new(
239            Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)),
240            10,
241        );
242        let mut ctx = LEGACY_SESSION.create_execution_ctx();
243        assert_eq!(count(&array.into_array(), &mut ctx)?, 0);
244        Ok(())
245    }
246
247    #[test]
248    fn count_chunked() -> VortexResult<()> {
249        let chunk1 = PrimitiveArray::from_option_iter([Some(1i32), None, Some(3)]);
250        let chunk2 = PrimitiveArray::from_option_iter([None, Some(5i32), None]);
251        let dtype = chunk1.dtype().clone();
252        let chunked = ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype)?;
253        let mut ctx = LEGACY_SESSION.create_execution_ctx();
254        assert_eq!(count(&chunked.into_array(), &mut ctx)?, 3);
255        Ok(())
256    }
257}