Skip to main content

warp_types/
data.rs

1//! GPU data types: uniform vs per-lane value distinction.
2//!
3//! The fundamental insight: GPU values are either uniform across all lanes
4//! or vary per-lane. Making this distinction in the type system prevents
5//! a large class of bugs (reading reduction results from wrong lanes,
6//! passing divergent data where uniform is expected, etc.).
7
8use crate::GpuValue;
9
10/// A lane identifier (0..31 for NVIDIA, 0..63 for AMD).
11///
12/// Type-safe: you can't accidentally use an arbitrary int as a lane id.
13/// Supports up to 64 lanes to accommodate AMD wavefronts.
14#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
15pub struct LaneId(u8);
16
17impl LaneId {
18    pub const fn new(id: u8) -> Self {
19        assert!(
20            id < 64,
21            "Lane ID must be < 64 (supports NVIDIA 32-lane and AMD 64-lane)"
22        );
23        LaneId(id)
24    }
25
26    pub const fn get(self) -> u8 {
27        self.0
28    }
29
30    pub const fn index(self) -> usize {
31        self.0 as usize
32    }
33}
34
35/// A warp identifier within a thread block.
36#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
37pub struct WarpId(u16);
38
39impl WarpId {
40    pub const fn new(id: u16) -> Self {
41        WarpId(id)
42    }
43
44    pub const fn get(self) -> u16 {
45        self.0
46    }
47}
48
49/// A value guaranteed to be the same across all lanes in a warp.
50///
51/// You can only create `Uniform` values through operations that guarantee
52/// uniformity (broadcasts, constants, ballot results). This prevents the
53/// common bug of assuming a value is uniform when it isn't.
54#[derive(Clone, Copy, Debug, PartialEq)]
55pub struct Uniform<T: GpuValue> {
56    value: T,
57}
58
59impl<T: GpuValue> Uniform<T> {
60    /// Create a uniform value from a compile-time constant.
61    pub const fn from_const(value: T) -> Self {
62        Uniform { value }
63    }
64
65    /// Get the value. Safe because it's the same in all lanes.
66    pub fn get(self) -> T {
67        self.value
68    }
69
70    /// Broadcast: convert uniform to per-lane (identity, but changes type).
71    pub fn broadcast(self) -> PerLane<T> {
72        PerLane { value: self.value }
73    }
74}
75
76/// A value that MAY DIFFER across lanes in a warp.
77///
78/// This is the default for most GPU computations. Each lane has its own
79/// value, and you can only access other lanes' values through explicit
80/// shuffle operations.
81#[must_use = "PerLane values carry per-lane GPU data — dropping discards computation"]
82#[derive(Clone, Copy, Debug, PartialEq)]
83pub struct PerLane<T: GpuValue> {
84    value: T,
85}
86
87impl<T: GpuValue> PerLane<T> {
88    pub fn new(value: T) -> Self {
89        PerLane { value }
90    }
91
92    pub fn get(self) -> T {
93        self.value
94    }
95
96    /// Assert this value is actually uniform.
97    ///
98    /// # Safety
99    /// Caller must ensure all lanes hold the same value.
100    pub unsafe fn assume_uniform(self) -> Uniform<T> {
101        Uniform { value: self.value }
102    }
103}
104
105impl<T: GpuValue + core::ops::Add<Output = T>> core::ops::Add for PerLane<T> {
106    type Output = PerLane<T>;
107    fn add(self, rhs: PerLane<T>) -> PerLane<T> {
108        PerLane {
109            value: self.value + rhs.value,
110        }
111    }
112}
113
114/// A value that exists ONLY in a specific lane.
115///
116/// Models the result of a reduction — only one lane has the answer.
117/// Prevents the common bug of reading a reduction result from all lanes
118/// (undefined behavior in CUDA).
119#[derive(Clone, Copy, Debug, PartialEq)]
120pub struct SingleLane<T: GpuValue, const LANE: u8> {
121    value: T,
122}
123
124impl<T: GpuValue, const LANE: u8> SingleLane<T, LANE> {
125    pub fn new(value: T) -> Self {
126        SingleLane { value }
127    }
128
129    /// Read the value. Only valid in the owning lane.
130    pub fn get(self) -> T {
131        self.value
132    }
133
134    /// Broadcast to all lanes — the ONLY safe way to share with other lanes.
135    pub fn broadcast(self) -> Uniform<T> {
136        Uniform::from_const(self.value)
137    }
138}
139
140/// A role within a warp (e.g., coordinator vs worker lanes).
141///
142/// Roles enable modeling warp-level protocols where different lanes
143/// have different responsibilities. Uses `u64` mask to match
144/// `ActiveSet::MASK` width (supporting AMD 64-lane wavefronts).
145#[derive(Clone, Copy, Debug, PartialEq, Eq)]
146pub struct Role {
147    pub mask: u64,
148    pub name: &'static str,
149}
150
151impl Role {
152    pub const fn lanes(start: u8, end: u8, name: &'static str) -> Self {
153        assert!(start < 64 && end <= 64 && start < end);
154        let width = (end - start) as u64;
155        let mask = if width >= 64 {
156            u64::MAX
157        } else {
158            ((1u64 << width) - 1) << start
159        };
160        Role { mask, name }
161    }
162
163    pub const fn lane(id: u8, name: &'static str) -> Self {
164        assert!(id < 64);
165        Role {
166            mask: 1u64 << id,
167            name,
168        }
169    }
170
171    pub const fn contains(self, lane: LaneId) -> bool {
172        (self.mask & (1u64 << lane.0)) != 0
173    }
174
175    pub const fn count(self) -> u32 {
176        self.mask.count_ones()
177    }
178}
179
180#[cfg(test)]
181mod tests {
182    use super::*;
183
184    #[test]
185    fn test_lane_id() {
186        let lane = LaneId::new(15);
187        assert_eq!(lane.get(), 15);
188        assert_eq!(lane.index(), 15);
189    }
190
191    #[test]
192    fn test_lane_id_boundary_31() {
193        let lane = LaneId::new(31);
194        assert_eq!(lane.get(), 31);
195        assert_eq!(lane.index(), 31);
196    }
197
198    #[test]
199    fn test_lane_id_boundary_63() {
200        let lane = LaneId::new(63);
201        assert_eq!(lane.get(), 63);
202    }
203
204    #[test]
205    #[should_panic]
206    fn test_lane_id_out_of_range() {
207        LaneId::new(64);
208    }
209
210    #[test]
211    fn test_uniform_broadcast() {
212        let u: Uniform<i32> = Uniform::from_const(42);
213        let p: PerLane<i32> = u.broadcast();
214        assert_eq!(p.get(), 42);
215    }
216
217    #[test]
218    fn test_single_lane_broadcast() {
219        let reduced: SingleLane<i32, 0> = SingleLane::new(42);
220        let uniform: Uniform<i32> = reduced.broadcast();
221        assert_eq!(uniform.get(), 42);
222    }
223
224    #[test]
225    fn test_role_coverage() {
226        let coordinator = Role::lanes(0, 4, "coordinator");
227        let worker = Role::lanes(4, 32, "worker");
228
229        assert!(coordinator.contains(LaneId::new(0)));
230        assert!(coordinator.contains(LaneId::new(3)));
231        assert!(!coordinator.contains(LaneId::new(4)));
232
233        assert!(!worker.contains(LaneId::new(3)));
234        assert!(worker.contains(LaneId::new(4)));
235        assert!(worker.contains(LaneId::new(31)));
236
237        assert_eq!(coordinator.count(), 4);
238        assert_eq!(worker.count(), 28);
239    }
240}