1use crate::accel::residency as accel_residency;
2use crate::bytecode::program::ExecutionContext;
3use crate::bytecode::Instr;
4use crate::interpreter::engine as interp_engine;
5use crate::interpreter::errors::mex;
6use crate::runtime::workspace::refresh_workspace_state;
7use runmat_accelerate::fusion::FusionStoreMaterialization;
8use runmat_accelerate::fusion_exec::{
9 execute_centered_gram, execute_elementwise, execute_explained_variance,
10 execute_image_normalize, execute_matmul_epilogue, execute_power_step_normalize,
11 execute_reduction, FusionExecutionRequest,
12};
13use runmat_accelerate::InstrSpan;
14use runmat_accelerate::{value_is_all_keyword, FusionKind, ShapeInfo, ValueOrigin, VarKind};
15use runmat_builtins::Value;
16use runmat_runtime::builtins::common::shape::is_scalar_shape;
17use runmat_runtime::RuntimeError;
18use std::collections::HashMap;
19
20#[inline]
21pub fn value_kind(value: &Value) -> &'static str {
22 match value {
23 Value::Int(_) => "Int",
24 Value::Num(_) => "Num",
25 Value::Complex(_, _) => "Complex",
26 Value::Bool(_) => "Bool",
27 Value::LogicalArray(_) => "LogicalArray",
28 Value::String(_) => "String",
29 Value::StringArray(_) => "StringArray",
30 Value::CharArray(_) => "CharArray",
31 Value::Tensor(_) => "Tensor",
32 Value::ComplexTensor(_) => "ComplexTensor",
33 Value::Cell(_) => "Cell",
34 Value::Struct(_) => "Struct",
35 Value::GpuTensor(_) => "GpuTensor",
36 Value::Object(_) => "Object",
37 Value::HandleObject(_) => "HandleObject",
38 Value::Listener(_) => "Listener",
39 Value::FunctionHandle(_)
40 | Value::ExternalFunctionHandle(_)
41 | Value::MethodFunctionHandle(_) => "FunctionHandle",
42 Value::BoundFunctionHandle { .. } => "FunctionHandle",
43 Value::Closure(_) => "Closure",
44 Value::ClassRef(_) => "ClassRef",
45 Value::MException(_) => "MException",
46 Value::OutputList(_) => "OutputList",
47 }
48}
49
50#[inline]
51pub fn summarize_value(i: usize, v: &Value) -> String {
52 match v {
53 Value::GpuTensor(h) => format!("in#{i}:GpuTensor shape={:?}", h.shape),
54 Value::Tensor(t) => format!("in#{i}:Tensor shape={:?}", t.shape),
55 Value::Num(n) => format!("in#{i}:Num({n:.6})"),
56 Value::Int(n) => format!("in#{i}:Int({})", n.to_i64()),
57 Value::Bool(b) => format!("in#{i}:Bool({})", if *b { 1 } else { 0 }),
58 Value::String(s) => format!("in#{i}:String({})", s),
59 _ => format!("in#{i}:{}", value_kind(v)),
60 }
61}
62
63#[inline]
64fn is_scalarish_runtime_value(value: &Value) -> bool {
65 match value {
66 Value::Num(_) | Value::Int(_) | Value::Bool(_) | Value::Complex(_, _) => true,
67 Value::Tensor(tensor) => is_scalar_shape(&tensor.shape),
68 Value::ComplexTensor(tensor) => is_scalar_shape(&tensor.shape),
69 Value::LogicalArray(array) => is_scalar_shape(&array.shape),
70 Value::GpuTensor(handle) => is_scalar_shape(&handle.shape),
71 Value::CharArray(array) => array.rows * array.cols == 1,
72 _ => false,
73 }
74}
75
76pub fn fusion_span_live_result_count(instructions: &[Instr], span: &InstrSpan) -> Option<usize> {
77 if span.start > span.end || span.end >= instructions.len() {
78 return None;
79 }
80 let mut current_depth = 0usize;
81 for instr in &instructions[span.start..=span.end] {
82 let effect = instr.stack_effect()?;
83 if current_depth < effect.pops {
84 current_depth = effect.pops;
85 }
86 current_depth = current_depth - effect.pops + effect.pushes;
87 }
88 Some(current_depth)
89}
90
91pub fn fusion_span_has_vm_barrier(instructions: &[Instr], span: &InstrSpan) -> bool {
92 if span.start > span.end || span.end >= instructions.len() {
93 return true;
94 }
95 for instr in &instructions[span.start..=span.end] {
96 if matches!(
97 instr,
98 Instr::StoreIndex(_)
99 | Instr::StoreIndexDelete(_)
100 | Instr::StoreSlice(_, _, _, _)
101 | Instr::StoreSliceDelete(_, _, _, _)
102 | Instr::StoreSliceExpr { .. }
103 | Instr::StoreSliceExprDelete { .. }
104 | Instr::StoreIndexCell { .. }
105 | Instr::StoreIndexCellDelete { .. }
106 | Instr::StoreMember(_)
107 | Instr::StoreMemberOrInit(_)
108 | Instr::StoreMemberDynamic
109 | Instr::StoreMemberDynamicOrInit
110 ) {
111 return true;
112 }
113 }
114 fusion_span_live_result_count(instructions, span) != Some(1)
115}
116
117pub struct StackSliceGuard<'a> {
118 stack: *mut Vec<Value>,
119 slice: Option<Vec<Value>>,
120 _marker: std::marker::PhantomData<&'a mut Vec<Value>>,
121}
122
123impl<'a> StackSliceGuard<'a> {
124 pub fn new(stack: &'a mut Vec<Value>, slice_start: usize) -> Self {
125 let slice = stack.split_off(slice_start);
126 Self {
127 stack,
128 slice: Some(slice),
129 _marker: std::marker::PhantomData,
130 }
131 }
132
133 pub fn slice(&self) -> &[Value] {
134 self.slice.as_ref().expect("stack slice missing").as_slice()
135 }
136
137 pub fn commit(mut self) {
138 self.slice = None;
139 }
140}
141
142impl Drop for StackSliceGuard<'_> {
143 fn drop(&mut self) {
144 if let Some(slice) = self.slice.take() {
145 unsafe { (&mut *self.stack).extend(slice) }
146 }
147 }
148}
149
150pub fn gather_fusion_inputs<'a>(
151 plan: &'a runmat_accelerate::FusionGroupPlan,
152 graph: &runmat_accelerate::AccelGraph,
153 stack: &'a mut Vec<Value>,
154 vars: &mut [Value],
155 context: &mut ExecutionContext,
156) -> Result<
157 (
158 StackSliceGuard<'a>,
159 FusionExecutionRequest<'a>,
160 Vec<Option<Value>>,
161 ),
162 RuntimeError,
163> {
164 if plan.group.stack_layout.is_none() && !plan.stack_pattern.is_empty() {
165 return Err(mex(
166 "FusionMissingStackLayout",
167 "fusion: missing compile-time stack layout metadata",
168 ));
169 }
170 let required_stack_operands = plan
171 .group
172 .stack_layout
173 .as_ref()
174 .map(|layout| layout.required_stack_operands)
175 .unwrap_or_else(|| plan.stack_pattern.len());
176 let mut inputs: Vec<Option<Value>> = vec![None; plan.inputs.len()];
177
178 for (idx, value) in &plan.constants {
179 if let Some(slot) = inputs.get_mut(*idx) {
180 if slot.is_none() {
181 *slot = Some(value.clone());
182 }
183 }
184 }
185
186 for (idx, value_id) in plan.inputs.iter().enumerate() {
187 let info = graph
188 .value(*value_id)
189 .ok_or_else(|| format!("fusion: missing value metadata for id {value_id}"))?;
190 match &info.origin {
191 ValueOrigin::Variable { kind, index } => {
192 let value =
193 match kind {
194 VarKind::Global => vars
195 .get(*index)
196 .cloned()
197 .ok_or_else(|| format!("fusion: global var {index} out of range"))?,
198 VarKind::Local => {
199 if let Some(frame) = context.call_stack.last() {
200 let absolute = frame.locals_start + index;
201 context.locals.get(absolute).cloned().ok_or_else(|| {
202 format!("fusion: local var {index} unavailable")
203 })?
204 } else {
205 vars.get(*index).cloned().ok_or_else(|| {
206 format!("fusion: local var {index} unavailable")
207 })?
208 }
209 }
210 };
211 debug_assert!(
212 inputs[idx].is_none(),
213 "fusion: duplicate input slot {} for plan {}",
214 idx,
215 plan.index
216 );
217 inputs[idx] = Some(value);
218 }
219 ValueOrigin::Constant | ValueOrigin::NodeOutput { .. } | ValueOrigin::Unknown => {}
220 }
221 }
222
223 if log::log_enabled!(log::Level::Debug) && interp_engine::fusion_debug_enabled() {
224 let stack_needed_preview = required_stack_operands;
225 let stack_snapshot: Vec<&Value> = stack.iter().rev().take(stack_needed_preview).collect();
226 let stack_kinds: Vec<&'static str> =
227 stack_snapshot.iter().rev().map(|v| value_kind(v)).collect();
228 let input_meta: Vec<String> = plan
229 .inputs
230 .iter()
231 .enumerate()
232 .map(|(i, value_id)| {
233 if let Some(info) = graph.value(*value_id) {
234 format!("#{i}:id={} origin={:?}", value_id, info.origin)
235 } else {
236 format!("#{i}:id={} origin=<missing>", value_id)
237 }
238 })
239 .collect();
240 log::debug!(
241 "fusion group {} gather: stack_depth={} stack_needed={} stack_kinds={:?} pattern={:?} inputs={:?}",
242 plan.index, stack.len(), stack_needed_preview, stack_kinds, &plan.stack_pattern, input_meta
243 );
244 }
245
246 if stack.len() < required_stack_operands {
247 if interp_engine::fusion_debug_enabled() {
248 log::debug!(
249 "fusion stack underflow: plan={} needed={} available={} pattern={:?}",
250 plan.index,
251 required_stack_operands,
252 stack.len(),
253 plan.stack_pattern
254 );
255 }
256 return Err(mex(
257 "FusionStackUnderflow",
258 "fusion: stack underflow gathering inputs",
259 ));
260 }
261 let available = required_stack_operands;
262 let slice_start = stack.len() - available;
263 let stack_guard = StackSliceGuard::new(stack, slice_start);
264 let slice = stack_guard.slice().to_vec();
265 let mut consumed_inputs: Vec<Option<Value>> = vec![None; plan.inputs.len()];
266 let input_positions: HashMap<runmat_accelerate::graph::ValueId, usize> = plan
267 .inputs
268 .iter()
269 .enumerate()
270 .map(|(idx, value_id)| (*value_id, idx))
271 .collect();
272
273 let allow_stack_value = |val: &Value| {
274 if plan.group.kind.is_reduction() {
275 matches!(val, Value::GpuTensor(_) | Value::Tensor(_))
276 } else {
277 true
278 }
279 };
280
281 if let Some(layout) = plan.group.stack_layout.as_ref() {
282 for binding in &layout.bindings {
283 let Some(input_idx) = input_positions.get(&binding.value_id).copied() else {
284 continue;
285 };
286 let Some(val) = slice.get(binding.stack_offset).cloned() else {
287 continue;
288 };
289 consumed_inputs[input_idx] = Some(val.clone());
290 if inputs[input_idx].is_none() && allow_stack_value(&val) {
291 inputs[input_idx] = Some(val);
292 }
293 }
294 } else {
295 for (offset, input_idx) in plan.stack_pattern.iter().enumerate() {
296 let Some(val) = slice.get(offset).cloned() else {
297 continue;
298 };
299 consumed_inputs[*input_idx] = Some(val.clone());
300 if inputs[*input_idx].is_none() && allow_stack_value(&val) {
301 inputs[*input_idx] = Some(val);
302 }
303 }
304 }
305
306 for (idx, slot) in inputs.iter_mut().enumerate() {
307 if slot.is_some() {
308 continue;
309 }
310 let vid = plan.inputs[idx];
311 let info = graph.value(vid);
312 if let Some(info) = info {
313 match &info.origin {
314 ValueOrigin::Variable { kind, index } => {
315 let value_opt = match kind {
316 VarKind::Global => vars.get(*index).cloned(),
317 VarKind::Local => {
318 if let Some(frame) = context.call_stack.last() {
319 let absolute = frame.locals_start + index;
320 context.locals.get(absolute).cloned()
321 } else {
322 vars.get(*index).cloned()
323 }
324 }
325 };
326 if let Some(value) = value_opt {
327 *slot = Some(value);
328 continue;
329 }
330 }
331 ValueOrigin::Constant => {
332 if let Some(value) = plan.const_values.get(&vid) {
333 *slot = Some(value.clone());
334 continue;
335 }
336 }
337 _ => {}
338 }
339 }
340 if slot.is_none() {
341 if let Some(binding) = graph.var_binding(vid) {
342 let value_opt = match binding.kind {
343 VarKind::Global => vars.get(binding.index).cloned(),
344 VarKind::Local => {
345 if let Some(frame) = context.call_stack.last() {
346 let absolute = frame.locals_start + binding.index;
347 context.locals.get(absolute).cloned()
348 } else {
349 vars.get(binding.index).cloned()
350 }
351 }
352 };
353 if let Some(value) = value_opt {
354 *slot = Some(value);
355 continue;
356 }
357 }
358 }
359 if slot.is_none() {
360 if let Some(info) = info {
361 if let ValueOrigin::NodeOutput { node, .. } = info.origin {
362 if let Some(binding) = graph.node_binding(node) {
363 let value_opt = match binding.kind {
364 VarKind::Global => vars.get(binding.index).cloned(),
365 VarKind::Local => {
366 if let Some(frame) = context.call_stack.last() {
367 let absolute = frame.locals_start + binding.index;
368 context.locals.get(absolute).cloned()
369 } else {
370 vars.get(binding.index).cloned()
371 }
372 }
373 };
374 if let Some(value) = value_opt {
375 *slot = Some(value);
376 continue;
377 }
378 }
379 }
380 }
381 }
382 if slot.is_none() {
383 if let Some(value) = plan.const_values.get(&vid) {
384 *slot = Some(value.clone());
385 }
386 }
387 }
388
389 let inputs: Vec<Value> = inputs
390 .into_iter()
391 .map(|opt| opt.ok_or_else(|| mex("FusionMissingInput", "fusion: missing input value")))
392 .collect::<Result<_, _>>()?;
393
394 if log::log_enabled!(log::Level::Debug) {
395 let summaries: Vec<String> = inputs
396 .iter()
397 .enumerate()
398 .map(|(i, v)| summarize_value(i, v))
399 .collect();
400 log::debug!("fusion inputs runtime: [{}]", summaries.join(", "));
401 }
402
403 Ok((
404 stack_guard,
405 FusionExecutionRequest { plan, inputs },
406 consumed_inputs,
407 ))
408}
409
410pub fn write_elementwise_materialized_stores(
411 materialized_stores: Vec<(FusionStoreMaterialization, Value)>,
412 vars: &mut Vec<Value>,
413 context: &mut ExecutionContext,
414) {
415 for (store, value) in materialized_stores {
416 match store.binding.kind {
417 VarKind::Global => {
418 let i = store.binding.index;
419 if i < vars.len() {
420 accel_residency::clear_value_excluding(&vars[i], &value);
421 }
422 if i >= vars.len() {
423 vars.resize(i + 1, Value::Num(0.0));
424 refresh_workspace_state(vars);
425 }
426 vars[i] = value;
427 }
428 VarKind::Local => {
429 if let Some(frame) = context.call_stack.last() {
430 let absolute = frame.locals_start + store.binding.index;
431 while context.locals.len() <= absolute {
432 context.locals.push(Value::Num(0.0));
433 }
434 accel_residency::clear_value_excluding(&context.locals[absolute], &value);
435 context.locals[absolute] = value;
436 } else {
437 let i = store.binding.index;
438 if i < vars.len() {
439 accel_residency::clear_value_excluding(&vars[i], &value);
440 }
441 if i >= vars.len() {
442 vars.resize(i + 1, Value::Num(0.0));
443 refresh_workspace_state(vars);
444 }
445 vars[i] = value;
446 }
447 }
448 }
449 }
450}
451
452pub fn execute_fusion_elementwise(
453 request: FusionExecutionRequest<'_>,
454 stack_guard: StackSliceGuard<'_>,
455 vars: &mut Vec<Value>,
456 context: &mut ExecutionContext,
457) -> Result<Value, RuntimeError> {
458 match execute_elementwise(request) {
459 Ok(result) => {
460 write_elementwise_materialized_stores(result.materialized_stores, vars, context);
461 stack_guard.commit();
462 Ok(result.final_value)
463 }
464 Err(err) => Err(mex("FusionExecutionFailed", &err.to_string())),
465 }
466}
467
468pub async fn execute_fusion_special_kind(
469 kind: FusionKind,
470 plan_inputs: &[runmat_accelerate::graph::ValueId],
471 request: FusionExecutionRequest<'_>,
472 stack_guard: StackSliceGuard<'_>,
473) -> Result<Value, RuntimeError> {
474 match kind {
475 FusionKind::CenteredGram => match execute_centered_gram(request).await {
476 Ok(result) => {
477 stack_guard.commit();
478 Ok(result)
479 }
480 Err(err) => Err(mex("FusionExecutionFailed", &err.to_string())),
481 },
482 FusionKind::PowerStepNormalize => match execute_power_step_normalize(request).await {
483 Ok(result) => {
484 stack_guard.commit();
485 Ok(result)
486 }
487 Err(err) => Err(mex("FusionExecutionFailed", &err.to_string())),
488 },
489 FusionKind::ExplainedVariance => {
490 log::debug!("explained variance plan inputs {:?}", plan_inputs);
491 match execute_explained_variance(request).await {
492 Ok(result) => {
493 stack_guard.commit();
494 Ok(result)
495 }
496 Err(err) => {
497 log::debug!("explained variance fusion fallback: {}", err);
498 Err(mex("FusionExecutionFailed", &err.to_string()))
499 }
500 }
501 }
502 FusionKind::MatmulEpilogue => match execute_matmul_epilogue(request).await {
503 Ok(result) => {
504 stack_guard.commit();
505 Ok(result)
506 }
507 Err(err) => Err(mex("FusionExecutionFailed", &err.to_string())),
508 },
509 FusionKind::ImageNormalize => match execute_image_normalize(request).await {
510 Ok(result) => {
511 stack_guard.commit();
512 Ok(result)
513 }
514 Err(err) => Err(mex("FusionExecutionFailed", &err.to_string())),
515 },
516 _ => Err(mex(
517 "FusionUnsupportedKind",
518 "fusion: unsupported fusion kind",
519 )),
520 }
521}
522
523pub struct ReductionGeometry {
524 pub axis: usize,
525 pub reduce_len: usize,
526 pub num_slices: usize,
527}
528
529pub fn resolve_reduction_geometry(
530 plan: &runmat_accelerate::FusionGroupPlan,
531 graph: &runmat_accelerate::AccelGraph,
532 request: &FusionExecutionRequest<'_>,
533 consumed_inputs: &[Option<Value>],
534 vars: &[Value],
535 context: &ExecutionContext,
536) -> Result<ReductionGeometry, RuntimeError> {
537 fn detect_reduce_all(
538 plan: &runmat_accelerate::FusionGroupPlan,
539 graph: &runmat_accelerate::AccelGraph,
540 ) -> bool {
541 let mut reduce_all = matches!(
542 plan.reduction_axes,
543 Some(runmat_accelerate::ReductionAxes::All)
544 );
545 let has_all = reduce_all
546 || plan.constants.values().any(value_is_all_keyword)
547 || plan.const_values.values().any(value_is_all_keyword);
548 if has_all {
549 return true;
550 }
551 for node_id in &plan.group.nodes {
552 if let Some(node) = graph.node(*node_id) {
553 if let runmat_accelerate::graph::AccelNodeLabel::Builtin { name } = &node.label {
554 if name.eq_ignore_ascii_case("mean") {
555 for input_vid in &node.inputs {
556 if let Some(info) = graph.value(*input_vid) {
557 if let Some(constant) = &info.constant {
558 if value_is_all_keyword(constant) {
559 reduce_all = true;
560 break;
561 }
562 }
563 }
564 }
565 }
566 }
567 }
568 if reduce_all {
569 break;
570 }
571 }
572 reduce_all
573 }
574
575 fn resolve_reduction_axis(plan: &runmat_accelerate::FusionGroupPlan) -> (usize, bool) {
576 let mut axis = 0usize;
577 let mut axis_explicit = false;
578 if let Some(runmat_accelerate::ReductionAxes::Explicit(dims)) = &plan.reduction_axes {
579 if let Some(first) = dims.first().copied() {
580 axis = first.saturating_sub(1);
581 axis_explicit = true;
582 }
583 }
584 if let Some(dim_vid) = plan.reduction_dim {
585 if let Some(cv) = plan.const_values.get(&dim_vid) {
586 axis = match cv {
587 Value::Num(n) if *n >= 1.0 => (*n as usize).saturating_sub(1),
588 Value::Int(i) => (i.to_f64() as usize).saturating_sub(1),
589 _ => axis,
590 };
591 axis_explicit = true;
592 } else if let Some(input_idx) = plan.inputs.iter().position(|v| *v == dim_vid) {
593 if let Some(cv) = plan.constants.get(&input_idx) {
594 axis = match cv {
595 Value::Num(n) if *n >= 1.0 => (*n as usize).saturating_sub(1),
596 Value::Int(i) => (i.to_f64() as usize).saturating_sub(1),
597 _ => axis,
598 };
599 axis_explicit = true;
600 }
601 }
602 } else if let Some(dim_const) = plan.constants.get(&1) {
603 axis = match dim_const {
604 Value::Num(n) if *n >= 1.0 => (*n as usize).saturating_sub(1),
605 Value::Int(i) => (i.to_f64() as usize).saturating_sub(1),
606 _ => axis,
607 };
608 axis_explicit = true;
609 }
610 (axis, axis_explicit)
611 }
612
613 fn derive_rows_cols(
614 plan: &runmat_accelerate::FusionGroupPlan,
615 graph: &runmat_accelerate::AccelGraph,
616 request: &FusionExecutionRequest<'_>,
617 consumed_inputs: &[Option<Value>],
618 vars: &[Value],
619 context: &ExecutionContext,
620 ) -> Option<(usize, usize)> {
621 let shape_of = |value: &Value| -> Option<(usize, usize)> {
622 match value {
623 Value::GpuTensor(h) => Some((
624 h.shape.first().copied().unwrap_or(1).max(1),
625 h.shape.get(1).copied().unwrap_or(1).max(1),
626 )),
627 Value::Tensor(t) => Some((
628 t.shape.first().copied().unwrap_or(1).max(1),
629 t.shape.get(1).copied().unwrap_or(1).max(1),
630 )),
631 _ => None,
632 }
633 };
634
635 if let Some(shape) = plan.reduction_data_shape(graph) {
636 if shape.len() >= 2 {
637 return Some((shape[0].max(1), shape[1].max(1)));
638 }
639 if shape.len() == 1 {
640 return Some((shape[0].max(1), 1));
641 }
642 }
643
644 for &vid in &plan.inputs {
645 if let Some(binding) = graph.var_binding(vid) {
646 let value_opt = match binding.kind {
647 VarKind::Global => vars.get(binding.index).cloned(),
648 VarKind::Local => {
649 if let Some(frame) = context.call_stack.last() {
650 let absolute = frame.locals_start + binding.index;
651 context.locals.get(absolute).cloned()
652 } else {
653 vars.get(binding.index).cloned()
654 }
655 }
656 };
657 if let Some(value) = value_opt {
658 if let Some(shape) = shape_of(&value) {
659 return Some(shape);
660 }
661 }
662 }
663 }
664
665 for v in consumed_inputs.iter().filter_map(|v| v.as_ref()) {
666 if let Some(shape) = shape_of(v) {
667 return Some(shape);
668 }
669 }
670
671 if let Some(data_id) = plan.reduction_data {
672 if let Some(input_index) = plan.inputs.iter().position(|vid| *vid == data_id) {
673 if let Some(val) = consumed_inputs.get(input_index).and_then(|v| v.as_ref()) {
674 if let Some(shape) = shape_of(val) {
675 return Some(shape);
676 }
677 }
678 if let Some(val) = request.inputs.get(input_index) {
679 if let Some(shape) = shape_of(val) {
680 return Some(shape);
681 }
682 }
683 }
684 if let Some(info) = graph.value(data_id) {
685 if let ValueOrigin::Variable { kind, index } = &info.origin {
686 let val = match kind {
687 VarKind::Global => vars.get(*index).cloned(),
688 VarKind::Local => {
689 if let Some(frame) = context.call_stack.last() {
690 let absolute = frame.locals_start + index;
691 context.locals.get(absolute).cloned()
692 } else {
693 vars.get(*index).cloned()
694 }
695 }
696 };
697 if let Some(v) = val {
698 if let Some(shape) = shape_of(&v) {
699 return Some(shape);
700 }
701 }
702 }
703 if let ShapeInfo::Tensor(dims) = &info.shape {
704 if !dims.is_empty() {
705 let r = dims.first().and_then(|d| *d).unwrap_or(1);
706 let c = dims.get(1).and_then(|d| *d).unwrap_or(1);
707 return Some((r.max(1), c.max(1)));
708 }
709 }
710 }
711 }
712
713 for v in &request.inputs {
714 if let Some(shape) = shape_of(v) {
715 return Some(shape);
716 }
717 }
718
719 if let ShapeInfo::Tensor(dims) = &plan.group.shape {
720 if !dims.is_empty() {
721 let r = dims.first().and_then(|d| *d).unwrap_or(1);
722 let c = dims.get(1).and_then(|d| *d).unwrap_or(1);
723 return Some((r.max(1), c.max(1)));
724 }
725 }
726 None
727 }
728
729 if log::log_enabled!(log::Level::Debug) {
730 let meta: Vec<String> = plan
731 .inputs
732 .iter()
733 .map(|vid| {
734 if let Some(info) = graph.value(*vid) {
735 format!(
736 "vid={} origin={:?} shape={:?}",
737 vid, info.origin, info.shape
738 )
739 } else {
740 format!("vid={} origin=<missing>", vid)
741 }
742 })
743 .collect();
744 log::debug!("reduction gather meta: [{}]", meta.join(", "));
745 }
746
747 let reduce_all = detect_reduce_all(plan, graph);
748 let (mut axis, axis_explicit) = if reduce_all {
749 (0usize, false)
750 } else {
751 resolve_reduction_axis(plan)
752 };
753 if reduce_all && interp_engine::fusion_debug_enabled() {
754 log::debug!(
755 "fusion reduction (all) meta: data_vid={:?} inputs={:?} stack_pattern={:?}",
756 plan.reduction_data,
757 plan.inputs,
758 plan.stack_pattern
759 );
760 }
761
762 let (r, c) =
763 derive_rows_cols(plan, graph, request, consumed_inputs, vars, context).unwrap_or((1, 1));
764 let (reduce_len, num_slices) = if reduce_all {
765 let total_from_runtime = consumed_inputs
766 .iter()
767 .filter_map(|v| v.as_ref())
768 .chain(request.inputs.iter())
769 .find_map(|value| match value {
770 Value::GpuTensor(handle) => Some(if handle.shape.is_empty() {
771 1
772 } else {
773 handle
774 .shape
775 .iter()
776 .copied()
777 .map(|d| d.max(1))
778 .product::<usize>()
779 }),
780 Value::Tensor(tensor) => Some(if tensor.shape.is_empty() {
781 1
782 } else {
783 tensor
784 .shape
785 .iter()
786 .copied()
787 .map(|d| d.max(1))
788 .product::<usize>()
789 }),
790 _ => None,
791 });
792 let total = plan
793 .reduction_data_shape(graph)
794 .map(|shape| shape.into_iter().map(|d| d.max(1)).product::<usize>())
795 .or(total_from_runtime)
796 .or_else(|| plan.element_count())
797 .filter(|v| *v > 0)
798 .ok_or_else(|| {
799 mex(
800 "FusionReductionExtentUnknown",
801 "fusion: reduction all extent unknown",
802 )
803 })?;
804 if interp_engine::fusion_debug_enabled() {
805 log::debug!(
806 "fusion reduction (all): total_elems={} fallback_rows={} fallback_cols={}",
807 total,
808 r,
809 c
810 );
811 }
812 (total, 1usize)
813 } else {
814 if !axis_explicit {
815 axis = if r == 1 && c > 1 {
816 1
817 } else if r > 1 {
818 0
819 } else {
820 axis
821 };
822 }
823 if interp_engine::fusion_debug_enabled() {
824 if r == 1 && c == 1 {
825 log::debug!(
826 "fusion reduction: unresolved shape (defaulted to 1x1); axis={}, constants={:?}",
827 axis,
828 plan.constants
829 );
830 } else {
831 log::debug!(
832 "fusion reduction: resolved shape rows={} cols={} axis={} constants={:?}",
833 r,
834 c,
835 axis,
836 plan.constants
837 );
838 }
839 }
840 if axis == 0 {
841 (r, c)
842 } else {
843 (c, r)
844 }
845 };
846
847 if interp_engine::fusion_debug_enabled() {
848 log::debug!(
849 "fusion reduction: axis={} reduce_len={} num_slices={} constants={:?}",
850 axis,
851 reduce_len,
852 num_slices,
853 plan.constants
854 );
855 }
856
857 let looks_wrong = reduce_len == 1 && num_slices == 1 && {
858 let mut big = false;
859 let mut check_val = |v: &Value| match v {
860 Value::GpuTensor(h) => {
861 let prod = h.shape.iter().copied().product::<usize>();
862 if prod > 1 {
863 big = true;
864 }
865 }
866 Value::Tensor(t) => {
867 let prod = t.shape.iter().copied().product::<usize>();
868 if prod > 1 {
869 big = true;
870 }
871 }
872 _ => {}
873 };
874 for v in consumed_inputs.iter().filter_map(|v| v.as_ref()) {
875 check_val(v);
876 }
877 for v in &request.inputs {
878 check_val(v);
879 }
880 big
881 };
882 if looks_wrong {
883 log::debug!("fusion reduction: skipping fusion due to unresolved shape; falling back to provider path");
884 return Err(mex(
885 "FusionReductionShapeUnresolved",
886 "fusion: reduction shape unresolved",
887 ));
888 }
889 if std::env::var("RUNMAT_DISABLE_FUSED_REDUCTION")
890 .ok()
891 .as_deref()
892 == Some("1")
893 {
894 return Err(mex(
895 "FusionReductionDisabled",
896 "fusion: fused reductions disabled",
897 ));
898 }
899
900 Ok(ReductionGeometry {
901 axis,
902 reduce_len,
903 num_slices,
904 })
905}
906
907pub fn execute_fusion_reduction(
908 plan: &runmat_accelerate::FusionGroupPlan,
909 graph: &runmat_accelerate::AccelGraph,
910 request: FusionExecutionRequest<'_>,
911 consumed_inputs: &[Option<Value>],
912 stack_guard: StackSliceGuard<'_>,
913 vars: &[Value],
914 context: &ExecutionContext,
915) -> Result<Value, RuntimeError> {
916 let geom = resolve_reduction_geometry(plan, graph, &request, consumed_inputs, vars, context)?;
917 match execute_reduction(request, geom.reduce_len, geom.num_slices, 256u32) {
918 Ok(result) => {
919 stack_guard.commit();
920 Ok(result)
921 }
922 Err(err) => Err(mex("FusionExecutionFailed", &err.to_string())),
923 }
924}
925
926pub async fn try_execute_fusion_group(
927 plan: &runmat_accelerate::FusionGroupPlan,
928 graph: &runmat_accelerate::AccelGraph,
929 stack: &mut Vec<Value>,
930 vars: &mut Vec<Value>,
931 context: &mut ExecutionContext,
932) -> Result<Value, RuntimeError> {
933 let (stack_guard, request, consumed_inputs) =
934 gather_fusion_inputs(plan, graph, stack, vars, context)?;
935 if plan.group.kind.is_elementwise()
936 && !request.inputs.is_empty()
937 && request.inputs.iter().all(is_scalarish_runtime_value)
938 {
939 return Err(mex(
940 "FusionScalarBypass",
941 "fusion: bypass scalar-only elementwise group",
942 ));
943 }
944 log::debug!(
945 "dispatch fusion kind {:?}, supported {}",
946 plan.group.kind,
947 plan.kernel.supported
948 );
949 if plan.group.kind.is_elementwise() {
950 execute_fusion_elementwise(request, stack_guard, vars, context)
951 } else if plan.group.kind.is_reduction() {
952 execute_fusion_reduction(
953 plan,
954 graph,
955 request,
956 &consumed_inputs,
957 stack_guard,
958 vars,
959 context,
960 )
961 } else {
962 execute_fusion_special_kind(plan.group.kind.clone(), &plan.inputs, request, stack_guard)
963 .await
964 }
965}
966
967#[cfg(all(test, feature = "native-accel"))]
968mod tests {
969 use super::write_elementwise_materialized_stores;
970 use crate::bytecode::program::ExecutionContext;
971 use runmat_accelerate::fusion::FusionStoreMaterialization;
972 use runmat_accelerate::fusion_residency;
973 use runmat_accelerate::graph::VarBinding;
974 use runmat_accelerate::VarKind;
975 use runmat_accelerate_api::GpuTensorHandle;
976 use runmat_builtins::Value;
977
978 #[test]
979 fn fusion_writeback_preserves_shared_gpu_handles() {
980 let shared = GpuTensorHandle {
981 shape: vec![1],
982 device_id: 17,
983 buffer_id: 17001,
984 };
985 let old_only = GpuTensorHandle {
986 shape: vec![1],
987 device_id: 17,
988 buffer_id: 17002,
989 };
990 fusion_residency::mark(&shared);
991 fusion_residency::mark(&old_only);
992 assert!(fusion_residency::is_resident(&shared));
993 assert!(fusion_residency::is_resident(&old_only));
994
995 let mut vars = vec![Value::OutputList(vec![
996 Value::GpuTensor(shared.clone()),
997 Value::GpuTensor(old_only.clone()),
998 ])];
999 let mut context = ExecutionContext {
1000 call_stack: Vec::new(),
1001 locals: Vec::new(),
1002 instruction_pointer: 0,
1003 spawned_task_ids: std::collections::HashSet::new(),
1004 next_spawn_task_id: 0,
1005 };
1006 write_elementwise_materialized_stores(
1007 vec![(
1008 FusionStoreMaterialization {
1009 value_id: 1,
1010 binding: VarBinding {
1011 kind: VarKind::Global,
1012 index: 0,
1013 },
1014 },
1015 Value::GpuTensor(shared.clone()),
1016 )],
1017 &mut vars,
1018 &mut context,
1019 );
1020
1021 assert!(fusion_residency::is_resident(&shared));
1022 assert!(!fusion_residency::is_resident(&old_only));
1023 fusion_residency::clear(&shared);
1024 }
1025}