Skip to main content

uni_plugin_wasm/
adapter_aggregate.rs

1//! Aggregate adapter — bridges a CM `aggregate-plugin` instance to
2//! [`AggregatePluginFn`] / [`PluginAccumulator`].
3//!
4//! Port of `uni_plugin_extism::adapter_aggregate`. Same envelope-less
5//! state-passing shape — the CM ABI carries `state: list<u8>` as a
6//! typed parameter rather than packing it into the IPC bytes, so the
7//! envelope helper from the extism side is unnecessary here.
8
9// Rust guideline compliant
10
11use std::sync::Arc;
12
13use arrow::array::RecordBatch;
14use arrow_array::ArrayRef;
15use arrow_schema::{Field, Schema, SchemaRef};
16use datafusion::scalar::ScalarValue;
17use uni_plugin::QName;
18use uni_plugin::adapter_common::arrow_types::argtype_to_arrow;
19use uni_plugin::errors::FnError;
20use uni_plugin::traits::aggregate::{AggSignature, AggregatePluginFn, PluginAccumulator};
21use uni_plugin_wasm_rt::ipc::{decode_batch, encode_batch};
22
23use crate::adapter_common::{acquire, ipc_to_fn_err};
24use crate::loader::AggregatePluginInstance;
25use crate::pool::WasmInstancePool;
26
27/// `AggregatePluginFn` adapter wrapping a CM aggregate-plugin pool.
28pub struct ComponentAggregateFn {
29    pool: Arc<WasmInstancePool<AggregatePluginInstance>>,
30    qname: QName,
31    sig: AggSignature,
32}
33
34impl std::fmt::Debug for ComponentAggregateFn {
35    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36        f.debug_struct("ComponentAggregateFn")
37            .field("qname", &self.qname)
38            .field("signature", &self.sig)
39            .finish_non_exhaustive()
40    }
41}
42
43impl ComponentAggregateFn {
44    /// Construct a new adapter against the supplied pool.
45    #[must_use]
46    pub fn new(
47        pool: Arc<WasmInstancePool<AggregatePluginInstance>>,
48        qname: QName,
49        sig: AggSignature,
50    ) -> Self {
51        Self { pool, qname, sig }
52    }
53
54    fn call_new(&self) -> Result<Vec<u8>, FnError> {
55        let mut leased = acquire(&self.pool, "aggregate")?;
56        let qname_str = self.qname.to_string();
57        let state = leased.get_mut().agg_new(&qname_str).map_err(|e| {
58            FnError::new(
59                FnError::CODE_UNEXPECTED_NULL,
60                format!("wasm agg_new `{qname_str}`: {e}"),
61            )
62        })?;
63        drop(leased);
64        Ok(state)
65    }
66}
67
68impl AggregatePluginFn for ComponentAggregateFn {
69    fn signature(&self) -> &AggSignature {
70        &self.sig
71    }
72
73    fn create_accumulator(&self) -> Box<dyn PluginAccumulator> {
74        let (state, init_err) = match self.call_new() {
75            Ok(s) => (s, None),
76            Err(e) => (Vec::new(), Some(e)),
77        };
78        Box::new(ComponentAggregateAccumulator {
79            state,
80            init_err,
81            pool: Arc::clone(&self.pool),
82            qname: self.qname.to_string(),
83            args_schema: build_args_schema(&self.sig),
84            returns_field: build_returns_field(&self.sig),
85        })
86    }
87}
88
89struct ComponentAggregateAccumulator {
90    state: Vec<u8>,
91    init_err: Option<FnError>,
92    pool: Arc<WasmInstancePool<AggregatePluginInstance>>,
93    qname: String,
94    args_schema: SchemaRef,
95    returns_field: Field,
96}
97
98impl ComponentAggregateAccumulator {
99    fn surface_init_err(&self) -> Result<(), FnError> {
100        if let Some(e) = &self.init_err {
101            return Err(FnError::new(
102                e.code,
103                format!("aggregate init failed: {}", e.message),
104            ));
105        }
106        Ok(())
107    }
108}
109
110impl PluginAccumulator for ComponentAggregateAccumulator {
111    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<(), FnError> {
112        self.surface_init_err()?;
113        let batch =
114            RecordBatch::try_new(Arc::clone(&self.args_schema), values.to_vec()).map_err(|e| {
115                FnError::new(
116                    FnError::CODE_TYPE_COERCION,
117                    format!("update_batch RecordBatch: {e}"),
118                )
119            })?;
120        let ipc = encode_batch(&batch).map_err(ipc_to_fn_err)?;
121        let mut leased = acquire(&self.pool, "aggregate")?;
122        let new_state = leased
123            .get_mut()
124            .agg_update(&self.qname, &self.state, &ipc)
125            .map_err(|e| {
126                FnError::new(
127                    FnError::CODE_UNEXPECTED_NULL,
128                    format!("wasm agg_update `{}`: {e}", self.qname),
129                )
130            })?;
131        drop(leased);
132        self.state = new_state;
133        Ok(())
134    }
135
136    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<(), FnError> {
137        self.surface_init_err()?;
138        if states.len() != 1 {
139            return Err(FnError::new(
140                FnError::CODE_TYPE_COERCION,
141                format!(
142                    "merge_batch expects 1 state column (opaque Binary); got {}",
143                    states.len()
144                ),
145            ));
146        }
147        let schema: SchemaRef = Arc::new(Schema::new(vec![Field::new(
148            "partial_state",
149            states[0].data_type().clone(),
150            true,
151        )]));
152        let batch = RecordBatch::try_new(schema, states.to_vec()).map_err(|e| {
153            FnError::new(
154                FnError::CODE_TYPE_COERCION,
155                format!("merge_batch RecordBatch: {e}"),
156            )
157        })?;
158        let ipc = encode_batch(&batch).map_err(ipc_to_fn_err)?;
159        let mut leased = acquire(&self.pool, "aggregate")?;
160        let new_state = leased
161            .get_mut()
162            .agg_merge(&self.qname, &self.state, &ipc)
163            .map_err(|e| {
164                FnError::new(
165                    FnError::CODE_UNEXPECTED_NULL,
166                    format!("wasm agg_merge `{}`: {e}", self.qname),
167                )
168            })?;
169        drop(leased);
170        self.state = new_state;
171        Ok(())
172    }
173
174    fn state(&self) -> Result<Vec<ScalarValue>, FnError> {
175        self.surface_init_err()?;
176        Ok(vec![ScalarValue::Binary(Some(self.state.clone()))])
177    }
178
179    fn evaluate(&self) -> Result<ScalarValue, FnError> {
180        self.surface_init_err()?;
181        let mut leased = acquire(&self.pool, "aggregate")?;
182        let out_bytes = leased
183            .get_mut()
184            .agg_evaluate(&self.qname, &self.state)
185            .map_err(|e| {
186                FnError::new(
187                    FnError::CODE_UNEXPECTED_NULL,
188                    format!("wasm agg_evaluate `{}`: {e}", self.qname),
189                )
190            })?;
191        drop(leased);
192        let batch = decode_batch(&out_bytes)
193            .map_err(ipc_to_fn_err)?
194            .ok_or_else(|| {
195                FnError::new(
196                    FnError::CODE_UNEXPECTED_NULL,
197                    format!("plugin agg_evaluate `{}` empty IPC", self.qname),
198                )
199            })?;
200        if batch.num_columns() != 1 || batch.num_rows() != 1 {
201            return Err(FnError::new(
202                FnError::CODE_TYPE_COERCION,
203                format!(
204                    "plugin agg_evaluate `{}` must return 1×1; got {}×{}",
205                    self.qname,
206                    batch.num_rows(),
207                    batch.num_columns()
208                ),
209            ));
210        }
211        if batch.column(0).data_type() != self.returns_field.data_type() {
212            return Err(FnError::new(
213                FnError::CODE_TYPE_COERCION,
214                format!(
215                    "plugin agg_evaluate `{}` returned {:?}, expected {:?}",
216                    self.qname,
217                    batch.column(0).data_type(),
218                    self.returns_field.data_type()
219                ),
220            ));
221        }
222        ScalarValue::try_from_array(batch.column(0), 0).map_err(|e| {
223            FnError::new(
224                FnError::CODE_TYPE_COERCION,
225                format!("agg_evaluate ScalarValue: {e}"),
226            )
227        })
228    }
229
230    fn size(&self) -> usize {
231        std::mem::size_of::<Self>() + self.state.capacity()
232    }
233}
234
235fn build_args_schema(sig: &AggSignature) -> SchemaRef {
236    let fields: Vec<Field> = sig
237        .args
238        .iter()
239        .enumerate()
240        .map(|(i, t)| Field::new(format!("arg{i}"), argtype_to_arrow(t), true))
241        .collect();
242    Arc::new(Schema::new(fields))
243}
244
245fn build_returns_field(sig: &AggSignature) -> Field {
246    Field::new("returns", argtype_to_arrow(&sig.returns), true)
247}