vortex_array/aggregate_fn/
accumulator.rs1use vortex_error::VortexResult;
5use vortex_error::vortex_ensure;
6use vortex_error::vortex_err;
7
8use crate::AnyCanonical;
9use crate::ArrayRef;
10use crate::Columnar;
11use crate::ExecutionCtx;
12use crate::aggregate_fn::AggregateFn;
13use crate::aggregate_fn::AggregateFnRef;
14use crate::aggregate_fn::AggregateFnVTable;
15use crate::aggregate_fn::session::AggregateFnSessionExt;
16use crate::dtype::DType;
17use crate::executor::MAX_ITERATIONS;
18use crate::scalar::Scalar;
19
20pub type AccumulatorRef = Box<dyn DynAccumulator>;
22
23pub struct Accumulator<V: AggregateFnVTable> {
25 vtable: V,
27 aggregate_fn: AggregateFnRef,
29 dtype: DType,
31 return_dtype: DType,
33 partial_dtype: DType,
35 partial: V::Partial,
37}
38
39impl<V: AggregateFnVTable> Accumulator<V> {
40 pub fn try_new(vtable: V, options: V::Options, dtype: DType) -> VortexResult<Self> {
41 let return_dtype = vtable.return_dtype(&options, &dtype).ok_or_else(|| {
42 vortex_err!(
43 "Aggregate function {} cannot be applied to dtype {}",
44 vtable.id(),
45 dtype
46 )
47 })?;
48 let partial_dtype = vtable.partial_dtype(&options, &dtype).ok_or_else(|| {
49 vortex_err!(
50 "Aggregate function {} cannot be applied to dtype {}",
51 vtable.id(),
52 dtype
53 )
54 })?;
55 let partial = vtable.empty_partial(&options, &dtype)?;
56 let aggregate_fn = AggregateFn::new(vtable.clone(), options).erased();
57
58 Ok(Self {
59 vtable,
60 aggregate_fn,
61 dtype,
62 return_dtype,
63 partial_dtype,
64 partial,
65 })
66 }
67}
68
69pub trait DynAccumulator: 'static + Send {
72 fn accumulate(&mut self, batch: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()>;
74
75 fn is_saturated(&self) -> bool;
77
78 fn flush(&mut self) -> VortexResult<Scalar>;
82
83 fn finish(&mut self) -> VortexResult<Scalar>;
87}
88
89impl<V: AggregateFnVTable> DynAccumulator for Accumulator<V> {
90 fn accumulate(&mut self, batch: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()> {
91 if self.is_saturated() {
92 return Ok(());
93 }
94
95 vortex_ensure!(
96 batch.dtype() == &self.dtype,
97 "Input DType mismatch: expected {}, got {}",
98 self.dtype,
99 batch.dtype()
100 );
101
102 if self.vtable.try_accumulate(&mut self.partial, batch, ctx)? {
104 return Ok(());
105 }
106
107 let session = ctx.session().clone();
108 let kernels = &session.aggregate_fns().kernels;
109
110 let mut batch = batch.clone();
111 for _ in 0..*MAX_ITERATIONS {
112 if batch.is::<AnyCanonical>() {
113 break;
114 }
115
116 let kernels_r = kernels.read();
117 let batch_id = batch.encoding_id();
118 if let Some(result) = kernels_r
119 .get(&(batch_id.clone(), Some(self.aggregate_fn.id())))
120 .or_else(|| kernels_r.get(&(batch_id, None)))
121 .and_then(|kernel| {
122 kernel
123 .aggregate(&self.aggregate_fn, &batch, ctx)
124 .transpose()
125 })
126 .transpose()?
127 {
128 vortex_ensure!(
129 result.dtype() == &self.partial_dtype,
130 "Aggregate kernel returned {}, expected {}",
131 result.dtype(),
132 self.partial_dtype,
133 );
134 self.vtable.combine_partials(&mut self.partial, result)?;
135 return Ok(());
136 }
137
138 batch = batch.execute(ctx)?;
140 }
141
142 let columnar = batch.execute::<Columnar>(ctx)?;
144
145 self.vtable.accumulate(&mut self.partial, &columnar, ctx)
146 }
147
148 fn is_saturated(&self) -> bool {
149 self.vtable.is_saturated(&self.partial)
150 }
151
152 fn flush(&mut self) -> VortexResult<Scalar> {
153 let partial = self.vtable.to_scalar(&self.partial)?;
154 self.vtable.reset(&mut self.partial);
155
156 #[cfg(debug_assertions)]
157 {
158 vortex_ensure!(
159 partial.dtype() == &self.partial_dtype,
160 "Aggregate kernel returned incorrect DType on flush: expected {}, got {}",
161 self.partial_dtype,
162 partial.dtype(),
163 );
164 }
165
166 Ok(partial)
167 }
168
169 fn finish(&mut self) -> VortexResult<Scalar> {
170 let result = self.vtable.finalize_scalar(&self.partial)?;
171 self.vtable.reset(&mut self.partial);
172
173 vortex_ensure!(
174 result.dtype() == &self.return_dtype,
175 "Aggregate kernel returned incorrect DType on finalize: expected {}, got {}",
176 self.return_dtype,
177 result.dtype(),
178 );
179
180 Ok(result)
181 }
182}