1use std::sync::Arc;
27
28use arrow::array::RecordBatch;
29use arrow_array::ArrayRef;
30use arrow_schema::{Field, Schema, SchemaRef};
31use datafusion::scalar::ScalarValue;
32use uni_plugin::QName;
33use uni_plugin::adapter_common::arrow_types::argtype_to_arrow;
34use uni_plugin::errors::FnError;
35use uni_plugin::traits::aggregate::{AggSignature, AggregatePluginFn, PluginAccumulator};
36
37use crate::adapter_common::{acquire, extism_err_to_fn_err, sanitize_qname};
38use crate::ipc::{decode_batch, encode_batch};
39use crate::pool::ExtismInstancePool;
40
41#[must_use]
43pub(crate) fn agg_new_export_name(qname: &QName) -> String {
44 format!("agg_{}_new", sanitize_qname(qname))
45}
46
47#[must_use]
49pub(crate) fn agg_update_export_name(qname: &QName) -> String {
50 format!("agg_{}_update", sanitize_qname(qname))
51}
52
53#[must_use]
55pub(crate) fn agg_merge_export_name(qname: &QName) -> String {
56 format!("agg_{}_merge", sanitize_qname(qname))
57}
58
59#[must_use]
61pub(crate) fn agg_evaluate_export_name(qname: &QName) -> String {
62 format!("agg_{}_evaluate", sanitize_qname(qname))
63}
64
65pub struct ExtismAggregateFn {
67 pool: Arc<ExtismInstancePool<extism::Plugin>>,
68 qname: QName,
69 sig: AggSignature,
70 new_export: String,
71 update_export: String,
72 merge_export: String,
73 evaluate_export: String,
74}
75
76impl std::fmt::Debug for ExtismAggregateFn {
77 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78 f.debug_struct("ExtismAggregateFn")
79 .field("qname", &self.qname)
80 .field("signature", &self.sig)
81 .finish_non_exhaustive()
82 }
83}
84
85impl ExtismAggregateFn {
86 #[must_use]
88 pub fn new(
89 pool: Arc<ExtismInstancePool<extism::Plugin>>,
90 qname: QName,
91 sig: AggSignature,
92 ) -> Self {
93 let new_export = agg_new_export_name(&qname);
94 let update_export = agg_update_export_name(&qname);
95 let merge_export = agg_merge_export_name(&qname);
96 let evaluate_export = agg_evaluate_export_name(&qname);
97 Self {
98 pool,
99 qname,
100 sig,
101 new_export,
102 update_export,
103 merge_export,
104 evaluate_export,
105 }
106 }
107
108 fn call_new(&self) -> Result<Vec<u8>, FnError> {
109 let mut leased = acquire(&self.pool)?;
110 let bytes: Vec<u8> = leased
111 .get_mut()
112 .call::<&[u8], &[u8]>(&self.new_export, &[])
113 .map_err(|e| {
114 FnError::new(
115 FnError::CODE_UNEXPECTED_NULL,
116 format!("extism call `{}` failed: {e}", self.new_export),
117 )
118 })?
119 .to_vec();
120 drop(leased);
121 Ok(bytes)
122 }
123}
124
125impl AggregatePluginFn for ExtismAggregateFn {
126 fn signature(&self) -> &AggSignature {
127 &self.sig
128 }
129
130 fn create_accumulator(&self) -> Box<dyn PluginAccumulator> {
131 let (state, init_err) = match self.call_new() {
137 Ok(s) => (s, None),
138 Err(e) => (Vec::new(), Some(e)),
139 };
140 Box::new(ExtismAggregateAccumulator {
141 state,
142 init_err,
143 pool: Arc::clone(&self.pool),
144 update_export: self.update_export.clone(),
145 merge_export: self.merge_export.clone(),
146 evaluate_export: self.evaluate_export.clone(),
147 args_schema: build_args_schema(&self.sig),
148 returns_field: build_returns_field(&self.sig),
149 })
150 }
151}
152
153struct ExtismAggregateAccumulator {
155 state: Vec<u8>,
156 init_err: Option<FnError>,
157 pool: Arc<ExtismInstancePool<extism::Plugin>>,
158 update_export: String,
159 merge_export: String,
160 evaluate_export: String,
161 args_schema: SchemaRef,
162 returns_field: Field,
163}
164
165impl ExtismAggregateAccumulator {
166 fn surface_init_err(&self) -> Result<(), FnError> {
167 if let Some(e) = &self.init_err {
168 return Err(FnError::new(
169 e.code,
170 format!("aggregate init failed: {}", e.message),
171 ));
172 }
173 Ok(())
174 }
175
176 fn call_with_envelope(&self, export: &str, batch: RecordBatch) -> Result<Vec<u8>, FnError> {
177 let ipc = encode_batch(&batch).map_err(extism_err_to_fn_err)?;
178 if u32::try_from(self.state.len()).is_err() {
182 return Err(FnError::new(
183 FnError::CODE_RESOURCE_LIMIT,
184 "aggregate state exceeds u32::MAX bytes",
185 ));
186 }
187 let buf = build_envelope(&self.state, &ipc);
188
189 let mut leased = acquire(&self.pool)?;
190 let out: Vec<u8> = leased
191 .get_mut()
192 .call::<&[u8], &[u8]>(export, &buf)
193 .map_err(|e| {
194 FnError::new(
195 FnError::CODE_UNEXPECTED_NULL,
196 format!("extism call `{export}` failed: {e}"),
197 )
198 })?
199 .to_vec();
200 drop(leased);
201 Ok(out)
202 }
203}
204
205impl PluginAccumulator for ExtismAggregateAccumulator {
206 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<(), FnError> {
207 self.surface_init_err()?;
208 let batch =
209 RecordBatch::try_new(Arc::clone(&self.args_schema), values.to_vec()).map_err(|e| {
210 FnError::new(
211 FnError::CODE_TYPE_COERCION,
212 format!("update_batch: RecordBatch assembly: {e}"),
213 )
214 })?;
215 let new_state = self.call_with_envelope(&self.update_export, batch)?;
216 self.state = new_state;
217 Ok(())
218 }
219
220 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<(), FnError> {
221 self.surface_init_err()?;
222 if states.len() != 1 {
223 return Err(FnError::new(
224 FnError::CODE_TYPE_COERCION,
225 format!(
226 "merge_batch expects exactly 1 state column (opaque Binary); got {}",
227 states.len()
228 ),
229 ));
230 }
231 let schema: SchemaRef = Arc::new(Schema::new(vec![Field::new(
232 "partial_state",
233 states[0].data_type().clone(),
234 true,
235 )]));
236 let batch = RecordBatch::try_new(schema, states.to_vec()).map_err(|e| {
237 FnError::new(
238 FnError::CODE_TYPE_COERCION,
239 format!("merge_batch: RecordBatch assembly: {e}"),
240 )
241 })?;
242 let new_state = self.call_with_envelope(&self.merge_export, batch)?;
243 self.state = new_state;
244 Ok(())
245 }
246
247 fn state(&self) -> Result<Vec<ScalarValue>, FnError> {
248 self.surface_init_err()?;
249 Ok(vec![ScalarValue::Binary(Some(self.state.clone()))])
250 }
251
252 fn evaluate(&self) -> Result<ScalarValue, FnError> {
253 self.surface_init_err()?;
254 let mut leased = acquire(&self.pool)?;
255 let out_bytes: Vec<u8> = leased
256 .get_mut()
257 .call::<&[u8], &[u8]>(&self.evaluate_export, &self.state)
258 .map_err(|e| {
259 FnError::new(
260 FnError::CODE_UNEXPECTED_NULL,
261 format!("extism call `{}` failed: {e}", self.evaluate_export),
262 )
263 })?
264 .to_vec();
265 drop(leased);
266
267 let batch = decode_batch(&out_bytes)
268 .map_err(extism_err_to_fn_err)?
269 .ok_or_else(|| {
270 FnError::new(
271 FnError::CODE_UNEXPECTED_NULL,
272 format!(
273 "plugin `{}` returned an empty IPC stream",
274 self.evaluate_export
275 ),
276 )
277 })?;
278 if batch.num_columns() != 1 || batch.num_rows() != 1 {
279 return Err(FnError::new(
280 FnError::CODE_TYPE_COERCION,
281 format!(
282 "plugin `{}` must return a 1-row × 1-col batch; got {} rows × {} cols",
283 self.evaluate_export,
284 batch.num_rows(),
285 batch.num_columns()
286 ),
287 ));
288 }
289 if batch.column(0).data_type() != self.returns_field.data_type() {
291 return Err(FnError::new(
292 FnError::CODE_TYPE_COERCION,
293 format!(
294 "plugin `{}` returned column type {:?}, expected {:?}",
295 self.evaluate_export,
296 batch.column(0).data_type(),
297 self.returns_field.data_type()
298 ),
299 ));
300 }
301 ScalarValue::try_from_array(batch.column(0), 0).map_err(|e| {
302 FnError::new(
303 FnError::CODE_TYPE_COERCION,
304 format!("evaluate: ScalarValue::try_from_array: {e}"),
305 )
306 })
307 }
308
309 fn size(&self) -> usize {
310 std::mem::size_of::<Self>() + self.state.capacity()
311 }
312}
313
314fn build_args_schema(sig: &AggSignature) -> SchemaRef {
315 let fields: Vec<Field> = sig
316 .args
317 .iter()
318 .enumerate()
319 .map(|(i, t)| Field::new(format!("arg{i}"), argtype_to_arrow(t), true))
320 .collect();
321 Arc::new(Schema::new(fields))
322}
323
324fn build_returns_field(sig: &AggSignature) -> Field {
325 Field::new("returns", argtype_to_arrow(&sig.returns), true)
326}
327
328#[doc(hidden)]
337#[must_use]
338pub fn build_envelope(state: &[u8], ipc: &[u8]) -> Vec<u8> {
339 let mut buf = Vec::with_capacity(4 + state.len() + ipc.len());
340 buf.extend_from_slice(&u32::try_from(state.len()).unwrap_or(u32::MAX).to_le_bytes());
341 buf.extend_from_slice(state);
342 buf.extend_from_slice(ipc);
343 buf
344}
345
346pub fn parse_envelope(buf: &[u8]) -> Result<(&[u8], &[u8]), String> {
357 if buf.len() < 4 {
358 return Err(format!("envelope too short: {} bytes < 4", buf.len()));
359 }
360 let len_bytes: [u8; 4] = buf[..4].try_into().expect("4 bytes");
361 let state_len = u32::from_le_bytes(len_bytes) as usize;
362 let end = 4usize
363 .checked_add(state_len)
364 .ok_or_else(|| "state length overflow".to_owned())?;
365 if end > buf.len() {
366 return Err(format!(
367 "declared state_len {} overruns buffer of {} bytes",
368 state_len,
369 buf.len()
370 ));
371 }
372 Ok((&buf[4..end], &buf[end..]))
373}
374
375#[cfg(test)]
376mod tests {
377 use super::*;
378
379 #[test]
380 fn export_name_format() {
381 let q = QName::parse("stats.weighted_mean").expect("valid");
382 assert_eq!(agg_new_export_name(&q), "agg_stats_weighted_mean_new");
383 assert_eq!(agg_update_export_name(&q), "agg_stats_weighted_mean_update");
384 assert_eq!(agg_merge_export_name(&q), "agg_stats_weighted_mean_merge");
385 assert_eq!(
386 agg_evaluate_export_name(&q),
387 "agg_stats_weighted_mean_evaluate"
388 );
389 }
390
391 #[test]
392 fn envelope_roundtrip_preserves_state_and_ipc() {
393 let state = b"opaque-state-blob".as_slice();
394 let ipc = b"\x01\x02\x03not-real-but-distinct".as_slice();
395 let env = build_envelope(state, ipc);
396 let (got_state, got_ipc) = parse_envelope(&env).expect("parse");
397 assert_eq!(got_state, state);
398 assert_eq!(got_ipc, ipc);
399 }
400
401 #[test]
402 fn envelope_with_empty_state() {
403 let env = build_envelope(&[], b"ipc");
404 let (state, ipc) = parse_envelope(&env).unwrap();
405 assert!(state.is_empty());
406 assert_eq!(ipc, b"ipc");
407 }
408
409 #[test]
410 fn envelope_with_empty_ipc() {
411 let env = build_envelope(b"state-only", &[]);
412 let (state, ipc) = parse_envelope(&env).unwrap();
413 assert_eq!(state, b"state-only");
414 assert!(ipc.is_empty());
415 }
416
417 #[test]
418 fn parse_envelope_rejects_short_buffer() {
419 assert!(parse_envelope(&[1u8, 2]).is_err());
420 }
421
422 #[test]
423 fn parse_envelope_rejects_overrun() {
424 let buf = vec![0xFFu8, 0xFF, 0xFF, 0xFF];
426 assert!(parse_envelope(&buf).is_err());
427 }
428
429 #[test]
430 fn args_schema_matches_signature_args() {
431 use arrow_schema::DataType;
432 use datafusion::logical_expr::Volatility;
433 use uni_plugin::traits::scalar::ArgType;
434 let sig = AggSignature::new(
435 vec![ArgType::Primitive(DataType::Float64), ArgType::CypherValue],
436 ArgType::Primitive(DataType::Float64),
437 vec![Field::new("state", DataType::Binary, true)],
438 Volatility::Immutable,
439 );
440 let schema = build_args_schema(&sig);
441 assert_eq!(schema.fields().len(), 2);
442 assert_eq!(schema.field(0).name(), "arg0");
443 assert_eq!(schema.field(0).data_type(), &DataType::Float64);
444 assert_eq!(schema.field(1).name(), "arg1");
445 assert_eq!(schema.field(1).data_type(), &DataType::LargeBinary);
446 }
447
448 #[test]
449 fn build_returns_field_uses_signature_returns() {
450 use arrow_schema::DataType;
451 use datafusion::logical_expr::Volatility;
452 use uni_plugin::traits::scalar::ArgType;
453 let sig = AggSignature::new(
454 vec![],
455 ArgType::Primitive(DataType::Int64),
456 vec![Field::new("state", DataType::Binary, true)],
457 Volatility::Immutable,
458 );
459 let f = build_returns_field(&sig);
460 assert_eq!(f.data_type(), &DataType::Int64);
461 }
462}