Skip to main content

runmat_vm/accel/
stack_layout.rs

1use std::collections::{HashMap, HashSet};
2
3use runmat_accelerate::graph::{AccelGraph, ValueId, ValueOrigin};
4use runmat_accelerate::{FusionGroup, FusionStackLayout, FusionStackValueBinding};
5
6use crate::instr::Instr;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9enum StackValue {
10    Unknown,
11    GraphValue(ValueId),
12}
13
14pub fn annotate_fusion_groups_with_stack_layout(
15    instructions: &[Instr],
16    graph: &AccelGraph,
17    groups: &mut [FusionGroup],
18) {
19    if groups.is_empty() {
20        return;
21    }
22
23    let mut groups_by_start: HashMap<usize, Vec<usize>> = HashMap::new();
24    for (idx, group) in groups.iter().enumerate() {
25        groups_by_start
26            .entry(group.span.start)
27            .or_default()
28            .push(idx);
29    }
30
31    let node_output_by_pc: HashMap<usize, ValueId> = graph
32        .nodes
33        .iter()
34        .filter_map(|node| {
35            node.outputs
36                .first()
37                .copied()
38                .map(|value_id| (node.span.end, value_id))
39        })
40        .collect();
41
42    let mut stack: Vec<StackValue> = Vec::new();
43    for (pc, instr) in instructions.iter().enumerate() {
44        if let Some(group_indices) = groups_by_start.get(&pc) {
45            for &group_idx in group_indices {
46                groups[group_idx].stack_layout =
47                    build_group_stack_layout(instructions, graph, &groups[group_idx], &stack);
48            }
49        }
50
51        let Some(effect) = instr.stack_effect() else {
52            stack.clear();
53            continue;
54        };
55        for _ in 0..effect.pops {
56            let _ = stack.pop();
57        }
58        if effect.pushes == 0 {
59            continue;
60        }
61
62        let pushed_value = if effect.pushes == 1 {
63            node_output_by_pc
64                .get(&pc)
65                .copied()
66                .map(StackValue::GraphValue)
67                .unwrap_or(StackValue::Unknown)
68        } else {
69            StackValue::Unknown
70        };
71        for _ in 0..effect.pushes {
72            stack.push(pushed_value);
73        }
74    }
75}
76
77fn build_group_stack_layout(
78    instructions: &[Instr],
79    graph: &AccelGraph,
80    group: &FusionGroup,
81    entry_stack: &[StackValue],
82) -> Option<FusionStackLayout> {
83    let required_stack_operands =
84        required_stack_operands(instructions, group.span.start, group.span.end)?;
85    if required_stack_operands > entry_stack.len() {
86        return None;
87    }
88
89    let stack_value_ids = stack_backed_external_values(graph, group);
90    let slice_start = entry_stack.len().saturating_sub(required_stack_operands);
91    let mut seen = HashSet::new();
92    let mut bindings = Vec::new();
93
94    for (absolute_offset, value) in entry_stack.iter().enumerate().skip(slice_start) {
95        let StackValue::GraphValue(value_id) = value else {
96            continue;
97        };
98        if !stack_value_ids.contains(value_id) || !seen.insert(*value_id) {
99            continue;
100        }
101        bindings.push(FusionStackValueBinding {
102            value_id: *value_id,
103            stack_offset: absolute_offset - slice_start,
104        });
105    }
106
107    Some(FusionStackLayout {
108        required_stack_operands,
109        bindings,
110    })
111}
112
113fn stack_backed_external_values(graph: &AccelGraph, group: &FusionGroup) -> HashSet<ValueId> {
114    let node_set: HashSet<_> = group.nodes.iter().copied().collect();
115    let mut values = HashSet::new();
116    for node_id in &group.nodes {
117        let Some(node) = graph.node(*node_id) else {
118            continue;
119        };
120        for value_id in &node.inputs {
121            let Some(info) = graph.value(*value_id) else {
122                continue;
123            };
124            match info.origin {
125                ValueOrigin::NodeOutput { node, .. } if !node_set.contains(&node) => {
126                    if graph.var_binding(*value_id).is_none() {
127                        values.insert(*value_id);
128                    }
129                }
130                _ => {}
131            }
132        }
133    }
134    values
135}
136
137fn required_stack_operands(
138    instructions: &[Instr],
139    start_pc: usize,
140    end_pc: usize,
141) -> Option<usize> {
142    if start_pc >= instructions.len() || end_pc >= instructions.len() || start_pc > end_pc {
143        return None;
144    }
145
146    let mut current_depth = 0usize;
147    let mut required_depth = 0usize;
148    for instr in &instructions[start_pc..=end_pc] {
149        let effect = instr.stack_effect()?;
150        if current_depth < effect.pops {
151            required_depth += effect.pops - current_depth;
152            current_depth = effect.pops;
153        }
154        current_depth = current_depth - effect.pops + effect.pushes;
155    }
156    Some(required_depth)
157}