1use crate::executor::objects::object_creation::read_slot_nb;
4use rust_decimal::prelude::ToPrimitive;
5use shape_runtime::type_schema::TypeSchemaRegistry;
6use shape_value::heap_value::HeapValue;
7use shape_value::{VMError, ValueSlot, ValueWord};
8use std::sync::Arc;
9
10pub fn marshal_args(args: &[ValueWord], schemas: &TypeSchemaRegistry) -> Result<Vec<u8>, VMError> {
12 let values: Vec<rmpv::Value> = args
13 .iter()
14 .map(|nb| nanboxed_to_msgpack_value(nb, schemas))
15 .collect();
16 let array = rmpv::Value::Array(values);
17 rmp_serde::to_vec(&array).map_err(|e| {
18 VMError::RuntimeError(format!("Failed to marshal foreign function args: {}", e))
19 })
20}
21
22pub fn unmarshal_result(
29 bytes: &[u8],
30 return_type: &str,
31 schema_id: Option<u32>,
32 schemas: &TypeSchemaRegistry,
33) -> Result<ValueWord, VMError> {
34 if bytes.is_empty() {
35 return Ok(ValueWord::none());
36 }
37 let value: rmpv::Value = rmp_serde::from_slice(bytes).map_err(|e| {
38 VMError::RuntimeError(format!(
39 "Failed to unmarshal foreign function result: {}",
40 e
41 ))
42 })?;
43
44 let inner_type = strip_result_wrapper(return_type);
45 typed_msgpack_to_nanboxed(&value, inner_type, schema_id, schemas)
46}
47
48fn typed_msgpack_to_nanboxed(
54 val: &rmpv::Value,
55 target: &str,
56 schema_id: Option<u32>,
57 schemas: &TypeSchemaRegistry,
58) -> Result<ValueWord, VMError> {
59 if matches!(val, rmpv::Value::Nil) {
61 if target == "none" {
62 return Ok(ValueWord::none());
63 }
64 return Err(marshal_error(format!("expected {}, got None", target)));
65 }
66
67 match target {
68 "int" => match val {
69 rmpv::Value::Integer(i) => {
70 if let Some(n) = i.as_i64() {
71 Ok(ValueWord::from_i64(n))
72 } else if let Some(n) = i.as_u64() {
73 Ok(ValueWord::from_i64(n as i64))
74 } else {
75 Err(marshal_error("integer out of range"))
76 }
77 }
78 _ => Err(marshal_error(format!(
79 "expected int, got {}",
80 msgpack_type_name(val)
81 ))),
82 },
83
84 "float" | "number" => match val {
85 rmpv::Value::F64(f) => Ok(ValueWord::from_f64(*f)),
86 rmpv::Value::F32(f) => Ok(ValueWord::from_f64(*f as f64)),
87 rmpv::Value::Integer(i) => {
88 if let Some(n) = i.as_i64() {
90 Ok(ValueWord::from_f64(n as f64))
91 } else if let Some(n) = i.as_u64() {
92 Ok(ValueWord::from_f64(n as f64))
93 } else {
94 Err(marshal_error("integer out of range for float coercion"))
95 }
96 }
97 _ => Err(marshal_error(format!(
98 "expected {}, got {}",
99 target,
100 msgpack_type_name(val)
101 ))),
102 },
103
104 "string" => match val {
105 rmpv::Value::String(s) => {
106 if let Some(s) = s.as_str() {
107 Ok(ValueWord::from_string(Arc::new(s.to_string())))
108 } else {
109 Err(marshal_error("string contains invalid UTF-8"))
110 }
111 }
112 _ => Err(marshal_error(format!(
113 "expected string, got {}",
114 msgpack_type_name(val)
115 ))),
116 },
117
118 "bool" => match val {
119 rmpv::Value::Boolean(b) => Ok(ValueWord::from_bool(*b)),
120 _ => Err(marshal_error(format!(
121 "expected bool, got {}",
122 msgpack_type_name(val)
123 ))),
124 },
125
126 "none" => Err(marshal_error(format!(
127 "expected none, got {}",
128 msgpack_type_name(val)
129 ))),
130
131 s if s.starts_with("Vec<") && s.ends_with('>') => {
133 let elem_type = &s[4..s.len() - 1];
134 match val {
135 rmpv::Value::Array(arr) => {
136 let items: Result<Vec<ValueWord>, VMError> = arr
137 .iter()
138 .enumerate()
139 .map(|(i, item)| {
140 typed_msgpack_to_nanboxed(item, elem_type, schema_id, schemas).map_err(
142 |e| VMError::RuntimeError(format!("Vec element [{}]: {}", i, e)),
143 )
144 })
145 .collect();
146 Ok(ValueWord::from_array(Arc::new(items?)))
147 }
148 _ => Err(marshal_error(format!(
149 "expected Vec, got {}",
150 msgpack_type_name(val)
151 ))),
152 }
153 }
154
155 s if s.starts_with('{') && s.ends_with('}') => {
157 match val {
158 rmpv::Value::Map(entries) => {
159 if let Some(sid) = schema_id {
160 marshal_typed_object(entries, sid, schemas)
161 } else {
162 Ok(untyped_msgpack_to_nanboxed(val))
164 }
165 }
166 _ => Err(marshal_error(format!(
167 "expected object, got {}",
168 msgpack_type_name(val)
169 ))),
170 }
171 }
172
173 _ if schema_id.is_some() => match val {
175 rmpv::Value::Map(entries) => marshal_typed_object(entries, schema_id.unwrap(), schemas),
176 _ => Err(marshal_error(format!(
177 "expected object for type '{}', got {}",
178 target,
179 msgpack_type_name(val)
180 ))),
181 },
182
183 _ => Ok(untyped_msgpack_to_nanboxed(val)),
185 }
186}
187
188fn marshal_typed_object(
190 entries: &[(rmpv::Value, rmpv::Value)],
191 schema_id: u32,
192 schemas: &TypeSchemaRegistry,
193) -> Result<ValueWord, VMError> {
194 let schema = schemas.get_by_id(schema_id).ok_or_else(|| {
195 VMError::RuntimeError(format!(
196 "FFI marshal: schema ID {} not found in registry",
197 schema_id
198 ))
199 })?;
200
201 let mut name_to_value: std::collections::HashMap<&str, &rmpv::Value> =
203 std::collections::HashMap::with_capacity(entries.len());
204 for (k, v) in entries {
205 if let rmpv::Value::String(s) = k {
206 if let Some(name) = s.as_str() {
207 name_to_value.insert(name, v);
208 }
209 }
210 }
211
212 let field_count = schema.fields.len();
213 let mut slots = Vec::with_capacity(field_count);
214 let mut heap_mask: u64 = 0;
215
216 for field in &schema.fields {
217 let val = name_to_value.get(field.wire_name());
218 use shape_runtime::type_schema::FieldType;
219
220 match &field.field_type {
221 FieldType::I64 => {
222 let n = val
223 .and_then(|v| match v {
224 rmpv::Value::Integer(i) => i.as_i64(),
225 _ => None,
226 })
227 .unwrap_or(0);
228 slots.push(ValueSlot::from_int(n));
229 }
230 FieldType::F64 => {
231 let f = val
232 .and_then(|v| match v {
233 rmpv::Value::F64(f) => Some(*f),
234 rmpv::Value::F32(f) => Some(*f as f64),
235 rmpv::Value::Integer(i) => i.as_i64().map(|n| n as f64),
236 _ => None,
237 })
238 .unwrap_or(0.0);
239 slots.push(ValueSlot::from_number(f));
240 }
241 FieldType::Bool => {
242 let b = val
243 .and_then(|v| match v {
244 rmpv::Value::Boolean(b) => Some(*b),
245 _ => None,
246 })
247 .unwrap_or(false);
248 slots.push(ValueSlot::from_bool(b));
249 }
250 FieldType::String => {
251 let s = val
252 .and_then(|v| match v {
253 rmpv::Value::String(s) => s.as_str().map(|s| s.to_string()),
254 _ => None,
255 })
256 .unwrap_or_default();
257 slots.push(ValueSlot::from_heap(HeapValue::String(Arc::new(s))));
258 heap_mask |= 1u64 << (slots.len() - 1);
259 }
260 FieldType::Array(_) => {
261 let arr_nb = val
262 .map(|v| untyped_msgpack_to_nanboxed(v))
263 .unwrap_or_else(|| ValueWord::from_array(Arc::new(Vec::new())));
264 let heap_val = match arr_nb.as_heap_ref() {
266 Some(hv) => hv.clone(),
267 None => HeapValue::Array(Arc::new(Vec::new())),
268 };
269 slots.push(ValueSlot::from_heap(heap_val));
270 heap_mask |= 1u64 << (slots.len() - 1);
271 }
272 FieldType::Object(_) => {
273 let obj_nb = val
274 .map(|v| untyped_msgpack_to_nanboxed(v))
275 .unwrap_or_else(ValueWord::none);
276 if let Some(hv) = obj_nb.as_heap_ref() {
277 slots.push(ValueSlot::from_heap(hv.clone()));
278 heap_mask |= 1u64 << (slots.len() - 1);
279 } else {
280 slots.push(ValueSlot::none());
281 }
282 }
283 _ => {
285 let nb = val
286 .map(|v| untyped_msgpack_to_nanboxed(v))
287 .unwrap_or_else(ValueWord::none);
288 if let Some(hv) = nb.as_heap_ref() {
289 slots.push(ValueSlot::from_heap(hv.clone()));
290 heap_mask |= 1u64 << (slots.len() - 1);
291 } else if let Some(f) = nb.as_f64() {
292 slots.push(ValueSlot::from_number(f));
293 } else if let Some(i) = nb.as_i64() {
294 slots.push(ValueSlot::from_number(i as f64));
295 } else if let Some(b) = nb.as_bool() {
296 slots.push(ValueSlot::from_bool(b));
297 } else {
298 slots.push(ValueSlot::none());
299 }
300 }
301 }
302 }
303
304 Ok(ValueWord::from_heap_value(HeapValue::TypedObject {
305 schema_id: schema_id as u64,
306 slots: slots.into_boxed_slice(),
307 heap_mask,
308 }))
309}
310
311fn nanboxed_to_msgpack_value(nb: &ValueWord, schemas: &TypeSchemaRegistry) -> rmpv::Value {
317 use shape_value::NanTag;
318 match nb.tag() {
319 NanTag::F64 => {
320 if let Some(f) = nb.as_f64() {
321 rmpv::Value::F64(f)
322 } else {
323 rmpv::Value::Nil
324 }
325 }
326 NanTag::I48 => {
327 if let Some(i) = nb.as_i64() {
328 rmpv::Value::Integer(rmpv::Integer::from(i))
329 } else {
330 rmpv::Value::Nil
331 }
332 }
333 NanTag::Bool => rmpv::Value::Boolean(nb.as_bool().unwrap_or(false)),
334 NanTag::None => rmpv::Value::Nil,
335 NanTag::Heap => nb
336 .as_heap_ref()
337 .map(|hv| heap_to_msgpack_value(hv, schemas))
338 .unwrap_or(rmpv::Value::Nil),
339 _ => rmpv::Value::Nil,
340 }
341}
342
343fn heap_to_msgpack_value(hv: &HeapValue, schemas: &TypeSchemaRegistry) -> rmpv::Value {
344 match hv {
345 HeapValue::String(s) => rmpv::Value::String(rmpv::Utf8String::from(s.as_str())),
346 HeapValue::Array(arr) => rmpv::Value::Array(
347 arr.iter()
348 .map(|item| nanboxed_to_msgpack_value(item, schemas))
349 .collect(),
350 ),
351 HeapValue::HashMap(map) => {
352 let entries: Vec<(rmpv::Value, rmpv::Value)> = map
353 .keys
354 .iter()
355 .zip(map.values.iter())
356 .map(|(key, value)| {
357 (
358 nanboxed_to_msgpack_value(key, schemas),
359 nanboxed_to_msgpack_value(value, schemas),
360 )
361 })
362 .collect();
363 rmpv::Value::Map(entries)
364 }
365 HeapValue::TypedObject {
366 schema_id,
367 slots,
368 heap_mask,
369 } => {
370 if let Some(schema) = schemas.get_by_id(*schema_id as u32) {
371 let mut entries = Vec::with_capacity(schema.fields.len());
372 for field in &schema.fields {
373 let value = read_slot_nb(
374 slots,
375 field.index as usize,
376 *heap_mask,
377 Some(&field.field_type),
378 );
379 entries.push((
380 rmpv::Value::String(rmpv::Utf8String::from(field.wire_name().to_string())),
381 nanboxed_to_msgpack_value(&value, schemas),
382 ));
383 }
384 return rmpv::Value::Map(entries);
385 }
386
387 let entries: Vec<(rmpv::Value, rmpv::Value)> = slots
389 .iter()
390 .enumerate()
391 .map(|(index, slot)| {
392 let is_heap = *heap_mask & (1u64 << index) != 0;
393 let value = slot.as_value_word(is_heap);
394 (
395 rmpv::Value::String(rmpv::Utf8String::from(index.to_string())),
396 nanboxed_to_msgpack_value(&value, schemas),
397 )
398 })
399 .collect();
400 rmpv::Value::Map(entries)
401 }
402 HeapValue::TypeAnnotatedValue { value, .. } => nanboxed_to_msgpack_value(value, schemas),
403 HeapValue::Some(inner) => nanboxed_to_msgpack_value(inner, schemas),
404 HeapValue::Ok(inner) => nanboxed_to_msgpack_value(inner, schemas),
405 HeapValue::Err(inner) => nanboxed_to_msgpack_value(inner, schemas),
406 HeapValue::BigInt(n) => rmpv::Value::Integer(rmpv::Integer::from(*n)),
407 HeapValue::Decimal(d) => d
408 .to_f64()
409 .map(rmpv::Value::F64)
410 .unwrap_or_else(|| rmpv::Value::String(rmpv::Utf8String::from(d.to_string()))),
411 _ => rmpv::Value::Nil,
412 }
413}
414
415fn strip_result_wrapper(s: &str) -> &str {
421 if s.starts_with("Result<") && s.ends_with('>') {
422 &s[7..s.len() - 1]
423 } else {
424 s
425 }
426}
427
428fn marshal_error(msg: impl Into<String>) -> VMError {
430 VMError::RuntimeError(msg.into())
431}
432
433fn msgpack_type_name(val: &rmpv::Value) -> &'static str {
435 match val {
436 rmpv::Value::Nil => "nil",
437 rmpv::Value::Boolean(_) => "bool",
438 rmpv::Value::Integer(_) => "int",
439 rmpv::Value::F32(_) | rmpv::Value::F64(_) => "float",
440 rmpv::Value::String(_) => "string",
441 rmpv::Value::Array(_) => "array",
442 rmpv::Value::Map(_) => "map",
443 rmpv::Value::Binary(_) => "binary",
444 rmpv::Value::Ext(_, _) => "ext",
445 }
446}
447
448fn untyped_msgpack_to_nanboxed(val: &rmpv::Value) -> ValueWord {
450 match val {
451 rmpv::Value::Nil => ValueWord::none(),
452 rmpv::Value::Boolean(b) => ValueWord::from_bool(*b),
453 rmpv::Value::Integer(i) => {
454 if let Some(n) = i.as_i64() {
455 ValueWord::from_i64(n)
456 } else if let Some(n) = i.as_u64() {
457 ValueWord::from_i64(n as i64)
458 } else {
459 ValueWord::none()
460 }
461 }
462 rmpv::Value::F32(f) => ValueWord::from_f64(*f as f64),
463 rmpv::Value::F64(f) => ValueWord::from_f64(*f),
464 rmpv::Value::String(s) => {
465 if let Some(s) = s.as_str() {
466 ValueWord::from_string(Arc::new(s.to_string()))
467 } else {
468 ValueWord::none()
469 }
470 }
471 rmpv::Value::Array(arr) => {
472 let items: Vec<ValueWord> = arr.iter().map(untyped_msgpack_to_nanboxed).collect();
473 ValueWord::from_array(Arc::new(items))
474 }
475 rmpv::Value::Map(entries) => {
476 let mut keys = Vec::with_capacity(entries.len());
477 let mut values = Vec::with_capacity(entries.len());
478 for (k, v) in entries.iter() {
479 keys.push(untyped_msgpack_to_nanboxed(k));
480 values.push(untyped_msgpack_to_nanboxed(v));
481 }
482 ValueWord::from_hashmap_pairs(keys, values)
483 }
484 rmpv::Value::Ext(_, _) | rmpv::Value::Binary(_) => ValueWord::none(),
485 }
486}
487
488#[cfg(test)]
489mod tests {
490 use super::*;
491 use shape_runtime::type_schema::{FieldType, TypeSchemaRegistry};
492 use shape_value::{ValueSlot, heap_value::HeapValue};
493
494 fn measurement_value(
495 schema_id: u32,
496 timestamp: &str,
497 value: f64,
498 sensor_id: &str,
499 ) -> ValueWord {
500 ValueWord::from_heap_value(HeapValue::TypedObject {
501 schema_id: schema_id as u64,
502 slots: vec![
503 ValueSlot::from_heap(HeapValue::String(Arc::new(timestamp.to_string()))),
504 ValueSlot::from_number(value),
505 ValueSlot::from_heap(HeapValue::String(Arc::new(sensor_id.to_string()))),
506 ]
507 .into_boxed_slice(),
508 heap_mask: 0b101,
509 })
510 }
511
512 #[test]
513 fn marshal_args_preserves_typed_object_fields_as_msgpack_map() {
514 let mut schemas = TypeSchemaRegistry::new();
515 let measurement_schema_id = schemas.register_type(
516 "Measurement",
517 vec![
518 ("timestamp".to_string(), FieldType::String),
519 ("value".to_string(), FieldType::F64),
520 ("sensor_id".to_string(), FieldType::String),
521 ],
522 );
523
524 let readings = ValueWord::from_array(Arc::new(vec![
525 measurement_value(measurement_schema_id, "2026-02-22T10:00:00Z", 10.0, "A"),
526 measurement_value(measurement_schema_id, "2026-02-22T10:01:00Z", 10.5, "A"),
527 ]));
528
529 let bytes = marshal_args(&[readings], &schemas).expect("marshal should succeed");
530 let decoded: rmpv::Value = rmp_serde::from_slice(&bytes).expect("valid msgpack");
531
532 let outer = decoded.as_array().expect("expected outer arg array");
533 let reading_items = outer[0].as_array().expect("expected readings array");
534 let first = reading_items[0]
535 .as_map()
536 .expect("expected typed object map");
537
538 let mut fields = std::collections::HashMap::new();
539 for (k, v) in first {
540 if let rmpv::Value::String(s) = k
541 && let Some(name) = s.as_str()
542 {
543 fields.insert(name.to_string(), v.clone());
544 }
545 }
546
547 assert_eq!(
548 fields.get("timestamp").and_then(|v| v.as_str()),
549 Some("2026-02-22T10:00:00Z")
550 );
551 assert_eq!(fields.get("value").and_then(|v| v.as_f64()), Some(10.0));
552 assert_eq!(fields.get("sensor_id").and_then(|v| v.as_str()), Some("A"));
553 }
554
555 #[test]
556 fn unmarshal_result_typed_int() {
557 let schemas = TypeSchemaRegistry::new();
558 let val = rmpv::Value::Integer(rmpv::Integer::from(42));
559 let bytes = rmp_serde::to_vec(&val).unwrap();
560 let result = unmarshal_result(&bytes, "Result<int>", None, &schemas).unwrap();
561 assert_eq!(result.as_i64(), Some(42));
562 }
563
564 #[test]
565 fn unmarshal_result_typed_string_rejects_int() {
566 let schemas = TypeSchemaRegistry::new();
567 let val = rmpv::Value::Integer(rmpv::Integer::from(42));
568 let bytes = rmp_serde::to_vec(&val).unwrap();
569 let result = unmarshal_result(&bytes, "Result<string>", None, &schemas);
570 assert!(result.is_err());
571 }
572
573 #[test]
574 fn unmarshal_result_typed_bool() {
575 let schemas = TypeSchemaRegistry::new();
576 let val = rmpv::Value::Boolean(true);
577 let bytes = rmp_serde::to_vec(&val).unwrap();
578 let result = unmarshal_result(&bytes, "Result<bool>", None, &schemas).unwrap();
579 assert_eq!(result.as_bool(), Some(true));
580 }
581
582 #[test]
583 fn unmarshal_result_typed_array_of_ints() {
584 let schemas = TypeSchemaRegistry::new();
585 let val = rmpv::Value::Array(vec![
586 rmpv::Value::Integer(rmpv::Integer::from(1)),
587 rmpv::Value::Integer(rmpv::Integer::from(2)),
588 rmpv::Value::Integer(rmpv::Integer::from(3)),
589 ]);
590 let bytes = rmp_serde::to_vec(&val).unwrap();
591 let result = unmarshal_result(&bytes, "Result<Vec<int>>", None, &schemas).unwrap();
592 let arr = result.as_heap_ref().unwrap();
593 match arr {
594 HeapValue::Array(items) => {
595 assert_eq!(items.len(), 3);
596 assert_eq!(items[0].as_i64(), Some(1));
597 }
598 _ => panic!("expected array"),
599 }
600 }
601
602 #[test]
603 fn unmarshal_result_typed_object() {
604 let mut schemas = TypeSchemaRegistry::new();
605 let sid = schemas.register_type(
606 "__ffi_test_return",
607 vec![
608 ("id".to_string(), FieldType::I64),
609 ("name".to_string(), FieldType::String),
610 ],
611 );
612
613 let val = rmpv::Value::Map(vec![
614 (
615 rmpv::Value::String(rmpv::Utf8String::from("id")),
616 rmpv::Value::Integer(rmpv::Integer::from(42)),
617 ),
618 (
619 rmpv::Value::String(rmpv::Utf8String::from("name")),
620 rmpv::Value::String(rmpv::Utf8String::from("hello")),
621 ),
622 ]);
623 let bytes = rmp_serde::to_vec(&val).unwrap();
624 let result = unmarshal_result(
625 &bytes,
626 "Result<{id: int, name: string}>",
627 Some(sid as u32),
628 &schemas,
629 )
630 .unwrap();
631
632 match result.as_heap_ref() {
634 Some(HeapValue::TypedObject {
635 schema_id, slots, ..
636 }) => {
637 assert_eq!(*schema_id, sid as u64);
638 assert_eq!(slots[0].as_i64(), 42);
639 match slots[1].as_heap_value() {
640 HeapValue::String(s) => assert_eq!(s.as_str(), "hello"),
641 other => panic!("expected string, got {:?}", other),
642 }
643 }
644 _ => panic!("expected TypedObject"),
645 }
646 }
647
648 #[test]
649 fn unmarshal_result_any_fallback() {
650 let schemas = TypeSchemaRegistry::new();
651 let val = rmpv::Value::String(rmpv::Utf8String::from("anything"));
652 let bytes = rmp_serde::to_vec(&val).unwrap();
653 let result = unmarshal_result(&bytes, "Result<any>", None, &schemas).unwrap();
654 match result.as_heap_ref() {
655 Some(HeapValue::String(s)) => assert_eq!(s.as_str(), "anything"),
656 _ => panic!("expected string"),
657 }
658 }
659}