Skip to main content

runmat_vm/accel/
fusion.rs

1use crate::accel::residency as accel_residency;
2use crate::bytecode::ExecutionContext;
3use crate::bytecode::Instr;
4use crate::interpreter::engine as interp_engine;
5use crate::interpreter::errors::mex;
6use crate::runtime::workspace::refresh_workspace_state;
7use runmat_accelerate::fusion::FusionStoreMaterialization;
8use runmat_accelerate::fusion_exec::{
9    execute_centered_gram, execute_elementwise, execute_explained_variance,
10    execute_image_normalize, execute_matmul_epilogue, execute_power_step_normalize,
11    execute_reduction, FusionExecutionRequest,
12};
13use runmat_accelerate::InstrSpan;
14use runmat_accelerate::{value_is_all_keyword, FusionKind, ShapeInfo, ValueOrigin, VarKind};
15use runmat_builtins::Value;
16use runmat_runtime::RuntimeError;
17use std::collections::HashMap;
18
19#[inline]
20pub fn value_kind(value: &Value) -> &'static str {
21    match value {
22        Value::Int(_) => "Int",
23        Value::Num(_) => "Num",
24        Value::Complex(_, _) => "Complex",
25        Value::Bool(_) => "Bool",
26        Value::LogicalArray(_) => "LogicalArray",
27        Value::String(_) => "String",
28        Value::StringArray(_) => "StringArray",
29        Value::CharArray(_) => "CharArray",
30        Value::Tensor(_) => "Tensor",
31        Value::ComplexTensor(_) => "ComplexTensor",
32        Value::Cell(_) => "Cell",
33        Value::Struct(_) => "Struct",
34        Value::GpuTensor(_) => "GpuTensor",
35        Value::Object(_) => "Object",
36        Value::HandleObject(_) => "HandleObject",
37        Value::Listener(_) => "Listener",
38        Value::FunctionHandle(_) => "FunctionHandle",
39        Value::Closure(_) => "Closure",
40        Value::ClassRef(_) => "ClassRef",
41        Value::MException(_) => "MException",
42        Value::OutputList(_) => "OutputList",
43    }
44}
45
46#[inline]
47pub fn summarize_value(i: usize, v: &Value) -> String {
48    match v {
49        Value::GpuTensor(h) => format!("in#{i}:GpuTensor shape={:?}", h.shape),
50        Value::Tensor(t) => format!("in#{i}:Tensor shape={:?}", t.shape),
51        Value::Num(n) => format!("in#{i}:Num({n:.6})"),
52        Value::Int(n) => format!("in#{i}:Int({})", n.to_i64()),
53        Value::Bool(b) => format!("in#{i}:Bool({})", if *b { 1 } else { 0 }),
54        Value::String(s) => format!("in#{i}:String({})", s),
55        _ => format!("in#{i}:{}", value_kind(v)),
56    }
57}
58
59pub fn fusion_span_live_result_count(instructions: &[Instr], span: &InstrSpan) -> Option<usize> {
60    if span.start > span.end || span.end >= instructions.len() {
61        return None;
62    }
63    let mut current_depth = 0usize;
64    for instr in &instructions[span.start..=span.end] {
65        let effect = instr.stack_effect()?;
66        if current_depth < effect.pops {
67            current_depth = effect.pops;
68        }
69        current_depth = current_depth - effect.pops + effect.pushes;
70    }
71    Some(current_depth)
72}
73
74pub fn fusion_span_has_vm_barrier(instructions: &[Instr], span: &InstrSpan) -> bool {
75    if span.start > span.end || span.end >= instructions.len() {
76        return true;
77    }
78    for instr in &instructions[span.start..=span.end] {
79        if matches!(
80            instr,
81            Instr::StoreIndex(_)
82                | Instr::StoreSlice(_, _, _, _)
83                | Instr::StoreSliceExpr { .. }
84                | Instr::StoreIndexCell(_)
85                | Instr::StoreMember(_)
86                | Instr::StoreMemberOrInit(_)
87                | Instr::StoreMemberDynamic
88                | Instr::StoreMemberDynamicOrInit
89        ) {
90            return true;
91        }
92    }
93    fusion_span_live_result_count(instructions, span) != Some(1)
94}
95
96pub struct StackSliceGuard<'a> {
97    stack: *mut Vec<Value>,
98    slice: Option<Vec<Value>>,
99    _marker: std::marker::PhantomData<&'a mut Vec<Value>>,
100}
101
102impl<'a> StackSliceGuard<'a> {
103    pub fn new(stack: &'a mut Vec<Value>, slice_start: usize) -> Self {
104        let slice = stack.split_off(slice_start);
105        Self {
106            stack,
107            slice: Some(slice),
108            _marker: std::marker::PhantomData,
109        }
110    }
111
112    pub fn slice(&self) -> &[Value] {
113        self.slice.as_ref().expect("stack slice missing").as_slice()
114    }
115
116    pub fn commit(mut self) {
117        self.slice = None;
118    }
119}
120
121impl Drop for StackSliceGuard<'_> {
122    fn drop(&mut self) {
123        if let Some(slice) = self.slice.take() {
124            unsafe { (&mut *self.stack).extend(slice) }
125        }
126    }
127}
128
129pub fn gather_fusion_inputs<'a>(
130    plan: &'a runmat_accelerate::FusionGroupPlan,
131    graph: &runmat_accelerate::AccelGraph,
132    stack: &'a mut Vec<Value>,
133    vars: &mut [Value],
134    context: &mut ExecutionContext,
135) -> Result<
136    (
137        StackSliceGuard<'a>,
138        FusionExecutionRequest<'a>,
139        Vec<Option<Value>>,
140    ),
141    RuntimeError,
142> {
143    if plan.group.stack_layout.is_none() && !plan.stack_pattern.is_empty() {
144        return Err(mex(
145            "FusionMissingStackLayout",
146            "fusion: missing compile-time stack layout metadata",
147        ));
148    }
149    let required_stack_operands = plan
150        .group
151        .stack_layout
152        .as_ref()
153        .map(|layout| layout.required_stack_operands)
154        .unwrap_or_else(|| plan.stack_pattern.len());
155    let mut inputs: Vec<Option<Value>> = vec![None; plan.inputs.len()];
156
157    for (idx, value) in &plan.constants {
158        if let Some(slot) = inputs.get_mut(*idx) {
159            if slot.is_none() {
160                *slot = Some(value.clone());
161            }
162        }
163    }
164
165    for (idx, value_id) in plan.inputs.iter().enumerate() {
166        let info = graph
167            .value(*value_id)
168            .ok_or_else(|| format!("fusion: missing value metadata for id {value_id}"))?;
169        match &info.origin {
170            ValueOrigin::Variable { kind, index } => {
171                let value =
172                    match kind {
173                        VarKind::Global => vars
174                            .get(*index)
175                            .cloned()
176                            .ok_or_else(|| format!("fusion: global var {index} out of range"))?,
177                        VarKind::Local => {
178                            if let Some(frame) = context.call_stack.last() {
179                                let absolute = frame.locals_start + index;
180                                context.locals.get(absolute).cloned().ok_or_else(|| {
181                                    format!("fusion: local var {index} unavailable")
182                                })?
183                            } else {
184                                vars.get(*index).cloned().ok_or_else(|| {
185                                    format!("fusion: local var {index} unavailable")
186                                })?
187                            }
188                        }
189                    };
190                debug_assert!(
191                    inputs[idx].is_none(),
192                    "fusion: duplicate input slot {} for plan {}",
193                    idx,
194                    plan.index
195                );
196                inputs[idx] = Some(value);
197            }
198            ValueOrigin::Constant | ValueOrigin::NodeOutput { .. } | ValueOrigin::Unknown => {}
199        }
200    }
201
202    if log::log_enabled!(log::Level::Debug) && interp_engine::fusion_debug_enabled() {
203        let stack_needed_preview = required_stack_operands;
204        let stack_snapshot: Vec<&Value> = stack.iter().rev().take(stack_needed_preview).collect();
205        let stack_kinds: Vec<&'static str> =
206            stack_snapshot.iter().rev().map(|v| value_kind(v)).collect();
207        let input_meta: Vec<String> = plan
208            .inputs
209            .iter()
210            .enumerate()
211            .map(|(i, value_id)| {
212                if let Some(info) = graph.value(*value_id) {
213                    format!("#{i}:id={} origin={:?}", value_id, info.origin)
214                } else {
215                    format!("#{i}:id={} origin=<missing>", value_id)
216                }
217            })
218            .collect();
219        log::debug!(
220            "fusion group {} gather: stack_depth={} stack_needed={} stack_kinds={:?} pattern={:?} inputs={:?}",
221            plan.index, stack.len(), stack_needed_preview, stack_kinds, &plan.stack_pattern, input_meta
222        );
223    }
224
225    if stack.len() < required_stack_operands {
226        if interp_engine::fusion_debug_enabled() {
227            log::debug!(
228                "fusion stack underflow: plan={} needed={} available={} pattern={:?}",
229                plan.index,
230                required_stack_operands,
231                stack.len(),
232                plan.stack_pattern
233            );
234        }
235        return Err(mex(
236            "FusionStackUnderflow",
237            "fusion: stack underflow gathering inputs",
238        ));
239    }
240    let available = required_stack_operands;
241    let slice_start = stack.len() - available;
242    let stack_guard = StackSliceGuard::new(stack, slice_start);
243    let slice = stack_guard.slice().to_vec();
244    let mut consumed_inputs: Vec<Option<Value>> = vec![None; plan.inputs.len()];
245    let input_positions: HashMap<runmat_accelerate::graph::ValueId, usize> = plan
246        .inputs
247        .iter()
248        .enumerate()
249        .map(|(idx, value_id)| (*value_id, idx))
250        .collect();
251
252    let allow_stack_value = |val: &Value| {
253        if plan.group.kind.is_reduction() {
254            matches!(val, Value::GpuTensor(_) | Value::Tensor(_))
255        } else {
256            true
257        }
258    };
259
260    if let Some(layout) = plan.group.stack_layout.as_ref() {
261        for binding in &layout.bindings {
262            let Some(input_idx) = input_positions.get(&binding.value_id).copied() else {
263                continue;
264            };
265            let Some(val) = slice.get(binding.stack_offset).cloned() else {
266                continue;
267            };
268            consumed_inputs[input_idx] = Some(val.clone());
269            if inputs[input_idx].is_none() && allow_stack_value(&val) {
270                inputs[input_idx] = Some(val);
271            }
272        }
273    } else {
274        for (offset, input_idx) in plan.stack_pattern.iter().enumerate() {
275            let Some(val) = slice.get(offset).cloned() else {
276                continue;
277            };
278            consumed_inputs[*input_idx] = Some(val.clone());
279            if inputs[*input_idx].is_none() && allow_stack_value(&val) {
280                inputs[*input_idx] = Some(val);
281            }
282        }
283    }
284
285    for (idx, slot) in inputs.iter_mut().enumerate() {
286        if slot.is_some() {
287            continue;
288        }
289        let vid = plan.inputs[idx];
290        let info = graph.value(vid);
291        if let Some(info) = info {
292            match &info.origin {
293                ValueOrigin::Variable { kind, index } => {
294                    let value_opt = match kind {
295                        VarKind::Global => vars.get(*index).cloned(),
296                        VarKind::Local => {
297                            if let Some(frame) = context.call_stack.last() {
298                                let absolute = frame.locals_start + index;
299                                context.locals.get(absolute).cloned()
300                            } else {
301                                vars.get(*index).cloned()
302                            }
303                        }
304                    };
305                    if let Some(value) = value_opt {
306                        *slot = Some(value);
307                        continue;
308                    }
309                }
310                ValueOrigin::Constant => {
311                    if let Some(value) = plan.const_values.get(&vid) {
312                        *slot = Some(value.clone());
313                        continue;
314                    }
315                }
316                _ => {}
317            }
318        }
319        if slot.is_none() {
320            if let Some(binding) = graph.var_binding(vid) {
321                let value_opt = match binding.kind {
322                    VarKind::Global => vars.get(binding.index).cloned(),
323                    VarKind::Local => {
324                        if let Some(frame) = context.call_stack.last() {
325                            let absolute = frame.locals_start + binding.index;
326                            context.locals.get(absolute).cloned()
327                        } else {
328                            vars.get(binding.index).cloned()
329                        }
330                    }
331                };
332                if let Some(value) = value_opt {
333                    *slot = Some(value);
334                    continue;
335                }
336            }
337        }
338        if slot.is_none() {
339            if let Some(info) = info {
340                if let ValueOrigin::NodeOutput { node, .. } = info.origin {
341                    if let Some(binding) = graph.node_binding(node) {
342                        let value_opt = match binding.kind {
343                            VarKind::Global => vars.get(binding.index).cloned(),
344                            VarKind::Local => {
345                                if let Some(frame) = context.call_stack.last() {
346                                    let absolute = frame.locals_start + binding.index;
347                                    context.locals.get(absolute).cloned()
348                                } else {
349                                    vars.get(binding.index).cloned()
350                                }
351                            }
352                        };
353                        if let Some(value) = value_opt {
354                            *slot = Some(value);
355                            continue;
356                        }
357                    }
358                }
359            }
360        }
361        if slot.is_none() {
362            if let Some(value) = plan.const_values.get(&vid) {
363                *slot = Some(value.clone());
364            }
365        }
366    }
367
368    let inputs: Vec<Value> = inputs
369        .into_iter()
370        .map(|opt| opt.ok_or_else(|| mex("FusionMissingInput", "fusion: missing input value")))
371        .collect::<Result<_, _>>()?;
372
373    if log::log_enabled!(log::Level::Debug) {
374        let summaries: Vec<String> = inputs
375            .iter()
376            .enumerate()
377            .map(|(i, v)| summarize_value(i, v))
378            .collect();
379        log::debug!("fusion inputs runtime: [{}]", summaries.join(", "));
380    }
381
382    Ok((
383        stack_guard,
384        FusionExecutionRequest { plan, inputs },
385        consumed_inputs,
386    ))
387}
388
389pub fn write_elementwise_materialized_stores(
390    materialized_stores: Vec<(FusionStoreMaterialization, Value)>,
391    vars: &mut Vec<Value>,
392    context: &mut ExecutionContext,
393) {
394    for (store, value) in materialized_stores {
395        match store.binding.kind {
396            VarKind::Global => {
397                let i = store.binding.index;
398                if i < vars.len() && !accel_residency::same_gpu_handle(&vars[i], &value) {
399                    accel_residency::clear_value(&vars[i]);
400                }
401                if i >= vars.len() {
402                    vars.resize(i + 1, Value::Num(0.0));
403                    refresh_workspace_state(vars);
404                }
405                vars[i] = value;
406            }
407            VarKind::Local => {
408                if let Some(frame) = context.call_stack.last() {
409                    let absolute = frame.locals_start + store.binding.index;
410                    while context.locals.len() <= absolute {
411                        context.locals.push(Value::Num(0.0));
412                    }
413                    if !accel_residency::same_gpu_handle(&context.locals[absolute], &value) {
414                        accel_residency::clear_value(&context.locals[absolute]);
415                    }
416                    context.locals[absolute] = value;
417                } else {
418                    let i = store.binding.index;
419                    if i < vars.len() && !accel_residency::same_gpu_handle(&vars[i], &value) {
420                        accel_residency::clear_value(&vars[i]);
421                    }
422                    if i >= vars.len() {
423                        vars.resize(i + 1, Value::Num(0.0));
424                        refresh_workspace_state(vars);
425                    }
426                    vars[i] = value;
427                }
428            }
429        }
430    }
431}
432
433pub fn execute_fusion_elementwise(
434    request: FusionExecutionRequest<'_>,
435    stack_guard: StackSliceGuard<'_>,
436    vars: &mut Vec<Value>,
437    context: &mut ExecutionContext,
438) -> Result<Value, RuntimeError> {
439    match execute_elementwise(request) {
440        Ok(result) => {
441            write_elementwise_materialized_stores(result.materialized_stores, vars, context);
442            stack_guard.commit();
443            Ok(result.final_value)
444        }
445        Err(err) => Err(mex("FusionExecutionFailed", &err.to_string())),
446    }
447}
448
449pub async fn execute_fusion_special_kind(
450    kind: FusionKind,
451    plan_inputs: &[runmat_accelerate::graph::ValueId],
452    request: FusionExecutionRequest<'_>,
453    stack_guard: StackSliceGuard<'_>,
454) -> Result<Value, RuntimeError> {
455    match kind {
456        FusionKind::CenteredGram => match execute_centered_gram(request).await {
457            Ok(result) => {
458                stack_guard.commit();
459                Ok(result)
460            }
461            Err(err) => Err(mex("FusionExecutionFailed", &err.to_string())),
462        },
463        FusionKind::PowerStepNormalize => match execute_power_step_normalize(request).await {
464            Ok(result) => {
465                stack_guard.commit();
466                Ok(result)
467            }
468            Err(err) => Err(mex("FusionExecutionFailed", &err.to_string())),
469        },
470        FusionKind::ExplainedVariance => {
471            log::debug!("explained variance plan inputs {:?}", plan_inputs);
472            match execute_explained_variance(request).await {
473                Ok(result) => {
474                    stack_guard.commit();
475                    Ok(result)
476                }
477                Err(err) => {
478                    log::debug!("explained variance fusion fallback: {}", err);
479                    Err(mex("FusionExecutionFailed", &err.to_string()))
480                }
481            }
482        }
483        FusionKind::MatmulEpilogue => match execute_matmul_epilogue(request).await {
484            Ok(result) => {
485                stack_guard.commit();
486                Ok(result)
487            }
488            Err(err) => Err(mex("FusionExecutionFailed", &err.to_string())),
489        },
490        FusionKind::ImageNormalize => match execute_image_normalize(request).await {
491            Ok(result) => {
492                stack_guard.commit();
493                Ok(result)
494            }
495            Err(err) => Err(mex("FusionExecutionFailed", &err.to_string())),
496        },
497        _ => Err(mex(
498            "FusionUnsupportedKind",
499            "fusion: unsupported fusion kind",
500        )),
501    }
502}
503
504pub struct ReductionGeometry {
505    pub axis: usize,
506    pub reduce_len: usize,
507    pub num_slices: usize,
508}
509
510pub fn resolve_reduction_geometry(
511    plan: &runmat_accelerate::FusionGroupPlan,
512    graph: &runmat_accelerate::AccelGraph,
513    request: &FusionExecutionRequest<'_>,
514    consumed_inputs: &[Option<Value>],
515    vars: &[Value],
516    context: &ExecutionContext,
517) -> Result<ReductionGeometry, RuntimeError> {
518    fn detect_reduce_all(
519        plan: &runmat_accelerate::FusionGroupPlan,
520        graph: &runmat_accelerate::AccelGraph,
521    ) -> bool {
522        let mut reduce_all = matches!(
523            plan.reduction_axes,
524            Some(runmat_accelerate::ReductionAxes::All)
525        );
526        let has_all = reduce_all
527            || plan.constants.values().any(value_is_all_keyword)
528            || plan.const_values.values().any(value_is_all_keyword);
529        if has_all {
530            return true;
531        }
532        for node_id in &plan.group.nodes {
533            if let Some(node) = graph.node(*node_id) {
534                if let runmat_accelerate::graph::AccelNodeLabel::Builtin { name } = &node.label {
535                    if name.eq_ignore_ascii_case("mean") {
536                        for input_vid in &node.inputs {
537                            if let Some(info) = graph.value(*input_vid) {
538                                if let Some(constant) = &info.constant {
539                                    if value_is_all_keyword(constant) {
540                                        reduce_all = true;
541                                        break;
542                                    }
543                                }
544                            }
545                        }
546                    }
547                }
548            }
549            if reduce_all {
550                break;
551            }
552        }
553        reduce_all
554    }
555
556    fn resolve_reduction_axis(plan: &runmat_accelerate::FusionGroupPlan) -> (usize, bool) {
557        let mut axis = 0usize;
558        let mut axis_explicit = false;
559        if let Some(runmat_accelerate::ReductionAxes::Explicit(dims)) = &plan.reduction_axes {
560            if let Some(first) = dims.first().copied() {
561                axis = first.saturating_sub(1);
562                axis_explicit = true;
563            }
564        }
565        if let Some(dim_vid) = plan.reduction_dim {
566            if let Some(cv) = plan.const_values.get(&dim_vid) {
567                axis = match cv {
568                    Value::Num(n) if *n >= 1.0 => (*n as usize).saturating_sub(1),
569                    Value::Int(i) => (i.to_f64() as usize).saturating_sub(1),
570                    _ => axis,
571                };
572                axis_explicit = true;
573            } else if let Some(input_idx) = plan.inputs.iter().position(|v| *v == dim_vid) {
574                if let Some(cv) = plan.constants.get(&input_idx) {
575                    axis = match cv {
576                        Value::Num(n) if *n >= 1.0 => (*n as usize).saturating_sub(1),
577                        Value::Int(i) => (i.to_f64() as usize).saturating_sub(1),
578                        _ => axis,
579                    };
580                    axis_explicit = true;
581                }
582            }
583        } else if let Some(dim_const) = plan.constants.get(&1) {
584            axis = match dim_const {
585                Value::Num(n) if *n >= 1.0 => (*n as usize).saturating_sub(1),
586                Value::Int(i) => (i.to_f64() as usize).saturating_sub(1),
587                _ => axis,
588            };
589            axis_explicit = true;
590        }
591        (axis, axis_explicit)
592    }
593
594    fn derive_rows_cols(
595        plan: &runmat_accelerate::FusionGroupPlan,
596        graph: &runmat_accelerate::AccelGraph,
597        request: &FusionExecutionRequest<'_>,
598        consumed_inputs: &[Option<Value>],
599        vars: &[Value],
600        context: &ExecutionContext,
601    ) -> Option<(usize, usize)> {
602        let shape_of = |value: &Value| -> Option<(usize, usize)> {
603            match value {
604                Value::GpuTensor(h) => Some((
605                    h.shape.first().copied().unwrap_or(1).max(1),
606                    h.shape.get(1).copied().unwrap_or(1).max(1),
607                )),
608                Value::Tensor(t) => Some((
609                    t.shape.first().copied().unwrap_or(1).max(1),
610                    t.shape.get(1).copied().unwrap_or(1).max(1),
611                )),
612                _ => None,
613            }
614        };
615
616        if let Some(shape) = plan.reduction_data_shape(graph) {
617            if shape.len() >= 2 {
618                return Some((shape[0].max(1), shape[1].max(1)));
619            }
620            if shape.len() == 1 {
621                return Some((shape[0].max(1), 1));
622            }
623        }
624
625        for &vid in &plan.inputs {
626            if let Some(binding) = graph.var_binding(vid) {
627                let value_opt = match binding.kind {
628                    VarKind::Global => vars.get(binding.index).cloned(),
629                    VarKind::Local => {
630                        if let Some(frame) = context.call_stack.last() {
631                            let absolute = frame.locals_start + binding.index;
632                            context.locals.get(absolute).cloned()
633                        } else {
634                            vars.get(binding.index).cloned()
635                        }
636                    }
637                };
638                if let Some(value) = value_opt {
639                    if let Some(shape) = shape_of(&value) {
640                        return Some(shape);
641                    }
642                }
643            }
644        }
645
646        for v in consumed_inputs.iter().filter_map(|v| v.as_ref()) {
647            if let Some(shape) = shape_of(v) {
648                return Some(shape);
649            }
650        }
651
652        if let Some(data_id) = plan.reduction_data {
653            if let Some(input_index) = plan.inputs.iter().position(|vid| *vid == data_id) {
654                if let Some(val) = consumed_inputs.get(input_index).and_then(|v| v.as_ref()) {
655                    if let Some(shape) = shape_of(val) {
656                        return Some(shape);
657                    }
658                }
659                if let Some(val) = request.inputs.get(input_index) {
660                    if let Some(shape) = shape_of(val) {
661                        return Some(shape);
662                    }
663                }
664            }
665            if let Some(info) = graph.value(data_id) {
666                if let ValueOrigin::Variable { kind, index } = &info.origin {
667                    let val = match kind {
668                        VarKind::Global => vars.get(*index).cloned(),
669                        VarKind::Local => {
670                            if let Some(frame) = context.call_stack.last() {
671                                let absolute = frame.locals_start + index;
672                                context.locals.get(absolute).cloned()
673                            } else {
674                                vars.get(*index).cloned()
675                            }
676                        }
677                    };
678                    if let Some(v) = val {
679                        if let Some(shape) = shape_of(&v) {
680                            return Some(shape);
681                        }
682                    }
683                }
684                if let ShapeInfo::Tensor(dims) = &info.shape {
685                    if !dims.is_empty() {
686                        let r = dims.first().and_then(|d| *d).unwrap_or(1);
687                        let c = dims.get(1).and_then(|d| *d).unwrap_or(1);
688                        return Some((r.max(1), c.max(1)));
689                    }
690                }
691            }
692        }
693
694        for v in &request.inputs {
695            if let Some(shape) = shape_of(v) {
696                return Some(shape);
697            }
698        }
699
700        if let ShapeInfo::Tensor(dims) = &plan.group.shape {
701            if !dims.is_empty() {
702                let r = dims.first().and_then(|d| *d).unwrap_or(1);
703                let c = dims.get(1).and_then(|d| *d).unwrap_or(1);
704                return Some((r.max(1), c.max(1)));
705            }
706        }
707        None
708    }
709
710    if log::log_enabled!(log::Level::Debug) {
711        let meta: Vec<String> = plan
712            .inputs
713            .iter()
714            .map(|vid| {
715                if let Some(info) = graph.value(*vid) {
716                    format!(
717                        "vid={} origin={:?} shape={:?}",
718                        vid, info.origin, info.shape
719                    )
720                } else {
721                    format!("vid={} origin=<missing>", vid)
722                }
723            })
724            .collect();
725        log::debug!("reduction gather meta: [{}]", meta.join(", "));
726    }
727
728    let reduce_all = detect_reduce_all(plan, graph);
729    let (mut axis, axis_explicit) = if reduce_all {
730        (0usize, false)
731    } else {
732        resolve_reduction_axis(plan)
733    };
734    if reduce_all && interp_engine::fusion_debug_enabled() {
735        log::debug!(
736            "fusion reduction (all) meta: data_vid={:?} inputs={:?} stack_pattern={:?}",
737            plan.reduction_data,
738            plan.inputs,
739            plan.stack_pattern
740        );
741    }
742
743    let (r, c) =
744        derive_rows_cols(plan, graph, request, consumed_inputs, vars, context).unwrap_or((1, 1));
745    let (reduce_len, num_slices) = if reduce_all {
746        let total_from_runtime = consumed_inputs
747            .iter()
748            .filter_map(|v| v.as_ref())
749            .chain(request.inputs.iter())
750            .find_map(|value| match value {
751                Value::GpuTensor(handle) => Some(if handle.shape.is_empty() {
752                    1
753                } else {
754                    handle
755                        .shape
756                        .iter()
757                        .copied()
758                        .map(|d| d.max(1))
759                        .product::<usize>()
760                }),
761                Value::Tensor(tensor) => Some(if tensor.shape.is_empty() {
762                    1
763                } else {
764                    tensor
765                        .shape
766                        .iter()
767                        .copied()
768                        .map(|d| d.max(1))
769                        .product::<usize>()
770                }),
771                _ => None,
772            });
773        let total = plan
774            .reduction_data_shape(graph)
775            .map(|shape| shape.into_iter().map(|d| d.max(1)).product::<usize>())
776            .or(total_from_runtime)
777            .or_else(|| plan.element_count())
778            .filter(|v| *v > 0)
779            .ok_or_else(|| {
780                mex(
781                    "FusionReductionExtentUnknown",
782                    "fusion: reduction all extent unknown",
783                )
784            })?;
785        if interp_engine::fusion_debug_enabled() {
786            log::debug!(
787                "fusion reduction (all): total_elems={} fallback_rows={} fallback_cols={}",
788                total,
789                r,
790                c
791            );
792        }
793        (total, 1usize)
794    } else {
795        if !axis_explicit {
796            axis = if r == 1 && c > 1 {
797                1
798            } else if r > 1 {
799                0
800            } else {
801                axis
802            };
803        }
804        if interp_engine::fusion_debug_enabled() {
805            if r == 1 && c == 1 {
806                log::debug!(
807                    "fusion reduction: unresolved shape (defaulted to 1x1); axis={}, constants={:?}",
808                    axis,
809                    plan.constants
810                );
811            } else {
812                log::debug!(
813                    "fusion reduction: resolved shape rows={} cols={} axis={} constants={:?}",
814                    r,
815                    c,
816                    axis,
817                    plan.constants
818                );
819            }
820        }
821        if axis == 0 {
822            (r, c)
823        } else {
824            (c, r)
825        }
826    };
827
828    if interp_engine::fusion_debug_enabled() {
829        log::debug!(
830            "fusion reduction: axis={} reduce_len={} num_slices={} constants={:?}",
831            axis,
832            reduce_len,
833            num_slices,
834            plan.constants
835        );
836    }
837
838    let looks_wrong = reduce_len == 1 && num_slices == 1 && {
839        let mut big = false;
840        let mut check_val = |v: &Value| match v {
841            Value::GpuTensor(h) => {
842                let prod = h.shape.iter().copied().product::<usize>();
843                if prod > 1 {
844                    big = true;
845                }
846            }
847            Value::Tensor(t) => {
848                let prod = t.shape.iter().copied().product::<usize>();
849                if prod > 1 {
850                    big = true;
851                }
852            }
853            _ => {}
854        };
855        for v in consumed_inputs.iter().filter_map(|v| v.as_ref()) {
856            check_val(v);
857        }
858        for v in &request.inputs {
859            check_val(v);
860        }
861        big
862    };
863    if looks_wrong {
864        log::debug!("fusion reduction: skipping fusion due to unresolved shape; falling back to provider path");
865        return Err(mex(
866            "FusionReductionShapeUnresolved",
867            "fusion: reduction shape unresolved",
868        ));
869    }
870    if std::env::var("RUNMAT_DISABLE_FUSED_REDUCTION")
871        .ok()
872        .as_deref()
873        == Some("1")
874    {
875        return Err(mex(
876            "FusionReductionDisabled",
877            "fusion: fused reductions disabled",
878        ));
879    }
880
881    Ok(ReductionGeometry {
882        axis,
883        reduce_len,
884        num_slices,
885    })
886}
887
888pub fn execute_fusion_reduction(
889    plan: &runmat_accelerate::FusionGroupPlan,
890    graph: &runmat_accelerate::AccelGraph,
891    request: FusionExecutionRequest<'_>,
892    consumed_inputs: &[Option<Value>],
893    stack_guard: StackSliceGuard<'_>,
894    vars: &[Value],
895    context: &ExecutionContext,
896) -> Result<Value, RuntimeError> {
897    let geom = resolve_reduction_geometry(plan, graph, &request, consumed_inputs, vars, context)?;
898    match execute_reduction(request, geom.reduce_len, geom.num_slices, 256u32) {
899        Ok(result) => {
900            stack_guard.commit();
901            Ok(result)
902        }
903        Err(err) => Err(mex("FusionExecutionFailed", &err.to_string())),
904    }
905}
906
907pub async fn try_execute_fusion_group(
908    plan: &runmat_accelerate::FusionGroupPlan,
909    graph: &runmat_accelerate::AccelGraph,
910    stack: &mut Vec<Value>,
911    vars: &mut Vec<Value>,
912    context: &mut ExecutionContext,
913) -> Result<Value, RuntimeError> {
914    let (stack_guard, request, consumed_inputs) =
915        gather_fusion_inputs(plan, graph, stack, vars, context)?;
916    log::debug!(
917        "dispatch fusion kind {:?}, supported {}",
918        plan.group.kind,
919        plan.kernel.supported
920    );
921    if plan.group.kind.is_elementwise() {
922        execute_fusion_elementwise(request, stack_guard, vars, context)
923    } else if plan.group.kind.is_reduction() {
924        execute_fusion_reduction(
925            plan,
926            graph,
927            request,
928            &consumed_inputs,
929            stack_guard,
930            vars,
931            context,
932        )
933    } else {
934        execute_fusion_special_kind(plan.group.kind.clone(), &plan.inputs, request, stack_guard)
935            .await
936    }
937}