Skip to main content

uni_plugin_extism/
exports.rs

1//! Plugin-export readers — `manifest` and `register`.
2//!
3//! Every Extism plugin exposes two canonical-JSON control-surface exports:
4//!
5//! - **`manifest`** — returns the plugin's [`ExtismPluginManifest`]
6//!   (id, version, declared capabilities, resource limits, …). Read once
7//!   at load time to drive the capability intersection.
8//! - **`register`** — returns a [`RegistrationManifest`] enumerating every
9//!   qname the plugin provides plus its wire-level signature. Read after
10//!   capability negotiation; one [`RegistrationEntry`] is converted to a
11//!   `ScalarPluginFn` / `AggregatePluginFn` / `ProcedurePlugin` adapter
12//!   downstream (M6a.1.5).
13//!
14//! This module splits parsing (pure, byte-slice in / value out) from the
15//! Extism-call wrapper (`read_*_export`). The split lets us unit-test JSON
16//! contracts without standing up a wasm plugin; the call-wrapper is
17//! exercised end-to-end by the M6a.1.7 example plugin.
18
19use serde::{Deserialize, Serialize};
20
21use crate::error::ExtismError;
22use crate::loader::ExtismPluginManifest;
23
24/// Wire-level scalar / aggregate / procedure signature shipped by a
25/// plugin's `register` export.
26///
27/// String-based for wire stability — plugins shouldn't have to encode
28/// `arrow_schema::DataType` JSON. Translation to internal `FnSignature`
29/// / `AggSignature` / `ProcedureSignature` happens at adapter
30/// construction time (M6a.1.5 / M6a.2).
31#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
32#[serde(deny_unknown_fields)]
33pub struct WireFnSignature {
34    /// Argument types in `WireArgType` form.
35    pub args: Vec<WireArgType>,
36    /// Return type.
37    pub returns: WireArgType,
38    /// Volatility — `"immutable"`, `"stable"`, or `"volatile"`. Default
39    /// `"immutable"`.
40    #[serde(default = "default_volatility")]
41    pub volatility: String,
42    /// Null handling — `"propagate"` (default) or `"user_handled"`.
43    #[serde(default = "default_null_handling")]
44    pub null_handling: String,
45}
46
47fn default_volatility() -> String {
48    "immutable".to_owned()
49}
50
51fn default_null_handling() -> String {
52    "propagate".to_owned()
53}
54
55/// Wire-level argument type shipped by a plugin.
56///
57/// Each variant maps to the corresponding `uni_plugin::traits::scalar::ArgType`
58/// at adapter time. Primitive types use the lowercase Arrow names
59/// (`"int64"`, `"float64"`, `"utf8"`, `"boolean"`, `"date64"`,
60/// `"timestamp_ms"`, `"binary"`, `"largebinary"`).
61#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
62#[serde(tag = "kind", rename_all = "snake_case", deny_unknown_fields)]
63pub enum WireArgType {
64    /// A native Arrow primitive — `kind: "primitive", arrow: "<name>"`.
65    Primitive {
66        /// Arrow primitive name.
67        arrow: String,
68    },
69    /// A `CypherValue` shipped via `LargeBinary` opaque transport.
70    CypherValue,
71    /// A fixed-size vector — `kind: "vector", len: N, element: "<arrow>"`.
72    Vector {
73        /// Number of elements per row.
74        len: usize,
75        /// Element type.
76        element: String,
77    },
78    /// Variadic — repeats `inner` zero or more times.
79    Variadic {
80        /// Inner element type.
81        inner: Box<WireArgType>,
82    },
83}
84
85/// One registration entry — a single qname plus its kind + signature.
86#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
87#[serde(tag = "kind", rename_all = "snake_case", deny_unknown_fields)]
88pub enum RegistrationEntry {
89    /// A Cypher scalar function.
90    Scalar {
91        /// Fully-qualified name (`"ns.fn"`).
92        qname: String,
93        /// Signature.
94        signature: WireFnSignature,
95    },
96    /// A Cypher aggregate function. Wire shape mirrors Scalar with the
97    /// state type carried as a separate `WireArgType`.
98    Aggregate {
99        /// Fully-qualified name.
100        qname: String,
101        /// Per-row input + return types.
102        signature: WireFnSignature,
103        /// State type — opaque to the wire; Adapter side wraps as
104        /// Arrow Binary.
105        state: WireArgType,
106    },
107    /// A Cypher procedure.
108    Procedure {
109        /// Fully-qualified name.
110        qname: String,
111        /// Argument signatures.
112        args: Vec<WireArgType>,
113        /// Yielded column types, in declared order.
114        yields: Vec<WireArgType>,
115        /// Mode — `"read"`, `"write"`, `"schema"`, or `"dbms"`. Default `"read"`.
116        #[serde(default = "default_proc_mode")]
117        mode: String,
118    },
119}
120
121fn default_proc_mode() -> String {
122    "read".to_owned()
123}
124
125impl RegistrationEntry {
126    /// Fully-qualified name of this entry.
127    #[must_use]
128    pub fn qname(&self) -> &str {
129        match self {
130            Self::Scalar { qname, .. }
131            | Self::Aggregate { qname, .. }
132            | Self::Procedure { qname, .. } => qname,
133        }
134    }
135}
136
137/// Top-level `register` export payload.
138#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
139#[serde(deny_unknown_fields)]
140pub struct RegistrationManifest {
141    /// One entry per qname provided by the plugin.
142    pub entries: Vec<RegistrationEntry>,
143}
144
145/// Parse the bytes returned by a plugin's `manifest` export into an
146/// [`ExtismPluginManifest`].
147///
148/// # Errors
149///
150/// - [`ExtismError::ManifestInvalid`] if the JSON doesn't parse or
151///   doesn't match the expected shape.
152pub fn parse_manifest_json(bytes: &[u8]) -> Result<ExtismPluginManifest, ExtismError> {
153    serde_json::from_slice(bytes)
154        .map_err(|e| ExtismError::ManifestInvalid(format!("json parse: {e}")))
155}
156
157/// Parse the bytes returned by a plugin's `register` export into a
158/// [`RegistrationManifest`].
159///
160/// # Errors
161///
162/// - [`ExtismError::OutputDecode`] if the JSON doesn't parse or doesn't
163///   match the expected shape.
164pub fn parse_registration_json(bytes: &[u8]) -> Result<RegistrationManifest, ExtismError> {
165    serde_json::from_slice(bytes)
166        .map_err(|e| ExtismError::OutputDecode(format!("register json parse: {e}")))
167}
168
169/// Call a live plugin's `manifest` export and parse the response.
170///
171/// The `manifest` export takes no input and returns canonical-JSON
172/// matching [`ExtismPluginManifest`]. The plugin produces this once and
173/// caches internally; the host reads it once at load and never again.
174///
175/// # Errors
176///
177/// - [`ExtismError::InvalidPlugin`] if the export doesn't exist or the
178///   underlying Extism call fails.
179/// - [`ExtismError::ManifestInvalid`] if the returned JSON is malformed.
180pub fn read_manifest_export(
181    plugin: &mut extism::Plugin,
182) -> Result<ExtismPluginManifest, ExtismError> {
183    if !plugin.function_exists("manifest") {
184        return Err(ExtismError::InvalidPlugin(
185            "plugin does not export required `manifest` function".to_owned(),
186        ));
187    }
188    let bytes: &[u8] = plugin
189        .call("manifest", "")
190        .map_err(|e| ExtismError::InvalidPlugin(format!("call manifest: {e}")))?;
191    parse_manifest_json(bytes)
192}
193
194/// Call a live plugin's `register` export and parse the response.
195///
196/// The `register` export takes no input and returns canonical-JSON
197/// matching [`RegistrationManifest`]. The host reads this after
198/// capability negotiation and converts each entry into an adapter
199/// implementing the corresponding capability trait.
200///
201/// # Errors
202///
203/// - [`ExtismError::InvalidPlugin`] if the export doesn't exist or the
204///   underlying Extism call fails.
205/// - [`ExtismError::OutputDecode`] if the returned JSON is malformed.
206pub fn read_register_export(
207    plugin: &mut extism::Plugin,
208) -> Result<RegistrationManifest, ExtismError> {
209    if !plugin.function_exists("register") {
210        return Err(ExtismError::InvalidPlugin(
211            "plugin does not export required `register` function".to_owned(),
212        ));
213    }
214    let bytes: &[u8] = plugin
215        .call("register", "")
216        .map_err(|e| ExtismError::InvalidPlugin(format!("call register: {e}")))?;
217    parse_registration_json(bytes)
218}
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223
224    #[test]
225    fn parses_minimal_manifest() {
226        let json = br#"{"id":"a.b","version":"0.0.1"}"#;
227        let m = parse_manifest_json(json).unwrap();
228        assert_eq!(m.id, "a.b");
229        assert_eq!(m.version, "0.0.1");
230        assert!(m.capabilities.is_empty());
231        assert!(m.fuel_per_call.is_none());
232    }
233
234    #[test]
235    fn parses_manifest_with_resource_limits() {
236        let json = br#"{
237            "id": "a.b",
238            "version": "0.0.1",
239            "capabilities": ["filesystem"],
240            "fuel_per_call": 1000,
241            "memory_max_pages": 4,
242            "timeout_ms": 500
243        }"#;
244        let m = parse_manifest_json(json).unwrap();
245        assert_eq!(m.fuel_per_call, Some(1000));
246        assert_eq!(m.memory_max_pages, Some(4));
247        assert_eq!(m.timeout_ms, Some(500));
248        assert!(m.declared_capability_set().contains_variant(
249            &uni_plugin::Capability::Filesystem {
250                read: vec![],
251                write: vec![],
252            }
253        ));
254    }
255
256    #[test]
257    fn rejects_unknown_manifest_field() {
258        let json = br#"{"id":"a.b","version":"0.0.1","mystery":"surprise"}"#;
259        let err = parse_manifest_json(json).unwrap_err();
260        assert!(matches!(err, ExtismError::ManifestInvalid(_)));
261    }
262
263    #[test]
264    fn parses_empty_registration() {
265        let json = br#"{"entries":[]}"#;
266        let r = parse_registration_json(json).unwrap();
267        assert!(r.entries.is_empty());
268    }
269
270    #[test]
271    fn parses_scalar_registration_entry() {
272        let json = br#"{
273            "entries": [{
274                "kind": "scalar",
275                "qname": "geo.haversine",
276                "signature": {
277                    "args": [
278                        {"kind":"primitive","arrow":"float64"},
279                        {"kind":"primitive","arrow":"float64"},
280                        {"kind":"primitive","arrow":"float64"},
281                        {"kind":"primitive","arrow":"float64"}
282                    ],
283                    "returns": {"kind":"primitive","arrow":"float64"}
284                }
285            }]
286        }"#;
287        let r = parse_registration_json(json).unwrap();
288        assert_eq!(r.entries.len(), 1);
289        match &r.entries[0] {
290            RegistrationEntry::Scalar { qname, signature } => {
291                assert_eq!(qname, "geo.haversine");
292                assert_eq!(signature.args.len(), 4);
293                assert_eq!(signature.volatility, "immutable");
294                assert_eq!(signature.null_handling, "propagate");
295                assert!(matches!(
296                    signature.returns,
297                    WireArgType::Primitive { ref arrow } if arrow == "float64"
298                ));
299            }
300            other => panic!("expected Scalar, got: {other:?}"),
301        }
302    }
303
304    #[test]
305    fn parses_aggregate_registration_entry() {
306        let json = br#"{
307            "entries": [{
308                "kind": "aggregate",
309                "qname": "stats.weighted_mean",
310                "signature": {
311                    "args": [
312                        {"kind":"primitive","arrow":"float64"},
313                        {"kind":"primitive","arrow":"float64"}
314                    ],
315                    "returns": {"kind":"primitive","arrow":"float64"},
316                    "volatility": "stable"
317                },
318                "state": {"kind":"primitive","arrow":"binary"}
319            }]
320        }"#;
321        let r = parse_registration_json(json).unwrap();
322        match &r.entries[0] {
323            RegistrationEntry::Aggregate {
324                qname,
325                signature,
326                state,
327            } => {
328                assert_eq!(qname, "stats.weighted_mean");
329                assert_eq!(signature.volatility, "stable");
330                assert!(matches!(state, WireArgType::Primitive { arrow } if arrow == "binary"));
331            }
332            other => panic!("expected Aggregate, got: {other:?}"),
333        }
334    }
335
336    #[test]
337    fn parses_procedure_registration_entry() {
338        let json = br#"{
339            "entries": [{
340                "kind": "procedure",
341                "qname": "myorg.scan",
342                "args": [{"kind":"primitive","arrow":"utf8"}],
343                "yields": [
344                    {"kind":"primitive","arrow":"int64"},
345                    {"kind":"cypher_value"}
346                ],
347                "mode": "write"
348            }]
349        }"#;
350        let r = parse_registration_json(json).unwrap();
351        match &r.entries[0] {
352            RegistrationEntry::Procedure {
353                qname,
354                args,
355                yields,
356                mode,
357            } => {
358                assert_eq!(qname, "myorg.scan");
359                assert_eq!(args.len(), 1);
360                assert_eq!(yields.len(), 2);
361                assert_eq!(mode, "write");
362                assert!(matches!(yields[1], WireArgType::CypherValue));
363            }
364            other => panic!("expected Procedure, got: {other:?}"),
365        }
366    }
367
368    #[test]
369    fn procedure_mode_defaults_to_read() {
370        let json = br#"{
371            "entries": [{
372                "kind": "procedure",
373                "qname": "myorg.scan",
374                "args": [],
375                "yields": []
376            }]
377        }"#;
378        let r = parse_registration_json(json).unwrap();
379        match &r.entries[0] {
380            RegistrationEntry::Procedure { mode, .. } => assert_eq!(mode, "read"),
381            _ => unreachable!(),
382        }
383    }
384
385    #[test]
386    fn registration_entry_exposes_qname() {
387        let e = RegistrationEntry::Scalar {
388            qname: "x.y".to_owned(),
389            signature: WireFnSignature {
390                args: vec![],
391                returns: WireArgType::CypherValue,
392                volatility: "immutable".to_owned(),
393                null_handling: "propagate".to_owned(),
394            },
395        };
396        assert_eq!(e.qname(), "x.y");
397    }
398
399    #[test]
400    fn rejects_unknown_registration_kind() {
401        let json = br#"{"entries":[{"kind":"telegraphic","qname":"x"}]}"#;
402        let err = parse_registration_json(json).unwrap_err();
403        assert!(matches!(err, ExtismError::OutputDecode(_)));
404    }
405
406    #[test]
407    fn parses_vector_and_variadic_argtypes() {
408        let json = br#"{
409            "entries": [{
410                "kind": "scalar",
411                "qname": "vec.norm",
412                "signature": {
413                    "args": [
414                        {"kind":"vector","len":128,"element":"float32"},
415                        {"kind":"variadic","inner":{"kind":"primitive","arrow":"int64"}}
416                    ],
417                    "returns": {"kind":"primitive","arrow":"float32"}
418                }
419            }]
420        }"#;
421        let r = parse_registration_json(json).unwrap();
422        match &r.entries[0] {
423            RegistrationEntry::Scalar { signature, .. } => {
424                assert!(matches!(
425                    signature.args[0],
426                    WireArgType::Vector { len: 128, ref element } if element == "float32"
427                ));
428                assert!(matches!(
429                    signature.args[1],
430                    WireArgType::Variadic { ref inner } if matches!(
431                        **inner,
432                        WireArgType::Primitive { ref arrow } if arrow == "int64"
433                    )
434                ));
435            }
436            _ => unreachable!(),
437        }
438    }
439}