Skip to main content

runmat_vm/interpreter/
runner.rs

1use crate::accel::fusion as accel_fusion;
2use crate::accel::residency as accel_residency;
3use crate::bytecode::{Bytecode, Instr, UserFunction};
4use crate::call::shared as call_shared;
5use crate::call::user as call_user;
6use crate::interpreter::api::{InterpreterOutcome, InterpreterState};
7use crate::interpreter::dispatch::{self as interp_dispatch, DispatchDecision};
8use crate::interpreter::engine as interp_engine;
9use crate::interpreter::errors::{attach_span_from_pc, mex, set_vm_pc};
10use crate::interpreter::timing::InterpreterTiming;
11use crate::runtime::call_stack::attach_call_frames;
12use crate::runtime::globals as runtime_globals;
13use crate::runtime::workspace::{
14    refresh_workspace_state, workspace_assign, workspace_clear, workspace_lookup, workspace_remove,
15    workspace_snapshot,
16};
17use runmat_builtins::Value;
18use runmat_runtime::{
19    user_functions,
20    workspace::{self as runtime_workspace, WorkspaceResolver},
21    RuntimeError,
22};
23use runmat_thread_local::runmat_thread_local;
24use std::cell::RefCell;
25use std::collections::HashMap;
26use std::future::Future;
27use std::pin::Pin;
28use std::sync::Arc;
29use std::sync::Once;
30use tracing::{debug, info_span};
31
32#[cfg(feature = "native-accel")]
33use runmat_accelerate::{
34    activate_fusion_plan, active_group_plan_clone, deactivate_fusion_plan, set_current_pc,
35};
36
37#[cfg(feature = "native-accel")]
38struct FusionPlanGuard;
39
40#[cfg(feature = "native-accel")]
41impl Drop for FusionPlanGuard {
42    fn drop(&mut self) {
43        deactivate_fusion_plan();
44    }
45}
46
47type VmResult<T> = Result<T, RuntimeError>;
48
49fn invoke_user_for_end_expr_adapter<'a>(
50    name: &'a str,
51    argv: Vec<Value>,
52    functions: &'a HashMap<String, UserFunction>,
53    vars_ref: &'a [Value],
54) -> Pin<Box<dyn Future<Output = Result<Value, RuntimeError>> + 'a>> {
55    Box::pin(async move {
56        let mut local_vars = vars_ref.to_owned();
57        invoke_user_function_value(name, &argv, functions, &mut local_vars).await
58    })
59}
60
61fn builtin_fallback_user_call_adapter(
62    name: String,
63    args: Vec<Value>,
64    out_count: usize,
65) -> Pin<Box<dyn Future<Output = Result<Option<Value>, RuntimeError>>>> {
66    Box::pin(async move {
67        if out_count == 1 {
68            call_user::try_builtin_fallback_single(&name, &args).await
69        } else {
70            call_user::try_builtin_fallback_multi(&name, &args, out_count).await
71        }
72    })
73}
74
75fn interpret_counts_adapter(
76    bc: Bytecode,
77    vars: Vec<Value>,
78    name: String,
79    out_count: usize,
80    in_count: usize,
81) -> Pin<Box<dyn Future<Output = Result<Vec<Value>, RuntimeError>>>> {
82    Box::pin(
83        async move { interpret_function_with_counts(&bc, vars, &name, out_count, in_count).await },
84    )
85}
86
87runmat_thread_local! {
88    static CALL_COUNTS: RefCell<Vec<(usize, usize)>> = const { RefCell::new(Vec::new()) };
89}
90
91runmat_thread_local! {
92    static USER_FUNCTION_VARS: RefCell<Option<*mut Vec<Value>>> = const { RefCell::new(None) };
93}
94
95runmat_thread_local! {
96    static DYNAMIC_USER_FUNCTIONS: RefCell<HashMap<String, UserFunction>> = RefCell::new(HashMap::new());
97}
98
99pub fn dynamic_user_functions_snapshot() -> HashMap<String, UserFunction> {
100    DYNAMIC_USER_FUNCTIONS.with(|slot| slot.borrow().clone())
101}
102
103fn clear_dynamic_user_functions() {
104    DYNAMIC_USER_FUNCTIONS.with(|slot| slot.borrow_mut().clear());
105}
106
107fn register_dynamic_user_functions(functions: &HashMap<String, UserFunction>) {
108    DYNAMIC_USER_FUNCTIONS.with(|slot| {
109        let mut map = slot.borrow_mut();
110        for (k, v) in functions {
111            map.insert(k.clone(), v.clone());
112        }
113    });
114}
115
116struct UserFunctionVarsGuard {
117    previous: Option<*mut Vec<Value>>,
118}
119
120impl Drop for UserFunctionVarsGuard {
121    fn drop(&mut self) {
122        let previous = self.previous.take();
123        USER_FUNCTION_VARS.with(|slot| {
124            *slot.borrow_mut() = previous;
125        });
126    }
127}
128
129fn install_user_function_vars(vars: &mut Vec<Value>) -> UserFunctionVarsGuard {
130    let vars_ptr = vars as *mut Vec<Value>;
131    let previous = USER_FUNCTION_VARS.with(|slot| slot.borrow_mut().replace(vars_ptr));
132    UserFunctionVarsGuard { previous }
133}
134
135fn sync_initial_vars(initial: &mut [Value], vars: &[Value]) {
136    for (i, var) in vars.iter().enumerate() {
137        if i < initial.len() {
138            initial[i] = var.clone();
139        }
140    }
141}
142
143fn ensure_workspace_resolver_registered() {
144    static REGISTER: Once = Once::new();
145    REGISTER.call_once(|| {
146        runtime_workspace::register_workspace_resolver(WorkspaceResolver {
147            lookup: workspace_lookup,
148            snapshot: workspace_snapshot,
149            globals: runtime_globals::workspace_global_names,
150            assign: Some(workspace_assign),
151            clear: Some(workspace_clear),
152            remove: Some(workspace_remove),
153        });
154    });
155}
156
157fn ensure_wasm_builtins_registered() {
158    #[cfg(target_arch = "wasm32")]
159    {
160        static REGISTER: Once = Once::new();
161        REGISTER.call_once(|| {
162            runmat_runtime::builtins::wasm_registry::register_all();
163        });
164    }
165}
166
167#[cfg(feature = "native-accel")]
168fn clear_residency(value: &Value) {
169    accel_residency::clear_value(value);
170}
171
172#[cfg(feature = "native-accel")]
173fn same_gpu_handle(lhs: &Value, rhs: &Value) -> bool {
174    accel_residency::same_gpu_handle(lhs, rhs)
175}
176
177async fn invoke_user_function_value(
178    name: &str,
179    args: &[Value],
180    functions: &HashMap<String, UserFunction>,
181    vars: &mut [Value],
182) -> Result<Value, RuntimeError> {
183    let func = call_shared::lookup_user_function(name, functions)?;
184    let arg_count = args.len();
185    call_shared::validate_user_function_arity(name, &func, arg_count)?;
186    let prepared = call_shared::prepare_user_call(func, args, vars)?;
187    let crate::call::shared::PreparedUserCall {
188        func,
189        var_map,
190        func_program,
191        func_vars,
192    } = prepared;
193    let func_bytecode = crate::compile(&func_program, functions)?;
194    register_dynamic_user_functions(&func_bytecode.functions);
195    let func_result_vars =
196        interpret_function_with_counts(&func_bytecode, func_vars, name, 1, arg_count).await?;
197    Ok(call_shared::first_output_value(
198        &func,
199        &var_map,
200        &func_result_vars,
201    ))
202}
203
204pub async fn interpret_with_vars(
205    bytecode: &Bytecode,
206    initial_vars: &mut [Value],
207    current_function_name: Option<&str>,
208) -> VmResult<InterpreterOutcome> {
209    let is_top_level = CALL_COUNTS.with(|cc| cc.borrow().is_empty());
210    if is_top_level {
211        clear_dynamic_user_functions();
212    }
213    let call_counts = CALL_COUNTS.with(|cc| cc.borrow().clone());
214    let state = Box::new(InterpreterState::new(
215        bytecode.clone(),
216        initial_vars,
217        current_function_name,
218        call_counts,
219    ));
220    match Box::pin(run_interpreter(state, initial_vars)).await {
221        Ok(outcome) => Ok(outcome),
222        Err(err) => {
223            let err = attach_span_from_pc(bytecode, err);
224            let current_name = current_function_name.unwrap_or("<main>");
225            Err(attach_call_frames(bytecode, current_name, err))
226        }
227    }
228}
229
230async fn run_interpreter(
231    state: Box<InterpreterState>,
232    initial_vars: &mut [Value],
233) -> VmResult<InterpreterOutcome> {
234    let state = *state;
235    Box::pin(run_interpreter_inner(state, initial_vars)).await
236}
237
238async fn run_interpreter_inner(
239    state: InterpreterState,
240    initial_vars: &mut [Value],
241) -> VmResult<InterpreterOutcome> {
242    let run_span = info_span!(
243        "interpreter.run",
244        function = state.current_function_name.as_str()
245    );
246    let _run_guard = run_span.enter();
247    ensure_wasm_builtins_registered();
248    ensure_workspace_resolver_registered();
249    #[cfg(feature = "native-accel")]
250    activate_fusion_plan(state.fusion_plan.clone());
251    #[cfg(feature = "native-accel")]
252    let _fusion_guard = FusionPlanGuard;
253    let InterpreterState {
254        mut stack,
255        mut vars,
256        mut pc,
257        mut context,
258        mut try_stack,
259        mut last_exception,
260        mut imports,
261        mut global_aliases,
262        mut persistent_aliases,
263        current_function_name,
264        call_counts,
265        #[cfg(feature = "native-accel")]
266            fusion_plan: _,
267        bytecode,
268    } = state;
269    let functions = Arc::new(context.functions.clone());
270    let _user_function_vars_guard = install_user_function_vars(&mut vars);
271    let _user_function_guard = user_functions::install_user_function_invoker(Some(Arc::new(
272        move |name: &str, args: &[Value]| {
273            let name = name.to_string();
274            let args = args.to_vec();
275            let functions = Arc::clone(&functions);
276            Box::pin(async move {
277                let vars_ptr = USER_FUNCTION_VARS.with(|slot| *slot.borrow());
278                let Some(vars_ptr) = vars_ptr else {
279                    return Err(mex(
280                        "InternalStateUnavailable",
281                        "user function vars not installed",
282                    ));
283                };
284                let vars = unsafe { &mut *vars_ptr };
285                invoke_user_function_value(&name, &args, &functions, vars).await
286            })
287        },
288    )));
289    CALL_COUNTS.with(|cc| {
290        *cc.borrow_mut() = call_counts.clone();
291    });
292    let _workspace_guard = interp_engine::prepare_workspace_guard(&mut vars);
293    let thread_roots: Vec<Value> = runtime_globals::collect_thread_roots();
294    let mut _gc_context = interp_engine::create_gc_context(&stack, &vars, thread_roots)?;
295    let debug_stack = interp_engine::debug_stack_enabled();
296    let mut interpreter_timing = InterpreterTiming::new();
297    while pc < bytecode.instructions.len() {
298        set_vm_pc(pc);
299        #[cfg(feature = "native-accel")]
300        set_current_pc(pc);
301        interp_engine::check_cancelled()?;
302        #[cfg(feature = "native-accel")]
303        if let (Some(plan), Some(graph)) =
304            (active_group_plan_clone(), bytecode.accel_graph.as_ref())
305        {
306            if plan.group.span.start == pc {
307                #[cfg(feature = "native-accel")]
308                {
309                    interp_engine::note_fusion_gate(
310                        &mut interpreter_timing,
311                        &plan,
312                        &bytecode,
313                        pc,
314                        accel_fusion::fusion_span_has_vm_barrier(
315                            &bytecode.instructions,
316                            &plan.group.span,
317                        ),
318                        accel_fusion::fusion_span_live_result_count(
319                            &bytecode.instructions,
320                            &plan.group.span,
321                        ),
322                    );
323                }
324                let span = plan.group.span.clone();
325                let has_barrier =
326                    accel_fusion::fusion_span_has_vm_barrier(&bytecode.instructions, &span);
327                let _fusion_span = info_span!(
328                    "fusion.execute",
329                    span_start = plan.group.span.start,
330                    span_end = plan.group.span.end,
331                    kind = ?plan.group.kind
332                )
333                .entered();
334                if !has_barrier {
335                    match accel_fusion::try_execute_fusion_group(
336                        &plan,
337                        graph,
338                        &mut stack,
339                        &mut vars,
340                        &mut context,
341                    )
342                    .await
343                    {
344                        Ok(result) => {
345                            stack.push(result);
346                            pc = plan.group.span.end + 1;
347                            continue;
348                        }
349                        Err(err) => {
350                            log::debug!("fusion fallback at pc {}: {}", pc, err);
351                        }
352                    }
353                } else {
354                    interp_engine::note_fusion_skip(pc, &span);
355                }
356            }
357        }
358        interp_engine::note_pre_dispatch(
359            &mut interpreter_timing,
360            debug_stack,
361            pc,
362            &bytecode.instructions[pc],
363            stack.len(),
364        );
365        let next_instr = bytecode.instructions.get(pc + 1);
366        let call_counts_snapshot = CALL_COUNTS.with(|cc| cc.borrow().clone());
367        let store_var_global_aliases = match &bytecode.instructions[pc] {
368            Instr::StoreVar(_) => Some(global_aliases.clone()),
369            _ => None,
370        };
371        let mut clear_value_residency = |value: &Value| {
372            #[cfg(feature = "native-accel")]
373            clear_residency(value);
374        };
375        let mut store_var_before_overwrite = |current: &Value, incoming: &Value| {
376            #[cfg(feature = "native-accel")]
377            if !same_gpu_handle(current, incoming) {
378                clear_residency(current);
379            }
380        };
381        let mut store_var_after_store = |stored_index: usize, stored_value: &Value| {
382            if let Some(ref aliases) = store_var_global_aliases {
383                runtime_globals::update_global_store(stored_index, stored_value, aliases);
384            }
385        };
386        let mut store_local_before_local_overwrite = |current: &Value, incoming: &Value| {
387            #[cfg(feature = "native-accel")]
388            if !same_gpu_handle(current, incoming) {
389                clear_residency(current);
390            }
391        };
392        let mut store_local_before_var_overwrite = |current: &Value, incoming: &Value| {
393            #[cfg(feature = "native-accel")]
394            if !same_gpu_handle(current, incoming) {
395                clear_residency(current);
396            }
397        };
398        let mut store_local_after_fallback_store =
399            |func_name: &str, stored_offset: usize, stored_value: &Value| {
400                runtime_globals::update_persistent_local_store(
401                    func_name,
402                    stored_offset,
403                    stored_value,
404                );
405            };
406        let dispatch_result = interp_dispatch::dispatch_instruction(
407            interp_dispatch::DispatchMeta {
408                instr: &bytecode.instructions[pc],
409                var_names: &bytecode.var_names,
410                bytecode_functions: &bytecode.functions,
411                source_id: bytecode.source_id,
412                call_arg_spans: bytecode.call_arg_spans.get(pc).cloned().flatten(),
413                call_counts: &call_counts_snapshot,
414                current_function_name: &current_function_name,
415                next_instr,
416            },
417            interp_dispatch::DispatchState {
418                stack: &mut stack,
419                vars: &mut vars,
420                context: &mut context,
421                try_stack: &mut try_stack,
422                last_exception: &mut last_exception,
423                imports: &mut imports,
424                global_aliases: &mut global_aliases,
425                persistent_aliases: &mut persistent_aliases,
426                pc: &mut pc,
427            },
428            interp_dispatch::DispatchHooks {
429                clear_value_residency: &mut clear_value_residency,
430                invoke_user_for_end_expr: &invoke_user_for_end_expr_adapter,
431                builtin_fallback_user_call: &builtin_fallback_user_call_adapter,
432                interpret_function_counts: &interpret_counts_adapter,
433                store_var_before_overwrite: &mut store_var_before_overwrite,
434                store_var_after_store: &mut store_var_after_store,
435                store_local_before_local_overwrite: &mut store_local_before_local_overwrite,
436                store_local_before_var_overwrite: &mut store_local_before_var_overwrite,
437                store_local_after_fallback_store: &mut store_local_after_fallback_store,
438            },
439        )
440        .await;
441        let dispatch_result = match dispatch_result {
442            Ok(result) => result,
443            Err(err) => match interp_dispatch::redirect_exception_to_catch(
444                err,
445                &mut try_stack,
446                &mut vars,
447                &mut last_exception,
448                &mut pc,
449                refresh_workspace_state,
450            ) {
451                interp_dispatch::ExceptionHandling::Caught => {
452                    continue;
453                }
454                interp_dispatch::ExceptionHandling::Uncaught(err) => return Err(*err),
455            },
456        };
457        if let Some(decision) = dispatch_result {
458            match decision {
459                interp_dispatch::DispatchHandled::Generic(DispatchDecision::ContinueLoop) => {
460                    continue
461                }
462                interp_dispatch::DispatchHandled::Generic(DispatchDecision::FallThrough) => {
463                    pc += 1;
464                    continue;
465                }
466                interp_dispatch::DispatchHandled::Generic(DispatchDecision::Return) => {
467                    interpreter_timing.flush_host_span("return", None);
468                    break;
469                }
470                interp_dispatch::DispatchHandled::ReturnValue(DispatchDecision::ContinueLoop)
471                | interp_dispatch::DispatchHandled::Return(DispatchDecision::ContinueLoop) => {
472                    continue
473                }
474                interp_dispatch::DispatchHandled::ReturnValue(DispatchDecision::Return) => {
475                    interpreter_timing.flush_host_span("return_value", None);
476                    break;
477                }
478                interp_dispatch::DispatchHandled::Return(DispatchDecision::Return) => {
479                    interpreter_timing.flush_host_span("return", None);
480                    break;
481                }
482                interp_dispatch::DispatchHandled::ReturnValue(DispatchDecision::FallThrough)
483                | interp_dispatch::DispatchHandled::Return(DispatchDecision::FallThrough) => {
484                    pc += 1;
485                    continue;
486                }
487            }
488        }
489        match bytecode.instructions[pc].clone() {
490            Instr::EmitStackTop { .. }
491            | Instr::EmitVar { .. }
492            | Instr::AndAnd(_)
493            | Instr::OrOr(_)
494            | Instr::JumpIfFalse(_)
495            | Instr::Jump(_)
496            | Instr::LoadConst(_)
497            | Instr::LoadComplex(_, _)
498            | Instr::LoadBool(_)
499            | Instr::LoadString(_)
500            | Instr::LoadCharRow(_)
501            | Instr::LoadLocal(_)
502            | Instr::LoadVar(_)
503            | Instr::StoreVar(_)
504            | Instr::StoreLocal(_)
505            | Instr::Swap
506            | Instr::Pop
507            | Instr::EnterTry(_, _)
508            | Instr::PopTry
509            | Instr::ReturnValue
510            | Instr::Return
511            | Instr::EnterScope(_)
512            | Instr::LoadMember(_)
513            | Instr::LoadMemberOrInit(_)
514            | Instr::LoadMemberDynamic
515            | Instr::LoadMemberDynamicOrInit
516            | Instr::StoreMember(_)
517            | Instr::StoreMemberOrInit(_)
518            | Instr::StoreMemberDynamic
519            | Instr::StoreMemberDynamicOrInit
520            | Instr::Index(_)
521            | Instr::IndexSlice(_, _, _, _)
522            | Instr::IndexSliceExpr { .. }
523            | Instr::IndexCell(_)
524            | Instr::IndexCellExpand(_, _)
525            | Instr::StoreIndex(_)
526            | Instr::StoreIndexCell(_)
527            | Instr::StoreSlice(_, _, _, _)
528            | Instr::StoreSliceExpr { .. }
529            | Instr::CallMethod(_, _)
530            | Instr::CallMethodOrMemberIndex(_, _)
531            | Instr::LoadMethod(_)
532            | Instr::CreateClosure(_, _)
533            | Instr::LoadStaticProperty(_, _)
534            | Instr::CallStaticMethod(_, _, _)
535            | Instr::RegisterClass { .. }
536            | Instr::CallFeval(_)
537            | Instr::CallFevalExpandMulti(_)
538            | Instr::CallBuiltin(_, _)
539            | Instr::CallFunction(_, _)
540            | Instr::CallFunctionMulti(_, _, _)
541            | Instr::CallFunctionExpandMulti(_, _)
542            | Instr::CallBuiltinExpandLast(_, _, _)
543            | Instr::CallBuiltinExpandAt(_, _, _, _)
544            | Instr::CallBuiltinExpandMulti(_, _)
545            | Instr::CallFunctionExpandAt(_, _, _, _)
546            | Instr::ExitScope(_)
547            | Instr::RegisterImport { .. }
548            | Instr::DeclareGlobal(_)
549            | Instr::DeclareGlobalNamed(_, _)
550            | Instr::DeclarePersistent(_)
551            | Instr::DeclarePersistentNamed(_, _)
552            | Instr::CreateCell2D(_, _)
553            | Instr::Add
554            | Instr::Sub
555            | Instr::Mul
556            | Instr::ElemMul
557            | Instr::ElemDiv
558            | Instr::ElemPow
559            | Instr::ElemLeftDiv
560            | Instr::Neg
561            | Instr::UPlus
562            | Instr::Transpose
563            | Instr::ConjugateTranspose
564            | Instr::Pow
565            | Instr::RightDiv
566            | Instr::LeftDiv
567            | Instr::LessEqual
568            | Instr::Less
569            | Instr::Greater
570            | Instr::GreaterEqual
571            | Instr::Equal
572            | Instr::NotEqual
573            | Instr::Unpack(_)
574            | Instr::CreateMatrix(_, _)
575            | Instr::CreateMatrixDynamic(_)
576            | Instr::CreateRange(_)
577            | Instr::PackToRow(_)
578            | Instr::PackToCol(_) => unreachable!("handled by dispatch_instruction"),
579            Instr::StochasticEvolution => {
580                let steps_value = stack
581                    .pop()
582                    .ok_or(mex("StackUnderflow", "stack underflow"))?;
583                let scale_value = stack
584                    .pop()
585                    .ok_or(mex("StackUnderflow", "stack underflow"))?;
586                let drift_value = stack
587                    .pop()
588                    .ok_or(mex("StackUnderflow", "stack underflow"))?;
589                let state_value = stack
590                    .pop()
591                    .ok_or(mex("StackUnderflow", "stack underflow"))?;
592                let evolved =
593                    crate::accel::idioms::stochastic_evolution::execute_stochastic_evolution(
594                        state_value,
595                        drift_value,
596                        scale_value,
597                        steps_value,
598                    )
599                    .await?;
600                stack.push(evolved);
601            }
602        }
603        if debug_stack {
604            debug!(pc, stack_len = stack.len(), "[vm] after exec");
605        }
606        pc += 1;
607    }
608    interpreter_timing.flush_host_span("loop_complete", None);
609    sync_initial_vars(initial_vars, &vars);
610    Ok(InterpreterOutcome::Completed(vars))
611}
612
613pub async fn interpret(bytecode: &Bytecode) -> Result<Vec<Value>, RuntimeError> {
614    let mut vars = vec![Value::Num(0.0); bytecode.var_count];
615    match interpret_with_vars(bytecode, &mut vars, Some("<main>")).await {
616        Ok(InterpreterOutcome::Completed(values)) => Ok(values),
617        Err(e) => Err(e),
618    }
619}
620
621pub async fn interpret_function(
622    bytecode: &Bytecode,
623    vars: Vec<Value>,
624) -> Result<Vec<Value>, RuntimeError> {
625    interpret_function_with_counts(bytecode, vars, "<anonymous>", 0, 0).await
626}
627
628pub async fn interpret_function_with_counts(
629    bytecode: &Bytecode,
630    vars: Vec<Value>,
631    name: &str,
632    out_count: usize,
633    in_count: usize,
634) -> Result<Vec<Value>, RuntimeError> {
635    let mut vars = vars;
636    CALL_COUNTS.with(|cc| {
637        cc.borrow_mut().push((in_count, out_count));
638    });
639    let res = Box::pin(interpret_with_vars(bytecode, &mut vars, Some(name))).await;
640    CALL_COUNTS.with(|cc| {
641        cc.borrow_mut().pop();
642    });
643    let res = match res {
644        Ok(InterpreterOutcome::Completed(values)) => Ok(values),
645        Err(e) => Err(e),
646    }?;
647    runtime_globals::persist_declared_for_bytecode(bytecode, name, &vars);
648    Ok(res)
649}