1use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
39pub enum PersistentThreadMode {
40 #[default]
42 Auto,
43 Force,
45 Disable,
47}
48
49#[derive(Debug, Clone, Copy, PartialEq, Eq)]
54#[repr(C)]
55pub struct PersistentWorkItem {
56 pub input_offset: u32,
58 pub input_len: u32,
60 pub rule_set_id: u32,
62 pub correlation: u32,
66}
67
68#[derive(Debug)]
70pub struct RingAtomics {
71 pub head: AtomicU64,
73 pub tail: AtomicU64,
75 pub ready: Vec<AtomicU64>,
79 pub done: Vec<AtomicU32>,
81}
82
83impl RingAtomics {
84 fn try_new(ring_size: u32) -> Result<Self, String> {
85 let capacity = persistent_ring_capacity(ring_size)?;
86 let mut ready = Vec::new();
87 crate::allocation::try_reserve_vec_to_capacity(&mut ready, capacity).map_err(|error| {
88 format!("Fix: persistent ring could not reserve {capacity} ready marker(s): {error}.")
89 })?;
90 for slot in 0..ring_size {
91 ready.push(AtomicU64::new(u64::from(slot)));
92 }
93
94 let mut done = Vec::new();
95 crate::allocation::try_reserve_vec_to_capacity(&mut done, capacity).map_err(|error| {
96 format!("Fix: persistent ring could not reserve {capacity} done marker(s): {error}.")
97 })?;
98 for _ in 0..ring_size {
99 done.push(AtomicU32::new(0));
100 }
101
102 Ok(Self {
103 head: AtomicU64::new(0),
104 tail: AtomicU64::new(0),
105 ready,
106 done,
107 })
108 }
109}
110
111#[derive(Debug)]
112struct WorkSlot {
113 lo: AtomicU64,
114 hi: AtomicU64,
115}
116
117impl WorkSlot {
118 fn new(item: PersistentWorkItem) -> Self {
119 let (lo, hi) = pack_work_item(item);
120 Self {
121 lo: AtomicU64::new(lo),
122 hi: AtomicU64::new(hi),
123 }
124 }
125
126 fn store(&self, item: PersistentWorkItem) {
127 let (lo, hi) = pack_work_item(item);
128 self.lo.store(lo, Ordering::Relaxed);
129 self.hi.store(hi, Ordering::Relaxed);
130 }
131
132 fn load(&self) -> PersistentWorkItem {
133 unpack_work_item(
134 self.lo.load(Ordering::Relaxed),
135 self.hi.load(Ordering::Relaxed),
136 )
137 }
138}
139
140fn pack_work_item(item: PersistentWorkItem) -> (u64, u64) {
141 (
142 u64::from(item.input_offset) | (u64::from(item.input_len) << 32),
143 u64::from(item.rule_set_id) | (u64::from(item.correlation) << 32),
144 )
145}
146
147fn unpack_work_item(lo: u64, hi: u64) -> PersistentWorkItem {
148 PersistentWorkItem {
149 input_offset: lo as u32,
150 input_len: (lo >> 32) as u32,
151 rule_set_id: hi as u32,
152 correlation: (hi >> 32) as u32,
153 }
154}
155
156#[derive(Debug)]
160pub struct PersistentEngine {
161 slots: Vec<WorkSlot>,
162 atomics: RingAtomics,
163 ring_size: u32,
164}
165
166impl PersistentEngine {
167 pub fn new(ring_size: u32) -> Self {
171 let ring_size = ring_size
172 .checked_next_power_of_two()
173 .filter(|&size| size > 0)
174 .unwrap_or_else(|| {
175 panic!(
176 "Fix: persistent ring_size {ring_size} cannot be rounded to a nonzero power of two without overflow."
177 )
178 });
179 Self::with_valid_ring_size(ring_size)
180 }
181
182 pub fn try_new(ring_size: u32) -> Result<Self, String> {
185 if ring_size.is_power_of_two() && ring_size > 0 {
186 Self::try_with_valid_ring_size(ring_size)
187 } else {
188 Err(format!(
189 "Fix: ring_size must be a nonzero power of two, got {ring_size}."
190 ))
191 }
192 }
193
194 fn with_valid_ring_size(ring_size: u32) -> Self {
195 match Self::try_with_valid_ring_size(ring_size) {
196 Ok(engine) => engine,
197 Err(error) => panic!("{error}"),
198 }
199 }
200
201 fn try_with_valid_ring_size(ring_size: u32) -> Result<Self, String> {
202 let zero = PersistentWorkItem {
203 input_offset: 0,
204 input_len: 0,
205 rule_set_id: 0,
206 correlation: 0,
207 };
208 let capacity = persistent_ring_capacity(ring_size)?;
209 let mut slots = Vec::new();
210 crate::allocation::try_reserve_vec_to_capacity(&mut slots, capacity).map_err(|error| {
211 format!("Fix: persistent ring could not reserve {capacity} work slot(s): {error}.")
212 })?;
213 for _ in 0..ring_size {
214 slots.push(WorkSlot::new(zero));
215 }
216
217 Ok(Self {
218 slots,
219 atomics: RingAtomics::try_new(ring_size)?,
220 ring_size,
221 })
222 }
223
224 pub fn ring_size(&self) -> u32 {
226 self.ring_size
227 }
228
229 pub fn enqueue(&self, item: PersistentWorkItem) -> Result<u32, QueueFull> {
233 loop {
234 let head = self.atomics.head.load(Ordering::Acquire);
235 let slot_idx = (head as u32) & (self.ring_size - 1);
236 let slot_offset = slot_idx as usize;
237 let Some(ready) = self.atomics.ready.get(slot_offset) else {
238 return Err(QueueFull);
239 };
240 match ring_sequence_order(ready.load(Ordering::Acquire), head) {
241 RingSequenceOrder::Free => {}
242 RingSequenceOrder::Behind => return Err(QueueFull),
243 RingSequenceOrder::Ahead => {
244 std::hint::spin_loop();
245 continue;
246 }
247 }
248 match self.atomics.head.compare_exchange(
249 head,
250 head.wrapping_add(1),
251 Ordering::AcqRel,
252 Ordering::Acquire,
253 ) {
254 Ok(_) => {
255 let Some(slot) = self.slots.get(slot_offset) else {
256 return Err(QueueFull);
257 };
258 slot.store(item);
259 self.atomics.done[slot_offset].store(0, Ordering::Release);
260 self.atomics.ready[slot_offset].store(head.wrapping_add(1), Ordering::Release);
261 return Ok(slot_idx);
262 }
263 Err(_) => continue,
264 }
265 }
266 }
267
268 pub fn claim(&self) -> Option<PersistentWorkItem> {
272 loop {
273 let tail = self.atomics.tail.load(Ordering::Acquire);
274 let slot_idx = (tail as u32) & (self.ring_size - 1);
275 let slot_offset = slot_idx as usize;
276 let published = tail.wrapping_add(1);
277 let Some(ready) = self.atomics.ready.get(slot_offset) else {
278 return None;
279 };
280 match ring_sequence_order(ready.load(Ordering::Acquire), published) {
281 RingSequenceOrder::Free => {}
282 RingSequenceOrder::Behind => {
283 if tail >= self.atomics.head.load(Ordering::Acquire) {
284 return None;
285 }
286 std::hint::spin_loop();
287 continue;
288 }
289 RingSequenceOrder::Ahead => {
290 std::hint::spin_loop();
291 continue;
292 }
293 }
294 match self.atomics.tail.compare_exchange(
295 tail,
296 tail.wrapping_add(1),
297 Ordering::AcqRel,
298 Ordering::Acquire,
299 ) {
300 Ok(_) => {
301 let slot = self.slots.get(slot_offset)?;
302 let item = slot.load();
303 self.atomics.ready[slot_offset].store(
304 tail.wrapping_add(u64::from(self.ring_size)),
305 Ordering::Release,
306 );
307 return Some(item);
308 }
309 Err(_) => continue,
310 }
311 }
312 }
313
314 pub fn mark_done(&self, slot_idx: u32) -> Result<(), String> {
316 let Some(done) = self.atomics.done.get(slot_idx as usize) else {
317 return Err(format!(
318 "Fix: persistent ring slot_idx={slot_idx} is outside ring_size={}. Reject stale or corrupt completion markers before marking done.",
319 self.ring_size
320 ));
321 };
322 done.store(1, Ordering::Release);
323 Ok(())
324 }
325
326 pub fn is_done(&self, slot_idx: u32) -> Result<bool, String> {
328 let Some(done) = self.atomics.done.get(slot_idx as usize) else {
329 return Err(format!(
330 "Fix: persistent ring slot_idx={slot_idx} is outside ring_size={}. Reject stale or corrupt completion markers before reading done state.",
331 self.ring_size
332 ));
333 };
334 Ok(done.load(Ordering::Acquire) != 0)
335 }
336
337 pub fn try_in_flight(&self) -> Result<u32, String> {
339 let pending = self
340 .atomics
341 .head
342 .load(Ordering::Acquire)
343 .wrapping_sub(self.atomics.tail.load(Ordering::Acquire));
344 u32::try_from(pending).map_err(|_| {
345 format!(
346 "Fix: persistent engine in-flight count {pending} exceeds u32::MAX. Drain the ring or use the 64-bit counters before exporting GPU-visible queue metadata."
347 )
348 })
349 }
350
351 pub fn in_flight(&self) -> u32 {
353 self.try_in_flight()
354 .unwrap_or_else(|message| panic!("{message}"))
355 }
356
357 pub fn head_counter(&self) -> u64 {
359 self.atomics.head.load(Ordering::Acquire)
360 }
361
362 pub fn head(&self) -> u32 {
364 let head = self.head_counter();
365 u32::try_from(head).unwrap_or_else(|_| {
366 panic!(
367 "Fix: persistent engine head counter {head} exceeds u32::MAX. Use head_counter() for long-running queues instead of truncating telemetry."
368 )
369 })
370 }
371
372 pub fn tail_counter(&self) -> u64 {
374 self.atomics.tail.load(Ordering::Acquire)
375 }
376
377 pub fn tail(&self) -> u32 {
379 let tail = self.tail_counter();
380 u32::try_from(tail).unwrap_or_else(|_| {
381 panic!(
382 "Fix: persistent engine tail counter {tail} exceeds u32::MAX. Use tail_counter() for long-running queues instead of truncating telemetry."
383 )
384 })
385 }
386}
387
388fn persistent_ring_capacity(ring_size: u32) -> Result<usize, String> {
389 usize::try_from(ring_size).map_err(|_| {
390 format!("Fix: persistent ring_size {ring_size} does not fit this target's address space.")
391 })
392}
393
394#[derive(Debug, Clone, Copy, PartialEq, Eq)]
395enum RingSequenceOrder {
396 Behind,
397 Free,
398 Ahead,
399}
400
401fn ring_sequence_order(sequence: u64, position: u64) -> RingSequenceOrder {
402 match (sequence.wrapping_sub(position) as i64).cmp(&0) {
403 std::cmp::Ordering::Less => RingSequenceOrder::Behind,
404 std::cmp::Ordering::Equal => RingSequenceOrder::Free,
405 std::cmp::Ordering::Greater => RingSequenceOrder::Ahead,
406 }
407}
408
409#[derive(Debug, Clone, Copy, PartialEq, Eq)]
411pub struct QueueFull;
412
413impl std::fmt::Display for QueueFull {
414 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
415 f.write_str("persistent engine ring buffer is full")
416 }
417}
418
419impl std::error::Error for QueueFull {}
420
421#[cfg(test)]
422mod tests {
423 use super::*;
424 use std::sync::Arc;
425 use std::thread;
426
427 fn item(i: u32) -> PersistentWorkItem {
428 PersistentWorkItem {
429 input_offset: i * 1024,
430 input_len: 1024,
431 rule_set_id: 0,
432 correlation: i,
433 }
434 }
435
436 #[test]
437 fn invalid_ring_size_has_explicit_error_api() {
438 let err = PersistentEngine::try_new(7).unwrap_err();
439 assert!(err.contains("Fix:"));
440 assert!(PersistentEngine::try_new(0).is_err());
441 }
442
443 #[test]
444 fn infallible_constructor_normalizes_ring_size() {
445 assert_eq!(PersistentEngine::new(7).ring_size(), 8);
446 assert_eq!(PersistentEngine::new(0).ring_size(), 1);
447 }
448
449 #[test]
450 fn enqueue_claim_fifo_single_thread() {
451 let eng = PersistentEngine::new(8);
452 for i in 0..8 {
453 assert_eq!(eng.enqueue(item(i)).unwrap(), i);
454 }
455 for i in 0..8 {
456 assert_eq!(eng.claim().unwrap().correlation, i);
457 }
458 assert!(eng.claim().is_none());
459 }
460
461 #[test]
462 fn queue_full_on_overflow() {
463 let eng = PersistentEngine::new(4);
464 for i in 0..4 {
465 eng.enqueue(item(i)).unwrap();
466 }
467 assert_eq!(eng.enqueue(item(99)), Err(QueueFull));
468 }
469
470 #[test]
471 fn space_reclaims_after_claim() {
472 let eng = PersistentEngine::new(4);
473 for i in 0..4 {
474 eng.enqueue(item(i)).unwrap();
475 }
476 assert!(eng.enqueue(item(99)).is_err());
477 let claimed = eng.claim().unwrap();
478 assert_eq!(claimed.correlation, 0);
479 assert!(eng.enqueue(item(99)).is_ok());
480 }
481
482 #[test]
483 fn in_flight_tracks_correctly() {
484 let eng = PersistentEngine::new(16);
485 assert_eq!(eng.in_flight(), 0);
486 for i in 0..5 {
487 eng.enqueue(item(i)).unwrap();
488 }
489 assert_eq!(eng.in_flight(), 5);
490 eng.claim().unwrap();
491 eng.claim().unwrap();
492 assert_eq!(eng.in_flight(), 3);
493 }
494
495 #[test]
496 fn done_marker_flows_through() {
497 let eng = PersistentEngine::new(4);
498 let slot = eng.enqueue(item(1)).unwrap();
499 assert!(!eng.is_done(slot).unwrap());
500 let claimed = eng.claim().unwrap();
501 assert_eq!(claimed.correlation, 1);
502 eng.mark_done(slot).unwrap();
503 assert!(eng.is_done(slot).unwrap());
504 }
505
506 #[test]
507 fn multi_producer_single_consumer_no_item_lost() {
508 let eng = Arc::new(PersistentEngine::new(128));
509 let producers = 4;
510 let items_per_producer = 16;
511 let mut handles = Vec::new();
512 for p in 0..producers {
513 let eng = Arc::clone(&eng);
514 handles.push(thread::spawn(move || {
515 for i in 0..items_per_producer {
516 let corr = (p * 1000 + i) as u32;
517 loop {
518 if eng.enqueue(item(corr)).is_ok() {
519 break;
520 }
521 std::hint::spin_loop();
522 }
523 }
524 }));
525 }
526 let consumer_eng = Arc::clone(&eng);
527 let consumer = thread::spawn(move || {
528 let total = (producers * items_per_producer) as usize;
529 let mut seen = Vec::with_capacity(total);
530 while seen.len() < total {
531 if let Some(it) = consumer_eng.claim() {
532 seen.push(it.correlation);
533 } else {
534 std::hint::spin_loop();
535 }
536 }
537 seen
538 });
539 for h in handles {
540 h.join().unwrap();
541 }
542 let seen = consumer.join().unwrap();
543 let mut sorted = seen.clone();
544 sorted.sort();
545 sorted.dedup();
546 assert_eq!(sorted.len(), seen.len(), "duplicate items consumed");
547 for p in 0..producers {
548 for i in 0..items_per_producer {
549 let expected = (p * 1000 + i) as u32;
550 assert!(
551 seen.contains(&expected),
552 "missing correlation id {expected}"
553 );
554 }
555 }
556 }
557
558 #[test]
559 fn wrap_around_works_for_large_throughput() {
560 let eng = PersistentEngine::new(16);
561 let passes = 10;
562 for p in 0..passes {
563 for i in 0..16 {
564 let corr = (p * 1000 + i) as u32;
565 assert!(eng.enqueue(item(corr)).is_ok());
566 }
567 for i in 0..16 {
568 let corr = (p * 1000 + i) as u32;
569 assert_eq!(eng.claim().unwrap().correlation, corr);
570 }
571 }
572 assert_eq!(eng.head(), (passes * 16) as u32);
573 assert_eq!(eng.tail(), (passes * 16) as u32);
574 assert_eq!(eng.in_flight(), 0);
575 }
576
577 #[test]
578 fn multi_consumer_no_double_claim() {
579 let eng = Arc::new(PersistentEngine::new(128));
580 let total = 100_u32;
581 for i in 0..total {
582 eng.enqueue(item(i)).unwrap();
583 }
584 let consumers = 4;
585 let mut handles = Vec::new();
586 let shared_consumed = Arc::new(std::sync::Mutex::new(Vec::new()));
587 for _ in 0..consumers {
588 let eng = Arc::clone(&eng);
589 let out = Arc::clone(&shared_consumed);
590 handles.push(thread::spawn(move || {
591 let mut local = Vec::new();
592 while let Some(it) = eng.claim() {
593 local.push(it.correlation);
594 }
595 out.lock().unwrap().extend(local);
596 }));
597 }
598 for h in handles {
599 h.join().unwrap();
600 }
601 let mut consumed = Arc::try_unwrap(shared_consumed)
602 .unwrap()
603 .into_inner()
604 .unwrap();
605 consumed.sort();
606 assert_eq!(consumed.len(), total as usize);
607 for (i, c) in consumed.iter().enumerate() {
608 assert_eq!(*c, i as u32, "duplicated or missing item at idx {i}");
609 }
610 }
611
612 #[test]
613 fn queue_full_error_display_is_useful() {
614 let s = format!("{QueueFull}");
615 assert!(s.contains("ring buffer"));
616 }
617}