Skip to main content

warp_types/
fence.rs

1//! Fence-divergence interaction types (§5.6).
2//!
3//! Global memory writes from diverged warps must be carefully tracked.
4//! A fence is only valid after ALL lanes have written — which requires
5//! the same complement proof used for merge.
6//!
7//! # Type-state protocol
8//!
9//! ```text
10//! GlobalRegion::with_region(|region| {
11//!   // region: GlobalRegion<'r, Unwritten>
12//!   //   → Warp<S1>.global_store() → GlobalRegion<'r, PartialWrite<S1>>
13//!   //   → merge_writes(PartialWrite<S1>, PartialWrite<S2>) → GlobalRegion<'r, FullWrite>
14//!   //     (requires S1: ComplementOf<S2>, same 'r)
15//!   //   → threadfence(FullWrite) → GlobalRegion<'r, Fenced>
16//! })
17//! ```
18//!
19//! The lifetime `'r` ties all partial writes to the region that created them.
20//! Two partial writes from *different* `GlobalRegion::with_region` calls have
21//! different lifetimes and **cannot** be merged — the compiler rejects it.
22//!
23//! This turns both memory ordering bugs and cross-region confusion into type errors.
24
25use crate::active_set::{ActiveSet, ComplementOf, ComplementWithin};
26use crate::warp::Warp;
27use core::marker::PhantomData;
28
29// ============================================================================
30// Write state markers
31// ============================================================================
32
33/// Marker trait for global region write states.
34///
35/// Sealed — external crates cannot implement this trait, preventing
36/// forgery of write-state transitions.
37pub trait WriteState: crate::active_set::sealed::Sealed {}
38
39/// No writes have occurred.
40#[derive(Debug, Clone, Copy)]
41pub struct Unwritten;
42#[allow(private_interfaces)]
43impl crate::active_set::sealed::Sealed for Unwritten {
44    fn _sealed() -> crate::active_set::sealed::SealToken {
45        crate::active_set::sealed::SealToken
46    }
47}
48impl WriteState for Unwritten {}
49
50/// Partial write: only lanes in `S` have written.
51#[derive(Debug, Clone, Copy)]
52pub struct PartialWrite<S: ActiveSet> {
53    _phantom: PhantomData<S>,
54}
55#[allow(private_interfaces)]
56impl<S: ActiveSet> crate::active_set::sealed::Sealed for PartialWrite<S> {
57    fn _sealed() -> crate::active_set::sealed::SealToken {
58        crate::active_set::sealed::SealToken
59    }
60}
61impl<S: ActiveSet> WriteState for PartialWrite<S> {}
62
63/// All lanes have written (complement-verified).
64#[derive(Debug, Clone, Copy)]
65pub struct FullWrite;
66#[allow(private_interfaces)]
67impl crate::active_set::sealed::Sealed for FullWrite {
68    fn _sealed() -> crate::active_set::sealed::SealToken {
69        crate::active_set::sealed::SealToken
70    }
71}
72impl WriteState for FullWrite {}
73
74/// Fence has been issued after full write.
75#[derive(Debug, Clone, Copy)]
76pub struct Fenced;
77#[allow(private_interfaces)]
78impl crate::active_set::sealed::Sealed for Fenced {
79    fn _sealed() -> crate::active_set::sealed::SealToken {
80        crate::active_set::sealed::SealToken
81    }
82}
83impl WriteState for Fenced {}
84
85// ============================================================================
86// Global region with write tracking + region identity
87// ============================================================================
88
89/// A global memory region with type-state tracked write progress.
90///
91/// The lifetime `'r` is an *identity brand*: every call to
92/// [`GlobalRegion::with_region`] introduces a fresh, unnameable lifetime,
93/// so partial writes from different regions cannot be mixed.
94///
95/// The type parameter `S` tracks which write state the region is in.
96/// Operations are only available in the correct state:
97/// - `global_store()` requires a warp (any active set)
98/// - `merge_writes()` requires complementary partial writes **from the same region**
99/// - `threadfence()` requires full write
100/// - Reading requires fenced state
101///
102/// # Region identity
103///
104/// ```compile_fail
105/// use warp_types::prelude::*;
106/// use warp_types::fence::*;
107/// // Two different regions — lifetimes differ, merge is rejected.
108/// GlobalRegion::with_region(|region1| {
109///     GlobalRegion::with_region(|region2| {
110///         let warp1 = Warp::kernel_entry();
111///         let (evens, _odds) = warp1.diverge_even_odd();
112///         let warp2 = Warp::kernel_entry();
113///         let (_odds2, odds2b) = warp2.diverge_even_odd();
114///         let (_evens, partial_even) = evens.global_store(region1);
115///         let (_odds2b, partial_odd) = odds2b.global_store(region2);
116///         // Cross-region merge — compile error: lifetime mismatch
117///         let _full = merge_writes(partial_even, partial_odd);
118///     });
119/// });
120/// ```
121#[must_use = "GlobalRegion tracks write progress — dropping it loses the write-state proof"]
122pub struct GlobalRegion<'r, S: WriteState> {
123    // fn(&'r ()) -> &'r () makes 'r invariant (cannot be widened or narrowed).
124    // This is critical: covariant 'r would let the compiler unify distinct
125    // region lifetimes by widening one to match the other.
126    _brand: PhantomData<fn(&'r ()) -> &'r ()>,
127    _state: PhantomData<S>,
128}
129
130impl GlobalRegion<'_, Unwritten> {
131    /// Enter a region scope. The callback receives a fresh `GlobalRegion`
132    /// whose lifetime `'r` is unique and unnameable — partial writes
133    /// derived from it cannot be merged with writes from any other region.
134    ///
135    /// # Examples
136    ///
137    /// ```
138    /// use warp_types::prelude::*;
139    /// use warp_types::fence::*;
140    /// GlobalRegion::with_region(|region| {
141    ///     let warp = Warp::kernel_entry();
142    ///     let (evens, odds) = warp.diverge_even_odd();
143    ///     let (evens, partial_even) = evens.global_store(region);
144    ///     let (odds, full) = odds.global_store_complement(partial_even);
145    ///     let _fenced = threadfence(full);
146    /// });
147    /// ```
148    pub fn with_region<R>(f: impl for<'r> FnOnce(GlobalRegion<'r, Unwritten>) -> R) -> R {
149        f(GlobalRegion {
150            _brand: PhantomData,
151            _state: PhantomData,
152        })
153    }
154}
155
156impl<'r> GlobalRegion<'r, Unwritten> {
157    /// Split an unwritten region into two halves sharing the same
158    /// lifetime brand. Each half can be stored to independently,
159    /// then the resulting partial writes can be merged (they share `'r`).
160    ///
161    /// This is the safe way to create two partial writes from one region
162    /// when using `merge_writes` or `merge_writes_within` instead of
163    /// the sequential `global_store` / `global_store_complement` path.
164    ///
165    /// **Note:** This is a type-level model — `GlobalRegion` is phantom
166    /// (zero-sized, no real memory). The lifetime brand `'r` ties the
167    /// halves together but does not track which memory addresses each
168    /// half covers. A real GPU memory model would need address-level
169    /// tracking to prevent combining writes to disjoint address ranges.
170    pub fn split(self) -> (GlobalRegion<'r, Unwritten>, GlobalRegion<'r, Unwritten>) {
171        (
172            GlobalRegion {
173                _brand: PhantomData,
174                _state: PhantomData,
175            },
176            GlobalRegion {
177                _brand: PhantomData,
178                _state: PhantomData,
179            },
180        )
181    }
182}
183
184/// Warp writes to a global region, producing a partial write.
185impl<S: ActiveSet> Warp<S> {
186    /// Store values to global memory.
187    ///
188    /// Returns the warp (unchanged) and a partially-written region
189    /// that tracks which lanes have written. The lifetime `'r` is
190    /// preserved, tying the partial write to its origin region.
191    ///
192    /// **Note:** Even `Warp<All>` produces `PartialWrite<All>`, not `FullWrite`.
193    /// Use `global_store_complement` with the complement's partial write to
194    /// advance to `FullWrite`. The sequential path is:
195    /// `warp.global_store(region)` → `warp.global_store_complement(partial)` → `FullWrite`.
196    pub fn global_store<'r>(
197        self,
198        _region: GlobalRegion<'r, Unwritten>,
199    ) -> (Self, GlobalRegion<'r, PartialWrite<S>>) {
200        (
201            self,
202            GlobalRegion {
203                _brand: PhantomData,
204                _state: PhantomData,
205            },
206        )
207    }
208
209    /// Store values to a region that already has a partial write from
210    /// complementary lanes, producing a full write.
211    ///
212    /// Returns the warp (unchanged) so it can still be merged.
213    /// The lifetime `'r` must match — both writes must target the same region.
214    pub fn global_store_complement<'r, S2: ActiveSet>(
215        self,
216        _region: GlobalRegion<'r, PartialWrite<S2>>,
217    ) -> (Self, GlobalRegion<'r, FullWrite>)
218    where
219        S: ComplementOf<S2>,
220    {
221        (
222            self,
223            GlobalRegion {
224                _brand: PhantomData,
225                _state: PhantomData,
226            },
227        )
228    }
229}
230
231/// Merge writes from complementary partial writes (top-level: covers All).
232///
233/// Requires the same `ComplementOf` proof as warp merge, AND the same
234/// region lifetime `'r` — preventing cross-region merging.
235///
236/// Writing with wrong complement type fails:
237///
238/// ```compile_fail
239/// use warp_types::prelude::*;
240/// use warp_types::fence::*;
241/// GlobalRegion::with_region(|region1| {
242///     GlobalRegion::with_region(|region2| {
243///         let warp1 = Warp::kernel_entry();
244///         let (evens, _odds) = warp1.diverge_even_odd();
245///         let warp2 = Warp::kernel_entry();
246///         let (low, _high) = warp2.diverge_halves();
247///         let (_evens, partial_even) = evens.global_store(region1);
248///         let (_low, partial_low) = low.global_store(region2);
249///         // Even and LowHalf are not complements (they overlap) — compile error
250///         let _full = merge_writes(partial_even, partial_low);
251///     });
252/// });
253/// ```
254pub fn merge_writes<'r, S1, S2>(
255    _a: GlobalRegion<'r, PartialWrite<S1>>,
256    _b: GlobalRegion<'r, PartialWrite<S2>>,
257) -> GlobalRegion<'r, FullWrite>
258where
259    S1: ComplementOf<S2>,
260    S2: ActiveSet,
261{
262    GlobalRegion {
263        _brand: PhantomData,
264        _state: PhantomData,
265    }
266}
267
268/// Merge writes from partial writes that are complements within a parent set.
269///
270/// For nested divergence: e.g., EvenLow + EvenHigh within Even.
271/// Returns a partial write for the parent set, not a full write.
272/// The lifetime `'r` must match — both writes must target the same region.
273pub fn merge_writes_within<'r, S1, S2, P>(
274    _a: GlobalRegion<'r, PartialWrite<S1>>,
275    _b: GlobalRegion<'r, PartialWrite<S2>>,
276) -> GlobalRegion<'r, PartialWrite<P>>
277where
278    S1: ComplementWithin<S2, P>,
279    S2: ActiveSet,
280    P: ActiveSet,
281{
282    GlobalRegion {
283        _brand: PhantomData,
284        _state: PhantomData,
285    }
286}
287
288/// Issue a thread fence after all writes are complete.
289///
290/// Only callable on `GlobalRegion<FullWrite>` — the type system ensures
291/// all lanes have written before the fence.
292pub fn threadfence(_proof: GlobalRegion<FullWrite>) -> GlobalRegion<Fenced> {
293    // In real implementation: __threadfence()
294    GlobalRegion {
295        _brand: PhantomData,
296        _state: PhantomData,
297    }
298}
299
300impl GlobalRegion<'_, Fenced> {
301    /// Read from a fenced global region. Safe because:
302    /// 1. All lanes have written (FullWrite)
303    /// 2. Fence ensures visibility (Fenced)
304    pub fn read<T: Default>(&self) -> T {
305        T::default() // placeholder
306    }
307}
308
309#[cfg(test)]
310mod tests {
311    use super::*;
312    use crate::active_set::*;
313
314    #[test]
315    fn test_full_fence_protocol() {
316        GlobalRegion::with_region(|region| {
317            let warp: Warp<All> = Warp::new();
318            let (evens, odds) = warp.diverge_even_odd();
319
320            // Each half writes to global memory
321            let (evens, partial_even) = evens.global_store(region);
322            // evens still usable — global_store returns the warp
323
324            // Second half completes the write — warp returned for merge
325            let (odds, full) = odds.global_store_complement(partial_even);
326
327            // Warps can still be merged after fence operations
328            let _merged: Warp<All> = crate::merge(evens, odds);
329
330            // Fence after full write
331            let fenced = threadfence(full);
332            let _val: i32 = fenced.read();
333        });
334    }
335
336    #[test]
337    fn test_merge_writes_same_region() {
338        // merge_writes with split: both partials share the region lifetime
339        GlobalRegion::with_region(|region| {
340            let (r1, r2) = region.split();
341
342            let warp: Warp<All> = Warp::new();
343            let (evens, odds) = warp.diverge_even_odd();
344
345            let (_evens, partial_even) = evens.global_store(r1);
346            let (_odds, partial_odd) = odds.global_store(r2);
347
348            let full = merge_writes(partial_even, partial_odd);
349            let _fenced = threadfence(full);
350        });
351    }
352
353    #[test]
354    fn test_nested_fence_protocol() {
355        // Nested divergence: EvenLow + EvenHigh are complements within Even.
356        // merge_writes_within returns PartialWrite<Even>, which can then be
357        // merged with PartialWrite<Odd> to get FullWrite.
358        GlobalRegion::with_region(|region| {
359            let (r_odd, r_nested) = region.split();
360            let (r_el, r_eh) = r_nested.split();
361
362            let warp: Warp<All> = Warp::new();
363            let (evens, odds) = warp.diverge_even_odd();
364            let (even_low, even_high) = evens.diverge_halves();
365
366            let (_odds, partial_odd) = odds.global_store(r_odd);
367            let (_el, partial_el) = even_low.global_store(r_el);
368            let (_eh, partial_eh) = even_high.global_store(r_eh);
369
370            // Nested merge: EvenLow + EvenHigh → Even (partial)
371            let even_partial: GlobalRegion<PartialWrite<Even>> =
372                merge_writes_within(partial_el, partial_eh);
373
374            // Top-level merge: Even + Odd → FullWrite
375            let full = merge_writes(even_partial, partial_odd);
376            let _fenced = threadfence(full);
377        });
378    }
379
380    #[test]
381    fn test_global_store_complement_same_region() {
382        // Sequential path: store then store_complement on the same region
383        GlobalRegion::with_region(|region| {
384            let warp: Warp<All> = Warp::new();
385            let (evens, odds) = warp.diverge_even_odd();
386
387            let (_evens, partial) = evens.global_store(region);
388            let (_odds, full) = odds.global_store_complement(partial);
389            let _fenced = threadfence(full);
390        });
391    }
392
393    #[test]
394    fn test_with_region_returns_value() {
395        let result = GlobalRegion::with_region(|region| {
396            let warp: Warp<All> = Warp::new();
397            let (evens, odds) = warp.diverge_even_odd();
398            let (_evens, partial) = evens.global_store(region);
399            let (_odds, full) = odds.global_store_complement(partial);
400            let fenced = threadfence(full);
401            fenced.read::<i32>()
402        });
403        assert_eq!(result, 0); // Default for i32
404    }
405
406    #[test]
407    fn test_split_preserves_region_identity() {
408        // Splitting preserves the region lifetime — all descendants
409        // can be merged because they share 'r.
410        GlobalRegion::with_region(|region| {
411            let (a, b) = region.split();
412            let (a1, a2) = a.split();
413
414            let warp: Warp<All> = Warp::new();
415            let (evens, odds) = warp.diverge_even_odd();
416            let (even_low, even_high) = evens.diverge_halves();
417
418            let (_el, p_el) = even_low.global_store(a1);
419            let (_eh, p_eh) = even_high.global_store(a2);
420            let (_odds, p_odd) = odds.global_store(b);
421
422            let even_partial: GlobalRegion<PartialWrite<Even>> = merge_writes_within(p_el, p_eh);
423            let full = merge_writes(even_partial, p_odd);
424            let _fenced = threadfence(full);
425        });
426    }
427}