1use serde_json::Value;
6use shape_abi_v1::{
7 ModuleInvokeResult, ModuleInvokeResultKind, ModuleSchema as AbiModuleSchema, ModuleVTable,
8 PluginError,
9};
10use shape_ast::error::{Result, ShapeError};
11use shape_value::ValueWord;
12use shape_wire::{WireValue, render_any_error_plain};
13use std::collections::HashSet;
14use std::ffi::c_void;
15use std::sync::Arc;
16
17#[derive(Debug, Clone, serde::Deserialize)]
18struct ArtifactPayload {
19 module_path: String,
20 #[serde(default)]
21 source: Option<String>,
22 #[serde(default)]
23 compiled: Option<Vec<u8>>,
24}
25
26#[derive(Debug, Clone)]
28pub struct ParsedModuleFunction {
29 pub name: String,
30 pub description: String,
31 pub params: Vec<String>,
32 pub return_type: Option<String>,
33}
34
35#[derive(Debug, Clone)]
37pub struct ParsedModuleArtifact {
38 pub module_path: String,
39 pub source: Option<String>,
40 pub compiled: Option<Vec<u8>>,
41}
42
43#[derive(Debug, Clone)]
45pub struct ParsedModuleSchema {
46 pub module_name: String,
47 pub functions: Vec<ParsedModuleFunction>,
48 pub artifacts: Vec<ParsedModuleArtifact>,
49}
50
51pub struct PluginModule {
53 name: String,
54 vtable: &'static ModuleVTable,
55 instance: *mut c_void,
56 schema: ParsedModuleSchema,
57}
58
59impl PluginModule {
60 pub fn new(name: String, vtable: &'static ModuleVTable, config: &Value) -> Result<Self> {
62 let config_bytes = rmp_serde::to_vec(config).map_err(|e| ShapeError::RuntimeError {
63 message: format!("Failed to serialize module config for '{}': {}", name, e),
64 location: None,
65 })?;
66
67 let init_fn = vtable.init.ok_or_else(|| ShapeError::RuntimeError {
68 message: format!("Plugin '{}' module capability has no init function", name),
69 location: None,
70 })?;
71
72 let instance = unsafe { init_fn(config_bytes.as_ptr(), config_bytes.len()) };
73 if instance.is_null() {
74 return Err(ShapeError::RuntimeError {
75 message: format!("Plugin '{}' module init returned null", name),
76 location: None,
77 });
78 }
79
80 let schema = parse_module_schema(vtable, instance, &name)?;
81
82 Ok(Self {
83 name,
84 vtable,
85 instance,
86 schema,
87 })
88 }
89
90 pub fn name(&self) -> &str {
92 &self.name
93 }
94
95 pub fn schema(&self) -> &ParsedModuleSchema {
97 &self.schema
98 }
99
100 pub fn to_module_exports(&self) -> crate::module_exports::ModuleExports {
102 use crate::module_exports::{ModuleExports, ModuleFunction, ModuleParam};
103
104 let mut module = ModuleExports::new(self.schema.module_name.clone());
105 module.description = format!("Plugin module exported by '{}'", self.name);
106
107 let invoker = Arc::new(ModuleInvoker {
108 name: self.name.clone(),
109 vtable: self.vtable,
110 instance: self.instance,
111 });
112
113 for function in &self.schema.functions {
114 let fn_name = function.name.clone();
115 let invoker_ref = Arc::clone(&invoker);
116
117 let schema = ModuleFunction {
118 description: function.description.clone(),
119 params: function
120 .params
121 .iter()
122 .enumerate()
123 .map(|(idx, ty)| ModuleParam {
124 name: format!("arg{}", idx),
125 type_name: ty.clone(),
126 required: true,
127 description: String::new(),
128 ..Default::default()
129 })
130 .collect(),
131 return_type: function.return_type.clone(),
132 };
133
134 let fn_name_for_closure = fn_name.clone();
135 module.add_function_with_schema(
136 fn_name,
137 move |args: &[ValueWord], _ctx: &crate::module_exports::ModuleContext| {
138 invoker_ref.invoke_nb(&fn_name_for_closure, args)
139 },
140 schema,
141 );
142 }
143
144 for artifact in &self.schema.artifacts {
145 module.add_shape_artifact(
146 artifact.module_path.clone(),
147 artifact.source.clone(),
148 artifact.compiled.clone(),
149 );
150 }
151
152 module
153 }
154
155 pub fn invoke_wire(&self, function: &str, args: &[WireValue]) -> Result<WireValue> {
157 let invoker = ModuleInvoker {
158 name: self.name.clone(),
159 vtable: self.vtable,
160 instance: self.instance,
161 };
162 invoker
163 .invoke_wire(function, args)
164 .map_err(|message| ShapeError::RuntimeError {
165 message,
166 location: None,
167 })
168 }
169
170 pub fn invoke_nb(&self, function: &str, args: &[ValueWord]) -> Result<ValueWord> {
175 let invoker = ModuleInvoker {
176 name: self.name.clone(),
177 vtable: self.vtable,
178 instance: self.instance,
179 };
180 invoker
181 .invoke_nb(function, args)
182 .map_err(|message| ShapeError::RuntimeError {
183 message,
184 location: None,
185 })
186 }
187}
188
189impl Drop for PluginModule {
190 fn drop(&mut self) {
191 if let Some(drop_fn) = self.vtable.drop {
192 unsafe { drop_fn(self.instance) };
193 }
194 }
195}
196
197unsafe impl Send for PluginModule {}
199unsafe impl Sync for PluginModule {}
200
201struct ModuleInvoker {
202 name: String,
203 vtable: &'static ModuleVTable,
204 instance: *mut c_void,
205}
206
207impl ModuleInvoker {
208 fn invoke_nb(
209 &self,
210 function: &str,
211 args: &[ValueWord],
212 ) -> std::result::Result<ValueWord, String> {
213 let ctx = crate::Context::new_empty();
214 let wire_args: Vec<WireValue> = args
215 .iter()
216 .map(|nb| crate::wire_conversion::nb_to_wire(nb, &ctx))
217 .collect();
218
219 let wire_bytes = rmp_serde::to_vec(&wire_args).map_err(|e| {
220 format!(
221 "Failed to serialize wire args for '{}.{}': {}",
222 self.name, function, e
223 )
224 })?;
225
226 match self
227 .invoke_with_args(function, &wire_bytes)
228 .map_err(|err| err.message)?
229 {
230 ModuleInvokePayload::Wire(bytes) => {
231 let payload = decode_payload_to_wire(&bytes).map_err(|e| {
232 format!(
233 "Failed to decode module result for '{}.{}': {}",
234 self.name, function, e
235 )
236 })?;
237 let normalized = normalize_invoke_result(payload, &self.name, function)?;
238 Ok(crate::wire_conversion::wire_to_nb(&normalized))
239 }
240 ModuleInvokePayload::TableArrowIpc(ipc_bytes) => {
241 let dt = crate::wire_conversion::datatable_from_ipc_bytes(&ipc_bytes, None, None)
242 .map_err(|e| {
243 format!(
244 "Failed to decode table payload for '{}.{}': {}",
245 self.name, function, e
246 )
247 })?;
248 Ok(ValueWord::from_datatable(Arc::new(dt)))
249 }
250 }
251 }
252
253 fn invoke_wire(
254 &self,
255 function: &str,
256 args: &[WireValue],
257 ) -> std::result::Result<WireValue, String> {
258 let wire_bytes = rmp_serde::to_vec(args).map_err(|e| {
259 format!(
260 "Failed to serialize wire args for '{}.{}': {}",
261 self.name, function, e
262 )
263 })?;
264
265 match self
266 .invoke_with_args(function, &wire_bytes)
267 .map_err(|err| err.message)?
268 {
269 ModuleInvokePayload::Wire(bytes) => {
270 let payload = decode_payload_to_wire(&bytes).map_err(|e| {
271 format!(
272 "Failed to decode module result for '{}.{}': {}",
273 self.name, function, e
274 )
275 })?;
276 normalize_invoke_result(payload, &self.name, function)
277 }
278 ModuleInvokePayload::TableArrowIpc(ipc_bytes) => {
279 let dt = crate::wire_conversion::datatable_from_ipc_bytes(&ipc_bytes, None, None)
280 .map_err(|e| {
281 format!(
282 "Failed to decode table payload for '{}.{}': {}",
283 self.name, function, e
284 )
285 })?;
286 let nb = ValueWord::from_datatable(Arc::new(dt));
287 let ctx = crate::Context::new_empty();
288 Ok(crate::wire_conversion::nb_to_wire(&nb, &ctx))
289 }
290 }
291 }
292
293 fn invoke_with_args(
294 &self,
295 function: &str,
296 args_bytes: &[u8],
297 ) -> std::result::Result<ModuleInvokePayload, ModuleInvokeFailure> {
298 if let Some(invoke_ex_fn) = self.vtable.invoke_ex {
299 let mut out = ModuleInvokeResult::empty();
300 let status = unsafe {
301 invoke_ex_fn(
302 self.instance,
303 function.as_ptr(),
304 function.len(),
305 args_bytes.as_ptr(),
306 args_bytes.len(),
307 &mut out,
308 )
309 };
310
311 if status != PluginError::Success as i32 {
312 return Err(ModuleInvokeFailure {
313 message: format!(
314 "Plugin '{}' module invoke_ex failed for '{}': status {}",
315 self.name, function, status
316 ),
317 });
318 }
319
320 let payload = self.take_payload_bytes(out.payload_ptr, out.payload_len);
321 return match out.kind {
322 ModuleInvokeResultKind::WireValueMsgpack => Ok(ModuleInvokePayload::Wire(payload)),
323 ModuleInvokeResultKind::TableArrowIpc => {
324 Ok(ModuleInvokePayload::TableArrowIpc(payload))
325 }
326 };
327 }
328
329 self.invoke_with_args_legacy(function, args_bytes)
330 }
331
332 fn invoke_with_args_legacy(
333 &self,
334 function: &str,
335 args_bytes: &[u8],
336 ) -> std::result::Result<ModuleInvokePayload, ModuleInvokeFailure> {
337 let invoke_fn = self.vtable.invoke.ok_or_else(|| ModuleInvokeFailure {
338 message: format!(
339 "Plugin '{}' module capability does not implement invoke()",
340 self.name
341 ),
342 })?;
343
344 let mut out_ptr: *mut u8 = std::ptr::null_mut();
345 let mut out_len: usize = 0;
346 let status = unsafe {
347 invoke_fn(
348 self.instance,
349 function.as_ptr(),
350 function.len(),
351 args_bytes.as_ptr(),
352 args_bytes.len(),
353 &mut out_ptr,
354 &mut out_len,
355 )
356 };
357
358 if status != PluginError::Success as i32 {
359 return Err(ModuleInvokeFailure {
360 message: format!(
361 "Plugin '{}' module invoke failed for '{}': status {}",
362 self.name, function, status
363 ),
364 });
365 }
366
367 Ok(ModuleInvokePayload::Wire(
368 self.take_payload_bytes(out_ptr, out_len),
369 ))
370 }
371
372 fn take_payload_bytes(&self, ptr: *mut u8, len: usize) -> Vec<u8> {
373 if ptr.is_null() {
374 return Vec::new();
375 }
376
377 let bytes = if len == 0 {
378 Vec::new()
379 } else {
380 unsafe { std::slice::from_raw_parts(ptr, len).to_vec() }
381 };
382
383 if let Some(free_fn) = self.vtable.free_buffer {
384 unsafe { free_fn(ptr, len) };
385 }
386 bytes
387 }
388}
389
390unsafe impl Send for ModuleInvoker {}
392unsafe impl Sync for ModuleInvoker {}
393
394#[derive(Debug)]
395struct ModuleInvokeFailure {
396 message: String,
397}
398
399#[derive(Debug)]
400enum ModuleInvokePayload {
401 Wire(Vec<u8>),
402 TableArrowIpc(Vec<u8>),
403}
404
405fn parse_module_schema(
406 vtable: &'static ModuleVTable,
407 instance: *mut c_void,
408 plugin_name: &str,
409) -> Result<ParsedModuleSchema> {
410 let get_schema_fn = vtable
411 .get_module_schema
412 .ok_or_else(|| ShapeError::RuntimeError {
413 message: format!(
414 "Plugin '{}' module capability has no get_module_schema()",
415 plugin_name
416 ),
417 location: None,
418 })?;
419
420 let mut out_ptr: *mut u8 = std::ptr::null_mut();
421 let mut out_len: usize = 0;
422 let status = unsafe { get_schema_fn(instance, &mut out_ptr, &mut out_len) };
423 if status != PluginError::Success as i32 {
424 return Err(ShapeError::RuntimeError {
425 message: format!(
426 "Plugin '{}' get_module_schema failed with status {}",
427 plugin_name, status
428 ),
429 location: None,
430 });
431 }
432
433 if out_ptr.is_null() || out_len == 0 {
434 return Err(ShapeError::RuntimeError {
435 message: format!(
436 "Plugin '{}' returned empty module schema payload",
437 plugin_name
438 ),
439 location: None,
440 });
441 }
442
443 let bytes = unsafe { std::slice::from_raw_parts(out_ptr, out_len).to_vec() };
444 if let Some(free_fn) = vtable.free_buffer {
445 unsafe { free_fn(out_ptr, out_len) };
446 }
447 let schema: AbiModuleSchema =
448 rmp_serde::from_slice(&bytes).map_err(|e| ShapeError::RuntimeError {
449 message: format!(
450 "Failed to decode module schema from '{}': {}",
451 plugin_name, e
452 ),
453 location: None,
454 })?;
455
456 let module_name = if schema.module_name.is_empty() {
457 plugin_name.to_string()
458 } else {
459 schema.module_name
460 };
461
462 let mut seen = HashSet::new();
463 let mut functions = Vec::new();
464 for f in schema.functions {
465 if f.name.is_empty() {
466 return Err(ShapeError::RuntimeError {
467 message: format!(
468 "Plugin '{}' module schema contains empty function name",
469 plugin_name
470 ),
471 location: None,
472 });
473 }
474 if !seen.insert(f.name.clone()) {
475 return Err(ShapeError::RuntimeError {
476 message: format!(
477 "Plugin '{}' module schema contains duplicate function '{}'",
478 plugin_name, f.name
479 ),
480 location: None,
481 });
482 }
483 functions.push(ParsedModuleFunction {
484 name: f.name,
485 description: f.description,
486 params: f.params,
487 return_type: f.return_type,
488 });
489 }
490
491 let artifacts = parse_module_artifacts(vtable, instance, plugin_name)?;
492
493 Ok(ParsedModuleSchema {
494 module_name,
495 functions,
496 artifacts,
497 })
498}
499
500fn parse_module_artifacts(
501 vtable: &'static ModuleVTable,
502 instance: *mut c_void,
503 plugin_name: &str,
504) -> Result<Vec<ParsedModuleArtifact>> {
505 let Some(get_artifacts_fn) = vtable.get_module_artifacts else {
506 return Ok(Vec::new());
507 };
508
509 let mut out_ptr: *mut u8 = std::ptr::null_mut();
510 let mut out_len: usize = 0;
511 let status = unsafe { get_artifacts_fn(instance, &mut out_ptr, &mut out_len) };
512 if status != PluginError::Success as i32 {
513 return Err(ShapeError::RuntimeError {
514 message: format!(
515 "Plugin '{}' get_module_artifacts failed with status {}",
516 plugin_name, status
517 ),
518 location: None,
519 });
520 }
521
522 if out_ptr.is_null() || out_len == 0 {
523 return Ok(Vec::new());
524 }
525
526 let bytes = unsafe { std::slice::from_raw_parts(out_ptr, out_len).to_vec() };
527 if let Some(free_fn) = vtable.free_buffer {
528 unsafe { free_fn(out_ptr, out_len) };
529 }
530
531 let parsed = rmp_serde::from_slice::<Vec<ArtifactPayload>>(&bytes).map_err(|e| {
532 ShapeError::RuntimeError {
533 message: format!(
534 "Failed to decode module artifacts from '{}': {}",
535 plugin_name, e
536 ),
537 location: None,
538 }
539 })?;
540
541 let mut seen_paths = HashSet::new();
542 let mut artifacts = Vec::new();
543 for item in parsed {
544 if item.module_path.is_empty() {
545 return Err(ShapeError::RuntimeError {
546 message: format!(
547 "Plugin '{}' module artifacts contain empty module_path",
548 plugin_name
549 ),
550 location: None,
551 });
552 }
553 if !seen_paths.insert(item.module_path.clone()) {
554 return Err(ShapeError::RuntimeError {
555 message: format!(
556 "Plugin '{}' module artifacts contain duplicate module_path '{}'",
557 plugin_name, item.module_path
558 ),
559 location: None,
560 });
561 }
562 artifacts.push(ParsedModuleArtifact {
563 module_path: item.module_path,
564 source: item.source,
565 compiled: item.compiled,
566 });
567 }
568
569 Ok(artifacts)
570}
571
572fn decode_payload_to_wire(bytes: &[u8]) -> std::result::Result<WireValue, String> {
573 if bytes.is_empty() {
574 return Ok(WireValue::Null);
575 }
576 rmp_serde::from_slice::<WireValue>(bytes).map_err(|e| format!("invalid wire payload: {}", e))
577}
578
579fn normalize_invoke_result(
580 payload: WireValue,
581 module_name: &str,
582 function: &str,
583) -> std::result::Result<WireValue, String> {
584 match payload {
585 WireValue::Result { ok, value } => {
586 if ok {
587 Ok(*value)
588 } else {
589 Err(format!(
590 "Plugin '{}.{}' failed: {}",
591 module_name,
592 function,
593 format_wire_error_message(&value)
594 ))
595 }
596 }
597 other => Ok(other),
598 }
599}
600
601fn format_wire_error_message(value: &WireValue) -> String {
602 if let Some(rendered) = render_any_error_plain(value) {
603 return rendered;
604 }
605
606 match value {
607 WireValue::String(s) => s.clone(),
608 WireValue::Object(map) => {
609 if let Some(WireValue::String(message)) = map.get("message") {
610 message.clone()
611 } else {
612 format!("{value:?}")
613 }
614 }
615 _ => format!("{value:?}"),
616 }
617}