Skip to main content

vyre_driver/
fusion.rs

1//! Cross-dispatch fusion decisions shared by concrete backends.
2
3use crate::specialization::SpecMap;
4
5/// One dispatch's pre-fusion description.
6#[derive(Debug, Clone)]
7pub struct DispatchShape {
8    /// Stable id for this dispatch inside the containing program.
9    pub id: &'static str,
10    /// Workgroup size `[x, y, z]`.
11    pub workgroup_size: [u32; 3],
12    /// Per-dispatch shared memory bytes.
13    pub shared_memory_bytes: u32,
14    /// Buffers this dispatch reads.
15    pub inputs: Vec<&'static str>,
16    /// Buffers this dispatch writes.
17    pub outputs: Vec<&'static str>,
18    /// Specialization constants baked into this dispatch.
19    pub specs: SpecMap,
20}
21
22/// Adapter caps honored by the generic fusion pass.
23#[derive(Debug, Clone, Copy)]
24pub struct FusionCaps {
25    /// Maximum workgroup-shared memory the adapter can serve.
26    pub max_shared_memory_bytes: u32,
27    /// Maximum workgroup invocation count.
28    pub max_invocations_per_workgroup: u32,
29}
30
31impl Default for FusionCaps {
32    fn default() -> Self {
33        Self {
34            max_shared_memory_bytes: 16 * 1024,
35            max_invocations_per_workgroup: 256,
36        }
37    }
38}
39
40impl FusionCaps {
41    /// High-end profile for tests and capability probes.
42    #[must_use]
43    pub const fn high_end() -> Self {
44        Self {
45            max_shared_memory_bytes: 128 * 1024,
46            max_invocations_per_workgroup: 1024,
47        }
48    }
49}
50
51/// Why the fusion pass accepted or rejected a pair.
52#[derive(Debug, Clone, PartialEq, Eq)]
53#[non_exhaustive]
54pub enum FusionDecision {
55    /// Fusion is legal; the concrete backend may stitch its target modules.
56    Accept,
57    /// Workgroup size mismatch or invocation budget violation.
58    WorkgroupSizeMismatch {
59        /// Upstream size.
60        upstream: [u32; 3],
61        /// Downstream size.
62        downstream: [u32; 3],
63    },
64    /// Shared-memory budget would exceed adapter caps.
65    SharedMemoryBudget {
66        /// Combined bytes the fused kernel would request.
67        needed: u64,
68        /// Adapter cap.
69        cap: u32,
70    },
71    /// A flow-through output is still consumed by a third dispatch.
72    OutputConsumedElsewhere,
73    /// No buffer flows from upstream outputs to downstream inputs.
74    NoPipelineDependency,
75}
76
77/// Pure cross-dispatch fusion analysis.
78pub struct FusionPass;
79
80impl FusionPass {
81    /// Decide whether `upstream` -> `downstream` is legal to fuse.
82    #[must_use]
83    pub fn decide(
84        upstream: &DispatchShape,
85        downstream: &DispatchShape,
86        caps: FusionCaps,
87        other_consumers: &[&str],
88    ) -> FusionDecision {
89        if upstream.workgroup_size != downstream.workgroup_size {
90            return FusionDecision::WorkgroupSizeMismatch {
91                upstream: upstream.workgroup_size,
92                downstream: downstream.workgroup_size,
93            };
94        }
95        let invocations = u128::from(upstream.workgroup_size[0])
96            * u128::from(upstream.workgroup_size[1])
97            * u128::from(upstream.workgroup_size[2]);
98        if invocations > u128::from(caps.max_invocations_per_workgroup) {
99            return FusionDecision::WorkgroupSizeMismatch {
100                upstream: upstream.workgroup_size,
101                downstream: downstream.workgroup_size,
102            };
103        }
104        let needed =
105            u64::from(upstream.shared_memory_bytes) + u64::from(downstream.shared_memory_bytes);
106        if needed > u64::from(caps.max_shared_memory_bytes) {
107            return FusionDecision::SharedMemoryBudget {
108                needed,
109                cap: caps.max_shared_memory_bytes,
110            };
111        }
112
113        let mut has_pipeline_dependency = false;
114        for output in &upstream.outputs {
115            if !downstream.inputs.iter().any(|input| input == output) {
116                continue;
117            }
118            has_pipeline_dependency = true;
119            if other_consumers.iter().any(|consumer| consumer == output) {
120                return FusionDecision::OutputConsumedElsewhere;
121            }
122        }
123        if !has_pipeline_dependency {
124            return FusionDecision::NoPipelineDependency;
125        }
126        FusionDecision::Accept
127    }
128}
129
130#[cfg(test)]
131mod tests {
132    use super::*;
133
134    fn dispatch(
135        id: &'static str,
136        inputs: &[&'static str],
137        outputs: &[&'static str],
138    ) -> DispatchShape {
139        DispatchShape {
140            id,
141            workgroup_size: [64, 1, 1],
142            shared_memory_bytes: 1024,
143            inputs: inputs.to_vec(),
144            outputs: outputs.to_vec(),
145            specs: SpecMap::new(),
146        }
147    }
148
149    #[test]
150    fn straight_producer_consumer_fuses() {
151        let up = dispatch("load", &["in"], &["stage"]);
152        let down = dispatch("xor", &["stage"], &["out"]);
153        assert_eq!(
154            FusionPass::decide(&up, &down, FusionCaps::high_end(), &[]),
155            FusionDecision::Accept
156        );
157    }
158
159    #[test]
160    fn third_consumer_rejects() {
161        let up = dispatch("a", &[], &["x"]);
162        let down = dispatch("b", &["x"], &[]);
163        assert_eq!(
164            FusionPass::decide(&up, &down, FusionCaps::high_end(), &["x"]),
165            FusionDecision::OutputConsumedElsewhere
166        );
167    }
168
169    #[test]
170    fn workgroup_invocation_overflow_rejects_instead_of_wrapping_or_clamping() {
171        let mut up = dispatch("wide-a", &["in"], &["stage"]);
172        up.workgroup_size = [u32::MAX, u32::MAX, 2];
173        let mut down = dispatch("wide-b", &["stage"], &["out"]);
174        down.workgroup_size = up.workgroup_size;
175        assert_eq!(
176            FusionPass::decide(&up, &down, FusionCaps::high_end(), &[]),
177            FusionDecision::WorkgroupSizeMismatch {
178                upstream: up.workgroup_size,
179                downstream: down.workgroup_size,
180            }
181        );
182    }
183
184    #[test]
185    fn shared_memory_overflow_rejects_instead_of_appearing_under_cap() {
186        let mut up = dispatch("smem-a", &["in"], &["stage"]);
187        up.shared_memory_bytes = u32::MAX;
188        let mut down = dispatch("smem-b", &["stage"], &["out"]);
189        down.shared_memory_bytes = 1;
190        assert_eq!(
191            FusionPass::decide(&up, &down, FusionCaps::high_end(), &[]),
192            FusionDecision::SharedMemoryBudget {
193                needed: u64::from(u32::MAX) + 1,
194                cap: FusionCaps::high_end().max_shared_memory_bytes,
195            }
196        );
197    }
198
199    #[test]
200    fn source_has_no_clamped_fusion_admission_math() {
201        let source = include_str!("fusion.rs");
202        assert!(
203            !source.contains(concat!(".", "saturating_")),
204            "fusion admission must use widened exact arithmetic, not silent clamps"
205        );
206    }
207}