1use crate::data::Role;
11use crate::GpuValue;
12use core::marker::PhantomData;
13
14pub struct SharedRegion<T: GpuValue, const OWNER: u8> {
33 data: [T; crate::WARP_SIZE as usize],
34 owner: Role,
35 _phantom: PhantomData<()>,
36}
37
38impl<T: GpuValue + Default, const OWNER: u8> SharedRegion<T, OWNER> {
39 pub fn new(owner: Role) -> Self {
40 SharedRegion {
41 data: [T::default(); crate::WARP_SIZE as usize],
42 owner,
43 _phantom: PhantomData,
44 }
45 }
46}
47
48impl<T: GpuValue, const OWNER: u8> SharedRegion<T, OWNER> {
49 pub fn write(&mut self, index: usize, value: T) {
50 assert!(index < crate::WARP_SIZE as usize, "Index out of bounds");
51 self.data[index] = value;
52 }
53
54 pub fn read(&self, index: usize) -> T {
55 assert!(index < crate::WARP_SIZE as usize, "Index out of bounds");
56 self.data[index]
57 }
58
59 pub fn grant_read(&self) -> SharedView<'_, T, OWNER> {
60 SharedView {
61 region: self,
62 _phantom: PhantomData,
63 }
64 }
65
66 pub fn owner(&self) -> Role {
67 self.owner
68 }
69}
70
71pub struct SharedView<'a, T: GpuValue, const OWNER: u8> {
73 region: &'a SharedRegion<T, OWNER>,
74 _phantom: PhantomData<()>,
75}
76
77impl<'a, T: GpuValue, const OWNER: u8> SharedView<'a, T, OWNER> {
78 pub fn read(&self, index: usize) -> T {
79 self.region.read(index)
80 }
81}
82
83pub struct WorkQueue<T: GpuValue, const PRODUCER: u8, const CONSUMER: u8> {
88 tasks: SharedRegion<T, PRODUCER>,
89 head: usize,
90 tail: usize,
91}
92
93#[derive(Debug, Clone, Copy)]
94pub struct QueueFull;
95
96impl<T: GpuValue + Default, const PRODUCER: u8, const CONSUMER: u8>
97 WorkQueue<T, PRODUCER, CONSUMER>
98{
99 pub fn new(producer_role: Role, _consumer_role: Role) -> Self {
100 WorkQueue {
101 tasks: SharedRegion::new(producer_role),
102 head: 0,
103 tail: 0,
104 }
105 }
106
107 pub fn push(&mut self, task: T) -> Result<(), QueueFull> {
108 let next = (self.head + 1) % crate::WARP_SIZE as usize;
109 if next == self.tail {
110 return Err(QueueFull);
111 }
112 self.tasks.write(self.head, task);
113 self.head = next;
114 Ok(())
115 }
116
117 pub fn pop(&mut self) -> Option<T> {
118 if self.tail == self.head {
119 return None;
120 }
121 let task = self.tasks.read(self.tail);
122 self.tail = (self.tail + 1) % crate::WARP_SIZE as usize;
123 Some(task)
124 }
125
126 pub fn is_empty(&self) -> bool {
127 self.tail == self.head
128 }
129 pub fn is_full(&self) -> bool {
130 (self.head + 1) % crate::WARP_SIZE as usize == self.tail
131 }
132}
133
134#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
139#[repr(transparent)]
140pub struct BlockId(u32);
141
142impl BlockId {
143 pub const fn new(id: u32) -> Self {
144 BlockId(id)
145 }
146
147 pub const fn get(self) -> u32 {
148 self.0
149 }
150}
151
152#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
153#[repr(C)]
154pub struct ThreadId {
155 block: BlockId,
156 warp: crate::data::WarpId,
157 lane: crate::data::LaneId,
158}
159
160impl ThreadId {
161 pub const fn new(block: BlockId, warp: crate::data::WarpId, lane: crate::data::LaneId) -> Self {
162 ThreadId { block, warp, lane }
163 }
164
165 pub const fn block(self) -> BlockId {
166 self.block
167 }
168
169 pub const fn warp(self) -> crate::data::WarpId {
170 self.warp
171 }
172
173 pub const fn lane(self) -> crate::data::LaneId {
174 self.lane
175 }
176}
177
178pub trait BlockRole {
183 const NAME: &'static str;
184}
185
186pub struct Leader;
187impl BlockRole for Leader {
188 const NAME: &'static str = "Leader";
189}
190
191pub struct Worker;
192impl BlockRole for Worker {
193 const NAME: &'static str = "Worker";
194}
195
196pub trait ProtocolState {}
197
198pub struct Initial;
199impl ProtocolState for Initial {}
200
201pub struct WorkDistributed;
202impl ProtocolState for WorkDistributed {}
203
204pub struct WorkComplete;
205impl ProtocolState for WorkComplete {}
206
207pub struct BlockSession<R: BlockRole, S: ProtocolState, const N: usize> {
209 block_id: BlockId,
210 _role: PhantomData<R>,
211 _state: PhantomData<S>,
212}
213
214impl<R: BlockRole, S: ProtocolState, const N: usize> BlockSession<R, S, N> {
215 #[allow(dead_code)] pub(crate) fn new(block_id: BlockId) -> Self {
217 BlockSession {
218 block_id,
219 _role: PhantomData,
220 _state: PhantomData,
221 }
222 }
223
224 pub fn block_id(&self) -> BlockId {
225 self.block_id
226 }
227}
228
229pub struct WarpPhase;
234pub struct BlockPhase;
235pub struct GridPhase;
236pub struct Complete;
237
238#[must_use = "ReductionSession is a linear state machine — dropping abandons the reduction"]
239pub struct ReductionSession<Phase> {
240 value: u32,
241 _phase: PhantomData<Phase>,
242}
243
244impl ReductionSession<WarpPhase> {
245 #[allow(dead_code)] pub(crate) fn new(value: u32) -> Self {
247 ReductionSession {
248 value,
249 _phase: PhantomData,
250 }
251 }
252
253 pub fn warp_reduce(self) -> (u32, ReductionSession<BlockPhase>) {
254 (
255 self.value,
256 ReductionSession {
257 value: self.value,
258 _phase: PhantomData,
259 },
260 )
261 }
262}
263
264impl ReductionSession<BlockPhase> {
265 pub fn block_reduce(self) -> (u32, ReductionSession<GridPhase>) {
266 (
267 self.value,
268 ReductionSession {
269 value: self.value,
270 _phase: PhantomData,
271 },
272 )
273 }
274}
275
276impl ReductionSession<GridPhase> {
277 pub fn grid_reduce(self) -> (u32, ReductionSession<Complete>) {
278 (
279 self.value,
280 ReductionSession {
281 value: self.value,
282 _phase: PhantomData,
283 },
284 )
285 }
286}
287
288impl ReductionSession<Complete> {
289 pub fn result(self) -> u32 {
290 self.value
291 }
292}
293
294#[cfg(test)]
295mod tests {
296 use super::*;
297
298 const COORDINATOR: u8 = 0;
299 const WORKER_ROLE: u8 = 1;
300
301 #[test]
302 fn test_shared_region_ownership() {
303 let coordinator = Role::lanes(0, 4, "coordinator");
304 let mut region: SharedRegion<i32, COORDINATOR> = SharedRegion::new(coordinator);
305 region.write(0, 42);
306 assert_eq!(region.read(0), 42);
307 let view = region.grant_read();
308 assert_eq!(view.read(0), 42);
309 }
310
311 #[test]
312 fn test_work_queue() {
313 let coordinator = Role::lanes(0, 4, "coordinator");
314 let worker = Role::lanes(4, 32, "worker");
315 let mut queue: WorkQueue<i32, COORDINATOR, WORKER_ROLE> =
316 WorkQueue::new(coordinator, worker);
317
318 assert!(queue.is_empty());
319 queue.push(1).unwrap();
320 queue.push(2).unwrap();
321 queue.push(3).unwrap();
322 assert!(!queue.is_empty());
323 assert_eq!(queue.pop(), Some(1));
324 assert_eq!(queue.pop(), Some(2));
325 assert_eq!(queue.pop(), Some(3));
326 assert_eq!(queue.pop(), None);
327 }
328
329 #[test]
330 fn test_work_queue_full() {
331 let coordinator = Role::lanes(0, 4, "coordinator");
332 let worker = Role::lanes(4, 32, "worker");
333 let mut queue: WorkQueue<i32, COORDINATOR, WORKER_ROLE> =
334 WorkQueue::new(coordinator, worker);
335
336 for i in 0..(crate::WARP_SIZE as i32 - 1) {
338 assert!(queue.push(i).is_ok());
339 }
340 assert!(queue.is_full());
341 assert!(queue.push(crate::WARP_SIZE as i32).is_err());
342 }
343
344 #[test]
345 fn test_hierarchical_reduction() {
346 let session = ReductionSession::<WarpPhase>::new(42);
347 let (warp_result, session) = session.warp_reduce();
348 assert_eq!(warp_result, 42);
349 let (block_result, session) = session.block_reduce();
350 assert_eq!(block_result, 42);
351 let (grid_result, session) = session.grid_reduce();
352 assert_eq!(grid_result, 42);
353 assert_eq!(session.result(), 42);
354 }
355
356 #[test]
357 fn test_block_session() {
358 let leader: BlockSession<Leader, Initial, 4> = BlockSession::new(BlockId::new(0));
359 assert_eq!(leader.block_id().0, 0);
360 let worker: BlockSession<Worker, Initial, 4> = BlockSession::new(BlockId::new(1));
361 assert_eq!(worker.block_id().0, 1);
362 }
363}