Skip to main content

uni_plugin_rhai/
adapter_procedure.rs

1//! Procedure adapter — Rhai-side procedure returning a stream of yield rows.
2//!
3//! A Rhai procedure exports a single function `${name}` returning a
4//! `rhai::Array` of `rhai::Map`s (rows). Each map's keys correspond to
5//! the yield-schema field names declared in the manifest. The adapter
6//! converts the returned array into a `RecordBatch` matching the yield
7//! schema, then wraps it in a `SendableRecordBatchStream` for the host
8//! to attach to the surrounding query plan.
9
10#![cfg(feature = "rhai-runtime")]
11
12// Rust guideline compliant
13
14use std::sync::Arc;
15
16use arrow_array::{ArrayRef, RecordBatch};
17use arrow_schema::{DataType, Schema, SchemaRef};
18use datafusion::execution::SendableRecordBatchStream;
19use datafusion::logical_expr::ColumnarValue;
20use rhai::{Dynamic, Map, Scope};
21use smol_str::SmolStr;
22
23use uni_plugin::adapter_common::batch_builder::batch_into_stream;
24use uni_plugin::errors::FnError;
25use uni_plugin::traits::procedure::{ProcedureContext, ProcedurePlugin, ProcedureSignature};
26
27use crate::dynamic_bridge::{OutBuilder, scalar_to_dynamic};
28use crate::runtime::RhaiPluginRuntime;
29
30/// Per-procedure Rhai callable adapter.
31#[derive(Debug)]
32pub struct RhaiProcedure {
33    runtime: Arc<RhaiPluginRuntime>,
34    name: SmolStr,
35    signature: ProcedureSignature,
36}
37
38impl RhaiProcedure {
39    /// Construct a procedure adapter binding `name` against the shared
40    /// runtime.
41    #[must_use]
42    pub fn new(
43        runtime: Arc<RhaiPluginRuntime>,
44        name: impl Into<SmolStr>,
45        signature: ProcedureSignature,
46    ) -> Self {
47        Self {
48            runtime,
49            name: name.into(),
50            signature,
51        }
52    }
53}
54
55impl ProcedurePlugin for RhaiProcedure {
56    fn signature(&self) -> &ProcedureSignature {
57        &self.signature
58    }
59
60    fn invoke(
61        &self,
62        _ctx: ProcedureContext<'_>,
63        args: &[ColumnarValue],
64    ) -> Result<SendableRecordBatchStream, FnError> {
65        // Convert each ColumnarValue::Scalar arg to a single Dynamic.
66        // Array args are unsupported for procedure invocation in v1 —
67        // procedures take scalar inputs, not batched columns.
68        let mut dyn_args: Vec<Dynamic> = Vec::with_capacity(args.len());
69        for (i, arg) in args.iter().enumerate() {
70            match arg {
71                ColumnarValue::Scalar(s) => {
72                    let d = scalar_to_dynamic(s)
73                        .map_err(|e| FnError::new(0x12, format!("procedure arg {i}: {e}")))?;
74                    dyn_args.push(d);
75                }
76                ColumnarValue::Array(_) => {
77                    return Err(FnError::new(
78                        0x10,
79                        format!("procedure arg {i} must be a scalar"),
80                    ));
81                }
82            }
83        }
84
85        // Call the Rhai fn; expect an Array of Maps (rows).
86        let mut scope = Scope::new();
87        let result: Dynamic = self
88            .runtime
89            .engine
90            .call_fn(&mut scope, &self.runtime.ast, self.name.as_str(), dyn_args)
91            .map_err(|e| FnError::new(0x730, format!("Rhai procedure `{}`: {e}", self.name)))?;
92
93        let yield_schema = Arc::new(Schema::new(self.signature.yields.clone()));
94        let batch = dynamic_to_record_batch(result, &yield_schema)?;
95        Ok(batch_into_stream(batch))
96    }
97}
98
99fn dynamic_to_record_batch(d: Dynamic, schema: &SchemaRef) -> Result<RecordBatch, FnError> {
100    let rows: rhai::Array = d.try_cast().ok_or_else(|| {
101        FnError::new(
102            0x12,
103            String::from("Rhai procedure must return an array of row maps"),
104        )
105    })?;
106    let row_count = rows.len();
107
108    // Pre-build one builder per yield field.
109    let mut builders: Vec<OutBuilder> = schema
110        .fields()
111        .iter()
112        .map(|f| OutBuilder::new(f.data_type(), row_count))
113        .collect::<Result<_, _>>()
114        .map_err(|e| FnError::new(0x11, e.to_string()))?;
115
116    for (i, row) in rows.into_iter().enumerate() {
117        let m: Map = row
118            .try_cast()
119            .ok_or_else(|| FnError::new(0x12, format!("procedure row {i} must be a map")))?;
120        for (field_idx, field) in schema.fields().iter().enumerate() {
121            let key = field.name();
122            let value = m.get(key.as_str()).cloned().unwrap_or(Dynamic::UNIT);
123            // Coerce numeric types — Rhai often returns INT for fields
124            // declared as Float (and vice versa for cross-int sizes).
125            let value = coerce_for(field.data_type(), value)?;
126            builders[field_idx]
127                .push(value)
128                .map_err(|e| FnError::new(0x14, e.to_string()))?;
129        }
130    }
131
132    let columns: Vec<ArrayRef> = builders.into_iter().map(|b| b.finish()).collect();
133    RecordBatch::try_new(schema.clone(), columns)
134        .map_err(|e| FnError::new(0x15, format!("procedure batch: {e}")))
135}
136
137fn coerce_for(target: &DataType, value: Dynamic) -> Result<Dynamic, FnError> {
138    if value.is_unit() {
139        return Ok(value);
140    }
141    match target {
142        // Rhai often returns INT for a Float-declared field (and vice versa
143        // for Int-declared fields); coerce the mismatched numeric type and
144        // pass everything else through unchanged.
145        DataType::Float64 => match value.as_int() {
146            Ok(i) => Ok(Dynamic::from(i as f64)),
147            Err(_) => Ok(value),
148        },
149        DataType::Int64 => match value.as_float() {
150            Ok(f) => Ok(Dynamic::from(f as i64)),
151            Err(_) => Ok(value),
152        },
153        _ => Ok(value),
154    }
155}
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160    use crate::engine::build_engine;
161    use crate::host_fns::RhaiHostFnRegistry;
162    use crate::manifest::compile;
163    use arrow_schema::Field;
164    use futures::StreamExt;
165    use uni_plugin::capability::SideEffects;
166    use uni_plugin::traits::procedure::ProcedureMode;
167    use uni_plugin::{CapabilitySet, PluginId};
168
169    fn build_runtime(script: &str) -> Arc<RhaiPluginRuntime> {
170        let engine = build_engine(&CapabilitySet::new(), &RhaiHostFnRegistry::new());
171        let ast = compile(&engine, script).unwrap();
172        RhaiPluginRuntime::new(PluginId::new("test.proc"), engine, ast)
173    }
174
175    #[tokio::test]
176    async fn procedure_emits_rows() {
177        let script = r#"
178            fn rows() {
179                [
180                    #{ id: 1, name: "alice" },
181                    #{ id: 2, name: "bob" },
182                    #{ id: 3, name: "carol" },
183                ]
184            }
185        "#;
186        let runtime = build_runtime(script);
187        let sig = ProcedureSignature {
188            args: vec![],
189            yields: vec![
190                Field::new("id", DataType::Int64, true),
191                Field::new("name", DataType::Utf8, true),
192            ],
193            mode: ProcedureMode::Read,
194            side_effects: SideEffects::ReadOnly,
195            retry_contract: None,
196            batch_input: None,
197            docs: String::new(),
198        };
199        let proc = RhaiProcedure::new(runtime, "rows", sig);
200        let mut stream = proc.invoke(ProcedureContext::new(), &[]).unwrap();
201        let batch = stream.next().await.unwrap().unwrap();
202        assert_eq!(batch.num_rows(), 3);
203        assert_eq!(batch.num_columns(), 2);
204    }
205}