Skip to main content

warp_types/
data.rs

1//! GPU data types: uniform vs per-lane value distinction.
2//!
3//! The fundamental insight: GPU values are either uniform across all lanes
4//! or vary per-lane. Making this distinction in the type system prevents
5//! a large class of bugs (reading reduction results from wrong lanes,
6//! passing divergent data where uniform is expected, etc.).
7
8use crate::GpuValue;
9
10/// A lane identifier (0..31 for NVIDIA, 0..63 for AMD).
11///
12/// Type-safe: you can't accidentally use an arbitrary int as a lane id.
13/// Supports up to 64 lanes to accommodate AMD wavefronts.
14#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
15#[repr(transparent)]
16pub struct LaneId(u8);
17
18impl LaneId {
19    pub const fn new(id: u8) -> Self {
20        assert!(
21            id < 64,
22            "Lane ID must be < 64 (supports NVIDIA 32-lane and AMD 64-lane)"
23        );
24        LaneId(id)
25    }
26
27    pub const fn get(self) -> u8 {
28        self.0
29    }
30
31    pub const fn index(self) -> usize {
32        self.0 as usize
33    }
34}
35
36/// A warp identifier within a thread block.
37#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
38#[repr(transparent)]
39pub struct WarpId(u16);
40
41impl WarpId {
42    pub const fn new(id: u16) -> Self {
43        WarpId(id)
44    }
45
46    pub const fn get(self) -> u16 {
47        self.0
48    }
49}
50
51/// A value guaranteed to be the same across all lanes in a warp.
52///
53/// You can only create `Uniform` values through operations that guarantee
54/// uniformity (broadcasts, constants, ballot results). This prevents the
55/// common bug of assuming a value is uniform when it isn't.
56#[must_use = "Uniform values represent warp-wide reduction results — dropping discards the result"]
57#[derive(Clone, Copy, Debug, PartialEq)]
58#[repr(transparent)]
59pub struct Uniform<T: GpuValue> {
60    value: T,
61}
62
63impl<T: GpuValue> Uniform<T> {
64    /// Create a uniform value from a compile-time constant.
65    pub const fn from_const(value: T) -> Self {
66        Uniform { value }
67    }
68
69    /// Get the value. Safe because it's the same in all lanes.
70    pub fn get(self) -> T {
71        self.value
72    }
73
74    /// Broadcast: convert uniform to per-lane (identity, but changes type).
75    pub fn broadcast(self) -> PerLane<T> {
76        PerLane { value: self.value }
77    }
78}
79
80/// A value that MAY DIFFER across lanes in a warp.
81///
82/// This is the default for most GPU computations. Each lane has its own
83/// value, and you can only access other lanes' values through explicit
84/// shuffle operations.
85#[must_use = "PerLane values carry per-lane GPU data — dropping discards computation"]
86#[derive(Clone, Copy, Debug, PartialEq)]
87#[repr(transparent)]
88pub struct PerLane<T: GpuValue> {
89    value: T,
90}
91
92impl<T: GpuValue> PerLane<T> {
93    pub fn new(value: T) -> Self {
94        PerLane { value }
95    }
96
97    pub fn get(self) -> T {
98        self.value
99    }
100
101    /// Assert this value is actually uniform.
102    ///
103    /// # Safety
104    /// Caller must ensure all lanes hold the same value.
105    pub unsafe fn assume_uniform(self) -> Uniform<T> {
106        Uniform { value: self.value }
107    }
108}
109
110impl<T: GpuValue + core::ops::Add<Output = T>> core::ops::Add for PerLane<T> {
111    type Output = PerLane<T>;
112    fn add(self, rhs: PerLane<T>) -> PerLane<T> {
113        PerLane {
114            value: self.value + rhs.value,
115        }
116    }
117}
118
119/// A value that exists ONLY in a specific lane.
120///
121/// Models the result of a reduction — only one lane has the answer.
122/// Prevents the common bug of reading a reduction result from all lanes
123/// (undefined behavior in CUDA).
124#[must_use = "SingleLane values exist in one lane only — dropping discards the reduction result"]
125#[derive(Clone, Copy, Debug, PartialEq)]
126#[repr(transparent)]
127pub struct SingleLane<T: GpuValue, const LANE: u8> {
128    value: T,
129}
130
131impl<T: GpuValue, const LANE: u8> SingleLane<T, LANE> {
132    pub fn new(value: T) -> Self {
133        SingleLane { value }
134    }
135
136    /// Read the value. Only valid in the owning lane.
137    pub fn get(self) -> T {
138        self.value
139    }
140
141    /// Broadcast to all lanes — the ONLY safe way to share with other lanes.
142    pub fn broadcast(self) -> Uniform<T> {
143        Uniform::from_const(self.value)
144    }
145}
146
147/// A role within a warp (e.g., coordinator vs worker lanes).
148///
149/// Roles enable modeling warp-level protocols where different lanes
150/// have different responsibilities. Uses `u64` mask to match
151/// `ActiveSet::MASK` width (supporting AMD 64-lane wavefronts).
152#[derive(Clone, Copy, Debug, PartialEq, Eq)]
153pub struct Role {
154    mask: u64,
155    name: &'static str,
156}
157
158impl Role {
159    /// The bitmask of lanes in this role.
160    pub const fn mask(self) -> u64 {
161        self.mask
162    }
163
164    /// Human-readable role name.
165    pub const fn name(self) -> &'static str {
166        self.name
167    }
168}
169
170impl Role {
171    pub const fn lanes(start: u8, end: u8, name: &'static str) -> Self {
172        assert!(start < 64 && end <= 64 && start < end);
173        let width = (end - start) as u64;
174        let mask = if width >= 64 {
175            u64::MAX
176        } else {
177            ((1u64 << width) - 1) << start
178        };
179        Role { mask, name }
180    }
181
182    pub const fn lane(id: u8, name: &'static str) -> Self {
183        assert!(id < 64);
184        Role {
185            mask: 1u64 << id,
186            name,
187        }
188    }
189
190    pub const fn contains(self, lane: LaneId) -> bool {
191        (self.mask & (1u64 << lane.0)) != 0
192    }
193
194    pub const fn count(self) -> u32 {
195        self.mask.count_ones()
196    }
197}
198
199#[cfg(test)]
200mod tests {
201    use super::*;
202
203    #[test]
204    fn test_lane_id() {
205        let lane = LaneId::new(15);
206        assert_eq!(lane.get(), 15);
207        assert_eq!(lane.index(), 15);
208    }
209
210    #[test]
211    fn test_lane_id_boundary_31() {
212        let lane = LaneId::new(31);
213        assert_eq!(lane.get(), 31);
214        assert_eq!(lane.index(), 31);
215    }
216
217    #[test]
218    fn test_lane_id_boundary_63() {
219        let lane = LaneId::new(63);
220        assert_eq!(lane.get(), 63);
221    }
222
223    #[test]
224    #[should_panic]
225    fn test_lane_id_out_of_range() {
226        LaneId::new(64);
227    }
228
229    #[test]
230    fn test_uniform_broadcast() {
231        let u: Uniform<i32> = Uniform::from_const(42);
232        let p: PerLane<i32> = u.broadcast();
233        assert_eq!(p.get(), 42);
234    }
235
236    #[test]
237    fn test_single_lane_broadcast() {
238        let reduced: SingleLane<i32, 0> = SingleLane::new(42);
239        let uniform: Uniform<i32> = reduced.broadcast();
240        assert_eq!(uniform.get(), 42);
241    }
242
243    #[test]
244    fn test_role_coverage() {
245        let coordinator = Role::lanes(0, 4, "coordinator");
246        let worker = Role::lanes(4, 32, "worker");
247
248        assert!(coordinator.contains(LaneId::new(0)));
249        assert!(coordinator.contains(LaneId::new(3)));
250        assert!(!coordinator.contains(LaneId::new(4)));
251
252        assert!(!worker.contains(LaneId::new(3)));
253        assert!(worker.contains(LaneId::new(4)));
254        assert!(worker.contains(LaneId::new(31)));
255
256        assert_eq!(coordinator.count(), 4);
257        assert_eq!(worker.count(), 28);
258    }
259}