Skip to main content

shape_vm/executor/
snapshot.rs

1//! VM snapshot and restore for suspending/resuming execution.
2
3use std::collections::HashMap;
4
5use shape_runtime::snapshot::{
6    SerializableCallFrame, SerializableExceptionHandler, SerializableLoopContext, SnapshotStore,
7    VmSnapshot, nanboxed_to_serializable, serializable_to_nanboxed,
8};
9use shape_value::{Upvalue, VMError, ValueWord};
10
11use super::{CallFrame, ExceptionHandler, LoopContext, VMConfig, VirtualMachine};
12use crate::bytecode::{Function, FunctionHash};
13
14/// Resolve a function's runtime ID from content-addressed identity.
15///
16/// Priority: `blob_hash` → `function_id` → `function_name`.
17/// Cross-validates when multiple identifiers are present.
18pub(crate) fn resolve_function_identity(
19    function_id_by_hash: &HashMap<FunctionHash, u16>,
20    functions: &[Function],
21    blob_hash: Option<FunctionHash>,
22    function_id: Option<u16>,
23    function_name: Option<&str>,
24) -> Result<u16, VMError> {
25    // 1. Hash-first resolution
26    if let Some(hash) = blob_hash {
27        let resolved = function_id_by_hash.get(&hash).copied().ok_or_else(|| {
28            VMError::RuntimeError(format!("unknown function blob hash: {}", hash))
29        })?;
30        // Cross-validate: if function_id is also present, they must agree
31        if let Some(fid) = function_id {
32            if fid != resolved {
33                return Err(VMError::RuntimeError(format!(
34                    "function_id/hash mismatch: frame id {} does not match hash {} (resolved id {})",
35                    fid, hash, resolved
36                )));
37            }
38        }
39        return Ok(resolved);
40    }
41
42    // 2. Direct function_id (no hash available)
43    if let Some(fid) = function_id {
44        if (fid as usize) < functions.len() {
45            return Ok(fid);
46        }
47        return Err(VMError::RuntimeError(format!(
48            "function_id {} out of range (program has {} functions)",
49            fid,
50            functions.len()
51        )));
52    }
53
54    // 3. Name-based fallback — require exactly one match
55    if let Some(name) = function_name {
56        let matches: Vec<usize> = functions
57            .iter()
58            .enumerate()
59            .filter_map(|(idx, f)| if f.name == name { Some(idx) } else { None })
60            .collect();
61        return match matches.len() {
62            1 => Ok(matches[0] as u16),
63            0 => Err(VMError::RuntimeError(format!(
64                "no function named '{}'",
65                name
66            ))),
67            n => Err(VMError::RuntimeError(format!(
68                "ambiguous function name '{}' ({} matches)",
69                name, n
70            ))),
71        };
72    }
73
74    // 4. No identifiers at all
75    Err(VMError::RuntimeError(
76        "cannot resolve function identity: no hash, id, or name provided".into(),
77    ))
78}
79
80impl VirtualMachine {
81    /// Create a serializable snapshot of VM state.
82    pub fn snapshot(&self, store: &SnapshotStore) -> Result<VmSnapshot, VMError> {
83        let mut stack = Vec::with_capacity(self.sp);
84        for nb in self.stack[..self.sp].iter() {
85            stack.push(
86                nanboxed_to_serializable(nb, store)
87                    .map_err(|e| VMError::RuntimeError(e.to_string()))?,
88            );
89        }
90        // Locals are now part of the unified stack; serialize empty vec for backward compat
91        let locals = Vec::new();
92        let mut module_bindings = Vec::with_capacity(self.module_bindings.len());
93        for nb in self.module_bindings.iter() {
94            module_bindings.push(
95                nanboxed_to_serializable(nb, store)
96                    .map_err(|e| VMError::RuntimeError(e.to_string()))?,
97            );
98        }
99
100        let mut call_stack = Vec::with_capacity(self.call_stack.len());
101        for frame in self.call_stack.iter() {
102            let upvalues = match &frame.upvalues {
103                Some(values) => {
104                    let mut out = Vec::new();
105                    for up in values.iter() {
106                        let nb = up.get();
107                        out.push(
108                            nanboxed_to_serializable(&nb, store)
109                                .map_err(|e| VMError::RuntimeError(e.to_string()))?,
110                        );
111                    }
112                    Some(out)
113                }
114                None => None,
115            };
116            // Compute content-addressed snapshot fields when blob_hash is available
117            let (blob_hash, local_ip) =
118                if let (Some(hash), Some(fid)) = (frame.blob_hash, frame.function_id) {
119                    let entry_point = self
120                        .function_entry_points
121                        .get(fid as usize)
122                        .copied()
123                        .unwrap_or(0);
124                    let lip = frame.return_ip.saturating_sub(entry_point);
125                    (Some(hash.0), Some(lip))
126                } else {
127                    (None, None)
128                };
129
130            call_stack.push(SerializableCallFrame {
131                return_ip: frame.return_ip,
132                locals_base: frame.base_pointer,
133                locals_count: frame.locals_count,
134                function_id: frame.function_id,
135                upvalues,
136                blob_hash,
137                local_ip,
138            });
139        }
140
141        let loop_stack = self
142            .loop_stack
143            .iter()
144            .map(|l| SerializableLoopContext {
145                start: l.start,
146                end: l.end,
147            })
148            .collect();
149        let exception_handlers = self
150            .exception_handlers
151            .iter()
152            .map(|h| SerializableExceptionHandler {
153                catch_ip: h.catch_ip,
154                stack_size: h.stack_size,
155                call_depth: h.call_depth,
156            })
157            .collect();
158
159        Ok(VmSnapshot {
160            ip: self.ip,
161            stack,
162            locals,
163            module_bindings,
164            call_stack,
165            loop_stack,
166            timeframe_stack: self.timeframe_stack.clone(),
167            exception_handlers,
168        })
169    }
170
171    /// Restore a VM from a snapshot and bytecode program.
172    pub fn from_snapshot(
173        program: crate::bytecode::BytecodeProgram,
174        snapshot: &VmSnapshot,
175        store: &SnapshotStore,
176    ) -> Result<Self, VMError> {
177        let mut vm = VirtualMachine::new(VMConfig::default());
178        vm.load_program(program);
179        vm.ip = snapshot.ip;
180
181        let restored_stack: Vec<ValueWord> = snapshot
182            .stack
183            .iter()
184            .map(|v| {
185                serializable_to_nanboxed(v, store).map_err(|e| VMError::RuntimeError(e.to_string()))
186            })
187            .collect::<Result<Vec<_>, _>>()?;
188        let restored_sp = restored_stack.len();
189        // Pre-allocate and copy into the unified stack
190        vm.stack = (0..restored_sp.max(crate::constants::DEFAULT_STACK_CAPACITY))
191            .map(|_| ValueWord::none())
192            .collect();
193        for (i, nb) in restored_stack.into_iter().enumerate() {
194            vm.stack[i] = nb;
195        }
196        vm.sp = restored_sp;
197        // Locals snapshot is ignored — locals now live on the unified stack
198        vm.module_bindings = snapshot
199            .module_bindings
200            .iter()
201            .map(|v| {
202                serializable_to_nanboxed(v, store).map_err(|e| VMError::RuntimeError(e.to_string()))
203            })
204            .collect::<Result<Vec<_>, _>>()?;
205
206        vm.call_stack = snapshot
207            .call_stack
208            .iter()
209            .map(|f| {
210                let upvalues = match &f.upvalues {
211                    Some(values) => {
212                        let mut out = Vec::new();
213                        for v in values.iter() {
214                            out.push(Upvalue::new(
215                                serializable_to_nanboxed(v, store)
216                                    .map_err(|e| VMError::RuntimeError(e.to_string()))?,
217                            ));
218                        }
219                        Some(out)
220                    }
221                    None => None,
222                };
223                // Restore blob_hash from the snapshot frame. Use the shared
224                // hash-first resolution helper for strict validation.
225                let blob_hash = f.blob_hash.map(FunctionHash);
226                let resolved_function_id = if blob_hash.is_some() || f.function_id.is_some() {
227                    Some(resolve_function_identity(
228                        &vm.function_id_by_hash,
229                        &vm.program.functions,
230                        blob_hash,
231                        f.function_id,
232                        None,
233                    )?)
234                } else {
235                    None
236                };
237
238                let return_ip = if let (Some(hash), Some(local_ip), Some(fid)) =
239                    (&blob_hash, f.local_ip, resolved_function_id)
240                {
241                    // Validate the blob hash matches the loaded program
242                    let current_hash = vm.blob_hash_for_function(fid);
243                    if let Some(current) = current_hash
244                        && current != *hash
245                    {
246                        return Err(VMError::RuntimeError(format!(
247                            "Snapshot blob hash mismatch for function {}: \
248                             snapshot has {}, program has {}",
249                            fid, hash, current
250                        )));
251                    }
252                    // Reconstruct absolute IP from local_ip + entry_point
253                    let entry_point = vm
254                        .function_entry_points
255                        .get(fid as usize)
256                        .copied()
257                        .unwrap_or(0);
258                    local_ip + entry_point
259                } else {
260                    f.return_ip
261                };
262
263                Ok(CallFrame {
264                    return_ip,
265                    base_pointer: f.locals_base,
266                    locals_count: f.locals_count,
267                    function_id: resolved_function_id,
268                    upvalues,
269                    blob_hash,
270                })
271            })
272            .collect::<Result<Vec<_>, VMError>>()?;
273
274        vm.loop_stack = snapshot
275            .loop_stack
276            .iter()
277            .map(|l| LoopContext {
278                start: l.start,
279                end: l.end,
280            })
281            .collect();
282        vm.timeframe_stack = snapshot.timeframe_stack.clone();
283        vm.exception_handlers = snapshot
284            .exception_handlers
285            .iter()
286            .map(|h| ExceptionHandler {
287                catch_ip: h.catch_ip,
288                stack_size: h.stack_size,
289                call_depth: h.call_depth,
290            })
291            .collect();
292
293        Ok(vm)
294    }
295}
296
297#[cfg(test)]
298mod tests {
299    use super::*;
300
301    /// Create a minimal Function with just a name (other fields defaulted).
302    fn make_function(name: &str) -> Function {
303        Function {
304            name: name.to_string(),
305            arity: 0,
306            param_names: Vec::new(),
307            locals_count: 0,
308            entry_point: 0,
309            body_length: 0,
310            is_closure: false,
311            captures_count: 0,
312            is_async: false,
313            ref_params: Vec::new(),
314            ref_mutates: Vec::new(),
315            mutable_captures: Vec::new(),
316            frame_descriptor: None,
317            osr_entry_points: Vec::new(),
318        }
319    }
320
321    fn make_hash(seed: u8) -> FunctionHash {
322        FunctionHash([seed; 32])
323    }
324
325    #[test]
326    fn test_resolve_by_hash() {
327        let hash = make_hash(0xAB);
328        let mut by_hash = HashMap::new();
329        by_hash.insert(hash, 3u16);
330        let funcs = vec![
331            make_function("a"),
332            make_function("b"),
333            make_function("c"),
334            make_function("d"),
335        ];
336
337        let result = resolve_function_identity(&by_hash, &funcs, Some(hash), None, None);
338        assert_eq!(result.unwrap(), 3);
339    }
340
341    #[test]
342    fn test_resolve_hash_not_found_is_error() {
343        let hash = make_hash(0xAB);
344        let by_hash = HashMap::new(); // empty — hash not registered
345        let funcs = vec![make_function("a")];
346
347        let result = resolve_function_identity(&by_hash, &funcs, Some(hash), None, None);
348        assert!(result.is_err());
349        let msg = result.unwrap_err().to_string();
350        assert!(msg.contains("unknown function blob hash"), "got: {}", msg);
351    }
352
353    #[test]
354    fn test_resolve_hash_function_id_mismatch_is_error() {
355        let hash = make_hash(0xCD);
356        let mut by_hash = HashMap::new();
357        by_hash.insert(hash, 2u16); // hash resolves to 2
358        let funcs = vec![make_function("a"), make_function("b"), make_function("c")];
359
360        // Pass function_id=5 which disagrees with hash-resolved id=2
361        let result = resolve_function_identity(&by_hash, &funcs, Some(hash), Some(5), None);
362        assert!(result.is_err());
363        let msg = result.unwrap_err().to_string();
364        assert!(msg.contains("mismatch"), "got: {}", msg);
365    }
366
367    #[test]
368    fn test_resolve_hash_function_id_agree() {
369        let hash = make_hash(0xEF);
370        let mut by_hash = HashMap::new();
371        by_hash.insert(hash, 1u16);
372        let funcs = vec![make_function("a"), make_function("b")];
373
374        // Both agree on id=1
375        let result = resolve_function_identity(&by_hash, &funcs, Some(hash), Some(1), None);
376        assert_eq!(result.unwrap(), 1);
377    }
378
379    #[test]
380    fn test_resolve_by_function_id() {
381        let by_hash = HashMap::new();
382        let funcs = vec![make_function("a"), make_function("b"), make_function("c")];
383
384        let result = resolve_function_identity(&by_hash, &funcs, None, Some(2), None);
385        assert_eq!(result.unwrap(), 2);
386    }
387
388    #[test]
389    fn test_resolve_function_id_out_of_range() {
390        let by_hash = HashMap::new();
391        let funcs = vec![make_function("a")];
392
393        let result = resolve_function_identity(&by_hash, &funcs, None, Some(99), None);
394        assert!(result.is_err());
395        let msg = result.unwrap_err().to_string();
396        assert!(msg.contains("out of range"), "got: {}", msg);
397    }
398
399    #[test]
400    fn test_resolve_unique_name_fallback() {
401        let by_hash = HashMap::new();
402        let funcs = vec![
403            make_function("alpha"),
404            make_function("beta"),
405            make_function("gamma"),
406        ];
407
408        let result = resolve_function_identity(&by_hash, &funcs, None, None, Some("beta"));
409        assert_eq!(result.unwrap(), 1);
410    }
411
412    #[test]
413    fn test_resolve_ambiguous_name_is_error() {
414        let by_hash = HashMap::new();
415        let funcs = vec![
416            make_function("dup"),
417            make_function("other"),
418            make_function("dup"),
419        ];
420
421        let result = resolve_function_identity(&by_hash, &funcs, None, None, Some("dup"));
422        assert!(result.is_err());
423        let msg = result.unwrap_err().to_string();
424        assert!(msg.contains("ambiguous"), "got: {}", msg);
425    }
426
427    #[test]
428    fn test_resolve_name_not_found() {
429        let by_hash = HashMap::new();
430        let funcs = vec![make_function("a")];
431
432        let result = resolve_function_identity(&by_hash, &funcs, None, None, Some("missing"));
433        assert!(result.is_err());
434        let msg = result.unwrap_err().to_string();
435        assert!(msg.contains("no function named"), "got: {}", msg);
436    }
437
438    #[test]
439    fn test_resolve_no_identifiers_is_error() {
440        let by_hash = HashMap::new();
441        let funcs = vec![make_function("a")];
442
443        let result = resolve_function_identity(&by_hash, &funcs, None, None, None);
444        assert!(result.is_err());
445        let msg = result.unwrap_err().to_string();
446        assert!(msg.contains("no hash, id, or name"), "got: {}", msg);
447    }
448}