1use serde_json::Value;
6use shape_abi_v1::{
7 ModuleInvokeResult, ModuleInvokeResultKind, ModuleSchema as AbiModuleSchema, ModuleVTable,
8 PluginError,
9};
10use shape_ast::error::{Result, ShapeError};
11use shape_wire::{WireValue, render_any_error_plain};
12use std::collections::HashSet;
13use std::ffi::c_void;
14
15#[derive(Debug, Clone, serde::Deserialize)]
16struct ArtifactPayload {
17 module_path: String,
18 #[serde(default)]
19 source: Option<String>,
20 #[serde(default)]
21 compiled: Option<Vec<u8>>,
22}
23
24#[derive(Debug, Clone)]
26pub struct ParsedModuleFunction {
27 pub name: String,
28 pub description: String,
29 pub params: Vec<String>,
30 pub return_type: Option<String>,
31}
32
33#[derive(Debug, Clone)]
35pub struct ParsedModuleArtifact {
36 pub module_path: String,
37 pub source: Option<String>,
38 pub compiled: Option<Vec<u8>>,
39}
40
41#[derive(Debug, Clone)]
43pub struct ParsedModuleSchema {
44 pub module_name: String,
45 pub functions: Vec<ParsedModuleFunction>,
46 pub artifacts: Vec<ParsedModuleArtifact>,
47}
48
49pub struct PluginModule {
51 name: String,
52 vtable: &'static ModuleVTable,
53 instance: *mut c_void,
54 schema: ParsedModuleSchema,
55}
56
57impl PluginModule {
58 pub fn new(name: String, vtable: &'static ModuleVTable, config: &Value) -> Result<Self> {
60 let config_bytes = rmp_serde::to_vec(config).map_err(|e| ShapeError::RuntimeError {
61 message: format!("Failed to serialize module config for '{}': {}", name, e),
62 location: None,
63 })?;
64
65 let init_fn = vtable.init.ok_or_else(|| ShapeError::RuntimeError {
66 message: format!("Plugin '{}' module capability has no init function", name),
67 location: None,
68 })?;
69
70 let instance = unsafe { init_fn(config_bytes.as_ptr(), config_bytes.len()) };
71 if instance.is_null() {
72 return Err(ShapeError::RuntimeError {
73 message: format!("Plugin '{}' module init returned null", name),
74 location: None,
75 });
76 }
77
78 let schema = parse_module_schema(vtable, instance, &name)?;
79
80 Ok(Self {
81 name,
82 vtable,
83 instance,
84 schema,
85 })
86 }
87
88 pub fn name(&self) -> &str {
90 &self.name
91 }
92
93 pub fn schema(&self) -> &ParsedModuleSchema {
95 &self.schema
96 }
97
98 pub fn invoke_wire(&self, function: &str, args: &[WireValue]) -> Result<WireValue> {
100 let invoker = ModuleInvoker {
101 name: self.name.clone(),
102 vtable: self.vtable,
103 instance: self.instance,
104 };
105 invoker
106 .invoke_wire(function, args)
107 .map_err(|message| ShapeError::RuntimeError {
108 message,
109 location: None,
110 })
111 }
112}
113
114impl Drop for PluginModule {
115 fn drop(&mut self) {
116 if let Some(drop_fn) = self.vtable.drop {
117 unsafe { drop_fn(self.instance) };
118 }
119 }
120}
121
122unsafe impl Send for PluginModule {}
124unsafe impl Sync for PluginModule {}
125
126struct ModuleInvoker {
127 name: String,
128 vtable: &'static ModuleVTable,
129 instance: *mut c_void,
130}
131
132impl ModuleInvoker {
133 fn invoke_wire(
134 &self,
135 function: &str,
136 args: &[WireValue],
137 ) -> std::result::Result<WireValue, String> {
138 let wire_bytes = rmp_serde::to_vec(args).map_err(|e| {
139 format!(
140 "Failed to serialize wire args for '{}.{}': {}",
141 self.name, function, e
142 )
143 })?;
144
145 match self
146 .invoke_with_args(function, &wire_bytes)
147 .map_err(|err| err.message)?
148 {
149 ModuleInvokePayload::Wire(bytes) => {
150 let payload = decode_payload_to_wire(&bytes).map_err(|e| {
151 format!(
152 "Failed to decode module result for '{}.{}': {}",
153 self.name, function, e
154 )
155 })?;
156 normalize_invoke_result(payload, &self.name, function)
157 }
158 ModuleInvokePayload::TableArrowIpc(ipc_bytes) => {
159 let dt = crate::wire_conversion::datatable_from_ipc_bytes(&ipc_bytes, None, None)
160 .map_err(|e| {
161 format!(
162 "Failed to decode table payload for '{}.{}': {}",
163 self.name, function, e
164 )
165 })?;
166 Ok(crate::wire_conversion::datatable_to_wire(&dt))
167 }
168 }
169 }
170
171 fn invoke_with_args(
172 &self,
173 function: &str,
174 args_bytes: &[u8],
175 ) -> std::result::Result<ModuleInvokePayload, ModuleInvokeFailure> {
176 if let Some(invoke_ex_fn) = self.vtable.invoke_ex {
177 let mut out = ModuleInvokeResult::empty();
178 let status = unsafe {
179 invoke_ex_fn(
180 self.instance,
181 function.as_ptr(),
182 function.len(),
183 args_bytes.as_ptr(),
184 args_bytes.len(),
185 &mut out,
186 )
187 };
188
189 if status != PluginError::Success as i32 {
190 return Err(ModuleInvokeFailure {
191 message: format!(
192 "Plugin '{}' module invoke_ex failed for '{}': status {}",
193 self.name, function, status
194 ),
195 });
196 }
197
198 let payload = self.take_payload_bytes(out.payload_ptr, out.payload_len);
199 return match out.kind {
200 ModuleInvokeResultKind::WireValueMsgpack => Ok(ModuleInvokePayload::Wire(payload)),
201 ModuleInvokeResultKind::TableArrowIpc => {
202 Ok(ModuleInvokePayload::TableArrowIpc(payload))
203 }
204 };
205 }
206
207 self.invoke_with_args_legacy(function, args_bytes)
208 }
209
210 fn invoke_with_args_legacy(
211 &self,
212 function: &str,
213 args_bytes: &[u8],
214 ) -> std::result::Result<ModuleInvokePayload, ModuleInvokeFailure> {
215 let invoke_fn = self.vtable.invoke.ok_or_else(|| ModuleInvokeFailure {
216 message: format!(
217 "Plugin '{}' module capability does not implement invoke()",
218 self.name
219 ),
220 })?;
221
222 let mut out_ptr: *mut u8 = std::ptr::null_mut();
223 let mut out_len: usize = 0;
224 let status = unsafe {
225 invoke_fn(
226 self.instance,
227 function.as_ptr(),
228 function.len(),
229 args_bytes.as_ptr(),
230 args_bytes.len(),
231 &mut out_ptr,
232 &mut out_len,
233 )
234 };
235
236 if status != PluginError::Success as i32 {
237 return Err(ModuleInvokeFailure {
238 message: format!(
239 "Plugin '{}' module invoke failed for '{}': status {}",
240 self.name, function, status
241 ),
242 });
243 }
244
245 Ok(ModuleInvokePayload::Wire(
246 self.take_payload_bytes(out_ptr, out_len),
247 ))
248 }
249
250 fn take_payload_bytes(&self, ptr: *mut u8, len: usize) -> Vec<u8> {
251 if ptr.is_null() {
252 return Vec::new();
253 }
254
255 let bytes = if len == 0 {
256 Vec::new()
257 } else {
258 unsafe { std::slice::from_raw_parts(ptr, len).to_vec() }
259 };
260
261 if let Some(free_fn) = self.vtable.free_buffer {
262 unsafe { free_fn(ptr, len) };
263 }
264 bytes
265 }
266}
267
268unsafe impl Send for ModuleInvoker {}
270unsafe impl Sync for ModuleInvoker {}
271
272#[derive(Debug)]
273struct ModuleInvokeFailure {
274 message: String,
275}
276
277#[derive(Debug)]
278enum ModuleInvokePayload {
279 Wire(Vec<u8>),
280 TableArrowIpc(Vec<u8>),
281}
282
283fn parse_module_schema(
284 vtable: &'static ModuleVTable,
285 instance: *mut c_void,
286 plugin_name: &str,
287) -> Result<ParsedModuleSchema> {
288 let get_schema_fn = vtable
289 .get_module_schema
290 .ok_or_else(|| ShapeError::RuntimeError {
291 message: format!(
292 "Plugin '{}' module capability has no get_module_schema()",
293 plugin_name
294 ),
295 location: None,
296 })?;
297
298 let mut out_ptr: *mut u8 = std::ptr::null_mut();
299 let mut out_len: usize = 0;
300 let status = unsafe { get_schema_fn(instance, &mut out_ptr, &mut out_len) };
301 if status != PluginError::Success as i32 {
302 return Err(ShapeError::RuntimeError {
303 message: format!(
304 "Plugin '{}' get_module_schema failed with status {}",
305 plugin_name, status
306 ),
307 location: None,
308 });
309 }
310
311 if out_ptr.is_null() || out_len == 0 {
312 return Err(ShapeError::RuntimeError {
313 message: format!(
314 "Plugin '{}' returned empty module schema payload",
315 plugin_name
316 ),
317 location: None,
318 });
319 }
320
321 let bytes = unsafe { std::slice::from_raw_parts(out_ptr, out_len).to_vec() };
322 if let Some(free_fn) = vtable.free_buffer {
323 unsafe { free_fn(out_ptr, out_len) };
324 }
325 let schema: AbiModuleSchema =
326 rmp_serde::from_slice(&bytes).map_err(|e| ShapeError::RuntimeError {
327 message: format!(
328 "Failed to decode module schema from '{}': {}",
329 plugin_name, e
330 ),
331 location: None,
332 })?;
333
334 let module_name = if schema.module_name.is_empty() {
335 plugin_name.to_string()
336 } else {
337 schema.module_name
338 };
339
340 let mut seen = HashSet::new();
341 let mut functions = Vec::new();
342 for f in schema.functions {
343 if f.name.is_empty() {
344 return Err(ShapeError::RuntimeError {
345 message: format!(
346 "Plugin '{}' module schema contains empty function name",
347 plugin_name
348 ),
349 location: None,
350 });
351 }
352 if !seen.insert(f.name.clone()) {
353 return Err(ShapeError::RuntimeError {
354 message: format!(
355 "Plugin '{}' module schema contains duplicate function '{}'",
356 plugin_name, f.name
357 ),
358 location: None,
359 });
360 }
361 functions.push(ParsedModuleFunction {
362 name: f.name,
363 description: f.description,
364 params: f.params,
365 return_type: f.return_type,
366 });
367 }
368
369 let artifacts = parse_module_artifacts(vtable, instance, plugin_name)?;
370
371 Ok(ParsedModuleSchema {
372 module_name,
373 functions,
374 artifacts,
375 })
376}
377
378fn parse_module_artifacts(
379 vtable: &'static ModuleVTable,
380 instance: *mut c_void,
381 plugin_name: &str,
382) -> Result<Vec<ParsedModuleArtifact>> {
383 let Some(get_artifacts_fn) = vtable.get_module_artifacts else {
384 return Ok(Vec::new());
385 };
386
387 let mut out_ptr: *mut u8 = std::ptr::null_mut();
388 let mut out_len: usize = 0;
389 let status = unsafe { get_artifacts_fn(instance, &mut out_ptr, &mut out_len) };
390 if status != PluginError::Success as i32 {
391 return Err(ShapeError::RuntimeError {
392 message: format!(
393 "Plugin '{}' get_module_artifacts failed with status {}",
394 plugin_name, status
395 ),
396 location: None,
397 });
398 }
399
400 if out_ptr.is_null() || out_len == 0 {
401 return Ok(Vec::new());
402 }
403
404 let bytes = unsafe { std::slice::from_raw_parts(out_ptr, out_len).to_vec() };
405 if let Some(free_fn) = vtable.free_buffer {
406 unsafe { free_fn(out_ptr, out_len) };
407 }
408
409 let parsed = rmp_serde::from_slice::<Vec<ArtifactPayload>>(&bytes).map_err(|e| {
410 ShapeError::RuntimeError {
411 message: format!(
412 "Failed to decode module artifacts from '{}': {}",
413 plugin_name, e
414 ),
415 location: None,
416 }
417 })?;
418
419 let mut seen_paths = HashSet::new();
420 let mut artifacts = Vec::new();
421 for item in parsed {
422 if item.module_path.is_empty() {
423 return Err(ShapeError::RuntimeError {
424 message: format!(
425 "Plugin '{}' module artifacts contain empty module_path",
426 plugin_name
427 ),
428 location: None,
429 });
430 }
431 if !seen_paths.insert(item.module_path.clone()) {
432 return Err(ShapeError::RuntimeError {
433 message: format!(
434 "Plugin '{}' module artifacts contain duplicate module_path '{}'",
435 plugin_name, item.module_path
436 ),
437 location: None,
438 });
439 }
440 artifacts.push(ParsedModuleArtifact {
441 module_path: item.module_path,
442 source: item.source,
443 compiled: item.compiled,
444 });
445 }
446
447 Ok(artifacts)
448}
449
450fn decode_payload_to_wire(bytes: &[u8]) -> std::result::Result<WireValue, String> {
451 if bytes.is_empty() {
452 return Ok(WireValue::Null);
453 }
454 rmp_serde::from_slice::<WireValue>(bytes).map_err(|e| format!("invalid wire payload: {}", e))
455}
456
457fn normalize_invoke_result(
458 payload: WireValue,
459 module_name: &str,
460 function: &str,
461) -> std::result::Result<WireValue, String> {
462 match payload {
463 WireValue::Result { ok, value } => {
464 if ok {
465 Ok(*value)
466 } else {
467 Err(format!(
468 "Plugin '{}.{}' failed: {}",
469 module_name,
470 function,
471 format_wire_error_message(&value)
472 ))
473 }
474 }
475 other => Ok(other),
476 }
477}
478
479fn format_wire_error_message(value: &WireValue) -> String {
480 if let Some(rendered) = render_any_error_plain(value) {
481 return rendered;
482 }
483
484 match value {
485 WireValue::String(s) => s.clone(),
486 WireValue::Object(map) => {
487 if let Some(WireValue::String(message)) = map.get("message") {
488 message.clone()
489 } else {
490 format!("{value:?}")
491 }
492 }
493 _ => format!("{value:?}"),
494 }
495}