Skip to main content

vyre_driver/
persistent.rs

1//! Persistent-thread engine + host-side work queue (G7).
2//!
3//! # What this is
4//!
5//! A single long-lived GPU dispatch owns a chunk of the device.
6//! Host workers push `PersistentWorkItem`s into a device-visible ring buffer
7//! via an atomic head counter; the device's persistent threads
8//! poll a tail counter and pick up items. The host waits on
9//! per-item completion markers to gather results.
10//!
11//! Eliminates the per-file kernel-launch cost (~5–20 µs on today's
12//! drivers) so a stream of 10 000 × 1 KiB scan jobs pays launch
13//! overhead once, not 10 000 times.
14//!
15//! # Scope of this file
16//!
17//! This module owns the **host-side ring buffer**  -  the atomic
18//! head/tail pair, the lock-free claim protocol, and exhaustive
19//! tests. The actual persistent GPU kernel that consumes the queue
20//! lives behind the `persistent` cargo feature and talks to the owning
21//! backend's native queue API. The host queue is proven correct in isolation
22//! so device integration only worries about the kernel side.
23//!
24//! # Memory ordering
25//!
26//! - Producers `AcqRel` on the head CAS; writes to the slot
27//!   before the CAS happen-before the head increment.
28//! - Consumers `AcqRel` on the tail CAS; after observing the
29//!   incremented head, they see the producer's slot writes.
30//! - A `Release` fence on the producer after the slot write and
31//!   an `Acquire` fence on the consumer before reading the slot
32//!   guarantees visibility across the weakest memory models we
33//!   need to support (x86, ARM, RISC-V GPU consumers).
34
35use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
36
37/// Caller-controlled persistent-thread dispatch policy.
38#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
39pub enum PersistentThreadMode {
40    /// Use the persistent path when the backend advertises support.
41    #[default]
42    Auto,
43    /// Require the persistent path; fail loudly if unavailable.
44    Force,
45    /// Never use the persistent path.
46    Disable,
47}
48
49/// One scan-unit descriptor.
50///
51/// All fields are plain 32-bit numbers so the same struct lays out
52/// identically on host and device.
53#[derive(Debug, Clone, Copy, PartialEq, Eq)]
54#[repr(C)]
55pub struct PersistentWorkItem {
56    /// Byte offset into the persistent input buffer.
57    pub input_offset: u32,
58    /// Number of bytes in this scan unit.
59    pub input_len: u32,
60    /// Rule-set / fused-megakernel output-slot bank id.
61    pub rule_set_id: u32,
62    /// Caller-opaque correlation id  -  echoed into the per-item
63    /// completion counter so the host can match results back to a
64    /// scan job without a shadow map.
65    pub correlation: u32,
66}
67
68/// Shared atomics between host producers and device consumers.
69#[derive(Debug)]
70pub struct RingAtomics {
71    /// Monotonically increasing next-slot-to-claim by a producer.
72    pub head: AtomicU64,
73    /// Monotonically increasing next-slot-to-claim by a consumer.
74    pub tail: AtomicU64,
75    /// Per-slot publication sequence. A producer writes the slot payload first
76    /// and then publishes `head + 1` here with `Release`; consumers wait for
77    /// that exact sequence before reading the packed payload.
78    pub ready: Vec<AtomicU64>,
79    /// Per-slot completion marker (1 = done).
80    pub done: Vec<AtomicU32>,
81}
82
83impl RingAtomics {
84    fn try_new(ring_size: u32) -> Result<Self, String> {
85        let capacity = persistent_ring_capacity(ring_size)?;
86        let mut ready = Vec::new();
87        crate::allocation::try_reserve_vec_to_capacity(&mut ready, capacity).map_err(|error| {
88            format!("Fix: persistent ring could not reserve {capacity} ready marker(s): {error}.")
89        })?;
90        for slot in 0..ring_size {
91            ready.push(AtomicU64::new(u64::from(slot)));
92        }
93
94        let mut done = Vec::new();
95        crate::allocation::try_reserve_vec_to_capacity(&mut done, capacity).map_err(|error| {
96            format!("Fix: persistent ring could not reserve {capacity} done marker(s): {error}.")
97        })?;
98        for _ in 0..ring_size {
99            done.push(AtomicU32::new(0));
100        }
101
102        Ok(Self {
103            head: AtomicU64::new(0),
104            tail: AtomicU64::new(0),
105            ready,
106            done,
107        })
108    }
109}
110
111#[derive(Debug)]
112struct WorkSlot {
113    lo: AtomicU64,
114    hi: AtomicU64,
115}
116
117impl WorkSlot {
118    fn new(item: PersistentWorkItem) -> Self {
119        let (lo, hi) = pack_work_item(item);
120        Self {
121            lo: AtomicU64::new(lo),
122            hi: AtomicU64::new(hi),
123        }
124    }
125
126    fn store(&self, item: PersistentWorkItem) {
127        let (lo, hi) = pack_work_item(item);
128        self.lo.store(lo, Ordering::Relaxed);
129        self.hi.store(hi, Ordering::Relaxed);
130    }
131
132    fn load(&self) -> PersistentWorkItem {
133        unpack_work_item(
134            self.lo.load(Ordering::Relaxed),
135            self.hi.load(Ordering::Relaxed),
136        )
137    }
138}
139
140fn pack_work_item(item: PersistentWorkItem) -> (u64, u64) {
141    (
142        u64::from(item.input_offset) | (u64::from(item.input_len) << 32),
143        u64::from(item.rule_set_id) | (u64::from(item.correlation) << 32),
144    )
145}
146
147fn unpack_work_item(lo: u64, hi: u64) -> PersistentWorkItem {
148    PersistentWorkItem {
149        input_offset: lo as u32,
150        input_len: (lo >> 32) as u32,
151        rule_set_id: hi as u32,
152        correlation: (hi >> 32) as u32,
153    }
154}
155
156/// Persistent-engine handle. Owns the host-side view of the ring
157/// buffer. The GPU kernel is a separate concern gated behind
158/// the `persistent` cargo feature.
159#[derive(Debug)]
160pub struct PersistentEngine {
161    slots: Vec<WorkSlot>,
162    atomics: RingAtomics,
163    ring_size: u32,
164}
165
166impl PersistentEngine {
167    /// Construct an engine with a ring capacity of `ring_size`
168    /// slots. Must be a nonzero power of two so
169    /// `index = slot & (cap-1)` is correct.
170    pub fn new(ring_size: u32) -> Self {
171        let ring_size = ring_size
172            .checked_next_power_of_two()
173            .filter(|&size| size > 0)
174            .unwrap_or_else(|| {
175                panic!(
176                    "Fix: persistent ring_size {ring_size} cannot be rounded to a nonzero power of two without overflow."
177                )
178            });
179        Self::with_valid_ring_size(ring_size)
180    }
181
182    /// Construct an engine only when the ring capacity already satisfies
183    /// the persistent-ring indexing contract.
184    pub fn try_new(ring_size: u32) -> Result<Self, String> {
185        if ring_size.is_power_of_two() && ring_size > 0 {
186            Self::try_with_valid_ring_size(ring_size)
187        } else {
188            Err(format!(
189                "Fix: ring_size must be a nonzero power of two, got {ring_size}."
190            ))
191        }
192    }
193
194    fn with_valid_ring_size(ring_size: u32) -> Self {
195        match Self::try_with_valid_ring_size(ring_size) {
196            Ok(engine) => engine,
197            Err(error) => panic!("{error}"),
198        }
199    }
200
201    fn try_with_valid_ring_size(ring_size: u32) -> Result<Self, String> {
202        let zero = PersistentWorkItem {
203            input_offset: 0,
204            input_len: 0,
205            rule_set_id: 0,
206            correlation: 0,
207        };
208        let capacity = persistent_ring_capacity(ring_size)?;
209        let mut slots = Vec::new();
210        crate::allocation::try_reserve_vec_to_capacity(&mut slots, capacity).map_err(|error| {
211            format!("Fix: persistent ring could not reserve {capacity} work slot(s): {error}.")
212        })?;
213        for _ in 0..ring_size {
214            slots.push(WorkSlot::new(zero));
215        }
216
217        Ok(Self {
218            slots,
219            atomics: RingAtomics::try_new(ring_size)?,
220            ring_size,
221        })
222    }
223
224    /// Capacity of the ring buffer.
225    pub fn ring_size(&self) -> u32 {
226        self.ring_size
227    }
228
229    /// Enqueue a PersistentWorkItem. Returns `Ok(slot_index)` on success, or
230    /// `Err(QueueFull)` if the ring is full. Thread-safe under
231    /// concurrent producers (lock-free CAS on `head`).
232    pub fn enqueue(&self, item: PersistentWorkItem) -> Result<u32, QueueFull> {
233        loop {
234            let head = self.atomics.head.load(Ordering::Acquire);
235            let slot_idx = (head as u32) & (self.ring_size - 1);
236            let slot_offset = slot_idx as usize;
237            let Some(ready) = self.atomics.ready.get(slot_offset) else {
238                return Err(QueueFull);
239            };
240            match ring_sequence_order(ready.load(Ordering::Acquire), head) {
241                RingSequenceOrder::Free => {}
242                RingSequenceOrder::Behind => return Err(QueueFull),
243                RingSequenceOrder::Ahead => {
244                    std::hint::spin_loop();
245                    continue;
246                }
247            }
248            match self.atomics.head.compare_exchange(
249                head,
250                head.wrapping_add(1),
251                Ordering::AcqRel,
252                Ordering::Acquire,
253            ) {
254                Ok(_) => {
255                    let Some(slot) = self.slots.get(slot_offset) else {
256                        return Err(QueueFull);
257                    };
258                    slot.store(item);
259                    self.atomics.done[slot_offset].store(0, Ordering::Release);
260                    self.atomics.ready[slot_offset].store(head.wrapping_add(1), Ordering::Release);
261                    return Ok(slot_idx);
262                }
263                Err(_) => continue,
264            }
265        }
266    }
267
268    /// Consumer-side claim. Returns the next available item or
269    /// `None` if the queue is empty. Thread-safe under concurrent
270    /// consumers.
271    pub fn claim(&self) -> Option<PersistentWorkItem> {
272        loop {
273            let tail = self.atomics.tail.load(Ordering::Acquire);
274            let slot_idx = (tail as u32) & (self.ring_size - 1);
275            let slot_offset = slot_idx as usize;
276            let published = tail.wrapping_add(1);
277            let Some(ready) = self.atomics.ready.get(slot_offset) else {
278                return None;
279            };
280            match ring_sequence_order(ready.load(Ordering::Acquire), published) {
281                RingSequenceOrder::Free => {}
282                RingSequenceOrder::Behind => {
283                    if tail >= self.atomics.head.load(Ordering::Acquire) {
284                        return None;
285                    }
286                    std::hint::spin_loop();
287                    continue;
288                }
289                RingSequenceOrder::Ahead => {
290                    std::hint::spin_loop();
291                    continue;
292                }
293            }
294            match self.atomics.tail.compare_exchange(
295                tail,
296                tail.wrapping_add(1),
297                Ordering::AcqRel,
298                Ordering::Acquire,
299            ) {
300                Ok(_) => {
301                    let slot = self.slots.get(slot_offset)?;
302                    let item = slot.load();
303                    self.atomics.ready[slot_offset].store(
304                        tail.wrapping_add(u64::from(self.ring_size)),
305                        Ordering::Release,
306                    );
307                    return Some(item);
308                }
309                Err(_) => continue,
310            }
311        }
312    }
313
314    /// Mark item at `slot_idx` as done.
315    pub fn mark_done(&self, slot_idx: u32) -> Result<(), String> {
316        let Some(done) = self.atomics.done.get(slot_idx as usize) else {
317            return Err(format!(
318                "Fix: persistent ring slot_idx={slot_idx} is outside ring_size={}. Reject stale or corrupt completion markers before marking done.",
319                self.ring_size
320            ));
321        };
322        done.store(1, Ordering::Release);
323        Ok(())
324    }
325
326    /// Whether the consumer finished the item at `slot_idx`.
327    pub fn is_done(&self, slot_idx: u32) -> Result<bool, String> {
328        let Some(done) = self.atomics.done.get(slot_idx as usize) else {
329            return Err(format!(
330                "Fix: persistent ring slot_idx={slot_idx} is outside ring_size={}. Reject stale or corrupt completion markers before reading done state.",
331                self.ring_size
332            ));
333        };
334        Ok(done.load(Ordering::Acquire) != 0)
335    }
336
337    /// Number of items queued and pending claim.
338    pub fn try_in_flight(&self) -> Result<u32, String> {
339        let pending = self
340            .atomics
341            .head
342            .load(Ordering::Acquire)
343            .wrapping_sub(self.atomics.tail.load(Ordering::Acquire));
344        u32::try_from(pending).map_err(|_| {
345            format!(
346                "Fix: persistent engine in-flight count {pending} exceeds u32::MAX. Drain the ring or use the 64-bit counters before exporting GPU-visible queue metadata."
347            )
348        })
349    }
350
351    /// Number of items queued and pending claim.
352    pub fn in_flight(&self) -> u32 {
353        self.try_in_flight()
354            .unwrap_or_else(|message| panic!("{message}"))
355    }
356
357    /// Monotonic head counter (modulo `ring_size` = slot index).
358    pub fn head_counter(&self) -> u64 {
359        self.atomics.head.load(Ordering::Acquire)
360    }
361
362    /// Monotonic head counter exposed through the legacy u32 API.
363    pub fn head(&self) -> u32 {
364        let head = self.head_counter();
365        u32::try_from(head).unwrap_or_else(|_| {
366            panic!(
367                "Fix: persistent engine head counter {head} exceeds u32::MAX. Use head_counter() for long-running queues instead of truncating telemetry."
368            )
369        })
370    }
371
372    /// Monotonic tail counter.
373    pub fn tail_counter(&self) -> u64 {
374        self.atomics.tail.load(Ordering::Acquire)
375    }
376
377    /// Monotonic tail counter exposed through the legacy u32 API.
378    pub fn tail(&self) -> u32 {
379        let tail = self.tail_counter();
380        u32::try_from(tail).unwrap_or_else(|_| {
381            panic!(
382                "Fix: persistent engine tail counter {tail} exceeds u32::MAX. Use tail_counter() for long-running queues instead of truncating telemetry."
383            )
384        })
385    }
386}
387
388fn persistent_ring_capacity(ring_size: u32) -> Result<usize, String> {
389    usize::try_from(ring_size).map_err(|_| {
390        format!("Fix: persistent ring_size {ring_size} does not fit this target's address space.")
391    })
392}
393
394#[derive(Debug, Clone, Copy, PartialEq, Eq)]
395enum RingSequenceOrder {
396    Behind,
397    Free,
398    Ahead,
399}
400
401fn ring_sequence_order(sequence: u64, position: u64) -> RingSequenceOrder {
402    match (sequence.wrapping_sub(position) as i64).cmp(&0) {
403        std::cmp::Ordering::Less => RingSequenceOrder::Behind,
404        std::cmp::Ordering::Equal => RingSequenceOrder::Free,
405        std::cmp::Ordering::Greater => RingSequenceOrder::Ahead,
406    }
407}
408
409/// Enqueue attempted but the ring is full.
410#[derive(Debug, Clone, Copy, PartialEq, Eq)]
411pub struct QueueFull;
412
413impl std::fmt::Display for QueueFull {
414    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
415        f.write_str("persistent engine ring buffer is full")
416    }
417}
418
419impl std::error::Error for QueueFull {}
420
421#[cfg(test)]
422mod tests {
423    use super::*;
424    use std::sync::Arc;
425    use std::thread;
426
427    fn item(i: u32) -> PersistentWorkItem {
428        PersistentWorkItem {
429            input_offset: i * 1024,
430            input_len: 1024,
431            rule_set_id: 0,
432            correlation: i,
433        }
434    }
435
436    #[test]
437    fn invalid_ring_size_has_explicit_error_api() {
438        let err = PersistentEngine::try_new(7).unwrap_err();
439        assert!(err.contains("Fix:"));
440        assert!(PersistentEngine::try_new(0).is_err());
441    }
442
443    #[test]
444    fn infallible_constructor_normalizes_ring_size() {
445        assert_eq!(PersistentEngine::new(7).ring_size(), 8);
446        assert_eq!(PersistentEngine::new(0).ring_size(), 1);
447    }
448
449    #[test]
450    fn enqueue_claim_fifo_single_thread() {
451        let eng = PersistentEngine::new(8);
452        for i in 0..8 {
453            assert_eq!(eng.enqueue(item(i)).unwrap(), i);
454        }
455        for i in 0..8 {
456            assert_eq!(eng.claim().unwrap().correlation, i);
457        }
458        assert!(eng.claim().is_none());
459    }
460
461    #[test]
462    fn queue_full_on_overflow() {
463        let eng = PersistentEngine::new(4);
464        for i in 0..4 {
465            eng.enqueue(item(i)).unwrap();
466        }
467        assert_eq!(eng.enqueue(item(99)), Err(QueueFull));
468    }
469
470    #[test]
471    fn space_reclaims_after_claim() {
472        let eng = PersistentEngine::new(4);
473        for i in 0..4 {
474            eng.enqueue(item(i)).unwrap();
475        }
476        assert!(eng.enqueue(item(99)).is_err());
477        let claimed = eng.claim().unwrap();
478        assert_eq!(claimed.correlation, 0);
479        assert!(eng.enqueue(item(99)).is_ok());
480    }
481
482    #[test]
483    fn in_flight_tracks_correctly() {
484        let eng = PersistentEngine::new(16);
485        assert_eq!(eng.in_flight(), 0);
486        for i in 0..5 {
487            eng.enqueue(item(i)).unwrap();
488        }
489        assert_eq!(eng.in_flight(), 5);
490        eng.claim().unwrap();
491        eng.claim().unwrap();
492        assert_eq!(eng.in_flight(), 3);
493    }
494
495    #[test]
496    fn done_marker_flows_through() {
497        let eng = PersistentEngine::new(4);
498        let slot = eng.enqueue(item(1)).unwrap();
499        assert!(!eng.is_done(slot).unwrap());
500        let claimed = eng.claim().unwrap();
501        assert_eq!(claimed.correlation, 1);
502        eng.mark_done(slot).unwrap();
503        assert!(eng.is_done(slot).unwrap());
504    }
505
506    #[test]
507    fn multi_producer_single_consumer_no_item_lost() {
508        let eng = Arc::new(PersistentEngine::new(128));
509        let producers = 4;
510        let items_per_producer = 16;
511        let mut handles = Vec::new();
512        for p in 0..producers {
513            let eng = Arc::clone(&eng);
514            handles.push(thread::spawn(move || {
515                for i in 0..items_per_producer {
516                    let corr = (p * 1000 + i) as u32;
517                    loop {
518                        if eng.enqueue(item(corr)).is_ok() {
519                            break;
520                        }
521                        std::hint::spin_loop();
522                    }
523                }
524            }));
525        }
526        let consumer_eng = Arc::clone(&eng);
527        let consumer = thread::spawn(move || {
528            let total = (producers * items_per_producer) as usize;
529            let mut seen = Vec::with_capacity(total);
530            while seen.len() < total {
531                if let Some(it) = consumer_eng.claim() {
532                    seen.push(it.correlation);
533                } else {
534                    std::hint::spin_loop();
535                }
536            }
537            seen
538        });
539        for h in handles {
540            h.join().unwrap();
541        }
542        let seen = consumer.join().unwrap();
543        let mut sorted = seen.clone();
544        sorted.sort();
545        sorted.dedup();
546        assert_eq!(sorted.len(), seen.len(), "duplicate items consumed");
547        for p in 0..producers {
548            for i in 0..items_per_producer {
549                let expected = (p * 1000 + i) as u32;
550                assert!(
551                    seen.contains(&expected),
552                    "missing correlation id {expected}"
553                );
554            }
555        }
556    }
557
558    #[test]
559    fn wrap_around_works_for_large_throughput() {
560        let eng = PersistentEngine::new(16);
561        let passes = 10;
562        for p in 0..passes {
563            for i in 0..16 {
564                let corr = (p * 1000 + i) as u32;
565                assert!(eng.enqueue(item(corr)).is_ok());
566            }
567            for i in 0..16 {
568                let corr = (p * 1000 + i) as u32;
569                assert_eq!(eng.claim().unwrap().correlation, corr);
570            }
571        }
572        assert_eq!(eng.head(), (passes * 16) as u32);
573        assert_eq!(eng.tail(), (passes * 16) as u32);
574        assert_eq!(eng.in_flight(), 0);
575    }
576
577    #[test]
578    fn multi_consumer_no_double_claim() {
579        let eng = Arc::new(PersistentEngine::new(128));
580        let total = 100_u32;
581        for i in 0..total {
582            eng.enqueue(item(i)).unwrap();
583        }
584        let consumers = 4;
585        let mut handles = Vec::new();
586        let shared_consumed = Arc::new(std::sync::Mutex::new(Vec::new()));
587        for _ in 0..consumers {
588            let eng = Arc::clone(&eng);
589            let out = Arc::clone(&shared_consumed);
590            handles.push(thread::spawn(move || {
591                let mut local = Vec::new();
592                while let Some(it) = eng.claim() {
593                    local.push(it.correlation);
594                }
595                out.lock().unwrap().extend(local);
596            }));
597        }
598        for h in handles {
599            h.join().unwrap();
600        }
601        let mut consumed = Arc::try_unwrap(shared_consumed)
602            .unwrap()
603            .into_inner()
604            .unwrap();
605        consumed.sort();
606        assert_eq!(consumed.len(), total as usize);
607        for (i, c) in consumed.iter().enumerate() {
608            assert_eq!(*c, i as u32, "duplicated or missing item at idx {i}");
609        }
610    }
611
612    #[test]
613    fn queue_full_error_display_is_useful() {
614        let s = format!("{QueueFull}");
615        assert!(s.contains("ring buffer"));
616    }
617}