uni_plugin_rhai/
adapter_procedure.rs1#![cfg(feature = "rhai-runtime")]
11
12use std::sync::Arc;
15
16use arrow_array::{ArrayRef, RecordBatch};
17use arrow_schema::{DataType, Schema, SchemaRef};
18use datafusion::execution::SendableRecordBatchStream;
19use datafusion::logical_expr::ColumnarValue;
20use rhai::{Dynamic, Map, Scope};
21use smol_str::SmolStr;
22
23use uni_plugin::adapter_common::batch_builder::batch_into_stream;
24use uni_plugin::errors::FnError;
25use uni_plugin::traits::procedure::{ProcedureContext, ProcedurePlugin, ProcedureSignature};
26
27use crate::dynamic_bridge::{OutBuilder, scalar_to_dynamic};
28use crate::runtime::RhaiPluginRuntime;
29
30#[derive(Debug)]
32pub struct RhaiProcedure {
33 runtime: Arc<RhaiPluginRuntime>,
34 name: SmolStr,
35 signature: ProcedureSignature,
36}
37
38impl RhaiProcedure {
39 #[must_use]
42 pub fn new(
43 runtime: Arc<RhaiPluginRuntime>,
44 name: impl Into<SmolStr>,
45 signature: ProcedureSignature,
46 ) -> Self {
47 Self {
48 runtime,
49 name: name.into(),
50 signature,
51 }
52 }
53}
54
55impl ProcedurePlugin for RhaiProcedure {
56 fn signature(&self) -> &ProcedureSignature {
57 &self.signature
58 }
59
60 fn invoke(
61 &self,
62 _ctx: ProcedureContext<'_>,
63 args: &[ColumnarValue],
64 ) -> Result<SendableRecordBatchStream, FnError> {
65 let mut dyn_args: Vec<Dynamic> = Vec::with_capacity(args.len());
69 for (i, arg) in args.iter().enumerate() {
70 match arg {
71 ColumnarValue::Scalar(s) => {
72 let d = scalar_to_dynamic(s)
73 .map_err(|e| FnError::new(0x12, format!("procedure arg {i}: {e}")))?;
74 dyn_args.push(d);
75 }
76 ColumnarValue::Array(_) => {
77 return Err(FnError::new(
78 0x10,
79 format!("procedure arg {i} must be a scalar"),
80 ));
81 }
82 }
83 }
84
85 let mut scope = Scope::new();
87 let result: Dynamic = self
88 .runtime
89 .engine
90 .call_fn(&mut scope, &self.runtime.ast, self.name.as_str(), dyn_args)
91 .map_err(|e| FnError::new(0x730, format!("Rhai procedure `{}`: {e}", self.name)))?;
92
93 let yield_schema = Arc::new(Schema::new(self.signature.yields.clone()));
94 let batch = dynamic_to_record_batch(result, &yield_schema)?;
95 Ok(batch_into_stream(batch))
96 }
97}
98
99fn dynamic_to_record_batch(d: Dynamic, schema: &SchemaRef) -> Result<RecordBatch, FnError> {
100 let rows: rhai::Array = d.try_cast().ok_or_else(|| {
101 FnError::new(
102 0x12,
103 String::from("Rhai procedure must return an array of row maps"),
104 )
105 })?;
106 let row_count = rows.len();
107
108 let mut builders: Vec<OutBuilder> = schema
110 .fields()
111 .iter()
112 .map(|f| OutBuilder::new(f.data_type(), row_count))
113 .collect::<Result<_, _>>()
114 .map_err(|e| FnError::new(0x11, e.to_string()))?;
115
116 for (i, row) in rows.into_iter().enumerate() {
117 let m: Map = row
118 .try_cast()
119 .ok_or_else(|| FnError::new(0x12, format!("procedure row {i} must be a map")))?;
120 for (field_idx, field) in schema.fields().iter().enumerate() {
121 let key = field.name();
122 let value = m.get(key.as_str()).cloned().unwrap_or(Dynamic::UNIT);
123 let value = coerce_for(field.data_type(), value)?;
126 builders[field_idx]
127 .push(value)
128 .map_err(|e| FnError::new(0x14, e.to_string()))?;
129 }
130 }
131
132 let columns: Vec<ArrayRef> = builders.into_iter().map(|b| b.finish()).collect();
133 RecordBatch::try_new(schema.clone(), columns)
134 .map_err(|e| FnError::new(0x15, format!("procedure batch: {e}")))
135}
136
137fn coerce_for(target: &DataType, value: Dynamic) -> Result<Dynamic, FnError> {
138 if value.is_unit() {
139 return Ok(value);
140 }
141 match target {
142 DataType::Float64 => match value.as_int() {
146 Ok(i) => Ok(Dynamic::from(i as f64)),
147 Err(_) => Ok(value),
148 },
149 DataType::Int64 => match value.as_float() {
150 Ok(f) => Ok(Dynamic::from(f as i64)),
151 Err(_) => Ok(value),
152 },
153 _ => Ok(value),
154 }
155}
156
157#[cfg(test)]
158mod tests {
159 use super::*;
160 use crate::engine::build_engine;
161 use crate::host_fns::RhaiHostFnRegistry;
162 use crate::manifest::compile;
163 use arrow_schema::Field;
164 use futures::StreamExt;
165 use uni_plugin::capability::SideEffects;
166 use uni_plugin::traits::procedure::ProcedureMode;
167 use uni_plugin::{CapabilitySet, PluginId};
168
169 fn build_runtime(script: &str) -> Arc<RhaiPluginRuntime> {
170 let engine = build_engine(&CapabilitySet::new(), &RhaiHostFnRegistry::new());
171 let ast = compile(&engine, script).unwrap();
172 RhaiPluginRuntime::new(PluginId::new("test.proc"), engine, ast)
173 }
174
175 #[tokio::test]
176 async fn procedure_emits_rows() {
177 let script = r#"
178 fn rows() {
179 [
180 #{ id: 1, name: "alice" },
181 #{ id: 2, name: "bob" },
182 #{ id: 3, name: "carol" },
183 ]
184 }
185 "#;
186 let runtime = build_runtime(script);
187 let sig = ProcedureSignature {
188 args: vec![],
189 yields: vec![
190 Field::new("id", DataType::Int64, true),
191 Field::new("name", DataType::Utf8, true),
192 ],
193 mode: ProcedureMode::Read,
194 side_effects: SideEffects::ReadOnly,
195 retry_contract: None,
196 batch_input: None,
197 docs: String::new(),
198 };
199 let proc = RhaiProcedure::new(runtime, "rows", sig);
200 let mut stream = proc.invoke(ProcedureContext::new(), &[]).unwrap();
201 let batch = stream.next().await.unwrap().unwrap();
202 assert_eq!(batch.num_rows(), 3);
203 assert_eq!(batch.num_columns(), 2);
204 }
205}