vortex_array/aggregate_fn/fns/count/
mod.rs1mod grouped;
5pub(crate) use grouped::CountGroupedKernel;
6use vortex_error::VortexExpect;
7use vortex_error::VortexResult;
8use vortex_session::registry::CachedId;
9
10use crate::ArrayRef;
11use crate::Columnar;
12use crate::ExecutionCtx;
13use crate::aggregate_fn::AggregateFnId;
14use crate::aggregate_fn::AggregateFnVTable;
15use crate::aggregate_fn::NumericalAggregateOpts;
16use crate::aggregate_fn::fns::nan_count::nan_count;
17use crate::dtype::DType;
18use crate::dtype::Nullability;
19use crate::dtype::PType;
20use crate::scalar::Scalar;
21
22#[derive(Clone, Debug)]
31pub struct Count;
32
33pub struct CountPartial {
35 count: u64,
36 exclude_nans: bool,
38}
39
40impl AggregateFnVTable for Count {
41 type Options = NumericalAggregateOpts;
42 type Partial = CountPartial;
43
44 fn id(&self) -> AggregateFnId {
45 static ID: CachedId = CachedId::new("vortex.count");
46 *ID
47 }
48
49 fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
50 unimplemented!("Count is not yet serializable");
51 }
52
53 fn return_dtype(&self, _options: &Self::Options, _input_dtype: &DType) -> Option<DType> {
54 Some(DType::Primitive(PType::U64, Nullability::NonNullable))
55 }
56
57 fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option<DType> {
58 self.return_dtype(options, input_dtype)
59 }
60
61 fn empty_partial(
62 &self,
63 options: &Self::Options,
64 input_dtype: &DType,
65 ) -> VortexResult<Self::Partial> {
66 Ok(CountPartial {
67 count: 0,
68 exclude_nans: options.skip_nans && input_dtype.is_float(),
69 })
70 }
71
72 fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()> {
73 let val = other
74 .as_primitive()
75 .typed_value::<u64>()
76 .vortex_expect("count partial should not be null");
77 partial.count += val;
78 Ok(())
79 }
80
81 fn to_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
82 Ok(Scalar::primitive(partial.count, Nullability::NonNullable))
83 }
84
85 fn reset(&self, partial: &mut Self::Partial) {
86 partial.count = 0;
87 }
88
89 #[inline]
90 fn is_saturated(&self, _partial: &Self::Partial) -> bool {
91 false
92 }
93
94 fn try_accumulate(
95 &self,
96 state: &mut Self::Partial,
97 batch: &ArrayRef,
98 ctx: &mut ExecutionCtx,
99 ) -> VortexResult<bool> {
100 let mut count = batch.valid_count(ctx)? as u64;
101 if state.exclude_nans {
102 count = count.saturating_sub(nan_count(batch, ctx)? as u64);
104 }
105 state.count += count;
106 Ok(true)
107 }
108
109 fn accumulate(
110 &self,
111 _partial: &mut Self::Partial,
112 _batch: &Columnar,
113 _ctx: &mut ExecutionCtx,
114 ) -> VortexResult<()> {
115 unreachable!("Count::try_accumulate handles all arrays")
116 }
117
118 fn finalize(&self, partials: ArrayRef) -> VortexResult<ArrayRef> {
119 Ok(partials)
120 }
121
122 fn finalize_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
123 self.to_scalar(partial)
124 }
125}
126
127#[cfg(test)]
128mod tests {
129 use std::sync::LazyLock;
130
131 use vortex_buffer::buffer;
132 use vortex_error::VortexExpect;
133 use vortex_error::VortexResult;
134 use vortex_session::VortexSession;
135
136 use crate::ArrayRef;
137 use crate::ExecutionCtx;
138 use crate::IntoArray;
139 use crate::VortexSessionExecute;
140 use crate::aggregate_fn::Accumulator;
141 use crate::aggregate_fn::AggregateFnVTable;
142 use crate::aggregate_fn::DynAccumulator;
143 use crate::aggregate_fn::NumericalAggregateOpts;
144 use crate::aggregate_fn::fns::count::Count;
145 use crate::arrays::ChunkedArray;
146 use crate::arrays::ConstantArray;
147 use crate::arrays::PrimitiveArray;
148 use crate::dtype::DType;
149 use crate::dtype::Nullability;
150 use crate::dtype::PType;
151 use crate::expr::stats::Precision;
152 use crate::expr::stats::Stat;
153 use crate::scalar::Scalar;
154 use crate::scalar::ScalarValue;
155 use crate::validity::Validity;
156
157 static SESSION: LazyLock<VortexSession> = LazyLock::new(vortex_array::array_session);
158
159 pub fn count(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<usize> {
160 let mut acc = Accumulator::try_new(
161 Count,
162 NumericalAggregateOpts::default(),
163 array.dtype().clone(),
164 )?;
165 acc.accumulate(array, ctx)?;
166 let result = acc.finish()?;
167
168 Ok(usize::try_from(
169 result
170 .as_primitive()
171 .typed_value::<u64>()
172 .vortex_expect("count result should not be null"),
173 )?)
174 }
175
176 #[test]
177 fn count_all_valid() -> VortexResult<()> {
178 let array =
179 PrimitiveArray::new(buffer![1i32, 2, 3, 4, 5], Validity::NonNullable).into_array();
180 let mut ctx = SESSION.create_execution_ctx();
181 assert_eq!(count(&array, &mut ctx)?, 5);
182 Ok(())
183 }
184
185 #[test]
186 fn count_with_nulls() -> VortexResult<()> {
187 let array = PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), None, Some(5)])
188 .into_array();
189 let mut ctx = SESSION.create_execution_ctx();
190 assert_eq!(count(&array, &mut ctx)?, 3);
191 Ok(())
192 }
193
194 #[test]
195 fn count_all_null() -> VortexResult<()> {
196 let array = PrimitiveArray::from_option_iter::<i32, _>([None, None, None]).into_array();
197 let mut ctx = SESSION.create_execution_ctx();
198 assert_eq!(count(&array, &mut ctx)?, 0);
199 Ok(())
200 }
201
202 #[test]
203 fn count_empty() -> VortexResult<()> {
204 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
205 let mut acc = Accumulator::try_new(Count, NumericalAggregateOpts::default(), dtype)?;
206 let result = acc.finish()?;
207 assert_eq!(result.as_primitive().typed_value::<u64>(), Some(0));
208 Ok(())
209 }
210
211 #[test]
212 fn count_multi_batch() -> VortexResult<()> {
213 let mut ctx = SESSION.create_execution_ctx();
214 let dtype = DType::Primitive(PType::I32, Nullability::Nullable);
215 let mut acc = Accumulator::try_new(Count, NumericalAggregateOpts::default(), dtype)?;
216
217 let batch1 = PrimitiveArray::from_option_iter([Some(1i32), None, Some(3)]).into_array();
218 acc.accumulate(&batch1, &mut ctx)?;
219
220 let batch2 = PrimitiveArray::from_option_iter([None, Some(5i32)]).into_array();
221 acc.accumulate(&batch2, &mut ctx)?;
222
223 let result = acc.finish()?;
224 assert_eq!(result.as_primitive().typed_value::<u64>(), Some(3));
225 Ok(())
226 }
227
228 #[test]
229 fn count_finish_resets_state() -> VortexResult<()> {
230 let mut ctx = SESSION.create_execution_ctx();
231 let dtype = DType::Primitive(PType::I32, Nullability::Nullable);
232 let mut acc = Accumulator::try_new(Count, NumericalAggregateOpts::default(), dtype)?;
233
234 let batch1 = PrimitiveArray::from_option_iter([Some(1i32), None]).into_array();
235 acc.accumulate(&batch1, &mut ctx)?;
236 let result1 = acc.finish()?;
237 assert_eq!(result1.as_primitive().typed_value::<u64>(), Some(1));
238
239 let batch2 = PrimitiveArray::from_option_iter([Some(2i32), Some(3), None]).into_array();
240 acc.accumulate(&batch2, &mut ctx)?;
241 let result2 = acc.finish()?;
242 assert_eq!(result2.as_primitive().typed_value::<u64>(), Some(2));
243 Ok(())
244 }
245
246 #[test]
247 fn count_state_merge() -> VortexResult<()> {
248 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
249 let mut state = Count.empty_partial(&NumericalAggregateOpts::default(), &dtype)?;
250
251 let scalar1 = Scalar::primitive(5u64, Nullability::NonNullable);
252 Count.combine_partials(&mut state, scalar1)?;
253
254 let scalar2 = Scalar::primitive(3u64, Nullability::NonNullable);
255 Count.combine_partials(&mut state, scalar2)?;
256
257 let result = Count.to_scalar(&state)?;
258 Count.reset(&mut state);
259 assert_eq!(result.as_primitive().typed_value::<u64>(), Some(8));
260 Ok(())
261 }
262
263 fn count_with_options(
264 array: &ArrayRef,
265 ctx: &mut ExecutionCtx,
266 options: NumericalAggregateOpts,
267 ) -> VortexResult<u64> {
268 let mut acc = Accumulator::try_new(Count, options, array.dtype().clone())?;
269 acc.accumulate(array, ctx)?;
270 Ok(acc
271 .finish()?
272 .as_primitive()
273 .typed_value::<u64>()
274 .vortex_expect("count result should not be null"))
275 }
276
277 #[test]
278 fn count_float_excludes_nans_by_default() -> VortexResult<()> {
279 let array =
280 PrimitiveArray::from_option_iter([Some(1.0f64), Some(f64::NAN), None, Some(3.0)])
281 .into_array();
282 let mut ctx = SESSION.create_execution_ctx();
283 assert_eq!(count(&array, &mut ctx)?, 2);
284 Ok(())
285 }
286
287 #[test]
288 fn count_float_includes_nans_when_not_skipping() -> VortexResult<()> {
289 let array =
290 PrimitiveArray::from_option_iter([Some(1.0f64), Some(f64::NAN), None, Some(3.0)])
291 .into_array();
292 let mut ctx = SESSION.create_execution_ctx();
293 assert_eq!(
294 count_with_options(&array, &mut ctx, NumericalAggregateOpts::include_nans())?,
295 3
296 );
297 Ok(())
298 }
299
300 #[test]
301 fn count_float_shortcircuits_on_exact_nan_count_stat() -> VortexResult<()> {
302 let array =
305 PrimitiveArray::new(buffer![1.0f64, 2.0, 3.0, 4.0], Validity::NonNullable).into_array();
306 array
307 .statistics()
308 .set(Stat::NaNCount, Precision::Exact(ScalarValue::from(3u64)));
309 let mut ctx = SESSION.create_execution_ctx();
310 assert_eq!(count(&array, &mut ctx)?, 1);
311 Ok(())
312 }
313
314 #[test]
315 fn count_constant_nan() -> VortexResult<()> {
316 let array = ConstantArray::new(f64::NAN, 5).into_array();
317 let mut ctx = SESSION.create_execution_ctx();
318 assert_eq!(count(&array, &mut ctx)?, 0);
319 assert_eq!(
320 count_with_options(&array, &mut ctx, NumericalAggregateOpts::include_nans())?,
321 5
322 );
323 Ok(())
324 }
325
326 #[test]
327 fn count_constant_non_null() -> VortexResult<()> {
328 let array = ConstantArray::new(42i32, 10);
329 let mut ctx = SESSION.create_execution_ctx();
330 assert_eq!(count(&array.into_array(), &mut ctx)?, 10);
331 Ok(())
332 }
333
334 #[test]
335 fn count_constant_null() -> VortexResult<()> {
336 let array = ConstantArray::new(
337 Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)),
338 10,
339 );
340 let mut ctx = SESSION.create_execution_ctx();
341 assert_eq!(count(&array.into_array(), &mut ctx)?, 0);
342 Ok(())
343 }
344
345 #[test]
346 fn count_chunked() -> VortexResult<()> {
347 let chunk1 = PrimitiveArray::from_option_iter([Some(1i32), None, Some(3)]);
348 let chunk2 = PrimitiveArray::from_option_iter([None, Some(5i32), None]);
349 let dtype = chunk1.dtype().clone();
350 let chunked = ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype)?;
351 let mut ctx = SESSION.create_execution_ctx();
352 assert_eq!(count(&chunked.into_array(), &mut ctx)?, 3);
353 Ok(())
354 }
355}