vortex_array/aggregate_fn/
accumulator.rs1use vortex_error::VortexResult;
5use vortex_error::vortex_ensure;
6use vortex_session::VortexSession;
7
8use crate::AnyCanonical;
9use crate::ArrayRef;
10use crate::Canonical;
11use crate::DynArray;
12use crate::VortexSessionExecute;
13use crate::aggregate_fn::AggregateFn;
14use crate::aggregate_fn::AggregateFnRef;
15use crate::aggregate_fn::AggregateFnVTable;
16use crate::aggregate_fn::session::AggregateFnSessionExt;
17use crate::dtype::DType;
18use crate::executor::MAX_ITERATIONS;
19use crate::scalar::Scalar;
20
21pub type AccumulatorRef = Box<dyn DynAccumulator>;
23
24pub struct Accumulator<V: AggregateFnVTable> {
26 vtable: V,
28 aggregate_fn: AggregateFnRef,
30 dtype: DType,
32 return_dtype: DType,
34 partial_dtype: DType,
36 partial: V::Partial,
38 session: VortexSession,
40}
41
42impl<V: AggregateFnVTable> Accumulator<V> {
43 pub fn try_new(
44 vtable: V,
45 options: V::Options,
46 dtype: DType,
47 session: VortexSession,
48 ) -> VortexResult<Self> {
49 let return_dtype = vtable.return_dtype(&options, &dtype)?;
50 let partial_dtype = vtable.partial_dtype(&options, &dtype)?;
51 let partial = vtable.empty_partial(&options, &dtype)?;
52 let aggregate_fn = AggregateFn::new(vtable.clone(), options).erased();
53
54 Ok(Self {
55 vtable,
56 aggregate_fn,
57 dtype,
58 return_dtype,
59 partial_dtype,
60 partial,
61 session,
62 })
63 }
64}
65
66pub trait DynAccumulator: 'static + Send {
69 fn accumulate(&mut self, batch: &ArrayRef) -> VortexResult<()>;
71
72 fn is_saturated(&self) -> bool;
74
75 fn flush(&mut self) -> VortexResult<Scalar>;
79
80 fn finish(&mut self) -> VortexResult<Scalar>;
84}
85
86impl<V: AggregateFnVTable> DynAccumulator for Accumulator<V> {
87 fn accumulate(&mut self, batch: &ArrayRef) -> VortexResult<()> {
88 if self.is_saturated() {
89 return Ok(());
90 }
91
92 vortex_ensure!(
93 batch.dtype() == &self.dtype,
94 "Input DType mismatch: expected {}, got {}",
95 self.dtype,
96 batch.dtype()
97 );
98
99 let kernels = &self.session.aggregate_fns().kernels;
100
101 let mut ctx = self.session.create_execution_ctx();
102 let mut batch = batch.clone();
103 for _ in 0..*MAX_ITERATIONS {
104 if batch.is::<AnyCanonical>() {
105 break;
106 }
107
108 let kernels_r = kernels.read();
109 let batch_id = batch.encoding_id();
110 if let Some(result) = kernels_r
111 .get(&(batch_id.clone(), Some(self.aggregate_fn.id())))
112 .or_else(|| kernels_r.get(&(batch_id, None)))
113 .and_then(|kernel| {
114 kernel
115 .aggregate(&self.aggregate_fn, &batch, &mut ctx)
116 .transpose()
117 })
118 .transpose()?
119 {
120 vortex_ensure!(
121 result.dtype() == &self.partial_dtype,
122 "Aggregate kernel returned {}, expected {}",
123 result.dtype(),
124 self.partial_dtype,
125 );
126 self.vtable.combine_partials(&mut self.partial, result)?;
127 return Ok(());
128 }
129
130 batch = batch.execute(&mut ctx)?;
132 }
133
134 let canonical = batch.execute::<Canonical>(&mut ctx)?;
136
137 self.vtable
138 .accumulate(&mut self.partial, &canonical, &mut ctx)
139 }
140
141 fn is_saturated(&self) -> bool {
142 self.vtable.is_saturated(&self.partial)
143 }
144
145 fn flush(&mut self) -> VortexResult<Scalar> {
146 let partial = self.vtable.flush(&mut self.partial)?;
147
148 #[cfg(debug_assertions)]
149 {
150 vortex_ensure!(
151 partial.dtype() == &self.partial_dtype,
152 "Aggregate kernel returned incorrect DType on flush: expected {}, got {}",
153 self.partial_dtype,
154 partial.dtype(),
155 );
156 }
157
158 Ok(partial)
159 }
160
161 fn finish(&mut self) -> VortexResult<Scalar> {
162 let partial = self.flush()?;
163 let result = self.vtable.finalize_scalar(partial)?;
164
165 vortex_ensure!(
166 result.dtype() == &self.return_dtype,
167 "Aggregate kernel returned incorrect DType on finalize: expected {}, got {}",
168 self.return_dtype,
169 result.dtype(),
170 );
171
172 Ok(result)
173 }
174}