1use std::sync::Arc;
23
24use serde::Deserialize;
25use wasmtime::component::{Component, Linker};
26use wasmtime::{Config, Engine, Store};
27
28use crate::adapter::ComponentScalarFn;
29use crate::adapter_aggregate::ComponentAggregateFn;
30use crate::adapter_procedure::ComponentProcedure;
31use crate::bindings::aggregate::AggregatePlugin;
32use crate::bindings::procedure::ProcedurePlugin as ProcedurePluginBindings;
33use crate::bindings::scalar::ScalarPlugin;
34use crate::error::WasmError;
35use crate::host_state::HostState;
36use crate::pool::WasmInstancePool;
37
38#[derive(Debug, Clone, Deserialize)]
42#[serde(deny_unknown_fields)]
43pub struct ComponentManifest {
44 pub id: String,
46 pub version: String,
48 #[serde(default)]
50 pub abi: Option<String>,
51 #[serde(default)]
55 pub capabilities: Vec<uni_plugin::ManifestCapability>,
56 #[serde(default)]
58 pub determinism: Option<String>,
59 #[serde(default)]
61 pub description: Option<String>,
62 #[serde(default)]
64 pub fuel_per_call: Option<u64>,
65 #[serde(default)]
67 pub memory_max_pages: Option<u32>,
68 #[serde(default)]
70 pub timeout_ms: Option<u64>,
71}
72
73impl ComponentManifest {
74 #[must_use]
76 pub fn declared_capability_set(&self) -> uni_plugin::CapabilitySet {
77 uni_plugin::CapabilitySet::from_manifest(self.capabilities.iter().cloned())
78 }
79}
80
81#[derive(Debug, Clone, Deserialize)]
83#[serde(deny_unknown_fields)]
84pub struct WireFnSignature {
85 pub args: Vec<WireArgType>,
87 pub returns: WireArgType,
89 #[serde(default = "default_volatility")]
91 pub volatility: String,
92 #[serde(default = "default_null_handling")]
94 pub null_handling: String,
95}
96
97fn default_volatility() -> String {
98 "immutable".to_owned()
99}
100fn default_null_handling() -> String {
101 "propagate".to_owned()
102}
103fn default_proc_mode() -> String {
104 "read".to_owned()
105}
106
107#[derive(Debug, Clone, Deserialize)]
109#[serde(tag = "kind", rename_all = "snake_case", deny_unknown_fields)]
110pub enum WireArgType {
111 Primitive {
113 arrow: String,
115 },
116 CypherValue,
118}
119
120#[derive(Debug, Clone, Deserialize)]
122#[serde(tag = "kind", rename_all = "snake_case", deny_unknown_fields)]
123pub enum RegistrationEntry {
124 Scalar {
126 qname: String,
128 signature: WireFnSignature,
130 },
131 Aggregate {
133 qname: String,
135 signature: WireFnSignature,
137 state: WireArgType,
140 },
141 Procedure {
143 qname: String,
145 args: Vec<WireArgType>,
147 yields: Vec<WireArgType>,
149 #[serde(default = "default_proc_mode")]
151 mode: String,
152 },
153}
154
155#[derive(Debug, Clone, Deserialize)]
157#[serde(deny_unknown_fields)]
158pub struct RegistrationManifest {
159 pub entries: Vec<RegistrationEntry>,
161}
162
163#[derive(Clone)]
166pub struct PreparedComponent {
167 pub manifest: ComponentManifest,
169 pub effective: uni_plugin::CapabilitySet,
173 pub denied_capabilities: Vec<String>,
175 pub http: Option<Arc<dyn uni_plugin::HttpEgress>>,
178}
179
180impl std::fmt::Debug for PreparedComponent {
181 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
182 f.debug_struct("PreparedComponent")
183 .field("manifest", &self.manifest)
184 .field("effective", &self.effective)
185 .field("denied_capabilities", &self.denied_capabilities)
186 .field("http", &self.http.is_some())
187 .finish()
188 }
189}
190
191pub struct ScalarPluginInstance {
196 store: Store<HostState>,
197 bindings: ScalarPlugin,
198}
199
200impl std::fmt::Debug for ScalarPluginInstance {
201 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
202 f.debug_struct("ScalarPluginInstance")
203 .finish_non_exhaustive()
204 }
205}
206
207trait WasmCallErr {
211 fn code(&self) -> u32;
212 fn message(&self) -> &str;
213 fn retryable(&self) -> bool;
214}
215
216macro_rules! impl_wasm_call_err {
217 ($ty:ty) => {
218 impl WasmCallErr for $ty {
219 fn code(&self) -> u32 {
220 self.code
221 }
222 fn message(&self) -> &str {
223 &self.message
224 }
225 fn retryable(&self) -> bool {
226 self.retryable
227 }
228 }
229 };
230}
231impl_wasm_call_err!(crate::bindings::scalar::FnError);
232impl_wasm_call_err!(crate::bindings::aggregate::FnError);
233impl_wasm_call_err!(crate::bindings::procedure::FnError);
234
235fn map_call<E: WasmCallErr>(
241 label: &str,
242 result: Result<Result<Vec<u8>, E>, wasmtime::Error>,
243) -> Result<Vec<u8>, WasmError> {
244 match result {
245 Ok(Ok(bytes)) => Ok(bytes),
246 Ok(Err(fn_err)) => Err(WasmError::Invoke(format!(
247 "{label} fn-error code={} retryable={}: {}",
248 fn_err.code(),
249 fn_err.retryable(),
250 fn_err.message()
251 ))),
252 Err(e) => Err(WasmError::Invoke(format!("{label} trap: {e}"))),
253 }
254}
255
256impl ScalarPluginInstance {
257 pub fn invoke_scalar(&mut self, qname: &str, ipc: &[u8]) -> Result<Vec<u8>, WasmError> {
264 let result = self
265 .bindings
266 .call_invoke_scalar(&mut self.store, qname, ipc);
267 map_call("invoke-scalar", result)
268 }
269
270 fn read_manifest(&mut self) -> Result<ComponentManifest, WasmError> {
272 let s = self
273 .bindings
274 .call_manifest(&mut self.store)
275 .map_err(|e| WasmError::Instantiate(format!("call manifest: {e}")))?;
276 serde_json::from_str(&s)
277 .map_err(|e| WasmError::InvalidWasm(format!("manifest json parse: {e}")))
278 }
279
280 fn read_register(&mut self) -> Result<RegistrationManifest, WasmError> {
282 let s = self
283 .bindings
284 .call_register(&mut self.store)
285 .map_err(|e| WasmError::Instantiate(format!("call register: {e}")))?;
286 serde_json::from_str(&s)
287 .map_err(|e| WasmError::InvalidWasm(format!("register json parse: {e}")))
288 }
289}
290
291pub struct AggregatePluginInstance {
293 store: Store<HostState>,
294 bindings: AggregatePlugin,
295}
296
297impl std::fmt::Debug for AggregatePluginInstance {
298 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
299 f.debug_struct("AggregatePluginInstance")
300 .finish_non_exhaustive()
301 }
302}
303
304impl AggregatePluginInstance {
305 pub fn agg_new(&mut self, qname: &str) -> Result<Vec<u8>, WasmError> {
307 map_call(
308 "agg-new",
309 self.bindings.call_agg_new(&mut self.store, qname),
310 )
311 }
312
313 pub fn agg_update(
315 &mut self,
316 qname: &str,
317 state: &[u8],
318 values_ipc: &[u8],
319 ) -> Result<Vec<u8>, WasmError> {
320 map_call(
321 "agg-update",
322 self.bindings
323 .call_agg_update(&mut self.store, qname, state, values_ipc),
324 )
325 }
326
327 pub fn agg_merge(
329 &mut self,
330 qname: &str,
331 state: &[u8],
332 other_states_ipc: &[u8],
333 ) -> Result<Vec<u8>, WasmError> {
334 map_call(
335 "agg-merge",
336 self.bindings
337 .call_agg_merge(&mut self.store, qname, state, other_states_ipc),
338 )
339 }
340
341 pub fn agg_evaluate(&mut self, qname: &str, state: &[u8]) -> Result<Vec<u8>, WasmError> {
343 map_call(
344 "agg-evaluate",
345 self.bindings
346 .call_agg_evaluate(&mut self.store, qname, state),
347 )
348 }
349}
350
351pub struct ProcedurePluginInstance {
353 store: Store<HostState>,
354 bindings: ProcedurePluginBindings,
355}
356
357impl std::fmt::Debug for ProcedurePluginInstance {
358 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
359 f.debug_struct("ProcedurePluginInstance")
360 .finish_non_exhaustive()
361 }
362}
363
364impl ProcedurePluginInstance {
365 pub fn invoke_procedure(&mut self, qname: &str, args_ipc: &[u8]) -> Result<Vec<u8>, WasmError> {
367 map_call(
368 "invoke-procedure",
369 self.bindings
370 .call_invoke_procedure(&mut self.store, qname, args_ipc),
371 )
372 }
373}
374
375#[derive(Default)]
377pub struct WasmLoader {
378 http: Option<Arc<dyn uni_plugin::HttpEgress>>,
382}
383
384impl std::fmt::Debug for WasmLoader {
385 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
386 f.debug_struct("WasmLoader")
387 .field("http", &self.http.is_some())
388 .finish()
389 }
390}
391
392impl WasmLoader {
393 #[must_use]
395 pub fn new() -> Self {
396 Self::default()
397 }
398
399 #[must_use]
401 pub fn with_http(mut self, http: Arc<dyn uni_plugin::HttpEgress>) -> Self {
402 self.http = Some(http);
403 self
404 }
405
406 fn bootstrap_prepared(&self, host_grants: &uni_plugin::CapabilitySet) -> PreparedComponent {
417 PreparedComponent {
418 manifest: ComponentManifest {
419 id: String::new(),
420 version: String::new(),
421 abi: None,
422 capabilities: Vec::new(),
423 determinism: None,
424 description: None,
425 fuel_per_call: None,
426 memory_max_pages: None,
427 timeout_ms: None,
428 },
429 effective: host_grants.clone(),
430 denied_capabilities: Vec::new(),
431 http: self.http.clone(),
432 }
433 }
434
435 pub fn prepare(
442 &self,
443 manifest_json: &[u8],
444 grants: &uni_plugin::CapabilitySet,
445 ) -> Result<PreparedComponent, WasmError> {
446 let manifest: ComponentManifest = serde_json::from_slice(manifest_json)
447 .map_err(|e| WasmError::InvalidWasm(format!("manifest json parse: {e}")))?;
448 Ok(self.prepare_parsed(manifest, grants))
449 }
450
451 pub fn prepare_parsed(
462 &self,
463 manifest: ComponentManifest,
464 grants: &uni_plugin::CapabilitySet,
465 ) -> PreparedComponent {
466 let declared = manifest.declared_capability_set();
467 let effective = declared.intersect(grants);
469 let denied: Vec<String> = declared
471 .iter()
472 .filter(|c| !effective.contains_variant(c))
473 .map(|c| format!("{c:?}"))
474 .collect();
475 PreparedComponent {
476 manifest,
477 effective,
478 denied_capabilities: denied,
479 http: self.http.clone(),
480 }
481 }
482
483 pub fn instantiate(
493 &self,
494 bytes: &[u8],
495 prepared: &PreparedComponent,
496 ) -> Result<ScalarPluginInstance, WasmError> {
497 let engine = build_engine(&prepared.manifest)?;
498 let component = Component::from_binary(&engine, bytes)
499 .map_err(|e| WasmError::InvalidWasm(format!("component compile: {e}")))?;
500 let linker: Linker<HostState> =
501 select_linker_for_manifest(&engine, &prepared.manifest, &prepared.effective)?;
502 let mut store = Store::new(
503 &engine,
504 HostState::new(prepared.effective.clone(), prepared.http.clone()),
505 );
506 apply_resource_limits(&mut store, &prepared.manifest);
507 let bindings = ScalarPlugin::instantiate(&mut store, &component, &linker)
508 .map_err(|e| WasmError::Instantiate(format!("scalar-plugin instantiate: {e}")))?;
509 Ok(ScalarPluginInstance { store, bindings })
510 }
511
512 pub fn load(
520 &self,
521 bytes: &[u8],
522 host_grants: &uni_plugin::CapabilitySet,
523 registrar: &mut uni_plugin::PluginRegistrar<'_>,
524 ) -> Result<LoadOutcome, WasmError> {
525 let bootstrap = self.bootstrap_prepared(host_grants);
528 let mut bootstrap_inst = self.instantiate(bytes, &bootstrap)?;
529 let parsed_manifest = bootstrap_inst.read_manifest()?;
530 drop(bootstrap_inst);
531
532 registrar.set_plugin_id(uni_plugin::PluginId::new(parsed_manifest.id.clone()));
536
537 let prepared = self.prepare_parsed(parsed_manifest, host_grants);
542
543 let pool = build_scalar_pool(bytes, &prepared)?;
545
546 let registration = {
548 let mut leased = crate::pool::PooledInstance::acquire(Arc::clone(&pool))
549 .map_err(|e| WasmError::Instantiate(format!("acquire warm instance: {e}")))?;
550 let r = leased.get_mut().read_register()?;
551 drop(leased);
552 r
553 };
554
555 let names = apply_registration(bytes, &prepared, &pool, registration, registrar)?;
556
557 Ok(LoadOutcome {
558 plugin_id: prepared.manifest.id.clone(),
559 version: prepared.manifest.version.clone(),
560 effective_capabilities: capability_names(&prepared.effective),
561 denied_capabilities: prepared.denied_capabilities,
562 scalars_registered: names.scalars,
563 aggregates_registered: names.aggregates,
564 procedures_registered: names.procedures,
565 pool,
566 })
567 }
568
569 pub fn load_as_plugin(
590 &self,
591 bytes: &[u8],
592 host_grants: &uni_plugin::CapabilitySet,
593 ) -> Result<Box<dyn uni_plugin::Plugin + Send + Sync>, WasmError> {
594 let bootstrap = self.bootstrap_prepared(host_grants);
597 let mut bootstrap_inst = self.instantiate(bytes, &bootstrap)?;
598 let parsed_manifest = bootstrap_inst.read_manifest()?;
599 drop(bootstrap_inst);
600
601 let prepared = self.prepare_parsed(parsed_manifest, host_grants);
603 let scalar_pool = build_scalar_pool(bytes, &prepared)?;
604 let registration = {
605 let mut leased = crate::pool::PooledInstance::acquire(Arc::clone(&scalar_pool))
606 .map_err(|e| WasmError::Instantiate(format!("acquire warm instance: {e}")))?;
607 let r = leased.get_mut().read_register()?;
608 drop(leased);
609 r
610 };
611 let manifest = synthesize_plugin_manifest(&prepared.manifest, ®istration)?;
612 Ok(Box::new(ComponentPlugin {
613 manifest,
614 bytes: bytes.to_vec(),
615 prepared,
616 scalar_pool,
617 registration,
618 }))
619 }
620}
621
622pub struct LoadOutcome {
624 pub plugin_id: String,
626 pub version: String,
628 pub effective_capabilities: Vec<String>,
630 pub denied_capabilities: Vec<String>,
632 pub scalars_registered: Vec<String>,
634 pub aggregates_registered: Vec<String>,
636 pub procedures_registered: Vec<String>,
638 pub pool: Arc<WasmInstancePool<ScalarPluginInstance>>,
640}
641
642impl std::fmt::Debug for LoadOutcome {
643 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
644 f.debug_struct("LoadOutcome")
645 .field("plugin_id", &self.plugin_id)
646 .field("version", &self.version)
647 .field("effective_capabilities", &self.effective_capabilities)
648 .field("denied_capabilities", &self.denied_capabilities)
649 .field("scalars_registered", &self.scalars_registered)
650 .field("aggregates_registered", &self.aggregates_registered)
651 .field("procedures_registered", &self.procedures_registered)
652 .finish_non_exhaustive()
653 }
654}
655
656fn capability_names(caps: &uni_plugin::CapabilitySet) -> Vec<String> {
659 caps.iter().map(|c| format!("{c:?}")).collect()
660}
661
662fn select_linker_for_manifest(
669 engine: &Engine,
670 manifest: &ComponentManifest,
671 effective_caps: &uni_plugin::CapabilitySet,
672) -> Result<Linker<HostState>, WasmError> {
673 use crate::linker::{build_scalar_linker_v1, build_scalar_linker_v2};
674 use crate::multi_version::{SUPPORTED_MAJORS, major_for_abi};
675
676 let Some(abi_str) = manifest.abi.as_deref() else {
677 return build_scalar_linker_v1(engine, effective_caps);
678 };
679 let abi = uni_plugin::AbiRange::parse(abi_str)
680 .map_err(|e| WasmError::InvalidWasm(format!("manifest abi parse: {e}")))?;
681 match major_for_abi(&abi)? {
682 1 => build_scalar_linker_v1(engine, effective_caps),
683 2 => build_scalar_linker_v2(engine, effective_caps),
684 _ => Err(WasmError::AbiUnsupported {
685 requested: abi_str.to_owned(),
686 supported: SUPPORTED_MAJORS.to_vec(),
687 }),
688 }
689}
690
691fn build_engine(manifest: &ComponentManifest) -> Result<Engine, WasmError> {
692 let mut cfg = Config::new();
693 cfg.wasm_component_model(true);
694 if manifest.fuel_per_call.is_some() {
695 cfg.consume_fuel(true);
696 }
697 if manifest.timeout_ms.is_some() {
698 cfg.epoch_interruption(true);
699 }
700 Engine::new(&cfg).map_err(|e| WasmError::Instantiate(format!("engine config: {e}")))
701}
702
703fn apply_resource_limits(store: &mut Store<HostState>, manifest: &ComponentManifest) {
704 if let Some(fuel) = manifest.fuel_per_call {
705 let _ = store.set_fuel(fuel);
708 }
709 if manifest.timeout_ms.is_some() {
710 store.set_epoch_deadline(1);
716 }
717}
718
719fn build_pool<I, F>(
731 bytes: &[u8],
732 prepared: &PreparedComponent,
733 build_instance: F,
734) -> Result<Arc<WasmInstancePool<I>>, WasmError>
735where
736 I: Send + 'static,
737 F: Fn(Store<HostState>, &Component, &Linker<HostState>) -> Result<I, WasmError>
738 + Send
739 + Sync
740 + 'static,
741{
742 let bytes_owned: Arc<Vec<u8>> = Arc::new(bytes.to_vec());
743 let prepared_owned: Arc<PreparedComponent> = Arc::new(prepared.clone());
744 let build_instance = Arc::new(build_instance);
745
746 let factory = {
747 let bytes = Arc::clone(&bytes_owned);
748 let prepared = Arc::clone(&prepared_owned);
749 let build_instance = Arc::clone(&build_instance);
750 move || -> Result<I, WasmError> {
751 let engine = build_engine(&prepared.manifest)?;
752 let component = Component::from_binary(&engine, &bytes)
753 .map_err(|e| WasmError::InvalidWasm(format!("component compile: {e}")))?;
754 let linker: Linker<HostState> =
755 select_linker_for_manifest(&engine, &prepared.manifest, &prepared.effective)?;
756 let mut store = Store::new(
757 &engine,
758 HostState::new(prepared.effective.clone(), prepared.http.clone()),
759 );
760 apply_resource_limits(&mut store, &prepared.manifest);
761 build_instance(store, &component, &linker)
762 }
763 };
764
765 let pool = WasmInstancePool::new(crate::pool::PoolConfig::default(), factory)?;
766 Ok(Arc::new(pool))
767}
768
769struct RegisteredQNames {
771 scalars: Vec<String>,
772 aggregates: Vec<String>,
773 procedures: Vec<String>,
774}
775
776fn apply_registration(
783 bytes: &[u8],
784 prepared: &PreparedComponent,
785 scalar_pool: &Arc<WasmInstancePool<ScalarPluginInstance>>,
786 registration: RegistrationManifest,
787 registrar: &mut uni_plugin::PluginRegistrar<'_>,
788) -> Result<RegisteredQNames, WasmError> {
789 let mut scalars = Vec::new();
790 let mut aggregates = Vec::new();
791 let mut procedures = Vec::new();
792 let mut agg_pool: Option<Arc<WasmInstancePool<AggregatePluginInstance>>> = None;
793 let mut proc_pool: Option<Arc<WasmInstancePool<ProcedurePluginInstance>>> = None;
794
795 for entry in registration.entries {
796 match entry {
797 RegistrationEntry::Scalar { qname, signature } => {
798 let parsed_qname = uni_plugin::QName::parse(&qname)
799 .map_err(|e| WasmError::InvalidWasm(format!("invalid qname `{qname}`: {e}")))?;
800 let sig = wire_fn_sig_to_internal(&signature)?;
801 let adapter = Arc::new(ComponentScalarFn::new(
802 Arc::clone(scalar_pool),
803 parsed_qname.clone(),
804 sig.clone(),
805 ));
806 registrar
807 .scalar_fn(parsed_qname, sig, adapter)
808 .map_err(|e| {
809 WasmError::Internal(format!("registrar.scalar_fn `{qname}`: {e}"))
810 })?;
811 scalars.push(qname);
812 }
813 RegistrationEntry::Aggregate {
814 qname,
815 signature,
816 state,
817 } => {
818 let parsed_qname = uni_plugin::QName::parse(&qname)
819 .map_err(|e| WasmError::InvalidWasm(format!("invalid qname `{qname}`: {e}")))?;
820 let sig = wire_agg_sig_to_internal(&signature, &state)?;
821 let pool_ref = match &agg_pool {
822 Some(p) => Arc::clone(p),
823 None => {
824 let p = build_aggregate_pool(bytes, prepared)?;
825 agg_pool = Some(Arc::clone(&p));
826 p
827 }
828 };
829 let adapter = Arc::new(ComponentAggregateFn::new(
830 pool_ref,
831 parsed_qname.clone(),
832 sig.clone(),
833 ));
834 registrar
835 .aggregate_fn(parsed_qname, sig, adapter)
836 .map_err(|e| {
837 WasmError::Internal(format!("registrar.aggregate_fn `{qname}`: {e}"))
838 })?;
839 aggregates.push(qname);
840 }
841 RegistrationEntry::Procedure {
842 qname,
843 args,
844 yields,
845 mode,
846 } => {
847 let parsed_qname = uni_plugin::QName::parse(&qname)
848 .map_err(|e| WasmError::InvalidWasm(format!("invalid qname `{qname}`: {e}")))?;
849 let sig = wire_proc_sig_to_internal(&args, &yields, &mode)?;
850 let pool_ref = match &proc_pool {
851 Some(p) => Arc::clone(p),
852 None => {
853 let p = build_procedure_pool(bytes, prepared)?;
854 proc_pool = Some(Arc::clone(&p));
855 p
856 }
857 };
858 let adapter = Arc::new(ComponentProcedure::new(
859 pool_ref,
860 parsed_qname.clone(),
861 sig.clone(),
862 ));
863 registrar
864 .procedure(parsed_qname, sig, adapter)
865 .map_err(|e| {
866 WasmError::Internal(format!("registrar.procedure `{qname}`: {e}"))
867 })?;
868 procedures.push(qname);
869 }
870 }
871 }
872
873 Ok(RegisteredQNames {
874 scalars,
875 aggregates,
876 procedures,
877 })
878}
879
880fn synthesize_plugin_manifest(
893 component: &ComponentManifest,
894 registration: &RegistrationManifest,
895) -> Result<uni_plugin::PluginManifest, WasmError> {
896 use uni_plugin::{
897 AbiRange, Capability, CapabilitySet, Determinism, PluginId, ProvidedSurfaces, Scope,
898 SideEffects,
899 };
900
901 let version = semver::Version::parse(&component.version).map_err(|e| {
902 WasmError::InvalidWasm(format!("manifest version `{}`: {e}", component.version))
903 })?;
904 let abi = AbiRange::parse(component.abi.as_deref().unwrap_or("^1"))
905 .map_err(|e| WasmError::InvalidWasm(format!("manifest abi: {e}")))?;
906
907 let mut capabilities = CapabilitySet::new();
908 let mut side_effects = SideEffects::ReadOnly;
909 for entry in ®istration.entries {
910 match entry {
911 RegistrationEntry::Scalar { .. } => {
912 capabilities.insert(Capability::ScalarFn);
913 }
914 RegistrationEntry::Aggregate { .. } => {
915 capabilities.insert(Capability::AggregateFn);
916 }
917 RegistrationEntry::Procedure { mode, .. } => {
918 capabilities.insert(Capability::Procedure);
919 match mode.as_str() {
920 "write" => {
921 capabilities.insert(Capability::ProcedureWrites);
922 side_effects = SideEffects::Writes;
923 }
924 "schema" => {
925 capabilities.insert(Capability::ProcedureSchema);
926 side_effects = SideEffects::Writes;
927 }
928 "dbms" => {
929 capabilities.insert(Capability::ProcedureDbms);
930 }
931 _ => {}
932 }
933 }
934 }
935 }
936
937 let determinism = match component.determinism.as_deref() {
938 Some("pure") => Determinism::Pure,
939 Some("session-scoped" | "session_scoped") => Determinism::SessionScoped,
940 _ => Determinism::Nondeterministic,
941 };
942
943 Ok(uni_plugin::PluginManifest {
944 id: PluginId::new(component.id.clone()),
945 version,
946 abi,
947 depends_on: Vec::new(),
948 capabilities,
949 determinism,
950 side_effects,
951 scope: Scope::Instance,
952 hash: None,
953 signature: None,
954 provides: ProvidedSurfaces::default(),
955 docs: component.description.clone().unwrap_or_default(),
956 metadata: std::collections::BTreeMap::new(),
957 })
958}
959
960pub struct ComponentPlugin {
966 manifest: uni_plugin::PluginManifest,
967 bytes: Vec<u8>,
968 prepared: PreparedComponent,
969 scalar_pool: Arc<WasmInstancePool<ScalarPluginInstance>>,
970 registration: RegistrationManifest,
971}
972
973impl std::fmt::Debug for ComponentPlugin {
974 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
975 f.debug_struct("ComponentPlugin")
976 .field("id", &self.manifest.id.as_str())
977 .field("scalars", &self.registration.entries.len())
978 .finish()
979 }
980}
981
982impl uni_plugin::Plugin for ComponentPlugin {
983 fn manifest(&self) -> &uni_plugin::PluginManifest {
984 &self.manifest
985 }
986
987 fn register(
988 &self,
989 r: &mut uni_plugin::PluginRegistrar<'_>,
990 ) -> Result<(), uni_plugin::PluginError> {
991 apply_registration(
992 &self.bytes,
993 &self.prepared,
994 &self.scalar_pool,
995 self.registration.clone(),
996 r,
997 )
998 .map_err(|e| {
999 uni_plugin::PluginError::WasmInstantiate(format!("component register: {e}"))
1000 })?;
1001 Ok(())
1002 }
1003}
1004
1005fn build_scalar_pool(
1006 bytes: &[u8],
1007 prepared: &PreparedComponent,
1008) -> Result<Arc<WasmInstancePool<ScalarPluginInstance>>, WasmError> {
1009 build_pool(bytes, prepared, |mut store, component, linker| {
1010 let bindings = ScalarPlugin::instantiate(&mut store, component, linker)
1011 .map_err(|e| WasmError::Instantiate(format!("scalar-plugin instantiate: {e}")))?;
1012 Ok(ScalarPluginInstance { store, bindings })
1013 })
1014}
1015
1016fn build_aggregate_pool(
1017 bytes: &[u8],
1018 prepared: &PreparedComponent,
1019) -> Result<Arc<WasmInstancePool<AggregatePluginInstance>>, WasmError> {
1020 build_pool(bytes, prepared, |mut store, component, linker| {
1021 let bindings = AggregatePlugin::instantiate(&mut store, component, linker)
1022 .map_err(|e| WasmError::Instantiate(format!("aggregate-plugin instantiate: {e}")))?;
1023 Ok(AggregatePluginInstance { store, bindings })
1024 })
1025}
1026
1027fn build_procedure_pool(
1028 bytes: &[u8],
1029 prepared: &PreparedComponent,
1030) -> Result<Arc<WasmInstancePool<ProcedurePluginInstance>>, WasmError> {
1031 build_pool(bytes, prepared, |mut store, component, linker| {
1032 let bindings = ProcedurePluginBindings::instantiate(&mut store, component, linker)
1033 .map_err(|e| WasmError::Instantiate(format!("procedure-plugin instantiate: {e}")))?;
1034 Ok(ProcedurePluginInstance { store, bindings })
1035 })
1036}
1037
1038fn wire_arg(w: &WireArgType) -> Result<uni_plugin::traits::scalar::ArgType, WasmError> {
1040 use uni_plugin::traits::scalar::ArgType;
1041 Ok(match w {
1042 WireArgType::Primitive { arrow } => ArgType::Primitive(arrow_name_to_dt(arrow)?),
1043 WireArgType::CypherValue => ArgType::CypherValue,
1044 })
1045}
1046
1047fn parse_volatility(s: &str) -> Result<datafusion::logical_expr::Volatility, WasmError> {
1049 use datafusion::logical_expr::Volatility;
1050 Ok(match s {
1051 "immutable" => Volatility::Immutable,
1052 "stable" => Volatility::Stable,
1053 "volatile" => Volatility::Volatile,
1054 other => {
1055 return Err(WasmError::InvalidWasm(format!(
1056 "unsupported volatility: `{other}`"
1057 )));
1058 }
1059 })
1060}
1061
1062fn parse_null_handling(s: &str) -> Result<uni_plugin::traits::scalar::NullHandling, WasmError> {
1064 use uni_plugin::traits::scalar::NullHandling;
1065 Ok(match s {
1066 "propagate" => NullHandling::PropagateNulls,
1067 "user_handled" => NullHandling::UserHandled,
1068 other => {
1069 return Err(WasmError::InvalidWasm(format!(
1070 "unsupported null_handling: `{other}`"
1071 )));
1072 }
1073 })
1074}
1075
1076fn parse_proc_mode(s: &str) -> Result<uni_plugin::traits::procedure::ProcedureMode, WasmError> {
1078 use uni_plugin::traits::procedure::ProcedureMode;
1079 Ok(match s {
1080 "read" => ProcedureMode::Read,
1081 "write" => ProcedureMode::Write,
1082 "schema" => ProcedureMode::Schema,
1083 "dbms" => ProcedureMode::Dbms,
1084 other => {
1085 return Err(WasmError::InvalidWasm(format!(
1086 "unsupported procedure mode: `{other}`"
1087 )));
1088 }
1089 })
1090}
1091
1092fn wire_agg_sig_to_internal(
1093 wire_sig: &WireFnSignature,
1094 wire_state: &WireArgType,
1095) -> Result<uni_plugin::traits::aggregate::AggSignature, WasmError> {
1096 use arrow_schema::Field;
1097 use uni_plugin::traits::aggregate::AggSignature;
1098
1099 let internal = wire_fn_sig_to_internal(wire_sig)?;
1100 let state_field = match wire_state {
1101 WireArgType::Primitive { arrow } => {
1102 let dt = arrow_name_to_dt(arrow)?;
1103 Field::new("state", dt, true)
1104 }
1105 _ => {
1106 return Err(WasmError::InvalidWasm(
1107 "aggregate state must be a Primitive Arrow type".to_owned(),
1108 ));
1109 }
1110 };
1111 Ok(AggSignature {
1112 volatility: internal.volatility,
1113 args: internal.args,
1114 returns: internal.returns,
1115 state_fields: vec![state_field],
1116 supports_partial: true,
1117 })
1118}
1119
1120fn wire_proc_sig_to_internal(
1121 args: &[WireArgType],
1122 yields: &[WireArgType],
1123 mode: &str,
1124) -> Result<uni_plugin::traits::procedure::ProcedureSignature, WasmError> {
1125 use arrow_schema::Field;
1126 use uni_plugin::capability::SideEffects;
1127 use uni_plugin::traits::procedure::{NamedArgType, ProcedureSignature};
1128 use uni_plugin::traits::scalar::ArgType;
1129
1130 let named_args: Vec<NamedArgType> = args
1131 .iter()
1132 .enumerate()
1133 .map(|(i, w)| {
1134 let ty = wire_arg(w)?;
1135 Ok::<NamedArgType, WasmError>(NamedArgType {
1136 name: format!("arg{i}").into(),
1137 ty,
1138 default: None,
1139 doc: String::new(),
1140 })
1141 })
1142 .collect::<Result<_, _>>()?;
1143 let yield_fields: Vec<Field> = yields
1144 .iter()
1145 .enumerate()
1146 .map(|(i, w)| {
1147 let ty = wire_arg(w)?;
1148 let dt = match ty {
1149 ArgType::Primitive(d) => d,
1150 ArgType::CypherValue | ArgType::Variadic(_) => arrow_schema::DataType::LargeBinary,
1151 ArgType::Vector { element, .. } => element,
1152 };
1153 Ok::<Field, WasmError>(Field::new(format!("yield{i}"), dt, true))
1154 })
1155 .collect::<Result<_, _>>()?;
1156 Ok(ProcedureSignature {
1157 args: named_args,
1158 yields: yield_fields,
1159 mode: parse_proc_mode(mode)?,
1160 side_effects: SideEffects::default(),
1161 retry_contract: None,
1162 batch_input: None,
1163 docs: String::new(),
1164 })
1165}
1166
1167fn arrow_name_to_dt(name: &str) -> Result<arrow_schema::DataType, WasmError> {
1168 uni_plugin::adapter_common::arrow_types::arrow_name_to_datatype(name)
1169 .ok_or_else(|| WasmError::InvalidWasm(format!("unsupported arrow primitive: `{name}`")))
1170}
1171
1172fn wire_fn_sig_to_internal(
1173 wire: &WireFnSignature,
1174) -> Result<uni_plugin::traits::scalar::FnSignature, WasmError> {
1175 use uni_plugin::traits::scalar::{ArgType, FnSignature};
1176
1177 let args: Vec<ArgType> = wire.args.iter().map(wire_arg).collect::<Result<_, _>>()?;
1178 Ok(FnSignature {
1179 args,
1180 returns: wire_arg(&wire.returns)?,
1181 volatility: parse_volatility(&wire.volatility)?,
1182 null_handling: parse_null_handling(&wire.null_handling)?,
1183 })
1184}
1185
1186#[cfg(test)]
1187mod tests {
1188 use super::*;
1189
1190 use uni_plugin::{Capability, CapabilitySet};
1191
1192 fn manifest_json(caps: &[&str]) -> String {
1194 let caps_json: Vec<String> = caps.iter().map(|c| format!("\"{c}\"")).collect();
1195 format!(
1196 r#"{{ "id": "ai.example.test", "version": "1.0.0", "capabilities": [{}] }}"#,
1197 caps_json.join(", ")
1198 )
1199 }
1200
1201 #[test]
1202 fn loader_constructs() {
1203 let _ = WasmLoader::new();
1204 }
1205
1206 #[test]
1207 fn prepare_parses_minimal_manifest() {
1208 let l = WasmLoader::new();
1209 let json = manifest_json(&[]);
1210 let prep = l.prepare(json.as_bytes(), &CapabilitySet::new()).unwrap();
1211 assert_eq!(prep.manifest.id, "ai.example.test");
1212 assert!(prep.effective.is_empty());
1213 }
1214
1215 #[test]
1216 fn prepare_intersects_capabilities() {
1217 let l = WasmLoader::new();
1218 let json = manifest_json(&["filesystem", "network", "kms"]);
1220 let grants = CapabilitySet::from_iter_of([
1222 Capability::Filesystem {
1223 read: vec![],
1224 write: vec![],
1225 },
1226 Capability::Network { allow: vec![] },
1227 ]);
1228 let prep = l.prepare(json.as_bytes(), &grants).unwrap();
1229 assert_eq!(prep.effective.len(), 2);
1230 assert!(
1231 prep.effective
1232 .contains_variant(&Capability::Network { allow: vec![] })
1233 );
1234 assert!(
1235 !prep
1236 .effective
1237 .contains_variant(&Capability::Kms { key_ids: vec![] })
1238 );
1239 }
1240
1241 #[test]
1242 fn prepare_carries_structured_network_allowlist() {
1243 let l = WasmLoader::new();
1244 let json = r#"{ "id": "a.b", "version": "1.0.0",
1246 "capabilities": [{"kind":"network","allow":["https://api.example/**"]}] }"#;
1247 let grants = CapabilitySet::from_iter_of([Capability::Network {
1248 allow: vec!["https://api.example/**".into()],
1249 }]);
1250 let prep = l.prepare(json.as_bytes(), &grants).unwrap();
1251 assert!(
1253 prep.effective
1254 .iter()
1255 .any(|c| c.network_allows("https://api.example/v1/x"))
1256 );
1257 assert!(
1258 !prep
1259 .effective
1260 .iter()
1261 .any(|c| c.network_allows("https://evil.example/x"))
1262 );
1263 }
1264
1265 #[test]
1266 fn prepare_rejects_malformed_manifest() {
1267 let l = WasmLoader::new();
1268 let err = l.prepare(b"not json", &CapabilitySet::new()).unwrap_err();
1269 assert!(matches!(err, WasmError::InvalidWasm(_)));
1270 }
1271
1272 #[test]
1273 fn instantiate_rejects_garbage_bytes() {
1274 let l = WasmLoader::new();
1275 let prep = l
1276 .prepare(
1277 b"{\"id\":\"a.b\",\"version\":\"0.0.0\"}",
1278 &CapabilitySet::new(),
1279 )
1280 .unwrap();
1281 let err = l.instantiate(b"not real wasm", &prep).unwrap_err();
1282 assert!(matches!(err, WasmError::InvalidWasm(_)));
1283 }
1284}