runmat_vm/accel/
stack_layout.rs1use std::collections::{HashMap, HashSet};
2
3use runmat_accelerate::graph::{AccelGraph, ValueId, ValueOrigin};
4use runmat_accelerate::{FusionGroup, FusionStackLayout, FusionStackValueBinding};
5
6use crate::instr::Instr;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9enum StackValue {
10 Unknown,
11 GraphValue(ValueId),
12}
13
14pub fn annotate_fusion_groups_with_stack_layout(
15 instructions: &[Instr],
16 graph: &AccelGraph,
17 groups: &mut [FusionGroup],
18) {
19 if groups.is_empty() {
20 return;
21 }
22
23 let mut groups_by_start: HashMap<usize, Vec<usize>> = HashMap::new();
24 for (idx, group) in groups.iter().enumerate() {
25 groups_by_start
26 .entry(group.span.start)
27 .or_default()
28 .push(idx);
29 }
30
31 let node_output_by_pc: HashMap<usize, ValueId> = graph
32 .nodes
33 .iter()
34 .filter_map(|node| {
35 node.outputs
36 .first()
37 .copied()
38 .map(|value_id| (node.span.end, value_id))
39 })
40 .collect();
41
42 let mut stack: Vec<StackValue> = Vec::new();
43 for (pc, instr) in instructions.iter().enumerate() {
44 if let Some(group_indices) = groups_by_start.get(&pc) {
45 for &group_idx in group_indices {
46 groups[group_idx].stack_layout =
47 build_group_stack_layout(instructions, graph, &groups[group_idx], &stack);
48 }
49 }
50
51 let Some(effect) = instr.stack_effect() else {
52 stack.clear();
53 continue;
54 };
55 for _ in 0..effect.pops {
56 let _ = stack.pop();
57 }
58 if effect.pushes == 0 {
59 continue;
60 }
61
62 let pushed_value = if effect.pushes == 1 {
63 node_output_by_pc
64 .get(&pc)
65 .copied()
66 .map(StackValue::GraphValue)
67 .unwrap_or(StackValue::Unknown)
68 } else {
69 StackValue::Unknown
70 };
71 for _ in 0..effect.pushes {
72 stack.push(pushed_value);
73 }
74 }
75}
76
77fn build_group_stack_layout(
78 instructions: &[Instr],
79 graph: &AccelGraph,
80 group: &FusionGroup,
81 entry_stack: &[StackValue],
82) -> Option<FusionStackLayout> {
83 let required_stack_operands =
84 required_stack_operands(instructions, group.span.start, group.span.end)?;
85 if required_stack_operands > entry_stack.len() {
86 return None;
87 }
88
89 let stack_value_ids = stack_backed_external_values(graph, group);
90 let slice_start = entry_stack.len().saturating_sub(required_stack_operands);
91 let mut seen = HashSet::new();
92 let mut bindings = Vec::new();
93
94 for (absolute_offset, value) in entry_stack.iter().enumerate().skip(slice_start) {
95 let StackValue::GraphValue(value_id) = value else {
96 continue;
97 };
98 if !stack_value_ids.contains(value_id) || !seen.insert(*value_id) {
99 continue;
100 }
101 bindings.push(FusionStackValueBinding {
102 value_id: *value_id,
103 stack_offset: absolute_offset - slice_start,
104 });
105 }
106
107 Some(FusionStackLayout {
108 required_stack_operands,
109 bindings,
110 })
111}
112
113fn stack_backed_external_values(graph: &AccelGraph, group: &FusionGroup) -> HashSet<ValueId> {
114 let node_set: HashSet<_> = group.nodes.iter().copied().collect();
115 let mut values = HashSet::new();
116 for node_id in &group.nodes {
117 let Some(node) = graph.node(*node_id) else {
118 continue;
119 };
120 for value_id in &node.inputs {
121 let Some(info) = graph.value(*value_id) else {
122 continue;
123 };
124 match info.origin {
125 ValueOrigin::NodeOutput { node, .. } if !node_set.contains(&node) => {
126 if graph.var_binding(*value_id).is_none() {
127 values.insert(*value_id);
128 }
129 }
130 _ => {}
131 }
132 }
133 }
134 values
135}
136
137fn required_stack_operands(
138 instructions: &[Instr],
139 start_pc: usize,
140 end_pc: usize,
141) -> Option<usize> {
142 if start_pc >= instructions.len() || end_pc >= instructions.len() || start_pc > end_pc {
143 return None;
144 }
145
146 let mut current_depth = 0usize;
147 let mut required_depth = 0usize;
148 for instr in &instructions[start_pc..=end_pc] {
149 let effect = instr.stack_effect()?;
150 if current_depth < effect.pops {
151 required_depth += effect.pops - current_depth;
152 current_depth = effect.pops;
153 }
154 current_depth = current_depth - effect.pops + effect.pushes;
155 }
156 Some(required_depth)
157}