1use crate::data::Role;
11use crate::GpuValue;
12use core::marker::PhantomData;
13
14pub struct SharedRegion<T: GpuValue, const OWNER: u8> {
28 data: [T; 32],
29 owner: Role,
30 _phantom: PhantomData<()>,
31}
32
33impl<T: GpuValue + Default, const OWNER: u8> SharedRegion<T, OWNER> {
34 pub fn new(owner: Role) -> Self {
35 SharedRegion {
36 data: [T::default(); 32],
37 owner,
38 _phantom: PhantomData,
39 }
40 }
41}
42
43impl<T: GpuValue, const OWNER: u8> SharedRegion<T, OWNER> {
44 pub fn write(&mut self, index: usize, value: T) {
45 assert!(index < 32, "Index out of bounds");
46 self.data[index] = value;
47 }
48
49 pub fn read(&self, index: usize) -> T {
50 assert!(index < 32, "Index out of bounds");
51 self.data[index]
52 }
53
54 pub fn grant_read(&self) -> SharedView<'_, T, OWNER> {
55 SharedView {
56 region: self,
57 _phantom: PhantomData,
58 }
59 }
60
61 pub fn owner(&self) -> Role {
62 self.owner
63 }
64}
65
66pub struct SharedView<'a, T: GpuValue, const OWNER: u8> {
68 region: &'a SharedRegion<T, OWNER>,
69 _phantom: PhantomData<()>,
70}
71
72impl<'a, T: GpuValue, const OWNER: u8> SharedView<'a, T, OWNER> {
73 pub fn read(&self, index: usize) -> T {
74 self.region.read(index)
75 }
76}
77
78pub struct WorkQueue<T: GpuValue, const PRODUCER: u8, const CONSUMER: u8> {
83 tasks: SharedRegion<T, PRODUCER>,
84 head: usize,
85 tail: usize,
86}
87
88#[derive(Debug, Clone, Copy)]
89pub struct QueueFull;
90
91impl<T: GpuValue + Default, const PRODUCER: u8, const CONSUMER: u8>
92 WorkQueue<T, PRODUCER, CONSUMER>
93{
94 pub fn new(producer_role: Role, _consumer_role: Role) -> Self {
95 WorkQueue {
96 tasks: SharedRegion::new(producer_role),
97 head: 0,
98 tail: 0,
99 }
100 }
101
102 pub fn push(&mut self, task: T) -> Result<(), QueueFull> {
103 let next = (self.head + 1) % 32;
104 if next == self.tail {
105 return Err(QueueFull);
106 }
107 self.tasks.write(self.head, task);
108 self.head = next;
109 Ok(())
110 }
111
112 pub fn pop(&mut self) -> Option<T> {
113 if self.tail == self.head {
114 return None;
115 }
116 let task = self.tasks.read(self.tail);
117 self.tail = (self.tail + 1) % 32;
118 Some(task)
119 }
120
121 pub fn is_empty(&self) -> bool {
122 self.tail == self.head
123 }
124 pub fn is_full(&self) -> bool {
125 (self.head + 1) % 32 == self.tail
126 }
127}
128
129#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
134pub struct BlockId(pub u32);
135
136#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
137pub struct ThreadId {
138 pub block: BlockId,
139 pub warp: crate::data::WarpId,
140 pub lane: crate::data::LaneId,
141}
142
143pub trait BlockRole {
148 const NAME: &'static str;
149}
150
151pub struct Leader;
152impl BlockRole for Leader {
153 const NAME: &'static str = "Leader";
154}
155
156pub struct Worker;
157impl BlockRole for Worker {
158 const NAME: &'static str = "Worker";
159}
160
161pub trait ProtocolState {}
162
163pub struct Initial;
164impl ProtocolState for Initial {}
165
166pub struct WorkDistributed;
167impl ProtocolState for WorkDistributed {}
168
169pub struct WorkComplete;
170impl ProtocolState for WorkComplete {}
171
172pub struct BlockSession<R: BlockRole, S: ProtocolState, const N: usize> {
174 block_id: BlockId,
175 _role: PhantomData<R>,
176 _state: PhantomData<S>,
177}
178
179impl<R: BlockRole, S: ProtocolState, const N: usize> BlockSession<R, S, N> {
180 #[allow(dead_code)] pub(crate) fn new(block_id: BlockId) -> Self {
182 BlockSession {
183 block_id,
184 _role: PhantomData,
185 _state: PhantomData,
186 }
187 }
188
189 pub fn block_id(&self) -> BlockId {
190 self.block_id
191 }
192}
193
194pub struct WarpPhase;
199pub struct BlockPhase;
200pub struct GridPhase;
201pub struct Complete;
202
203pub struct ReductionSession<Phase> {
204 value: u32,
205 _phase: PhantomData<Phase>,
206}
207
208impl ReductionSession<WarpPhase> {
209 #[allow(dead_code)] pub(crate) fn new(value: u32) -> Self {
211 ReductionSession {
212 value,
213 _phase: PhantomData,
214 }
215 }
216
217 pub fn warp_reduce(self) -> (u32, ReductionSession<BlockPhase>) {
218 (
219 self.value,
220 ReductionSession {
221 value: self.value,
222 _phase: PhantomData,
223 },
224 )
225 }
226}
227
228impl ReductionSession<BlockPhase> {
229 pub fn block_reduce(self) -> (u32, ReductionSession<GridPhase>) {
230 (
231 self.value,
232 ReductionSession {
233 value: self.value,
234 _phase: PhantomData,
235 },
236 )
237 }
238}
239
240impl ReductionSession<GridPhase> {
241 pub fn grid_reduce(self) -> (u32, ReductionSession<Complete>) {
242 (
243 self.value,
244 ReductionSession {
245 value: self.value,
246 _phase: PhantomData,
247 },
248 )
249 }
250}
251
252impl ReductionSession<Complete> {
253 pub fn result(self) -> u32 {
254 self.value
255 }
256}
257
258#[cfg(test)]
259mod tests {
260 use super::*;
261
262 const COORDINATOR: u8 = 0;
263 const WORKER_ROLE: u8 = 1;
264
265 #[test]
266 fn test_shared_region_ownership() {
267 let coordinator = Role::lanes(0, 4, "coordinator");
268 let mut region: SharedRegion<i32, COORDINATOR> = SharedRegion::new(coordinator);
269 region.write(0, 42);
270 assert_eq!(region.read(0), 42);
271 let view = region.grant_read();
272 assert_eq!(view.read(0), 42);
273 }
274
275 #[test]
276 fn test_work_queue() {
277 let coordinator = Role::lanes(0, 4, "coordinator");
278 let worker = Role::lanes(4, 32, "worker");
279 let mut queue: WorkQueue<i32, COORDINATOR, WORKER_ROLE> =
280 WorkQueue::new(coordinator, worker);
281
282 assert!(queue.is_empty());
283 queue.push(1).unwrap();
284 queue.push(2).unwrap();
285 queue.push(3).unwrap();
286 assert!(!queue.is_empty());
287 assert_eq!(queue.pop(), Some(1));
288 assert_eq!(queue.pop(), Some(2));
289 assert_eq!(queue.pop(), Some(3));
290 assert_eq!(queue.pop(), None);
291 }
292
293 #[test]
294 fn test_work_queue_full() {
295 let coordinator = Role::lanes(0, 4, "coordinator");
296 let worker = Role::lanes(4, 32, "worker");
297 let mut queue: WorkQueue<i32, COORDINATOR, WORKER_ROLE> =
298 WorkQueue::new(coordinator, worker);
299
300 for i in 0..31 {
302 assert!(queue.push(i).is_ok());
303 }
304 assert!(queue.is_full());
305 assert!(queue.push(31).is_err());
306 }
307
308 #[test]
309 fn test_hierarchical_reduction() {
310 let session = ReductionSession::<WarpPhase>::new(42);
311 let (warp_result, session) = session.warp_reduce();
312 assert_eq!(warp_result, 42);
313 let (block_result, session) = session.block_reduce();
314 assert_eq!(block_result, 42);
315 let (grid_result, session) = session.grid_reduce();
316 assert_eq!(grid_result, 42);
317 assert_eq!(session.result(), 42);
318 }
319
320 #[test]
321 fn test_block_session() {
322 let leader: BlockSession<Leader, Initial, 4> = BlockSession::new(BlockId(0));
323 assert_eq!(leader.block_id().0, 0);
324 let worker: BlockSession<Worker, Initial, 4> = BlockSession::new(BlockId(1));
325 assert_eq!(worker.block_id().0, 1);
326 }
327}