Skip to main content

runmat_vm/accel/
fusion.rs

1use crate::accel::residency as accel_residency;
2use crate::bytecode::program::ExecutionContext;
3use crate::bytecode::Instr;
4use crate::interpreter::engine as interp_engine;
5use crate::interpreter::errors::mex;
6use crate::runtime::workspace::refresh_workspace_state;
7use runmat_accelerate::fusion::FusionStoreMaterialization;
8use runmat_accelerate::fusion_exec::{
9    execute_centered_gram, execute_elementwise, execute_explained_variance,
10    execute_image_normalize, execute_matmul_epilogue, execute_power_step_normalize,
11    execute_reduction, FusionExecutionRequest,
12};
13use runmat_accelerate::InstrSpan;
14use runmat_accelerate::{value_is_all_keyword, FusionKind, ShapeInfo, ValueOrigin, VarKind};
15use runmat_builtins::Value;
16use runmat_runtime::builtins::common::shape::is_scalar_shape;
17use runmat_runtime::RuntimeError;
18use std::collections::HashMap;
19
20#[inline]
21pub fn value_kind(value: &Value) -> &'static str {
22    match value {
23        Value::Int(_) => "Int",
24        Value::Num(_) => "Num",
25        Value::Complex(_, _) => "Complex",
26        Value::Bool(_) => "Bool",
27        Value::LogicalArray(_) => "LogicalArray",
28        Value::String(_) => "String",
29        Value::StringArray(_) => "StringArray",
30        Value::CharArray(_) => "CharArray",
31        Value::Tensor(_) => "Tensor",
32        Value::SparseTensor(_) => "SparseTensor",
33        Value::ComplexTensor(_) => "ComplexTensor",
34        Value::Cell(_) => "Cell",
35        Value::Struct(_) => "Struct",
36        Value::GpuTensor(_) => "GpuTensor",
37        Value::Object(_) => "Object",
38        Value::HandleObject(_) => "HandleObject",
39        Value::Listener(_) => "Listener",
40        Value::FunctionHandle(_)
41        | Value::ExternalFunctionHandle(_)
42        | Value::MethodFunctionHandle(_) => "FunctionHandle",
43        Value::BoundFunctionHandle { .. } => "FunctionHandle",
44        Value::Closure(_) => "Closure",
45        Value::ClassRef(_) => "ClassRef",
46        Value::MException(_) => "MException",
47        Value::OutputList(_) => "OutputList",
48    }
49}
50
51#[inline]
52pub fn summarize_value(i: usize, v: &Value) -> String {
53    match v {
54        Value::GpuTensor(h) => format!("in#{i}:GpuTensor shape={:?}", h.shape),
55        Value::Tensor(t) => format!("in#{i}:Tensor shape={:?}", t.shape),
56        Value::Num(n) => format!("in#{i}:Num({n:.6})"),
57        Value::Int(n) => format!("in#{i}:Int({})", n.to_i64()),
58        Value::Bool(b) => format!("in#{i}:Bool({})", if *b { 1 } else { 0 }),
59        Value::String(s) => format!("in#{i}:String({})", s),
60        _ => format!("in#{i}:{}", value_kind(v)),
61    }
62}
63
64#[inline]
65fn is_scalarish_runtime_value(value: &Value) -> bool {
66    match value {
67        Value::Num(_) | Value::Int(_) | Value::Bool(_) | Value::Complex(_, _) => true,
68        Value::Tensor(tensor) => is_scalar_shape(&tensor.shape),
69        Value::ComplexTensor(tensor) => is_scalar_shape(&tensor.shape),
70        Value::LogicalArray(array) => is_scalar_shape(&array.shape),
71        Value::GpuTensor(handle) => is_scalar_shape(&handle.shape),
72        Value::CharArray(array) => array.rows * array.cols == 1,
73        _ => false,
74    }
75}
76
77pub fn fusion_span_live_result_count(instructions: &[Instr], span: &InstrSpan) -> Option<usize> {
78    if span.start > span.end || span.end >= instructions.len() {
79        return None;
80    }
81    let mut current_depth = 0usize;
82    for instr in &instructions[span.start..=span.end] {
83        let effect = instr.stack_effect()?;
84        if current_depth < effect.pops {
85            current_depth = effect.pops;
86        }
87        current_depth = current_depth - effect.pops + effect.pushes;
88    }
89    Some(current_depth)
90}
91
92pub fn fusion_span_has_vm_barrier(instructions: &[Instr], span: &InstrSpan) -> bool {
93    if span.start > span.end || span.end >= instructions.len() {
94        return true;
95    }
96    for instr in &instructions[span.start..=span.end] {
97        if matches!(
98            instr,
99            Instr::StoreIndex(_)
100                | Instr::StoreIndexDelete(_)
101                | Instr::StoreSlice(_, _, _, _)
102                | Instr::StoreSliceDelete(_, _, _, _)
103                | Instr::StoreSliceExpr { .. }
104                | Instr::StoreSliceExprDelete { .. }
105                | Instr::StoreIndexCell { .. }
106                | Instr::StoreIndexCellDelete { .. }
107                | Instr::StoreMember(_)
108                | Instr::StoreMemberOrInit(_)
109                | Instr::StoreMemberDynamic
110                | Instr::StoreMemberDynamicOrInit
111        ) {
112            return true;
113        }
114    }
115    fusion_span_live_result_count(instructions, span) != Some(1)
116}
117
118pub struct StackSliceGuard<'a> {
119    stack: *mut Vec<Value>,
120    slice: Option<Vec<Value>>,
121    _marker: std::marker::PhantomData<&'a mut Vec<Value>>,
122}
123
124impl<'a> StackSliceGuard<'a> {
125    pub fn new(stack: &'a mut Vec<Value>, slice_start: usize) -> Self {
126        let slice = stack.split_off(slice_start);
127        Self {
128            stack,
129            slice: Some(slice),
130            _marker: std::marker::PhantomData,
131        }
132    }
133
134    pub fn slice(&self) -> &[Value] {
135        self.slice.as_ref().expect("stack slice missing").as_slice()
136    }
137
138    pub fn commit(mut self) {
139        self.slice = None;
140    }
141}
142
143impl Drop for StackSliceGuard<'_> {
144    fn drop(&mut self) {
145        if let Some(slice) = self.slice.take() {
146            unsafe { (&mut *self.stack).extend(slice) }
147        }
148    }
149}
150
151pub fn gather_fusion_inputs<'a>(
152    plan: &'a runmat_accelerate::FusionGroupPlan,
153    graph: &runmat_accelerate::AccelGraph,
154    stack: &'a mut Vec<Value>,
155    vars: &mut [Value],
156    context: &mut ExecutionContext,
157) -> Result<
158    (
159        StackSliceGuard<'a>,
160        FusionExecutionRequest<'a>,
161        Vec<Option<Value>>,
162    ),
163    RuntimeError,
164> {
165    if plan.group.stack_layout.is_none() && !plan.stack_pattern.is_empty() {
166        return Err(mex(
167            "FusionMissingStackLayout",
168            "fusion: missing compile-time stack layout metadata",
169        ));
170    }
171    let required_stack_operands = plan
172        .group
173        .stack_layout
174        .as_ref()
175        .map(|layout| layout.required_stack_operands)
176        .unwrap_or_else(|| plan.stack_pattern.len());
177    let mut inputs: Vec<Option<Value>> = vec![None; plan.inputs.len()];
178
179    for (idx, value) in &plan.constants {
180        if let Some(slot) = inputs.get_mut(*idx) {
181            if slot.is_none() {
182                *slot = Some(value.clone());
183            }
184        }
185    }
186
187    for (idx, value_id) in plan.inputs.iter().enumerate() {
188        let info = graph
189            .value(*value_id)
190            .ok_or_else(|| format!("fusion: missing value metadata for id {value_id}"))?;
191        match &info.origin {
192            ValueOrigin::Variable { kind, index } => {
193                let value =
194                    match kind {
195                        VarKind::Global => vars
196                            .get(*index)
197                            .cloned()
198                            .ok_or_else(|| format!("fusion: global var {index} out of range"))?,
199                        VarKind::Local => {
200                            if let Some(frame) = context.call_stack.last() {
201                                let absolute = frame.locals_start + index;
202                                context.locals.get(absolute).cloned().ok_or_else(|| {
203                                    format!("fusion: local var {index} unavailable")
204                                })?
205                            } else {
206                                vars.get(*index).cloned().ok_or_else(|| {
207                                    format!("fusion: local var {index} unavailable")
208                                })?
209                            }
210                        }
211                    };
212                debug_assert!(
213                    inputs[idx].is_none(),
214                    "fusion: duplicate input slot {} for plan {}",
215                    idx,
216                    plan.index
217                );
218                inputs[idx] = Some(value);
219            }
220            ValueOrigin::Constant | ValueOrigin::NodeOutput { .. } | ValueOrigin::Unknown => {}
221        }
222    }
223
224    if log::log_enabled!(log::Level::Debug) && interp_engine::fusion_debug_enabled() {
225        let stack_needed_preview = required_stack_operands;
226        let stack_snapshot: Vec<&Value> = stack.iter().rev().take(stack_needed_preview).collect();
227        let stack_kinds: Vec<&'static str> =
228            stack_snapshot.iter().rev().map(|v| value_kind(v)).collect();
229        let input_meta: Vec<String> = plan
230            .inputs
231            .iter()
232            .enumerate()
233            .map(|(i, value_id)| {
234                if let Some(info) = graph.value(*value_id) {
235                    format!("#{i}:id={} origin={:?}", value_id, info.origin)
236                } else {
237                    format!("#{i}:id={} origin=<missing>", value_id)
238                }
239            })
240            .collect();
241        log::debug!(
242            "fusion group {} gather: stack_depth={} stack_needed={} stack_kinds={:?} pattern={:?} inputs={:?}",
243            plan.index, stack.len(), stack_needed_preview, stack_kinds, &plan.stack_pattern, input_meta
244        );
245    }
246
247    if stack.len() < required_stack_operands {
248        if interp_engine::fusion_debug_enabled() {
249            log::debug!(
250                "fusion stack underflow: plan={} needed={} available={} pattern={:?}",
251                plan.index,
252                required_stack_operands,
253                stack.len(),
254                plan.stack_pattern
255            );
256        }
257        return Err(mex(
258            "FusionStackUnderflow",
259            "fusion: stack underflow gathering inputs",
260        ));
261    }
262    let available = required_stack_operands;
263    let slice_start = stack.len() - available;
264    let stack_guard = StackSliceGuard::new(stack, slice_start);
265    let slice = stack_guard.slice().to_vec();
266    let mut consumed_inputs: Vec<Option<Value>> = vec![None; plan.inputs.len()];
267    let input_positions: HashMap<runmat_accelerate::graph::ValueId, usize> = plan
268        .inputs
269        .iter()
270        .enumerate()
271        .map(|(idx, value_id)| (*value_id, idx))
272        .collect();
273
274    let allow_stack_value = |val: &Value| {
275        if plan.group.kind.is_reduction() {
276            matches!(val, Value::GpuTensor(_) | Value::Tensor(_))
277        } else {
278            true
279        }
280    };
281
282    if let Some(layout) = plan.group.stack_layout.as_ref() {
283        for binding in &layout.bindings {
284            let Some(input_idx) = input_positions.get(&binding.value_id).copied() else {
285                continue;
286            };
287            let Some(val) = slice.get(binding.stack_offset).cloned() else {
288                continue;
289            };
290            consumed_inputs[input_idx] = Some(val.clone());
291            if inputs[input_idx].is_none() && allow_stack_value(&val) {
292                inputs[input_idx] = Some(val);
293            }
294        }
295    } else {
296        for (offset, input_idx) in plan.stack_pattern.iter().enumerate() {
297            let Some(val) = slice.get(offset).cloned() else {
298                continue;
299            };
300            consumed_inputs[*input_idx] = Some(val.clone());
301            if inputs[*input_idx].is_none() && allow_stack_value(&val) {
302                inputs[*input_idx] = Some(val);
303            }
304        }
305    }
306
307    for (idx, slot) in inputs.iter_mut().enumerate() {
308        if slot.is_some() {
309            continue;
310        }
311        let vid = plan.inputs[idx];
312        let info = graph.value(vid);
313        if let Some(info) = info {
314            match &info.origin {
315                ValueOrigin::Variable { kind, index } => {
316                    let value_opt = match kind {
317                        VarKind::Global => vars.get(*index).cloned(),
318                        VarKind::Local => {
319                            if let Some(frame) = context.call_stack.last() {
320                                let absolute = frame.locals_start + index;
321                                context.locals.get(absolute).cloned()
322                            } else {
323                                vars.get(*index).cloned()
324                            }
325                        }
326                    };
327                    if let Some(value) = value_opt {
328                        *slot = Some(value);
329                        continue;
330                    }
331                }
332                ValueOrigin::Constant => {
333                    if let Some(value) = plan.const_values.get(&vid) {
334                        *slot = Some(value.clone());
335                        continue;
336                    }
337                }
338                _ => {}
339            }
340        }
341        if slot.is_none() {
342            if let Some(binding) = graph.var_binding(vid) {
343                let value_opt = match binding.kind {
344                    VarKind::Global => vars.get(binding.index).cloned(),
345                    VarKind::Local => {
346                        if let Some(frame) = context.call_stack.last() {
347                            let absolute = frame.locals_start + binding.index;
348                            context.locals.get(absolute).cloned()
349                        } else {
350                            vars.get(binding.index).cloned()
351                        }
352                    }
353                };
354                if let Some(value) = value_opt {
355                    *slot = Some(value);
356                    continue;
357                }
358            }
359        }
360        if slot.is_none() {
361            if let Some(info) = info {
362                if let ValueOrigin::NodeOutput { node, .. } = info.origin {
363                    if let Some(binding) = graph.node_binding(node) {
364                        let value_opt = match binding.kind {
365                            VarKind::Global => vars.get(binding.index).cloned(),
366                            VarKind::Local => {
367                                if let Some(frame) = context.call_stack.last() {
368                                    let absolute = frame.locals_start + binding.index;
369                                    context.locals.get(absolute).cloned()
370                                } else {
371                                    vars.get(binding.index).cloned()
372                                }
373                            }
374                        };
375                        if let Some(value) = value_opt {
376                            *slot = Some(value);
377                            continue;
378                        }
379                    }
380                }
381            }
382        }
383        if slot.is_none() {
384            if let Some(value) = plan.const_values.get(&vid) {
385                *slot = Some(value.clone());
386            }
387        }
388    }
389
390    let inputs: Vec<Value> = inputs
391        .into_iter()
392        .map(|opt| opt.ok_or_else(|| mex("FusionMissingInput", "fusion: missing input value")))
393        .collect::<Result<_, _>>()?;
394
395    if log::log_enabled!(log::Level::Debug) {
396        let summaries: Vec<String> = inputs
397            .iter()
398            .enumerate()
399            .map(|(i, v)| summarize_value(i, v))
400            .collect();
401        log::debug!("fusion inputs runtime: [{}]", summaries.join(", "));
402    }
403
404    Ok((
405        stack_guard,
406        FusionExecutionRequest { plan, inputs },
407        consumed_inputs,
408    ))
409}
410
411pub fn write_elementwise_materialized_stores(
412    materialized_stores: Vec<(FusionStoreMaterialization, Value)>,
413    vars: &mut Vec<Value>,
414    context: &mut ExecutionContext,
415) {
416    for (store, value) in materialized_stores {
417        match store.binding.kind {
418            VarKind::Global => {
419                let i = store.binding.index;
420                if i < vars.len() {
421                    accel_residency::clear_value_excluding(&vars[i], &value);
422                }
423                if i >= vars.len() {
424                    vars.resize(i + 1, Value::Num(0.0));
425                    refresh_workspace_state(vars);
426                }
427                vars[i] = value;
428            }
429            VarKind::Local => {
430                if let Some(frame) = context.call_stack.last() {
431                    let absolute = frame.locals_start + store.binding.index;
432                    while context.locals.len() <= absolute {
433                        context.locals.push(Value::Num(0.0));
434                    }
435                    accel_residency::clear_value_excluding(&context.locals[absolute], &value);
436                    context.locals[absolute] = value;
437                } else {
438                    let i = store.binding.index;
439                    if i < vars.len() {
440                        accel_residency::clear_value_excluding(&vars[i], &value);
441                    }
442                    if i >= vars.len() {
443                        vars.resize(i + 1, Value::Num(0.0));
444                        refresh_workspace_state(vars);
445                    }
446                    vars[i] = value;
447                }
448            }
449        }
450    }
451}
452
453pub fn execute_fusion_elementwise(
454    request: FusionExecutionRequest<'_>,
455    stack_guard: StackSliceGuard<'_>,
456    vars: &mut Vec<Value>,
457    context: &mut ExecutionContext,
458) -> Result<Value, RuntimeError> {
459    match execute_elementwise(request) {
460        Ok(result) => {
461            write_elementwise_materialized_stores(result.materialized_stores, vars, context);
462            stack_guard.commit();
463            Ok(result.final_value)
464        }
465        Err(err) => Err(mex("FusionExecutionFailed", &err.to_string())),
466    }
467}
468
469pub async fn execute_fusion_special_kind(
470    kind: FusionKind,
471    plan_inputs: &[runmat_accelerate::graph::ValueId],
472    request: FusionExecutionRequest<'_>,
473    stack_guard: StackSliceGuard<'_>,
474) -> Result<Value, RuntimeError> {
475    match kind {
476        FusionKind::CenteredGram => match execute_centered_gram(request).await {
477            Ok(result) => {
478                stack_guard.commit();
479                Ok(result)
480            }
481            Err(err) => Err(mex("FusionExecutionFailed", &err.to_string())),
482        },
483        FusionKind::PowerStepNormalize => match execute_power_step_normalize(request).await {
484            Ok(result) => {
485                stack_guard.commit();
486                Ok(result)
487            }
488            Err(err) => Err(mex("FusionExecutionFailed", &err.to_string())),
489        },
490        FusionKind::ExplainedVariance => {
491            log::debug!("explained variance plan inputs {:?}", plan_inputs);
492            match execute_explained_variance(request).await {
493                Ok(result) => {
494                    stack_guard.commit();
495                    Ok(result)
496                }
497                Err(err) => {
498                    log::debug!("explained variance fusion fallback: {}", err);
499                    Err(mex("FusionExecutionFailed", &err.to_string()))
500                }
501            }
502        }
503        FusionKind::MatmulEpilogue => match execute_matmul_epilogue(request).await {
504            Ok(result) => {
505                stack_guard.commit();
506                Ok(result)
507            }
508            Err(err) => Err(mex("FusionExecutionFailed", &err.to_string())),
509        },
510        FusionKind::ImageNormalize => match execute_image_normalize(request).await {
511            Ok(result) => {
512                stack_guard.commit();
513                Ok(result)
514            }
515            Err(err) => Err(mex("FusionExecutionFailed", &err.to_string())),
516        },
517        _ => Err(mex(
518            "FusionUnsupportedKind",
519            "fusion: unsupported fusion kind",
520        )),
521    }
522}
523
524pub struct ReductionGeometry {
525    pub axis: usize,
526    pub reduce_len: usize,
527    pub num_slices: usize,
528}
529
530pub fn resolve_reduction_geometry(
531    plan: &runmat_accelerate::FusionGroupPlan,
532    graph: &runmat_accelerate::AccelGraph,
533    request: &FusionExecutionRequest<'_>,
534    consumed_inputs: &[Option<Value>],
535    vars: &[Value],
536    context: &ExecutionContext,
537) -> Result<ReductionGeometry, RuntimeError> {
538    fn detect_reduce_all(
539        plan: &runmat_accelerate::FusionGroupPlan,
540        graph: &runmat_accelerate::AccelGraph,
541    ) -> bool {
542        let mut reduce_all = matches!(
543            plan.reduction_axes,
544            Some(runmat_accelerate::ReductionAxes::All)
545        );
546        let has_all = reduce_all
547            || plan.constants.values().any(value_is_all_keyword)
548            || plan.const_values.values().any(value_is_all_keyword);
549        if has_all {
550            return true;
551        }
552        for node_id in &plan.group.nodes {
553            if let Some(node) = graph.node(*node_id) {
554                if let runmat_accelerate::graph::AccelNodeLabel::Builtin { name } = &node.label {
555                    if name.eq_ignore_ascii_case("mean") {
556                        for input_vid in &node.inputs {
557                            if let Some(info) = graph.value(*input_vid) {
558                                if let Some(constant) = &info.constant {
559                                    if value_is_all_keyword(constant) {
560                                        reduce_all = true;
561                                        break;
562                                    }
563                                }
564                            }
565                        }
566                    }
567                }
568            }
569            if reduce_all {
570                break;
571            }
572        }
573        reduce_all
574    }
575
576    fn resolve_reduction_axis(plan: &runmat_accelerate::FusionGroupPlan) -> (usize, bool) {
577        let mut axis = 0usize;
578        let mut axis_explicit = false;
579        if let Some(runmat_accelerate::ReductionAxes::Explicit(dims)) = &plan.reduction_axes {
580            if let Some(first) = dims.first().copied() {
581                axis = first.saturating_sub(1);
582                axis_explicit = true;
583            }
584        }
585        if let Some(dim_vid) = plan.reduction_dim {
586            if let Some(cv) = plan.const_values.get(&dim_vid) {
587                axis = match cv {
588                    Value::Num(n) if *n >= 1.0 => (*n as usize).saturating_sub(1),
589                    Value::Int(i) => (i.to_f64() as usize).saturating_sub(1),
590                    _ => axis,
591                };
592                axis_explicit = true;
593            } else if let Some(input_idx) = plan.inputs.iter().position(|v| *v == dim_vid) {
594                if let Some(cv) = plan.constants.get(&input_idx) {
595                    axis = match cv {
596                        Value::Num(n) if *n >= 1.0 => (*n as usize).saturating_sub(1),
597                        Value::Int(i) => (i.to_f64() as usize).saturating_sub(1),
598                        _ => axis,
599                    };
600                    axis_explicit = true;
601                }
602            }
603        } else if let Some(dim_const) = plan.constants.get(&1) {
604            axis = match dim_const {
605                Value::Num(n) if *n >= 1.0 => (*n as usize).saturating_sub(1),
606                Value::Int(i) => (i.to_f64() as usize).saturating_sub(1),
607                _ => axis,
608            };
609            axis_explicit = true;
610        }
611        (axis, axis_explicit)
612    }
613
614    fn derive_rows_cols(
615        plan: &runmat_accelerate::FusionGroupPlan,
616        graph: &runmat_accelerate::AccelGraph,
617        request: &FusionExecutionRequest<'_>,
618        consumed_inputs: &[Option<Value>],
619        vars: &[Value],
620        context: &ExecutionContext,
621    ) -> Option<(usize, usize)> {
622        let shape_of = |value: &Value| -> Option<(usize, usize)> {
623            match value {
624                Value::GpuTensor(h) => Some((
625                    h.shape.first().copied().unwrap_or(1).max(1),
626                    h.shape.get(1).copied().unwrap_or(1).max(1),
627                )),
628                Value::Tensor(t) => Some((
629                    t.shape.first().copied().unwrap_or(1).max(1),
630                    t.shape.get(1).copied().unwrap_or(1).max(1),
631                )),
632                _ => None,
633            }
634        };
635
636        if let Some(shape) = plan.reduction_data_shape(graph) {
637            if shape.len() >= 2 {
638                return Some((shape[0].max(1), shape[1].max(1)));
639            }
640            if shape.len() == 1 {
641                return Some((shape[0].max(1), 1));
642            }
643        }
644
645        for &vid in &plan.inputs {
646            if let Some(binding) = graph.var_binding(vid) {
647                let value_opt = match binding.kind {
648                    VarKind::Global => vars.get(binding.index).cloned(),
649                    VarKind::Local => {
650                        if let Some(frame) = context.call_stack.last() {
651                            let absolute = frame.locals_start + binding.index;
652                            context.locals.get(absolute).cloned()
653                        } else {
654                            vars.get(binding.index).cloned()
655                        }
656                    }
657                };
658                if let Some(value) = value_opt {
659                    if let Some(shape) = shape_of(&value) {
660                        return Some(shape);
661                    }
662                }
663            }
664        }
665
666        for v in consumed_inputs.iter().filter_map(|v| v.as_ref()) {
667            if let Some(shape) = shape_of(v) {
668                return Some(shape);
669            }
670        }
671
672        if let Some(data_id) = plan.reduction_data {
673            if let Some(input_index) = plan.inputs.iter().position(|vid| *vid == data_id) {
674                if let Some(val) = consumed_inputs.get(input_index).and_then(|v| v.as_ref()) {
675                    if let Some(shape) = shape_of(val) {
676                        return Some(shape);
677                    }
678                }
679                if let Some(val) = request.inputs.get(input_index) {
680                    if let Some(shape) = shape_of(val) {
681                        return Some(shape);
682                    }
683                }
684            }
685            if let Some(info) = graph.value(data_id) {
686                if let ValueOrigin::Variable { kind, index } = &info.origin {
687                    let val = match kind {
688                        VarKind::Global => vars.get(*index).cloned(),
689                        VarKind::Local => {
690                            if let Some(frame) = context.call_stack.last() {
691                                let absolute = frame.locals_start + index;
692                                context.locals.get(absolute).cloned()
693                            } else {
694                                vars.get(*index).cloned()
695                            }
696                        }
697                    };
698                    if let Some(v) = val {
699                        if let Some(shape) = shape_of(&v) {
700                            return Some(shape);
701                        }
702                    }
703                }
704                if let ShapeInfo::Tensor(dims) = &info.shape {
705                    if !dims.is_empty() {
706                        let r = dims.first().and_then(|d| *d).unwrap_or(1);
707                        let c = dims.get(1).and_then(|d| *d).unwrap_or(1);
708                        return Some((r.max(1), c.max(1)));
709                    }
710                }
711            }
712        }
713
714        for v in &request.inputs {
715            if let Some(shape) = shape_of(v) {
716                return Some(shape);
717            }
718        }
719
720        if let ShapeInfo::Tensor(dims) = &plan.group.shape {
721            if !dims.is_empty() {
722                let r = dims.first().and_then(|d| *d).unwrap_or(1);
723                let c = dims.get(1).and_then(|d| *d).unwrap_or(1);
724                return Some((r.max(1), c.max(1)));
725            }
726        }
727        None
728    }
729
730    if log::log_enabled!(log::Level::Debug) {
731        let meta: Vec<String> = plan
732            .inputs
733            .iter()
734            .map(|vid| {
735                if let Some(info) = graph.value(*vid) {
736                    format!(
737                        "vid={} origin={:?} shape={:?}",
738                        vid, info.origin, info.shape
739                    )
740                } else {
741                    format!("vid={} origin=<missing>", vid)
742                }
743            })
744            .collect();
745        log::debug!("reduction gather meta: [{}]", meta.join(", "));
746    }
747
748    let reduce_all = detect_reduce_all(plan, graph);
749    let (mut axis, axis_explicit) = if reduce_all {
750        (0usize, false)
751    } else {
752        resolve_reduction_axis(plan)
753    };
754    if reduce_all && interp_engine::fusion_debug_enabled() {
755        log::debug!(
756            "fusion reduction (all) meta: data_vid={:?} inputs={:?} stack_pattern={:?}",
757            plan.reduction_data,
758            plan.inputs,
759            plan.stack_pattern
760        );
761    }
762
763    let (r, c) =
764        derive_rows_cols(plan, graph, request, consumed_inputs, vars, context).unwrap_or((1, 1));
765    let (reduce_len, num_slices) = if reduce_all {
766        let total_from_runtime = consumed_inputs
767            .iter()
768            .filter_map(|v| v.as_ref())
769            .chain(request.inputs.iter())
770            .find_map(|value| match value {
771                Value::GpuTensor(handle) => Some(if handle.shape.is_empty() {
772                    1
773                } else {
774                    handle
775                        .shape
776                        .iter()
777                        .copied()
778                        .map(|d| d.max(1))
779                        .product::<usize>()
780                }),
781                Value::Tensor(tensor) => Some(if tensor.shape.is_empty() {
782                    1
783                } else {
784                    tensor
785                        .shape
786                        .iter()
787                        .copied()
788                        .map(|d| d.max(1))
789                        .product::<usize>()
790                }),
791                _ => None,
792            });
793        let total = plan
794            .reduction_data_shape(graph)
795            .map(|shape| shape.into_iter().map(|d| d.max(1)).product::<usize>())
796            .or(total_from_runtime)
797            .or_else(|| plan.element_count())
798            .filter(|v| *v > 0)
799            .ok_or_else(|| {
800                mex(
801                    "FusionReductionExtentUnknown",
802                    "fusion: reduction all extent unknown",
803                )
804            })?;
805        if interp_engine::fusion_debug_enabled() {
806            log::debug!(
807                "fusion reduction (all): total_elems={} fallback_rows={} fallback_cols={}",
808                total,
809                r,
810                c
811            );
812        }
813        (total, 1usize)
814    } else {
815        if !axis_explicit {
816            axis = if r == 1 && c > 1 {
817                1
818            } else if r > 1 {
819                0
820            } else {
821                axis
822            };
823        }
824        if interp_engine::fusion_debug_enabled() {
825            if r == 1 && c == 1 {
826                log::debug!(
827                    "fusion reduction: unresolved shape (defaulted to 1x1); axis={}, constants={:?}",
828                    axis,
829                    plan.constants
830                );
831            } else {
832                log::debug!(
833                    "fusion reduction: resolved shape rows={} cols={} axis={} constants={:?}",
834                    r,
835                    c,
836                    axis,
837                    plan.constants
838                );
839            }
840        }
841        if axis == 0 {
842            (r, c)
843        } else {
844            (c, r)
845        }
846    };
847
848    if interp_engine::fusion_debug_enabled() {
849        log::debug!(
850            "fusion reduction: axis={} reduce_len={} num_slices={} constants={:?}",
851            axis,
852            reduce_len,
853            num_slices,
854            plan.constants
855        );
856    }
857
858    let looks_wrong = reduce_len == 1 && num_slices == 1 && {
859        let mut big = false;
860        let mut check_val = |v: &Value| match v {
861            Value::GpuTensor(h) => {
862                let prod = h.shape.iter().copied().product::<usize>();
863                if prod > 1 {
864                    big = true;
865                }
866            }
867            Value::Tensor(t) => {
868                let prod = t.shape.iter().copied().product::<usize>();
869                if prod > 1 {
870                    big = true;
871                }
872            }
873            _ => {}
874        };
875        for v in consumed_inputs.iter().filter_map(|v| v.as_ref()) {
876            check_val(v);
877        }
878        for v in &request.inputs {
879            check_val(v);
880        }
881        big
882    };
883    if looks_wrong {
884        log::debug!("fusion reduction: skipping fusion due to unresolved shape; falling back to provider path");
885        return Err(mex(
886            "FusionReductionShapeUnresolved",
887            "fusion: reduction shape unresolved",
888        ));
889    }
890    if std::env::var("RUNMAT_DISABLE_FUSED_REDUCTION")
891        .ok()
892        .as_deref()
893        == Some("1")
894    {
895        return Err(mex(
896            "FusionReductionDisabled",
897            "fusion: fused reductions disabled",
898        ));
899    }
900
901    Ok(ReductionGeometry {
902        axis,
903        reduce_len,
904        num_slices,
905    })
906}
907
908pub fn execute_fusion_reduction(
909    plan: &runmat_accelerate::FusionGroupPlan,
910    graph: &runmat_accelerate::AccelGraph,
911    request: FusionExecutionRequest<'_>,
912    consumed_inputs: &[Option<Value>],
913    stack_guard: StackSliceGuard<'_>,
914    vars: &[Value],
915    context: &ExecutionContext,
916) -> Result<Value, RuntimeError> {
917    let geom = resolve_reduction_geometry(plan, graph, &request, consumed_inputs, vars, context)?;
918    match execute_reduction(request, geom.reduce_len, geom.num_slices, 256u32) {
919        Ok(result) => {
920            stack_guard.commit();
921            Ok(result)
922        }
923        Err(err) => Err(mex("FusionExecutionFailed", &err.to_string())),
924    }
925}
926
927pub async fn try_execute_fusion_group(
928    plan: &runmat_accelerate::FusionGroupPlan,
929    graph: &runmat_accelerate::AccelGraph,
930    stack: &mut Vec<Value>,
931    vars: &mut Vec<Value>,
932    context: &mut ExecutionContext,
933) -> Result<Value, RuntimeError> {
934    let (stack_guard, request, consumed_inputs) =
935        gather_fusion_inputs(plan, graph, stack, vars, context)?;
936    if plan.group.kind.is_elementwise()
937        && !request.inputs.is_empty()
938        && request.inputs.iter().all(is_scalarish_runtime_value)
939    {
940        return Err(mex(
941            "FusionScalarBypass",
942            "fusion: bypass scalar-only elementwise group",
943        ));
944    }
945    log::debug!(
946        "dispatch fusion kind {:?}, supported {}",
947        plan.group.kind,
948        plan.kernel.supported
949    );
950    if plan.group.kind.is_elementwise() {
951        execute_fusion_elementwise(request, stack_guard, vars, context)
952    } else if plan.group.kind.is_reduction() {
953        execute_fusion_reduction(
954            plan,
955            graph,
956            request,
957            &consumed_inputs,
958            stack_guard,
959            vars,
960            context,
961        )
962    } else {
963        execute_fusion_special_kind(plan.group.kind.clone(), &plan.inputs, request, stack_guard)
964            .await
965    }
966}
967
968#[cfg(all(test, feature = "native-accel"))]
969mod tests {
970    use super::write_elementwise_materialized_stores;
971    use crate::bytecode::program::ExecutionContext;
972    use runmat_accelerate::fusion::FusionStoreMaterialization;
973    use runmat_accelerate::fusion_residency;
974    use runmat_accelerate::graph::VarBinding;
975    use runmat_accelerate::VarKind;
976    use runmat_accelerate_api::GpuTensorHandle;
977    use runmat_builtins::Value;
978
979    #[test]
980    fn fusion_writeback_preserves_shared_gpu_handles() {
981        let shared = GpuTensorHandle {
982            shape: vec![1],
983            device_id: 17,
984            buffer_id: 17001,
985        };
986        let old_only = GpuTensorHandle {
987            shape: vec![1],
988            device_id: 17,
989            buffer_id: 17002,
990        };
991        fusion_residency::mark(&shared);
992        fusion_residency::mark(&old_only);
993        assert!(fusion_residency::is_resident(&shared));
994        assert!(fusion_residency::is_resident(&old_only));
995
996        let mut vars = vec![Value::OutputList(vec![
997            Value::GpuTensor(shared.clone()),
998            Value::GpuTensor(old_only.clone()),
999        ])];
1000        let mut context = ExecutionContext {
1001            call_stack: Vec::new(),
1002            locals: Vec::new(),
1003            instruction_pointer: 0,
1004            spawned_task_ids: std::collections::HashSet::new(),
1005            next_spawn_task_id: 0,
1006        };
1007        write_elementwise_materialized_stores(
1008            vec![(
1009                FusionStoreMaterialization {
1010                    value_id: 1,
1011                    binding: VarBinding {
1012                        kind: VarKind::Global,
1013                        index: 0,
1014                    },
1015                },
1016                Value::GpuTensor(shared.clone()),
1017            )],
1018            &mut vars,
1019            &mut context,
1020        );
1021
1022        assert!(fusion_residency::is_resident(&shared));
1023        assert!(!fusion_residency::is_resident(&old_only));
1024        fusion_residency::clear(&shared);
1025    }
1026}