Skip to main content

uni_plugin_wasm/
adapter.rs

1//! `ComponentScalarFn` — bridges a CM `scalar-plugin` instance to
2//! [`ScalarPluginFn`].
3//!
4//! Mirrors `uni-plugin-extism`'s `ExtismScalarFn`: encode args as
5//! Arrow IPC, call the plugin's typed `invoke-scalar` export, decode
6//! the returned IPC bytes back into a `ColumnarValue`. The pool's
7//! cold-start factory rebuilds the wasmtime `Store<HostState>` + the
8//! linked `ScalarPlugin` typed wrapper for each new instance.
9
10use std::sync::Arc;
11
12use arrow::array::{ArrayRef, RecordBatch};
13use arrow_schema::{Field, Schema, SchemaRef};
14use datafusion::logical_expr::ColumnarValue;
15use uni_plugin::QName;
16use uni_plugin::errors::FnError;
17use uni_plugin::traits::scalar::{FnSignature, ScalarPluginFn};
18use uni_plugin_wasm_rt::ipc::{decode_batch, encode_batch};
19
20use crate::adapter_common::{acquire, ipc_to_fn_err};
21use crate::loader::ScalarPluginInstance;
22use crate::pool::WasmInstancePool;
23
24/// Adapter that registers as `ScalarPluginFn` on the host's
25/// `PluginRegistrar`. Holds an `Arc` to the pool so multiple
26/// concurrent Cypher calls each acquire their own warm instance.
27pub struct ComponentScalarFn {
28    pool: Arc<WasmInstancePool<ScalarPluginInstance>>,
29    qname: QName,
30    sig: FnSignature,
31}
32
33impl std::fmt::Debug for ComponentScalarFn {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        f.debug_struct("ComponentScalarFn")
36            .field("qname", &self.qname)
37            .field("signature", &self.sig)
38            .finish_non_exhaustive()
39    }
40}
41
42impl ComponentScalarFn {
43    /// Construct a new adapter against the supplied pool.
44    #[must_use]
45    pub fn new(
46        pool: Arc<WasmInstancePool<ScalarPluginInstance>>,
47        qname: QName,
48        sig: FnSignature,
49    ) -> Self {
50        Self { pool, qname, sig }
51    }
52
53    fn args_to_batch(&self, args: &[ColumnarValue], rows: usize) -> Result<RecordBatch, FnError> {
54        let arrays: Vec<ArrayRef> = args
55            .iter()
56            .map(|c| {
57                c.clone().into_array(rows).map_err(|e| {
58                    FnError::new(
59                        FnError::CODE_TYPE_COERCION,
60                        format!("ColumnarValue::into_array: {e}"),
61                    )
62                })
63            })
64            .collect::<Result<_, _>>()?;
65        let fields: Vec<Field> = arrays
66            .iter()
67            .enumerate()
68            .map(|(i, a)| Field::new(format!("arg{i}"), a.data_type().clone(), true))
69            .collect();
70        let schema: SchemaRef = Arc::new(Schema::new(fields));
71        RecordBatch::try_new(schema, arrays).map_err(|e| {
72            FnError::new(
73                FnError::CODE_TYPE_COERCION,
74                format!("RecordBatch assembly: {e}"),
75            )
76        })
77    }
78}
79
80impl ScalarPluginFn for ComponentScalarFn {
81    fn signature(&self) -> &FnSignature {
82        &self.sig
83    }
84
85    fn invoke(&self, args: &[ColumnarValue], rows: usize) -> Result<ColumnarValue, FnError> {
86        let batch = self.args_to_batch(args, rows)?;
87        let bytes = encode_batch(&batch).map_err(ipc_to_fn_err)?;
88
89        let mut leased = acquire(&self.pool, "plugin")?;
90        let qname_str = self.qname.to_string();
91        let out_bytes: Vec<u8> =
92            leased
93                .get_mut()
94                .invoke_scalar(&qname_str, &bytes)
95                .map_err(|e| {
96                    FnError::new(
97                        FnError::CODE_UNEXPECTED_NULL,
98                        format!("wasm component invoke_scalar `{qname_str}`: {e}"),
99                    )
100                })?;
101        drop(leased);
102
103        let out_batch = decode_batch(&out_bytes)
104            .map_err(ipc_to_fn_err)?
105            .ok_or_else(|| {
106                FnError::new(
107                    FnError::CODE_UNEXPECTED_NULL,
108                    format!("wasm component `{qname_str}` returned empty IPC stream"),
109                )
110            })?;
111
112        if out_batch.num_columns() != 1 {
113            return Err(FnError::new(
114                FnError::CODE_TYPE_COERCION,
115                format!(
116                    "wasm component `{qname_str}` returned {} columns; scalar fns must return 1",
117                    out_batch.num_columns()
118                ),
119            ));
120        }
121        Ok(ColumnarValue::Array(out_batch.column(0).clone()))
122    }
123}