Skip to main content

shape_vm/executor/
call_convention.rs

1//! Function and closure call convention, execution wrappers, and async resolution.
2
3use shape_value::{Upvalue, VMError, ValueWord};
4
5use super::{CallFrame, ExecutionResult, VirtualMachine, task_scheduler};
6
7impl VirtualMachine {
8    /// Execute a named function with arguments, returning its result.
9    ///
10    /// If the program has module-level bindings, the top-level code is executed
11    /// first (once) to initialize them before calling the target function.
12    pub fn execute_function_by_name(
13        &mut self,
14        name: &str,
15        args: Vec<ValueWord>,
16        ctx: Option<&mut shape_runtime::context::ExecutionContext>,
17    ) -> Result<ValueWord, VMError> {
18        let func_id = self
19            .program
20            .functions
21            .iter()
22            .position(|f| f.name == name)
23            .ok_or_else(|| VMError::RuntimeError(format!("Function '{}' not found", name)))?;
24
25        // Run the top-level code first to initialize module bindings,
26        // but only if there are module bindings that need initialization.
27        if !self.program.module_binding_names.is_empty() && !self.module_init_done {
28            self.reset();
29            self.execute(None)?;
30            self.module_init_done = true;
31        }
32
33        // Now call the target function.
34        // Use reset_stack to keep module_bindings intact.
35        self.reset_stack();
36        self.ip = self.program.instructions.len();
37        self.call_function_with_nb_args(func_id as u16, &args)?;
38        self.execute(ctx)
39    }
40
41    /// Execute a function by its ID with positional arguments.
42    ///
43    /// Used by the remote execution system when the caller already knows the
44    /// function index (e.g., from a `RemoteCallRequest.function_id`).
45    pub fn execute_function_by_id(
46        &mut self,
47        func_id: u16,
48        args: Vec<ValueWord>,
49        ctx: Option<&mut shape_runtime::context::ExecutionContext>,
50    ) -> Result<ValueWord, VMError> {
51        self.reset();
52        self.ip = self.program.instructions.len();
53        self.call_function_with_nb_args(func_id, &args)?;
54        self.execute(ctx)
55    }
56
57    /// Execute a closure with its captured upvalues and arguments.
58    ///
59    /// Used by the remote execution system to run closures that were
60    /// serialized with their captured values.
61    pub fn execute_closure(
62        &mut self,
63        function_id: u16,
64        upvalues: Vec<Upvalue>,
65        args: Vec<ValueWord>,
66        ctx: Option<&mut shape_runtime::context::ExecutionContext>,
67    ) -> Result<ValueWord, VMError> {
68        self.reset();
69        self.ip = self.program.instructions.len();
70        self.call_closure_with_nb_args(function_id, upvalues, &args)?;
71        self.execute(ctx)
72    }
73
74    /// Fast function execution for hot loops (backtesting)
75    /// - Uses pre-computed function ID (no name lookup)
76    /// - Uses reset_minimal() for minimum overhead
77    /// - Uses execute_fast() which skips debugging overhead
78    /// - Assumes function doesn't create GC objects or use exceptions
79    pub fn execute_function_fast(
80        &mut self,
81        func_id: u16,
82        ctx: Option<&mut shape_runtime::context::ExecutionContext>,
83    ) -> Result<ValueWord, VMError> {
84        // Minimal reset - only essential state, no GC overhead
85        self.reset_minimal();
86        self.ip = self.program.instructions.len();
87        self.call_function_with_nb_args(func_id, &[])?;
88        self.execute_fast(ctx)
89    }
90
91    /// Execute a function with named arguments
92    /// Maps named args to positional based on function's param_names
93    pub fn execute_function_with_named_args(
94        &mut self,
95        func_id: u16,
96        named_args: &[(String, ValueWord)],
97        ctx: Option<&mut shape_runtime::context::ExecutionContext>,
98    ) -> Result<ValueWord, VMError> {
99        let function = self
100            .program
101            .functions
102            .get(func_id as usize)
103            .ok_or(VMError::InvalidCall)?;
104
105        // Map named args to positional based on param_names
106        let mut args = vec![ValueWord::none(); function.arity as usize];
107        for (name, value) in named_args {
108            if let Some(idx) = function.param_names.iter().position(|p| p == name) {
109                if idx < args.len() {
110                    args[idx] = value.clone();
111                }
112            }
113        }
114
115        self.reset_minimal();
116        self.ip = self.program.instructions.len();
117        self.call_function_with_nb_args(func_id, &args)?;
118        self.execute_fast(ctx)
119    }
120
121    /// Resume execution after a suspension.
122    ///
123    /// The resolved value is pushed onto the stack, and execution continues
124    /// from where it left off (the IP is already set to the resume point).
125    pub fn resume(
126        &mut self,
127        value: ValueWord,
128        ctx: Option<&mut shape_runtime::context::ExecutionContext>,
129    ) -> Result<ExecutionResult, VMError> {
130        self.push_vw(value)?;
131        self.execute_with_suspend(ctx)
132    }
133
134    /// Execute with automatic async task resolution.
135    ///
136    /// Runs `execute_with_suspend` in a loop. Each time the VM suspends on a
137    /// `Future { id }`, the host resolves the task via the TaskScheduler
138    /// (synchronously executing the spawned callable inline) and resumes the
139    /// VM with the result. This continues until execution completes or an
140    /// unresolvable suspension is encountered.
141    pub fn execute_with_async(
142        &mut self,
143        mut ctx: Option<&mut shape_runtime::context::ExecutionContext>,
144    ) -> Result<ValueWord, VMError> {
145        loop {
146            match self.execute_with_suspend(ctx.as_deref_mut())? {
147                ExecutionResult::Completed(value) => return Ok(value),
148                ExecutionResult::Suspended { future_id, .. } => {
149                    // Try to resolve via the task scheduler
150                    let result = self.resolve_spawned_task(future_id)?;
151                    // Push the result so the resumed VM finds it on the stack
152                    self.push_vw(result)?;
153                    // Loop continues with execute_with_suspend
154                }
155            }
156        }
157    }
158
159    /// Resolve a spawned task by executing its callable synchronously.
160    ///
161    /// Looks up the callable in the TaskScheduler, then executes it:
162    /// - NanTag::Function -> calls via call_function_with_nb_args
163    /// - HeapValue::Closure -> calls via call_closure_with_nb_args
164    /// - Other values -> returns them directly (already-resolved value)
165    ///
166    /// For externally-completed tasks (remote calls), checks the oneshot
167    /// receiver first (non-blocking).
168    fn resolve_spawned_task(&mut self, task_id: u64) -> Result<ValueWord, VMError> {
169        // Check if already resolved (cached)
170        if let Some(task_scheduler::TaskStatus::Completed(val)) =
171            self.task_scheduler.get_result(task_id)
172        {
173            return Ok(val.clone());
174        }
175        if let Some(task_scheduler::TaskStatus::Cancelled) = self.task_scheduler.get_result(task_id)
176        {
177            return Err(VMError::RuntimeError(format!(
178                "Task {} was cancelled",
179                task_id
180            )));
181        }
182
183        // Check external receivers (non-blocking) before inline execution
184        if let Some(result) = self.task_scheduler.try_resolve_external(task_id) {
185            return result;
186        }
187
188        // If this is an external task that hasn't completed yet, block on it
189        // using tokio's block_in_place to avoid deadlocking the runtime.
190        if self.task_scheduler.has_external(task_id) {
191            if let Some(rx) = self.task_scheduler.take_external_receiver(task_id) {
192                let result = tokio::task::block_in_place(|| {
193                    tokio::runtime::Handle::current().block_on(rx)
194                })
195                .map_err(|_| VMError::RuntimeError("Remote task dropped".to_string()))?
196                .map_err(VMError::RuntimeError)?;
197                self.task_scheduler.complete(task_id, result.clone());
198                return Ok(result);
199            }
200        }
201
202        // Take the callable
203        let callable_nb = self.task_scheduler.take_callable(task_id).ok_or_else(|| {
204            VMError::RuntimeError(format!("No callable registered for task {}", task_id))
205        })?;
206
207        // Execute based on callable type.
208        // We save/restore the instruction pointer and stack depth so the
209        // nested execution doesn't corrupt the outer (suspended) state.
210        use shape_value::NanTag;
211
212        let result_nb = match callable_nb.tag() {
213            NanTag::Function => {
214                let func_id = callable_nb.as_function().ok_or(VMError::InvalidCall)?;
215                let saved_ip = self.ip;
216                let saved_sp = self.sp;
217
218                self.ip = self.program.instructions.len();
219                self.call_function_with_nb_args(func_id, &[])?;
220                let res = self.execute_fast(None);
221
222                self.ip = saved_ip;
223                // Restore stack pointer (clear anything left above saved_sp)
224                for i in saved_sp..self.sp {
225                    self.stack[i] = ValueWord::none();
226                }
227                self.sp = saved_sp;
228
229                res?
230            }
231            NanTag::Heap => {
232                if let Some((function_id, upvalues)) = callable_nb.as_closure() {
233                    let upvalues = upvalues.to_vec();
234                    let saved_ip = self.ip;
235                    let saved_sp = self.sp;
236
237                    self.ip = self.program.instructions.len();
238                    self.call_closure_with_nb_args(function_id, upvalues, &[])?;
239                    let res = self.execute_fast(None);
240
241                    self.ip = saved_ip;
242                    for i in saved_sp..self.sp {
243                        self.stack[i] = ValueWord::none();
244                    }
245                    self.sp = saved_sp;
246
247                    res?
248                } else {
249                    // If someone spawned an already-resolved value, just return it
250                    callable_nb
251                }
252            }
253            // If someone spawned an already-resolved value, just return it
254            _ => callable_nb,
255        };
256
257        // Cache the result
258        self.task_scheduler.complete(task_id, result_nb.clone());
259
260        Ok(result_nb)
261    }
262
263    /// ValueWord-module function call: takes ValueWord args directly.
264    pub(crate) fn call_function_with_nb_args(
265        &mut self,
266        func_id: u16,
267        args: &[ValueWord],
268    ) -> Result<(), VMError> {
269        let function = self
270            .program
271            .functions
272            .get(func_id as usize)
273            .ok_or(VMError::InvalidCall)?;
274
275        if self.call_stack.len() >= self.config.max_call_depth {
276            return Err(VMError::StackOverflow);
277        }
278
279        let locals_count = function.locals_count as usize;
280        let param_count = function.arity as usize;
281        let entry_point = function.entry_point;
282        let ref_params = function.ref_params.clone();
283
284        // Count ref params that need shadow slots for their actual values.
285        // DerefLoad/DerefStore expect the param slot to contain a TAG_REF
286        // pointing to a *different* slot that holds the real value.
287        let ref_shadow_count = ref_params
288            .iter()
289            .enumerate()
290            .filter(|&(i, &is_ref)| is_ref && i < param_count && i < locals_count)
291            .count();
292
293        let bp = self.sp;
294        let total_slots = locals_count + ref_shadow_count;
295        let needed = bp + total_slots;
296        if needed > self.stack.len() {
297            self.stack.resize_with(needed * 2 + 1, ValueWord::none);
298        }
299
300        for i in 0..param_count {
301            if i < locals_count {
302                self.stack[bp + i] = args.get(i).cloned().unwrap_or_else(ValueWord::none);
303            }
304        }
305
306        // For ref-inferred parameters: move the actual value to a shadow slot
307        // beyond locals_count, then replace the param slot with a TAG_REF
308        // pointing to the shadow slot. This way DerefLoad follows the ref
309        // to the actual value (not a circular self-reference).
310        let mut shadow_idx = 0;
311        for (i, &is_ref) in ref_params.iter().enumerate() {
312            if is_ref && i < param_count && i < locals_count {
313                let shadow_slot = bp + locals_count + shadow_idx;
314                self.stack[shadow_slot] = self.stack[bp + i].clone();
315                self.stack[bp + i] = ValueWord::from_ref(shadow_slot);
316                shadow_idx += 1;
317            }
318        }
319
320        self.sp = needed;
321
322        let blob_hash = self.blob_hash_for_function(func_id);
323        let frame = CallFrame {
324            return_ip: self.ip,
325            base_pointer: bp,
326            locals_count: total_slots,
327            function_id: Some(func_id),
328            upvalues: None,
329            blob_hash,
330        };
331        self.call_stack.push(frame);
332        self.ip = entry_point;
333        Ok(())
334    }
335
336    /// ValueWord-host closure call: takes ValueWord args directly.
337    pub(crate) fn call_closure_with_nb_args(
338        &mut self,
339        func_id: u16,
340        upvalues: Vec<Upvalue>,
341        args: &[ValueWord],
342    ) -> Result<(), VMError> {
343        let function = self
344            .program
345            .functions
346            .get(func_id as usize)
347            .ok_or(VMError::InvalidCall)?;
348
349        if self.call_stack.len() >= self.config.max_call_depth {
350            return Err(VMError::StackOverflow);
351        }
352
353        let locals_count = function.locals_count as usize;
354        let captures_count = function.captures_count as usize;
355        let arity = function.arity as usize;
356        let entry_point = function.entry_point;
357
358        let bp = self.sp;
359        let needed = bp + locals_count;
360        if needed > self.stack.len() {
361            self.stack.resize_with(needed * 2 + 1, ValueWord::none);
362        }
363
364        // Bind upvalue values as the first N locals
365        for (i, upvalue) in upvalues.iter().enumerate() {
366            if i < locals_count {
367                self.stack[bp + i] = upvalue.get();
368            }
369        }
370
371        // Bind the regular arguments after the upvalues
372        for (i, arg) in args.iter().enumerate() {
373            let local_idx = captures_count + i;
374            if local_idx < locals_count {
375                self.stack[bp + local_idx] = arg.clone();
376            }
377        }
378
379        // Fill remaining parameters with None
380        for i in (captures_count + args.len())..arity.min(locals_count) {
381            self.stack[bp + i] = ValueWord::none();
382        }
383
384        self.sp = needed;
385
386        let blob_hash = self.blob_hash_for_function(func_id);
387        self.call_stack.push(CallFrame {
388            return_ip: self.ip,
389            base_pointer: bp,
390            locals_count,
391            function_id: Some(func_id),
392            upvalues: Some(upvalues),
393            blob_hash,
394        });
395
396        self.ip = entry_point;
397        Ok(())
398    }
399
400    /// ValueWord-native call_value_immediate: dispatches on NanTag/HeapKind.
401    ///
402    /// Returns ValueWord directly.
403    pub(in crate::executor) fn call_value_immediate_nb(
404        &mut self,
405        callee: &ValueWord,
406        args: &[ValueWord],
407        ctx: Option<&mut shape_runtime::context::ExecutionContext>,
408    ) -> Result<ValueWord, VMError> {
409        use shape_value::NanTag;
410        let target_depth = self.call_stack.len();
411
412        match callee.tag() {
413            NanTag::Function => {
414                let func_id = callee.as_function().ok_or(VMError::InvalidCall)?;
415                self.call_function_with_nb_args(func_id, args)?;
416            }
417            NanTag::ModuleFunction => {
418                let func_id = callee.as_module_function().ok_or(VMError::InvalidCall)?;
419                let module_fn = self.module_fn_table.get(func_id).cloned().ok_or_else(|| {
420                    VMError::RuntimeError(format!(
421                        "Module function ID {} not found in registry",
422                        func_id
423                    ))
424                })?;
425                let args_vec: Vec<ValueWord> = args.to_vec();
426                let result_nb = self.invoke_module_fn(&module_fn, &args_vec)?;
427                return Ok(result_nb);
428            }
429            NanTag::Heap => match callee.as_heap_ref() {
430                Some(shape_value::HeapValue::Closure {
431                    function_id,
432                    upvalues,
433                }) => {
434                    self.call_closure_with_nb_args(*function_id, upvalues.clone(), args)?;
435                }
436                Some(shape_value::HeapValue::HostClosure(callable)) => {
437                    let args_vec: Vec<ValueWord> = args.to_vec();
438                    let result_nb = callable.call(&args_vec).map_err(VMError::RuntimeError)?;
439                    return Ok(result_nb);
440                }
441                _ => return Err(VMError::InvalidCall),
442            },
443            _ => return Err(VMError::InvalidCall),
444        }
445
446        self.execute_until_call_depth(target_depth, ctx)?;
447        self.pop_vw()
448    }
449
450    /// Fast-path function call: reads `arg_count` arguments directly from the
451    /// value stack instead of collecting them into a temporary `Vec`.
452    ///
453    /// Precondition: the top `arg_count` values on the stack (below sp) are the
454    /// arguments in left-to-right order (arg0 deepest, argN-1 at top).
455    /// These args become the first locals of the new frame's register window.
456    pub(crate) fn call_function_from_stack(
457        &mut self,
458        func_id: u16,
459        arg_count: usize,
460    ) -> Result<(), VMError> {
461        let function = self
462            .program
463            .functions
464            .get(func_id as usize)
465            .ok_or(VMError::InvalidCall)?;
466
467        if self.call_stack.len() >= self.config.max_call_depth {
468            return Err(VMError::StackOverflow);
469        }
470
471        let locals_count = function.locals_count as usize;
472        let entry_point = function.entry_point;
473        let arity = function.arity as usize;
474
475        // The args are already on the stack at positions [sp - arg_count .. sp).
476        // They become the first locals in the register window.
477        // bp = sp - arg_count (args are already in place as the first locals)
478        let bp = self.sp.saturating_sub(arg_count);
479
480        // Ensure stack has room for all locals (some may be beyond the args)
481        let needed = bp + locals_count;
482        if needed > self.stack.len() {
483            self.stack.resize_with(needed * 2 + 1, ValueWord::none);
484        }
485
486        // Zero remaining local slots (including omitted args that the compiler
487        // may intentionally represent as null sentinels for default params).
488        let copy_count = arg_count.min(arity).min(locals_count);
489        for i in copy_count..locals_count {
490            self.stack[bp + i] = ValueWord::none();
491        }
492
493        // Advance sp past all locals
494        self.sp = needed;
495
496        let blob_hash = self.blob_hash_for_function(func_id);
497        self.call_stack.push(CallFrame {
498            return_ip: self.ip,
499            base_pointer: bp,
500            locals_count,
501            function_id: Some(func_id),
502            upvalues: None,
503            blob_hash,
504        });
505        self.ip = entry_point;
506        Ok(())
507    }
508}