1use crate::GpuValue;
9
10#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
15#[repr(transparent)]
16pub struct LaneId(u8);
17
18impl LaneId {
19 pub const fn new(id: u8) -> Self {
20 assert!(
21 id < 64,
22 "Lane ID must be < 64 (supports NVIDIA 32-lane and AMD 64-lane)"
23 );
24 LaneId(id)
25 }
26
27 pub const fn get(self) -> u8 {
28 self.0
29 }
30
31 pub const fn index(self) -> usize {
32 self.0 as usize
33 }
34}
35
36#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
38#[repr(transparent)]
39pub struct WarpId(u16);
40
41impl WarpId {
42 pub const fn new(id: u16) -> Self {
43 WarpId(id)
44 }
45
46 pub const fn get(self) -> u16 {
47 self.0
48 }
49}
50
51#[must_use = "Uniform values represent warp-wide reduction results — dropping discards the result"]
57#[derive(Clone, Copy, Debug, PartialEq)]
58#[repr(transparent)]
59pub struct Uniform<T: GpuValue> {
60 value: T,
61}
62
63impl<T: GpuValue> Uniform<T> {
64 pub const fn from_const(value: T) -> Self {
66 Uniform { value }
67 }
68
69 pub fn get(self) -> T {
71 self.value
72 }
73
74 pub fn broadcast(self) -> PerLane<T> {
76 PerLane { value: self.value }
77 }
78}
79
80#[must_use = "PerLane values carry per-lane GPU data — dropping discards computation"]
86#[derive(Clone, Copy, Debug, PartialEq)]
87#[repr(transparent)]
88pub struct PerLane<T: GpuValue> {
89 value: T,
90}
91
92impl<T: GpuValue> PerLane<T> {
93 pub fn new(value: T) -> Self {
94 PerLane { value }
95 }
96
97 pub fn get(self) -> T {
98 self.value
99 }
100
101 pub unsafe fn assume_uniform(self) -> Uniform<T> {
106 Uniform { value: self.value }
107 }
108}
109
110impl<T: GpuValue + core::ops::Add<Output = T>> core::ops::Add for PerLane<T> {
111 type Output = PerLane<T>;
112 fn add(self, rhs: PerLane<T>) -> PerLane<T> {
113 PerLane {
114 value: self.value + rhs.value,
115 }
116 }
117}
118
119#[must_use = "SingleLane values exist in one lane only — dropping discards the reduction result"]
125#[derive(Clone, Copy, Debug, PartialEq)]
126#[repr(transparent)]
127pub struct SingleLane<T: GpuValue, const LANE: u8> {
128 value: T,
129}
130
131impl<T: GpuValue, const LANE: u8> SingleLane<T, LANE> {
132 pub fn new(value: T) -> Self {
133 SingleLane { value }
134 }
135
136 pub fn get(self) -> T {
138 self.value
139 }
140
141 pub fn broadcast(self) -> Uniform<T> {
143 Uniform::from_const(self.value)
144 }
145}
146
147#[derive(Clone, Copy, Debug, PartialEq, Eq)]
153pub struct Role {
154 mask: u64,
155 name: &'static str,
156}
157
158impl Role {
159 pub const fn mask(self) -> u64 {
161 self.mask
162 }
163
164 pub const fn name(self) -> &'static str {
166 self.name
167 }
168}
169
170impl Role {
171 pub const fn lanes(start: u8, end: u8, name: &'static str) -> Self {
172 assert!(start < 64 && end <= 64 && start < end);
173 let width = (end - start) as u64;
174 let mask = if width >= 64 {
175 u64::MAX
176 } else {
177 ((1u64 << width) - 1) << start
178 };
179 Role { mask, name }
180 }
181
182 pub const fn lane(id: u8, name: &'static str) -> Self {
183 assert!(id < 64);
184 Role {
185 mask: 1u64 << id,
186 name,
187 }
188 }
189
190 pub const fn contains(self, lane: LaneId) -> bool {
191 (self.mask & (1u64 << lane.0)) != 0
192 }
193
194 pub const fn count(self) -> u32 {
195 self.mask.count_ones()
196 }
197}
198
199#[cfg(test)]
200mod tests {
201 use super::*;
202
203 #[test]
204 fn test_lane_id() {
205 let lane = LaneId::new(15);
206 assert_eq!(lane.get(), 15);
207 assert_eq!(lane.index(), 15);
208 }
209
210 #[test]
211 fn test_lane_id_boundary_31() {
212 let lane = LaneId::new(31);
213 assert_eq!(lane.get(), 31);
214 assert_eq!(lane.index(), 31);
215 }
216
217 #[test]
218 fn test_lane_id_boundary_63() {
219 let lane = LaneId::new(63);
220 assert_eq!(lane.get(), 63);
221 }
222
223 #[test]
224 #[should_panic]
225 fn test_lane_id_out_of_range() {
226 LaneId::new(64);
227 }
228
229 #[test]
230 fn test_uniform_broadcast() {
231 let u: Uniform<i32> = Uniform::from_const(42);
232 let p: PerLane<i32> = u.broadcast();
233 assert_eq!(p.get(), 42);
234 }
235
236 #[test]
237 fn test_single_lane_broadcast() {
238 let reduced: SingleLane<i32, 0> = SingleLane::new(42);
239 let uniform: Uniform<i32> = reduced.broadcast();
240 assert_eq!(uniform.get(), 42);
241 }
242
243 #[test]
244 fn test_role_coverage() {
245 let coordinator = Role::lanes(0, 4, "coordinator");
246 let worker = Role::lanes(4, 32, "worker");
247
248 assert!(coordinator.contains(LaneId::new(0)));
249 assert!(coordinator.contains(LaneId::new(3)));
250 assert!(!coordinator.contains(LaneId::new(4)));
251
252 assert!(!worker.contains(LaneId::new(3)));
253 assert!(worker.contains(LaneId::new(4)));
254 assert!(worker.contains(LaneId::new(31)));
255
256 assert_eq!(coordinator.count(), 4);
257 assert_eq!(worker.count(), 28);
258 }
259}