Skip to main content

shape_runtime/plugins/
module_capability.rs

1//! Base module capability wrapper (`shape.module`).
2//!
3//! Exposes plugin module functions as runtime `ModuleExports` exports.
4
5use serde_json::Value;
6use shape_abi_v1::{
7    ModuleInvokeResult, ModuleInvokeResultKind, ModuleSchema as AbiModuleSchema, ModuleVTable,
8    PluginError,
9};
10use shape_ast::error::{Result, ShapeError};
11use shape_value::ValueWord;
12use shape_wire::{WireValue, render_any_error_plain};
13use std::collections::HashSet;
14use std::ffi::c_void;
15use std::sync::Arc;
16
17#[derive(Debug, Clone, serde::Deserialize)]
18struct ArtifactPayload {
19    module_path: String,
20    #[serde(default)]
21    source: Option<String>,
22    #[serde(default)]
23    compiled: Option<Vec<u8>>,
24}
25
26/// Parsed schema for one module export.
27#[derive(Debug, Clone)]
28pub struct ParsedModuleFunction {
29    pub name: String,
30    pub description: String,
31    pub params: Vec<String>,
32    pub return_type: Option<String>,
33}
34
35/// Parsed bundled module artifact.
36#[derive(Debug, Clone)]
37pub struct ParsedModuleArtifact {
38    pub module_path: String,
39    pub source: Option<String>,
40    pub compiled: Option<Vec<u8>>,
41}
42
43/// Parsed `shape.module` schema.
44#[derive(Debug, Clone)]
45pub struct ParsedModuleSchema {
46    pub module_name: String,
47    pub functions: Vec<ParsedModuleFunction>,
48    pub artifacts: Vec<ParsedModuleArtifact>,
49}
50
51/// Wrapper around the `shape.module` capability.
52pub struct PluginModule {
53    name: String,
54    vtable: &'static ModuleVTable,
55    instance: *mut c_void,
56    schema: ParsedModuleSchema,
57}
58
59impl PluginModule {
60    /// Create a new module-capability wrapper from a plugin vtable.
61    pub fn new(name: String, vtable: &'static ModuleVTable, config: &Value) -> Result<Self> {
62        let config_bytes = rmp_serde::to_vec(config).map_err(|e| ShapeError::RuntimeError {
63            message: format!("Failed to serialize module config for '{}': {}", name, e),
64            location: None,
65        })?;
66
67        let init_fn = vtable.init.ok_or_else(|| ShapeError::RuntimeError {
68            message: format!("Plugin '{}' module capability has no init function", name),
69            location: None,
70        })?;
71
72        let instance = unsafe { init_fn(config_bytes.as_ptr(), config_bytes.len()) };
73        if instance.is_null() {
74            return Err(ShapeError::RuntimeError {
75                message: format!("Plugin '{}' module init returned null", name),
76                location: None,
77            });
78        }
79
80        let schema = parse_module_schema(vtable, instance, &name)?;
81
82        Ok(Self {
83            name,
84            vtable,
85            instance,
86            schema,
87        })
88    }
89
90    /// Plugin/module name.
91    pub fn name(&self) -> &str {
92        &self.name
93    }
94
95    /// Parsed module schema.
96    pub fn schema(&self) -> &ParsedModuleSchema {
97        &self.schema
98    }
99
100    /// Build a runtime `ModuleExports` wrapper for VM module dispatch.
101    pub fn to_module_exports(&self) -> crate::module_exports::ModuleExports {
102        use crate::module_exports::{ModuleExports, ModuleFunction, ModuleParam};
103
104        let mut module = ModuleExports::new(self.schema.module_name.clone());
105        module.description = format!("Plugin module exported by '{}'", self.name);
106
107        let invoker = Arc::new(ModuleInvoker {
108            name: self.name.clone(),
109            vtable: self.vtable,
110            instance: self.instance,
111        });
112
113        for function in &self.schema.functions {
114            let fn_name = function.name.clone();
115            let invoker_ref = Arc::clone(&invoker);
116
117            let schema = ModuleFunction {
118                description: function.description.clone(),
119                params: function
120                    .params
121                    .iter()
122                    .enumerate()
123                    .map(|(idx, ty)| ModuleParam {
124                        name: format!("arg{}", idx),
125                        type_name: ty.clone(),
126                        required: true,
127                        description: String::new(),
128                        ..Default::default()
129                    })
130                    .collect(),
131                return_type: function.return_type.clone(),
132            };
133
134            let fn_name_for_closure = fn_name.clone();
135            module.add_function_with_schema(
136                fn_name,
137                move |args: &[ValueWord], _ctx: &crate::module_exports::ModuleContext| {
138                    invoker_ref.invoke_nb(&fn_name_for_closure, args)
139                },
140                schema,
141            );
142        }
143
144        for artifact in &self.schema.artifacts {
145            module.add_shape_artifact(
146                artifact.module_path.clone(),
147                artifact.source.clone(),
148                artifact.compiled.clone(),
149            );
150        }
151
152        module
153    }
154
155    /// Invoke one module export with shape-wire arguments/results.
156    pub fn invoke_wire(&self, function: &str, args: &[WireValue]) -> Result<WireValue> {
157        let invoker = ModuleInvoker {
158            name: self.name.clone(),
159            vtable: self.vtable,
160            instance: self.instance,
161        };
162        invoker
163            .invoke_wire(function, args)
164            .map_err(|message| ShapeError::RuntimeError {
165                message,
166                location: None,
167            })
168    }
169
170    /// Invoke one module export with ValueWord arguments/results.
171    ///
172    /// This is the primary host-side call path for runtime/LSP internals and
173    /// uses `shape-wire` payloads end-to-end.
174    pub fn invoke_nb(&self, function: &str, args: &[ValueWord]) -> Result<ValueWord> {
175        let invoker = ModuleInvoker {
176            name: self.name.clone(),
177            vtable: self.vtable,
178            instance: self.instance,
179        };
180        invoker
181            .invoke_nb(function, args)
182            .map_err(|message| ShapeError::RuntimeError {
183                message,
184                location: None,
185            })
186    }
187}
188
189impl Drop for PluginModule {
190    fn drop(&mut self) {
191        if let Some(drop_fn) = self.vtable.drop {
192            unsafe { drop_fn(self.instance) };
193        }
194    }
195}
196
197// SAFETY: access goes through plugin vtable calls that are required to be thread-safe.
198unsafe impl Send for PluginModule {}
199unsafe impl Sync for PluginModule {}
200
201struct ModuleInvoker {
202    name: String,
203    vtable: &'static ModuleVTable,
204    instance: *mut c_void,
205}
206
207impl ModuleInvoker {
208    fn invoke_nb(
209        &self,
210        function: &str,
211        args: &[ValueWord],
212    ) -> std::result::Result<ValueWord, String> {
213        let ctx = crate::Context::new_empty();
214        let wire_args: Vec<WireValue> = args
215            .iter()
216            .map(|nb| crate::wire_conversion::nb_to_wire(nb, &ctx))
217            .collect();
218
219        let wire_bytes = rmp_serde::to_vec(&wire_args).map_err(|e| {
220            format!(
221                "Failed to serialize wire args for '{}.{}': {}",
222                self.name, function, e
223            )
224        })?;
225
226        match self
227            .invoke_with_args(function, &wire_bytes)
228            .map_err(|err| err.message)?
229        {
230            ModuleInvokePayload::Wire(bytes) => {
231                let payload = decode_payload_to_wire(&bytes).map_err(|e| {
232                    format!(
233                        "Failed to decode module result for '{}.{}': {}",
234                        self.name, function, e
235                    )
236                })?;
237                let normalized = normalize_invoke_result(payload, &self.name, function)?;
238                Ok(crate::wire_conversion::wire_to_nb(&normalized))
239            }
240            ModuleInvokePayload::TableArrowIpc(ipc_bytes) => {
241                let dt = crate::wire_conversion::datatable_from_ipc_bytes(&ipc_bytes, None, None)
242                    .map_err(|e| {
243                    format!(
244                        "Failed to decode table payload for '{}.{}': {}",
245                        self.name, function, e
246                    )
247                })?;
248                Ok(ValueWord::from_datatable(Arc::new(dt)))
249            }
250        }
251    }
252
253    fn invoke_wire(
254        &self,
255        function: &str,
256        args: &[WireValue],
257    ) -> std::result::Result<WireValue, String> {
258        let wire_bytes = rmp_serde::to_vec(args).map_err(|e| {
259            format!(
260                "Failed to serialize wire args for '{}.{}': {}",
261                self.name, function, e
262            )
263        })?;
264
265        match self
266            .invoke_with_args(function, &wire_bytes)
267            .map_err(|err| err.message)?
268        {
269            ModuleInvokePayload::Wire(bytes) => {
270                let payload = decode_payload_to_wire(&bytes).map_err(|e| {
271                    format!(
272                        "Failed to decode module result for '{}.{}': {}",
273                        self.name, function, e
274                    )
275                })?;
276                normalize_invoke_result(payload, &self.name, function)
277            }
278            ModuleInvokePayload::TableArrowIpc(ipc_bytes) => {
279                let dt = crate::wire_conversion::datatable_from_ipc_bytes(&ipc_bytes, None, None)
280                    .map_err(|e| {
281                    format!(
282                        "Failed to decode table payload for '{}.{}': {}",
283                        self.name, function, e
284                    )
285                })?;
286                let nb = ValueWord::from_datatable(Arc::new(dt));
287                let ctx = crate::Context::new_empty();
288                Ok(crate::wire_conversion::nb_to_wire(&nb, &ctx))
289            }
290        }
291    }
292
293    fn invoke_with_args(
294        &self,
295        function: &str,
296        args_bytes: &[u8],
297    ) -> std::result::Result<ModuleInvokePayload, ModuleInvokeFailure> {
298        if let Some(invoke_ex_fn) = self.vtable.invoke_ex {
299            let mut out = ModuleInvokeResult::empty();
300            let status = unsafe {
301                invoke_ex_fn(
302                    self.instance,
303                    function.as_ptr(),
304                    function.len(),
305                    args_bytes.as_ptr(),
306                    args_bytes.len(),
307                    &mut out,
308                )
309            };
310
311            if status != PluginError::Success as i32 {
312                return Err(ModuleInvokeFailure {
313                    message: format!(
314                        "Plugin '{}' module invoke_ex failed for '{}': status {}",
315                        self.name, function, status
316                    ),
317                });
318            }
319
320            let payload = self.take_payload_bytes(out.payload_ptr, out.payload_len);
321            return match out.kind {
322                ModuleInvokeResultKind::WireValueMsgpack => Ok(ModuleInvokePayload::Wire(payload)),
323                ModuleInvokeResultKind::TableArrowIpc => {
324                    Ok(ModuleInvokePayload::TableArrowIpc(payload))
325                }
326            };
327        }
328
329        self.invoke_with_args_legacy(function, args_bytes)
330    }
331
332    fn invoke_with_args_legacy(
333        &self,
334        function: &str,
335        args_bytes: &[u8],
336    ) -> std::result::Result<ModuleInvokePayload, ModuleInvokeFailure> {
337        let invoke_fn = self.vtable.invoke.ok_or_else(|| ModuleInvokeFailure {
338            message: format!(
339                "Plugin '{}' module capability does not implement invoke()",
340                self.name
341            ),
342        })?;
343
344        let mut out_ptr: *mut u8 = std::ptr::null_mut();
345        let mut out_len: usize = 0;
346        let status = unsafe {
347            invoke_fn(
348                self.instance,
349                function.as_ptr(),
350                function.len(),
351                args_bytes.as_ptr(),
352                args_bytes.len(),
353                &mut out_ptr,
354                &mut out_len,
355            )
356        };
357
358        if status != PluginError::Success as i32 {
359            return Err(ModuleInvokeFailure {
360                message: format!(
361                    "Plugin '{}' module invoke failed for '{}': status {}",
362                    self.name, function, status
363                ),
364            });
365        }
366
367        Ok(ModuleInvokePayload::Wire(
368            self.take_payload_bytes(out_ptr, out_len),
369        ))
370    }
371
372    fn take_payload_bytes(&self, ptr: *mut u8, len: usize) -> Vec<u8> {
373        if ptr.is_null() {
374            return Vec::new();
375        }
376
377        let bytes = if len == 0 {
378            Vec::new()
379        } else {
380            unsafe { std::slice::from_raw_parts(ptr, len).to_vec() }
381        };
382
383        if let Some(free_fn) = self.vtable.free_buffer {
384            unsafe { free_fn(ptr, len) };
385        }
386        bytes
387    }
388}
389
390// SAFETY: access goes through plugin vtable calls that are required to be thread-safe.
391unsafe impl Send for ModuleInvoker {}
392unsafe impl Sync for ModuleInvoker {}
393
394#[derive(Debug)]
395struct ModuleInvokeFailure {
396    message: String,
397}
398
399#[derive(Debug)]
400enum ModuleInvokePayload {
401    Wire(Vec<u8>),
402    TableArrowIpc(Vec<u8>),
403}
404
405fn parse_module_schema(
406    vtable: &'static ModuleVTable,
407    instance: *mut c_void,
408    plugin_name: &str,
409) -> Result<ParsedModuleSchema> {
410    let get_schema_fn = vtable
411        .get_module_schema
412        .ok_or_else(|| ShapeError::RuntimeError {
413            message: format!(
414                "Plugin '{}' module capability has no get_module_schema()",
415                plugin_name
416            ),
417            location: None,
418        })?;
419
420    let mut out_ptr: *mut u8 = std::ptr::null_mut();
421    let mut out_len: usize = 0;
422    let status = unsafe { get_schema_fn(instance, &mut out_ptr, &mut out_len) };
423    if status != PluginError::Success as i32 {
424        return Err(ShapeError::RuntimeError {
425            message: format!(
426                "Plugin '{}' get_module_schema failed with status {}",
427                plugin_name, status
428            ),
429            location: None,
430        });
431    }
432
433    if out_ptr.is_null() || out_len == 0 {
434        return Err(ShapeError::RuntimeError {
435            message: format!(
436                "Plugin '{}' returned empty module schema payload",
437                plugin_name
438            ),
439            location: None,
440        });
441    }
442
443    let bytes = unsafe { std::slice::from_raw_parts(out_ptr, out_len).to_vec() };
444    if let Some(free_fn) = vtable.free_buffer {
445        unsafe { free_fn(out_ptr, out_len) };
446    }
447    let schema: AbiModuleSchema =
448        rmp_serde::from_slice(&bytes).map_err(|e| ShapeError::RuntimeError {
449            message: format!(
450                "Failed to decode module schema from '{}': {}",
451                plugin_name, e
452            ),
453            location: None,
454        })?;
455
456    let module_name = if schema.module_name.is_empty() {
457        plugin_name.to_string()
458    } else {
459        schema.module_name
460    };
461
462    let mut seen = HashSet::new();
463    let mut functions = Vec::new();
464    for f in schema.functions {
465        if f.name.is_empty() {
466            return Err(ShapeError::RuntimeError {
467                message: format!(
468                    "Plugin '{}' module schema contains empty function name",
469                    plugin_name
470                ),
471                location: None,
472            });
473        }
474        if !seen.insert(f.name.clone()) {
475            return Err(ShapeError::RuntimeError {
476                message: format!(
477                    "Plugin '{}' module schema contains duplicate function '{}'",
478                    plugin_name, f.name
479                ),
480                location: None,
481            });
482        }
483        functions.push(ParsedModuleFunction {
484            name: f.name,
485            description: f.description,
486            params: f.params,
487            return_type: f.return_type,
488        });
489    }
490
491    let artifacts = parse_module_artifacts(vtable, instance, plugin_name)?;
492
493    Ok(ParsedModuleSchema {
494        module_name,
495        functions,
496        artifacts,
497    })
498}
499
500fn parse_module_artifacts(
501    vtable: &'static ModuleVTable,
502    instance: *mut c_void,
503    plugin_name: &str,
504) -> Result<Vec<ParsedModuleArtifact>> {
505    let Some(get_artifacts_fn) = vtable.get_module_artifacts else {
506        return Ok(Vec::new());
507    };
508
509    let mut out_ptr: *mut u8 = std::ptr::null_mut();
510    let mut out_len: usize = 0;
511    let status = unsafe { get_artifacts_fn(instance, &mut out_ptr, &mut out_len) };
512    if status != PluginError::Success as i32 {
513        return Err(ShapeError::RuntimeError {
514            message: format!(
515                "Plugin '{}' get_module_artifacts failed with status {}",
516                plugin_name, status
517            ),
518            location: None,
519        });
520    }
521
522    if out_ptr.is_null() || out_len == 0 {
523        return Ok(Vec::new());
524    }
525
526    let bytes = unsafe { std::slice::from_raw_parts(out_ptr, out_len).to_vec() };
527    if let Some(free_fn) = vtable.free_buffer {
528        unsafe { free_fn(out_ptr, out_len) };
529    }
530
531    let parsed = rmp_serde::from_slice::<Vec<ArtifactPayload>>(&bytes).map_err(|e| {
532        ShapeError::RuntimeError {
533            message: format!(
534                "Failed to decode module artifacts from '{}': {}",
535                plugin_name, e
536            ),
537            location: None,
538        }
539    })?;
540
541    let mut seen_paths = HashSet::new();
542    let mut artifacts = Vec::new();
543    for item in parsed {
544        if item.module_path.is_empty() {
545            return Err(ShapeError::RuntimeError {
546                message: format!(
547                    "Plugin '{}' module artifacts contain empty module_path",
548                    plugin_name
549                ),
550                location: None,
551            });
552        }
553        if !seen_paths.insert(item.module_path.clone()) {
554            return Err(ShapeError::RuntimeError {
555                message: format!(
556                    "Plugin '{}' module artifacts contain duplicate module_path '{}'",
557                    plugin_name, item.module_path
558                ),
559                location: None,
560            });
561        }
562        artifacts.push(ParsedModuleArtifact {
563            module_path: item.module_path,
564            source: item.source,
565            compiled: item.compiled,
566        });
567    }
568
569    Ok(artifacts)
570}
571
572fn decode_payload_to_wire(bytes: &[u8]) -> std::result::Result<WireValue, String> {
573    if bytes.is_empty() {
574        return Ok(WireValue::Null);
575    }
576    rmp_serde::from_slice::<WireValue>(bytes).map_err(|e| format!("invalid wire payload: {}", e))
577}
578
579fn normalize_invoke_result(
580    payload: WireValue,
581    module_name: &str,
582    function: &str,
583) -> std::result::Result<WireValue, String> {
584    match payload {
585        WireValue::Result { ok, value } => {
586            if ok {
587                Ok(*value)
588            } else {
589                Err(format!(
590                    "Plugin '{}.{}' failed: {}",
591                    module_name,
592                    function,
593                    format_wire_error_message(&value)
594                ))
595            }
596        }
597        other => Ok(other),
598    }
599}
600
601fn format_wire_error_message(value: &WireValue) -> String {
602    if let Some(rendered) = render_any_error_plain(value) {
603        return rendered;
604    }
605
606    match value {
607        WireValue::String(s) => s.clone(),
608        WireValue::Object(map) => {
609            if let Some(WireValue::String(message)) = map.get("message") {
610                message.clone()
611            } else {
612                format!("{value:?}")
613            }
614        }
615        _ => format!("{value:?}"),
616    }
617}