1use crate::accel::residency as accel_residency;
2use crate::bytecode::program::ExecutionContext;
3use crate::bytecode::Instr;
4use crate::interpreter::engine as interp_engine;
5use crate::interpreter::errors::mex;
6use crate::runtime::workspace::refresh_workspace_state;
7use runmat_accelerate::fusion::FusionStoreMaterialization;
8use runmat_accelerate::fusion_exec::{
9 execute_centered_gram, execute_elementwise, execute_explained_variance,
10 execute_image_normalize, execute_matmul_epilogue, execute_power_step_normalize,
11 execute_reduction, FusionExecutionRequest,
12};
13use runmat_accelerate::InstrSpan;
14use runmat_accelerate::{value_is_all_keyword, FusionKind, ShapeInfo, ValueOrigin, VarKind};
15use runmat_builtins::Value;
16use runmat_runtime::builtins::common::shape::is_scalar_shape;
17use runmat_runtime::RuntimeError;
18use std::collections::HashMap;
19
20#[inline]
21pub fn value_kind(value: &Value) -> &'static str {
22 match value {
23 Value::Int(_) => "Int",
24 Value::Num(_) => "Num",
25 Value::Complex(_, _) => "Complex",
26 Value::Bool(_) => "Bool",
27 Value::LogicalArray(_) => "LogicalArray",
28 Value::String(_) => "String",
29 Value::StringArray(_) => "StringArray",
30 Value::Symbolic(_) => "Symbolic",
31 Value::CharArray(_) => "CharArray",
32 Value::Tensor(_) => "Tensor",
33 Value::SparseTensor(_) => "SparseTensor",
34 Value::ComplexTensor(_) => "ComplexTensor",
35 Value::Cell(_) => "Cell",
36 Value::Struct(_) => "Struct",
37 Value::GpuTensor(_) => "GpuTensor",
38 Value::Object(_) => "Object",
39 Value::HandleObject(_) => "HandleObject",
40 Value::Listener(_) => "Listener",
41 Value::FunctionHandle(_)
42 | Value::ExternalFunctionHandle(_)
43 | Value::MethodFunctionHandle(_) => "FunctionHandle",
44 Value::BoundFunctionHandle { .. } => "FunctionHandle",
45 Value::Closure(_) => "Closure",
46 Value::ClassRef(_) => "ClassRef",
47 Value::MException(_) => "MException",
48 Value::OutputList(_) => "OutputList",
49 }
50}
51
52#[inline]
53pub fn summarize_value(i: usize, v: &Value) -> String {
54 match v {
55 Value::GpuTensor(h) => format!("in#{i}:GpuTensor shape={:?}", h.shape),
56 Value::Tensor(t) => format!("in#{i}:Tensor shape={:?}", t.shape),
57 Value::Num(n) => format!("in#{i}:Num({n:.6})"),
58 Value::Int(n) => format!("in#{i}:Int({})", n.to_i64()),
59 Value::Bool(b) => format!("in#{i}:Bool({})", if *b { 1 } else { 0 }),
60 Value::String(s) => format!("in#{i}:String({})", s),
61 _ => format!("in#{i}:{}", value_kind(v)),
62 }
63}
64
65#[inline]
66fn is_scalarish_runtime_value(value: &Value) -> bool {
67 match value {
68 Value::Num(_) | Value::Int(_) | Value::Bool(_) | Value::Complex(_, _) => true,
69 Value::Tensor(tensor) => is_scalar_shape(&tensor.shape),
70 Value::ComplexTensor(tensor) => is_scalar_shape(&tensor.shape),
71 Value::LogicalArray(array) => is_scalar_shape(&array.shape),
72 Value::GpuTensor(handle) => is_scalar_shape(&handle.shape),
73 Value::CharArray(array) => array.rows * array.cols == 1,
74 _ => false,
75 }
76}
77
78pub fn fusion_span_live_result_count(instructions: &[Instr], span: &InstrSpan) -> Option<usize> {
79 if span.start > span.end || span.end >= instructions.len() {
80 return None;
81 }
82 let mut current_depth = 0usize;
83 for instr in &instructions[span.start..=span.end] {
84 let effect = instr.stack_effect()?;
85 if current_depth < effect.pops {
86 current_depth = effect.pops;
87 }
88 current_depth = current_depth - effect.pops + effect.pushes;
89 }
90 Some(current_depth)
91}
92
93pub fn fusion_span_has_vm_barrier(instructions: &[Instr], span: &InstrSpan) -> bool {
94 if span.start > span.end || span.end >= instructions.len() {
95 return true;
96 }
97 for instr in &instructions[span.start..=span.end] {
98 if matches!(
99 instr,
100 Instr::StoreIndex(_)
101 | Instr::StoreIndexDelete(_)
102 | Instr::StoreSlice(_, _, _, _)
103 | Instr::StoreSliceDelete(_, _, _, _)
104 | Instr::StoreSliceExpr { .. }
105 | Instr::StoreSliceExprDelete { .. }
106 | Instr::StoreIndexCell { .. }
107 | Instr::StoreIndexCellDelete { .. }
108 | Instr::StoreMember(_)
109 | Instr::StoreMemberOrInit(_)
110 | Instr::StoreMemberDynamic
111 | Instr::StoreMemberDynamicOrInit
112 ) {
113 return true;
114 }
115 }
116 fusion_span_live_result_count(instructions, span) != Some(1)
117}
118
119pub struct StackSliceGuard<'a> {
120 stack: *mut Vec<Value>,
121 slice: Option<Vec<Value>>,
122 _marker: std::marker::PhantomData<&'a mut Vec<Value>>,
123}
124
125impl<'a> StackSliceGuard<'a> {
126 pub fn new(stack: &'a mut Vec<Value>, slice_start: usize) -> Self {
127 let slice = stack.split_off(slice_start);
128 Self {
129 stack,
130 slice: Some(slice),
131 _marker: std::marker::PhantomData,
132 }
133 }
134
135 pub fn slice(&self) -> &[Value] {
136 self.slice.as_ref().expect("stack slice missing").as_slice()
137 }
138
139 pub fn commit(mut self) {
140 self.slice = None;
141 }
142}
143
144impl Drop for StackSliceGuard<'_> {
145 fn drop(&mut self) {
146 if let Some(slice) = self.slice.take() {
147 unsafe { (&mut *self.stack).extend(slice) }
148 }
149 }
150}
151
152pub fn gather_fusion_inputs<'a>(
153 plan: &'a runmat_accelerate::FusionGroupPlan,
154 graph: &runmat_accelerate::AccelGraph,
155 stack: &'a mut Vec<Value>,
156 vars: &mut [Value],
157 context: &mut ExecutionContext,
158) -> Result<
159 (
160 StackSliceGuard<'a>,
161 FusionExecutionRequest<'a>,
162 Vec<Option<Value>>,
163 ),
164 RuntimeError,
165> {
166 if plan.group.stack_layout.is_none() && !plan.stack_pattern.is_empty() {
167 return Err(mex(
168 "FusionMissingStackLayout",
169 "fusion: missing compile-time stack layout metadata",
170 ));
171 }
172 let required_stack_operands = plan
173 .group
174 .stack_layout
175 .as_ref()
176 .map(|layout| layout.required_stack_operands)
177 .unwrap_or_else(|| plan.stack_pattern.len());
178 let mut inputs: Vec<Option<Value>> = vec![None; plan.inputs.len()];
179
180 for (idx, value) in &plan.constants {
181 if let Some(slot) = inputs.get_mut(*idx) {
182 if slot.is_none() {
183 *slot = Some(value.clone());
184 }
185 }
186 }
187
188 for (idx, value_id) in plan.inputs.iter().enumerate() {
189 let info = graph
190 .value(*value_id)
191 .ok_or_else(|| format!("fusion: missing value metadata for id {value_id}"))?;
192 match &info.origin {
193 ValueOrigin::Variable { kind, index } => {
194 let value =
195 match kind {
196 VarKind::Global => vars
197 .get(*index)
198 .cloned()
199 .ok_or_else(|| format!("fusion: global var {index} out of range"))?,
200 VarKind::Local => {
201 if let Some(frame) = context.call_stack.last() {
202 let absolute = frame.locals_start + index;
203 context.locals.get(absolute).cloned().ok_or_else(|| {
204 format!("fusion: local var {index} unavailable")
205 })?
206 } else {
207 vars.get(*index).cloned().ok_or_else(|| {
208 format!("fusion: local var {index} unavailable")
209 })?
210 }
211 }
212 };
213 debug_assert!(
214 inputs[idx].is_none(),
215 "fusion: duplicate input slot {} for plan {}",
216 idx,
217 plan.index
218 );
219 inputs[idx] = Some(value);
220 }
221 ValueOrigin::Constant | ValueOrigin::NodeOutput { .. } | ValueOrigin::Unknown => {}
222 }
223 }
224
225 if log::log_enabled!(log::Level::Debug) && interp_engine::fusion_debug_enabled() {
226 let stack_needed_preview = required_stack_operands;
227 let stack_snapshot: Vec<&Value> = stack.iter().rev().take(stack_needed_preview).collect();
228 let stack_kinds: Vec<&'static str> =
229 stack_snapshot.iter().rev().map(|v| value_kind(v)).collect();
230 let input_meta: Vec<String> = plan
231 .inputs
232 .iter()
233 .enumerate()
234 .map(|(i, value_id)| {
235 if let Some(info) = graph.value(*value_id) {
236 format!("#{i}:id={} origin={:?}", value_id, info.origin)
237 } else {
238 format!("#{i}:id={} origin=<missing>", value_id)
239 }
240 })
241 .collect();
242 log::debug!(
243 "fusion group {} gather: stack_depth={} stack_needed={} stack_kinds={:?} pattern={:?} inputs={:?}",
244 plan.index, stack.len(), stack_needed_preview, stack_kinds, &plan.stack_pattern, input_meta
245 );
246 }
247
248 if stack.len() < required_stack_operands {
249 if interp_engine::fusion_debug_enabled() {
250 log::debug!(
251 "fusion stack underflow: plan={} needed={} available={} pattern={:?}",
252 plan.index,
253 required_stack_operands,
254 stack.len(),
255 plan.stack_pattern
256 );
257 }
258 return Err(mex(
259 "FusionStackUnderflow",
260 "fusion: stack underflow gathering inputs",
261 ));
262 }
263 let available = required_stack_operands;
264 let slice_start = stack.len() - available;
265 let stack_guard = StackSliceGuard::new(stack, slice_start);
266 let slice = stack_guard.slice().to_vec();
267 let mut consumed_inputs: Vec<Option<Value>> = vec![None; plan.inputs.len()];
268 let input_positions: HashMap<runmat_accelerate::graph::ValueId, usize> = plan
269 .inputs
270 .iter()
271 .enumerate()
272 .map(|(idx, value_id)| (*value_id, idx))
273 .collect();
274
275 let allow_stack_value = |val: &Value| {
276 if plan.group.kind.is_reduction() {
277 matches!(val, Value::GpuTensor(_) | Value::Tensor(_))
278 } else {
279 true
280 }
281 };
282
283 if let Some(layout) = plan.group.stack_layout.as_ref() {
284 for binding in &layout.bindings {
285 let Some(input_idx) = input_positions.get(&binding.value_id).copied() else {
286 continue;
287 };
288 let Some(val) = slice.get(binding.stack_offset).cloned() else {
289 continue;
290 };
291 consumed_inputs[input_idx] = Some(val.clone());
292 if inputs[input_idx].is_none() && allow_stack_value(&val) {
293 inputs[input_idx] = Some(val);
294 }
295 }
296 } else {
297 for (offset, input_idx) in plan.stack_pattern.iter().enumerate() {
298 let Some(val) = slice.get(offset).cloned() else {
299 continue;
300 };
301 consumed_inputs[*input_idx] = Some(val.clone());
302 if inputs[*input_idx].is_none() && allow_stack_value(&val) {
303 inputs[*input_idx] = Some(val);
304 }
305 }
306 }
307
308 for (idx, slot) in inputs.iter_mut().enumerate() {
309 if slot.is_some() {
310 continue;
311 }
312 let vid = plan.inputs[idx];
313 let info = graph.value(vid);
314 if let Some(info) = info {
315 match &info.origin {
316 ValueOrigin::Variable { kind, index } => {
317 let value_opt = match kind {
318 VarKind::Global => vars.get(*index).cloned(),
319 VarKind::Local => {
320 if let Some(frame) = context.call_stack.last() {
321 let absolute = frame.locals_start + index;
322 context.locals.get(absolute).cloned()
323 } else {
324 vars.get(*index).cloned()
325 }
326 }
327 };
328 if let Some(value) = value_opt {
329 *slot = Some(value);
330 continue;
331 }
332 }
333 ValueOrigin::Constant => {
334 if let Some(value) = plan.const_values.get(&vid) {
335 *slot = Some(value.clone());
336 continue;
337 }
338 }
339 _ => {}
340 }
341 }
342 if slot.is_none() {
343 if let Some(binding) = graph.var_binding(vid) {
344 let value_opt = match binding.kind {
345 VarKind::Global => vars.get(binding.index).cloned(),
346 VarKind::Local => {
347 if let Some(frame) = context.call_stack.last() {
348 let absolute = frame.locals_start + binding.index;
349 context.locals.get(absolute).cloned()
350 } else {
351 vars.get(binding.index).cloned()
352 }
353 }
354 };
355 if let Some(value) = value_opt {
356 *slot = Some(value);
357 continue;
358 }
359 }
360 }
361 if slot.is_none() {
362 if let Some(info) = info {
363 if let ValueOrigin::NodeOutput { node, .. } = info.origin {
364 if let Some(binding) = graph.node_binding(node) {
365 let value_opt = match binding.kind {
366 VarKind::Global => vars.get(binding.index).cloned(),
367 VarKind::Local => {
368 if let Some(frame) = context.call_stack.last() {
369 let absolute = frame.locals_start + binding.index;
370 context.locals.get(absolute).cloned()
371 } else {
372 vars.get(binding.index).cloned()
373 }
374 }
375 };
376 if let Some(value) = value_opt {
377 *slot = Some(value);
378 continue;
379 }
380 }
381 }
382 }
383 }
384 if slot.is_none() {
385 if let Some(value) = plan.const_values.get(&vid) {
386 *slot = Some(value.clone());
387 }
388 }
389 }
390
391 let inputs: Vec<Value> = inputs
392 .into_iter()
393 .map(|opt| opt.ok_or_else(|| mex("FusionMissingInput", "fusion: missing input value")))
394 .collect::<Result<_, _>>()?;
395
396 if log::log_enabled!(log::Level::Debug) {
397 let summaries: Vec<String> = inputs
398 .iter()
399 .enumerate()
400 .map(|(i, v)| summarize_value(i, v))
401 .collect();
402 log::debug!("fusion inputs runtime: [{}]", summaries.join(", "));
403 }
404
405 Ok((
406 stack_guard,
407 FusionExecutionRequest { plan, inputs },
408 consumed_inputs,
409 ))
410}
411
412pub fn write_elementwise_materialized_stores(
413 materialized_stores: Vec<(FusionStoreMaterialization, Value)>,
414 vars: &mut Vec<Value>,
415 context: &mut ExecutionContext,
416) {
417 for (store, value) in materialized_stores {
418 match store.binding.kind {
419 VarKind::Global => {
420 let i = store.binding.index;
421 if i < vars.len() {
422 if let Err(err) = accel_residency::clear_value_excluding(&vars[i], &value) {
423 log::warn!("failed to clear fused global GPU residency: {err}");
424 }
425 }
426 if i >= vars.len() {
427 vars.resize(i + 1, Value::Num(0.0));
428 refresh_workspace_state(vars);
429 }
430 vars[i] = value;
431 }
432 VarKind::Local => {
433 if let Some(frame) = context.call_stack.last() {
434 let absolute = frame.locals_start + store.binding.index;
435 while context.locals.len() <= absolute {
436 context.locals.push(Value::Num(0.0));
437 }
438 if let Err(err) =
439 accel_residency::clear_value_excluding(&context.locals[absolute], &value)
440 {
441 log::warn!("failed to clear fused local GPU residency: {err}");
442 }
443 context.locals[absolute] = value;
444 } else {
445 let i = store.binding.index;
446 if i < vars.len() {
447 if let Err(err) = accel_residency::clear_value_excluding(&vars[i], &value) {
448 log::warn!("failed to clear fused fallback GPU residency: {err}");
449 }
450 }
451 if i >= vars.len() {
452 vars.resize(i + 1, Value::Num(0.0));
453 refresh_workspace_state(vars);
454 }
455 vars[i] = value;
456 }
457 }
458 }
459 }
460}
461
462pub fn execute_fusion_elementwise(
463 request: FusionExecutionRequest<'_>,
464 stack_guard: StackSliceGuard<'_>,
465 vars: &mut Vec<Value>,
466 context: &mut ExecutionContext,
467) -> Result<Value, RuntimeError> {
468 match execute_elementwise(request) {
469 Ok(result) => {
470 write_elementwise_materialized_stores(result.materialized_stores, vars, context);
471 stack_guard.commit();
472 Ok(result.final_value)
473 }
474 Err(err) => Err(mex("FusionExecutionFailed", &err.to_string())),
475 }
476}
477
478pub async fn execute_fusion_special_kind(
479 kind: FusionKind,
480 plan_inputs: &[runmat_accelerate::graph::ValueId],
481 request: FusionExecutionRequest<'_>,
482 stack_guard: StackSliceGuard<'_>,
483) -> Result<Value, RuntimeError> {
484 match kind {
485 FusionKind::CenteredGram => match execute_centered_gram(request).await {
486 Ok(result) => {
487 stack_guard.commit();
488 Ok(result)
489 }
490 Err(err) => Err(mex("FusionExecutionFailed", &err.to_string())),
491 },
492 FusionKind::PowerStepNormalize => match execute_power_step_normalize(request).await {
493 Ok(result) => {
494 stack_guard.commit();
495 Ok(result)
496 }
497 Err(err) => Err(mex("FusionExecutionFailed", &err.to_string())),
498 },
499 FusionKind::ExplainedVariance => {
500 log::debug!("explained variance plan inputs {:?}", plan_inputs);
501 match execute_explained_variance(request).await {
502 Ok(result) => {
503 stack_guard.commit();
504 Ok(result)
505 }
506 Err(err) => {
507 log::debug!("explained variance fusion fallback: {}", err);
508 Err(mex("FusionExecutionFailed", &err.to_string()))
509 }
510 }
511 }
512 FusionKind::MatmulEpilogue => match execute_matmul_epilogue(request).await {
513 Ok(result) => {
514 stack_guard.commit();
515 Ok(result)
516 }
517 Err(err) => Err(mex("FusionExecutionFailed", &err.to_string())),
518 },
519 FusionKind::ImageNormalize => match execute_image_normalize(request).await {
520 Ok(result) => {
521 stack_guard.commit();
522 Ok(result)
523 }
524 Err(err) => Err(mex("FusionExecutionFailed", &err.to_string())),
525 },
526 _ => Err(mex(
527 "FusionUnsupportedKind",
528 "fusion: unsupported fusion kind",
529 )),
530 }
531}
532
533pub struct ReductionGeometry {
534 pub axis: usize,
535 pub reduce_len: usize,
536 pub num_slices: usize,
537}
538
539pub fn resolve_reduction_geometry(
540 plan: &runmat_accelerate::FusionGroupPlan,
541 graph: &runmat_accelerate::AccelGraph,
542 request: &FusionExecutionRequest<'_>,
543 consumed_inputs: &[Option<Value>],
544 vars: &[Value],
545 context: &ExecutionContext,
546) -> Result<ReductionGeometry, RuntimeError> {
547 fn detect_reduce_all(
548 plan: &runmat_accelerate::FusionGroupPlan,
549 graph: &runmat_accelerate::AccelGraph,
550 ) -> bool {
551 let mut reduce_all = matches!(
552 plan.reduction_axes,
553 Some(runmat_accelerate::ReductionAxes::All)
554 );
555 let has_all = reduce_all
556 || plan.constants.values().any(value_is_all_keyword)
557 || plan.const_values.values().any(value_is_all_keyword);
558 if has_all {
559 return true;
560 }
561 for node_id in &plan.group.nodes {
562 if let Some(node) = graph.node(*node_id) {
563 if let runmat_accelerate::graph::AccelNodeLabel::Builtin { name } = &node.label {
564 if name.eq_ignore_ascii_case("mean") {
565 for input_vid in &node.inputs {
566 if let Some(info) = graph.value(*input_vid) {
567 if let Some(constant) = &info.constant {
568 if value_is_all_keyword(constant) {
569 reduce_all = true;
570 break;
571 }
572 }
573 }
574 }
575 }
576 }
577 }
578 if reduce_all {
579 break;
580 }
581 }
582 reduce_all
583 }
584
585 fn resolve_reduction_axis(plan: &runmat_accelerate::FusionGroupPlan) -> (usize, bool) {
586 let mut axis = 0usize;
587 let mut axis_explicit = false;
588 if let Some(runmat_accelerate::ReductionAxes::Explicit(dims)) = &plan.reduction_axes {
589 if let Some(first) = dims.first().copied() {
590 axis = first.saturating_sub(1);
591 axis_explicit = true;
592 }
593 }
594 if let Some(dim_vid) = plan.reduction_dim {
595 if let Some(cv) = plan.const_values.get(&dim_vid) {
596 axis = match cv {
597 Value::Num(n) if *n >= 1.0 => (*n as usize).saturating_sub(1),
598 Value::Int(i) => (i.to_f64() as usize).saturating_sub(1),
599 _ => axis,
600 };
601 axis_explicit = true;
602 } else if let Some(input_idx) = plan.inputs.iter().position(|v| *v == dim_vid) {
603 if let Some(cv) = plan.constants.get(&input_idx) {
604 axis = match cv {
605 Value::Num(n) if *n >= 1.0 => (*n as usize).saturating_sub(1),
606 Value::Int(i) => (i.to_f64() as usize).saturating_sub(1),
607 _ => axis,
608 };
609 axis_explicit = true;
610 }
611 }
612 } else if let Some(dim_const) = plan.constants.get(&1) {
613 axis = match dim_const {
614 Value::Num(n) if *n >= 1.0 => (*n as usize).saturating_sub(1),
615 Value::Int(i) => (i.to_f64() as usize).saturating_sub(1),
616 _ => axis,
617 };
618 axis_explicit = true;
619 }
620 (axis, axis_explicit)
621 }
622
623 fn derive_rows_cols(
624 plan: &runmat_accelerate::FusionGroupPlan,
625 graph: &runmat_accelerate::AccelGraph,
626 request: &FusionExecutionRequest<'_>,
627 consumed_inputs: &[Option<Value>],
628 vars: &[Value],
629 context: &ExecutionContext,
630 ) -> Option<(usize, usize)> {
631 let shape_of = |value: &Value| -> Option<(usize, usize)> {
632 match value {
633 Value::GpuTensor(h) => Some((
634 h.shape.first().copied().unwrap_or(1).max(1),
635 h.shape.get(1).copied().unwrap_or(1).max(1),
636 )),
637 Value::Tensor(t) => Some((
638 t.shape.first().copied().unwrap_or(1).max(1),
639 t.shape.get(1).copied().unwrap_or(1).max(1),
640 )),
641 _ => None,
642 }
643 };
644
645 if let Some(shape) = plan.reduction_data_shape(graph) {
646 if shape.len() >= 2 {
647 return Some((shape[0].max(1), shape[1].max(1)));
648 }
649 if shape.len() == 1 {
650 return Some((shape[0].max(1), 1));
651 }
652 }
653
654 for &vid in &plan.inputs {
655 if let Some(binding) = graph.var_binding(vid) {
656 let value_opt = match binding.kind {
657 VarKind::Global => vars.get(binding.index).cloned(),
658 VarKind::Local => {
659 if let Some(frame) = context.call_stack.last() {
660 let absolute = frame.locals_start + binding.index;
661 context.locals.get(absolute).cloned()
662 } else {
663 vars.get(binding.index).cloned()
664 }
665 }
666 };
667 if let Some(value) = value_opt {
668 if let Some(shape) = shape_of(&value) {
669 return Some(shape);
670 }
671 }
672 }
673 }
674
675 for v in consumed_inputs.iter().filter_map(|v| v.as_ref()) {
676 if let Some(shape) = shape_of(v) {
677 return Some(shape);
678 }
679 }
680
681 if let Some(data_id) = plan.reduction_data {
682 if let Some(input_index) = plan.inputs.iter().position(|vid| *vid == data_id) {
683 if let Some(val) = consumed_inputs.get(input_index).and_then(|v| v.as_ref()) {
684 if let Some(shape) = shape_of(val) {
685 return Some(shape);
686 }
687 }
688 if let Some(val) = request.inputs.get(input_index) {
689 if let Some(shape) = shape_of(val) {
690 return Some(shape);
691 }
692 }
693 }
694 if let Some(info) = graph.value(data_id) {
695 if let ValueOrigin::Variable { kind, index } = &info.origin {
696 let val = match kind {
697 VarKind::Global => vars.get(*index).cloned(),
698 VarKind::Local => {
699 if let Some(frame) = context.call_stack.last() {
700 let absolute = frame.locals_start + index;
701 context.locals.get(absolute).cloned()
702 } else {
703 vars.get(*index).cloned()
704 }
705 }
706 };
707 if let Some(v) = val {
708 if let Some(shape) = shape_of(&v) {
709 return Some(shape);
710 }
711 }
712 }
713 if let ShapeInfo::Tensor(dims) = &info.shape {
714 if !dims.is_empty() {
715 let r = dims.first().and_then(|d| *d).unwrap_or(1);
716 let c = dims.get(1).and_then(|d| *d).unwrap_or(1);
717 return Some((r.max(1), c.max(1)));
718 }
719 }
720 }
721 }
722
723 for v in &request.inputs {
724 if let Some(shape) = shape_of(v) {
725 return Some(shape);
726 }
727 }
728
729 if let ShapeInfo::Tensor(dims) = &plan.group.shape {
730 if !dims.is_empty() {
731 let r = dims.first().and_then(|d| *d).unwrap_or(1);
732 let c = dims.get(1).and_then(|d| *d).unwrap_or(1);
733 return Some((r.max(1), c.max(1)));
734 }
735 }
736 None
737 }
738
739 if log::log_enabled!(log::Level::Debug) {
740 let meta: Vec<String> = plan
741 .inputs
742 .iter()
743 .map(|vid| {
744 if let Some(info) = graph.value(*vid) {
745 format!(
746 "vid={} origin={:?} shape={:?}",
747 vid, info.origin, info.shape
748 )
749 } else {
750 format!("vid={} origin=<missing>", vid)
751 }
752 })
753 .collect();
754 log::debug!("reduction gather meta: [{}]", meta.join(", "));
755 }
756
757 let reduce_all = detect_reduce_all(plan, graph);
758 let (mut axis, axis_explicit) = if reduce_all {
759 (0usize, false)
760 } else {
761 resolve_reduction_axis(plan)
762 };
763 if reduce_all && interp_engine::fusion_debug_enabled() {
764 log::debug!(
765 "fusion reduction (all) meta: data_vid={:?} inputs={:?} stack_pattern={:?}",
766 plan.reduction_data,
767 plan.inputs,
768 plan.stack_pattern
769 );
770 }
771
772 let (r, c) =
773 derive_rows_cols(plan, graph, request, consumed_inputs, vars, context).unwrap_or((1, 1));
774 let (reduce_len, num_slices) = if reduce_all {
775 let total_from_runtime = consumed_inputs
776 .iter()
777 .filter_map(|v| v.as_ref())
778 .chain(request.inputs.iter())
779 .find_map(|value| match value {
780 Value::GpuTensor(handle) => Some(if handle.shape.is_empty() {
781 1
782 } else {
783 handle
784 .shape
785 .iter()
786 .copied()
787 .map(|d| d.max(1))
788 .product::<usize>()
789 }),
790 Value::Tensor(tensor) => Some(if tensor.shape.is_empty() {
791 1
792 } else {
793 tensor
794 .shape
795 .iter()
796 .copied()
797 .map(|d| d.max(1))
798 .product::<usize>()
799 }),
800 _ => None,
801 });
802 let total = plan
803 .reduction_data_shape(graph)
804 .map(|shape| shape.into_iter().map(|d| d.max(1)).product::<usize>())
805 .or(total_from_runtime)
806 .or_else(|| plan.element_count())
807 .filter(|v| *v > 0)
808 .ok_or_else(|| {
809 mex(
810 "FusionReductionExtentUnknown",
811 "fusion: reduction all extent unknown",
812 )
813 })?;
814 if interp_engine::fusion_debug_enabled() {
815 log::debug!(
816 "fusion reduction (all): total_elems={} fallback_rows={} fallback_cols={}",
817 total,
818 r,
819 c
820 );
821 }
822 (total, 1usize)
823 } else {
824 if !axis_explicit {
825 axis = if r == 1 && c > 1 {
826 1
827 } else if r > 1 {
828 0
829 } else {
830 axis
831 };
832 }
833 if interp_engine::fusion_debug_enabled() {
834 if r == 1 && c == 1 {
835 log::debug!(
836 "fusion reduction: unresolved shape (defaulted to 1x1); axis={}, constants={:?}",
837 axis,
838 plan.constants
839 );
840 } else {
841 log::debug!(
842 "fusion reduction: resolved shape rows={} cols={} axis={} constants={:?}",
843 r,
844 c,
845 axis,
846 plan.constants
847 );
848 }
849 }
850 if axis == 0 {
851 (r, c)
852 } else {
853 (c, r)
854 }
855 };
856
857 if interp_engine::fusion_debug_enabled() {
858 log::debug!(
859 "fusion reduction: axis={} reduce_len={} num_slices={} constants={:?}",
860 axis,
861 reduce_len,
862 num_slices,
863 plan.constants
864 );
865 }
866
867 let looks_wrong = reduce_len == 1 && num_slices == 1 && {
868 let mut big = false;
869 let mut check_val = |v: &Value| match v {
870 Value::GpuTensor(h) => {
871 let prod = h.shape.iter().copied().product::<usize>();
872 if prod > 1 {
873 big = true;
874 }
875 }
876 Value::Tensor(t) => {
877 let prod = t.shape.iter().copied().product::<usize>();
878 if prod > 1 {
879 big = true;
880 }
881 }
882 _ => {}
883 };
884 for v in consumed_inputs.iter().filter_map(|v| v.as_ref()) {
885 check_val(v);
886 }
887 for v in &request.inputs {
888 check_val(v);
889 }
890 big
891 };
892 if looks_wrong {
893 log::debug!("fusion reduction: skipping fusion due to unresolved shape; falling back to provider path");
894 return Err(mex(
895 "FusionReductionShapeUnresolved",
896 "fusion: reduction shape unresolved",
897 ));
898 }
899 if std::env::var("RUNMAT_DISABLE_FUSED_REDUCTION")
900 .ok()
901 .as_deref()
902 == Some("1")
903 {
904 return Err(mex(
905 "FusionReductionDisabled",
906 "fusion: fused reductions disabled",
907 ));
908 }
909
910 Ok(ReductionGeometry {
911 axis,
912 reduce_len,
913 num_slices,
914 })
915}
916
917pub fn execute_fusion_reduction(
918 plan: &runmat_accelerate::FusionGroupPlan,
919 graph: &runmat_accelerate::AccelGraph,
920 request: FusionExecutionRequest<'_>,
921 consumed_inputs: &[Option<Value>],
922 stack_guard: StackSliceGuard<'_>,
923 vars: &[Value],
924 context: &ExecutionContext,
925) -> Result<Value, RuntimeError> {
926 let geom = resolve_reduction_geometry(plan, graph, &request, consumed_inputs, vars, context)?;
927 match execute_reduction(request, geom.reduce_len, geom.num_slices, 256u32) {
928 Ok(result) => {
929 stack_guard.commit();
930 Ok(result)
931 }
932 Err(err) => Err(mex("FusionExecutionFailed", &err.to_string())),
933 }
934}
935
936pub async fn try_execute_fusion_group(
937 plan: &runmat_accelerate::FusionGroupPlan,
938 graph: &runmat_accelerate::AccelGraph,
939 stack: &mut Vec<Value>,
940 vars: &mut Vec<Value>,
941 context: &mut ExecutionContext,
942) -> Result<Value, RuntimeError> {
943 let (stack_guard, request, consumed_inputs) =
944 gather_fusion_inputs(plan, graph, stack, vars, context)?;
945 if plan.group.kind.is_elementwise()
946 && !request.inputs.is_empty()
947 && request.inputs.iter().all(is_scalarish_runtime_value)
948 {
949 return Err(mex(
950 "FusionScalarBypass",
951 "fusion: bypass scalar-only elementwise group",
952 ));
953 }
954 log::debug!(
955 "dispatch fusion kind {:?}, supported {}",
956 plan.group.kind,
957 plan.kernel.supported
958 );
959 if plan.group.kind.is_elementwise() {
960 execute_fusion_elementwise(request, stack_guard, vars, context)
961 } else if plan.group.kind.is_reduction() {
962 execute_fusion_reduction(
963 plan,
964 graph,
965 request,
966 &consumed_inputs,
967 stack_guard,
968 vars,
969 context,
970 )
971 } else {
972 execute_fusion_special_kind(plan.group.kind.clone(), &plan.inputs, request, stack_guard)
973 .await
974 }
975}
976
977#[cfg(all(test, feature = "native-accel"))]
978mod tests {
979 use super::write_elementwise_materialized_stores;
980 use crate::bytecode::program::ExecutionContext;
981 use runmat_accelerate::fusion::FusionStoreMaterialization;
982 use runmat_accelerate::fusion_residency;
983 use runmat_accelerate::graph::VarBinding;
984 use runmat_accelerate::VarKind;
985 use runmat_accelerate_api::GpuTensorHandle;
986 use runmat_builtins::Value;
987
988 #[test]
989 fn fusion_writeback_preserves_shared_gpu_handles() {
990 let shared = GpuTensorHandle {
991 shape: vec![1],
992 device_id: 17,
993 buffer_id: 17001,
994 };
995 let old_only = GpuTensorHandle {
996 shape: vec![1],
997 device_id: 17,
998 buffer_id: 17002,
999 };
1000 fusion_residency::mark(&shared);
1001 fusion_residency::mark(&old_only);
1002 assert!(fusion_residency::is_resident(&shared));
1003 assert!(fusion_residency::is_resident(&old_only));
1004
1005 let mut vars = vec![Value::OutputList(vec![
1006 Value::GpuTensor(shared.clone()),
1007 Value::GpuTensor(old_only.clone()),
1008 ])];
1009 let mut context = ExecutionContext {
1010 call_stack: Vec::new(),
1011 locals: Vec::new(),
1012 instruction_pointer: 0,
1013 spawned_task_ids: std::collections::HashSet::new(),
1014 next_spawn_task_id: 0,
1015 };
1016 write_elementwise_materialized_stores(
1017 vec![(
1018 FusionStoreMaterialization {
1019 value_id: 1,
1020 binding: VarBinding {
1021 kind: VarKind::Global,
1022 index: 0,
1023 },
1024 },
1025 Value::GpuTensor(shared.clone()),
1026 )],
1027 &mut vars,
1028 &mut context,
1029 );
1030
1031 assert!(fusion_residency::is_resident(&shared));
1032 assert!(!fusion_residency::is_resident(&old_only));
1033 fusion_residency::clear(&shared);
1034 }
1035}