1use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
39pub enum PersistentThreadMode {
40 #[default]
42 Auto,
43 Force,
45 Disable,
47}
48
49#[derive(Debug, Clone, Copy, PartialEq, Eq)]
54#[repr(C)]
55pub struct PersistentWorkItem {
56 pub input_offset: u32,
58 pub input_len: u32,
60 pub rule_set_id: u32,
62 pub correlation: u32,
66}
67
68#[derive(Debug)]
70pub struct RingAtomics {
71 pub head: AtomicU64,
73 pub tail: AtomicU64,
75 pub ready: Vec<AtomicU64>,
79 pub done: Vec<AtomicU32>,
81}
82
83impl RingAtomics {
84 fn try_new(ring_size: u32) -> Result<Self, String> {
85 let capacity = persistent_ring_capacity(ring_size)?;
86 let mut ready = Vec::new();
87 crate::allocation::try_reserve_vec_to_capacity(&mut ready, capacity).map_err(|error| {
88 format!("Fix: persistent ring could not reserve {capacity} ready marker(s): {error}.")
89 })?;
90 for slot in 0..ring_size {
91 ready.push(AtomicU64::new(u64::from(slot)));
92 }
93
94 let mut done = Vec::new();
95 crate::allocation::try_reserve_vec_to_capacity(&mut done, capacity).map_err(|error| {
96 format!("Fix: persistent ring could not reserve {capacity} done marker(s): {error}.")
97 })?;
98 for _ in 0..ring_size {
99 done.push(AtomicU32::new(0));
100 }
101
102 Ok(Self {
103 head: AtomicU64::new(0),
104 tail: AtomicU64::new(0),
105 ready,
106 done,
107 })
108 }
109}
110
111#[derive(Debug)]
112struct WorkSlot {
113 lo: AtomicU64,
114 hi: AtomicU64,
115}
116
117impl WorkSlot {
118 fn new(item: PersistentWorkItem) -> Self {
119 let (lo, hi) = pack_work_item(item);
120 Self {
121 lo: AtomicU64::new(lo),
122 hi: AtomicU64::new(hi),
123 }
124 }
125
126 fn store(&self, item: PersistentWorkItem) {
127 let (lo, hi) = pack_work_item(item);
128 self.lo.store(lo, Ordering::Relaxed);
129 self.hi.store(hi, Ordering::Relaxed);
130 }
131
132 fn load(&self) -> PersistentWorkItem {
133 unpack_work_item(
134 self.lo.load(Ordering::Relaxed),
135 self.hi.load(Ordering::Relaxed),
136 )
137 }
138}
139
140fn pack_work_item(item: PersistentWorkItem) -> (u64, u64) {
141 (
142 u64::from(item.input_offset) | (u64::from(item.input_len) << 32),
143 u64::from(item.rule_set_id) | (u64::from(item.correlation) << 32),
144 )
145}
146
147fn unpack_work_item(lo: u64, hi: u64) -> PersistentWorkItem {
148 PersistentWorkItem {
149 input_offset: lo as u32,
150 input_len: (lo >> 32) as u32,
151 rule_set_id: hi as u32,
152 correlation: (hi >> 32) as u32,
153 }
154}
155
156#[derive(Debug)]
160pub struct PersistentEngine {
161 slots: Vec<WorkSlot>,
162 atomics: RingAtomics,
163 ring_size: u32,
164}
165
166impl PersistentEngine {
167 pub fn new(ring_size: u32) -> Self {
171 let ring_size = ring_size
172 .checked_next_power_of_two()
173 .filter(|&size| size > 0)
174 .unwrap_or(1);
175 Self::with_valid_ring_size(ring_size)
176 }
177
178 pub fn try_new(ring_size: u32) -> Result<Self, String> {
181 if ring_size.is_power_of_two() && ring_size > 0 {
182 Self::try_with_valid_ring_size(ring_size)
183 } else {
184 Err(format!(
185 "Fix: ring_size must be a nonzero power of two, got {ring_size}."
186 ))
187 }
188 }
189
190 fn with_valid_ring_size(ring_size: u32) -> Self {
191 match Self::try_with_valid_ring_size(ring_size) {
192 Ok(engine) => engine,
193 Err(_) => Self::try_with_valid_ring_size(1).unwrap_or_else(|_| std::process::abort()),
194 }
195 }
196
197 fn try_with_valid_ring_size(ring_size: u32) -> Result<Self, String> {
198 let zero = PersistentWorkItem {
199 input_offset: 0,
200 input_len: 0,
201 rule_set_id: 0,
202 correlation: 0,
203 };
204 let capacity = persistent_ring_capacity(ring_size)?;
205 let mut slots = Vec::new();
206 crate::allocation::try_reserve_vec_to_capacity(&mut slots, capacity).map_err(|error| {
207 format!("Fix: persistent ring could not reserve {capacity} work slot(s): {error}.")
208 })?;
209 for _ in 0..ring_size {
210 slots.push(WorkSlot::new(zero));
211 }
212
213 Ok(Self {
214 slots,
215 atomics: RingAtomics::try_new(ring_size)?,
216 ring_size,
217 })
218 }
219
220 pub fn ring_size(&self) -> u32 {
222 self.ring_size
223 }
224
225 pub fn enqueue(&self, item: PersistentWorkItem) -> Result<u32, QueueFull> {
229 loop {
230 let head = self.atomics.head.load(Ordering::Acquire);
231 let slot_idx = (head as u32) & (self.ring_size - 1);
232 let slot_offset = slot_idx as usize;
233 let Some(ready) = self.atomics.ready.get(slot_offset) else {
234 return Err(QueueFull);
235 };
236 match ring_sequence_order(ready.load(Ordering::Acquire), head) {
237 RingSequenceOrder::Free => {}
238 RingSequenceOrder::Behind => return Err(QueueFull),
239 RingSequenceOrder::Ahead => {
240 std::hint::spin_loop();
241 continue;
242 }
243 }
244 match self.atomics.head.compare_exchange(
245 head,
246 head.wrapping_add(1),
247 Ordering::AcqRel,
248 Ordering::Acquire,
249 ) {
250 Ok(_) => {
251 let Some(slot) = self.slots.get(slot_offset) else {
252 return Err(QueueFull);
253 };
254 slot.store(item);
255 self.atomics.done[slot_offset].store(0, Ordering::Release);
256 self.atomics.ready[slot_offset].store(head.wrapping_add(1), Ordering::Release);
257 return Ok(slot_idx);
258 }
259 Err(_) => continue,
260 }
261 }
262 }
263
264 pub fn claim(&self) -> Option<PersistentWorkItem> {
268 loop {
269 let tail = self.atomics.tail.load(Ordering::Acquire);
270 let slot_idx = (tail as u32) & (self.ring_size - 1);
271 let slot_offset = slot_idx as usize;
272 let published = tail.wrapping_add(1);
273 let Some(ready) = self.atomics.ready.get(slot_offset) else {
274 return None;
275 };
276 match ring_sequence_order(ready.load(Ordering::Acquire), published) {
277 RingSequenceOrder::Free => {}
278 RingSequenceOrder::Behind => {
279 if tail >= self.atomics.head.load(Ordering::Acquire) {
280 return None;
281 }
282 std::hint::spin_loop();
283 continue;
284 }
285 RingSequenceOrder::Ahead => {
286 std::hint::spin_loop();
287 continue;
288 }
289 }
290 match self.atomics.tail.compare_exchange(
291 tail,
292 tail.wrapping_add(1),
293 Ordering::AcqRel,
294 Ordering::Acquire,
295 ) {
296 Ok(_) => {
297 let slot = self.slots.get(slot_offset)?;
298 let item = slot.load();
299 self.atomics.ready[slot_offset].store(
300 tail.wrapping_add(u64::from(self.ring_size)),
301 Ordering::Release,
302 );
303 return Some(item);
304 }
305 Err(_) => continue,
306 }
307 }
308 }
309
310 pub fn mark_done(&self, slot_idx: u32) -> Result<(), String> {
312 let Some(done) = self.atomics.done.get(slot_idx as usize) else {
313 return Err(format!(
314 "Fix: persistent ring slot_idx={slot_idx} is outside ring_size={}. Reject stale or corrupt completion markers before marking done.",
315 self.ring_size
316 ));
317 };
318 done.store(1, Ordering::Release);
319 Ok(())
320 }
321
322 pub fn is_done(&self, slot_idx: u32) -> Result<bool, String> {
324 let Some(done) = self.atomics.done.get(slot_idx as usize) else {
325 return Err(format!(
326 "Fix: persistent ring slot_idx={slot_idx} is outside ring_size={}. Reject stale or corrupt completion markers before reading done state.",
327 self.ring_size
328 ));
329 };
330 Ok(done.load(Ordering::Acquire) != 0)
331 }
332
333 pub fn try_in_flight(&self) -> Result<u32, String> {
335 let pending = self
336 .atomics
337 .head
338 .load(Ordering::Acquire)
339 .wrapping_sub(self.atomics.tail.load(Ordering::Acquire));
340 u32::try_from(pending).map_err(|_| {
341 format!(
342 "Fix: persistent engine in-flight count {pending} exceeds u32::MAX. Drain the ring or use the 64-bit counters before exporting GPU-visible queue metadata."
343 )
344 })
345 }
346
347 pub fn in_flight(&self) -> u32 {
349 self.try_in_flight().unwrap_or(u32::MAX)
350 }
351
352 pub fn head_counter(&self) -> u64 {
354 self.atomics.head.load(Ordering::Acquire)
355 }
356
357 pub fn head(&self) -> u32 {
359 let head = self.head_counter();
360 u32::try_from(head).unwrap_or(u32::MAX)
361 }
362
363 pub fn tail_counter(&self) -> u64 {
365 self.atomics.tail.load(Ordering::Acquire)
366 }
367
368 pub fn tail(&self) -> u32 {
370 let tail = self.tail_counter();
371 u32::try_from(tail).unwrap_or(u32::MAX)
372 }
373}
374
375fn persistent_ring_capacity(ring_size: u32) -> Result<usize, String> {
376 usize::try_from(ring_size).map_err(|_| {
377 format!("Fix: persistent ring_size {ring_size} does not fit this target's address space.")
378 })
379}
380
381#[derive(Debug, Clone, Copy, PartialEq, Eq)]
382enum RingSequenceOrder {
383 Behind,
384 Free,
385 Ahead,
386}
387
388fn ring_sequence_order(sequence: u64, position: u64) -> RingSequenceOrder {
389 match (sequence.wrapping_sub(position) as i64).cmp(&0) {
390 std::cmp::Ordering::Less => RingSequenceOrder::Behind,
391 std::cmp::Ordering::Equal => RingSequenceOrder::Free,
392 std::cmp::Ordering::Greater => RingSequenceOrder::Ahead,
393 }
394}
395
396#[derive(Debug, Clone, Copy, PartialEq, Eq)]
398pub struct QueueFull;
399
400impl std::fmt::Display for QueueFull {
401 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
402 f.write_str("persistent engine ring buffer is full")
403 }
404}
405
406impl std::error::Error for QueueFull {}
407
408#[cfg(test)]
409mod tests {
410 use super::*;
411 use std::sync::Arc;
412 use std::thread;
413
414 fn item(i: u32) -> PersistentWorkItem {
415 PersistentWorkItem {
416 input_offset: i * 1024,
417 input_len: 1024,
418 rule_set_id: 0,
419 correlation: i,
420 }
421 }
422
423 #[test]
424 fn invalid_ring_size_has_explicit_error_api() {
425 let err = PersistentEngine::try_new(7).unwrap_err();
426 assert!(err.contains("Fix:"));
427 assert!(PersistentEngine::try_new(0).is_err());
428 }
429
430 #[test]
431 fn infallible_constructor_normalizes_ring_size() {
432 assert_eq!(PersistentEngine::new(7).ring_size(), 8);
433 assert_eq!(PersistentEngine::new(0).ring_size(), 1);
434 }
435
436 #[test]
437 fn enqueue_claim_fifo_single_thread() {
438 let eng = PersistentEngine::new(8);
439 for i in 0..8 {
440 assert_eq!(eng.enqueue(item(i)).unwrap(), i);
441 }
442 for i in 0..8 {
443 assert_eq!(eng.claim().unwrap().correlation, i);
444 }
445 assert!(eng.claim().is_none());
446 }
447
448 #[test]
449 fn queue_full_on_overflow() {
450 let eng = PersistentEngine::new(4);
451 for i in 0..4 {
452 eng.enqueue(item(i)).unwrap();
453 }
454 assert_eq!(eng.enqueue(item(99)), Err(QueueFull));
455 }
456
457 #[test]
458 fn space_reclaims_after_claim() {
459 let eng = PersistentEngine::new(4);
460 for i in 0..4 {
461 eng.enqueue(item(i)).unwrap();
462 }
463 assert!(eng.enqueue(item(99)).is_err());
464 let claimed = eng.claim().unwrap();
465 assert_eq!(claimed.correlation, 0);
466 assert!(eng.enqueue(item(99)).is_ok());
467 }
468
469 #[test]
470 fn in_flight_tracks_correctly() {
471 let eng = PersistentEngine::new(16);
472 assert_eq!(eng.in_flight(), 0);
473 for i in 0..5 {
474 eng.enqueue(item(i)).unwrap();
475 }
476 assert_eq!(eng.in_flight(), 5);
477 eng.claim().unwrap();
478 eng.claim().unwrap();
479 assert_eq!(eng.in_flight(), 3);
480 }
481
482 #[test]
483 fn done_marker_flows_through() {
484 let eng = PersistentEngine::new(4);
485 let slot = eng.enqueue(item(1)).unwrap();
486 assert!(!eng.is_done(slot).unwrap());
487 let claimed = eng.claim().unwrap();
488 assert_eq!(claimed.correlation, 1);
489 eng.mark_done(slot).unwrap();
490 assert!(eng.is_done(slot).unwrap());
491 }
492
493 #[test]
494 fn multi_producer_single_consumer_no_item_lost() {
495 let eng = Arc::new(PersistentEngine::new(128));
496 let producers = 4;
497 let items_per_producer = 16;
498 let mut handles = Vec::new();
499 for p in 0..producers {
500 let eng = Arc::clone(&eng);
501 handles.push(thread::spawn(move || {
502 for i in 0..items_per_producer {
503 let corr = (p * 1000 + i) as u32;
504 loop {
505 if eng.enqueue(item(corr)).is_ok() {
506 break;
507 }
508 std::hint::spin_loop();
509 }
510 }
511 }));
512 }
513 let consumer_eng = Arc::clone(&eng);
514 let consumer = thread::spawn(move || {
515 let total = (producers * items_per_producer) as usize;
516 let mut seen = Vec::with_capacity(total);
517 while seen.len() < total {
518 if let Some(it) = consumer_eng.claim() {
519 seen.push(it.correlation);
520 } else {
521 std::hint::spin_loop();
522 }
523 }
524 seen
525 });
526 for h in handles {
527 h.join().unwrap();
528 }
529 let seen = consumer.join().unwrap();
530 let mut sorted = seen.clone();
531 sorted.sort();
532 sorted.dedup();
533 assert_eq!(sorted.len(), seen.len(), "duplicate items consumed");
534 for p in 0..producers {
535 for i in 0..items_per_producer {
536 let expected = (p * 1000 + i) as u32;
537 assert!(
538 seen.contains(&expected),
539 "missing correlation id {expected}"
540 );
541 }
542 }
543 }
544
545 #[test]
546 fn wrap_around_works_for_large_throughput() {
547 let eng = PersistentEngine::new(16);
548 let passes = 10;
549 for p in 0..passes {
550 for i in 0..16 {
551 let corr = (p * 1000 + i) as u32;
552 assert!(eng.enqueue(item(corr)).is_ok());
553 }
554 for i in 0..16 {
555 let corr = (p * 1000 + i) as u32;
556 assert_eq!(eng.claim().unwrap().correlation, corr);
557 }
558 }
559 assert_eq!(eng.head(), (passes * 16) as u32);
560 assert_eq!(eng.tail(), (passes * 16) as u32);
561 assert_eq!(eng.in_flight(), 0);
562 }
563
564 #[test]
565 fn multi_consumer_no_double_claim() {
566 let eng = Arc::new(PersistentEngine::new(128));
567 let total = 100_u32;
568 for i in 0..total {
569 eng.enqueue(item(i)).unwrap();
570 }
571 let consumers = 4;
572 let mut handles = Vec::new();
573 let shared_consumed = Arc::new(std::sync::Mutex::new(Vec::new()));
574 for _ in 0..consumers {
575 let eng = Arc::clone(&eng);
576 let out = Arc::clone(&shared_consumed);
577 handles.push(thread::spawn(move || {
578 let mut local = Vec::new();
579 while let Some(it) = eng.claim() {
580 local.push(it.correlation);
581 }
582 out.lock().unwrap().extend(local);
583 }));
584 }
585 for h in handles {
586 h.join().unwrap();
587 }
588 let mut consumed = Arc::try_unwrap(shared_consumed)
589 .unwrap()
590 .into_inner()
591 .unwrap();
592 consumed.sort();
593 assert_eq!(consumed.len(), total as usize);
594 for (i, c) in consumed.iter().enumerate() {
595 assert_eq!(*c, i as u32, "duplicated or missing item at idx {i}");
596 }
597 }
598
599 #[test]
600 fn queue_full_error_display_is_useful() {
601 let s = format!("{QueueFull}");
602 assert!(s.contains("ring buffer"));
603 }
604}