1use serde::{Deserialize, Serialize};
20
21use crate::error::ExtismError;
22use crate::loader::ExtismPluginManifest;
23
24#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
32#[serde(deny_unknown_fields)]
33pub struct WireFnSignature {
34 pub args: Vec<WireArgType>,
36 pub returns: WireArgType,
38 #[serde(default = "default_volatility")]
41 pub volatility: String,
42 #[serde(default = "default_null_handling")]
44 pub null_handling: String,
45}
46
47fn default_volatility() -> String {
48 "immutable".to_owned()
49}
50
51fn default_null_handling() -> String {
52 "propagate".to_owned()
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
62#[serde(tag = "kind", rename_all = "snake_case", deny_unknown_fields)]
63pub enum WireArgType {
64 Primitive {
66 arrow: String,
68 },
69 CypherValue,
71 Vector {
73 len: usize,
75 element: String,
77 },
78 Variadic {
80 inner: Box<WireArgType>,
82 },
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
87#[serde(tag = "kind", rename_all = "snake_case", deny_unknown_fields)]
88pub enum RegistrationEntry {
89 Scalar {
91 qname: String,
93 signature: WireFnSignature,
95 },
96 Aggregate {
99 qname: String,
101 signature: WireFnSignature,
103 state: WireArgType,
106 },
107 Procedure {
109 qname: String,
111 args: Vec<WireArgType>,
113 yields: Vec<WireArgType>,
115 #[serde(default = "default_proc_mode")]
117 mode: String,
118 },
119}
120
121fn default_proc_mode() -> String {
122 "read".to_owned()
123}
124
125impl RegistrationEntry {
126 #[must_use]
128 pub fn qname(&self) -> &str {
129 match self {
130 Self::Scalar { qname, .. }
131 | Self::Aggregate { qname, .. }
132 | Self::Procedure { qname, .. } => qname,
133 }
134 }
135}
136
137#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
139#[serde(deny_unknown_fields)]
140pub struct RegistrationManifest {
141 pub entries: Vec<RegistrationEntry>,
143}
144
145pub fn parse_manifest_json(bytes: &[u8]) -> Result<ExtismPluginManifest, ExtismError> {
153 serde_json::from_slice(bytes)
154 .map_err(|e| ExtismError::ManifestInvalid(format!("json parse: {e}")))
155}
156
157pub fn parse_registration_json(bytes: &[u8]) -> Result<RegistrationManifest, ExtismError> {
165 serde_json::from_slice(bytes)
166 .map_err(|e| ExtismError::OutputDecode(format!("register json parse: {e}")))
167}
168
169pub fn read_manifest_export(
181 plugin: &mut extism::Plugin,
182) -> Result<ExtismPluginManifest, ExtismError> {
183 if !plugin.function_exists("manifest") {
184 return Err(ExtismError::InvalidPlugin(
185 "plugin does not export required `manifest` function".to_owned(),
186 ));
187 }
188 let bytes: &[u8] = plugin
189 .call("manifest", "")
190 .map_err(|e| ExtismError::InvalidPlugin(format!("call manifest: {e}")))?;
191 parse_manifest_json(bytes)
192}
193
194pub fn read_register_export(
207 plugin: &mut extism::Plugin,
208) -> Result<RegistrationManifest, ExtismError> {
209 if !plugin.function_exists("register") {
210 return Err(ExtismError::InvalidPlugin(
211 "plugin does not export required `register` function".to_owned(),
212 ));
213 }
214 let bytes: &[u8] = plugin
215 .call("register", "")
216 .map_err(|e| ExtismError::InvalidPlugin(format!("call register: {e}")))?;
217 parse_registration_json(bytes)
218}
219
220#[cfg(test)]
221mod tests {
222 use super::*;
223
224 #[test]
225 fn parses_minimal_manifest() {
226 let json = br#"{"id":"a.b","version":"0.0.1"}"#;
227 let m = parse_manifest_json(json).unwrap();
228 assert_eq!(m.id, "a.b");
229 assert_eq!(m.version, "0.0.1");
230 assert!(m.capabilities.is_empty());
231 assert!(m.fuel_per_call.is_none());
232 }
233
234 #[test]
235 fn parses_manifest_with_resource_limits() {
236 let json = br#"{
237 "id": "a.b",
238 "version": "0.0.1",
239 "capabilities": ["filesystem"],
240 "fuel_per_call": 1000,
241 "memory_max_pages": 4,
242 "timeout_ms": 500
243 }"#;
244 let m = parse_manifest_json(json).unwrap();
245 assert_eq!(m.fuel_per_call, Some(1000));
246 assert_eq!(m.memory_max_pages, Some(4));
247 assert_eq!(m.timeout_ms, Some(500));
248 assert!(m.declared_capability_set().contains_variant(
249 &uni_plugin::Capability::Filesystem {
250 read: vec![],
251 write: vec![],
252 }
253 ));
254 }
255
256 #[test]
257 fn rejects_unknown_manifest_field() {
258 let json = br#"{"id":"a.b","version":"0.0.1","mystery":"surprise"}"#;
259 let err = parse_manifest_json(json).unwrap_err();
260 assert!(matches!(err, ExtismError::ManifestInvalid(_)));
261 }
262
263 #[test]
264 fn parses_empty_registration() {
265 let json = br#"{"entries":[]}"#;
266 let r = parse_registration_json(json).unwrap();
267 assert!(r.entries.is_empty());
268 }
269
270 #[test]
271 fn parses_scalar_registration_entry() {
272 let json = br#"{
273 "entries": [{
274 "kind": "scalar",
275 "qname": "geo.haversine",
276 "signature": {
277 "args": [
278 {"kind":"primitive","arrow":"float64"},
279 {"kind":"primitive","arrow":"float64"},
280 {"kind":"primitive","arrow":"float64"},
281 {"kind":"primitive","arrow":"float64"}
282 ],
283 "returns": {"kind":"primitive","arrow":"float64"}
284 }
285 }]
286 }"#;
287 let r = parse_registration_json(json).unwrap();
288 assert_eq!(r.entries.len(), 1);
289 match &r.entries[0] {
290 RegistrationEntry::Scalar { qname, signature } => {
291 assert_eq!(qname, "geo.haversine");
292 assert_eq!(signature.args.len(), 4);
293 assert_eq!(signature.volatility, "immutable");
294 assert_eq!(signature.null_handling, "propagate");
295 assert!(matches!(
296 signature.returns,
297 WireArgType::Primitive { ref arrow } if arrow == "float64"
298 ));
299 }
300 other => panic!("expected Scalar, got: {other:?}"),
301 }
302 }
303
304 #[test]
305 fn parses_aggregate_registration_entry() {
306 let json = br#"{
307 "entries": [{
308 "kind": "aggregate",
309 "qname": "stats.weighted_mean",
310 "signature": {
311 "args": [
312 {"kind":"primitive","arrow":"float64"},
313 {"kind":"primitive","arrow":"float64"}
314 ],
315 "returns": {"kind":"primitive","arrow":"float64"},
316 "volatility": "stable"
317 },
318 "state": {"kind":"primitive","arrow":"binary"}
319 }]
320 }"#;
321 let r = parse_registration_json(json).unwrap();
322 match &r.entries[0] {
323 RegistrationEntry::Aggregate {
324 qname,
325 signature,
326 state,
327 } => {
328 assert_eq!(qname, "stats.weighted_mean");
329 assert_eq!(signature.volatility, "stable");
330 assert!(matches!(state, WireArgType::Primitive { arrow } if arrow == "binary"));
331 }
332 other => panic!("expected Aggregate, got: {other:?}"),
333 }
334 }
335
336 #[test]
337 fn parses_procedure_registration_entry() {
338 let json = br#"{
339 "entries": [{
340 "kind": "procedure",
341 "qname": "myorg.scan",
342 "args": [{"kind":"primitive","arrow":"utf8"}],
343 "yields": [
344 {"kind":"primitive","arrow":"int64"},
345 {"kind":"cypher_value"}
346 ],
347 "mode": "write"
348 }]
349 }"#;
350 let r = parse_registration_json(json).unwrap();
351 match &r.entries[0] {
352 RegistrationEntry::Procedure {
353 qname,
354 args,
355 yields,
356 mode,
357 } => {
358 assert_eq!(qname, "myorg.scan");
359 assert_eq!(args.len(), 1);
360 assert_eq!(yields.len(), 2);
361 assert_eq!(mode, "write");
362 assert!(matches!(yields[1], WireArgType::CypherValue));
363 }
364 other => panic!("expected Procedure, got: {other:?}"),
365 }
366 }
367
368 #[test]
369 fn procedure_mode_defaults_to_read() {
370 let json = br#"{
371 "entries": [{
372 "kind": "procedure",
373 "qname": "myorg.scan",
374 "args": [],
375 "yields": []
376 }]
377 }"#;
378 let r = parse_registration_json(json).unwrap();
379 match &r.entries[0] {
380 RegistrationEntry::Procedure { mode, .. } => assert_eq!(mode, "read"),
381 _ => unreachable!(),
382 }
383 }
384
385 #[test]
386 fn registration_entry_exposes_qname() {
387 let e = RegistrationEntry::Scalar {
388 qname: "x.y".to_owned(),
389 signature: WireFnSignature {
390 args: vec![],
391 returns: WireArgType::CypherValue,
392 volatility: "immutable".to_owned(),
393 null_handling: "propagate".to_owned(),
394 },
395 };
396 assert_eq!(e.qname(), "x.y");
397 }
398
399 #[test]
400 fn rejects_unknown_registration_kind() {
401 let json = br#"{"entries":[{"kind":"telegraphic","qname":"x"}]}"#;
402 let err = parse_registration_json(json).unwrap_err();
403 assert!(matches!(err, ExtismError::OutputDecode(_)));
404 }
405
406 #[test]
407 fn parses_vector_and_variadic_argtypes() {
408 let json = br#"{
409 "entries": [{
410 "kind": "scalar",
411 "qname": "vec.norm",
412 "signature": {
413 "args": [
414 {"kind":"vector","len":128,"element":"float32"},
415 {"kind":"variadic","inner":{"kind":"primitive","arrow":"int64"}}
416 ],
417 "returns": {"kind":"primitive","arrow":"float32"}
418 }
419 }]
420 }"#;
421 let r = parse_registration_json(json).unwrap();
422 match &r.entries[0] {
423 RegistrationEntry::Scalar { signature, .. } => {
424 assert!(matches!(
425 signature.args[0],
426 WireArgType::Vector { len: 128, ref element } if element == "float32"
427 ));
428 assert!(matches!(
429 signature.args[1],
430 WireArgType::Variadic { ref inner } if matches!(
431 **inner,
432 WireArgType::Primitive { ref arrow } if arrow == "int64"
433 )
434 ));
435 }
436 _ => unreachable!(),
437 }
438 }
439}