1use std::collections::HashMap;
4
5use shape_runtime::snapshot::{
6 SerializableCallFrame, SerializableExceptionHandler, SerializableLoopContext, SnapshotStore,
7 VmSnapshot, nanboxed_to_serializable, serializable_to_nanboxed,
8};
9use shape_value::{Upvalue, VMError, ValueWord};
10
11use super::{CallFrame, ExceptionHandler, LoopContext, VMConfig, VirtualMachine};
12use crate::bytecode::{Function, FunctionHash};
13
14pub(crate) fn resolve_function_identity(
19 function_id_by_hash: &HashMap<FunctionHash, u16>,
20 functions: &[Function],
21 blob_hash: Option<FunctionHash>,
22 function_id: Option<u16>,
23 function_name: Option<&str>,
24) -> Result<u16, VMError> {
25 if let Some(hash) = blob_hash {
27 let resolved = function_id_by_hash.get(&hash).copied().ok_or_else(|| {
28 VMError::RuntimeError(format!("unknown function blob hash: {}", hash))
29 })?;
30 if let Some(fid) = function_id {
32 if fid != resolved {
33 return Err(VMError::RuntimeError(format!(
34 "function_id/hash mismatch: frame id {} does not match hash {} (resolved id {})",
35 fid, hash, resolved
36 )));
37 }
38 }
39 return Ok(resolved);
40 }
41
42 if let Some(fid) = function_id {
44 if (fid as usize) < functions.len() {
45 return Ok(fid);
46 }
47 return Err(VMError::RuntimeError(format!(
48 "function_id {} out of range (program has {} functions)",
49 fid,
50 functions.len()
51 )));
52 }
53
54 if let Some(name) = function_name {
56 let matches: Vec<usize> = functions
57 .iter()
58 .enumerate()
59 .filter_map(|(idx, f)| if f.name == name { Some(idx) } else { None })
60 .collect();
61 return match matches.len() {
62 1 => Ok(matches[0] as u16),
63 0 => Err(VMError::RuntimeError(format!(
64 "no function named '{}'",
65 name
66 ))),
67 n => Err(VMError::RuntimeError(format!(
68 "ambiguous function name '{}' ({} matches)",
69 name, n
70 ))),
71 };
72 }
73
74 Err(VMError::RuntimeError(
76 "cannot resolve function identity: no hash, id, or name provided".into(),
77 ))
78}
79
80impl VirtualMachine {
81 pub fn snapshot(&self, store: &SnapshotStore) -> Result<VmSnapshot, VMError> {
83 let mut stack = Vec::with_capacity(self.sp);
84 for nb in self.stack[..self.sp].iter() {
85 stack.push(
86 nanboxed_to_serializable(nb, store)
87 .map_err(|e| VMError::RuntimeError(e.to_string()))?,
88 );
89 }
90 let locals = Vec::new();
92 let mut module_bindings = Vec::with_capacity(self.module_bindings.len());
93 for nb in self.module_bindings.iter() {
94 module_bindings.push(
95 nanboxed_to_serializable(nb, store)
96 .map_err(|e| VMError::RuntimeError(e.to_string()))?,
97 );
98 }
99
100 let mut call_stack = Vec::with_capacity(self.call_stack.len());
101 for frame in self.call_stack.iter() {
102 let upvalues = match &frame.upvalues {
103 Some(values) => {
104 let mut out = Vec::new();
105 for up in values.iter() {
106 let nb = up.get();
107 out.push(
108 nanboxed_to_serializable(&nb, store)
109 .map_err(|e| VMError::RuntimeError(e.to_string()))?,
110 );
111 }
112 Some(out)
113 }
114 None => None,
115 };
116 let (blob_hash, local_ip) =
118 if let (Some(hash), Some(fid)) = (frame.blob_hash, frame.function_id) {
119 let entry_point = self
120 .function_entry_points
121 .get(fid as usize)
122 .copied()
123 .unwrap_or(0);
124 let lip = frame.return_ip.saturating_sub(entry_point);
125 (Some(hash.0), Some(lip))
126 } else {
127 (None, None)
128 };
129
130 call_stack.push(SerializableCallFrame {
131 return_ip: frame.return_ip,
132 locals_base: frame.base_pointer,
133 locals_count: frame.locals_count,
134 function_id: frame.function_id,
135 upvalues,
136 blob_hash,
137 local_ip,
138 });
139 }
140
141 let loop_stack = self
142 .loop_stack
143 .iter()
144 .map(|l| SerializableLoopContext {
145 start: l.start,
146 end: l.end,
147 })
148 .collect();
149 let exception_handlers = self
150 .exception_handlers
151 .iter()
152 .map(|h| SerializableExceptionHandler {
153 catch_ip: h.catch_ip,
154 stack_size: h.stack_size,
155 call_depth: h.call_depth,
156 })
157 .collect();
158
159 let (ip_blob_hash, ip_local_offset, ip_function_id) =
162 if let Some(frame) = self.call_stack.last() {
163 let fid = frame.function_id;
164 let blob_hash = fid.and_then(|id| self.blob_hash_for_function(id));
165 let entry_point = fid
166 .and_then(|id| self.function_entry_points.get(id as usize).copied())
167 .unwrap_or(0);
168 let local_offset = self.ip.saturating_sub(entry_point);
169 (blob_hash.map(|h| h.0), Some(local_offset), fid)
170 } else {
171 (None, None, None)
172 };
173
174 Ok(VmSnapshot {
175 ip: self.ip,
176 stack,
177 locals,
178 module_bindings,
179 call_stack,
180 loop_stack,
181 timeframe_stack: self.timeframe_stack.clone(),
182 exception_handlers,
183 ip_blob_hash,
184 ip_local_offset,
185 ip_function_id,
186 })
187 }
188
189 pub fn from_snapshot(
191 program: crate::bytecode::BytecodeProgram,
192 snapshot: &VmSnapshot,
193 store: &SnapshotStore,
194 ) -> Result<Self, VMError> {
195 let mut vm = VirtualMachine::new(VMConfig::default());
196 vm.load_program(program);
197
198 vm.ip = if let (Some(hash_bytes), Some(local_offset)) =
202 (&snapshot.ip_blob_hash, snapshot.ip_local_offset)
203 {
204 let hash = FunctionHash(*hash_bytes);
205 let func_id = resolve_function_identity(
207 &vm.function_id_by_hash,
208 &vm.program.functions,
209 Some(hash),
210 snapshot.ip_function_id,
211 None,
212 )?;
213 let entry_point = vm
214 .function_entry_points
215 .get(func_id as usize)
216 .copied()
217 .unwrap_or(0);
218 entry_point + local_offset
219 } else if let Some(fid) = snapshot.ip_function_id {
220 let entry_point = vm
222 .function_entry_points
223 .get(fid as usize)
224 .copied()
225 .unwrap_or(0);
226 let local_offset = snapshot.ip_local_offset.unwrap_or(0);
227 entry_point + local_offset
228 } else {
229 snapshot.ip
231 };
232
233 let restored_stack: Vec<ValueWord> = snapshot
234 .stack
235 .iter()
236 .map(|v| {
237 serializable_to_nanboxed(v, store).map_err(|e| VMError::RuntimeError(e.to_string()))
238 })
239 .collect::<Result<Vec<_>, _>>()?;
240 let restored_sp = restored_stack.len();
241 vm.stack = (0..restored_sp.max(crate::constants::DEFAULT_STACK_CAPACITY))
243 .map(|_| ValueWord::none())
244 .collect();
245 for (i, nb) in restored_stack.into_iter().enumerate() {
246 vm.stack[i] = nb;
247 }
248 vm.sp = restored_sp;
249 vm.module_bindings = snapshot
251 .module_bindings
252 .iter()
253 .map(|v| {
254 serializable_to_nanboxed(v, store).map_err(|e| VMError::RuntimeError(e.to_string()))
255 })
256 .collect::<Result<Vec<_>, _>>()?;
257
258 vm.call_stack = snapshot
259 .call_stack
260 .iter()
261 .map(|f| {
262 let upvalues = match &f.upvalues {
263 Some(values) => {
264 let mut out = Vec::new();
265 for v in values.iter() {
266 out.push(Upvalue::new(
267 serializable_to_nanboxed(v, store)
268 .map_err(|e| VMError::RuntimeError(e.to_string()))?,
269 ));
270 }
271 Some(out)
272 }
273 None => None,
274 };
275 let blob_hash = f.blob_hash.map(FunctionHash);
278 let resolved_function_id = if blob_hash.is_some() || f.function_id.is_some() {
279 Some(resolve_function_identity(
280 &vm.function_id_by_hash,
281 &vm.program.functions,
282 blob_hash,
283 f.function_id,
284 None,
285 )?)
286 } else {
287 None
288 };
289
290 let return_ip = if let (Some(hash), Some(local_ip), Some(fid)) =
291 (&blob_hash, f.local_ip, resolved_function_id)
292 {
293 let current_hash = vm.blob_hash_for_function(fid);
295 if let Some(current) = current_hash
296 && current != *hash
297 {
298 return Err(VMError::RuntimeError(format!(
299 "Snapshot blob hash mismatch for function {}: \
300 snapshot has {}, program has {}",
301 fid, hash, current
302 )));
303 }
304 let entry_point = vm
306 .function_entry_points
307 .get(fid as usize)
308 .copied()
309 .unwrap_or(0);
310 local_ip + entry_point
311 } else {
312 f.return_ip
313 };
314
315 Ok(CallFrame {
316 return_ip,
317 base_pointer: f.locals_base,
318 locals_count: f.locals_count,
319 function_id: resolved_function_id,
320 upvalues,
321 blob_hash,
322 })
323 })
324 .collect::<Result<Vec<_>, VMError>>()?;
325
326 vm.loop_stack = snapshot
327 .loop_stack
328 .iter()
329 .map(|l| LoopContext {
330 start: l.start,
331 end: l.end,
332 })
333 .collect();
334 vm.timeframe_stack = snapshot.timeframe_stack.clone();
335 vm.exception_handlers = snapshot
336 .exception_handlers
337 .iter()
338 .map(|h| ExceptionHandler {
339 catch_ip: h.catch_ip,
340 stack_size: h.stack_size,
341 call_depth: h.call_depth,
342 })
343 .collect();
344
345 Ok(vm)
346 }
347}
348
349#[cfg(test)]
350mod tests {
351 use super::*;
352
353 fn make_function(name: &str) -> Function {
355 Function {
356 name: name.to_string(),
357 arity: 0,
358 param_names: Vec::new(),
359 locals_count: 0,
360 entry_point: 0,
361 body_length: 0,
362 is_closure: false,
363 captures_count: 0,
364 is_async: false,
365 ref_params: Vec::new(),
366 ref_mutates: Vec::new(),
367 mutable_captures: Vec::new(),
368 frame_descriptor: None,
369 osr_entry_points: Vec::new(),
370 }
371 }
372
373 fn make_hash(seed: u8) -> FunctionHash {
374 FunctionHash([seed; 32])
375 }
376
377 #[test]
378 fn test_resolve_by_hash() {
379 let hash = make_hash(0xAB);
380 let mut by_hash = HashMap::new();
381 by_hash.insert(hash, 3u16);
382 let funcs = vec![
383 make_function("a"),
384 make_function("b"),
385 make_function("c"),
386 make_function("d"),
387 ];
388
389 let result = resolve_function_identity(&by_hash, &funcs, Some(hash), None, None);
390 assert_eq!(result.unwrap(), 3);
391 }
392
393 #[test]
394 fn test_resolve_hash_not_found_is_error() {
395 let hash = make_hash(0xAB);
396 let by_hash = HashMap::new(); let funcs = vec![make_function("a")];
398
399 let result = resolve_function_identity(&by_hash, &funcs, Some(hash), None, None);
400 assert!(result.is_err());
401 let msg = result.unwrap_err().to_string();
402 assert!(msg.contains("unknown function blob hash"), "got: {}", msg);
403 }
404
405 #[test]
406 fn test_resolve_hash_function_id_mismatch_is_error() {
407 let hash = make_hash(0xCD);
408 let mut by_hash = HashMap::new();
409 by_hash.insert(hash, 2u16); let funcs = vec![make_function("a"), make_function("b"), make_function("c")];
411
412 let result = resolve_function_identity(&by_hash, &funcs, Some(hash), Some(5), None);
414 assert!(result.is_err());
415 let msg = result.unwrap_err().to_string();
416 assert!(msg.contains("mismatch"), "got: {}", msg);
417 }
418
419 #[test]
420 fn test_resolve_hash_function_id_agree() {
421 let hash = make_hash(0xEF);
422 let mut by_hash = HashMap::new();
423 by_hash.insert(hash, 1u16);
424 let funcs = vec![make_function("a"), make_function("b")];
425
426 let result = resolve_function_identity(&by_hash, &funcs, Some(hash), Some(1), None);
428 assert_eq!(result.unwrap(), 1);
429 }
430
431 #[test]
432 fn test_resolve_by_function_id() {
433 let by_hash = HashMap::new();
434 let funcs = vec![make_function("a"), make_function("b"), make_function("c")];
435
436 let result = resolve_function_identity(&by_hash, &funcs, None, Some(2), None);
437 assert_eq!(result.unwrap(), 2);
438 }
439
440 #[test]
441 fn test_resolve_function_id_out_of_range() {
442 let by_hash = HashMap::new();
443 let funcs = vec![make_function("a")];
444
445 let result = resolve_function_identity(&by_hash, &funcs, None, Some(99), None);
446 assert!(result.is_err());
447 let msg = result.unwrap_err().to_string();
448 assert!(msg.contains("out of range"), "got: {}", msg);
449 }
450
451 #[test]
452 fn test_resolve_unique_name_fallback() {
453 let by_hash = HashMap::new();
454 let funcs = vec![
455 make_function("alpha"),
456 make_function("beta"),
457 make_function("gamma"),
458 ];
459
460 let result = resolve_function_identity(&by_hash, &funcs, None, None, Some("beta"));
461 assert_eq!(result.unwrap(), 1);
462 }
463
464 #[test]
465 fn test_resolve_ambiguous_name_is_error() {
466 let by_hash = HashMap::new();
467 let funcs = vec![
468 make_function("dup"),
469 make_function("other"),
470 make_function("dup"),
471 ];
472
473 let result = resolve_function_identity(&by_hash, &funcs, None, None, Some("dup"));
474 assert!(result.is_err());
475 let msg = result.unwrap_err().to_string();
476 assert!(msg.contains("ambiguous"), "got: {}", msg);
477 }
478
479 #[test]
480 fn test_resolve_name_not_found() {
481 let by_hash = HashMap::new();
482 let funcs = vec![make_function("a")];
483
484 let result = resolve_function_identity(&by_hash, &funcs, None, None, Some("missing"));
485 assert!(result.is_err());
486 let msg = result.unwrap_err().to_string();
487 assert!(msg.contains("no function named"), "got: {}", msg);
488 }
489
490 #[test]
491 fn test_resolve_no_identifiers_is_error() {
492 let by_hash = HashMap::new();
493 let funcs = vec![make_function("a")];
494
495 let result = resolve_function_identity(&by_hash, &funcs, None, None, None);
496 assert!(result.is_err());
497 let msg = result.unwrap_err().to_string();
498 assert!(msg.contains("no hash, id, or name"), "got: {}", msg);
499 }
500
501 #[test]
504 fn test_snapshot_ip_relocation_fields_present() {
505 let snapshot = VmSnapshot {
507 ip: 42,
508 stack: vec![],
509 locals: vec![],
510 module_bindings: vec![],
511 call_stack: vec![],
512 loop_stack: vec![],
513 timeframe_stack: vec![],
514 exception_handlers: vec![],
515 ip_blob_hash: Some([0xAB; 32]),
516 ip_local_offset: Some(10),
517 ip_function_id: Some(1),
518 };
519 assert_eq!(snapshot.ip, 42);
520 assert_eq!(snapshot.ip_blob_hash, Some([0xAB; 32]));
521 assert_eq!(snapshot.ip_local_offset, Some(10));
522 assert_eq!(snapshot.ip_function_id, Some(1));
523 }
524
525 #[test]
526 fn test_snapshot_legacy_without_relocation_fields() {
527 let snapshot = VmSnapshot {
530 ip: 100,
531 stack: vec![],
532 locals: vec![],
533 module_bindings: vec![],
534 call_stack: vec![],
535 loop_stack: vec![],
536 timeframe_stack: vec![],
537 exception_handlers: vec![],
538 ip_blob_hash: None,
539 ip_local_offset: None,
540 ip_function_id: None,
541 };
542 assert!(snapshot.ip_blob_hash.is_none());
544 assert!(snapshot.ip_local_offset.is_none());
545 assert!(snapshot.ip_function_id.is_none());
546 }
547
548 #[test]
549 fn test_snapshot_serialization_roundtrip_with_relocation() {
550 let snapshot = VmSnapshot {
551 ip: 42,
552 stack: vec![],
553 locals: vec![],
554 module_bindings: vec![],
555 call_stack: vec![],
556 loop_stack: vec![],
557 timeframe_stack: vec![],
558 exception_handlers: vec![],
559 ip_blob_hash: Some([0xCD; 32]),
560 ip_local_offset: Some(7),
561 ip_function_id: Some(2),
562 };
563 let json = serde_json::to_string(&snapshot).unwrap();
564 let restored: VmSnapshot = serde_json::from_str(&json).unwrap();
565 assert_eq!(restored.ip_blob_hash, Some([0xCD; 32]));
566 assert_eq!(restored.ip_local_offset, Some(7));
567 assert_eq!(restored.ip_function_id, Some(2));
568 }
569
570 #[test]
571 fn test_snapshot_deserialization_without_relocation_fields() {
572 let json = r#"{
574 "ip": 50,
575 "stack": [],
576 "locals": [],
577 "module_bindings": [],
578 "call_stack": [],
579 "loop_stack": [],
580 "timeframe_stack": [],
581 "exception_handlers": []
582 }"#;
583 let snapshot: VmSnapshot = serde_json::from_str(json).unwrap();
584 assert_eq!(snapshot.ip, 50);
585 assert!(snapshot.ip_blob_hash.is_none());
586 assert!(snapshot.ip_local_offset.is_none());
587 assert!(snapshot.ip_function_id.is_none());
588 }
589}