Skip to main content

warp_types/
block.rs

1//! Block-level types: shared memory ownership and inter-block sessions.
2//!
3//! GPU parallelism has three levels:
4//! - **Warp** (32/64 lanes): shuffles, lockstep, linear typestate
5//! - **Block** (multiple warps): shared memory, `__syncthreads()`
6//! - **Grid** (multiple blocks): global memory, cooperative groups
7//!
8//! This module provides typed abstractions for the block and grid levels.
9
10use crate::data::Role;
11use crate::GpuValue;
12use core::marker::PhantomData;
13
14// ============================================================================
15// Shared memory with ownership semantics
16// ============================================================================
17
18/// A region of shared memory owned by a specific role.
19///
20/// The key insight: shared memory races happen because ownership is implicit.
21/// By making ownership explicit, we prevent races at the type level.
22///
23/// `OWNER` is a type-level tag (u8 discriminator) that prevents cross-type access
24/// at compile time. The `owner` field carries the runtime lane range metadata
25/// (which lanes belong to this role). These encode different concerns: OWNER
26/// prevents mixing regions at the type level; Role describes the lane geometry.
27///
28/// Ownership is enforced at compile time only (const generic tag). The `owner`
29/// field is metadata for debugging — `write`/`read` do not verify the caller's
30/// role at runtime. In a real GPU implementation, kernel launch guarantees
31/// would replace runtime checks.
32pub struct SharedRegion<T: GpuValue, const OWNER: u8> {
33    data: [T; crate::WARP_SIZE as usize],
34    owner: Role,
35    _phantom: PhantomData<()>,
36}
37
38impl<T: GpuValue + Default, const OWNER: u8> SharedRegion<T, OWNER> {
39    pub fn new(owner: Role) -> Self {
40        SharedRegion {
41            data: [T::default(); crate::WARP_SIZE as usize],
42            owner,
43            _phantom: PhantomData,
44        }
45    }
46}
47
48impl<T: GpuValue, const OWNER: u8> SharedRegion<T, OWNER> {
49    pub fn write(&mut self, index: usize, value: T) {
50        assert!(index < crate::WARP_SIZE as usize, "Index out of bounds");
51        self.data[index] = value;
52    }
53
54    pub fn read(&self, index: usize) -> T {
55        assert!(index < crate::WARP_SIZE as usize, "Index out of bounds");
56        self.data[index]
57    }
58
59    pub fn grant_read(&self) -> SharedView<'_, T, OWNER> {
60        SharedView {
61            region: self,
62            _phantom: PhantomData,
63        }
64    }
65
66    pub fn owner(&self) -> Role {
67        self.owner
68    }
69}
70
71/// A read-only view of a shared region (for non-owning roles).
72pub struct SharedView<'a, T: GpuValue, const OWNER: u8> {
73    region: &'a SharedRegion<T, OWNER>,
74    _phantom: PhantomData<()>,
75}
76
77impl<'a, T: GpuValue, const OWNER: u8> SharedView<'a, T, OWNER> {
78    pub fn read(&self, index: usize) -> T {
79        self.region.read(index)
80    }
81}
82
83/// A work queue in shared memory with typed producer/consumer roles.
84///
85/// Uses a circular buffer with `WARP_SIZE` slots and one sentinel for full
86/// detection, giving an effective capacity of `WARP_SIZE - 1` items.
87pub struct WorkQueue<T: GpuValue, const PRODUCER: u8, const CONSUMER: u8> {
88    tasks: SharedRegion<T, PRODUCER>,
89    head: usize,
90    tail: usize,
91}
92
93#[derive(Debug, Clone, Copy)]
94pub struct QueueFull;
95
96impl<T: GpuValue + Default, const PRODUCER: u8, const CONSUMER: u8>
97    WorkQueue<T, PRODUCER, CONSUMER>
98{
99    pub fn new(producer_role: Role, _consumer_role: Role) -> Self {
100        WorkQueue {
101            tasks: SharedRegion::new(producer_role),
102            head: 0,
103            tail: 0,
104        }
105    }
106
107    pub fn push(&mut self, task: T) -> Result<(), QueueFull> {
108        let next = (self.head + 1) % crate::WARP_SIZE as usize;
109        if next == self.tail {
110            return Err(QueueFull);
111        }
112        self.tasks.write(self.head, task);
113        self.head = next;
114        Ok(())
115    }
116
117    pub fn pop(&mut self) -> Option<T> {
118        if self.tail == self.head {
119            return None;
120        }
121        let task = self.tasks.read(self.tail);
122        self.tail = (self.tail + 1) % crate::WARP_SIZE as usize;
123        Some(task)
124    }
125
126    pub fn is_empty(&self) -> bool {
127        self.tail == self.head
128    }
129    pub fn is_full(&self) -> bool {
130        (self.head + 1) % crate::WARP_SIZE as usize == self.tail
131    }
132}
133
134// ============================================================================
135// GPU hierarchy types
136// ============================================================================
137
138#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
139#[repr(transparent)]
140pub struct BlockId(u32);
141
142impl BlockId {
143    pub const fn new(id: u32) -> Self {
144        BlockId(id)
145    }
146
147    pub const fn get(self) -> u32 {
148        self.0
149    }
150}
151
152#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
153#[repr(C)]
154pub struct ThreadId {
155    block: BlockId,
156    warp: crate::data::WarpId,
157    lane: crate::data::LaneId,
158}
159
160impl ThreadId {
161    pub const fn new(block: BlockId, warp: crate::data::WarpId, lane: crate::data::LaneId) -> Self {
162        ThreadId { block, warp, lane }
163    }
164
165    pub const fn block(self) -> BlockId {
166        self.block
167    }
168
169    pub const fn warp(self) -> crate::data::WarpId {
170        self.warp
171    }
172
173    pub const fn lane(self) -> crate::data::LaneId {
174        self.lane
175    }
176}
177
178// ============================================================================
179// Inter-block protocol types
180// ============================================================================
181
182pub trait BlockRole {
183    const NAME: &'static str;
184}
185
186pub struct Leader;
187impl BlockRole for Leader {
188    const NAME: &'static str = "Leader";
189}
190
191pub struct Worker;
192impl BlockRole for Worker {
193    const NAME: &'static str = "Worker";
194}
195
196pub trait ProtocolState {}
197
198pub struct Initial;
199impl ProtocolState for Initial {}
200
201pub struct WorkDistributed;
202impl ProtocolState for WorkDistributed {}
203
204pub struct WorkComplete;
205impl ProtocolState for WorkComplete {}
206
207/// A session between blocks, parameterized by role, state, and block count.
208pub struct BlockSession<R: BlockRole, S: ProtocolState, const N: usize> {
209    block_id: BlockId,
210    _role: PhantomData<R>,
211    _state: PhantomData<S>,
212}
213
214impl<R: BlockRole, S: ProtocolState, const N: usize> BlockSession<R, S, N> {
215    #[allow(dead_code)] // Constructor for future block-level API usage
216    pub(crate) fn new(block_id: BlockId) -> Self {
217        BlockSession {
218            block_id,
219            _role: PhantomData,
220            _state: PhantomData,
221        }
222    }
223
224    pub fn block_id(&self) -> BlockId {
225        self.block_id
226    }
227}
228
229// ============================================================================
230// Hierarchical reduction (type-state machine)
231// ============================================================================
232
233pub struct WarpPhase;
234pub struct BlockPhase;
235pub struct GridPhase;
236pub struct Complete;
237
238#[must_use = "ReductionSession is a linear state machine — dropping abandons the reduction"]
239pub struct ReductionSession<Phase> {
240    value: u32,
241    _phase: PhantomData<Phase>,
242}
243
244impl ReductionSession<WarpPhase> {
245    #[allow(dead_code)] // Constructor for future reduction pipeline usage
246    pub(crate) fn new(value: u32) -> Self {
247        ReductionSession {
248            value,
249            _phase: PhantomData,
250        }
251    }
252
253    pub fn warp_reduce(self) -> (u32, ReductionSession<BlockPhase>) {
254        (
255            self.value,
256            ReductionSession {
257                value: self.value,
258                _phase: PhantomData,
259            },
260        )
261    }
262}
263
264impl ReductionSession<BlockPhase> {
265    pub fn block_reduce(self) -> (u32, ReductionSession<GridPhase>) {
266        (
267            self.value,
268            ReductionSession {
269                value: self.value,
270                _phase: PhantomData,
271            },
272        )
273    }
274}
275
276impl ReductionSession<GridPhase> {
277    pub fn grid_reduce(self) -> (u32, ReductionSession<Complete>) {
278        (
279            self.value,
280            ReductionSession {
281                value: self.value,
282                _phase: PhantomData,
283            },
284        )
285    }
286}
287
288impl ReductionSession<Complete> {
289    pub fn result(self) -> u32 {
290        self.value
291    }
292}
293
294#[cfg(test)]
295mod tests {
296    use super::*;
297
298    const COORDINATOR: u8 = 0;
299    const WORKER_ROLE: u8 = 1;
300
301    #[test]
302    fn test_shared_region_ownership() {
303        let coordinator = Role::lanes(0, 4, "coordinator");
304        let mut region: SharedRegion<i32, COORDINATOR> = SharedRegion::new(coordinator);
305        region.write(0, 42);
306        assert_eq!(region.read(0), 42);
307        let view = region.grant_read();
308        assert_eq!(view.read(0), 42);
309    }
310
311    #[test]
312    fn test_work_queue() {
313        let coordinator = Role::lanes(0, 4, "coordinator");
314        let worker = Role::lanes(4, 32, "worker");
315        let mut queue: WorkQueue<i32, COORDINATOR, WORKER_ROLE> =
316            WorkQueue::new(coordinator, worker);
317
318        assert!(queue.is_empty());
319        queue.push(1).unwrap();
320        queue.push(2).unwrap();
321        queue.push(3).unwrap();
322        assert!(!queue.is_empty());
323        assert_eq!(queue.pop(), Some(1));
324        assert_eq!(queue.pop(), Some(2));
325        assert_eq!(queue.pop(), Some(3));
326        assert_eq!(queue.pop(), None);
327    }
328
329    #[test]
330    fn test_work_queue_full() {
331        let coordinator = Role::lanes(0, 4, "coordinator");
332        let worker = Role::lanes(4, 32, "worker");
333        let mut queue: WorkQueue<i32, COORDINATOR, WORKER_ROLE> =
334            WorkQueue::new(coordinator, worker);
335
336        // Ring buffer of WARP_SIZE has capacity WARP_SIZE-1 (one slot reserved for full detection)
337        for i in 0..(crate::WARP_SIZE as i32 - 1) {
338            assert!(queue.push(i).is_ok());
339        }
340        assert!(queue.is_full());
341        assert!(queue.push(crate::WARP_SIZE as i32).is_err());
342    }
343
344    #[test]
345    fn test_hierarchical_reduction() {
346        let session = ReductionSession::<WarpPhase>::new(42);
347        let (warp_result, session) = session.warp_reduce();
348        assert_eq!(warp_result, 42);
349        let (block_result, session) = session.block_reduce();
350        assert_eq!(block_result, 42);
351        let (grid_result, session) = session.grid_reduce();
352        assert_eq!(grid_result, 42);
353        assert_eq!(session.result(), 42);
354    }
355
356    #[test]
357    fn test_block_session() {
358        let leader: BlockSession<Leader, Initial, 4> = BlockSession::new(BlockId::new(0));
359        assert_eq!(leader.block_id().0, 0);
360        let worker: BlockSession<Worker, Initial, 4> = BlockSession::new(BlockId::new(1));
361        assert_eq!(worker.block_id().0, 1);
362    }
363}