Skip to main content

uni_plugin_extism/
adapter_procedure.rs

1//! Procedure adapter — bridges Extism procedure plugins to
2//! [`ProcedurePlugin`].
3//!
4//! ## Wire contract (per qname `q`)
5//!
6//! - `proc_<q>_invoke` — input is an Arrow IPC stream with one 1-row
7//!   batch whose columns match `proc.signature().args`. Output is an
8//!   Arrow IPC stream containing zero or more batches, each matching
9//!   the declared `yields` schema. M6a.2 collects every output batch
10//!   eagerly and serves them from an in-memory stream; true streaming
11//!   via a `host_yield` callback lands with M6b (host imports under
12//!   the Component Model).
13
14// Rust guideline compliant
15
16use std::sync::Arc;
17
18use arrow::array::RecordBatch;
19use arrow_schema::{Field, Schema, SchemaRef};
20use datafusion::execution::SendableRecordBatchStream;
21use datafusion::logical_expr::ColumnarValue;
22use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
23use futures::stream;
24use uni_plugin::QName;
25use uni_plugin::adapter_common::arrow_types::argtype_to_arrow;
26use uni_plugin::errors::FnError;
27use uni_plugin::traits::procedure::{ProcedureContext, ProcedurePlugin, ProcedureSignature};
28
29use crate::adapter_common::{acquire, extism_err_to_fn_err, sanitize_qname};
30use crate::ipc::{decode_batches, encode_batch};
31use crate::pool::ExtismInstancePool;
32
33/// Plugin-side procedure-invoke export name from a qname.
34///
35/// `.` in qnames is replaced with `_` so plugin authors can use
36/// idiomatic Rust function names (Rust identifiers can't contain
37/// `.`). Matches the scalar / aggregate sanitization.
38#[must_use]
39pub(crate) fn proc_invoke_export_name(qname: &QName) -> String {
40    format!("proc_{}_invoke", sanitize_qname(qname))
41}
42
43/// `ProcedurePlugin` adapter wrapping an Extism plugin pool.
44pub struct ExtismProcedure {
45    pool: Arc<ExtismInstancePool<extism::Plugin>>,
46    qname: QName,
47    invoke_export: String,
48    sig: ProcedureSignature,
49    args_schema: SchemaRef,
50    yields_schema: SchemaRef,
51}
52
53impl std::fmt::Debug for ExtismProcedure {
54    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55        f.debug_struct("ExtismProcedure")
56            .field("qname", &self.qname)
57            .field("signature", &self.sig)
58            .finish_non_exhaustive()
59    }
60}
61
62impl ExtismProcedure {
63    /// Construct a new adapter against the supplied pool.
64    #[must_use]
65    pub fn new(
66        pool: Arc<ExtismInstancePool<extism::Plugin>>,
67        qname: QName,
68        sig: ProcedureSignature,
69    ) -> Self {
70        let invoke_export = proc_invoke_export_name(&qname);
71        let args_schema = build_args_schema(&sig);
72        let yields_schema = Arc::new(Schema::new(sig.yields.clone()));
73        Self {
74            pool,
75            qname,
76            invoke_export,
77            sig,
78            args_schema,
79            yields_schema,
80        }
81    }
82}
83
84impl ProcedurePlugin for ExtismProcedure {
85    fn signature(&self) -> &ProcedureSignature {
86        &self.sig
87    }
88
89    fn invoke(
90        &self,
91        _ctx: ProcedureContext<'_>,
92        args: &[ColumnarValue],
93    ) -> Result<SendableRecordBatchStream, FnError> {
94        // M6a.2: procedures receive scalar args, packed into a 1-row
95        // RecordBatch (one column per arg). Plugins decode and produce
96        // a stream of `yields`-shaped batches eagerly.
97        let arrays: Vec<arrow::array::ArrayRef> = args
98            .iter()
99            .map(|c| {
100                c.clone().into_array(1).map_err(|e| {
101                    FnError::new(
102                        FnError::CODE_TYPE_COERCION,
103                        format!("ColumnarValue::into_array: {e}"),
104                    )
105                })
106            })
107            .collect::<Result<_, _>>()?;
108        if arrays.len() != self.args_schema.fields().len() {
109            return Err(FnError::new(
110                FnError::CODE_TYPE_COERCION,
111                format!(
112                    "procedure `{}` expected {} args; got {}",
113                    self.qname,
114                    self.args_schema.fields().len(),
115                    arrays.len()
116                ),
117            ));
118        }
119        let args_batch =
120            RecordBatch::try_new(Arc::clone(&self.args_schema), arrays).map_err(|e| {
121                FnError::new(
122                    FnError::CODE_TYPE_COERCION,
123                    format!("procedure `{}` args RecordBatch: {e}", self.qname),
124                )
125            })?;
126        let ipc = encode_batch(&args_batch).map_err(extism_err_to_fn_err)?;
127
128        let mut leased = acquire(&self.pool)?;
129        let out_bytes: Vec<u8> = leased
130            .get_mut()
131            .call::<&[u8], &[u8]>(&self.invoke_export, &ipc)
132            .map_err(|e| {
133                FnError::new(
134                    FnError::CODE_UNEXPECTED_NULL,
135                    format!("extism call `{}` failed: {e}", self.invoke_export),
136                )
137            })?
138            .to_vec();
139        drop(leased);
140
141        let batches = decode_batches(&out_bytes).map_err(extism_err_to_fn_err)?;
142
143        // Validate every batch matches the declared yields schema.
144        for (i, b) in batches.iter().enumerate() {
145            if b.schema().fields() != self.yields_schema.fields() {
146                return Err(FnError::new(
147                    FnError::CODE_TYPE_COERCION,
148                    format!(
149                        "procedure `{}` batch[{i}] schema mismatch: got {:?}, expected {:?}",
150                        self.qname,
151                        b.schema().fields(),
152                        self.yields_schema.fields()
153                    ),
154                ));
155            }
156        }
157
158        let schema = Arc::clone(&self.yields_schema);
159        let stream = stream::iter(batches.into_iter().map(Ok));
160        Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
161    }
162}
163
164fn build_args_schema(sig: &ProcedureSignature) -> SchemaRef {
165    let fields: Vec<Field> = sig
166        .args
167        .iter()
168        .enumerate()
169        .map(|(i, a)| Field::new(format!("arg{i}"), argtype_to_arrow(&a.ty), true))
170        .collect();
171    Arc::new(Schema::new(fields))
172}
173
174#[cfg(test)]
175mod tests {
176    use super::*;
177    use arrow_schema::DataType;
178    use uni_plugin::capability::SideEffects;
179    use uni_plugin::traits::procedure::{NamedArgType, ProcedureMode};
180    use uni_plugin::traits::scalar::ArgType;
181
182    fn sample_sig() -> ProcedureSignature {
183        ProcedureSignature {
184            args: vec![NamedArgType {
185                name: "arg0".into(),
186                ty: ArgType::Primitive(DataType::Utf8),
187                default: None,
188                doc: String::new(),
189            }],
190            yields: vec![
191                Field::new("yield0", DataType::Int64, true),
192                Field::new("yield1", DataType::Utf8, true),
193            ],
194            mode: ProcedureMode::Read,
195            side_effects: SideEffects::default(),
196            retry_contract: None,
197            batch_input: None,
198            docs: String::new(),
199        }
200    }
201
202    #[test]
203    fn export_name_format() {
204        let q = QName::parse("myorg.scan").expect("valid");
205        assert_eq!(proc_invoke_export_name(&q), "proc_myorg_scan_invoke");
206    }
207
208    #[test]
209    fn build_args_schema_matches_named_args() {
210        let sig = sample_sig();
211        let schema = build_args_schema(&sig);
212        assert_eq!(schema.fields().len(), 1);
213        assert_eq!(schema.field(0).name(), "arg0");
214        assert_eq!(schema.field(0).data_type(), &DataType::Utf8);
215    }
216}