Skip to main content

shape_runtime/plugins/
module_capability.rs

1//! Base module capability wrapper (`shape.module`).
2//!
3//! Exposes plugin module functions as runtime `ModuleExports` exports.
4
5use serde_json::Value;
6use shape_abi_v1::{
7    ModuleInvokeResult, ModuleInvokeResultKind, ModuleSchema as AbiModuleSchema, ModuleVTable,
8    PluginError,
9};
10use shape_ast::error::{Result, ShapeError};
11use shape_wire::{WireValue, render_any_error_plain};
12use std::collections::HashSet;
13use std::ffi::c_void;
14
15#[derive(Debug, Clone, serde::Deserialize)]
16struct ArtifactPayload {
17    module_path: String,
18    #[serde(default)]
19    source: Option<String>,
20    #[serde(default)]
21    compiled: Option<Vec<u8>>,
22}
23
24/// Parsed schema for one module export.
25#[derive(Debug, Clone)]
26pub struct ParsedModuleFunction {
27    pub name: String,
28    pub description: String,
29    pub params: Vec<String>,
30    pub return_type: Option<String>,
31}
32
33/// Parsed bundled module artifact.
34#[derive(Debug, Clone)]
35pub struct ParsedModuleArtifact {
36    pub module_path: String,
37    pub source: Option<String>,
38    pub compiled: Option<Vec<u8>>,
39}
40
41/// Parsed `shape.module` schema.
42#[derive(Debug, Clone)]
43pub struct ParsedModuleSchema {
44    pub module_name: String,
45    pub functions: Vec<ParsedModuleFunction>,
46    pub artifacts: Vec<ParsedModuleArtifact>,
47}
48
49/// Wrapper around the `shape.module` capability.
50pub struct PluginModule {
51    name: String,
52    vtable: &'static ModuleVTable,
53    instance: *mut c_void,
54    schema: ParsedModuleSchema,
55}
56
57impl PluginModule {
58    /// Create a new module-capability wrapper from a plugin vtable.
59    pub fn new(name: String, vtable: &'static ModuleVTable, config: &Value) -> Result<Self> {
60        let config_bytes = rmp_serde::to_vec(config).map_err(|e| ShapeError::RuntimeError {
61            message: format!("Failed to serialize module config for '{}': {}", name, e),
62            location: None,
63        })?;
64
65        let init_fn = vtable.init.ok_or_else(|| ShapeError::RuntimeError {
66            message: format!("Plugin '{}' module capability has no init function", name),
67            location: None,
68        })?;
69
70        let instance = unsafe { init_fn(config_bytes.as_ptr(), config_bytes.len()) };
71        if instance.is_null() {
72            return Err(ShapeError::RuntimeError {
73                message: format!("Plugin '{}' module init returned null", name),
74                location: None,
75            });
76        }
77
78        let schema = parse_module_schema(vtable, instance, &name)?;
79
80        Ok(Self {
81            name,
82            vtable,
83            instance,
84            schema,
85        })
86    }
87
88    /// Plugin/module name.
89    pub fn name(&self) -> &str {
90        &self.name
91    }
92
93    /// Parsed module schema.
94    pub fn schema(&self) -> &ParsedModuleSchema {
95        &self.schema
96    }
97
98    /// Invoke one module export with shape-wire arguments/results.
99    pub fn invoke_wire(&self, function: &str, args: &[WireValue]) -> Result<WireValue> {
100        let invoker = ModuleInvoker {
101            name: self.name.clone(),
102            vtable: self.vtable,
103            instance: self.instance,
104        };
105        invoker
106            .invoke_wire(function, args)
107            .map_err(|message| ShapeError::RuntimeError {
108                message,
109                location: None,
110            })
111    }
112}
113
114impl Drop for PluginModule {
115    fn drop(&mut self) {
116        if let Some(drop_fn) = self.vtable.drop {
117            unsafe { drop_fn(self.instance) };
118        }
119    }
120}
121
122// SAFETY: access goes through plugin vtable calls that are required to be thread-safe.
123unsafe impl Send for PluginModule {}
124unsafe impl Sync for PluginModule {}
125
126struct ModuleInvoker {
127    name: String,
128    vtable: &'static ModuleVTable,
129    instance: *mut c_void,
130}
131
132impl ModuleInvoker {
133    fn invoke_wire(
134        &self,
135        function: &str,
136        args: &[WireValue],
137    ) -> std::result::Result<WireValue, String> {
138        let wire_bytes = rmp_serde::to_vec(args).map_err(|e| {
139            format!(
140                "Failed to serialize wire args for '{}.{}': {}",
141                self.name, function, e
142            )
143        })?;
144
145        match self
146            .invoke_with_args(function, &wire_bytes)
147            .map_err(|err| err.message)?
148        {
149            ModuleInvokePayload::Wire(bytes) => {
150                let payload = decode_payload_to_wire(&bytes).map_err(|e| {
151                    format!(
152                        "Failed to decode module result for '{}.{}': {}",
153                        self.name, function, e
154                    )
155                })?;
156                normalize_invoke_result(payload, &self.name, function)
157            }
158            ModuleInvokePayload::TableArrowIpc(ipc_bytes) => {
159                let dt = crate::wire_conversion::datatable_from_ipc_bytes(&ipc_bytes, None, None)
160                    .map_err(|e| {
161                    format!(
162                        "Failed to decode table payload for '{}.{}': {}",
163                        self.name, function, e
164                    )
165                })?;
166                Ok(crate::wire_conversion::datatable_to_wire(&dt))
167            }
168        }
169    }
170
171    fn invoke_with_args(
172        &self,
173        function: &str,
174        args_bytes: &[u8],
175    ) -> std::result::Result<ModuleInvokePayload, ModuleInvokeFailure> {
176        if let Some(invoke_ex_fn) = self.vtable.invoke_ex {
177            let mut out = ModuleInvokeResult::empty();
178            let status = unsafe {
179                invoke_ex_fn(
180                    self.instance,
181                    function.as_ptr(),
182                    function.len(),
183                    args_bytes.as_ptr(),
184                    args_bytes.len(),
185                    &mut out,
186                )
187            };
188
189            if status != PluginError::Success as i32 {
190                return Err(ModuleInvokeFailure {
191                    message: format!(
192                        "Plugin '{}' module invoke_ex failed for '{}': status {}",
193                        self.name, function, status
194                    ),
195                });
196            }
197
198            let payload = self.take_payload_bytes(out.payload_ptr, out.payload_len);
199            return match out.kind {
200                ModuleInvokeResultKind::WireValueMsgpack => Ok(ModuleInvokePayload::Wire(payload)),
201                ModuleInvokeResultKind::TableArrowIpc => {
202                    Ok(ModuleInvokePayload::TableArrowIpc(payload))
203                }
204            };
205        }
206
207        self.invoke_with_args_legacy(function, args_bytes)
208    }
209
210    fn invoke_with_args_legacy(
211        &self,
212        function: &str,
213        args_bytes: &[u8],
214    ) -> std::result::Result<ModuleInvokePayload, ModuleInvokeFailure> {
215        let invoke_fn = self.vtable.invoke.ok_or_else(|| ModuleInvokeFailure {
216            message: format!(
217                "Plugin '{}' module capability does not implement invoke()",
218                self.name
219            ),
220        })?;
221
222        let mut out_ptr: *mut u8 = std::ptr::null_mut();
223        let mut out_len: usize = 0;
224        let status = unsafe {
225            invoke_fn(
226                self.instance,
227                function.as_ptr(),
228                function.len(),
229                args_bytes.as_ptr(),
230                args_bytes.len(),
231                &mut out_ptr,
232                &mut out_len,
233            )
234        };
235
236        if status != PluginError::Success as i32 {
237            return Err(ModuleInvokeFailure {
238                message: format!(
239                    "Plugin '{}' module invoke failed for '{}': status {}",
240                    self.name, function, status
241                ),
242            });
243        }
244
245        Ok(ModuleInvokePayload::Wire(
246            self.take_payload_bytes(out_ptr, out_len),
247        ))
248    }
249
250    fn take_payload_bytes(&self, ptr: *mut u8, len: usize) -> Vec<u8> {
251        if ptr.is_null() {
252            return Vec::new();
253        }
254
255        let bytes = if len == 0 {
256            Vec::new()
257        } else {
258            unsafe { std::slice::from_raw_parts(ptr, len).to_vec() }
259        };
260
261        if let Some(free_fn) = self.vtable.free_buffer {
262            unsafe { free_fn(ptr, len) };
263        }
264        bytes
265    }
266}
267
268// SAFETY: access goes through plugin vtable calls that are required to be thread-safe.
269unsafe impl Send for ModuleInvoker {}
270unsafe impl Sync for ModuleInvoker {}
271
272#[derive(Debug)]
273struct ModuleInvokeFailure {
274    message: String,
275}
276
277#[derive(Debug)]
278enum ModuleInvokePayload {
279    Wire(Vec<u8>),
280    TableArrowIpc(Vec<u8>),
281}
282
283fn parse_module_schema(
284    vtable: &'static ModuleVTable,
285    instance: *mut c_void,
286    plugin_name: &str,
287) -> Result<ParsedModuleSchema> {
288    let get_schema_fn = vtable
289        .get_module_schema
290        .ok_or_else(|| ShapeError::RuntimeError {
291            message: format!(
292                "Plugin '{}' module capability has no get_module_schema()",
293                plugin_name
294            ),
295            location: None,
296        })?;
297
298    let mut out_ptr: *mut u8 = std::ptr::null_mut();
299    let mut out_len: usize = 0;
300    let status = unsafe { get_schema_fn(instance, &mut out_ptr, &mut out_len) };
301    if status != PluginError::Success as i32 {
302        return Err(ShapeError::RuntimeError {
303            message: format!(
304                "Plugin '{}' get_module_schema failed with status {}",
305                plugin_name, status
306            ),
307            location: None,
308        });
309    }
310
311    if out_ptr.is_null() || out_len == 0 {
312        return Err(ShapeError::RuntimeError {
313            message: format!(
314                "Plugin '{}' returned empty module schema payload",
315                plugin_name
316            ),
317            location: None,
318        });
319    }
320
321    let bytes = unsafe { std::slice::from_raw_parts(out_ptr, out_len).to_vec() };
322    if let Some(free_fn) = vtable.free_buffer {
323        unsafe { free_fn(out_ptr, out_len) };
324    }
325    let schema: AbiModuleSchema =
326        rmp_serde::from_slice(&bytes).map_err(|e| ShapeError::RuntimeError {
327            message: format!(
328                "Failed to decode module schema from '{}': {}",
329                plugin_name, e
330            ),
331            location: None,
332        })?;
333
334    let module_name = if schema.module_name.is_empty() {
335        plugin_name.to_string()
336    } else {
337        schema.module_name
338    };
339
340    let mut seen = HashSet::new();
341    let mut functions = Vec::new();
342    for f in schema.functions {
343        if f.name.is_empty() {
344            return Err(ShapeError::RuntimeError {
345                message: format!(
346                    "Plugin '{}' module schema contains empty function name",
347                    plugin_name
348                ),
349                location: None,
350            });
351        }
352        if !seen.insert(f.name.clone()) {
353            return Err(ShapeError::RuntimeError {
354                message: format!(
355                    "Plugin '{}' module schema contains duplicate function '{}'",
356                    plugin_name, f.name
357                ),
358                location: None,
359            });
360        }
361        functions.push(ParsedModuleFunction {
362            name: f.name,
363            description: f.description,
364            params: f.params,
365            return_type: f.return_type,
366        });
367    }
368
369    let artifacts = parse_module_artifacts(vtable, instance, plugin_name)?;
370
371    Ok(ParsedModuleSchema {
372        module_name,
373        functions,
374        artifacts,
375    })
376}
377
378fn parse_module_artifacts(
379    vtable: &'static ModuleVTable,
380    instance: *mut c_void,
381    plugin_name: &str,
382) -> Result<Vec<ParsedModuleArtifact>> {
383    let Some(get_artifacts_fn) = vtable.get_module_artifacts else {
384        return Ok(Vec::new());
385    };
386
387    let mut out_ptr: *mut u8 = std::ptr::null_mut();
388    let mut out_len: usize = 0;
389    let status = unsafe { get_artifacts_fn(instance, &mut out_ptr, &mut out_len) };
390    if status != PluginError::Success as i32 {
391        return Err(ShapeError::RuntimeError {
392            message: format!(
393                "Plugin '{}' get_module_artifacts failed with status {}",
394                plugin_name, status
395            ),
396            location: None,
397        });
398    }
399
400    if out_ptr.is_null() || out_len == 0 {
401        return Ok(Vec::new());
402    }
403
404    let bytes = unsafe { std::slice::from_raw_parts(out_ptr, out_len).to_vec() };
405    if let Some(free_fn) = vtable.free_buffer {
406        unsafe { free_fn(out_ptr, out_len) };
407    }
408
409    let parsed = rmp_serde::from_slice::<Vec<ArtifactPayload>>(&bytes).map_err(|e| {
410        ShapeError::RuntimeError {
411            message: format!(
412                "Failed to decode module artifacts from '{}': {}",
413                plugin_name, e
414            ),
415            location: None,
416        }
417    })?;
418
419    let mut seen_paths = HashSet::new();
420    let mut artifacts = Vec::new();
421    for item in parsed {
422        if item.module_path.is_empty() {
423            return Err(ShapeError::RuntimeError {
424                message: format!(
425                    "Plugin '{}' module artifacts contain empty module_path",
426                    plugin_name
427                ),
428                location: None,
429            });
430        }
431        if !seen_paths.insert(item.module_path.clone()) {
432            return Err(ShapeError::RuntimeError {
433                message: format!(
434                    "Plugin '{}' module artifacts contain duplicate module_path '{}'",
435                    plugin_name, item.module_path
436                ),
437                location: None,
438            });
439        }
440        artifacts.push(ParsedModuleArtifact {
441            module_path: item.module_path,
442            source: item.source,
443            compiled: item.compiled,
444        });
445    }
446
447    Ok(artifacts)
448}
449
450fn decode_payload_to_wire(bytes: &[u8]) -> std::result::Result<WireValue, String> {
451    if bytes.is_empty() {
452        return Ok(WireValue::Null);
453    }
454    rmp_serde::from_slice::<WireValue>(bytes).map_err(|e| format!("invalid wire payload: {}", e))
455}
456
457fn normalize_invoke_result(
458    payload: WireValue,
459    module_name: &str,
460    function: &str,
461) -> std::result::Result<WireValue, String> {
462    match payload {
463        WireValue::Result { ok, value } => {
464            if ok {
465                Ok(*value)
466            } else {
467                Err(format!(
468                    "Plugin '{}.{}' failed: {}",
469                    module_name,
470                    function,
471                    format_wire_error_message(&value)
472                ))
473            }
474        }
475        other => Ok(other),
476    }
477}
478
479fn format_wire_error_message(value: &WireValue) -> String {
480    if let Some(rendered) = render_any_error_plain(value) {
481        return rendered;
482    }
483
484    match value {
485        WireValue::String(s) => s.clone(),
486        WireValue::Object(map) => {
487            if let Some(WireValue::String(message)) = map.get("message") {
488                message.clone()
489            } else {
490                format!("{value:?}")
491            }
492        }
493        _ => format!("{value:?}"),
494    }
495}