Skip to main content

uni_plugin_extism/
adapter_aggregate.rs

1//! Aggregate adapter — bridges Extism aggregate plugins to
2//! [`AggregatePluginFn`] / [`PluginAccumulator`].
3//!
4//! ## Wire contract (per qname `q`)
5//!
6//! - `agg_<q>_new` — input empty; output is the initial state bytes
7//!   (opaque, plugin-defined; carried as Arrow `Binary` on the host).
8//! - `agg_<q>_update` — input is `[state_len: u32 LE][state_bytes]
9//!   [arrow_ipc_stream]` where the stream contains one batch whose
10//!   columns match `agg.signature().args`. Output is the updated state
11//!   bytes.
12//! - `agg_<q>_merge` — input is `[state_len: u32 LE][state_bytes]
13//!   [arrow_ipc_stream]` where the stream contains one batch with a
14//!   single `Binary` column of `M` partial states. Output is the
15//!   merged state bytes.
16//! - `agg_<q>_evaluate` — input is the raw state bytes; output is an
17//!   Arrow IPC stream with one 1-row batch whose single column has the
18//!   declared `returns` type.
19//!
20//! The length-prefixed envelope is used because Extism's host↔plugin
21//! call boundary takes a single byte buffer; the prefix lets the
22//! plugin recover the opaque state without parsing the IPC bytes.
23
24// Rust guideline compliant
25
26use std::sync::Arc;
27
28use arrow::array::RecordBatch;
29use arrow_array::ArrayRef;
30use arrow_schema::{Field, Schema, SchemaRef};
31use datafusion::scalar::ScalarValue;
32use uni_plugin::QName;
33use uni_plugin::adapter_common::arrow_types::argtype_to_arrow;
34use uni_plugin::errors::FnError;
35use uni_plugin::traits::aggregate::{AggSignature, AggregatePluginFn, PluginAccumulator};
36
37use crate::adapter_common::{acquire, extism_err_to_fn_err, sanitize_qname};
38use crate::ipc::{decode_batch, encode_batch};
39use crate::pool::ExtismInstancePool;
40
41/// Plugin-side aggregate-`new` export name from a qname.
42#[must_use]
43pub(crate) fn agg_new_export_name(qname: &QName) -> String {
44    format!("agg_{}_new", sanitize_qname(qname))
45}
46
47/// Plugin-side aggregate-`update` export name from a qname.
48#[must_use]
49pub(crate) fn agg_update_export_name(qname: &QName) -> String {
50    format!("agg_{}_update", sanitize_qname(qname))
51}
52
53/// Plugin-side aggregate-`merge` export name from a qname.
54#[must_use]
55pub(crate) fn agg_merge_export_name(qname: &QName) -> String {
56    format!("agg_{}_merge", sanitize_qname(qname))
57}
58
59/// Plugin-side aggregate-`evaluate` export name from a qname.
60#[must_use]
61pub(crate) fn agg_evaluate_export_name(qname: &QName) -> String {
62    format!("agg_{}_evaluate", sanitize_qname(qname))
63}
64
65/// `AggregatePluginFn` adapter wrapping an Extism plugin pool.
66pub struct ExtismAggregateFn {
67    pool: Arc<ExtismInstancePool<extism::Plugin>>,
68    qname: QName,
69    sig: AggSignature,
70    new_export: String,
71    update_export: String,
72    merge_export: String,
73    evaluate_export: String,
74}
75
76impl std::fmt::Debug for ExtismAggregateFn {
77    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78        f.debug_struct("ExtismAggregateFn")
79            .field("qname", &self.qname)
80            .field("signature", &self.sig)
81            .finish_non_exhaustive()
82    }
83}
84
85impl ExtismAggregateFn {
86    /// Construct a new adapter against the supplied pool.
87    #[must_use]
88    pub fn new(
89        pool: Arc<ExtismInstancePool<extism::Plugin>>,
90        qname: QName,
91        sig: AggSignature,
92    ) -> Self {
93        let new_export = agg_new_export_name(&qname);
94        let update_export = agg_update_export_name(&qname);
95        let merge_export = agg_merge_export_name(&qname);
96        let evaluate_export = agg_evaluate_export_name(&qname);
97        Self {
98            pool,
99            qname,
100            sig,
101            new_export,
102            update_export,
103            merge_export,
104            evaluate_export,
105        }
106    }
107
108    fn call_new(&self) -> Result<Vec<u8>, FnError> {
109        let mut leased = acquire(&self.pool)?;
110        let bytes: Vec<u8> = leased
111            .get_mut()
112            .call::<&[u8], &[u8]>(&self.new_export, &[])
113            .map_err(|e| {
114                FnError::new(
115                    FnError::CODE_UNEXPECTED_NULL,
116                    format!("extism call `{}` failed: {e}", self.new_export),
117                )
118            })?
119            .to_vec();
120        drop(leased);
121        Ok(bytes)
122    }
123}
124
125impl AggregatePluginFn for ExtismAggregateFn {
126    fn signature(&self) -> &AggSignature {
127        &self.sig
128    }
129
130    fn create_accumulator(&self) -> Box<dyn PluginAccumulator> {
131        // `create_accumulator` returns a Box without a Result; if the
132        // plugin's `_new` export fails, we surface that on the first
133        // update/evaluate call by carrying an empty state and a
134        // remembered init error. Two-phase init keeps the trait shape
135        // (DataFusion expects an infallible accumulator factory).
136        let (state, init_err) = match self.call_new() {
137            Ok(s) => (s, None),
138            Err(e) => (Vec::new(), Some(e)),
139        };
140        Box::new(ExtismAggregateAccumulator {
141            state,
142            init_err,
143            pool: Arc::clone(&self.pool),
144            update_export: self.update_export.clone(),
145            merge_export: self.merge_export.clone(),
146            evaluate_export: self.evaluate_export.clone(),
147            args_schema: build_args_schema(&self.sig),
148            returns_field: build_returns_field(&self.sig),
149        })
150    }
151}
152
153/// Per-group state machine.
154struct ExtismAggregateAccumulator {
155    state: Vec<u8>,
156    init_err: Option<FnError>,
157    pool: Arc<ExtismInstancePool<extism::Plugin>>,
158    update_export: String,
159    merge_export: String,
160    evaluate_export: String,
161    args_schema: SchemaRef,
162    returns_field: Field,
163}
164
165impl ExtismAggregateAccumulator {
166    fn surface_init_err(&self) -> Result<(), FnError> {
167        if let Some(e) = &self.init_err {
168            return Err(FnError::new(
169                e.code,
170                format!("aggregate init failed: {}", e.message),
171            ));
172        }
173        Ok(())
174    }
175
176    fn call_with_envelope(&self, export: &str, batch: RecordBatch) -> Result<Vec<u8>, FnError> {
177        let ipc = encode_batch(&batch).map_err(extism_err_to_fn_err)?;
178        // Reject states wider than the u32 length prefix before building the
179        // envelope (`build_envelope` would otherwise silently clamp to
180        // `u32::MAX`, corrupting the wire framing).
181        if u32::try_from(self.state.len()).is_err() {
182            return Err(FnError::new(
183                FnError::CODE_RESOURCE_LIMIT,
184                "aggregate state exceeds u32::MAX bytes",
185            ));
186        }
187        let buf = build_envelope(&self.state, &ipc);
188
189        let mut leased = acquire(&self.pool)?;
190        let out: Vec<u8> = leased
191            .get_mut()
192            .call::<&[u8], &[u8]>(export, &buf)
193            .map_err(|e| {
194                FnError::new(
195                    FnError::CODE_UNEXPECTED_NULL,
196                    format!("extism call `{export}` failed: {e}"),
197                )
198            })?
199            .to_vec();
200        drop(leased);
201        Ok(out)
202    }
203}
204
205impl PluginAccumulator for ExtismAggregateAccumulator {
206    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<(), FnError> {
207        self.surface_init_err()?;
208        let batch =
209            RecordBatch::try_new(Arc::clone(&self.args_schema), values.to_vec()).map_err(|e| {
210                FnError::new(
211                    FnError::CODE_TYPE_COERCION,
212                    format!("update_batch: RecordBatch assembly: {e}"),
213                )
214            })?;
215        let new_state = self.call_with_envelope(&self.update_export, batch)?;
216        self.state = new_state;
217        Ok(())
218    }
219
220    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<(), FnError> {
221        self.surface_init_err()?;
222        if states.len() != 1 {
223            return Err(FnError::new(
224                FnError::CODE_TYPE_COERCION,
225                format!(
226                    "merge_batch expects exactly 1 state column (opaque Binary); got {}",
227                    states.len()
228                ),
229            ));
230        }
231        let schema: SchemaRef = Arc::new(Schema::new(vec![Field::new(
232            "partial_state",
233            states[0].data_type().clone(),
234            true,
235        )]));
236        let batch = RecordBatch::try_new(schema, states.to_vec()).map_err(|e| {
237            FnError::new(
238                FnError::CODE_TYPE_COERCION,
239                format!("merge_batch: RecordBatch assembly: {e}"),
240            )
241        })?;
242        let new_state = self.call_with_envelope(&self.merge_export, batch)?;
243        self.state = new_state;
244        Ok(())
245    }
246
247    fn state(&self) -> Result<Vec<ScalarValue>, FnError> {
248        self.surface_init_err()?;
249        Ok(vec![ScalarValue::Binary(Some(self.state.clone()))])
250    }
251
252    fn evaluate(&self) -> Result<ScalarValue, FnError> {
253        self.surface_init_err()?;
254        let mut leased = acquire(&self.pool)?;
255        let out_bytes: Vec<u8> = leased
256            .get_mut()
257            .call::<&[u8], &[u8]>(&self.evaluate_export, &self.state)
258            .map_err(|e| {
259                FnError::new(
260                    FnError::CODE_UNEXPECTED_NULL,
261                    format!("extism call `{}` failed: {e}", self.evaluate_export),
262                )
263            })?
264            .to_vec();
265        drop(leased);
266
267        let batch = decode_batch(&out_bytes)
268            .map_err(extism_err_to_fn_err)?
269            .ok_or_else(|| {
270                FnError::new(
271                    FnError::CODE_UNEXPECTED_NULL,
272                    format!(
273                        "plugin `{}` returned an empty IPC stream",
274                        self.evaluate_export
275                    ),
276                )
277            })?;
278        if batch.num_columns() != 1 || batch.num_rows() != 1 {
279            return Err(FnError::new(
280                FnError::CODE_TYPE_COERCION,
281                format!(
282                    "plugin `{}` must return a 1-row × 1-col batch; got {} rows × {} cols",
283                    self.evaluate_export,
284                    batch.num_rows(),
285                    batch.num_columns()
286                ),
287            ));
288        }
289        // Sanity-check declared returns type matches.
290        if batch.column(0).data_type() != self.returns_field.data_type() {
291            return Err(FnError::new(
292                FnError::CODE_TYPE_COERCION,
293                format!(
294                    "plugin `{}` returned column type {:?}, expected {:?}",
295                    self.evaluate_export,
296                    batch.column(0).data_type(),
297                    self.returns_field.data_type()
298                ),
299            ));
300        }
301        ScalarValue::try_from_array(batch.column(0), 0).map_err(|e| {
302            FnError::new(
303                FnError::CODE_TYPE_COERCION,
304                format!("evaluate: ScalarValue::try_from_array: {e}"),
305            )
306        })
307    }
308
309    fn size(&self) -> usize {
310        std::mem::size_of::<Self>() + self.state.capacity()
311    }
312}
313
314fn build_args_schema(sig: &AggSignature) -> SchemaRef {
315    let fields: Vec<Field> = sig
316        .args
317        .iter()
318        .enumerate()
319        .map(|(i, t)| Field::new(format!("arg{i}"), argtype_to_arrow(t), true))
320        .collect();
321    Arc::new(Schema::new(fields))
322}
323
324fn build_returns_field(sig: &AggSignature) -> Field {
325    Field::new("returns", argtype_to_arrow(&sig.returns), true)
326}
327
328/// Build the length-prefixed `update`/`merge` envelope
329/// `[state_len: u32 LE][state_bytes][ipc_stream_bytes]`.
330///
331/// Used by [`ExtismAggregateAccumulator::call_with_envelope`] on the
332/// host side and mirrored by the plugin's [`parse_envelope`] equivalent;
333/// also exposed for envelope round-trip tests. Callers that cannot
334/// tolerate a clamped length must reject `state.len() > u32::MAX` before
335/// calling — this writer saturates the prefix at `u32::MAX`.
336#[doc(hidden)]
337#[must_use]
338pub fn build_envelope(state: &[u8], ipc: &[u8]) -> Vec<u8> {
339    let mut buf = Vec::with_capacity(4 + state.len() + ipc.len());
340    buf.extend_from_slice(&u32::try_from(state.len()).unwrap_or(u32::MAX).to_le_bytes());
341    buf.extend_from_slice(state);
342    buf.extend_from_slice(ipc);
343    buf
344}
345
346/// Helper exposed for tests / plugin authors: parse the envelope shape
347/// `[state_len: u32 LE][state_bytes][ipc_stream_bytes]`.
348///
349/// Returns `(state, ipc)`. The plugin uses the equivalent of this on
350/// its side to recover state + values per `update`/`merge` call.
351///
352/// # Errors
353///
354/// Returns a string error if the buffer is shorter than 4 bytes or the
355/// declared state length overruns the buffer.
356pub fn parse_envelope(buf: &[u8]) -> Result<(&[u8], &[u8]), String> {
357    if buf.len() < 4 {
358        return Err(format!("envelope too short: {} bytes < 4", buf.len()));
359    }
360    let len_bytes: [u8; 4] = buf[..4].try_into().expect("4 bytes");
361    let state_len = u32::from_le_bytes(len_bytes) as usize;
362    let end = 4usize
363        .checked_add(state_len)
364        .ok_or_else(|| "state length overflow".to_owned())?;
365    if end > buf.len() {
366        return Err(format!(
367            "declared state_len {} overruns buffer of {} bytes",
368            state_len,
369            buf.len()
370        ));
371    }
372    Ok((&buf[4..end], &buf[end..]))
373}
374
375#[cfg(test)]
376mod tests {
377    use super::*;
378
379    #[test]
380    fn export_name_format() {
381        let q = QName::parse("stats.weighted_mean").expect("valid");
382        assert_eq!(agg_new_export_name(&q), "agg_stats_weighted_mean_new");
383        assert_eq!(agg_update_export_name(&q), "agg_stats_weighted_mean_update");
384        assert_eq!(agg_merge_export_name(&q), "agg_stats_weighted_mean_merge");
385        assert_eq!(
386            agg_evaluate_export_name(&q),
387            "agg_stats_weighted_mean_evaluate"
388        );
389    }
390
391    #[test]
392    fn envelope_roundtrip_preserves_state_and_ipc() {
393        let state = b"opaque-state-blob".as_slice();
394        let ipc = b"\x01\x02\x03not-real-but-distinct".as_slice();
395        let env = build_envelope(state, ipc);
396        let (got_state, got_ipc) = parse_envelope(&env).expect("parse");
397        assert_eq!(got_state, state);
398        assert_eq!(got_ipc, ipc);
399    }
400
401    #[test]
402    fn envelope_with_empty_state() {
403        let env = build_envelope(&[], b"ipc");
404        let (state, ipc) = parse_envelope(&env).unwrap();
405        assert!(state.is_empty());
406        assert_eq!(ipc, b"ipc");
407    }
408
409    #[test]
410    fn envelope_with_empty_ipc() {
411        let env = build_envelope(b"state-only", &[]);
412        let (state, ipc) = parse_envelope(&env).unwrap();
413        assert_eq!(state, b"state-only");
414        assert!(ipc.is_empty());
415    }
416
417    #[test]
418    fn parse_envelope_rejects_short_buffer() {
419        assert!(parse_envelope(&[1u8, 2]).is_err());
420    }
421
422    #[test]
423    fn parse_envelope_rejects_overrun() {
424        // state_len declared = 0xFFFFFFFF but buffer is only 4 bytes.
425        let buf = vec![0xFFu8, 0xFF, 0xFF, 0xFF];
426        assert!(parse_envelope(&buf).is_err());
427    }
428
429    #[test]
430    fn args_schema_matches_signature_args() {
431        use arrow_schema::DataType;
432        use datafusion::logical_expr::Volatility;
433        use uni_plugin::traits::scalar::ArgType;
434        let sig = AggSignature::new(
435            vec![ArgType::Primitive(DataType::Float64), ArgType::CypherValue],
436            ArgType::Primitive(DataType::Float64),
437            vec![Field::new("state", DataType::Binary, true)],
438            Volatility::Immutable,
439        );
440        let schema = build_args_schema(&sig);
441        assert_eq!(schema.fields().len(), 2);
442        assert_eq!(schema.field(0).name(), "arg0");
443        assert_eq!(schema.field(0).data_type(), &DataType::Float64);
444        assert_eq!(schema.field(1).name(), "arg1");
445        assert_eq!(schema.field(1).data_type(), &DataType::LargeBinary);
446    }
447
448    #[test]
449    fn build_returns_field_uses_signature_returns() {
450        use arrow_schema::DataType;
451        use datafusion::logical_expr::Volatility;
452        use uni_plugin::traits::scalar::ArgType;
453        let sig = AggSignature::new(
454            vec![],
455            ArgType::Primitive(DataType::Int64),
456            vec![Field::new("state", DataType::Binary, true)],
457            Volatility::Immutable,
458        );
459        let f = build_returns_field(&sig);
460        assert_eq!(f.data_type(), &DataType::Int64);
461    }
462}