Skip to main content

shape_vm/executor/
call_convention.rs

1//! Function and closure call convention, execution wrappers, and async resolution.
2
3use shape_value::{Upvalue, VMError, ValueWord};
4
5use super::{CallFrame, ExecutionResult, VirtualMachine, task_scheduler};
6
7impl VirtualMachine {
8    /// Execute a named function with arguments, returning its result.
9    ///
10    /// If the program has module-level bindings, the top-level code is executed
11    /// first (once) to initialize them before calling the target function.
12    pub fn execute_function_by_name(
13        &mut self,
14        name: &str,
15        args: Vec<ValueWord>,
16        ctx: Option<&mut shape_runtime::context::ExecutionContext>,
17    ) -> Result<ValueWord, VMError> {
18        let func_id = self
19            .program
20            .functions
21            .iter()
22            .position(|f| f.name == name)
23            .ok_or_else(|| VMError::RuntimeError(format!("Function '{}' not found", name)))?;
24
25        // Run the top-level code first to initialize module bindings,
26        // but only if there are module bindings that need initialization.
27        if !self.program.module_binding_names.is_empty() && !self.module_init_done {
28            self.reset();
29            self.execute(None)?;
30            self.module_init_done = true;
31        }
32
33        // Now call the target function.
34        // Use reset_stack to keep module_bindings intact.
35        self.reset_stack();
36        self.ip = self.program.instructions.len();
37        self.call_function_with_nb_args(func_id as u16, &args)?;
38        self.execute(ctx)
39    }
40
41    /// Execute a function by its ID with positional arguments.
42    ///
43    /// Used by the remote execution system when the caller already knows the
44    /// function index (e.g., from a `RemoteCallRequest.function_id`).
45    pub fn execute_function_by_id(
46        &mut self,
47        func_id: u16,
48        args: Vec<ValueWord>,
49        ctx: Option<&mut shape_runtime::context::ExecutionContext>,
50    ) -> Result<ValueWord, VMError> {
51        self.reset();
52        self.ip = self.program.instructions.len();
53        self.call_function_with_nb_args(func_id, &args)?;
54        self.execute(ctx)
55    }
56
57    /// Execute a closure with its captured upvalues and arguments.
58    ///
59    /// Used by the remote execution system to run closures that were
60    /// serialized with their captured values.
61    pub fn execute_closure(
62        &mut self,
63        function_id: u16,
64        upvalues: Vec<Upvalue>,
65        args: Vec<ValueWord>,
66        ctx: Option<&mut shape_runtime::context::ExecutionContext>,
67    ) -> Result<ValueWord, VMError> {
68        self.reset();
69        self.ip = self.program.instructions.len();
70        self.call_closure_with_nb_args(function_id, upvalues, &args)?;
71        self.execute(ctx)
72    }
73
74    /// Fast function execution for hot loops (backtesting)
75    /// - Uses pre-computed function ID (no name lookup)
76    /// - Uses reset_minimal() for minimum overhead
77    /// - Uses execute_fast() which skips debugging overhead
78    /// - Assumes function doesn't create GC objects or use exceptions
79    pub fn execute_function_fast(
80        &mut self,
81        func_id: u16,
82        ctx: Option<&mut shape_runtime::context::ExecutionContext>,
83    ) -> Result<ValueWord, VMError> {
84        // Minimal reset - only essential state, no GC overhead
85        self.reset_minimal();
86        self.ip = self.program.instructions.len();
87        self.call_function_with_nb_args(func_id, &[])?;
88        self.execute_fast(ctx)
89    }
90
91    /// Execute a function with named arguments
92    /// Maps named args to positional based on function's param_names
93    pub fn execute_function_with_named_args(
94        &mut self,
95        func_id: u16,
96        named_args: &[(String, ValueWord)],
97        ctx: Option<&mut shape_runtime::context::ExecutionContext>,
98    ) -> Result<ValueWord, VMError> {
99        let function = self
100            .program
101            .functions
102            .get(func_id as usize)
103            .ok_or(VMError::InvalidCall)?;
104
105        // Map named args to positional based on param_names
106        let mut args = vec![ValueWord::none(); function.arity as usize];
107        for (name, value) in named_args {
108            if let Some(idx) = function.param_names.iter().position(|p| p == name) {
109                if idx < args.len() {
110                    args[idx] = value.clone();
111                }
112            }
113        }
114
115        self.reset_minimal();
116        self.ip = self.program.instructions.len();
117        self.call_function_with_nb_args(func_id, &args)?;
118        self.execute_fast(ctx)
119    }
120
121    /// Resume execution after a suspension.
122    ///
123    /// The resolved value is pushed onto the stack, and execution continues
124    /// from where it left off (the IP is already set to the resume point).
125    pub fn resume(
126        &mut self,
127        value: ValueWord,
128        ctx: Option<&mut shape_runtime::context::ExecutionContext>,
129    ) -> Result<ExecutionResult, VMError> {
130        self.push_vw(value)?;
131        self.execute_with_suspend(ctx)
132    }
133
134    /// Execute with automatic async task resolution.
135    ///
136    /// Runs `execute_with_suspend` in a loop. Each time the VM suspends on a
137    /// `Future { id }`, the host resolves the task via the TaskScheduler
138    /// (synchronously executing the spawned callable inline) and resumes the
139    /// VM with the result. This continues until execution completes or an
140    /// unresolvable suspension is encountered.
141    pub fn execute_with_async(
142        &mut self,
143        mut ctx: Option<&mut shape_runtime::context::ExecutionContext>,
144    ) -> Result<ValueWord, VMError> {
145        loop {
146            match self.execute_with_suspend(ctx.as_deref_mut())? {
147                ExecutionResult::Completed(value) => return Ok(value),
148                ExecutionResult::Suspended { future_id, .. } => {
149                    // Try to resolve via the task scheduler
150                    let result = self.resolve_spawned_task(future_id)?;
151                    // Push the result so the resumed VM finds it on the stack
152                    self.push_vw(result)?;
153                    // Loop continues with execute_with_suspend
154                }
155            }
156        }
157    }
158
159    /// Resolve a spawned task by executing its callable synchronously.
160    ///
161    /// Looks up the callable in the TaskScheduler, then executes it:
162    /// - NanTag::Function -> calls via call_function_with_nb_args
163    /// - HeapValue::Closure -> calls via call_closure_with_nb_args
164    /// - Other values -> returns them directly (already-resolved value)
165    ///
166    /// For externally-completed tasks (remote calls), checks the oneshot
167    /// receiver first (non-blocking).
168    fn resolve_spawned_task(&mut self, task_id: u64) -> Result<ValueWord, VMError> {
169        // Check if already resolved (cached)
170        if let Some(task_scheduler::TaskStatus::Completed(val)) =
171            self.task_scheduler.get_result(task_id)
172        {
173            return Ok(val.clone());
174        }
175        if let Some(task_scheduler::TaskStatus::Cancelled) = self.task_scheduler.get_result(task_id)
176        {
177            return Err(VMError::RuntimeError(format!(
178                "Task {} was cancelled",
179                task_id
180            )));
181        }
182
183        // Check external receivers (non-blocking) before inline execution
184        if let Some(result) = self.task_scheduler.try_resolve_external(task_id) {
185            return result;
186        }
187
188        // If this is an external task that hasn't completed yet, block on it
189        // using tokio's block_in_place to avoid deadlocking the runtime.
190        if self.task_scheduler.has_external(task_id) {
191            if let Some(rx) = self.task_scheduler.take_external_receiver(task_id) {
192                let result =
193                    tokio::task::block_in_place(|| tokio::runtime::Handle::current().block_on(rx))
194                        .map_err(|_| VMError::RuntimeError("Remote task dropped".to_string()))?
195                        .map_err(VMError::RuntimeError)?;
196                self.task_scheduler.complete(task_id, result.clone());
197                return Ok(result);
198            }
199        }
200
201        // Take the callable
202        let callable_nb = self.task_scheduler.take_callable(task_id).ok_or_else(|| {
203            VMError::RuntimeError(format!("No callable registered for task {}", task_id))
204        })?;
205
206        // Execute based on callable type.
207        // We save/restore the instruction pointer and stack depth so the
208        // nested execution doesn't corrupt the outer (suspended) state.
209        use shape_value::NanTag;
210
211        let result_nb = match callable_nb.tag() {
212            NanTag::Function => {
213                let func_id = callable_nb.as_function().ok_or(VMError::InvalidCall)?;
214                let saved_ip = self.ip;
215                let saved_sp = self.sp;
216
217                self.ip = self.program.instructions.len();
218                self.call_function_with_nb_args(func_id, &[])?;
219                let res = self.execute_fast(None);
220
221                self.ip = saved_ip;
222                // Restore stack pointer (clear anything left above saved_sp)
223                for i in saved_sp..self.sp {
224                    self.stack[i] = ValueWord::none();
225                }
226                self.sp = saved_sp;
227
228                res?
229            }
230            NanTag::Heap => {
231                if let Some((function_id, upvalues)) = callable_nb.as_closure() {
232                    let upvalues = upvalues.to_vec();
233                    let saved_ip = self.ip;
234                    let saved_sp = self.sp;
235
236                    self.ip = self.program.instructions.len();
237                    self.call_closure_with_nb_args(function_id, upvalues, &[])?;
238                    let res = self.execute_fast(None);
239
240                    self.ip = saved_ip;
241                    for i in saved_sp..self.sp {
242                        self.stack[i] = ValueWord::none();
243                    }
244                    self.sp = saved_sp;
245
246                    res?
247                } else {
248                    // If someone spawned an already-resolved value, just return it
249                    callable_nb
250                }
251            }
252            // If someone spawned an already-resolved value, just return it
253            _ => callable_nb,
254        };
255
256        // Cache the result
257        self.task_scheduler.complete(task_id, result_nb.clone());
258
259        Ok(result_nb)
260    }
261
262    /// ValueWord-module function call: takes ValueWord args directly.
263    pub(crate) fn call_function_with_nb_args(
264        &mut self,
265        func_id: u16,
266        args: &[ValueWord],
267    ) -> Result<(), VMError> {
268        let function = self
269            .program
270            .functions
271            .get(func_id as usize)
272            .ok_or(VMError::InvalidCall)?;
273
274        if self.call_stack.len() >= self.config.max_call_depth {
275            return Err(VMError::StackOverflow);
276        }
277
278        let locals_count = function.locals_count as usize;
279        let param_count = function.arity as usize;
280        let entry_point = function.entry_point;
281        let ref_params = function.ref_params.clone();
282
283        // Count ref params that need shadow slots for their actual values.
284        // DerefLoad/DerefStore expect the param slot to contain a TAG_REF
285        // pointing to a *different* slot that holds the real value.
286        let ref_shadow_count = ref_params
287            .iter()
288            .enumerate()
289            .filter(|&(i, &is_ref)| is_ref && i < param_count && i < locals_count)
290            .count();
291
292        let bp = self.sp;
293        let total_slots = locals_count + ref_shadow_count;
294        let needed = bp + total_slots;
295        if needed > self.stack.len() {
296            self.stack.resize_with(needed * 2 + 1, ValueWord::none);
297        }
298
299        for i in 0..param_count {
300            if i < locals_count {
301                self.stack[bp + i] = args.get(i).cloned().unwrap_or_else(ValueWord::none);
302            }
303        }
304
305        // For ref-inferred parameters: move the actual value to a shadow slot
306        // beyond locals_count, then replace the param slot with a TAG_REF
307        // pointing to the shadow slot. This way DerefLoad follows the ref
308        // to the actual value (not a circular self-reference).
309        let mut shadow_idx = 0;
310        for (i, &is_ref) in ref_params.iter().enumerate() {
311            if is_ref && i < param_count && i < locals_count {
312                let shadow_slot = bp + locals_count + shadow_idx;
313                self.stack[shadow_slot] = self.stack[bp + i].clone();
314                self.stack[bp + i] = ValueWord::from_ref(shadow_slot);
315                shadow_idx += 1;
316            }
317        }
318
319        self.sp = needed;
320
321        let blob_hash = self.blob_hash_for_function(func_id);
322        let frame = CallFrame {
323            return_ip: self.ip,
324            base_pointer: bp,
325            locals_count: total_slots,
326            function_id: Some(func_id),
327            upvalues: None,
328            blob_hash,
329        };
330        self.call_stack.push(frame);
331        self.ip = entry_point;
332        Ok(())
333    }
334
335    /// ValueWord-host closure call: takes ValueWord args directly.
336    pub(crate) fn call_closure_with_nb_args(
337        &mut self,
338        func_id: u16,
339        upvalues: Vec<Upvalue>,
340        args: &[ValueWord],
341    ) -> Result<(), VMError> {
342        let function = self
343            .program
344            .functions
345            .get(func_id as usize)
346            .ok_or(VMError::InvalidCall)?;
347
348        if self.call_stack.len() >= self.config.max_call_depth {
349            return Err(VMError::StackOverflow);
350        }
351
352        let locals_count = function.locals_count as usize;
353        let captures_count = function.captures_count as usize;
354        let arity = function.arity as usize;
355        let entry_point = function.entry_point;
356
357        let bp = self.sp;
358        let needed = bp + locals_count;
359        if needed > self.stack.len() {
360            self.stack.resize_with(needed * 2 + 1, ValueWord::none);
361        }
362
363        // Bind upvalue values as the first N locals
364        for (i, upvalue) in upvalues.iter().enumerate() {
365            if i < locals_count {
366                self.stack[bp + i] = upvalue.get();
367            }
368        }
369
370        // Bind the regular arguments after the upvalues
371        for (i, arg) in args.iter().enumerate() {
372            let local_idx = captures_count + i;
373            if local_idx < locals_count {
374                self.stack[bp + local_idx] = arg.clone();
375            }
376        }
377
378        // Fill remaining parameters with None
379        for i in (captures_count + args.len())..arity.min(locals_count) {
380            self.stack[bp + i] = ValueWord::none();
381        }
382
383        self.sp = needed;
384
385        let blob_hash = self.blob_hash_for_function(func_id);
386        self.call_stack.push(CallFrame {
387            return_ip: self.ip,
388            base_pointer: bp,
389            locals_count,
390            function_id: Some(func_id),
391            upvalues: Some(upvalues),
392            blob_hash,
393        });
394
395        self.ip = entry_point;
396        Ok(())
397    }
398
399    /// ValueWord-native call_value_immediate: dispatches on NanTag/HeapKind.
400    ///
401    /// Returns ValueWord directly.
402    pub(in crate::executor) fn call_value_immediate_nb(
403        &mut self,
404        callee: &ValueWord,
405        args: &[ValueWord],
406        ctx: Option<&mut shape_runtime::context::ExecutionContext>,
407    ) -> Result<ValueWord, VMError> {
408        use shape_value::NanTag;
409        let target_depth = self.call_stack.len();
410
411        match callee.tag() {
412            NanTag::Function => {
413                let func_id = callee.as_function().ok_or(VMError::InvalidCall)?;
414                self.call_function_with_nb_args(func_id, args)?;
415            }
416            NanTag::ModuleFunction => {
417                let func_id = callee.as_module_function().ok_or(VMError::InvalidCall)?;
418                let module_fn = self.module_fn_table.get(func_id).cloned().ok_or_else(|| {
419                    VMError::RuntimeError(format!(
420                        "Module function ID {} not found in registry",
421                        func_id
422                    ))
423                })?;
424                let args_vec: Vec<ValueWord> = args.to_vec();
425                let result_nb = self.invoke_module_fn(&module_fn, &args_vec)?;
426                return Ok(result_nb);
427            }
428            NanTag::Heap => match callee.as_heap_ref() {
429                Some(shape_value::HeapValue::Closure {
430                    function_id,
431                    upvalues,
432                }) => {
433                    self.call_closure_with_nb_args(*function_id, upvalues.clone(), args)?;
434                }
435                Some(shape_value::HeapValue::HostClosure(callable)) => {
436                    let args_vec: Vec<ValueWord> = args.to_vec();
437                    let result_nb = callable.call(&args_vec).map_err(VMError::RuntimeError)?;
438                    return Ok(result_nb);
439                }
440                _ => return Err(VMError::InvalidCall),
441            },
442            _ => return Err(VMError::InvalidCall),
443        }
444
445        self.execute_until_call_depth(target_depth, ctx)?;
446        self.pop_vw()
447    }
448
449    /// Fast-path function call: reads `arg_count` arguments directly from the
450    /// value stack instead of collecting them into a temporary `Vec`.
451    ///
452    /// Precondition: the top `arg_count` values on the stack (below sp) are the
453    /// arguments in left-to-right order (arg0 deepest, argN-1 at top).
454    /// These args become the first locals of the new frame's register window.
455    pub(crate) fn call_function_from_stack(
456        &mut self,
457        func_id: u16,
458        arg_count: usize,
459    ) -> Result<(), VMError> {
460        let function = self
461            .program
462            .functions
463            .get(func_id as usize)
464            .ok_or(VMError::InvalidCall)?;
465
466        if self.call_stack.len() >= self.config.max_call_depth {
467            return Err(VMError::StackOverflow);
468        }
469
470        let locals_count = function.locals_count as usize;
471        let entry_point = function.entry_point;
472        let arity = function.arity as usize;
473
474        // The args are already on the stack at positions [sp - arg_count .. sp).
475        // They become the first locals in the register window.
476        // bp = sp - arg_count (args are already in place as the first locals)
477        let bp = self.sp.saturating_sub(arg_count);
478
479        // Ensure stack has room for all locals (some may be beyond the args)
480        let needed = bp + locals_count;
481        if needed > self.stack.len() {
482            self.stack.resize_with(needed * 2 + 1, ValueWord::none);
483        }
484
485        // Zero remaining local slots (including omitted args that the compiler
486        // may intentionally represent as null sentinels for default params).
487        let copy_count = arg_count.min(arity).min(locals_count);
488        for i in copy_count..locals_count {
489            self.stack[bp + i] = ValueWord::none();
490        }
491
492        // Advance sp past all locals
493        self.sp = needed;
494
495        let blob_hash = self.blob_hash_for_function(func_id);
496        self.call_stack.push(CallFrame {
497            return_ip: self.ip,
498            base_pointer: bp,
499            locals_count,
500            function_id: Some(func_id),
501            upvalues: None,
502            blob_hash,
503        });
504        self.ip = entry_point;
505        Ok(())
506    }
507}