Skip to main content

vyre_driver/
persistent.rs

1//! Persistent-thread engine + host-side work queue (G7).
2//!
3//! # What this is
4//!
5//! A single long-lived GPU dispatch owns a chunk of the device.
6//! Host workers push `PersistentWorkItem`s into a device-visible ring buffer
7//! via an atomic head counter; the device's persistent threads
8//! poll a tail counter and pick up items. The host waits on
9//! per-item completion markers to gather results.
10//!
11//! Eliminates the per-file kernel-launch cost (~5–20 µs on today's
12//! drivers) so a stream of 10 000 × 1 KiB scan jobs pays launch
13//! overhead once, not 10 000 times.
14//!
15//! # Scope of this file
16//!
17//! This module owns the **host-side ring buffer**  -  the atomic
18//! head/tail pair, the lock-free claim protocol, and exhaustive
19//! tests. The actual persistent GPU kernel that consumes the queue
20//! lives behind the `persistent` cargo feature and talks to the owning
21//! backend's native queue API. The host queue is proven correct in isolation
22//! so device integration only worries about the kernel side.
23//!
24//! # Memory ordering
25//!
26//! - Producers `AcqRel` on the head CAS; writes to the slot
27//!   before the CAS happen-before the head increment.
28//! - Consumers `AcqRel` on the tail CAS; after observing the
29//!   incremented head, they see the producer's slot writes.
30//! - A `Release` fence on the producer after the slot write and
31//!   an `Acquire` fence on the consumer before reading the slot
32//!   guarantees visibility across the weakest memory models we
33//!   need to support (x86, ARM, RISC-V GPU consumers).
34
35use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
36
37/// Caller-controlled persistent-thread dispatch policy.
38#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
39pub enum PersistentThreadMode {
40    /// Use the persistent path when the backend advertises support.
41    #[default]
42    Auto,
43    /// Require the persistent path; fail loudly if unavailable.
44    Force,
45    /// Never use the persistent path.
46    Disable,
47}
48
49/// One scan-unit descriptor.
50///
51/// All fields are plain 32-bit numbers so the same struct lays out
52/// identically on host and device.
53#[derive(Debug, Clone, Copy, PartialEq, Eq)]
54#[repr(C)]
55pub struct PersistentWorkItem {
56    /// Byte offset into the persistent input buffer.
57    pub input_offset: u32,
58    /// Number of bytes in this scan unit.
59    pub input_len: u32,
60    /// Rule-set / fused-megakernel output-slot bank id.
61    pub rule_set_id: u32,
62    /// Caller-opaque correlation id  -  echoed into the per-item
63    /// completion counter so the host can match results back to a
64    /// scan job without a shadow map.
65    pub correlation: u32,
66}
67
68/// Shared atomics between host producers and device consumers.
69#[derive(Debug)]
70pub struct RingAtomics {
71    /// Monotonically increasing next-slot-to-claim by a producer.
72    pub head: AtomicU64,
73    /// Monotonically increasing next-slot-to-claim by a consumer.
74    pub tail: AtomicU64,
75    /// Per-slot publication sequence. A producer writes the slot payload first
76    /// and then publishes `head + 1` here with `Release`; consumers wait for
77    /// that exact sequence before reading the packed payload.
78    pub ready: Vec<AtomicU64>,
79    /// Per-slot completion marker (1 = done).
80    pub done: Vec<AtomicU32>,
81}
82
83impl RingAtomics {
84    fn try_new(ring_size: u32) -> Result<Self, String> {
85        let capacity = persistent_ring_capacity(ring_size)?;
86        let mut ready = Vec::new();
87        crate::allocation::try_reserve_vec_to_capacity(&mut ready, capacity).map_err(|error| {
88            format!("Fix: persistent ring could not reserve {capacity} ready marker(s): {error}.")
89        })?;
90        for slot in 0..ring_size {
91            ready.push(AtomicU64::new(u64::from(slot)));
92        }
93
94        let mut done = Vec::new();
95        crate::allocation::try_reserve_vec_to_capacity(&mut done, capacity).map_err(|error| {
96            format!("Fix: persistent ring could not reserve {capacity} done marker(s): {error}.")
97        })?;
98        for _ in 0..ring_size {
99            done.push(AtomicU32::new(0));
100        }
101
102        Ok(Self {
103            head: AtomicU64::new(0),
104            tail: AtomicU64::new(0),
105            ready,
106            done,
107        })
108    }
109}
110
111#[derive(Debug)]
112struct WorkSlot {
113    lo: AtomicU64,
114    hi: AtomicU64,
115}
116
117impl WorkSlot {
118    fn new(item: PersistentWorkItem) -> Self {
119        let (lo, hi) = pack_work_item(item);
120        Self {
121            lo: AtomicU64::new(lo),
122            hi: AtomicU64::new(hi),
123        }
124    }
125
126    fn store(&self, item: PersistentWorkItem) {
127        let (lo, hi) = pack_work_item(item);
128        self.lo.store(lo, Ordering::Relaxed);
129        self.hi.store(hi, Ordering::Relaxed);
130    }
131
132    fn load(&self) -> PersistentWorkItem {
133        unpack_work_item(
134            self.lo.load(Ordering::Relaxed),
135            self.hi.load(Ordering::Relaxed),
136        )
137    }
138}
139
140fn pack_work_item(item: PersistentWorkItem) -> (u64, u64) {
141    (
142        u64::from(item.input_offset) | (u64::from(item.input_len) << 32),
143        u64::from(item.rule_set_id) | (u64::from(item.correlation) << 32),
144    )
145}
146
147fn unpack_work_item(lo: u64, hi: u64) -> PersistentWorkItem {
148    PersistentWorkItem {
149        input_offset: lo as u32,
150        input_len: (lo >> 32) as u32,
151        rule_set_id: hi as u32,
152        correlation: (hi >> 32) as u32,
153    }
154}
155
156/// Persistent-engine handle. Owns the host-side view of the ring
157/// buffer. The GPU kernel is a separate concern gated behind
158/// the `persistent` cargo feature.
159#[derive(Debug)]
160pub struct PersistentEngine {
161    slots: Vec<WorkSlot>,
162    atomics: RingAtomics,
163    ring_size: u32,
164}
165
166impl PersistentEngine {
167    /// Construct an engine with a ring capacity of `ring_size`
168    /// slots. Must be a nonzero power of two so
169    /// `index = slot & (cap-1)` is correct.
170    pub fn new(ring_size: u32) -> Self {
171        let ring_size = ring_size
172            .checked_next_power_of_two()
173            .filter(|&size| size > 0)
174            .unwrap_or(1);
175        Self::with_valid_ring_size(ring_size)
176    }
177
178    /// Construct an engine only when the ring capacity already satisfies
179    /// the persistent-ring indexing contract.
180    pub fn try_new(ring_size: u32) -> Result<Self, String> {
181        if ring_size.is_power_of_two() && ring_size > 0 {
182            Self::try_with_valid_ring_size(ring_size)
183        } else {
184            Err(format!(
185                "Fix: ring_size must be a nonzero power of two, got {ring_size}."
186            ))
187        }
188    }
189
190    fn with_valid_ring_size(ring_size: u32) -> Self {
191        match Self::try_with_valid_ring_size(ring_size) {
192            Ok(engine) => engine,
193            Err(_) => Self::try_with_valid_ring_size(1).unwrap_or_else(|_| std::process::abort()),
194        }
195    }
196
197    fn try_with_valid_ring_size(ring_size: u32) -> Result<Self, String> {
198        let zero = PersistentWorkItem {
199            input_offset: 0,
200            input_len: 0,
201            rule_set_id: 0,
202            correlation: 0,
203        };
204        let capacity = persistent_ring_capacity(ring_size)?;
205        let mut slots = Vec::new();
206        crate::allocation::try_reserve_vec_to_capacity(&mut slots, capacity).map_err(|error| {
207            format!("Fix: persistent ring could not reserve {capacity} work slot(s): {error}.")
208        })?;
209        for _ in 0..ring_size {
210            slots.push(WorkSlot::new(zero));
211        }
212
213        Ok(Self {
214            slots,
215            atomics: RingAtomics::try_new(ring_size)?,
216            ring_size,
217        })
218    }
219
220    /// Capacity of the ring buffer.
221    pub fn ring_size(&self) -> u32 {
222        self.ring_size
223    }
224
225    /// Enqueue a PersistentWorkItem. Returns `Ok(slot_index)` on success, or
226    /// `Err(QueueFull)` if the ring is full. Thread-safe under
227    /// concurrent producers (lock-free CAS on `head`).
228    pub fn enqueue(&self, item: PersistentWorkItem) -> Result<u32, QueueFull> {
229        loop {
230            let head = self.atomics.head.load(Ordering::Acquire);
231            let slot_idx = (head as u32) & (self.ring_size - 1);
232            let slot_offset = slot_idx as usize;
233            let Some(ready) = self.atomics.ready.get(slot_offset) else {
234                return Err(QueueFull);
235            };
236            match ring_sequence_order(ready.load(Ordering::Acquire), head) {
237                RingSequenceOrder::Free => {}
238                RingSequenceOrder::Behind => return Err(QueueFull),
239                RingSequenceOrder::Ahead => {
240                    std::hint::spin_loop();
241                    continue;
242                }
243            }
244            match self.atomics.head.compare_exchange(
245                head,
246                head.wrapping_add(1),
247                Ordering::AcqRel,
248                Ordering::Acquire,
249            ) {
250                Ok(_) => {
251                    let Some(slot) = self.slots.get(slot_offset) else {
252                        return Err(QueueFull);
253                    };
254                    slot.store(item);
255                    self.atomics.done[slot_offset].store(0, Ordering::Release);
256                    self.atomics.ready[slot_offset].store(head.wrapping_add(1), Ordering::Release);
257                    return Ok(slot_idx);
258                }
259                Err(_) => continue,
260            }
261        }
262    }
263
264    /// Consumer-side claim. Returns the next available item or
265    /// `None` if the queue is empty. Thread-safe under concurrent
266    /// consumers.
267    pub fn claim(&self) -> Option<PersistentWorkItem> {
268        loop {
269            let tail = self.atomics.tail.load(Ordering::Acquire);
270            let slot_idx = (tail as u32) & (self.ring_size - 1);
271            let slot_offset = slot_idx as usize;
272            let published = tail.wrapping_add(1);
273            let Some(ready) = self.atomics.ready.get(slot_offset) else {
274                return None;
275            };
276            match ring_sequence_order(ready.load(Ordering::Acquire), published) {
277                RingSequenceOrder::Free => {}
278                RingSequenceOrder::Behind => {
279                    if tail >= self.atomics.head.load(Ordering::Acquire) {
280                        return None;
281                    }
282                    std::hint::spin_loop();
283                    continue;
284                }
285                RingSequenceOrder::Ahead => {
286                    std::hint::spin_loop();
287                    continue;
288                }
289            }
290            match self.atomics.tail.compare_exchange(
291                tail,
292                tail.wrapping_add(1),
293                Ordering::AcqRel,
294                Ordering::Acquire,
295            ) {
296                Ok(_) => {
297                    let slot = self.slots.get(slot_offset)?;
298                    let item = slot.load();
299                    self.atomics.ready[slot_offset].store(
300                        tail.wrapping_add(u64::from(self.ring_size)),
301                        Ordering::Release,
302                    );
303                    return Some(item);
304                }
305                Err(_) => continue,
306            }
307        }
308    }
309
310    /// Mark item at `slot_idx` as done.
311    pub fn mark_done(&self, slot_idx: u32) -> Result<(), String> {
312        let Some(done) = self.atomics.done.get(slot_idx as usize) else {
313            return Err(format!(
314                "Fix: persistent ring slot_idx={slot_idx} is outside ring_size={}. Reject stale or corrupt completion markers before marking done.",
315                self.ring_size
316            ));
317        };
318        done.store(1, Ordering::Release);
319        Ok(())
320    }
321
322    /// Whether the consumer finished the item at `slot_idx`.
323    pub fn is_done(&self, slot_idx: u32) -> Result<bool, String> {
324        let Some(done) = self.atomics.done.get(slot_idx as usize) else {
325            return Err(format!(
326                "Fix: persistent ring slot_idx={slot_idx} is outside ring_size={}. Reject stale or corrupt completion markers before reading done state.",
327                self.ring_size
328            ));
329        };
330        Ok(done.load(Ordering::Acquire) != 0)
331    }
332
333    /// Number of items queued and pending claim.
334    pub fn try_in_flight(&self) -> Result<u32, String> {
335        let pending = self
336            .atomics
337            .head
338            .load(Ordering::Acquire)
339            .wrapping_sub(self.atomics.tail.load(Ordering::Acquire));
340        u32::try_from(pending).map_err(|_| {
341            format!(
342                "Fix: persistent engine in-flight count {pending} exceeds u32::MAX. Drain the ring or use the 64-bit counters before exporting GPU-visible queue metadata."
343            )
344        })
345    }
346
347    /// Number of items queued and pending claim.
348    pub fn in_flight(&self) -> u32 {
349        self.try_in_flight().unwrap_or(u32::MAX)
350    }
351
352    /// Monotonic head counter (modulo `ring_size` = slot index).
353    pub fn head_counter(&self) -> u64 {
354        self.atomics.head.load(Ordering::Acquire)
355    }
356
357    /// Monotonic head counter exposed through the legacy u32 API.
358    pub fn head(&self) -> u32 {
359        let head = self.head_counter();
360        u32::try_from(head).unwrap_or(u32::MAX)
361    }
362
363    /// Monotonic tail counter.
364    pub fn tail_counter(&self) -> u64 {
365        self.atomics.tail.load(Ordering::Acquire)
366    }
367
368    /// Monotonic tail counter exposed through the legacy u32 API.
369    pub fn tail(&self) -> u32 {
370        let tail = self.tail_counter();
371        u32::try_from(tail).unwrap_or(u32::MAX)
372    }
373}
374
375fn persistent_ring_capacity(ring_size: u32) -> Result<usize, String> {
376    usize::try_from(ring_size).map_err(|_| {
377        format!("Fix: persistent ring_size {ring_size} does not fit this target's address space.")
378    })
379}
380
381#[derive(Debug, Clone, Copy, PartialEq, Eq)]
382enum RingSequenceOrder {
383    Behind,
384    Free,
385    Ahead,
386}
387
388fn ring_sequence_order(sequence: u64, position: u64) -> RingSequenceOrder {
389    match (sequence.wrapping_sub(position) as i64).cmp(&0) {
390        std::cmp::Ordering::Less => RingSequenceOrder::Behind,
391        std::cmp::Ordering::Equal => RingSequenceOrder::Free,
392        std::cmp::Ordering::Greater => RingSequenceOrder::Ahead,
393    }
394}
395
396/// Enqueue attempted but the ring is full.
397#[derive(Debug, Clone, Copy, PartialEq, Eq)]
398pub struct QueueFull;
399
400impl std::fmt::Display for QueueFull {
401    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
402        f.write_str("persistent engine ring buffer is full")
403    }
404}
405
406impl std::error::Error for QueueFull {}
407
408#[cfg(test)]
409mod tests {
410    use super::*;
411    use std::sync::Arc;
412    use std::thread;
413
414    fn item(i: u32) -> PersistentWorkItem {
415        PersistentWorkItem {
416            input_offset: i * 1024,
417            input_len: 1024,
418            rule_set_id: 0,
419            correlation: i,
420        }
421    }
422
423    #[test]
424    fn invalid_ring_size_has_explicit_error_api() {
425        let err = PersistentEngine::try_new(7).unwrap_err();
426        assert!(err.contains("Fix:"));
427        assert!(PersistentEngine::try_new(0).is_err());
428    }
429
430    #[test]
431    fn infallible_constructor_normalizes_ring_size() {
432        assert_eq!(PersistentEngine::new(7).ring_size(), 8);
433        assert_eq!(PersistentEngine::new(0).ring_size(), 1);
434    }
435
436    #[test]
437    fn enqueue_claim_fifo_single_thread() {
438        let eng = PersistentEngine::new(8);
439        for i in 0..8 {
440            assert_eq!(eng.enqueue(item(i)).unwrap(), i);
441        }
442        for i in 0..8 {
443            assert_eq!(eng.claim().unwrap().correlation, i);
444        }
445        assert!(eng.claim().is_none());
446    }
447
448    #[test]
449    fn queue_full_on_overflow() {
450        let eng = PersistentEngine::new(4);
451        for i in 0..4 {
452            eng.enqueue(item(i)).unwrap();
453        }
454        assert_eq!(eng.enqueue(item(99)), Err(QueueFull));
455    }
456
457    #[test]
458    fn space_reclaims_after_claim() {
459        let eng = PersistentEngine::new(4);
460        for i in 0..4 {
461            eng.enqueue(item(i)).unwrap();
462        }
463        assert!(eng.enqueue(item(99)).is_err());
464        let claimed = eng.claim().unwrap();
465        assert_eq!(claimed.correlation, 0);
466        assert!(eng.enqueue(item(99)).is_ok());
467    }
468
469    #[test]
470    fn in_flight_tracks_correctly() {
471        let eng = PersistentEngine::new(16);
472        assert_eq!(eng.in_flight(), 0);
473        for i in 0..5 {
474            eng.enqueue(item(i)).unwrap();
475        }
476        assert_eq!(eng.in_flight(), 5);
477        eng.claim().unwrap();
478        eng.claim().unwrap();
479        assert_eq!(eng.in_flight(), 3);
480    }
481
482    #[test]
483    fn done_marker_flows_through() {
484        let eng = PersistentEngine::new(4);
485        let slot = eng.enqueue(item(1)).unwrap();
486        assert!(!eng.is_done(slot).unwrap());
487        let claimed = eng.claim().unwrap();
488        assert_eq!(claimed.correlation, 1);
489        eng.mark_done(slot).unwrap();
490        assert!(eng.is_done(slot).unwrap());
491    }
492
493    #[test]
494    fn multi_producer_single_consumer_no_item_lost() {
495        let eng = Arc::new(PersistentEngine::new(128));
496        let producers = 4;
497        let items_per_producer = 16;
498        let mut handles = Vec::new();
499        for p in 0..producers {
500            let eng = Arc::clone(&eng);
501            handles.push(thread::spawn(move || {
502                for i in 0..items_per_producer {
503                    let corr = (p * 1000 + i) as u32;
504                    loop {
505                        if eng.enqueue(item(corr)).is_ok() {
506                            break;
507                        }
508                        std::hint::spin_loop();
509                    }
510                }
511            }));
512        }
513        let consumer_eng = Arc::clone(&eng);
514        let consumer = thread::spawn(move || {
515            let total = (producers * items_per_producer) as usize;
516            let mut seen = Vec::with_capacity(total);
517            while seen.len() < total {
518                if let Some(it) = consumer_eng.claim() {
519                    seen.push(it.correlation);
520                } else {
521                    std::hint::spin_loop();
522                }
523            }
524            seen
525        });
526        for h in handles {
527            h.join().unwrap();
528        }
529        let seen = consumer.join().unwrap();
530        let mut sorted = seen.clone();
531        sorted.sort();
532        sorted.dedup();
533        assert_eq!(sorted.len(), seen.len(), "duplicate items consumed");
534        for p in 0..producers {
535            for i in 0..items_per_producer {
536                let expected = (p * 1000 + i) as u32;
537                assert!(
538                    seen.contains(&expected),
539                    "missing correlation id {expected}"
540                );
541            }
542        }
543    }
544
545    #[test]
546    fn wrap_around_works_for_large_throughput() {
547        let eng = PersistentEngine::new(16);
548        let passes = 10;
549        for p in 0..passes {
550            for i in 0..16 {
551                let corr = (p * 1000 + i) as u32;
552                assert!(eng.enqueue(item(corr)).is_ok());
553            }
554            for i in 0..16 {
555                let corr = (p * 1000 + i) as u32;
556                assert_eq!(eng.claim().unwrap().correlation, corr);
557            }
558        }
559        assert_eq!(eng.head(), (passes * 16) as u32);
560        assert_eq!(eng.tail(), (passes * 16) as u32);
561        assert_eq!(eng.in_flight(), 0);
562    }
563
564    #[test]
565    fn multi_consumer_no_double_claim() {
566        let eng = Arc::new(PersistentEngine::new(128));
567        let total = 100_u32;
568        for i in 0..total {
569            eng.enqueue(item(i)).unwrap();
570        }
571        let consumers = 4;
572        let mut handles = Vec::new();
573        let shared_consumed = Arc::new(std::sync::Mutex::new(Vec::new()));
574        for _ in 0..consumers {
575            let eng = Arc::clone(&eng);
576            let out = Arc::clone(&shared_consumed);
577            handles.push(thread::spawn(move || {
578                let mut local = Vec::new();
579                while let Some(it) = eng.claim() {
580                    local.push(it.correlation);
581                }
582                out.lock().unwrap().extend(local);
583            }));
584        }
585        for h in handles {
586            h.join().unwrap();
587        }
588        let mut consumed = Arc::try_unwrap(shared_consumed)
589            .unwrap()
590            .into_inner()
591            .unwrap();
592        consumed.sort();
593        assert_eq!(consumed.len(), total as usize);
594        for (i, c) in consumed.iter().enumerate() {
595            assert_eq!(*c, i as u32, "duplicated or missing item at idx {i}");
596        }
597    }
598
599    #[test]
600    fn queue_full_error_display_is_useful() {
601        let s = format!("{QueueFull}");
602        assert!(s.contains("ring buffer"));
603    }
604}