Skip to main content

shape_vm/executor/
snapshot.rs

1//! VM snapshot and restore for suspending/resuming execution.
2
3use std::collections::HashMap;
4
5use shape_runtime::snapshot::{
6    SerializableCallFrame, SerializableExceptionHandler, SerializableLoopContext, SnapshotStore,
7    VmSnapshot, nanboxed_to_serializable, serializable_to_nanboxed,
8};
9use shape_value::{Upvalue, VMError, ValueWord};
10
11use super::{CallFrame, ExceptionHandler, LoopContext, VMConfig, VirtualMachine};
12use crate::bytecode::{Function, FunctionHash};
13
14/// Resolve a function's runtime ID from content-addressed identity.
15///
16/// Priority: `blob_hash` → `function_id` → `function_name`.
17/// Cross-validates when multiple identifiers are present.
18pub(crate) fn resolve_function_identity(
19    function_id_by_hash: &HashMap<FunctionHash, u16>,
20    functions: &[Function],
21    blob_hash: Option<FunctionHash>,
22    function_id: Option<u16>,
23    function_name: Option<&str>,
24) -> Result<u16, VMError> {
25    // 1. Hash-first resolution
26    if let Some(hash) = blob_hash {
27        let resolved = function_id_by_hash.get(&hash).copied().ok_or_else(|| {
28            VMError::RuntimeError(format!("unknown function blob hash: {}", hash))
29        })?;
30        // Cross-validate: if function_id is also present, they must agree
31        if let Some(fid) = function_id {
32            if fid != resolved {
33                return Err(VMError::RuntimeError(format!(
34                    "function_id/hash mismatch: frame id {} does not match hash {} (resolved id {})",
35                    fid, hash, resolved
36                )));
37            }
38        }
39        return Ok(resolved);
40    }
41
42    // 2. Direct function_id (no hash available)
43    if let Some(fid) = function_id {
44        if (fid as usize) < functions.len() {
45            return Ok(fid);
46        }
47        return Err(VMError::RuntimeError(format!(
48            "function_id {} out of range (program has {} functions)",
49            fid,
50            functions.len()
51        )));
52    }
53
54    // 3. Name-based fallback — require exactly one match
55    if let Some(name) = function_name {
56        let matches: Vec<usize> = functions
57            .iter()
58            .enumerate()
59            .filter_map(|(idx, f)| if f.name == name { Some(idx) } else { None })
60            .collect();
61        return match matches.len() {
62            1 => Ok(matches[0] as u16),
63            0 => Err(VMError::RuntimeError(format!(
64                "no function named '{}'",
65                name
66            ))),
67            n => Err(VMError::RuntimeError(format!(
68                "ambiguous function name '{}' ({} matches)",
69                name, n
70            ))),
71        };
72    }
73
74    // 4. No identifiers at all
75    Err(VMError::RuntimeError(
76        "cannot resolve function identity: no hash, id, or name provided".into(),
77    ))
78}
79
80impl VirtualMachine {
81    /// Create a serializable snapshot of VM state.
82    pub fn snapshot(&self, store: &SnapshotStore) -> Result<VmSnapshot, VMError> {
83        let mut stack = Vec::with_capacity(self.sp);
84        for nb in self.stack[..self.sp].iter() {
85            stack.push(
86                nanboxed_to_serializable(nb, store)
87                    .map_err(|e| VMError::RuntimeError(e.to_string()))?,
88            );
89        }
90        // Locals are now part of the unified stack; serialize empty vec for backward compat
91        let locals = Vec::new();
92        let mut module_bindings = Vec::with_capacity(self.module_bindings.len());
93        for nb in self.module_bindings.iter() {
94            module_bindings.push(
95                nanboxed_to_serializable(nb, store)
96                    .map_err(|e| VMError::RuntimeError(e.to_string()))?,
97            );
98        }
99
100        let mut call_stack = Vec::with_capacity(self.call_stack.len());
101        for frame in self.call_stack.iter() {
102            let upvalues = match &frame.upvalues {
103                Some(values) => {
104                    let mut out = Vec::new();
105                    for up in values.iter() {
106                        let nb = up.get();
107                        out.push(
108                            nanboxed_to_serializable(&nb, store)
109                                .map_err(|e| VMError::RuntimeError(e.to_string()))?,
110                        );
111                    }
112                    Some(out)
113                }
114                None => None,
115            };
116            // Compute content-addressed snapshot fields when blob_hash is available
117            let (blob_hash, local_ip) =
118                if let (Some(hash), Some(fid)) = (frame.blob_hash, frame.function_id) {
119                    let entry_point = self
120                        .function_entry_points
121                        .get(fid as usize)
122                        .copied()
123                        .unwrap_or(0);
124                    let lip = frame.return_ip.saturating_sub(entry_point);
125                    (Some(hash.0), Some(lip))
126                } else {
127                    (None, None)
128                };
129
130            call_stack.push(SerializableCallFrame {
131                return_ip: frame.return_ip,
132                locals_base: frame.base_pointer,
133                locals_count: frame.locals_count,
134                function_id: frame.function_id,
135                upvalues,
136                blob_hash,
137                local_ip,
138            });
139        }
140
141        let loop_stack = self
142            .loop_stack
143            .iter()
144            .map(|l| SerializableLoopContext {
145                start: l.start,
146                end: l.end,
147            })
148            .collect();
149        let exception_handlers = self
150            .exception_handlers
151            .iter()
152            .map(|h| SerializableExceptionHandler {
153                catch_ip: h.catch_ip,
154                stack_size: h.stack_size,
155                call_depth: h.call_depth,
156            })
157            .collect();
158
159        // Compute relocatable top-level IP from the current call frame.
160        // The top-level `ip` corresponds to the innermost frame's function.
161        let (ip_blob_hash, ip_local_offset, ip_function_id) =
162            if let Some(frame) = self.call_stack.last() {
163                let fid = frame.function_id;
164                let blob_hash = fid.and_then(|id| self.blob_hash_for_function(id));
165                let entry_point = fid
166                    .and_then(|id| self.function_entry_points.get(id as usize).copied())
167                    .unwrap_or(0);
168                let local_offset = self.ip.saturating_sub(entry_point);
169                (blob_hash.map(|h| h.0), Some(local_offset), fid)
170            } else {
171                (None, None, None)
172            };
173
174        Ok(VmSnapshot {
175            ip: self.ip,
176            stack,
177            locals,
178            module_bindings,
179            call_stack,
180            loop_stack,
181            timeframe_stack: self.timeframe_stack.clone(),
182            exception_handlers,
183            ip_blob_hash,
184            ip_local_offset,
185            ip_function_id,
186        })
187    }
188
189    /// Restore a VM from a snapshot and bytecode program.
190    pub fn from_snapshot(
191        program: crate::bytecode::BytecodeProgram,
192        snapshot: &VmSnapshot,
193        store: &SnapshotStore,
194    ) -> Result<Self, VMError> {
195        let mut vm = VirtualMachine::new(VMConfig::default());
196        vm.load_program(program);
197
198        // Relocate the top-level IP using content-addressed identity when
199        // available. This handles the case where the program was recompiled
200        // and instruction positions changed.
201        vm.ip = if let (Some(hash_bytes), Some(local_offset)) =
202            (&snapshot.ip_blob_hash, snapshot.ip_local_offset)
203        {
204            let hash = FunctionHash(*hash_bytes);
205            // Look up the function by blob hash in the new program
206            let func_id = resolve_function_identity(
207                &vm.function_id_by_hash,
208                &vm.program.functions,
209                Some(hash),
210                snapshot.ip_function_id,
211                None,
212            )?;
213            let entry_point = vm
214                .function_entry_points
215                .get(func_id as usize)
216                .copied()
217                .unwrap_or(0);
218            entry_point + local_offset
219        } else if let Some(fid) = snapshot.ip_function_id {
220            // Fallback: use function_id to relocate (same program, stable IDs)
221            let entry_point = vm
222                .function_entry_points
223                .get(fid as usize)
224                .copied()
225                .unwrap_or(0);
226            let local_offset = snapshot.ip_local_offset.unwrap_or(0);
227            entry_point + local_offset
228        } else {
229            // Legacy snapshots without relocation info: use absolute IP
230            snapshot.ip
231        };
232
233        let restored_stack: Vec<ValueWord> = snapshot
234            .stack
235            .iter()
236            .map(|v| {
237                serializable_to_nanboxed(v, store).map_err(|e| VMError::RuntimeError(e.to_string()))
238            })
239            .collect::<Result<Vec<_>, _>>()?;
240        let restored_sp = restored_stack.len();
241        // Pre-allocate and copy into the unified stack
242        vm.stack = (0..restored_sp.max(crate::constants::DEFAULT_STACK_CAPACITY))
243            .map(|_| ValueWord::none())
244            .collect();
245        for (i, nb) in restored_stack.into_iter().enumerate() {
246            vm.stack[i] = nb;
247        }
248        vm.sp = restored_sp;
249        // Locals snapshot is ignored — locals now live on the unified stack
250        vm.module_bindings = snapshot
251            .module_bindings
252            .iter()
253            .map(|v| {
254                serializable_to_nanboxed(v, store).map_err(|e| VMError::RuntimeError(e.to_string()))
255            })
256            .collect::<Result<Vec<_>, _>>()?;
257
258        vm.call_stack = snapshot
259            .call_stack
260            .iter()
261            .map(|f| {
262                let upvalues = match &f.upvalues {
263                    Some(values) => {
264                        let mut out = Vec::new();
265                        for v in values.iter() {
266                            out.push(Upvalue::new(
267                                serializable_to_nanboxed(v, store)
268                                    .map_err(|e| VMError::RuntimeError(e.to_string()))?,
269                            ));
270                        }
271                        Some(out)
272                    }
273                    None => None,
274                };
275                // Restore blob_hash from the snapshot frame. Use the shared
276                // hash-first resolution helper for strict validation.
277                let blob_hash = f.blob_hash.map(FunctionHash);
278                let resolved_function_id = if blob_hash.is_some() || f.function_id.is_some() {
279                    Some(resolve_function_identity(
280                        &vm.function_id_by_hash,
281                        &vm.program.functions,
282                        blob_hash,
283                        f.function_id,
284                        None,
285                    )?)
286                } else {
287                    None
288                };
289
290                let return_ip = if let (Some(hash), Some(local_ip), Some(fid)) =
291                    (&blob_hash, f.local_ip, resolved_function_id)
292                {
293                    // Validate the blob hash matches the loaded program
294                    let current_hash = vm.blob_hash_for_function(fid);
295                    if let Some(current) = current_hash
296                        && current != *hash
297                    {
298                        return Err(VMError::RuntimeError(format!(
299                            "Snapshot blob hash mismatch for function {}: \
300                             snapshot has {}, program has {}",
301                            fid, hash, current
302                        )));
303                    }
304                    // Reconstruct absolute IP from local_ip + entry_point
305                    let entry_point = vm
306                        .function_entry_points
307                        .get(fid as usize)
308                        .copied()
309                        .unwrap_or(0);
310                    local_ip + entry_point
311                } else {
312                    f.return_ip
313                };
314
315                Ok(CallFrame {
316                    return_ip,
317                    base_pointer: f.locals_base,
318                    locals_count: f.locals_count,
319                    function_id: resolved_function_id,
320                    upvalues,
321                    blob_hash,
322                })
323            })
324            .collect::<Result<Vec<_>, VMError>>()?;
325
326        vm.loop_stack = snapshot
327            .loop_stack
328            .iter()
329            .map(|l| LoopContext {
330                start: l.start,
331                end: l.end,
332            })
333            .collect();
334        vm.timeframe_stack = snapshot.timeframe_stack.clone();
335        vm.exception_handlers = snapshot
336            .exception_handlers
337            .iter()
338            .map(|h| ExceptionHandler {
339                catch_ip: h.catch_ip,
340                stack_size: h.stack_size,
341                call_depth: h.call_depth,
342            })
343            .collect();
344
345        Ok(vm)
346    }
347}
348
349#[cfg(test)]
350mod tests {
351    use super::*;
352
353    /// Create a minimal Function with just a name (other fields defaulted).
354    fn make_function(name: &str) -> Function {
355        Function {
356            name: name.to_string(),
357            arity: 0,
358            param_names: Vec::new(),
359            locals_count: 0,
360            entry_point: 0,
361            body_length: 0,
362            is_closure: false,
363            captures_count: 0,
364            is_async: false,
365            ref_params: Vec::new(),
366            ref_mutates: Vec::new(),
367            mutable_captures: Vec::new(),
368            frame_descriptor: None,
369            osr_entry_points: Vec::new(),
370        }
371    }
372
373    fn make_hash(seed: u8) -> FunctionHash {
374        FunctionHash([seed; 32])
375    }
376
377    #[test]
378    fn test_resolve_by_hash() {
379        let hash = make_hash(0xAB);
380        let mut by_hash = HashMap::new();
381        by_hash.insert(hash, 3u16);
382        let funcs = vec![
383            make_function("a"),
384            make_function("b"),
385            make_function("c"),
386            make_function("d"),
387        ];
388
389        let result = resolve_function_identity(&by_hash, &funcs, Some(hash), None, None);
390        assert_eq!(result.unwrap(), 3);
391    }
392
393    #[test]
394    fn test_resolve_hash_not_found_is_error() {
395        let hash = make_hash(0xAB);
396        let by_hash = HashMap::new(); // empty — hash not registered
397        let funcs = vec![make_function("a")];
398
399        let result = resolve_function_identity(&by_hash, &funcs, Some(hash), None, None);
400        assert!(result.is_err());
401        let msg = result.unwrap_err().to_string();
402        assert!(msg.contains("unknown function blob hash"), "got: {}", msg);
403    }
404
405    #[test]
406    fn test_resolve_hash_function_id_mismatch_is_error() {
407        let hash = make_hash(0xCD);
408        let mut by_hash = HashMap::new();
409        by_hash.insert(hash, 2u16); // hash resolves to 2
410        let funcs = vec![make_function("a"), make_function("b"), make_function("c")];
411
412        // Pass function_id=5 which disagrees with hash-resolved id=2
413        let result = resolve_function_identity(&by_hash, &funcs, Some(hash), Some(5), None);
414        assert!(result.is_err());
415        let msg = result.unwrap_err().to_string();
416        assert!(msg.contains("mismatch"), "got: {}", msg);
417    }
418
419    #[test]
420    fn test_resolve_hash_function_id_agree() {
421        let hash = make_hash(0xEF);
422        let mut by_hash = HashMap::new();
423        by_hash.insert(hash, 1u16);
424        let funcs = vec![make_function("a"), make_function("b")];
425
426        // Both agree on id=1
427        let result = resolve_function_identity(&by_hash, &funcs, Some(hash), Some(1), None);
428        assert_eq!(result.unwrap(), 1);
429    }
430
431    #[test]
432    fn test_resolve_by_function_id() {
433        let by_hash = HashMap::new();
434        let funcs = vec![make_function("a"), make_function("b"), make_function("c")];
435
436        let result = resolve_function_identity(&by_hash, &funcs, None, Some(2), None);
437        assert_eq!(result.unwrap(), 2);
438    }
439
440    #[test]
441    fn test_resolve_function_id_out_of_range() {
442        let by_hash = HashMap::new();
443        let funcs = vec![make_function("a")];
444
445        let result = resolve_function_identity(&by_hash, &funcs, None, Some(99), None);
446        assert!(result.is_err());
447        let msg = result.unwrap_err().to_string();
448        assert!(msg.contains("out of range"), "got: {}", msg);
449    }
450
451    #[test]
452    fn test_resolve_unique_name_fallback() {
453        let by_hash = HashMap::new();
454        let funcs = vec![
455            make_function("alpha"),
456            make_function("beta"),
457            make_function("gamma"),
458        ];
459
460        let result = resolve_function_identity(&by_hash, &funcs, None, None, Some("beta"));
461        assert_eq!(result.unwrap(), 1);
462    }
463
464    #[test]
465    fn test_resolve_ambiguous_name_is_error() {
466        let by_hash = HashMap::new();
467        let funcs = vec![
468            make_function("dup"),
469            make_function("other"),
470            make_function("dup"),
471        ];
472
473        let result = resolve_function_identity(&by_hash, &funcs, None, None, Some("dup"));
474        assert!(result.is_err());
475        let msg = result.unwrap_err().to_string();
476        assert!(msg.contains("ambiguous"), "got: {}", msg);
477    }
478
479    #[test]
480    fn test_resolve_name_not_found() {
481        let by_hash = HashMap::new();
482        let funcs = vec![make_function("a")];
483
484        let result = resolve_function_identity(&by_hash, &funcs, None, None, Some("missing"));
485        assert!(result.is_err());
486        let msg = result.unwrap_err().to_string();
487        assert!(msg.contains("no function named"), "got: {}", msg);
488    }
489
490    #[test]
491    fn test_resolve_no_identifiers_is_error() {
492        let by_hash = HashMap::new();
493        let funcs = vec![make_function("a")];
494
495        let result = resolve_function_identity(&by_hash, &funcs, None, None, None);
496        assert!(result.is_err());
497        let msg = result.unwrap_err().to_string();
498        assert!(msg.contains("no hash, id, or name"), "got: {}", msg);
499    }
500
501    // --- VmSnapshot IP relocation tests ---
502
503    #[test]
504    fn test_snapshot_ip_relocation_fields_present() {
505        // Verify that VmSnapshot has the new relocation fields
506        let snapshot = VmSnapshot {
507            ip: 42,
508            stack: vec![],
509            locals: vec![],
510            module_bindings: vec![],
511            call_stack: vec![],
512            loop_stack: vec![],
513            timeframe_stack: vec![],
514            exception_handlers: vec![],
515            ip_blob_hash: Some([0xAB; 32]),
516            ip_local_offset: Some(10),
517            ip_function_id: Some(1),
518        };
519        assert_eq!(snapshot.ip, 42);
520        assert_eq!(snapshot.ip_blob_hash, Some([0xAB; 32]));
521        assert_eq!(snapshot.ip_local_offset, Some(10));
522        assert_eq!(snapshot.ip_function_id, Some(1));
523    }
524
525    #[test]
526    fn test_snapshot_legacy_without_relocation_fields() {
527        // Legacy snapshots that don't have the new fields should still deserialize
528        // (serde default kicks in)
529        let snapshot = VmSnapshot {
530            ip: 100,
531            stack: vec![],
532            locals: vec![],
533            module_bindings: vec![],
534            call_stack: vec![],
535            loop_stack: vec![],
536            timeframe_stack: vec![],
537            exception_handlers: vec![],
538            ip_blob_hash: None,
539            ip_local_offset: None,
540            ip_function_id: None,
541        };
542        // Without relocation info, from_snapshot should fall back to absolute IP
543        assert!(snapshot.ip_blob_hash.is_none());
544        assert!(snapshot.ip_local_offset.is_none());
545        assert!(snapshot.ip_function_id.is_none());
546    }
547
548    #[test]
549    fn test_snapshot_serialization_roundtrip_with_relocation() {
550        let snapshot = VmSnapshot {
551            ip: 42,
552            stack: vec![],
553            locals: vec![],
554            module_bindings: vec![],
555            call_stack: vec![],
556            loop_stack: vec![],
557            timeframe_stack: vec![],
558            exception_handlers: vec![],
559            ip_blob_hash: Some([0xCD; 32]),
560            ip_local_offset: Some(7),
561            ip_function_id: Some(2),
562        };
563        let json = serde_json::to_string(&snapshot).unwrap();
564        let restored: VmSnapshot = serde_json::from_str(&json).unwrap();
565        assert_eq!(restored.ip_blob_hash, Some([0xCD; 32]));
566        assert_eq!(restored.ip_local_offset, Some(7));
567        assert_eq!(restored.ip_function_id, Some(2));
568    }
569
570    #[test]
571    fn test_snapshot_deserialization_without_relocation_fields() {
572        // Simulate a JSON snapshot from before the relocation fields were added
573        let json = r#"{
574            "ip": 50,
575            "stack": [],
576            "locals": [],
577            "module_bindings": [],
578            "call_stack": [],
579            "loop_stack": [],
580            "timeframe_stack": [],
581            "exception_handlers": []
582        }"#;
583        let snapshot: VmSnapshot = serde_json::from_str(json).unwrap();
584        assert_eq!(snapshot.ip, 50);
585        assert!(snapshot.ip_blob_hash.is_none());
586        assert!(snapshot.ip_local_offset.is_none());
587        assert!(snapshot.ip_function_id.is_none());
588    }
589}