uni_plugin_wasm/
adapter.rs1use std::sync::Arc;
11
12use arrow::array::{ArrayRef, RecordBatch};
13use arrow_schema::{Field, Schema, SchemaRef};
14use datafusion::logical_expr::ColumnarValue;
15use uni_plugin::QName;
16use uni_plugin::errors::FnError;
17use uni_plugin::traits::scalar::{FnSignature, ScalarPluginFn};
18use uni_plugin_wasm_rt::ipc::{decode_batch, encode_batch};
19
20use crate::adapter_common::{acquire, ipc_to_fn_err};
21use crate::loader::ScalarPluginInstance;
22use crate::pool::WasmInstancePool;
23
24pub struct ComponentScalarFn {
28 pool: Arc<WasmInstancePool<ScalarPluginInstance>>,
29 qname: QName,
30 sig: FnSignature,
31}
32
33impl std::fmt::Debug for ComponentScalarFn {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 f.debug_struct("ComponentScalarFn")
36 .field("qname", &self.qname)
37 .field("signature", &self.sig)
38 .finish_non_exhaustive()
39 }
40}
41
42impl ComponentScalarFn {
43 #[must_use]
45 pub fn new(
46 pool: Arc<WasmInstancePool<ScalarPluginInstance>>,
47 qname: QName,
48 sig: FnSignature,
49 ) -> Self {
50 Self { pool, qname, sig }
51 }
52
53 fn args_to_batch(&self, args: &[ColumnarValue], rows: usize) -> Result<RecordBatch, FnError> {
54 let arrays: Vec<ArrayRef> = args
55 .iter()
56 .map(|c| {
57 c.clone().into_array(rows).map_err(|e| {
58 FnError::new(
59 FnError::CODE_TYPE_COERCION,
60 format!("ColumnarValue::into_array: {e}"),
61 )
62 })
63 })
64 .collect::<Result<_, _>>()?;
65 let fields: Vec<Field> = arrays
66 .iter()
67 .enumerate()
68 .map(|(i, a)| Field::new(format!("arg{i}"), a.data_type().clone(), true))
69 .collect();
70 let schema: SchemaRef = Arc::new(Schema::new(fields));
71 RecordBatch::try_new(schema, arrays).map_err(|e| {
72 FnError::new(
73 FnError::CODE_TYPE_COERCION,
74 format!("RecordBatch assembly: {e}"),
75 )
76 })
77 }
78}
79
80impl ScalarPluginFn for ComponentScalarFn {
81 fn signature(&self) -> &FnSignature {
82 &self.sig
83 }
84
85 fn invoke(&self, args: &[ColumnarValue], rows: usize) -> Result<ColumnarValue, FnError> {
86 let batch = self.args_to_batch(args, rows)?;
87 let bytes = encode_batch(&batch).map_err(ipc_to_fn_err)?;
88
89 let mut leased = acquire(&self.pool, "plugin")?;
90 let qname_str = self.qname.to_string();
91 let out_bytes: Vec<u8> =
92 leased
93 .get_mut()
94 .invoke_scalar(&qname_str, &bytes)
95 .map_err(|e| {
96 FnError::new(
97 FnError::CODE_UNEXPECTED_NULL,
98 format!("wasm component invoke_scalar `{qname_str}`: {e}"),
99 )
100 })?;
101 drop(leased);
102
103 let out_batch = decode_batch(&out_bytes)
104 .map_err(ipc_to_fn_err)?
105 .ok_or_else(|| {
106 FnError::new(
107 FnError::CODE_UNEXPECTED_NULL,
108 format!("wasm component `{qname_str}` returned empty IPC stream"),
109 )
110 })?;
111
112 if out_batch.num_columns() != 1 {
113 return Err(FnError::new(
114 FnError::CODE_TYPE_COERCION,
115 format!(
116 "wasm component `{qname_str}` returned {} columns; scalar fns must return 1",
117 out_batch.num_columns()
118 ),
119 ));
120 }
121 Ok(ColumnarValue::Array(out_batch.column(0).clone()))
122 }
123}