1use std::sync::Arc;
12
13use arrow::array::RecordBatch;
14use arrow_array::ArrayRef;
15use arrow_schema::{Field, Schema, SchemaRef};
16use datafusion::scalar::ScalarValue;
17use uni_plugin::QName;
18use uni_plugin::adapter_common::arrow_types::argtype_to_arrow;
19use uni_plugin::errors::FnError;
20use uni_plugin::traits::aggregate::{AggSignature, AggregatePluginFn, PluginAccumulator};
21use uni_plugin_wasm_rt::ipc::{decode_batch, encode_batch};
22
23use crate::adapter_common::{acquire, ipc_to_fn_err};
24use crate::loader::AggregatePluginInstance;
25use crate::pool::WasmInstancePool;
26
27pub struct ComponentAggregateFn {
29 pool: Arc<WasmInstancePool<AggregatePluginInstance>>,
30 qname: QName,
31 sig: AggSignature,
32}
33
34impl std::fmt::Debug for ComponentAggregateFn {
35 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36 f.debug_struct("ComponentAggregateFn")
37 .field("qname", &self.qname)
38 .field("signature", &self.sig)
39 .finish_non_exhaustive()
40 }
41}
42
43impl ComponentAggregateFn {
44 #[must_use]
46 pub fn new(
47 pool: Arc<WasmInstancePool<AggregatePluginInstance>>,
48 qname: QName,
49 sig: AggSignature,
50 ) -> Self {
51 Self { pool, qname, sig }
52 }
53
54 fn call_new(&self) -> Result<Vec<u8>, FnError> {
55 let mut leased = acquire(&self.pool, "aggregate")?;
56 let qname_str = self.qname.to_string();
57 let state = leased.get_mut().agg_new(&qname_str).map_err(|e| {
58 FnError::new(
59 FnError::CODE_UNEXPECTED_NULL,
60 format!("wasm agg_new `{qname_str}`: {e}"),
61 )
62 })?;
63 drop(leased);
64 Ok(state)
65 }
66}
67
68impl AggregatePluginFn for ComponentAggregateFn {
69 fn signature(&self) -> &AggSignature {
70 &self.sig
71 }
72
73 fn create_accumulator(&self) -> Box<dyn PluginAccumulator> {
74 let (state, init_err) = match self.call_new() {
75 Ok(s) => (s, None),
76 Err(e) => (Vec::new(), Some(e)),
77 };
78 Box::new(ComponentAggregateAccumulator {
79 state,
80 init_err,
81 pool: Arc::clone(&self.pool),
82 qname: self.qname.to_string(),
83 args_schema: build_args_schema(&self.sig),
84 returns_field: build_returns_field(&self.sig),
85 })
86 }
87}
88
89struct ComponentAggregateAccumulator {
90 state: Vec<u8>,
91 init_err: Option<FnError>,
92 pool: Arc<WasmInstancePool<AggregatePluginInstance>>,
93 qname: String,
94 args_schema: SchemaRef,
95 returns_field: Field,
96}
97
98impl ComponentAggregateAccumulator {
99 fn surface_init_err(&self) -> Result<(), FnError> {
100 if let Some(e) = &self.init_err {
101 return Err(FnError::new(
102 e.code,
103 format!("aggregate init failed: {}", e.message),
104 ));
105 }
106 Ok(())
107 }
108}
109
110impl PluginAccumulator for ComponentAggregateAccumulator {
111 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<(), FnError> {
112 self.surface_init_err()?;
113 let batch =
114 RecordBatch::try_new(Arc::clone(&self.args_schema), values.to_vec()).map_err(|e| {
115 FnError::new(
116 FnError::CODE_TYPE_COERCION,
117 format!("update_batch RecordBatch: {e}"),
118 )
119 })?;
120 let ipc = encode_batch(&batch).map_err(ipc_to_fn_err)?;
121 let mut leased = acquire(&self.pool, "aggregate")?;
122 let new_state = leased
123 .get_mut()
124 .agg_update(&self.qname, &self.state, &ipc)
125 .map_err(|e| {
126 FnError::new(
127 FnError::CODE_UNEXPECTED_NULL,
128 format!("wasm agg_update `{}`: {e}", self.qname),
129 )
130 })?;
131 drop(leased);
132 self.state = new_state;
133 Ok(())
134 }
135
136 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<(), FnError> {
137 self.surface_init_err()?;
138 if states.len() != 1 {
139 return Err(FnError::new(
140 FnError::CODE_TYPE_COERCION,
141 format!(
142 "merge_batch expects 1 state column (opaque Binary); got {}",
143 states.len()
144 ),
145 ));
146 }
147 let schema: SchemaRef = Arc::new(Schema::new(vec![Field::new(
148 "partial_state",
149 states[0].data_type().clone(),
150 true,
151 )]));
152 let batch = RecordBatch::try_new(schema, states.to_vec()).map_err(|e| {
153 FnError::new(
154 FnError::CODE_TYPE_COERCION,
155 format!("merge_batch RecordBatch: {e}"),
156 )
157 })?;
158 let ipc = encode_batch(&batch).map_err(ipc_to_fn_err)?;
159 let mut leased = acquire(&self.pool, "aggregate")?;
160 let new_state = leased
161 .get_mut()
162 .agg_merge(&self.qname, &self.state, &ipc)
163 .map_err(|e| {
164 FnError::new(
165 FnError::CODE_UNEXPECTED_NULL,
166 format!("wasm agg_merge `{}`: {e}", self.qname),
167 )
168 })?;
169 drop(leased);
170 self.state = new_state;
171 Ok(())
172 }
173
174 fn state(&self) -> Result<Vec<ScalarValue>, FnError> {
175 self.surface_init_err()?;
176 Ok(vec![ScalarValue::Binary(Some(self.state.clone()))])
177 }
178
179 fn evaluate(&self) -> Result<ScalarValue, FnError> {
180 self.surface_init_err()?;
181 let mut leased = acquire(&self.pool, "aggregate")?;
182 let out_bytes = leased
183 .get_mut()
184 .agg_evaluate(&self.qname, &self.state)
185 .map_err(|e| {
186 FnError::new(
187 FnError::CODE_UNEXPECTED_NULL,
188 format!("wasm agg_evaluate `{}`: {e}", self.qname),
189 )
190 })?;
191 drop(leased);
192 let batch = decode_batch(&out_bytes)
193 .map_err(ipc_to_fn_err)?
194 .ok_or_else(|| {
195 FnError::new(
196 FnError::CODE_UNEXPECTED_NULL,
197 format!("plugin agg_evaluate `{}` empty IPC", self.qname),
198 )
199 })?;
200 if batch.num_columns() != 1 || batch.num_rows() != 1 {
201 return Err(FnError::new(
202 FnError::CODE_TYPE_COERCION,
203 format!(
204 "plugin agg_evaluate `{}` must return 1×1; got {}×{}",
205 self.qname,
206 batch.num_rows(),
207 batch.num_columns()
208 ),
209 ));
210 }
211 if batch.column(0).data_type() != self.returns_field.data_type() {
212 return Err(FnError::new(
213 FnError::CODE_TYPE_COERCION,
214 format!(
215 "plugin agg_evaluate `{}` returned {:?}, expected {:?}",
216 self.qname,
217 batch.column(0).data_type(),
218 self.returns_field.data_type()
219 ),
220 ));
221 }
222 ScalarValue::try_from_array(batch.column(0), 0).map_err(|e| {
223 FnError::new(
224 FnError::CODE_TYPE_COERCION,
225 format!("agg_evaluate ScalarValue: {e}"),
226 )
227 })
228 }
229
230 fn size(&self) -> usize {
231 std::mem::size_of::<Self>() + self.state.capacity()
232 }
233}
234
235fn build_args_schema(sig: &AggSignature) -> SchemaRef {
236 let fields: Vec<Field> = sig
237 .args
238 .iter()
239 .enumerate()
240 .map(|(i, t)| Field::new(format!("arg{i}"), argtype_to_arrow(t), true))
241 .collect();
242 Arc::new(Schema::new(fields))
243}
244
245fn build_returns_field(sig: &AggSignature) -> Field {
246 Field::new("returns", argtype_to_arrow(&sig.returns), true)
247}