1use std::collections::HashMap;
4
5use shape_runtime::snapshot::{
6 SerializableCallFrame, SerializableExceptionHandler, SerializableLoopContext, SnapshotStore,
7 VmSnapshot, nanboxed_to_serializable, serializable_to_nanboxed,
8};
9use shape_value::{Upvalue, VMError, ValueWord};
10
11use super::{CallFrame, ExceptionHandler, LoopContext, VMConfig, VirtualMachine};
12use crate::bytecode::{Function, FunctionHash};
13
14pub(crate) fn resolve_function_identity(
19 function_id_by_hash: &HashMap<FunctionHash, u16>,
20 functions: &[Function],
21 blob_hash: Option<FunctionHash>,
22 function_id: Option<u16>,
23 function_name: Option<&str>,
24) -> Result<u16, VMError> {
25 if let Some(hash) = blob_hash {
27 let resolved = function_id_by_hash.get(&hash).copied().ok_or_else(|| {
28 VMError::RuntimeError(format!("unknown function blob hash: {}", hash))
29 })?;
30 if let Some(fid) = function_id {
32 if fid != resolved {
33 return Err(VMError::RuntimeError(format!(
34 "function_id/hash mismatch: frame id {} does not match hash {} (resolved id {})",
35 fid, hash, resolved
36 )));
37 }
38 }
39 return Ok(resolved);
40 }
41
42 if let Some(fid) = function_id {
44 if (fid as usize) < functions.len() {
45 return Ok(fid);
46 }
47 return Err(VMError::RuntimeError(format!(
48 "function_id {} out of range (program has {} functions)",
49 fid,
50 functions.len()
51 )));
52 }
53
54 if let Some(name) = function_name {
56 let matches: Vec<usize> = functions
57 .iter()
58 .enumerate()
59 .filter_map(|(idx, f)| if f.name == name { Some(idx) } else { None })
60 .collect();
61 return match matches.len() {
62 1 => Ok(matches[0] as u16),
63 0 => Err(VMError::RuntimeError(format!(
64 "no function named '{}'",
65 name
66 ))),
67 n => Err(VMError::RuntimeError(format!(
68 "ambiguous function name '{}' ({} matches)",
69 name, n
70 ))),
71 };
72 }
73
74 Err(VMError::RuntimeError(
76 "cannot resolve function identity: no hash, id, or name provided".into(),
77 ))
78}
79
80impl VirtualMachine {
81 pub fn snapshot(&self, store: &SnapshotStore) -> Result<VmSnapshot, VMError> {
83 let mut stack = Vec::with_capacity(self.sp);
84 for nb in self.stack[..self.sp].iter() {
85 stack.push(
86 nanboxed_to_serializable(nb, store)
87 .map_err(|e| VMError::RuntimeError(e.to_string()))?,
88 );
89 }
90 let locals = Vec::new();
92 let mut module_bindings = Vec::with_capacity(self.module_bindings.len());
93 for nb in self.module_bindings.iter() {
94 module_bindings.push(
95 nanboxed_to_serializable(nb, store)
96 .map_err(|e| VMError::RuntimeError(e.to_string()))?,
97 );
98 }
99
100 let mut call_stack = Vec::with_capacity(self.call_stack.len());
101 for frame in self.call_stack.iter() {
102 let upvalues = match &frame.upvalues {
103 Some(values) => {
104 let mut out = Vec::new();
105 for up in values.iter() {
106 let nb = up.get();
107 out.push(
108 nanboxed_to_serializable(&nb, store)
109 .map_err(|e| VMError::RuntimeError(e.to_string()))?,
110 );
111 }
112 Some(out)
113 }
114 None => None,
115 };
116 let (blob_hash, local_ip) =
118 if let (Some(hash), Some(fid)) = (frame.blob_hash, frame.function_id) {
119 let entry_point = self
120 .function_entry_points
121 .get(fid as usize)
122 .copied()
123 .unwrap_or(0);
124 let lip = frame.return_ip.saturating_sub(entry_point);
125 (Some(hash.0), Some(lip))
126 } else {
127 (None, None)
128 };
129
130 call_stack.push(SerializableCallFrame {
131 return_ip: frame.return_ip,
132 locals_base: frame.base_pointer,
133 locals_count: frame.locals_count,
134 function_id: frame.function_id,
135 upvalues,
136 blob_hash,
137 local_ip,
138 });
139 }
140
141 let loop_stack = self
142 .loop_stack
143 .iter()
144 .map(|l| SerializableLoopContext {
145 start: l.start,
146 end: l.end,
147 })
148 .collect();
149 let exception_handlers = self
150 .exception_handlers
151 .iter()
152 .map(|h| SerializableExceptionHandler {
153 catch_ip: h.catch_ip,
154 stack_size: h.stack_size,
155 call_depth: h.call_depth,
156 })
157 .collect();
158
159 Ok(VmSnapshot {
160 ip: self.ip,
161 stack,
162 locals,
163 module_bindings,
164 call_stack,
165 loop_stack,
166 timeframe_stack: self.timeframe_stack.clone(),
167 exception_handlers,
168 })
169 }
170
171 pub fn from_snapshot(
173 program: crate::bytecode::BytecodeProgram,
174 snapshot: &VmSnapshot,
175 store: &SnapshotStore,
176 ) -> Result<Self, VMError> {
177 let mut vm = VirtualMachine::new(VMConfig::default());
178 vm.load_program(program);
179 vm.ip = snapshot.ip;
180
181 let restored_stack: Vec<ValueWord> = snapshot
182 .stack
183 .iter()
184 .map(|v| {
185 serializable_to_nanboxed(v, store).map_err(|e| VMError::RuntimeError(e.to_string()))
186 })
187 .collect::<Result<Vec<_>, _>>()?;
188 let restored_sp = restored_stack.len();
189 vm.stack = (0..restored_sp.max(crate::constants::DEFAULT_STACK_CAPACITY))
191 .map(|_| ValueWord::none())
192 .collect();
193 for (i, nb) in restored_stack.into_iter().enumerate() {
194 vm.stack[i] = nb;
195 }
196 vm.sp = restored_sp;
197 vm.module_bindings = snapshot
199 .module_bindings
200 .iter()
201 .map(|v| {
202 serializable_to_nanboxed(v, store).map_err(|e| VMError::RuntimeError(e.to_string()))
203 })
204 .collect::<Result<Vec<_>, _>>()?;
205
206 vm.call_stack = snapshot
207 .call_stack
208 .iter()
209 .map(|f| {
210 let upvalues = match &f.upvalues {
211 Some(values) => {
212 let mut out = Vec::new();
213 for v in values.iter() {
214 out.push(Upvalue::new(
215 serializable_to_nanboxed(v, store)
216 .map_err(|e| VMError::RuntimeError(e.to_string()))?,
217 ));
218 }
219 Some(out)
220 }
221 None => None,
222 };
223 let blob_hash = f.blob_hash.map(FunctionHash);
226 let resolved_function_id = if blob_hash.is_some() || f.function_id.is_some() {
227 Some(resolve_function_identity(
228 &vm.function_id_by_hash,
229 &vm.program.functions,
230 blob_hash,
231 f.function_id,
232 None,
233 )?)
234 } else {
235 None
236 };
237
238 let return_ip = if let (Some(hash), Some(local_ip), Some(fid)) =
239 (&blob_hash, f.local_ip, resolved_function_id)
240 {
241 let current_hash = vm.blob_hash_for_function(fid);
243 if let Some(current) = current_hash
244 && current != *hash
245 {
246 return Err(VMError::RuntimeError(format!(
247 "Snapshot blob hash mismatch for function {}: \
248 snapshot has {}, program has {}",
249 fid, hash, current
250 )));
251 }
252 let entry_point = vm
254 .function_entry_points
255 .get(fid as usize)
256 .copied()
257 .unwrap_or(0);
258 local_ip + entry_point
259 } else {
260 f.return_ip
261 };
262
263 Ok(CallFrame {
264 return_ip,
265 base_pointer: f.locals_base,
266 locals_count: f.locals_count,
267 function_id: resolved_function_id,
268 upvalues,
269 blob_hash,
270 })
271 })
272 .collect::<Result<Vec<_>, VMError>>()?;
273
274 vm.loop_stack = snapshot
275 .loop_stack
276 .iter()
277 .map(|l| LoopContext {
278 start: l.start,
279 end: l.end,
280 })
281 .collect();
282 vm.timeframe_stack = snapshot.timeframe_stack.clone();
283 vm.exception_handlers = snapshot
284 .exception_handlers
285 .iter()
286 .map(|h| ExceptionHandler {
287 catch_ip: h.catch_ip,
288 stack_size: h.stack_size,
289 call_depth: h.call_depth,
290 })
291 .collect();
292
293 Ok(vm)
294 }
295}
296
297#[cfg(test)]
298mod tests {
299 use super::*;
300
301 fn make_function(name: &str) -> Function {
303 Function {
304 name: name.to_string(),
305 arity: 0,
306 param_names: Vec::new(),
307 locals_count: 0,
308 entry_point: 0,
309 body_length: 0,
310 is_closure: false,
311 captures_count: 0,
312 is_async: false,
313 ref_params: Vec::new(),
314 ref_mutates: Vec::new(),
315 mutable_captures: Vec::new(),
316 frame_descriptor: None,
317 osr_entry_points: Vec::new(),
318 }
319 }
320
321 fn make_hash(seed: u8) -> FunctionHash {
322 FunctionHash([seed; 32])
323 }
324
325 #[test]
326 fn test_resolve_by_hash() {
327 let hash = make_hash(0xAB);
328 let mut by_hash = HashMap::new();
329 by_hash.insert(hash, 3u16);
330 let funcs = vec![
331 make_function("a"),
332 make_function("b"),
333 make_function("c"),
334 make_function("d"),
335 ];
336
337 let result = resolve_function_identity(&by_hash, &funcs, Some(hash), None, None);
338 assert_eq!(result.unwrap(), 3);
339 }
340
341 #[test]
342 fn test_resolve_hash_not_found_is_error() {
343 let hash = make_hash(0xAB);
344 let by_hash = HashMap::new(); let funcs = vec![make_function("a")];
346
347 let result = resolve_function_identity(&by_hash, &funcs, Some(hash), None, None);
348 assert!(result.is_err());
349 let msg = result.unwrap_err().to_string();
350 assert!(msg.contains("unknown function blob hash"), "got: {}", msg);
351 }
352
353 #[test]
354 fn test_resolve_hash_function_id_mismatch_is_error() {
355 let hash = make_hash(0xCD);
356 let mut by_hash = HashMap::new();
357 by_hash.insert(hash, 2u16); let funcs = vec![make_function("a"), make_function("b"), make_function("c")];
359
360 let result = resolve_function_identity(&by_hash, &funcs, Some(hash), Some(5), None);
362 assert!(result.is_err());
363 let msg = result.unwrap_err().to_string();
364 assert!(msg.contains("mismatch"), "got: {}", msg);
365 }
366
367 #[test]
368 fn test_resolve_hash_function_id_agree() {
369 let hash = make_hash(0xEF);
370 let mut by_hash = HashMap::new();
371 by_hash.insert(hash, 1u16);
372 let funcs = vec![make_function("a"), make_function("b")];
373
374 let result = resolve_function_identity(&by_hash, &funcs, Some(hash), Some(1), None);
376 assert_eq!(result.unwrap(), 1);
377 }
378
379 #[test]
380 fn test_resolve_by_function_id() {
381 let by_hash = HashMap::new();
382 let funcs = vec![make_function("a"), make_function("b"), make_function("c")];
383
384 let result = resolve_function_identity(&by_hash, &funcs, None, Some(2), None);
385 assert_eq!(result.unwrap(), 2);
386 }
387
388 #[test]
389 fn test_resolve_function_id_out_of_range() {
390 let by_hash = HashMap::new();
391 let funcs = vec![make_function("a")];
392
393 let result = resolve_function_identity(&by_hash, &funcs, None, Some(99), None);
394 assert!(result.is_err());
395 let msg = result.unwrap_err().to_string();
396 assert!(msg.contains("out of range"), "got: {}", msg);
397 }
398
399 #[test]
400 fn test_resolve_unique_name_fallback() {
401 let by_hash = HashMap::new();
402 let funcs = vec![
403 make_function("alpha"),
404 make_function("beta"),
405 make_function("gamma"),
406 ];
407
408 let result = resolve_function_identity(&by_hash, &funcs, None, None, Some("beta"));
409 assert_eq!(result.unwrap(), 1);
410 }
411
412 #[test]
413 fn test_resolve_ambiguous_name_is_error() {
414 let by_hash = HashMap::new();
415 let funcs = vec![
416 make_function("dup"),
417 make_function("other"),
418 make_function("dup"),
419 ];
420
421 let result = resolve_function_identity(&by_hash, &funcs, None, None, Some("dup"));
422 assert!(result.is_err());
423 let msg = result.unwrap_err().to_string();
424 assert!(msg.contains("ambiguous"), "got: {}", msg);
425 }
426
427 #[test]
428 fn test_resolve_name_not_found() {
429 let by_hash = HashMap::new();
430 let funcs = vec![make_function("a")];
431
432 let result = resolve_function_identity(&by_hash, &funcs, None, None, Some("missing"));
433 assert!(result.is_err());
434 let msg = result.unwrap_err().to_string();
435 assert!(msg.contains("no function named"), "got: {}", msg);
436 }
437
438 #[test]
439 fn test_resolve_no_identifiers_is_error() {
440 let by_hash = HashMap::new();
441 let funcs = vec![make_function("a")];
442
443 let result = resolve_function_identity(&by_hash, &funcs, None, None, None);
444 assert!(result.is_err());
445 let msg = result.unwrap_err().to_string();
446 assert!(msg.contains("no hash, id, or name"), "got: {}", msg);
447 }
448}