vortex_array/aggregate_fn/fns/count/
mod.rs1use vortex_error::VortexExpect;
5use vortex_error::VortexResult;
6
7use crate::ArrayRef;
8use crate::Columnar;
9use crate::ExecutionCtx;
10use crate::aggregate_fn::AggregateFnId;
11use crate::aggregate_fn::AggregateFnVTable;
12use crate::aggregate_fn::EmptyOptions;
13use crate::dtype::DType;
14use crate::dtype::Nullability;
15use crate::dtype::PType;
16use crate::scalar::Scalar;
17
18#[derive(Clone, Debug)]
23pub struct Count;
24
25impl AggregateFnVTable for Count {
26 type Options = EmptyOptions;
27 type Partial = u64;
28
29 fn id(&self) -> AggregateFnId {
30 AggregateFnId::new_ref("vortex.count")
31 }
32
33 fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
34 unimplemented!("Count is not yet serializable");
35 }
36
37 fn return_dtype(&self, _options: &Self::Options, _input_dtype: &DType) -> Option<DType> {
38 Some(DType::Primitive(PType::U64, Nullability::NonNullable))
39 }
40
41 fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option<DType> {
42 self.return_dtype(options, input_dtype)
43 }
44
45 fn empty_partial(
46 &self,
47 _options: &Self::Options,
48 _input_dtype: &DType,
49 ) -> VortexResult<Self::Partial> {
50 Ok(0u64)
51 }
52
53 fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()> {
54 let val = other
55 .as_primitive()
56 .typed_value::<u64>()
57 .vortex_expect("count partial should not be null");
58 *partial += val;
59 Ok(())
60 }
61
62 fn to_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
63 Ok(Scalar::primitive(*partial, Nullability::NonNullable))
64 }
65
66 fn reset(&self, partial: &mut Self::Partial) {
67 *partial = 0;
68 }
69
70 #[inline]
71 fn is_saturated(&self, _partial: &Self::Partial) -> bool {
72 false
73 }
74
75 fn try_accumulate(
76 &self,
77 state: &mut Self::Partial,
78 batch: &ArrayRef,
79 _ctx: &mut ExecutionCtx,
80 ) -> VortexResult<bool> {
81 *state += batch.valid_count()? as u64;
82 Ok(true)
83 }
84
85 fn accumulate(
86 &self,
87 _partial: &mut Self::Partial,
88 _batch: &Columnar,
89 _ctx: &mut ExecutionCtx,
90 ) -> VortexResult<()> {
91 unreachable!("Count::try_accumulate handles all arrays")
92 }
93
94 fn finalize(&self, partials: ArrayRef) -> VortexResult<ArrayRef> {
95 Ok(partials)
96 }
97
98 fn finalize_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
99 self.to_scalar(partial)
100 }
101}
102
103#[cfg(test)]
104mod tests {
105 use vortex_buffer::buffer;
106 use vortex_error::VortexExpect;
107 use vortex_error::VortexResult;
108
109 use crate::ArrayRef;
110 use crate::ExecutionCtx;
111 use crate::IntoArray;
112 use crate::LEGACY_SESSION;
113 use crate::VortexSessionExecute;
114 use crate::aggregate_fn::Accumulator;
115 use crate::aggregate_fn::AggregateFnVTable;
116 use crate::aggregate_fn::DynAccumulator;
117 use crate::aggregate_fn::EmptyOptions;
118 use crate::aggregate_fn::fns::count::Count;
119 use crate::arrays::ChunkedArray;
120 use crate::arrays::ConstantArray;
121 use crate::arrays::PrimitiveArray;
122 use crate::dtype::DType;
123 use crate::dtype::Nullability;
124 use crate::dtype::PType;
125 use crate::scalar::Scalar;
126 use crate::validity::Validity;
127
128 pub fn count(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<usize> {
129 let mut acc = Accumulator::try_new(Count, EmptyOptions, array.dtype().clone())?;
130 acc.accumulate(array, ctx)?;
131 let result = acc.finish()?;
132
133 Ok(usize::try_from(
134 result
135 .as_primitive()
136 .typed_value::<u64>()
137 .vortex_expect("count result should not be null"),
138 )?)
139 }
140
141 #[test]
142 fn count_all_valid() -> VortexResult<()> {
143 let array =
144 PrimitiveArray::new(buffer![1i32, 2, 3, 4, 5], Validity::NonNullable).into_array();
145 let mut ctx = LEGACY_SESSION.create_execution_ctx();
146 assert_eq!(count(&array, &mut ctx)?, 5);
147 Ok(())
148 }
149
150 #[test]
151 fn count_with_nulls() -> VortexResult<()> {
152 let array = PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), None, Some(5)])
153 .into_array();
154 let mut ctx = LEGACY_SESSION.create_execution_ctx();
155 assert_eq!(count(&array, &mut ctx)?, 3);
156 Ok(())
157 }
158
159 #[test]
160 fn count_all_null() -> VortexResult<()> {
161 let array = PrimitiveArray::from_option_iter::<i32, _>([None, None, None]).into_array();
162 let mut ctx = LEGACY_SESSION.create_execution_ctx();
163 assert_eq!(count(&array, &mut ctx)?, 0);
164 Ok(())
165 }
166
167 #[test]
168 fn count_empty() -> VortexResult<()> {
169 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
170 let mut acc = Accumulator::try_new(Count, EmptyOptions, dtype)?;
171 let result = acc.finish()?;
172 assert_eq!(result.as_primitive().typed_value::<u64>(), Some(0));
173 Ok(())
174 }
175
176 #[test]
177 fn count_multi_batch() -> VortexResult<()> {
178 let mut ctx = LEGACY_SESSION.create_execution_ctx();
179 let dtype = DType::Primitive(PType::I32, Nullability::Nullable);
180 let mut acc = Accumulator::try_new(Count, EmptyOptions, dtype)?;
181
182 let batch1 = PrimitiveArray::from_option_iter([Some(1i32), None, Some(3)]).into_array();
183 acc.accumulate(&batch1, &mut ctx)?;
184
185 let batch2 = PrimitiveArray::from_option_iter([None, Some(5i32)]).into_array();
186 acc.accumulate(&batch2, &mut ctx)?;
187
188 let result = acc.finish()?;
189 assert_eq!(result.as_primitive().typed_value::<u64>(), Some(3));
190 Ok(())
191 }
192
193 #[test]
194 fn count_finish_resets_state() -> VortexResult<()> {
195 let mut ctx = LEGACY_SESSION.create_execution_ctx();
196 let dtype = DType::Primitive(PType::I32, Nullability::Nullable);
197 let mut acc = Accumulator::try_new(Count, EmptyOptions, dtype)?;
198
199 let batch1 = PrimitiveArray::from_option_iter([Some(1i32), None]).into_array();
200 acc.accumulate(&batch1, &mut ctx)?;
201 let result1 = acc.finish()?;
202 assert_eq!(result1.as_primitive().typed_value::<u64>(), Some(1));
203
204 let batch2 = PrimitiveArray::from_option_iter([Some(2i32), Some(3), None]).into_array();
205 acc.accumulate(&batch2, &mut ctx)?;
206 let result2 = acc.finish()?;
207 assert_eq!(result2.as_primitive().typed_value::<u64>(), Some(2));
208 Ok(())
209 }
210
211 #[test]
212 fn count_state_merge() -> VortexResult<()> {
213 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
214 let mut state = Count.empty_partial(&EmptyOptions, &dtype)?;
215
216 let scalar1 = Scalar::primitive(5u64, Nullability::NonNullable);
217 Count.combine_partials(&mut state, scalar1)?;
218
219 let scalar2 = Scalar::primitive(3u64, Nullability::NonNullable);
220 Count.combine_partials(&mut state, scalar2)?;
221
222 let result = Count.to_scalar(&state)?;
223 Count.reset(&mut state);
224 assert_eq!(result.as_primitive().typed_value::<u64>(), Some(8));
225 Ok(())
226 }
227
228 #[test]
229 fn count_constant_non_null() -> VortexResult<()> {
230 let array = ConstantArray::new(42i32, 10);
231 let mut ctx = LEGACY_SESSION.create_execution_ctx();
232 assert_eq!(count(&array.into_array(), &mut ctx)?, 10);
233 Ok(())
234 }
235
236 #[test]
237 fn count_constant_null() -> VortexResult<()> {
238 let array = ConstantArray::new(
239 Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)),
240 10,
241 );
242 let mut ctx = LEGACY_SESSION.create_execution_ctx();
243 assert_eq!(count(&array.into_array(), &mut ctx)?, 0);
244 Ok(())
245 }
246
247 #[test]
248 fn count_chunked() -> VortexResult<()> {
249 let chunk1 = PrimitiveArray::from_option_iter([Some(1i32), None, Some(3)]);
250 let chunk2 = PrimitiveArray::from_option_iter([None, Some(5i32), None]);
251 let dtype = chunk1.dtype().clone();
252 let chunked = ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype)?;
253 let mut ctx = LEGACY_SESSION.create_execution_ctx();
254 assert_eq!(count(&chunked.into_array(), &mut ctx)?, 3);
255 Ok(())
256 }
257}