uni_plugin_extism/
adapter_procedure.rs1use std::sync::Arc;
17
18use arrow::array::RecordBatch;
19use arrow_schema::{Field, Schema, SchemaRef};
20use datafusion::execution::SendableRecordBatchStream;
21use datafusion::logical_expr::ColumnarValue;
22use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
23use futures::stream;
24use uni_plugin::QName;
25use uni_plugin::adapter_common::arrow_types::argtype_to_arrow;
26use uni_plugin::errors::FnError;
27use uni_plugin::traits::procedure::{ProcedureContext, ProcedurePlugin, ProcedureSignature};
28
29use crate::adapter_common::{acquire, extism_err_to_fn_err, sanitize_qname};
30use crate::ipc::{decode_batches, encode_batch};
31use crate::pool::ExtismInstancePool;
32
33#[must_use]
39pub(crate) fn proc_invoke_export_name(qname: &QName) -> String {
40 format!("proc_{}_invoke", sanitize_qname(qname))
41}
42
43pub struct ExtismProcedure {
45 pool: Arc<ExtismInstancePool<extism::Plugin>>,
46 qname: QName,
47 invoke_export: String,
48 sig: ProcedureSignature,
49 args_schema: SchemaRef,
50 yields_schema: SchemaRef,
51}
52
53impl std::fmt::Debug for ExtismProcedure {
54 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55 f.debug_struct("ExtismProcedure")
56 .field("qname", &self.qname)
57 .field("signature", &self.sig)
58 .finish_non_exhaustive()
59 }
60}
61
62impl ExtismProcedure {
63 #[must_use]
65 pub fn new(
66 pool: Arc<ExtismInstancePool<extism::Plugin>>,
67 qname: QName,
68 sig: ProcedureSignature,
69 ) -> Self {
70 let invoke_export = proc_invoke_export_name(&qname);
71 let args_schema = build_args_schema(&sig);
72 let yields_schema = Arc::new(Schema::new(sig.yields.clone()));
73 Self {
74 pool,
75 qname,
76 invoke_export,
77 sig,
78 args_schema,
79 yields_schema,
80 }
81 }
82}
83
84impl ProcedurePlugin for ExtismProcedure {
85 fn signature(&self) -> &ProcedureSignature {
86 &self.sig
87 }
88
89 fn invoke(
90 &self,
91 _ctx: ProcedureContext<'_>,
92 args: &[ColumnarValue],
93 ) -> Result<SendableRecordBatchStream, FnError> {
94 let arrays: Vec<arrow::array::ArrayRef> = args
98 .iter()
99 .map(|c| {
100 c.clone().into_array(1).map_err(|e| {
101 FnError::new(
102 FnError::CODE_TYPE_COERCION,
103 format!("ColumnarValue::into_array: {e}"),
104 )
105 })
106 })
107 .collect::<Result<_, _>>()?;
108 if arrays.len() != self.args_schema.fields().len() {
109 return Err(FnError::new(
110 FnError::CODE_TYPE_COERCION,
111 format!(
112 "procedure `{}` expected {} args; got {}",
113 self.qname,
114 self.args_schema.fields().len(),
115 arrays.len()
116 ),
117 ));
118 }
119 let args_batch =
120 RecordBatch::try_new(Arc::clone(&self.args_schema), arrays).map_err(|e| {
121 FnError::new(
122 FnError::CODE_TYPE_COERCION,
123 format!("procedure `{}` args RecordBatch: {e}", self.qname),
124 )
125 })?;
126 let ipc = encode_batch(&args_batch).map_err(extism_err_to_fn_err)?;
127
128 let mut leased = acquire(&self.pool)?;
129 let out_bytes: Vec<u8> = leased
130 .get_mut()
131 .call::<&[u8], &[u8]>(&self.invoke_export, &ipc)
132 .map_err(|e| {
133 FnError::new(
134 FnError::CODE_UNEXPECTED_NULL,
135 format!("extism call `{}` failed: {e}", self.invoke_export),
136 )
137 })?
138 .to_vec();
139 drop(leased);
140
141 let batches = decode_batches(&out_bytes).map_err(extism_err_to_fn_err)?;
142
143 for (i, b) in batches.iter().enumerate() {
145 if b.schema().fields() != self.yields_schema.fields() {
146 return Err(FnError::new(
147 FnError::CODE_TYPE_COERCION,
148 format!(
149 "procedure `{}` batch[{i}] schema mismatch: got {:?}, expected {:?}",
150 self.qname,
151 b.schema().fields(),
152 self.yields_schema.fields()
153 ),
154 ));
155 }
156 }
157
158 let schema = Arc::clone(&self.yields_schema);
159 let stream = stream::iter(batches.into_iter().map(Ok));
160 Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
161 }
162}
163
164fn build_args_schema(sig: &ProcedureSignature) -> SchemaRef {
165 let fields: Vec<Field> = sig
166 .args
167 .iter()
168 .enumerate()
169 .map(|(i, a)| Field::new(format!("arg{i}"), argtype_to_arrow(&a.ty), true))
170 .collect();
171 Arc::new(Schema::new(fields))
172}
173
174#[cfg(test)]
175mod tests {
176 use super::*;
177 use arrow_schema::DataType;
178 use uni_plugin::capability::SideEffects;
179 use uni_plugin::traits::procedure::{NamedArgType, ProcedureMode};
180 use uni_plugin::traits::scalar::ArgType;
181
182 fn sample_sig() -> ProcedureSignature {
183 ProcedureSignature {
184 args: vec![NamedArgType {
185 name: "arg0".into(),
186 ty: ArgType::Primitive(DataType::Utf8),
187 default: None,
188 doc: String::new(),
189 }],
190 yields: vec![
191 Field::new("yield0", DataType::Int64, true),
192 Field::new("yield1", DataType::Utf8, true),
193 ],
194 mode: ProcedureMode::Read,
195 side_effects: SideEffects::default(),
196 retry_contract: None,
197 batch_input: None,
198 docs: String::new(),
199 }
200 }
201
202 #[test]
203 fn export_name_format() {
204 let q = QName::parse("myorg.scan").expect("valid");
205 assert_eq!(proc_invoke_export_name(&q), "proc_myorg_scan_invoke");
206 }
207
208 #[test]
209 fn build_args_schema_matches_named_args() {
210 let sig = sample_sig();
211 let schema = build_args_schema(&sig);
212 assert_eq!(schema.fields().len(), 1);
213 assert_eq!(schema.field(0).name(), "arg0");
214 assert_eq!(schema.field(0).data_type(), &DataType::Utf8);
215 }
216}