Skip to main content

runmat_vm/accel/
fusion.rs

1use crate::accel::residency as accel_residency;
2use crate::bytecode::program::ExecutionContext;
3use crate::bytecode::Instr;
4use crate::interpreter::engine as interp_engine;
5use crate::interpreter::errors::mex;
6use crate::runtime::workspace::refresh_workspace_state;
7use runmat_accelerate::fusion::FusionStoreMaterialization;
8use runmat_accelerate::fusion_exec::{
9    execute_centered_gram, execute_elementwise, execute_explained_variance,
10    execute_image_normalize, execute_matmul_epilogue, execute_power_step_normalize,
11    execute_reduction, FusionExecutionRequest,
12};
13use runmat_accelerate::InstrSpan;
14use runmat_accelerate::{value_is_all_keyword, FusionKind, ShapeInfo, ValueOrigin, VarKind};
15use runmat_builtins::Value;
16use runmat_runtime::builtins::common::shape::is_scalar_shape;
17use runmat_runtime::RuntimeError;
18use std::collections::HashMap;
19
20#[inline]
21pub fn value_kind(value: &Value) -> &'static str {
22    match value {
23        Value::Int(_) => "Int",
24        Value::Num(_) => "Num",
25        Value::Complex(_, _) => "Complex",
26        Value::Bool(_) => "Bool",
27        Value::LogicalArray(_) => "LogicalArray",
28        Value::String(_) => "String",
29        Value::StringArray(_) => "StringArray",
30        Value::Symbolic(_) => "Symbolic",
31        Value::CharArray(_) => "CharArray",
32        Value::Tensor(_) => "Tensor",
33        Value::SparseTensor(_) => "SparseTensor",
34        Value::ComplexTensor(_) => "ComplexTensor",
35        Value::Cell(_) => "Cell",
36        Value::Struct(_) => "Struct",
37        Value::GpuTensor(_) => "GpuTensor",
38        Value::Object(_) => "Object",
39        Value::HandleObject(_) => "HandleObject",
40        Value::Listener(_) => "Listener",
41        Value::FunctionHandle(_)
42        | Value::ExternalFunctionHandle(_)
43        | Value::MethodFunctionHandle(_) => "FunctionHandle",
44        Value::BoundFunctionHandle { .. } => "FunctionHandle",
45        Value::Closure(_) => "Closure",
46        Value::ClassRef(_) => "ClassRef",
47        Value::MException(_) => "MException",
48        Value::OutputList(_) => "OutputList",
49    }
50}
51
52#[inline]
53pub fn summarize_value(i: usize, v: &Value) -> String {
54    match v {
55        Value::GpuTensor(h) => format!("in#{i}:GpuTensor shape={:?}", h.shape),
56        Value::Tensor(t) => format!("in#{i}:Tensor shape={:?}", t.shape),
57        Value::Num(n) => format!("in#{i}:Num({n:.6})"),
58        Value::Int(n) => format!("in#{i}:Int({})", n.to_i64()),
59        Value::Bool(b) => format!("in#{i}:Bool({})", if *b { 1 } else { 0 }),
60        Value::String(s) => format!("in#{i}:String({})", s),
61        _ => format!("in#{i}:{}", value_kind(v)),
62    }
63}
64
65#[inline]
66fn is_scalarish_runtime_value(value: &Value) -> bool {
67    match value {
68        Value::Num(_) | Value::Int(_) | Value::Bool(_) | Value::Complex(_, _) => true,
69        Value::Tensor(tensor) => is_scalar_shape(&tensor.shape),
70        Value::ComplexTensor(tensor) => is_scalar_shape(&tensor.shape),
71        Value::LogicalArray(array) => is_scalar_shape(&array.shape),
72        Value::GpuTensor(handle) => is_scalar_shape(&handle.shape),
73        Value::CharArray(array) => array.rows * array.cols == 1,
74        _ => false,
75    }
76}
77
78pub fn fusion_span_live_result_count(instructions: &[Instr], span: &InstrSpan) -> Option<usize> {
79    if span.start > span.end || span.end >= instructions.len() {
80        return None;
81    }
82    let mut current_depth = 0usize;
83    for instr in &instructions[span.start..=span.end] {
84        let effect = instr.stack_effect()?;
85        if current_depth < effect.pops {
86            current_depth = effect.pops;
87        }
88        current_depth = current_depth - effect.pops + effect.pushes;
89    }
90    Some(current_depth)
91}
92
93pub fn fusion_span_has_vm_barrier(instructions: &[Instr], span: &InstrSpan) -> bool {
94    if span.start > span.end || span.end >= instructions.len() {
95        return true;
96    }
97    for instr in &instructions[span.start..=span.end] {
98        if matches!(
99            instr,
100            Instr::StoreIndex(_)
101                | Instr::StoreIndexDelete(_)
102                | Instr::StoreSlice(_, _, _, _)
103                | Instr::StoreSliceDelete(_, _, _, _)
104                | Instr::StoreSliceExpr { .. }
105                | Instr::StoreSliceExprDelete { .. }
106                | Instr::StoreIndexCell { .. }
107                | Instr::StoreIndexCellDelete { .. }
108                | Instr::StoreMember(_)
109                | Instr::StoreMemberOrInit(_)
110                | Instr::StoreMemberDynamic
111                | Instr::StoreMemberDynamicOrInit
112        ) {
113            return true;
114        }
115    }
116    fusion_span_live_result_count(instructions, span) != Some(1)
117}
118
119pub struct StackSliceGuard<'a> {
120    stack: *mut Vec<Value>,
121    slice: Option<Vec<Value>>,
122    _marker: std::marker::PhantomData<&'a mut Vec<Value>>,
123}
124
125impl<'a> StackSliceGuard<'a> {
126    pub fn new(stack: &'a mut Vec<Value>, slice_start: usize) -> Self {
127        let slice = stack.split_off(slice_start);
128        Self {
129            stack,
130            slice: Some(slice),
131            _marker: std::marker::PhantomData,
132        }
133    }
134
135    pub fn slice(&self) -> &[Value] {
136        self.slice.as_ref().expect("stack slice missing").as_slice()
137    }
138
139    pub fn commit(mut self) {
140        self.slice = None;
141    }
142}
143
144impl Drop for StackSliceGuard<'_> {
145    fn drop(&mut self) {
146        if let Some(slice) = self.slice.take() {
147            unsafe { (&mut *self.stack).extend(slice) }
148        }
149    }
150}
151
152pub fn gather_fusion_inputs<'a>(
153    plan: &'a runmat_accelerate::FusionGroupPlan,
154    graph: &runmat_accelerate::AccelGraph,
155    stack: &'a mut Vec<Value>,
156    vars: &mut [Value],
157    context: &mut ExecutionContext,
158) -> Result<
159    (
160        StackSliceGuard<'a>,
161        FusionExecutionRequest<'a>,
162        Vec<Option<Value>>,
163    ),
164    RuntimeError,
165> {
166    if plan.group.stack_layout.is_none() && !plan.stack_pattern.is_empty() {
167        return Err(mex(
168            "FusionMissingStackLayout",
169            "fusion: missing compile-time stack layout metadata",
170        ));
171    }
172    let required_stack_operands = plan
173        .group
174        .stack_layout
175        .as_ref()
176        .map(|layout| layout.required_stack_operands)
177        .unwrap_or_else(|| plan.stack_pattern.len());
178    let mut inputs: Vec<Option<Value>> = vec![None; plan.inputs.len()];
179
180    for (idx, value) in &plan.constants {
181        if let Some(slot) = inputs.get_mut(*idx) {
182            if slot.is_none() {
183                *slot = Some(value.clone());
184            }
185        }
186    }
187
188    for (idx, value_id) in plan.inputs.iter().enumerate() {
189        let info = graph
190            .value(*value_id)
191            .ok_or_else(|| format!("fusion: missing value metadata for id {value_id}"))?;
192        match &info.origin {
193            ValueOrigin::Variable { kind, index } => {
194                let value =
195                    match kind {
196                        VarKind::Global => vars
197                            .get(*index)
198                            .cloned()
199                            .ok_or_else(|| format!("fusion: global var {index} out of range"))?,
200                        VarKind::Local => {
201                            if let Some(frame) = context.call_stack.last() {
202                                let absolute = frame.locals_start + index;
203                                context.locals.get(absolute).cloned().ok_or_else(|| {
204                                    format!("fusion: local var {index} unavailable")
205                                })?
206                            } else {
207                                vars.get(*index).cloned().ok_or_else(|| {
208                                    format!("fusion: local var {index} unavailable")
209                                })?
210                            }
211                        }
212                    };
213                debug_assert!(
214                    inputs[idx].is_none(),
215                    "fusion: duplicate input slot {} for plan {}",
216                    idx,
217                    plan.index
218                );
219                inputs[idx] = Some(value);
220            }
221            ValueOrigin::Constant | ValueOrigin::NodeOutput { .. } | ValueOrigin::Unknown => {}
222        }
223    }
224
225    if log::log_enabled!(log::Level::Debug) && interp_engine::fusion_debug_enabled() {
226        let stack_needed_preview = required_stack_operands;
227        let stack_snapshot: Vec<&Value> = stack.iter().rev().take(stack_needed_preview).collect();
228        let stack_kinds: Vec<&'static str> =
229            stack_snapshot.iter().rev().map(|v| value_kind(v)).collect();
230        let input_meta: Vec<String> = plan
231            .inputs
232            .iter()
233            .enumerate()
234            .map(|(i, value_id)| {
235                if let Some(info) = graph.value(*value_id) {
236                    format!("#{i}:id={} origin={:?}", value_id, info.origin)
237                } else {
238                    format!("#{i}:id={} origin=<missing>", value_id)
239                }
240            })
241            .collect();
242        log::debug!(
243            "fusion group {} gather: stack_depth={} stack_needed={} stack_kinds={:?} pattern={:?} inputs={:?}",
244            plan.index, stack.len(), stack_needed_preview, stack_kinds, &plan.stack_pattern, input_meta
245        );
246    }
247
248    if stack.len() < required_stack_operands {
249        if interp_engine::fusion_debug_enabled() {
250            log::debug!(
251                "fusion stack underflow: plan={} needed={} available={} pattern={:?}",
252                plan.index,
253                required_stack_operands,
254                stack.len(),
255                plan.stack_pattern
256            );
257        }
258        return Err(mex(
259            "FusionStackUnderflow",
260            "fusion: stack underflow gathering inputs",
261        ));
262    }
263    let available = required_stack_operands;
264    let slice_start = stack.len() - available;
265    let stack_guard = StackSliceGuard::new(stack, slice_start);
266    let slice = stack_guard.slice().to_vec();
267    let mut consumed_inputs: Vec<Option<Value>> = vec![None; plan.inputs.len()];
268    let input_positions: HashMap<runmat_accelerate::graph::ValueId, usize> = plan
269        .inputs
270        .iter()
271        .enumerate()
272        .map(|(idx, value_id)| (*value_id, idx))
273        .collect();
274
275    let allow_stack_value = |val: &Value| {
276        if plan.group.kind.is_reduction() {
277            matches!(val, Value::GpuTensor(_) | Value::Tensor(_))
278        } else {
279            true
280        }
281    };
282
283    if let Some(layout) = plan.group.stack_layout.as_ref() {
284        for binding in &layout.bindings {
285            let Some(input_idx) = input_positions.get(&binding.value_id).copied() else {
286                continue;
287            };
288            let Some(val) = slice.get(binding.stack_offset).cloned() else {
289                continue;
290            };
291            consumed_inputs[input_idx] = Some(val.clone());
292            if inputs[input_idx].is_none() && allow_stack_value(&val) {
293                inputs[input_idx] = Some(val);
294            }
295        }
296    } else {
297        for (offset, input_idx) in plan.stack_pattern.iter().enumerate() {
298            let Some(val) = slice.get(offset).cloned() else {
299                continue;
300            };
301            consumed_inputs[*input_idx] = Some(val.clone());
302            if inputs[*input_idx].is_none() && allow_stack_value(&val) {
303                inputs[*input_idx] = Some(val);
304            }
305        }
306    }
307
308    for (idx, slot) in inputs.iter_mut().enumerate() {
309        if slot.is_some() {
310            continue;
311        }
312        let vid = plan.inputs[idx];
313        let info = graph.value(vid);
314        if let Some(info) = info {
315            match &info.origin {
316                ValueOrigin::Variable { kind, index } => {
317                    let value_opt = match kind {
318                        VarKind::Global => vars.get(*index).cloned(),
319                        VarKind::Local => {
320                            if let Some(frame) = context.call_stack.last() {
321                                let absolute = frame.locals_start + index;
322                                context.locals.get(absolute).cloned()
323                            } else {
324                                vars.get(*index).cloned()
325                            }
326                        }
327                    };
328                    if let Some(value) = value_opt {
329                        *slot = Some(value);
330                        continue;
331                    }
332                }
333                ValueOrigin::Constant => {
334                    if let Some(value) = plan.const_values.get(&vid) {
335                        *slot = Some(value.clone());
336                        continue;
337                    }
338                }
339                _ => {}
340            }
341        }
342        if slot.is_none() {
343            if let Some(binding) = graph.var_binding(vid) {
344                let value_opt = match binding.kind {
345                    VarKind::Global => vars.get(binding.index).cloned(),
346                    VarKind::Local => {
347                        if let Some(frame) = context.call_stack.last() {
348                            let absolute = frame.locals_start + binding.index;
349                            context.locals.get(absolute).cloned()
350                        } else {
351                            vars.get(binding.index).cloned()
352                        }
353                    }
354                };
355                if let Some(value) = value_opt {
356                    *slot = Some(value);
357                    continue;
358                }
359            }
360        }
361        if slot.is_none() {
362            if let Some(info) = info {
363                if let ValueOrigin::NodeOutput { node, .. } = info.origin {
364                    if let Some(binding) = graph.node_binding(node) {
365                        let value_opt = match binding.kind {
366                            VarKind::Global => vars.get(binding.index).cloned(),
367                            VarKind::Local => {
368                                if let Some(frame) = context.call_stack.last() {
369                                    let absolute = frame.locals_start + binding.index;
370                                    context.locals.get(absolute).cloned()
371                                } else {
372                                    vars.get(binding.index).cloned()
373                                }
374                            }
375                        };
376                        if let Some(value) = value_opt {
377                            *slot = Some(value);
378                            continue;
379                        }
380                    }
381                }
382            }
383        }
384        if slot.is_none() {
385            if let Some(value) = plan.const_values.get(&vid) {
386                *slot = Some(value.clone());
387            }
388        }
389    }
390
391    let inputs: Vec<Value> = inputs
392        .into_iter()
393        .map(|opt| opt.ok_or_else(|| mex("FusionMissingInput", "fusion: missing input value")))
394        .collect::<Result<_, _>>()?;
395
396    if log::log_enabled!(log::Level::Debug) {
397        let summaries: Vec<String> = inputs
398            .iter()
399            .enumerate()
400            .map(|(i, v)| summarize_value(i, v))
401            .collect();
402        log::debug!("fusion inputs runtime: [{}]", summaries.join(", "));
403    }
404
405    Ok((
406        stack_guard,
407        FusionExecutionRequest { plan, inputs },
408        consumed_inputs,
409    ))
410}
411
412pub fn write_elementwise_materialized_stores(
413    materialized_stores: Vec<(FusionStoreMaterialization, Value)>,
414    vars: &mut Vec<Value>,
415    context: &mut ExecutionContext,
416) {
417    for (store, value) in materialized_stores {
418        match store.binding.kind {
419            VarKind::Global => {
420                let i = store.binding.index;
421                if i < vars.len() {
422                    if let Err(err) = accel_residency::clear_value_excluding(&vars[i], &value) {
423                        log::warn!("failed to clear fused global GPU residency: {err}");
424                    }
425                }
426                if i >= vars.len() {
427                    vars.resize(i + 1, Value::Num(0.0));
428                    refresh_workspace_state(vars);
429                }
430                vars[i] = value;
431            }
432            VarKind::Local => {
433                if let Some(frame) = context.call_stack.last() {
434                    let absolute = frame.locals_start + store.binding.index;
435                    while context.locals.len() <= absolute {
436                        context.locals.push(Value::Num(0.0));
437                    }
438                    if let Err(err) =
439                        accel_residency::clear_value_excluding(&context.locals[absolute], &value)
440                    {
441                        log::warn!("failed to clear fused local GPU residency: {err}");
442                    }
443                    context.locals[absolute] = value;
444                } else {
445                    let i = store.binding.index;
446                    if i < vars.len() {
447                        if let Err(err) = accel_residency::clear_value_excluding(&vars[i], &value) {
448                            log::warn!("failed to clear fused fallback GPU residency: {err}");
449                        }
450                    }
451                    if i >= vars.len() {
452                        vars.resize(i + 1, Value::Num(0.0));
453                        refresh_workspace_state(vars);
454                    }
455                    vars[i] = value;
456                }
457            }
458        }
459    }
460}
461
462pub fn execute_fusion_elementwise(
463    request: FusionExecutionRequest<'_>,
464    stack_guard: StackSliceGuard<'_>,
465    vars: &mut Vec<Value>,
466    context: &mut ExecutionContext,
467) -> Result<Value, RuntimeError> {
468    match execute_elementwise(request) {
469        Ok(result) => {
470            write_elementwise_materialized_stores(result.materialized_stores, vars, context);
471            stack_guard.commit();
472            Ok(result.final_value)
473        }
474        Err(err) => Err(mex("FusionExecutionFailed", &err.to_string())),
475    }
476}
477
478pub async fn execute_fusion_special_kind(
479    kind: FusionKind,
480    plan_inputs: &[runmat_accelerate::graph::ValueId],
481    request: FusionExecutionRequest<'_>,
482    stack_guard: StackSliceGuard<'_>,
483) -> Result<Value, RuntimeError> {
484    match kind {
485        FusionKind::CenteredGram => match execute_centered_gram(request).await {
486            Ok(result) => {
487                stack_guard.commit();
488                Ok(result)
489            }
490            Err(err) => Err(mex("FusionExecutionFailed", &err.to_string())),
491        },
492        FusionKind::PowerStepNormalize => match execute_power_step_normalize(request).await {
493            Ok(result) => {
494                stack_guard.commit();
495                Ok(result)
496            }
497            Err(err) => Err(mex("FusionExecutionFailed", &err.to_string())),
498        },
499        FusionKind::ExplainedVariance => {
500            log::debug!("explained variance plan inputs {:?}", plan_inputs);
501            match execute_explained_variance(request).await {
502                Ok(result) => {
503                    stack_guard.commit();
504                    Ok(result)
505                }
506                Err(err) => {
507                    log::debug!("explained variance fusion fallback: {}", err);
508                    Err(mex("FusionExecutionFailed", &err.to_string()))
509                }
510            }
511        }
512        FusionKind::MatmulEpilogue => match execute_matmul_epilogue(request).await {
513            Ok(result) => {
514                stack_guard.commit();
515                Ok(result)
516            }
517            Err(err) => Err(mex("FusionExecutionFailed", &err.to_string())),
518        },
519        FusionKind::ImageNormalize => match execute_image_normalize(request).await {
520            Ok(result) => {
521                stack_guard.commit();
522                Ok(result)
523            }
524            Err(err) => Err(mex("FusionExecutionFailed", &err.to_string())),
525        },
526        _ => Err(mex(
527            "FusionUnsupportedKind",
528            "fusion: unsupported fusion kind",
529        )),
530    }
531}
532
533pub struct ReductionGeometry {
534    pub axis: usize,
535    pub reduce_len: usize,
536    pub num_slices: usize,
537}
538
539pub fn resolve_reduction_geometry(
540    plan: &runmat_accelerate::FusionGroupPlan,
541    graph: &runmat_accelerate::AccelGraph,
542    request: &FusionExecutionRequest<'_>,
543    consumed_inputs: &[Option<Value>],
544    vars: &[Value],
545    context: &ExecutionContext,
546) -> Result<ReductionGeometry, RuntimeError> {
547    fn detect_reduce_all(
548        plan: &runmat_accelerate::FusionGroupPlan,
549        graph: &runmat_accelerate::AccelGraph,
550    ) -> bool {
551        let mut reduce_all = matches!(
552            plan.reduction_axes,
553            Some(runmat_accelerate::ReductionAxes::All)
554        );
555        let has_all = reduce_all
556            || plan.constants.values().any(value_is_all_keyword)
557            || plan.const_values.values().any(value_is_all_keyword);
558        if has_all {
559            return true;
560        }
561        for node_id in &plan.group.nodes {
562            if let Some(node) = graph.node(*node_id) {
563                if let runmat_accelerate::graph::AccelNodeLabel::Builtin { name } = &node.label {
564                    if name.eq_ignore_ascii_case("mean") {
565                        for input_vid in &node.inputs {
566                            if let Some(info) = graph.value(*input_vid) {
567                                if let Some(constant) = &info.constant {
568                                    if value_is_all_keyword(constant) {
569                                        reduce_all = true;
570                                        break;
571                                    }
572                                }
573                            }
574                        }
575                    }
576                }
577            }
578            if reduce_all {
579                break;
580            }
581        }
582        reduce_all
583    }
584
585    fn resolve_reduction_axis(plan: &runmat_accelerate::FusionGroupPlan) -> (usize, bool) {
586        let mut axis = 0usize;
587        let mut axis_explicit = false;
588        if let Some(runmat_accelerate::ReductionAxes::Explicit(dims)) = &plan.reduction_axes {
589            if let Some(first) = dims.first().copied() {
590                axis = first.saturating_sub(1);
591                axis_explicit = true;
592            }
593        }
594        if let Some(dim_vid) = plan.reduction_dim {
595            if let Some(cv) = plan.const_values.get(&dim_vid) {
596                axis = match cv {
597                    Value::Num(n) if *n >= 1.0 => (*n as usize).saturating_sub(1),
598                    Value::Int(i) => (i.to_f64() as usize).saturating_sub(1),
599                    _ => axis,
600                };
601                axis_explicit = true;
602            } else if let Some(input_idx) = plan.inputs.iter().position(|v| *v == dim_vid) {
603                if let Some(cv) = plan.constants.get(&input_idx) {
604                    axis = match cv {
605                        Value::Num(n) if *n >= 1.0 => (*n as usize).saturating_sub(1),
606                        Value::Int(i) => (i.to_f64() as usize).saturating_sub(1),
607                        _ => axis,
608                    };
609                    axis_explicit = true;
610                }
611            }
612        } else if let Some(dim_const) = plan.constants.get(&1) {
613            axis = match dim_const {
614                Value::Num(n) if *n >= 1.0 => (*n as usize).saturating_sub(1),
615                Value::Int(i) => (i.to_f64() as usize).saturating_sub(1),
616                _ => axis,
617            };
618            axis_explicit = true;
619        }
620        (axis, axis_explicit)
621    }
622
623    fn derive_rows_cols(
624        plan: &runmat_accelerate::FusionGroupPlan,
625        graph: &runmat_accelerate::AccelGraph,
626        request: &FusionExecutionRequest<'_>,
627        consumed_inputs: &[Option<Value>],
628        vars: &[Value],
629        context: &ExecutionContext,
630    ) -> Option<(usize, usize)> {
631        let shape_of = |value: &Value| -> Option<(usize, usize)> {
632            match value {
633                Value::GpuTensor(h) => Some((
634                    h.shape.first().copied().unwrap_or(1).max(1),
635                    h.shape.get(1).copied().unwrap_or(1).max(1),
636                )),
637                Value::Tensor(t) => Some((
638                    t.shape.first().copied().unwrap_or(1).max(1),
639                    t.shape.get(1).copied().unwrap_or(1).max(1),
640                )),
641                _ => None,
642            }
643        };
644
645        if let Some(shape) = plan.reduction_data_shape(graph) {
646            if shape.len() >= 2 {
647                return Some((shape[0].max(1), shape[1].max(1)));
648            }
649            if shape.len() == 1 {
650                return Some((shape[0].max(1), 1));
651            }
652        }
653
654        for &vid in &plan.inputs {
655            if let Some(binding) = graph.var_binding(vid) {
656                let value_opt = match binding.kind {
657                    VarKind::Global => vars.get(binding.index).cloned(),
658                    VarKind::Local => {
659                        if let Some(frame) = context.call_stack.last() {
660                            let absolute = frame.locals_start + binding.index;
661                            context.locals.get(absolute).cloned()
662                        } else {
663                            vars.get(binding.index).cloned()
664                        }
665                    }
666                };
667                if let Some(value) = value_opt {
668                    if let Some(shape) = shape_of(&value) {
669                        return Some(shape);
670                    }
671                }
672            }
673        }
674
675        for v in consumed_inputs.iter().filter_map(|v| v.as_ref()) {
676            if let Some(shape) = shape_of(v) {
677                return Some(shape);
678            }
679        }
680
681        if let Some(data_id) = plan.reduction_data {
682            if let Some(input_index) = plan.inputs.iter().position(|vid| *vid == data_id) {
683                if let Some(val) = consumed_inputs.get(input_index).and_then(|v| v.as_ref()) {
684                    if let Some(shape) = shape_of(val) {
685                        return Some(shape);
686                    }
687                }
688                if let Some(val) = request.inputs.get(input_index) {
689                    if let Some(shape) = shape_of(val) {
690                        return Some(shape);
691                    }
692                }
693            }
694            if let Some(info) = graph.value(data_id) {
695                if let ValueOrigin::Variable { kind, index } = &info.origin {
696                    let val = match kind {
697                        VarKind::Global => vars.get(*index).cloned(),
698                        VarKind::Local => {
699                            if let Some(frame) = context.call_stack.last() {
700                                let absolute = frame.locals_start + index;
701                                context.locals.get(absolute).cloned()
702                            } else {
703                                vars.get(*index).cloned()
704                            }
705                        }
706                    };
707                    if let Some(v) = val {
708                        if let Some(shape) = shape_of(&v) {
709                            return Some(shape);
710                        }
711                    }
712                }
713                if let ShapeInfo::Tensor(dims) = &info.shape {
714                    if !dims.is_empty() {
715                        let r = dims.first().and_then(|d| *d).unwrap_or(1);
716                        let c = dims.get(1).and_then(|d| *d).unwrap_or(1);
717                        return Some((r.max(1), c.max(1)));
718                    }
719                }
720            }
721        }
722
723        for v in &request.inputs {
724            if let Some(shape) = shape_of(v) {
725                return Some(shape);
726            }
727        }
728
729        if let ShapeInfo::Tensor(dims) = &plan.group.shape {
730            if !dims.is_empty() {
731                let r = dims.first().and_then(|d| *d).unwrap_or(1);
732                let c = dims.get(1).and_then(|d| *d).unwrap_or(1);
733                return Some((r.max(1), c.max(1)));
734            }
735        }
736        None
737    }
738
739    if log::log_enabled!(log::Level::Debug) {
740        let meta: Vec<String> = plan
741            .inputs
742            .iter()
743            .map(|vid| {
744                if let Some(info) = graph.value(*vid) {
745                    format!(
746                        "vid={} origin={:?} shape={:?}",
747                        vid, info.origin, info.shape
748                    )
749                } else {
750                    format!("vid={} origin=<missing>", vid)
751                }
752            })
753            .collect();
754        log::debug!("reduction gather meta: [{}]", meta.join(", "));
755    }
756
757    let reduce_all = detect_reduce_all(plan, graph);
758    let (mut axis, axis_explicit) = if reduce_all {
759        (0usize, false)
760    } else {
761        resolve_reduction_axis(plan)
762    };
763    if reduce_all && interp_engine::fusion_debug_enabled() {
764        log::debug!(
765            "fusion reduction (all) meta: data_vid={:?} inputs={:?} stack_pattern={:?}",
766            plan.reduction_data,
767            plan.inputs,
768            plan.stack_pattern
769        );
770    }
771
772    let (r, c) =
773        derive_rows_cols(plan, graph, request, consumed_inputs, vars, context).unwrap_or((1, 1));
774    let (reduce_len, num_slices) = if reduce_all {
775        let total_from_runtime = consumed_inputs
776            .iter()
777            .filter_map(|v| v.as_ref())
778            .chain(request.inputs.iter())
779            .find_map(|value| match value {
780                Value::GpuTensor(handle) => Some(if handle.shape.is_empty() {
781                    1
782                } else {
783                    handle
784                        .shape
785                        .iter()
786                        .copied()
787                        .map(|d| d.max(1))
788                        .product::<usize>()
789                }),
790                Value::Tensor(tensor) => Some(if tensor.shape.is_empty() {
791                    1
792                } else {
793                    tensor
794                        .shape
795                        .iter()
796                        .copied()
797                        .map(|d| d.max(1))
798                        .product::<usize>()
799                }),
800                _ => None,
801            });
802        let total = plan
803            .reduction_data_shape(graph)
804            .map(|shape| shape.into_iter().map(|d| d.max(1)).product::<usize>())
805            .or(total_from_runtime)
806            .or_else(|| plan.element_count())
807            .filter(|v| *v > 0)
808            .ok_or_else(|| {
809                mex(
810                    "FusionReductionExtentUnknown",
811                    "fusion: reduction all extent unknown",
812                )
813            })?;
814        if interp_engine::fusion_debug_enabled() {
815            log::debug!(
816                "fusion reduction (all): total_elems={} fallback_rows={} fallback_cols={}",
817                total,
818                r,
819                c
820            );
821        }
822        (total, 1usize)
823    } else {
824        if !axis_explicit {
825            axis = if r == 1 && c > 1 {
826                1
827            } else if r > 1 {
828                0
829            } else {
830                axis
831            };
832        }
833        if interp_engine::fusion_debug_enabled() {
834            if r == 1 && c == 1 {
835                log::debug!(
836                    "fusion reduction: unresolved shape (defaulted to 1x1); axis={}, constants={:?}",
837                    axis,
838                    plan.constants
839                );
840            } else {
841                log::debug!(
842                    "fusion reduction: resolved shape rows={} cols={} axis={} constants={:?}",
843                    r,
844                    c,
845                    axis,
846                    plan.constants
847                );
848            }
849        }
850        if axis == 0 {
851            (r, c)
852        } else {
853            (c, r)
854        }
855    };
856
857    if interp_engine::fusion_debug_enabled() {
858        log::debug!(
859            "fusion reduction: axis={} reduce_len={} num_slices={} constants={:?}",
860            axis,
861            reduce_len,
862            num_slices,
863            plan.constants
864        );
865    }
866
867    let looks_wrong = reduce_len == 1 && num_slices == 1 && {
868        let mut big = false;
869        let mut check_val = |v: &Value| match v {
870            Value::GpuTensor(h) => {
871                let prod = h.shape.iter().copied().product::<usize>();
872                if prod > 1 {
873                    big = true;
874                }
875            }
876            Value::Tensor(t) => {
877                let prod = t.shape.iter().copied().product::<usize>();
878                if prod > 1 {
879                    big = true;
880                }
881            }
882            _ => {}
883        };
884        for v in consumed_inputs.iter().filter_map(|v| v.as_ref()) {
885            check_val(v);
886        }
887        for v in &request.inputs {
888            check_val(v);
889        }
890        big
891    };
892    if looks_wrong {
893        log::debug!("fusion reduction: skipping fusion due to unresolved shape; falling back to provider path");
894        return Err(mex(
895            "FusionReductionShapeUnresolved",
896            "fusion: reduction shape unresolved",
897        ));
898    }
899    if std::env::var("RUNMAT_DISABLE_FUSED_REDUCTION")
900        .ok()
901        .as_deref()
902        == Some("1")
903    {
904        return Err(mex(
905            "FusionReductionDisabled",
906            "fusion: fused reductions disabled",
907        ));
908    }
909
910    Ok(ReductionGeometry {
911        axis,
912        reduce_len,
913        num_slices,
914    })
915}
916
917pub fn execute_fusion_reduction(
918    plan: &runmat_accelerate::FusionGroupPlan,
919    graph: &runmat_accelerate::AccelGraph,
920    request: FusionExecutionRequest<'_>,
921    consumed_inputs: &[Option<Value>],
922    stack_guard: StackSliceGuard<'_>,
923    vars: &[Value],
924    context: &ExecutionContext,
925) -> Result<Value, RuntimeError> {
926    let geom = resolve_reduction_geometry(plan, graph, &request, consumed_inputs, vars, context)?;
927    match execute_reduction(request, geom.reduce_len, geom.num_slices, 256u32) {
928        Ok(result) => {
929            stack_guard.commit();
930            Ok(result)
931        }
932        Err(err) => Err(mex("FusionExecutionFailed", &err.to_string())),
933    }
934}
935
936pub async fn try_execute_fusion_group(
937    plan: &runmat_accelerate::FusionGroupPlan,
938    graph: &runmat_accelerate::AccelGraph,
939    stack: &mut Vec<Value>,
940    vars: &mut Vec<Value>,
941    context: &mut ExecutionContext,
942) -> Result<Value, RuntimeError> {
943    let (stack_guard, request, consumed_inputs) =
944        gather_fusion_inputs(plan, graph, stack, vars, context)?;
945    if plan.group.kind.is_elementwise()
946        && !request.inputs.is_empty()
947        && request.inputs.iter().all(is_scalarish_runtime_value)
948    {
949        return Err(mex(
950            "FusionScalarBypass",
951            "fusion: bypass scalar-only elementwise group",
952        ));
953    }
954    log::debug!(
955        "dispatch fusion kind {:?}, supported {}",
956        plan.group.kind,
957        plan.kernel.supported
958    );
959    if plan.group.kind.is_elementwise() {
960        execute_fusion_elementwise(request, stack_guard, vars, context)
961    } else if plan.group.kind.is_reduction() {
962        execute_fusion_reduction(
963            plan,
964            graph,
965            request,
966            &consumed_inputs,
967            stack_guard,
968            vars,
969            context,
970        )
971    } else {
972        execute_fusion_special_kind(plan.group.kind.clone(), &plan.inputs, request, stack_guard)
973            .await
974    }
975}
976
977#[cfg(all(test, feature = "native-accel"))]
978mod tests {
979    use super::write_elementwise_materialized_stores;
980    use crate::bytecode::program::ExecutionContext;
981    use runmat_accelerate::fusion::FusionStoreMaterialization;
982    use runmat_accelerate::fusion_residency;
983    use runmat_accelerate::graph::VarBinding;
984    use runmat_accelerate::VarKind;
985    use runmat_accelerate_api::GpuTensorHandle;
986    use runmat_builtins::Value;
987
988    #[test]
989    fn fusion_writeback_preserves_shared_gpu_handles() {
990        let shared = GpuTensorHandle {
991            shape: vec![1],
992            device_id: 17,
993            buffer_id: 17001,
994        };
995        let old_only = GpuTensorHandle {
996            shape: vec![1],
997            device_id: 17,
998            buffer_id: 17002,
999        };
1000        fusion_residency::mark(&shared);
1001        fusion_residency::mark(&old_only);
1002        assert!(fusion_residency::is_resident(&shared));
1003        assert!(fusion_residency::is_resident(&old_only));
1004
1005        let mut vars = vec![Value::OutputList(vec![
1006            Value::GpuTensor(shared.clone()),
1007            Value::GpuTensor(old_only.clone()),
1008        ])];
1009        let mut context = ExecutionContext {
1010            call_stack: Vec::new(),
1011            locals: Vec::new(),
1012            instruction_pointer: 0,
1013            spawned_task_ids: std::collections::HashSet::new(),
1014            next_spawn_task_id: 0,
1015        };
1016        write_elementwise_materialized_stores(
1017            vec![(
1018                FusionStoreMaterialization {
1019                    value_id: 1,
1020                    binding: VarBinding {
1021                        kind: VarKind::Global,
1022                        index: 0,
1023                    },
1024                },
1025                Value::GpuTensor(shared.clone()),
1026            )],
1027            &mut vars,
1028            &mut context,
1029        );
1030
1031        assert!(fusion_residency::is_resident(&shared));
1032        assert!(!fusion_residency::is_resident(&old_only));
1033        fusion_residency::clear(&shared);
1034    }
1035}