Skip to main content

warp_types/
block.rs

1//! Block-level types: shared memory ownership and inter-block sessions.
2//!
3//! GPU parallelism has three levels:
4//! - **Warp** (32/64 lanes): shuffles, lockstep, linear typestate
5//! - **Block** (multiple warps): shared memory, `__syncthreads()`
6//! - **Grid** (multiple blocks): global memory, cooperative groups
7//!
8//! This module provides typed abstractions for the block and grid levels.
9
10use crate::data::Role;
11use crate::GpuValue;
12use core::marker::PhantomData;
13
14// ============================================================================
15// Shared memory with ownership semantics
16// ============================================================================
17
18/// A region of shared memory owned by a specific role.
19///
20/// The key insight: shared memory races happen because ownership is implicit.
21/// By making ownership explicit, we prevent races at the type level.
22///
23/// `OWNER` is a type-level tag (u8 discriminator) that prevents cross-type access
24/// at compile time. The `owner` field carries the runtime lane range metadata
25/// (which lanes belong to this role). These encode different concerns: OWNER
26/// prevents mixing regions at the type level; Role describes the lane geometry.
27pub struct SharedRegion<T: GpuValue, const OWNER: u8> {
28    data: [T; 32],
29    owner: Role,
30    _phantom: PhantomData<()>,
31}
32
33impl<T: GpuValue + Default, const OWNER: u8> SharedRegion<T, OWNER> {
34    pub fn new(owner: Role) -> Self {
35        SharedRegion {
36            data: [T::default(); 32],
37            owner,
38            _phantom: PhantomData,
39        }
40    }
41}
42
43impl<T: GpuValue, const OWNER: u8> SharedRegion<T, OWNER> {
44    pub fn write(&mut self, index: usize, value: T) {
45        assert!(index < 32, "Index out of bounds");
46        self.data[index] = value;
47    }
48
49    pub fn read(&self, index: usize) -> T {
50        assert!(index < 32, "Index out of bounds");
51        self.data[index]
52    }
53
54    pub fn grant_read(&self) -> SharedView<'_, T, OWNER> {
55        SharedView {
56            region: self,
57            _phantom: PhantomData,
58        }
59    }
60
61    pub fn owner(&self) -> Role {
62        self.owner
63    }
64}
65
66/// A read-only view of a shared region (for non-owning roles).
67pub struct SharedView<'a, T: GpuValue, const OWNER: u8> {
68    region: &'a SharedRegion<T, OWNER>,
69    _phantom: PhantomData<()>,
70}
71
72impl<'a, T: GpuValue, const OWNER: u8> SharedView<'a, T, OWNER> {
73    pub fn read(&self, index: usize) -> T {
74        self.region.read(index)
75    }
76}
77
78/// A work queue in shared memory with typed producer/consumer roles.
79///
80/// Uses a circular buffer with 32 slots and one sentinel for full detection,
81/// giving an effective capacity of 31 items.
82pub struct WorkQueue<T: GpuValue, const PRODUCER: u8, const CONSUMER: u8> {
83    tasks: SharedRegion<T, PRODUCER>,
84    head: usize,
85    tail: usize,
86}
87
88#[derive(Debug, Clone, Copy)]
89pub struct QueueFull;
90
91impl<T: GpuValue + Default, const PRODUCER: u8, const CONSUMER: u8>
92    WorkQueue<T, PRODUCER, CONSUMER>
93{
94    pub fn new(producer_role: Role, _consumer_role: Role) -> Self {
95        WorkQueue {
96            tasks: SharedRegion::new(producer_role),
97            head: 0,
98            tail: 0,
99        }
100    }
101
102    pub fn push(&mut self, task: T) -> Result<(), QueueFull> {
103        let next = (self.head + 1) % 32;
104        if next == self.tail {
105            return Err(QueueFull);
106        }
107        self.tasks.write(self.head, task);
108        self.head = next;
109        Ok(())
110    }
111
112    pub fn pop(&mut self) -> Option<T> {
113        if self.tail == self.head {
114            return None;
115        }
116        let task = self.tasks.read(self.tail);
117        self.tail = (self.tail + 1) % 32;
118        Some(task)
119    }
120
121    pub fn is_empty(&self) -> bool {
122        self.tail == self.head
123    }
124    pub fn is_full(&self) -> bool {
125        (self.head + 1) % 32 == self.tail
126    }
127}
128
129// ============================================================================
130// GPU hierarchy types
131// ============================================================================
132
133#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
134pub struct BlockId(pub u32);
135
136#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
137pub struct ThreadId {
138    pub block: BlockId,
139    pub warp: crate::data::WarpId,
140    pub lane: crate::data::LaneId,
141}
142
143// ============================================================================
144// Inter-block protocol types
145// ============================================================================
146
147pub trait BlockRole {
148    const NAME: &'static str;
149}
150
151pub struct Leader;
152impl BlockRole for Leader {
153    const NAME: &'static str = "Leader";
154}
155
156pub struct Worker;
157impl BlockRole for Worker {
158    const NAME: &'static str = "Worker";
159}
160
161pub trait ProtocolState {}
162
163pub struct Initial;
164impl ProtocolState for Initial {}
165
166pub struct WorkDistributed;
167impl ProtocolState for WorkDistributed {}
168
169pub struct WorkComplete;
170impl ProtocolState for WorkComplete {}
171
172/// A session between blocks, parameterized by role, state, and block count.
173pub struct BlockSession<R: BlockRole, S: ProtocolState, const N: usize> {
174    block_id: BlockId,
175    _role: PhantomData<R>,
176    _state: PhantomData<S>,
177}
178
179impl<R: BlockRole, S: ProtocolState, const N: usize> BlockSession<R, S, N> {
180    #[allow(dead_code)] // Constructor for future block-level API usage
181    pub(crate) fn new(block_id: BlockId) -> Self {
182        BlockSession {
183            block_id,
184            _role: PhantomData,
185            _state: PhantomData,
186        }
187    }
188
189    pub fn block_id(&self) -> BlockId {
190        self.block_id
191    }
192}
193
194// ============================================================================
195// Hierarchical reduction (type-state machine)
196// ============================================================================
197
198pub struct WarpPhase;
199pub struct BlockPhase;
200pub struct GridPhase;
201pub struct Complete;
202
203pub struct ReductionSession<Phase> {
204    value: u32,
205    _phase: PhantomData<Phase>,
206}
207
208impl ReductionSession<WarpPhase> {
209    #[allow(dead_code)] // Constructor for future reduction pipeline usage
210    pub(crate) fn new(value: u32) -> Self {
211        ReductionSession {
212            value,
213            _phase: PhantomData,
214        }
215    }
216
217    pub fn warp_reduce(self) -> (u32, ReductionSession<BlockPhase>) {
218        (
219            self.value,
220            ReductionSession {
221                value: self.value,
222                _phase: PhantomData,
223            },
224        )
225    }
226}
227
228impl ReductionSession<BlockPhase> {
229    pub fn block_reduce(self) -> (u32, ReductionSession<GridPhase>) {
230        (
231            self.value,
232            ReductionSession {
233                value: self.value,
234                _phase: PhantomData,
235            },
236        )
237    }
238}
239
240impl ReductionSession<GridPhase> {
241    pub fn grid_reduce(self) -> (u32, ReductionSession<Complete>) {
242        (
243            self.value,
244            ReductionSession {
245                value: self.value,
246                _phase: PhantomData,
247            },
248        )
249    }
250}
251
252impl ReductionSession<Complete> {
253    pub fn result(self) -> u32 {
254        self.value
255    }
256}
257
258#[cfg(test)]
259mod tests {
260    use super::*;
261
262    const COORDINATOR: u8 = 0;
263    const WORKER_ROLE: u8 = 1;
264
265    #[test]
266    fn test_shared_region_ownership() {
267        let coordinator = Role::lanes(0, 4, "coordinator");
268        let mut region: SharedRegion<i32, COORDINATOR> = SharedRegion::new(coordinator);
269        region.write(0, 42);
270        assert_eq!(region.read(0), 42);
271        let view = region.grant_read();
272        assert_eq!(view.read(0), 42);
273    }
274
275    #[test]
276    fn test_work_queue() {
277        let coordinator = Role::lanes(0, 4, "coordinator");
278        let worker = Role::lanes(4, 32, "worker");
279        let mut queue: WorkQueue<i32, COORDINATOR, WORKER_ROLE> =
280            WorkQueue::new(coordinator, worker);
281
282        assert!(queue.is_empty());
283        queue.push(1).unwrap();
284        queue.push(2).unwrap();
285        queue.push(3).unwrap();
286        assert!(!queue.is_empty());
287        assert_eq!(queue.pop(), Some(1));
288        assert_eq!(queue.pop(), Some(2));
289        assert_eq!(queue.pop(), Some(3));
290        assert_eq!(queue.pop(), None);
291    }
292
293    #[test]
294    fn test_work_queue_full() {
295        let coordinator = Role::lanes(0, 4, "coordinator");
296        let worker = Role::lanes(4, 32, "worker");
297        let mut queue: WorkQueue<i32, COORDINATOR, WORKER_ROLE> =
298            WorkQueue::new(coordinator, worker);
299
300        // Ring buffer of size 32 has capacity 31 (one slot reserved for full detection)
301        for i in 0..31 {
302            assert!(queue.push(i).is_ok());
303        }
304        assert!(queue.is_full());
305        assert!(queue.push(31).is_err());
306    }
307
308    #[test]
309    fn test_hierarchical_reduction() {
310        let session = ReductionSession::<WarpPhase>::new(42);
311        let (warp_result, session) = session.warp_reduce();
312        assert_eq!(warp_result, 42);
313        let (block_result, session) = session.block_reduce();
314        assert_eq!(block_result, 42);
315        let (grid_result, session) = session.grid_reduce();
316        assert_eq!(grid_result, 42);
317        assert_eq!(session.result(), 42);
318    }
319
320    #[test]
321    fn test_block_session() {
322        let leader: BlockSession<Leader, Initial, 4> = BlockSession::new(BlockId(0));
323        assert_eq!(leader.block_id().0, 0);
324        let worker: BlockSession<Worker, Initial, 4> = BlockSession::new(BlockId(1));
325        assert_eq!(worker.block_id().0, 1);
326    }
327}