1use crate::GpuValue;
9
10#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
15pub struct LaneId(u8);
16
17impl LaneId {
18 pub const fn new(id: u8) -> Self {
19 assert!(
20 id < 64,
21 "Lane ID must be < 64 (supports NVIDIA 32-lane and AMD 64-lane)"
22 );
23 LaneId(id)
24 }
25
26 pub const fn get(self) -> u8 {
27 self.0
28 }
29
30 pub const fn index(self) -> usize {
31 self.0 as usize
32 }
33}
34
35#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
37pub struct WarpId(u16);
38
39impl WarpId {
40 pub const fn new(id: u16) -> Self {
41 WarpId(id)
42 }
43
44 pub const fn get(self) -> u16 {
45 self.0
46 }
47}
48
49#[derive(Clone, Copy, Debug, PartialEq)]
55pub struct Uniform<T: GpuValue> {
56 value: T,
57}
58
59impl<T: GpuValue> Uniform<T> {
60 pub const fn from_const(value: T) -> Self {
62 Uniform { value }
63 }
64
65 pub fn get(self) -> T {
67 self.value
68 }
69
70 pub fn broadcast(self) -> PerLane<T> {
72 PerLane { value: self.value }
73 }
74}
75
76#[must_use = "PerLane values carry per-lane GPU data — dropping discards computation"]
82#[derive(Clone, Copy, Debug, PartialEq)]
83pub struct PerLane<T: GpuValue> {
84 value: T,
85}
86
87impl<T: GpuValue> PerLane<T> {
88 pub fn new(value: T) -> Self {
89 PerLane { value }
90 }
91
92 pub fn get(self) -> T {
93 self.value
94 }
95
96 pub unsafe fn assume_uniform(self) -> Uniform<T> {
101 Uniform { value: self.value }
102 }
103}
104
105impl<T: GpuValue + core::ops::Add<Output = T>> core::ops::Add for PerLane<T> {
106 type Output = PerLane<T>;
107 fn add(self, rhs: PerLane<T>) -> PerLane<T> {
108 PerLane {
109 value: self.value + rhs.value,
110 }
111 }
112}
113
114#[derive(Clone, Copy, Debug, PartialEq)]
120pub struct SingleLane<T: GpuValue, const LANE: u8> {
121 value: T,
122}
123
124impl<T: GpuValue, const LANE: u8> SingleLane<T, LANE> {
125 pub fn new(value: T) -> Self {
126 SingleLane { value }
127 }
128
129 pub fn get(self) -> T {
131 self.value
132 }
133
134 pub fn broadcast(self) -> Uniform<T> {
136 Uniform::from_const(self.value)
137 }
138}
139
140#[derive(Clone, Copy, Debug, PartialEq, Eq)]
146pub struct Role {
147 pub mask: u64,
148 pub name: &'static str,
149}
150
151impl Role {
152 pub const fn lanes(start: u8, end: u8, name: &'static str) -> Self {
153 assert!(start < 64 && end <= 64 && start < end);
154 let width = (end - start) as u64;
155 let mask = if width >= 64 {
156 u64::MAX
157 } else {
158 ((1u64 << width) - 1) << start
159 };
160 Role { mask, name }
161 }
162
163 pub const fn lane(id: u8, name: &'static str) -> Self {
164 assert!(id < 64);
165 Role {
166 mask: 1u64 << id,
167 name,
168 }
169 }
170
171 pub const fn contains(self, lane: LaneId) -> bool {
172 (self.mask & (1u64 << lane.0)) != 0
173 }
174
175 pub const fn count(self) -> u32 {
176 self.mask.count_ones()
177 }
178}
179
180#[cfg(test)]
181mod tests {
182 use super::*;
183
184 #[test]
185 fn test_lane_id() {
186 let lane = LaneId::new(15);
187 assert_eq!(lane.get(), 15);
188 assert_eq!(lane.index(), 15);
189 }
190
191 #[test]
192 fn test_lane_id_boundary_31() {
193 let lane = LaneId::new(31);
194 assert_eq!(lane.get(), 31);
195 assert_eq!(lane.index(), 31);
196 }
197
198 #[test]
199 fn test_lane_id_boundary_63() {
200 let lane = LaneId::new(63);
201 assert_eq!(lane.get(), 63);
202 }
203
204 #[test]
205 #[should_panic]
206 fn test_lane_id_out_of_range() {
207 LaneId::new(64);
208 }
209
210 #[test]
211 fn test_uniform_broadcast() {
212 let u: Uniform<i32> = Uniform::from_const(42);
213 let p: PerLane<i32> = u.broadcast();
214 assert_eq!(p.get(), 42);
215 }
216
217 #[test]
218 fn test_single_lane_broadcast() {
219 let reduced: SingleLane<i32, 0> = SingleLane::new(42);
220 let uniform: Uniform<i32> = reduced.broadcast();
221 assert_eq!(uniform.get(), 42);
222 }
223
224 #[test]
225 fn test_role_coverage() {
226 let coordinator = Role::lanes(0, 4, "coordinator");
227 let worker = Role::lanes(4, 32, "worker");
228
229 assert!(coordinator.contains(LaneId::new(0)));
230 assert!(coordinator.contains(LaneId::new(3)));
231 assert!(!coordinator.contains(LaneId::new(4)));
232
233 assert!(!worker.contains(LaneId::new(3)));
234 assert!(worker.contains(LaneId::new(4)));
235 assert!(worker.contains(LaneId::new(31)));
236
237 assert_eq!(coordinator.count(), 4);
238 assert_eq!(worker.count(), 28);
239 }
240}