vortex_array/aggregate_fn/
accumulator.rs1use vortex_error::VortexResult;
5use vortex_error::vortex_ensure;
6use vortex_error::vortex_err;
7
8use crate::AnyCanonical;
9use crate::ArrayRef;
10use crate::Columnar;
11use crate::DynArray;
12use crate::ExecutionCtx;
13use crate::aggregate_fn::AggregateFn;
14use crate::aggregate_fn::AggregateFnRef;
15use crate::aggregate_fn::AggregateFnVTable;
16use crate::aggregate_fn::session::AggregateFnSessionExt;
17use crate::dtype::DType;
18use crate::executor::MAX_ITERATIONS;
19use crate::scalar::Scalar;
20
21pub type AccumulatorRef = Box<dyn DynAccumulator>;
23
24pub struct Accumulator<V: AggregateFnVTable> {
26 vtable: V,
28 aggregate_fn: AggregateFnRef,
30 dtype: DType,
32 return_dtype: DType,
34 partial_dtype: DType,
36 partial: V::Partial,
38}
39
40impl<V: AggregateFnVTable> Accumulator<V> {
41 pub fn try_new(vtable: V, options: V::Options, dtype: DType) -> VortexResult<Self> {
42 let return_dtype = vtable.return_dtype(&options, &dtype).ok_or_else(|| {
43 vortex_err!(
44 "Aggregate function {} cannot be applied to dtype {}",
45 vtable.id(),
46 dtype
47 )
48 })?;
49 let partial_dtype = vtable.partial_dtype(&options, &dtype).ok_or_else(|| {
50 vortex_err!(
51 "Aggregate function {} cannot be applied to dtype {}",
52 vtable.id(),
53 dtype
54 )
55 })?;
56 let partial = vtable.empty_partial(&options, &dtype)?;
57 let aggregate_fn = AggregateFn::new(vtable.clone(), options).erased();
58
59 Ok(Self {
60 vtable,
61 aggregate_fn,
62 dtype,
63 return_dtype,
64 partial_dtype,
65 partial,
66 })
67 }
68}
69
70pub trait DynAccumulator: 'static + Send {
73 fn accumulate(&mut self, batch: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()>;
75
76 fn is_saturated(&self) -> bool;
78
79 fn flush(&mut self) -> VortexResult<Scalar>;
83
84 fn finish(&mut self) -> VortexResult<Scalar>;
88}
89
90impl<V: AggregateFnVTable> DynAccumulator for Accumulator<V> {
91 fn accumulate(&mut self, batch: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()> {
92 if self.is_saturated() {
93 return Ok(());
94 }
95
96 vortex_ensure!(
97 batch.dtype() == &self.dtype,
98 "Input DType mismatch: expected {}, got {}",
99 self.dtype,
100 batch.dtype()
101 );
102
103 let session = ctx.session().clone();
104 let kernels = &session.aggregate_fns().kernels;
105
106 let mut batch = batch.clone();
107 for _ in 0..*MAX_ITERATIONS {
108 if batch.is::<AnyCanonical>() {
109 break;
110 }
111
112 let kernels_r = kernels.read();
113 let batch_id = batch.encoding_id();
114 if let Some(result) = kernels_r
115 .get(&(batch_id.clone(), Some(self.aggregate_fn.id())))
116 .or_else(|| kernels_r.get(&(batch_id, None)))
117 .and_then(|kernel| {
118 kernel
119 .aggregate(&self.aggregate_fn, &batch, ctx)
120 .transpose()
121 })
122 .transpose()?
123 {
124 vortex_ensure!(
125 result.dtype() == &self.partial_dtype,
126 "Aggregate kernel returned {}, expected {}",
127 result.dtype(),
128 self.partial_dtype,
129 );
130 self.vtable.combine_partials(&mut self.partial, result)?;
131 return Ok(());
132 }
133
134 batch = batch.execute(ctx)?;
136 }
137
138 let columnar = batch.execute::<Columnar>(ctx)?;
140
141 self.vtable.accumulate(&mut self.partial, &columnar, ctx)
142 }
143
144 fn is_saturated(&self) -> bool {
145 self.vtable.is_saturated(&self.partial)
146 }
147
148 fn flush(&mut self) -> VortexResult<Scalar> {
149 let partial = self.vtable.to_scalar(&self.partial)?;
150 self.vtable.reset(&mut self.partial);
151
152 #[cfg(debug_assertions)]
153 {
154 vortex_ensure!(
155 partial.dtype() == &self.partial_dtype,
156 "Aggregate kernel returned incorrect DType on flush: expected {}, got {}",
157 self.partial_dtype,
158 partial.dtype(),
159 );
160 }
161
162 Ok(partial)
163 }
164
165 fn finish(&mut self) -> VortexResult<Scalar> {
166 let result = self.vtable.finalize_scalar(&self.partial)?;
167 self.vtable.reset(&mut self.partial);
168
169 vortex_ensure!(
170 result.dtype() == &self.return_dtype,
171 "Aggregate kernel returned incorrect DType on finalize: expected {}, got {}",
172 self.return_dtype,
173 result.dtype(),
174 );
175
176 Ok(result)
177 }
178}