Skip to main content

runmat_vm/accel/
fusion.rs

1use crate::accel::residency as accel_residency;
2use crate::bytecode::program::ExecutionContext;
3use crate::bytecode::Instr;
4use crate::interpreter::engine as interp_engine;
5use crate::interpreter::errors::mex;
6use crate::runtime::workspace::refresh_workspace_state;
7use runmat_accelerate::fusion::FusionStoreMaterialization;
8use runmat_accelerate::fusion_exec::{
9    execute_centered_gram, execute_elementwise, execute_explained_variance,
10    execute_image_normalize, execute_matmul_epilogue, execute_power_step_normalize,
11    execute_reduction, FusionExecutionRequest,
12};
13use runmat_accelerate::InstrSpan;
14use runmat_accelerate::{value_is_all_keyword, FusionKind, ShapeInfo, ValueOrigin, VarKind};
15use runmat_builtins::Value;
16use runmat_runtime::builtins::common::shape::is_scalar_shape;
17use runmat_runtime::RuntimeError;
18use std::collections::HashMap;
19
20#[inline]
21pub fn value_kind(value: &Value) -> &'static str {
22    match value {
23        Value::Int(_) => "Int",
24        Value::Num(_) => "Num",
25        Value::Complex(_, _) => "Complex",
26        Value::Bool(_) => "Bool",
27        Value::LogicalArray(_) => "LogicalArray",
28        Value::String(_) => "String",
29        Value::StringArray(_) => "StringArray",
30        Value::Symbolic(_) => "Symbolic",
31        Value::CharArray(_) => "CharArray",
32        Value::Tensor(_) => "Tensor",
33        Value::SparseTensor(_) => "SparseTensor",
34        Value::ComplexTensor(_) => "ComplexTensor",
35        Value::Cell(_) => "Cell",
36        Value::Struct(_) => "Struct",
37        Value::GpuTensor(_) => "GpuTensor",
38        Value::Object(_) => "Object",
39        Value::HandleObject(_) => "HandleObject",
40        Value::Listener(_) => "Listener",
41        Value::FunctionHandle(_)
42        | Value::ExternalFunctionHandle(_)
43        | Value::MethodFunctionHandle(_) => "FunctionHandle",
44        Value::BoundFunctionHandle { .. } => "FunctionHandle",
45        Value::Closure(_) => "Closure",
46        Value::ClassRef(_) => "ClassRef",
47        Value::MException(_) => "MException",
48        Value::OutputList(_) => "OutputList",
49    }
50}
51
52#[inline]
53pub fn summarize_value(i: usize, v: &Value) -> String {
54    match v {
55        Value::GpuTensor(h) => format!("in#{i}:GpuTensor shape={:?}", h.shape),
56        Value::Tensor(t) => format!("in#{i}:Tensor shape={:?}", t.shape),
57        Value::Num(n) => format!("in#{i}:Num({n:.6})"),
58        Value::Int(n) => format!("in#{i}:Int({})", n.to_i64()),
59        Value::Bool(b) => format!("in#{i}:Bool({})", if *b { 1 } else { 0 }),
60        Value::String(s) => format!("in#{i}:String({})", s),
61        _ => format!("in#{i}:{}", value_kind(v)),
62    }
63}
64
65#[inline]
66fn is_scalarish_runtime_value(value: &Value) -> bool {
67    match value {
68        Value::Num(_) | Value::Int(_) | Value::Bool(_) | Value::Complex(_, _) => true,
69        Value::Tensor(tensor) => is_scalar_shape(&tensor.shape),
70        Value::ComplexTensor(tensor) => is_scalar_shape(&tensor.shape),
71        Value::LogicalArray(array) => is_scalar_shape(&array.shape),
72        Value::GpuTensor(handle) => is_scalar_shape(&handle.shape),
73        Value::CharArray(array) => array.rows * array.cols == 1,
74        _ => false,
75    }
76}
77
78pub fn fusion_span_live_result_count(instructions: &[Instr], span: &InstrSpan) -> Option<usize> {
79    if span.start > span.end || span.end >= instructions.len() {
80        return None;
81    }
82    let mut current_depth = 0usize;
83    for instr in &instructions[span.start..=span.end] {
84        let effect = instr.stack_effect()?;
85        if current_depth < effect.pops {
86            current_depth = effect.pops;
87        }
88        current_depth = current_depth - effect.pops + effect.pushes;
89    }
90    Some(current_depth)
91}
92
93pub fn fusion_span_has_vm_barrier(instructions: &[Instr], span: &InstrSpan) -> bool {
94    if span.start > span.end || span.end >= instructions.len() {
95        return true;
96    }
97    for instr in &instructions[span.start..=span.end] {
98        if matches!(
99            instr,
100            Instr::StoreIndex(_)
101                | Instr::StoreIndexDelete(_)
102                | Instr::StoreSlice(_, _, _, _)
103                | Instr::StoreSliceDelete(_, _, _, _)
104                | Instr::StoreSliceExpr { .. }
105                | Instr::StoreSliceExprDelete { .. }
106                | Instr::StoreIndexCell { .. }
107                | Instr::StoreIndexCellDelete { .. }
108                | Instr::StoreMember(_)
109                | Instr::StoreMemberOrInit(_)
110                | Instr::StoreMemberDynamic
111                | Instr::StoreMemberDynamicOrInit
112        ) {
113            return true;
114        }
115    }
116    fusion_span_live_result_count(instructions, span) != Some(1)
117}
118
119pub struct StackSliceGuard<'a> {
120    stack: *mut Vec<Value>,
121    slice: Option<Vec<Value>>,
122    _marker: std::marker::PhantomData<&'a mut Vec<Value>>,
123}
124
125impl<'a> StackSliceGuard<'a> {
126    pub fn new(stack: &'a mut Vec<Value>, slice_start: usize) -> Self {
127        let slice = stack.split_off(slice_start);
128        Self {
129            stack,
130            slice: Some(slice),
131            _marker: std::marker::PhantomData,
132        }
133    }
134
135    pub fn slice(&self) -> &[Value] {
136        self.slice.as_ref().expect("stack slice missing").as_slice()
137    }
138
139    pub fn commit(mut self) {
140        self.slice = None;
141    }
142}
143
144impl Drop for StackSliceGuard<'_> {
145    fn drop(&mut self) {
146        if let Some(slice) = self.slice.take() {
147            unsafe { (&mut *self.stack).extend(slice) }
148        }
149    }
150}
151
152pub fn gather_fusion_inputs<'a>(
153    plan: &'a runmat_accelerate::FusionGroupPlan,
154    graph: &runmat_accelerate::AccelGraph,
155    stack: &'a mut Vec<Value>,
156    vars: &mut [Value],
157    context: &mut ExecutionContext,
158) -> Result<
159    (
160        StackSliceGuard<'a>,
161        FusionExecutionRequest<'a>,
162        Vec<Option<Value>>,
163    ),
164    RuntimeError,
165> {
166    if plan.group.stack_layout.is_none() && !plan.stack_pattern.is_empty() {
167        return Err(mex(
168            "FusionMissingStackLayout",
169            "fusion: missing compile-time stack layout metadata",
170        ));
171    }
172    let required_stack_operands = plan
173        .group
174        .stack_layout
175        .as_ref()
176        .map(|layout| layout.required_stack_operands)
177        .unwrap_or_else(|| plan.stack_pattern.len());
178    let mut inputs: Vec<Option<Value>> = vec![None; plan.inputs.len()];
179
180    for (idx, value) in &plan.constants {
181        if let Some(slot) = inputs.get_mut(*idx) {
182            if slot.is_none() {
183                *slot = Some(value.clone());
184            }
185        }
186    }
187
188    for (idx, value_id) in plan.inputs.iter().enumerate() {
189        let info = graph
190            .value(*value_id)
191            .ok_or_else(|| format!("fusion: missing value metadata for id {value_id}"))?;
192        match &info.origin {
193            ValueOrigin::Variable { kind, index } => {
194                let value =
195                    match kind {
196                        VarKind::Global => vars
197                            .get(*index)
198                            .cloned()
199                            .ok_or_else(|| format!("fusion: global var {index} out of range"))?,
200                        VarKind::Local => {
201                            if let Some(frame) = context.call_stack.last() {
202                                let absolute = frame.locals_start + index;
203                                context.locals.get(absolute).cloned().ok_or_else(|| {
204                                    format!("fusion: local var {index} unavailable")
205                                })?
206                            } else {
207                                vars.get(*index).cloned().ok_or_else(|| {
208                                    format!("fusion: local var {index} unavailable")
209                                })?
210                            }
211                        }
212                    };
213                debug_assert!(
214                    inputs[idx].is_none(),
215                    "fusion: duplicate input slot {} for plan {}",
216                    idx,
217                    plan.index
218                );
219                inputs[idx] = Some(value);
220            }
221            ValueOrigin::Constant | ValueOrigin::NodeOutput { .. } | ValueOrigin::Unknown => {}
222        }
223    }
224
225    if log::log_enabled!(log::Level::Debug) && interp_engine::fusion_debug_enabled() {
226        let stack_needed_preview = required_stack_operands;
227        let stack_snapshot: Vec<&Value> = stack.iter().rev().take(stack_needed_preview).collect();
228        let stack_kinds: Vec<&'static str> =
229            stack_snapshot.iter().rev().map(|v| value_kind(v)).collect();
230        let input_meta: Vec<String> = plan
231            .inputs
232            .iter()
233            .enumerate()
234            .map(|(i, value_id)| {
235                if let Some(info) = graph.value(*value_id) {
236                    format!("#{i}:id={} origin={:?}", value_id, info.origin)
237                } else {
238                    format!("#{i}:id={} origin=<missing>", value_id)
239                }
240            })
241            .collect();
242        log::debug!(
243            "fusion group {} gather: stack_depth={} stack_needed={} stack_kinds={:?} pattern={:?} inputs={:?}",
244            plan.index, stack.len(), stack_needed_preview, stack_kinds, &plan.stack_pattern, input_meta
245        );
246    }
247
248    if stack.len() < required_stack_operands {
249        if interp_engine::fusion_debug_enabled() {
250            log::debug!(
251                "fusion stack underflow: plan={} needed={} available={} pattern={:?}",
252                plan.index,
253                required_stack_operands,
254                stack.len(),
255                plan.stack_pattern
256            );
257        }
258        return Err(mex(
259            "FusionStackUnderflow",
260            "fusion: stack underflow gathering inputs",
261        ));
262    }
263    let available = required_stack_operands;
264    let slice_start = stack.len() - available;
265    let stack_guard = StackSliceGuard::new(stack, slice_start);
266    let slice = stack_guard.slice().to_vec();
267    let mut consumed_inputs: Vec<Option<Value>> = vec![None; plan.inputs.len()];
268    let input_positions: HashMap<runmat_accelerate::graph::ValueId, usize> = plan
269        .inputs
270        .iter()
271        .enumerate()
272        .map(|(idx, value_id)| (*value_id, idx))
273        .collect();
274
275    let allow_stack_value = |val: &Value| {
276        if plan.group.kind.is_reduction() {
277            matches!(val, Value::GpuTensor(_) | Value::Tensor(_))
278        } else {
279            true
280        }
281    };
282
283    if let Some(layout) = plan.group.stack_layout.as_ref() {
284        for binding in &layout.bindings {
285            let Some(input_idx) = input_positions.get(&binding.value_id).copied() else {
286                continue;
287            };
288            let Some(val) = slice.get(binding.stack_offset).cloned() else {
289                continue;
290            };
291            consumed_inputs[input_idx] = Some(val.clone());
292            if inputs[input_idx].is_none() && allow_stack_value(&val) {
293                inputs[input_idx] = Some(val);
294            }
295        }
296    } else {
297        for (offset, input_idx) in plan.stack_pattern.iter().enumerate() {
298            let Some(val) = slice.get(offset).cloned() else {
299                continue;
300            };
301            consumed_inputs[*input_idx] = Some(val.clone());
302            if inputs[*input_idx].is_none() && allow_stack_value(&val) {
303                inputs[*input_idx] = Some(val);
304            }
305        }
306    }
307
308    for (idx, slot) in inputs.iter_mut().enumerate() {
309        if slot.is_some() {
310            continue;
311        }
312        let vid = plan.inputs[idx];
313        let info = graph.value(vid);
314        if let Some(info) = info {
315            match &info.origin {
316                ValueOrigin::Variable { kind, index } => {
317                    let value_opt = match kind {
318                        VarKind::Global => vars.get(*index).cloned(),
319                        VarKind::Local => {
320                            if let Some(frame) = context.call_stack.last() {
321                                let absolute = frame.locals_start + index;
322                                context.locals.get(absolute).cloned()
323                            } else {
324                                vars.get(*index).cloned()
325                            }
326                        }
327                    };
328                    if let Some(value) = value_opt {
329                        *slot = Some(value);
330                        continue;
331                    }
332                }
333                ValueOrigin::Constant => {
334                    if let Some(value) = plan.const_values.get(&vid) {
335                        *slot = Some(value.clone());
336                        continue;
337                    }
338                }
339                _ => {}
340            }
341        }
342        if slot.is_none() {
343            if let Some(binding) = graph.var_binding(vid) {
344                let value_opt = match binding.kind {
345                    VarKind::Global => vars.get(binding.index).cloned(),
346                    VarKind::Local => {
347                        if let Some(frame) = context.call_stack.last() {
348                            let absolute = frame.locals_start + binding.index;
349                            context.locals.get(absolute).cloned()
350                        } else {
351                            vars.get(binding.index).cloned()
352                        }
353                    }
354                };
355                if let Some(value) = value_opt {
356                    *slot = Some(value);
357                    continue;
358                }
359            }
360        }
361        if slot.is_none() {
362            if let Some(info) = info {
363                if let ValueOrigin::NodeOutput { node, .. } = info.origin {
364                    if let Some(binding) = graph.node_binding(node) {
365                        let value_opt = match binding.kind {
366                            VarKind::Global => vars.get(binding.index).cloned(),
367                            VarKind::Local => {
368                                if let Some(frame) = context.call_stack.last() {
369                                    let absolute = frame.locals_start + binding.index;
370                                    context.locals.get(absolute).cloned()
371                                } else {
372                                    vars.get(binding.index).cloned()
373                                }
374                            }
375                        };
376                        if let Some(value) = value_opt {
377                            *slot = Some(value);
378                            continue;
379                        }
380                    }
381                }
382            }
383        }
384        if slot.is_none() {
385            if let Some(value) = plan.const_values.get(&vid) {
386                *slot = Some(value.clone());
387            }
388        }
389    }
390
391    let inputs: Vec<Value> = inputs
392        .into_iter()
393        .map(|opt| opt.ok_or_else(|| mex("FusionMissingInput", "fusion: missing input value")))
394        .collect::<Result<_, _>>()?;
395
396    if log::log_enabled!(log::Level::Debug) {
397        let summaries: Vec<String> = inputs
398            .iter()
399            .enumerate()
400            .map(|(i, v)| summarize_value(i, v))
401            .collect();
402        log::debug!("fusion inputs runtime: [{}]", summaries.join(", "));
403    }
404
405    Ok((
406        stack_guard,
407        FusionExecutionRequest { plan, inputs },
408        consumed_inputs,
409    ))
410}
411
412pub fn write_elementwise_materialized_stores(
413    materialized_stores: Vec<(FusionStoreMaterialization, Value)>,
414    vars: &mut Vec<Value>,
415    context: &mut ExecutionContext,
416) {
417    for (store, value) in materialized_stores {
418        match store.binding.kind {
419            VarKind::Global => {
420                let i = store.binding.index;
421                if i < vars.len() {
422                    accel_residency::clear_value_excluding(&vars[i], &value);
423                }
424                if i >= vars.len() {
425                    vars.resize(i + 1, Value::Num(0.0));
426                    refresh_workspace_state(vars);
427                }
428                vars[i] = value;
429            }
430            VarKind::Local => {
431                if let Some(frame) = context.call_stack.last() {
432                    let absolute = frame.locals_start + store.binding.index;
433                    while context.locals.len() <= absolute {
434                        context.locals.push(Value::Num(0.0));
435                    }
436                    accel_residency::clear_value_excluding(&context.locals[absolute], &value);
437                    context.locals[absolute] = value;
438                } else {
439                    let i = store.binding.index;
440                    if i < vars.len() {
441                        accel_residency::clear_value_excluding(&vars[i], &value);
442                    }
443                    if i >= vars.len() {
444                        vars.resize(i + 1, Value::Num(0.0));
445                        refresh_workspace_state(vars);
446                    }
447                    vars[i] = value;
448                }
449            }
450        }
451    }
452}
453
454pub fn execute_fusion_elementwise(
455    request: FusionExecutionRequest<'_>,
456    stack_guard: StackSliceGuard<'_>,
457    vars: &mut Vec<Value>,
458    context: &mut ExecutionContext,
459) -> Result<Value, RuntimeError> {
460    match execute_elementwise(request) {
461        Ok(result) => {
462            write_elementwise_materialized_stores(result.materialized_stores, vars, context);
463            stack_guard.commit();
464            Ok(result.final_value)
465        }
466        Err(err) => Err(mex("FusionExecutionFailed", &err.to_string())),
467    }
468}
469
470pub async fn execute_fusion_special_kind(
471    kind: FusionKind,
472    plan_inputs: &[runmat_accelerate::graph::ValueId],
473    request: FusionExecutionRequest<'_>,
474    stack_guard: StackSliceGuard<'_>,
475) -> Result<Value, RuntimeError> {
476    match kind {
477        FusionKind::CenteredGram => match execute_centered_gram(request).await {
478            Ok(result) => {
479                stack_guard.commit();
480                Ok(result)
481            }
482            Err(err) => Err(mex("FusionExecutionFailed", &err.to_string())),
483        },
484        FusionKind::PowerStepNormalize => match execute_power_step_normalize(request).await {
485            Ok(result) => {
486                stack_guard.commit();
487                Ok(result)
488            }
489            Err(err) => Err(mex("FusionExecutionFailed", &err.to_string())),
490        },
491        FusionKind::ExplainedVariance => {
492            log::debug!("explained variance plan inputs {:?}", plan_inputs);
493            match execute_explained_variance(request).await {
494                Ok(result) => {
495                    stack_guard.commit();
496                    Ok(result)
497                }
498                Err(err) => {
499                    log::debug!("explained variance fusion fallback: {}", err);
500                    Err(mex("FusionExecutionFailed", &err.to_string()))
501                }
502            }
503        }
504        FusionKind::MatmulEpilogue => match execute_matmul_epilogue(request).await {
505            Ok(result) => {
506                stack_guard.commit();
507                Ok(result)
508            }
509            Err(err) => Err(mex("FusionExecutionFailed", &err.to_string())),
510        },
511        FusionKind::ImageNormalize => match execute_image_normalize(request).await {
512            Ok(result) => {
513                stack_guard.commit();
514                Ok(result)
515            }
516            Err(err) => Err(mex("FusionExecutionFailed", &err.to_string())),
517        },
518        _ => Err(mex(
519            "FusionUnsupportedKind",
520            "fusion: unsupported fusion kind",
521        )),
522    }
523}
524
525pub struct ReductionGeometry {
526    pub axis: usize,
527    pub reduce_len: usize,
528    pub num_slices: usize,
529}
530
531pub fn resolve_reduction_geometry(
532    plan: &runmat_accelerate::FusionGroupPlan,
533    graph: &runmat_accelerate::AccelGraph,
534    request: &FusionExecutionRequest<'_>,
535    consumed_inputs: &[Option<Value>],
536    vars: &[Value],
537    context: &ExecutionContext,
538) -> Result<ReductionGeometry, RuntimeError> {
539    fn detect_reduce_all(
540        plan: &runmat_accelerate::FusionGroupPlan,
541        graph: &runmat_accelerate::AccelGraph,
542    ) -> bool {
543        let mut reduce_all = matches!(
544            plan.reduction_axes,
545            Some(runmat_accelerate::ReductionAxes::All)
546        );
547        let has_all = reduce_all
548            || plan.constants.values().any(value_is_all_keyword)
549            || plan.const_values.values().any(value_is_all_keyword);
550        if has_all {
551            return true;
552        }
553        for node_id in &plan.group.nodes {
554            if let Some(node) = graph.node(*node_id) {
555                if let runmat_accelerate::graph::AccelNodeLabel::Builtin { name } = &node.label {
556                    if name.eq_ignore_ascii_case("mean") {
557                        for input_vid in &node.inputs {
558                            if let Some(info) = graph.value(*input_vid) {
559                                if let Some(constant) = &info.constant {
560                                    if value_is_all_keyword(constant) {
561                                        reduce_all = true;
562                                        break;
563                                    }
564                                }
565                            }
566                        }
567                    }
568                }
569            }
570            if reduce_all {
571                break;
572            }
573        }
574        reduce_all
575    }
576
577    fn resolve_reduction_axis(plan: &runmat_accelerate::FusionGroupPlan) -> (usize, bool) {
578        let mut axis = 0usize;
579        let mut axis_explicit = false;
580        if let Some(runmat_accelerate::ReductionAxes::Explicit(dims)) = &plan.reduction_axes {
581            if let Some(first) = dims.first().copied() {
582                axis = first.saturating_sub(1);
583                axis_explicit = true;
584            }
585        }
586        if let Some(dim_vid) = plan.reduction_dim {
587            if let Some(cv) = plan.const_values.get(&dim_vid) {
588                axis = match cv {
589                    Value::Num(n) if *n >= 1.0 => (*n as usize).saturating_sub(1),
590                    Value::Int(i) => (i.to_f64() as usize).saturating_sub(1),
591                    _ => axis,
592                };
593                axis_explicit = true;
594            } else if let Some(input_idx) = plan.inputs.iter().position(|v| *v == dim_vid) {
595                if let Some(cv) = plan.constants.get(&input_idx) {
596                    axis = match cv {
597                        Value::Num(n) if *n >= 1.0 => (*n as usize).saturating_sub(1),
598                        Value::Int(i) => (i.to_f64() as usize).saturating_sub(1),
599                        _ => axis,
600                    };
601                    axis_explicit = true;
602                }
603            }
604        } else if let Some(dim_const) = plan.constants.get(&1) {
605            axis = match dim_const {
606                Value::Num(n) if *n >= 1.0 => (*n as usize).saturating_sub(1),
607                Value::Int(i) => (i.to_f64() as usize).saturating_sub(1),
608                _ => axis,
609            };
610            axis_explicit = true;
611        }
612        (axis, axis_explicit)
613    }
614
615    fn derive_rows_cols(
616        plan: &runmat_accelerate::FusionGroupPlan,
617        graph: &runmat_accelerate::AccelGraph,
618        request: &FusionExecutionRequest<'_>,
619        consumed_inputs: &[Option<Value>],
620        vars: &[Value],
621        context: &ExecutionContext,
622    ) -> Option<(usize, usize)> {
623        let shape_of = |value: &Value| -> Option<(usize, usize)> {
624            match value {
625                Value::GpuTensor(h) => Some((
626                    h.shape.first().copied().unwrap_or(1).max(1),
627                    h.shape.get(1).copied().unwrap_or(1).max(1),
628                )),
629                Value::Tensor(t) => Some((
630                    t.shape.first().copied().unwrap_or(1).max(1),
631                    t.shape.get(1).copied().unwrap_or(1).max(1),
632                )),
633                _ => None,
634            }
635        };
636
637        if let Some(shape) = plan.reduction_data_shape(graph) {
638            if shape.len() >= 2 {
639                return Some((shape[0].max(1), shape[1].max(1)));
640            }
641            if shape.len() == 1 {
642                return Some((shape[0].max(1), 1));
643            }
644        }
645
646        for &vid in &plan.inputs {
647            if let Some(binding) = graph.var_binding(vid) {
648                let value_opt = match binding.kind {
649                    VarKind::Global => vars.get(binding.index).cloned(),
650                    VarKind::Local => {
651                        if let Some(frame) = context.call_stack.last() {
652                            let absolute = frame.locals_start + binding.index;
653                            context.locals.get(absolute).cloned()
654                        } else {
655                            vars.get(binding.index).cloned()
656                        }
657                    }
658                };
659                if let Some(value) = value_opt {
660                    if let Some(shape) = shape_of(&value) {
661                        return Some(shape);
662                    }
663                }
664            }
665        }
666
667        for v in consumed_inputs.iter().filter_map(|v| v.as_ref()) {
668            if let Some(shape) = shape_of(v) {
669                return Some(shape);
670            }
671        }
672
673        if let Some(data_id) = plan.reduction_data {
674            if let Some(input_index) = plan.inputs.iter().position(|vid| *vid == data_id) {
675                if let Some(val) = consumed_inputs.get(input_index).and_then(|v| v.as_ref()) {
676                    if let Some(shape) = shape_of(val) {
677                        return Some(shape);
678                    }
679                }
680                if let Some(val) = request.inputs.get(input_index) {
681                    if let Some(shape) = shape_of(val) {
682                        return Some(shape);
683                    }
684                }
685            }
686            if let Some(info) = graph.value(data_id) {
687                if let ValueOrigin::Variable { kind, index } = &info.origin {
688                    let val = match kind {
689                        VarKind::Global => vars.get(*index).cloned(),
690                        VarKind::Local => {
691                            if let Some(frame) = context.call_stack.last() {
692                                let absolute = frame.locals_start + index;
693                                context.locals.get(absolute).cloned()
694                            } else {
695                                vars.get(*index).cloned()
696                            }
697                        }
698                    };
699                    if let Some(v) = val {
700                        if let Some(shape) = shape_of(&v) {
701                            return Some(shape);
702                        }
703                    }
704                }
705                if let ShapeInfo::Tensor(dims) = &info.shape {
706                    if !dims.is_empty() {
707                        let r = dims.first().and_then(|d| *d).unwrap_or(1);
708                        let c = dims.get(1).and_then(|d| *d).unwrap_or(1);
709                        return Some((r.max(1), c.max(1)));
710                    }
711                }
712            }
713        }
714
715        for v in &request.inputs {
716            if let Some(shape) = shape_of(v) {
717                return Some(shape);
718            }
719        }
720
721        if let ShapeInfo::Tensor(dims) = &plan.group.shape {
722            if !dims.is_empty() {
723                let r = dims.first().and_then(|d| *d).unwrap_or(1);
724                let c = dims.get(1).and_then(|d| *d).unwrap_or(1);
725                return Some((r.max(1), c.max(1)));
726            }
727        }
728        None
729    }
730
731    if log::log_enabled!(log::Level::Debug) {
732        let meta: Vec<String> = plan
733            .inputs
734            .iter()
735            .map(|vid| {
736                if let Some(info) = graph.value(*vid) {
737                    format!(
738                        "vid={} origin={:?} shape={:?}",
739                        vid, info.origin, info.shape
740                    )
741                } else {
742                    format!("vid={} origin=<missing>", vid)
743                }
744            })
745            .collect();
746        log::debug!("reduction gather meta: [{}]", meta.join(", "));
747    }
748
749    let reduce_all = detect_reduce_all(plan, graph);
750    let (mut axis, axis_explicit) = if reduce_all {
751        (0usize, false)
752    } else {
753        resolve_reduction_axis(plan)
754    };
755    if reduce_all && interp_engine::fusion_debug_enabled() {
756        log::debug!(
757            "fusion reduction (all) meta: data_vid={:?} inputs={:?} stack_pattern={:?}",
758            plan.reduction_data,
759            plan.inputs,
760            plan.stack_pattern
761        );
762    }
763
764    let (r, c) =
765        derive_rows_cols(plan, graph, request, consumed_inputs, vars, context).unwrap_or((1, 1));
766    let (reduce_len, num_slices) = if reduce_all {
767        let total_from_runtime = consumed_inputs
768            .iter()
769            .filter_map(|v| v.as_ref())
770            .chain(request.inputs.iter())
771            .find_map(|value| match value {
772                Value::GpuTensor(handle) => Some(if handle.shape.is_empty() {
773                    1
774                } else {
775                    handle
776                        .shape
777                        .iter()
778                        .copied()
779                        .map(|d| d.max(1))
780                        .product::<usize>()
781                }),
782                Value::Tensor(tensor) => Some(if tensor.shape.is_empty() {
783                    1
784                } else {
785                    tensor
786                        .shape
787                        .iter()
788                        .copied()
789                        .map(|d| d.max(1))
790                        .product::<usize>()
791                }),
792                _ => None,
793            });
794        let total = plan
795            .reduction_data_shape(graph)
796            .map(|shape| shape.into_iter().map(|d| d.max(1)).product::<usize>())
797            .or(total_from_runtime)
798            .or_else(|| plan.element_count())
799            .filter(|v| *v > 0)
800            .ok_or_else(|| {
801                mex(
802                    "FusionReductionExtentUnknown",
803                    "fusion: reduction all extent unknown",
804                )
805            })?;
806        if interp_engine::fusion_debug_enabled() {
807            log::debug!(
808                "fusion reduction (all): total_elems={} fallback_rows={} fallback_cols={}",
809                total,
810                r,
811                c
812            );
813        }
814        (total, 1usize)
815    } else {
816        if !axis_explicit {
817            axis = if r == 1 && c > 1 {
818                1
819            } else if r > 1 {
820                0
821            } else {
822                axis
823            };
824        }
825        if interp_engine::fusion_debug_enabled() {
826            if r == 1 && c == 1 {
827                log::debug!(
828                    "fusion reduction: unresolved shape (defaulted to 1x1); axis={}, constants={:?}",
829                    axis,
830                    plan.constants
831                );
832            } else {
833                log::debug!(
834                    "fusion reduction: resolved shape rows={} cols={} axis={} constants={:?}",
835                    r,
836                    c,
837                    axis,
838                    plan.constants
839                );
840            }
841        }
842        if axis == 0 {
843            (r, c)
844        } else {
845            (c, r)
846        }
847    };
848
849    if interp_engine::fusion_debug_enabled() {
850        log::debug!(
851            "fusion reduction: axis={} reduce_len={} num_slices={} constants={:?}",
852            axis,
853            reduce_len,
854            num_slices,
855            plan.constants
856        );
857    }
858
859    let looks_wrong = reduce_len == 1 && num_slices == 1 && {
860        let mut big = false;
861        let mut check_val = |v: &Value| match v {
862            Value::GpuTensor(h) => {
863                let prod = h.shape.iter().copied().product::<usize>();
864                if prod > 1 {
865                    big = true;
866                }
867            }
868            Value::Tensor(t) => {
869                let prod = t.shape.iter().copied().product::<usize>();
870                if prod > 1 {
871                    big = true;
872                }
873            }
874            _ => {}
875        };
876        for v in consumed_inputs.iter().filter_map(|v| v.as_ref()) {
877            check_val(v);
878        }
879        for v in &request.inputs {
880            check_val(v);
881        }
882        big
883    };
884    if looks_wrong {
885        log::debug!("fusion reduction: skipping fusion due to unresolved shape; falling back to provider path");
886        return Err(mex(
887            "FusionReductionShapeUnresolved",
888            "fusion: reduction shape unresolved",
889        ));
890    }
891    if std::env::var("RUNMAT_DISABLE_FUSED_REDUCTION")
892        .ok()
893        .as_deref()
894        == Some("1")
895    {
896        return Err(mex(
897            "FusionReductionDisabled",
898            "fusion: fused reductions disabled",
899        ));
900    }
901
902    Ok(ReductionGeometry {
903        axis,
904        reduce_len,
905        num_slices,
906    })
907}
908
909pub fn execute_fusion_reduction(
910    plan: &runmat_accelerate::FusionGroupPlan,
911    graph: &runmat_accelerate::AccelGraph,
912    request: FusionExecutionRequest<'_>,
913    consumed_inputs: &[Option<Value>],
914    stack_guard: StackSliceGuard<'_>,
915    vars: &[Value],
916    context: &ExecutionContext,
917) -> Result<Value, RuntimeError> {
918    let geom = resolve_reduction_geometry(plan, graph, &request, consumed_inputs, vars, context)?;
919    match execute_reduction(request, geom.reduce_len, geom.num_slices, 256u32) {
920        Ok(result) => {
921            stack_guard.commit();
922            Ok(result)
923        }
924        Err(err) => Err(mex("FusionExecutionFailed", &err.to_string())),
925    }
926}
927
928pub async fn try_execute_fusion_group(
929    plan: &runmat_accelerate::FusionGroupPlan,
930    graph: &runmat_accelerate::AccelGraph,
931    stack: &mut Vec<Value>,
932    vars: &mut Vec<Value>,
933    context: &mut ExecutionContext,
934) -> Result<Value, RuntimeError> {
935    let (stack_guard, request, consumed_inputs) =
936        gather_fusion_inputs(plan, graph, stack, vars, context)?;
937    if plan.group.kind.is_elementwise()
938        && !request.inputs.is_empty()
939        && request.inputs.iter().all(is_scalarish_runtime_value)
940    {
941        return Err(mex(
942            "FusionScalarBypass",
943            "fusion: bypass scalar-only elementwise group",
944        ));
945    }
946    log::debug!(
947        "dispatch fusion kind {:?}, supported {}",
948        plan.group.kind,
949        plan.kernel.supported
950    );
951    if plan.group.kind.is_elementwise() {
952        execute_fusion_elementwise(request, stack_guard, vars, context)
953    } else if plan.group.kind.is_reduction() {
954        execute_fusion_reduction(
955            plan,
956            graph,
957            request,
958            &consumed_inputs,
959            stack_guard,
960            vars,
961            context,
962        )
963    } else {
964        execute_fusion_special_kind(plan.group.kind.clone(), &plan.inputs, request, stack_guard)
965            .await
966    }
967}
968
969#[cfg(all(test, feature = "native-accel"))]
970mod tests {
971    use super::write_elementwise_materialized_stores;
972    use crate::bytecode::program::ExecutionContext;
973    use runmat_accelerate::fusion::FusionStoreMaterialization;
974    use runmat_accelerate::fusion_residency;
975    use runmat_accelerate::graph::VarBinding;
976    use runmat_accelerate::VarKind;
977    use runmat_accelerate_api::GpuTensorHandle;
978    use runmat_builtins::Value;
979
980    #[test]
981    fn fusion_writeback_preserves_shared_gpu_handles() {
982        let shared = GpuTensorHandle {
983            shape: vec![1],
984            device_id: 17,
985            buffer_id: 17001,
986        };
987        let old_only = GpuTensorHandle {
988            shape: vec![1],
989            device_id: 17,
990            buffer_id: 17002,
991        };
992        fusion_residency::mark(&shared);
993        fusion_residency::mark(&old_only);
994        assert!(fusion_residency::is_resident(&shared));
995        assert!(fusion_residency::is_resident(&old_only));
996
997        let mut vars = vec![Value::OutputList(vec![
998            Value::GpuTensor(shared.clone()),
999            Value::GpuTensor(old_only.clone()),
1000        ])];
1001        let mut context = ExecutionContext {
1002            call_stack: Vec::new(),
1003            locals: Vec::new(),
1004            instruction_pointer: 0,
1005            spawned_task_ids: std::collections::HashSet::new(),
1006            next_spawn_task_id: 0,
1007        };
1008        write_elementwise_materialized_stores(
1009            vec![(
1010                FusionStoreMaterialization {
1011                    value_id: 1,
1012                    binding: VarBinding {
1013                        kind: VarKind::Global,
1014                        index: 0,
1015                    },
1016                },
1017                Value::GpuTensor(shared.clone()),
1018            )],
1019            &mut vars,
1020            &mut context,
1021        );
1022
1023        assert!(fusion_residency::is_resident(&shared));
1024        assert!(!fusion_residency::is_resident(&old_only));
1025        fusion_residency::clear(&shared);
1026    }
1027}