Skip to main content

warp_types/
active_set.rs

1//! Active set types: compile-time lane subset tracking.
2//!
3//! Active sets are zero-sized marker types that represent subsets of warp lanes.
4//! The type system tracks which lanes are active through diverge/merge operations,
5//! preventing shuffle-from-inactive-lane bugs at compile time.
6//!
7//! # Lattice structure
8//!
9//! Active sets form a Boolean lattice under subset ordering:
10//!
11//! ```text
12//!                    All (32/64 lanes)
13//!                   /    \
14//!            Even (16)   Odd (16)     LowHalf (16)   HighHalf (16)
15//!             / \         / \            / \              / \
16//!        EvenLow EvenHigh OddLow OddHigh EvenLow OddLow EvenHigh OddHigh
17//!          (8)    (8)      (8)    (8)      (8)    (8)     (8)     (8)
18//! ```
19//!
20//! Note: `EvenLow` appears under both `Even` and `LowHalf` — same set,
21//! reached by different diverge paths. Path independence is a key property.
22
23/// Sealed trait module — prevents external crates from implementing safety-critical traits.
24///
25/// Hard-sealed: the `_sealed` method returns a `pub(crate)` type with no default
26/// body. External crates cannot name `SealToken` and therefore cannot provide
27/// the required method implementation.
28#[doc(hidden)]
29pub mod sealed {
30    #[doc(hidden)]
31    pub(crate) struct SealToken;
32
33    #[allow(private_interfaces)]
34    pub trait Sealed {
35        #[doc(hidden)]
36        fn _sealed() -> SealToken;
37    }
38}
39
40/// Marker trait for active lane set types.
41///
42/// Each implementor is a zero-sized type encoding a specific bitmask of lanes.
43/// The `MASK` constant enables runtime debugging; the type itself provides
44/// compile-time tracking.
45pub trait ActiveSet: sealed::Sealed + Copy + 'static {
46    /// Bitmask of active lanes (for runtime debugging/verification).
47    const MASK: u64;
48    /// Human-readable name.
49    const NAME: &'static str;
50}
51
52/// Proof that `Self` and `Other` are complements: disjoint AND covering all lanes.
53///
54/// This is THE key safety trait. `merge(a, b)` requires `A: ComplementOf<B>`.
55/// Only implemented for valid complement pairs — the compiler rejects invalid merges.
56#[diagnostic::on_unimplemented(
57    message = "`{Self}` is not the complement of `{Other}` — cannot merge these sub-warps",
58    label = "merge requires complementary active sets (e.g., Even + Odd, LowHalf + HighHalf)",
59    note = "use `diverge_even_odd()` or `diverge_halves()` to create valid complement pairs, then merge them"
60)]
61pub trait ComplementOf<Other: ActiveSet>: sealed::Sealed + ActiveSet {}
62
63/// Proof that `Self` and `Other` are complements within a parent set `P`.
64///
65/// `S1 ∪ S2 = P` and `S1 ∩ S2 = ∅`. Used for nested divergence where
66/// merge returns to a parent set rather than `All`.
67pub trait ComplementWithin<Other: ActiveSet, Parent: ActiveSet>:
68    sealed::Sealed + ActiveSet
69{
70}
71
72/// Proof that an active set can be split into two disjoint subsets.
73///
74/// Implemented for each valid diverge pattern (e.g., `All` → `Even` + `Odd`).
75#[diagnostic::on_unimplemented(
76    message = "`{Self}` cannot be split into `{TrueBranch}` + `{FalseBranch}`",
77    label = "this diverge pattern is not defined in the active set hierarchy",
78    note = "valid diverge patterns: All → Even/Odd, All → LowHalf/HighHalf, Even → EvenLow/EvenHigh, etc."
79)]
80pub trait CanDiverge<TrueBranch: ActiveSet, FalseBranch: ActiveSet>:
81    sealed::Sealed + ActiveSet + Sized
82{
83    fn diverge(
84        warp: crate::warp::Warp<Self>,
85    ) -> (
86        crate::warp::Warp<TrueBranch>,
87        crate::warp::Warp<FalseBranch>,
88    );
89}
90
91/// No lanes active (degenerate). Not part of the diverge hierarchy.
92#[derive(Copy, Clone, Debug, Default)]
93pub struct Empty;
94#[allow(private_interfaces)]
95impl sealed::Sealed for Empty {
96    fn _sealed() -> sealed::SealToken {
97        sealed::SealToken
98    }
99}
100impl ActiveSet for Empty {
101    const MASK: u64 = 0;
102    const NAME: &'static str = "Empty";
103}
104
105// ============================================================================
106// Generated active set hierarchy
107//
108// The warp_sets! macro validates at compile time:
109//   - Each pair is disjoint (true_mask & false_mask == 0)
110//   - Each pair covers its parent (true_mask | false_mask == parent_mask)
111//   - Children are subsets of parent (child_mask & !parent_mask == 0)
112//
113// Shared types (e.g., EvenLow under both Even and LowHalf) are deduplicated.
114// ============================================================================
115
116// 32-lane NVIDIA warps (default)
117#[cfg(not(feature = "warp64"))]
118warp_types_macros::warp_sets! {
119    All = 0xFFFFFFFF {
120        Even = 0x55555555 / Odd = 0xAAAAAAAA,
121        LowHalf = 0x0000FFFF / HighHalf = 0xFFFF0000,
122        Lane0 = 0x00000001 / NotLane0 = 0xFFFFFFFE,
123    }
124    Even = 0x55555555 {
125        EvenLow = 0x00005555 / EvenHigh = 0x55550000,
126    }
127    Odd = 0xAAAAAAAA {
128        OddLow = 0x0000AAAA / OddHigh = 0xAAAA0000,
129    }
130    LowHalf = 0x0000FFFF {
131        EvenLow = 0x00005555 / OddLow = 0x0000AAAA,
132    }
133    HighHalf = 0xFFFF0000 {
134        EvenHigh = 0x55550000 / OddHigh = 0xAAAA0000,
135    }
136}
137
138// 64-lane AMD wavefronts (warp64 feature)
139#[cfg(feature = "warp64")]
140warp_types_macros::warp_sets! {
141    All = 0xFFFFFFFFFFFFFFFF {
142        Even = 0x5555555555555555 / Odd = 0xAAAAAAAAAAAAAAAA,
143        LowHalf = 0x00000000FFFFFFFF / HighHalf = 0xFFFFFFFF00000000,
144        Lane0 = 0x0000000000000001 / NotLane0 = 0xFFFFFFFFFFFFFFFE,
145    }
146    Even = 0x5555555555555555 {
147        EvenLow = 0x0000000055555555 / EvenHigh = 0x5555555500000000,
148    }
149    Odd = 0xAAAAAAAAAAAAAAAA {
150        OddLow = 0x00000000AAAAAAAA / OddHigh = 0xAAAAAAAA00000000,
151    }
152    LowHalf = 0x00000000FFFFFFFF {
153        EvenLow = 0x0000000055555555 / OddLow = 0x00000000AAAAAAAA,
154    }
155    HighHalf = 0xFFFFFFFF00000000 {
156        EvenHigh = 0x5555555500000000 / OddHigh = 0xAAAAAAAA00000000,
157    }
158}
159
160// Empty/All complement pair — Empty isn't produced by any diverge,
161// so it's not part of the generated hierarchy
162impl ComplementOf<Empty> for All {}
163impl ComplementOf<All> for Empty {}
164
165// NOTE: EvenLow/EvenHigh are complements within Even, NOT within All.
166// ComplementOf requires covering ALL lanes, so these do NOT get ComplementOf impls.
167// Use merge_within<EvenLow, EvenHigh, Even>() for nested merges.
168// See also: ComplementWithin impls generated by warp_sets! macro.
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173
174    #[test]
175    #[cfg(not(feature = "warp64"))]
176    fn test_mask_values() {
177        assert_eq!(All::MASK, 0xFFFFFFFF);
178        assert_eq!(Empty::MASK, 0x00000000);
179        assert_eq!(Even::MASK, 0x55555555);
180        assert_eq!(Odd::MASK, 0xAAAAAAAA);
181        assert_eq!(LowHalf::MASK, 0x0000FFFF);
182        assert_eq!(HighHalf::MASK, 0xFFFF0000);
183        assert_eq!(Lane0::MASK, 0x00000001);
184        assert_eq!(NotLane0::MASK, 0xFFFFFFFE);
185        assert_eq!(EvenLow::MASK, 0x00005555);
186        assert_eq!(EvenHigh::MASK, 0x55550000);
187        assert_eq!(OddLow::MASK, 0x0000AAAA);
188        assert_eq!(OddHigh::MASK, 0xAAAA0000);
189    }
190
191    #[test]
192    #[cfg(feature = "warp64")]
193    fn test_mask_values_64() {
194        assert_eq!(All::MASK, 0xFFFFFFFFFFFFFFFF);
195        assert_eq!(Empty::MASK, 0x00000000);
196        assert_eq!(Even::MASK, 0x5555555555555555);
197        assert_eq!(Odd::MASK, 0xAAAAAAAAAAAAAAAA);
198        assert_eq!(LowHalf::MASK, 0x00000000FFFFFFFF);
199        assert_eq!(HighHalf::MASK, 0xFFFFFFFF00000000);
200        assert_eq!(Lane0::MASK, 0x0000000000000001);
201        assert_eq!(NotLane0::MASK, 0xFFFFFFFFFFFFFFFE);
202        assert_eq!(EvenLow::MASK, 0x0000000055555555);
203        assert_eq!(EvenHigh::MASK, 0x5555555500000000);
204        assert_eq!(OddLow::MASK, 0x00000000AAAAAAAA);
205        assert_eq!(OddHigh::MASK, 0xAAAAAAAA00000000);
206    }
207
208    #[test]
209    fn test_intersection_properties() {
210        assert_eq!(Even::MASK & LowHalf::MASK, EvenLow::MASK);
211        assert_eq!(Even::MASK & HighHalf::MASK, EvenHigh::MASK);
212        assert_eq!(Odd::MASK & LowHalf::MASK, OddLow::MASK);
213        assert_eq!(Odd::MASK & HighHalf::MASK, OddHigh::MASK);
214    }
215
216    #[test]
217    fn test_union_properties() {
218        assert_eq!(EvenLow::MASK | EvenHigh::MASK, Even::MASK);
219        assert_eq!(OddLow::MASK | OddHigh::MASK, Odd::MASK);
220        assert_eq!(EvenLow::MASK | OddLow::MASK, LowHalf::MASK);
221        assert_eq!(EvenHigh::MASK | OddHigh::MASK, HighHalf::MASK);
222        assert_eq!(
223            EvenLow::MASK | EvenHigh::MASK | OddLow::MASK | OddHigh::MASK,
224            All::MASK
225        );
226    }
227
228    #[test]
229    fn test_pairwise_disjoint() {
230        let sets = [EvenLow::MASK, EvenHigh::MASK, OddLow::MASK, OddHigh::MASK];
231        for i in 0..sets.len() {
232            for j in (i + 1)..sets.len() {
233                assert_eq!(sets[i] & sets[j], 0, "sets {} and {} overlap", i, j);
234            }
235        }
236    }
237
238    #[test]
239    fn test_complement_symmetry() {
240        assert_eq!(Even::MASK | Odd::MASK, All::MASK);
241        assert_eq!(Even::MASK & Odd::MASK, 0);
242        assert_eq!(LowHalf::MASK | HighHalf::MASK, All::MASK);
243        assert_eq!(LowHalf::MASK & HighHalf::MASK, 0);
244        assert_eq!(Lane0::MASK | NotLane0::MASK, All::MASK);
245        assert_eq!(Lane0::MASK & NotLane0::MASK, 0);
246    }
247}