1#![cfg(feature = "rhai-runtime")]
15
16use std::sync::Arc;
17
18use uni_plugin::{
19 Capability, CapabilitySet, HttpEgress, KmsProvider, PluginError, PluginId, PluginRegistrar,
20 QName,
21};
22
23use arrow_schema::Field;
24
25use uni_plugin::capability::SideEffects;
26use uni_plugin::secrets::SecretStore;
27use uni_plugin::traits::procedure::{NamedArgType, ProcedureMode, ProcedureSignature};
28
29use crate::adapter::RhaiScalarFn;
30use crate::adapter_aggregate::{RhaiAggregateFn, build_agg_signature};
31use crate::adapter_procedure::RhaiProcedure;
32use crate::engine::build_engine;
33use crate::error::RhaiError;
34use crate::host_fns::RhaiHostFnRegistry;
35use crate::manifest::{ProcedureEntry, RhaiManifest, compile, parse_manifest};
36use crate::runtime::RhaiPluginRuntime;
37use crate::wire_translate::{build_fn_signature, type_name_to_argtype, type_name_to_datatype};
38
39#[derive(Debug)]
41pub struct LoadOutcome {
42 pub plugin_id: PluginId,
44 pub version: String,
46 pub effective_capabilities: CapabilitySet,
49 pub denied_capabilities: Vec<Capability>,
51 pub scalars_registered: Vec<String>,
53 pub aggregates_registered: Vec<String>,
55 pub procedures_registered: Vec<String>,
57 pub runtime: Arc<RhaiPluginRuntime>,
61}
62
63#[derive(Default, Clone)]
68pub struct RhaiLoader {
69 host_fns: RhaiHostFnRegistry,
70 kms: Option<Arc<dyn KmsProvider>>,
73 secrets: Option<Arc<SecretStore>>,
75 http: Option<Arc<dyn HttpEgress>>,
77}
78
79impl std::fmt::Debug for RhaiLoader {
80 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81 f.debug_struct("RhaiLoader")
82 .field("host_fn_count", &self.host_fns.len())
83 .finish()
84 }
85}
86
87impl RhaiLoader {
88 #[must_use]
90 pub fn new() -> Self {
91 Self::default()
92 }
93
94 pub fn host_fns_mut(&mut self) -> &mut RhaiHostFnRegistry {
97 &mut self.host_fns
98 }
99
100 #[must_use]
102 pub fn host_fns(&self) -> &RhaiHostFnRegistry {
103 &self.host_fns
104 }
105
106 #[must_use]
108 pub fn host_fn_count(&self) -> usize {
109 self.host_fns.len()
110 }
111
112 #[must_use]
114 pub fn with_kms(mut self, kms: Arc<dyn KmsProvider>) -> Self {
115 self.kms = Some(kms);
116 self
117 }
118
119 #[must_use]
121 pub fn with_secret_store(mut self, store: Arc<SecretStore>) -> Self {
122 self.secrets = Some(store);
123 self
124 }
125
126 #[must_use]
128 pub fn with_http(mut self, http: Arc<dyn HttpEgress>) -> Self {
129 self.http = Some(http);
130 self
131 }
132
133 #[must_use]
135 pub fn kms(&self) -> Option<Arc<dyn KmsProvider>> {
136 self.kms.clone()
137 }
138
139 #[must_use]
141 pub fn secret_store(&self) -> Option<Arc<SecretStore>> {
142 self.secrets.clone()
143 }
144
145 #[must_use]
147 pub fn http(&self) -> Option<Arc<dyn HttpEgress>> {
148 self.http.clone()
149 }
150
151 pub fn load(
163 &self,
164 script: &str,
165 registrar: &mut PluginRegistrar<'_>,
166 registrar_caps: &CapabilitySet,
167 ) -> Result<LoadOutcome, RhaiError> {
168 let engine = build_engine(registrar_caps, &self.host_fns);
187 let ast = compile(&engine, script)?;
188 let manifest = parse_manifest(&engine, &ast)?;
189
190 let plugin_id = PluginId::new(manifest.id.clone());
191
192 let declared = derive_declared_capabilities(&manifest);
199 let (effective, denied) = intersect_caps(&declared, registrar_caps);
200
201 let runtime = RhaiPluginRuntime::new(plugin_id.clone(), engine, ast);
202
203 registrar.set_plugin_id(plugin_id.clone());
205
206 let mut scalars_registered = Vec::with_capacity(manifest.scalar_fns.len());
211 if effective.contains(&Capability::ScalarFn) {
212 for entry in &manifest.scalar_fns {
213 let sig = build_fn_signature(&entry.args, &entry.returns, &manifest.determinism)?;
214 let qname = QName::new(plugin_id.as_str(), entry.name.clone());
215 let adapter = if entry.vectorized {
216 RhaiScalarFn::new_vectorized(
217 Arc::clone(&runtime),
218 entry.name.clone(),
219 sig.clone(),
220 )
221 } else {
222 RhaiScalarFn::new(Arc::clone(&runtime), entry.name.clone(), sig.clone())
223 };
224 registrar
225 .scalar_fn(qname.clone(), sig, Arc::new(adapter))
226 .map_err(plugin_to_rhai_err)?;
227 scalars_registered.push(qname.to_string());
228 }
229 }
230
231 let mut aggregates_registered = Vec::with_capacity(manifest.aggregate_fns.len());
232 if effective.contains(&Capability::AggregateFn) {
233 for entry in &manifest.aggregate_fns {
234 let sig = build_agg_signature(&entry.args, &entry.returns, &manifest.determinism)?;
235 let qname = QName::new(plugin_id.as_str(), entry.name.clone());
236 let adapter =
237 RhaiAggregateFn::new(Arc::clone(&runtime), entry.name.clone(), sig.clone());
238 registrar
239 .aggregate_fn(qname.clone(), sig, Arc::new(adapter))
240 .map_err(plugin_to_rhai_err)?;
241 aggregates_registered.push(qname.to_string());
242 }
243 }
244
245 let mut procedures_registered = Vec::with_capacity(manifest.procedures.len());
246 if effective.contains(&Capability::Procedure) {
247 for entry in &manifest.procedures {
248 let sig = build_procedure_signature(entry)?;
249 let qname = QName::new(plugin_id.as_str(), entry.name.clone());
250 let adapter =
251 RhaiProcedure::new(Arc::clone(&runtime), entry.name.clone(), sig.clone());
252 registrar
253 .procedure(qname.clone(), sig, Arc::new(adapter))
254 .map_err(plugin_to_rhai_err)?;
255 procedures_registered.push(qname.to_string());
256 }
257 }
258
259 Ok(LoadOutcome {
260 plugin_id,
261 version: manifest.version,
262 effective_capabilities: effective,
263 denied_capabilities: denied,
264 scalars_registered,
265 aggregates_registered,
266 procedures_registered,
267 runtime,
268 })
269 }
270}
271
272fn build_procedure_signature(entry: &ProcedureEntry) -> Result<ProcedureSignature, RhaiError> {
273 let args: Vec<NamedArgType> = entry
274 .args
275 .iter()
276 .enumerate()
277 .map(|(i, t)| {
278 let ty = type_name_to_argtype(t)?;
279 Ok(NamedArgType {
280 name: format!("arg{i}").into(),
281 ty,
282 default: None,
283 doc: String::new(),
284 })
285 })
286 .collect::<Result<_, RhaiError>>()?;
287
288 let yields: Vec<Field> = entry
289 .yields
290 .iter()
291 .enumerate()
292 .map(|(i, t)| {
293 let dt = type_name_to_datatype(t)?;
294 Ok(Field::new(format!("col{i}"), dt, true))
295 })
296 .collect::<Result<_, RhaiError>>()?;
297
298 let mode = match entry.mode.trim().to_ascii_lowercase().as_str() {
299 "write" => ProcedureMode::Write,
300 "schema" => ProcedureMode::Schema,
301 "dbms" => ProcedureMode::Dbms,
302 _ => ProcedureMode::Read,
303 };
304 let side_effects = match mode {
305 ProcedureMode::Read => SideEffects::ReadOnly,
306 _ => SideEffects::Writes,
307 };
308
309 Ok(ProcedureSignature {
310 args,
311 yields,
312 mode,
313 side_effects,
314 retry_contract: None,
315 batch_input: None,
316 docs: String::new(),
317 })
318}
319
320fn derive_declared_capabilities(m: &RhaiManifest) -> CapabilitySet {
321 let mut set = CapabilitySet::new();
322 if !m.scalar_fns.is_empty() {
323 set.insert(Capability::ScalarFn);
324 }
325 if !m.aggregate_fns.is_empty() {
326 set.insert(Capability::AggregateFn);
327 }
328 if !m.procedures.is_empty() {
329 set.insert(Capability::Procedure);
330 }
331 set
332}
333
334fn intersect_caps(
335 declared: &CapabilitySet,
336 granted: &CapabilitySet,
337) -> (CapabilitySet, Vec<Capability>) {
338 let effective = declared.intersect(granted);
339 let denied: Vec<Capability> = declared
340 .iter()
341 .filter(|c| !granted.contains(c))
342 .cloned()
343 .collect();
344 (effective, denied)
345}
346
347fn plugin_to_rhai_err(e: PluginError) -> RhaiError {
348 match e {
349 PluginError::DuplicateRegistration(q) => {
350 RhaiError::ManifestInvalid(format!("duplicate registration: {q}"))
351 }
352 PluginError::CapabilityRequired(c) => {
353 RhaiError::ManifestInvalid(format!("registrar caps missing: {c:?}"))
354 }
355 other => RhaiError::Internal(format!("registrar: {other}")),
356 }
357}
358
359#[cfg(test)]
360mod tests {
361 use super::*;
362 use uni_plugin::PluginRegistry;
363
364 fn loader_with_caps() -> (RhaiLoader, CapabilitySet) {
365 let loader = RhaiLoader::new();
366 let caps = CapabilitySet::from_iter_of([
367 Capability::ScalarFn,
368 Capability::AggregateFn,
369 Capability::Procedure,
370 ]);
371 (loader, caps)
372 }
373
374 #[test]
375 fn loads_minimal_scalar_plugin() {
376 let script = r#"
377 fn uni_manifest() {
378 #{
379 id: "ai.test.scalar",
380 version: "0.1.0",
381 scalar_fns: [
382 #{ name: "double", args: ["float"], returns: "float" },
383 ],
384 }
385 }
386 fn double(x) { x * 2.0 }
387 "#;
388 let (loader, caps) = loader_with_caps();
389 let registry = PluginRegistry::new();
390 let mut r = PluginRegistrar::new(PluginId::new("rhai.loading"), &caps, ®istry);
391 let outcome = loader.load(script, &mut r, &caps).expect("loads");
392 assert_eq!(outcome.plugin_id.as_str(), "ai.test.scalar");
393 assert_eq!(outcome.scalars_registered.len(), 1);
394 assert!(outcome.denied_capabilities.is_empty());
395 r.commit_to_registry().expect("commits");
396 let q = QName::new("ai.test.scalar", "double");
398 assert!(registry.scalar_fn(&q).is_some());
399 }
400
401 #[test]
402 fn declared_but_not_granted_caps_show_as_denied() {
403 let script = r#"
404 fn uni_manifest() {
405 #{
406 id: "ai.test.denied",
407 version: "0.1.0",
408 scalar_fns: [
409 #{ name: "noop", args: [], returns: "int" },
410 ],
411 aggregate_fns: [
412 #{ name: "agg", args: ["float"], returns: "float", state: "map" },
413 ],
414 }
415 }
416 fn noop() { 0 }
417 "#;
418 let loader = RhaiLoader::new();
419 let caps = CapabilitySet::from_iter_of([Capability::ScalarFn]);
420 let registry = PluginRegistry::new();
421 let mut r = PluginRegistrar::new(PluginId::new("rhai.loading"), &caps, ®istry);
422 let outcome = loader.load(script, &mut r, &caps).expect("loads");
423 assert!(
424 outcome
425 .denied_capabilities
426 .contains(&Capability::AggregateFn)
427 );
428 assert_eq!(outcome.scalars_registered.len(), 1);
429 }
430
431 #[test]
432 fn parse_failure_returns_parse_error() {
433 let script = r#"this is not valid rhai @@@"#;
434 let (loader, caps) = loader_with_caps();
435 let registry = PluginRegistry::new();
436 let mut r = PluginRegistrar::new(PluginId::new("rhai.loading"), &caps, ®istry);
437 let err = loader.load(script, &mut r, &caps).unwrap_err();
438 assert!(matches!(err, RhaiError::ParseFailed(_)));
439 }
440}