1#![deny(missing_docs)]
31
32use std::sync::Arc;
33use core::sync::atomic::{AtomicUsize, Ordering};
34use cache_line_size::CacheAligned;
35
36struct BipBuffer {
37 sequestered: Box<std::any::Any>,
38 buf: *mut u8,
39 len: usize,
40 read: CacheAligned<AtomicUsize>,
41 write: CacheAligned<AtomicUsize>,
42 last: CacheAligned<AtomicUsize>,
43}
44
45#[cfg(feature = "debug")]
46impl BipBuffer {
47 fn dbg_info(&self) -> String {
48 format!(" read: {:?} -- write: {:?} -- last: {:?} [len: {:?}] ",
49 self.read,
50 self.write,
51 self.last,
52 self.len)
53 }
54}
55
56pub struct BipBufferWriter {
60 buffer: Arc<BipBuffer>,
61 write: usize,
62 last: usize,
63}
64
65unsafe impl Send for BipBufferWriter {}
66
67pub struct BipBufferReader {
71 buffer: Arc<BipBuffer>,
72 read: usize,
73 priv_write: usize,
74 priv_last: usize,
75}
76
77unsafe impl Send for BipBufferReader {}
78
79pub fn bip_buffer_from<B: std::ops::DerefMut<Target=[u8]>+'static>(from: B) -> (BipBufferWriter, BipBufferReader) {
90 let mut sequestered = Box::new(from);
91 let len = sequestered.len();
92 let buf = sequestered.as_mut_ptr();
93
94 let buffer = Arc::new(BipBuffer {
95 sequestered,
96 buf,
97 len,
98 read: CacheAligned(AtomicUsize::new(0)),
99 write: CacheAligned(AtomicUsize::new(0)),
100 last: CacheAligned(AtomicUsize::new(0)),
101 });
102
103 (
104 BipBufferWriter {
105 buffer: buffer.clone(),
106 write: 0,
107 last: len,
108 },
109 BipBufferReader {
110 buffer,
111 read: 0,
112 priv_write: 0,
113 priv_last: len,
114 },
115 )
116}
117
118pub fn bip_buffer_with_len(len: usize) -> (BipBufferWriter, BipBufferReader) {
124 bip_buffer_from(vec![0u8; len].into_boxed_slice())
125}
126
127impl BipBuffer {
128 fn into_inner<B: std::ops::DerefMut<Target=[u8]>+'static>(self) -> B {
130 let BipBuffer { sequestered, .. } = self;
131 *sequestered.downcast::<B>().expect("incorrect underlying type")
132 }
133}
134
135#[derive(Clone, Copy)]
136struct PendingReservation {
137 start: usize,
138 len: usize,
139 wraparound: bool,
140}
141
142impl BipBufferWriter {
143 fn reserve_core(&mut self, len: usize) -> Option<PendingReservation> {
144 assert!(len > 0);
145 let read = self.buffer.read.0.load(Ordering::Acquire);
146 if self.write >= read {
147 if self.buffer.len.saturating_sub(self.write) >= len {
148 Some(PendingReservation {
149 start: self.write,
150 len,
151 wraparound: false,
152 })
153 } else {
154 if read.saturating_sub(1) >= len {
155 Some(PendingReservation {
156 start: 0,
157 len,
158 wraparound: true,
159 })
160 } else {
161 None
162 }
163 }
164 } else {
165 if (read - self.write).saturating_sub(1) >= len {
166 Some(PendingReservation {
167 start: self.write,
168 len,
169 wraparound: false,
170 })
171 } else {
172 None
173 }
174 }
175 }
176
177 pub fn reserve(&mut self, len: usize) -> Option<BipBufferWriterReservation<'_>> {
185 let reserved = self.reserve_core(len);
186 if let Some(PendingReservation { start, len, wraparound }) = reserved {
187 Some(BipBufferWriterReservation { writer: self, start, len, wraparound })
188 } else {
189 None
190 }
191 }
192
193 pub fn spin_reserve(&mut self, len: usize) -> BipBufferWriterReservation<'_> {
205 assert!(len <= self.buffer.len);
206 let PendingReservation { start, len, wraparound } = loop {
207 match self.reserve_core(len) {
208 None => continue,
209 Some(r) => break r,
210 }
211 };
212 BipBufferWriterReservation { writer: self, start, len, wraparound }
213 }
214
215 pub fn try_unwrap<B: std::ops::DerefMut<Target=[u8]>+'static>(self) -> Result<B, Self> {
224 let BipBufferWriter { buffer, write, last, } = self;
225 match Arc::try_unwrap(buffer) {
226 Ok(b) => Ok(b.into_inner()),
227 Err(buffer) => Err(BipBufferWriter { buffer, write, last, }),
228 }
229 }
230}
231
232pub struct BipBufferWriterReservation<'a> {
269 writer: &'a mut BipBufferWriter,
270 start: usize,
271 len: usize,
272 wraparound: bool,
273}
274
275impl<'a> core::ops::Deref for BipBufferWriterReservation<'a> {
276 type Target = [u8];
277
278 fn deref(&self) -> &[u8] {
279 unsafe {
280 core::slice::from_raw_parts(self.writer.buffer.buf.add(self.start), self.len)
281 }
282 }
283}
284
285impl<'a> core::ops::DerefMut for BipBufferWriterReservation<'a> {
286 fn deref_mut(&mut self) -> &mut [u8] {
287 unsafe {
288 core::slice::from_raw_parts_mut(self.writer.buffer.buf.add(self.start), self.len)
289 }
290 }
291}
292
293impl<'a> core::ops::Drop for BipBufferWriterReservation<'a> {
294 fn drop(&mut self) {
295 if self.wraparound {
296 self.writer.buffer.last.0.store(self.writer.write, Ordering::Relaxed);
297 self.writer.write = 0;
298 }
299 self.writer.write += self.len;
300 if self.writer.write > self.writer.last {
301 self.writer.last = self.writer.write;
302 self.writer.buffer.last.0.store(self.writer.last, Ordering::Relaxed);
303 }
304 self.writer.buffer.write.0.store(self.writer.write, Ordering::Release);
305
306 #[cfg(feature = "debug")]
307 eprintln!("+++{}", self.writer.buffer.dbg_info());
308 }
309}
310
311impl<'a> BipBufferWriterReservation<'a> {
312 pub fn send(self) {
315 }
317}
318
319impl BipBufferReader {
320 pub fn valid(&mut self) -> &mut [u8] {
325 #[cfg(feature = "debug")]
326 eprintln!("???{}", self.buffer.dbg_info());
327 self.priv_write = self.buffer.write.0.load(Ordering::Acquire);
328
329 if self.priv_write >= self.read {
330 unsafe {
331 core::slice::from_raw_parts_mut(self.buffer.buf.add(self.read), self.priv_write - self.read)
332 }
333 } else {
334 self.priv_last = self.buffer.last.0.load(Ordering::Relaxed);
335 if self.read == self.priv_last {
336 self.read = 0;
337 return self.valid();
338 }
339 unsafe {
340 core::slice::from_raw_parts_mut(self.buffer.buf.add(self.read), self.priv_last - self.read)
341 }
342 }
343 }
344
345 pub fn consume(&mut self, len: usize) -> bool {
349 if self.priv_write >= self.read {
350 if len <= self.priv_write - self.read {
351 self.read += len;
352 } else {
353 return false;
354 }
355 } else {
356 let remaining = self.priv_last - self.read;
357 if len == remaining {
358 self.read = 0;
359 } else if len <= remaining {
360 self.read += len;
361 } else {
362 return false;
363 }
364 }
365 self.buffer.read.0.store(self.read, Ordering::Release);
366 #[cfg(feature = "debug")]
367 eprintln!("---{}", self.buffer.dbg_info());
368 true
369 }
370
371 pub fn try_unwrap<B: std::ops::DerefMut<Target=[u8]>+'static>(self) -> Result<B, Self> {
380 let BipBufferReader { buffer, read, priv_write, priv_last, } = self;
381 match Arc::try_unwrap(buffer) {
382 Ok(b) => Ok(b.into_inner()),
383 Err(buffer) => Err(BipBufferReader { buffer, read, priv_write, priv_last, }),
384 }
385 }
386}
387
388#[cfg(test)]
389mod tests {
390 use crate::bip_buffer_from;
391
392 #[test]
393 fn basic() {
394 for i in 0..128 {
395 let (mut writer, mut reader) = bip_buffer_from(vec![0u8; 16].into_boxed_slice());
396 let sender = std::thread::spawn(move || {
397 writer.reserve(8).as_mut().expect("reserve").copy_from_slice(&[10, 11, 12, 13, 14, 15, 16, i]);
398 });
399 let receiver = std::thread::spawn(move || {
400 while reader.valid().len() < 8 {}
401 assert_eq!(reader.valid(), &[10, 11, 12, 13, 14, 15, 16, i]);
402 reader.consume(8);
403 });
404 sender.join().unwrap();
405 receiver.join().unwrap();
406 }
407 }
408
409 #[test]
410 fn spsc() {
411 let (mut writer, mut reader) = bip_buffer_from(vec![0u8; 256].into_boxed_slice());
412 let sender = std::thread::spawn(move || {
413 for i in 0..128 {
414 writer.spin_reserve(8).copy_from_slice(&[10, 11, 12, 13, 14, 15, 16, i]);
415 }
416 });
417 let receiver = std::thread::spawn(move || {
418 for i in 0..128 {
419 while reader.valid().len() < 8 {}
420 assert_eq!(&reader.valid()[..8], &[10, 11, 12, 13, 14, 15, 16, i]);
421 reader.consume(8);
422 }
423 });
424 sender.join().unwrap();
425 receiver.join().unwrap();
426 }
427
428 #[test]
429 fn provided_storage() {
430 let storage = vec![0u8; 256].into_boxed_slice();
431 let (mut writer, mut reader) = bip_buffer_from(storage);
432 let sender = std::thread::spawn(move || {
433 writer.spin_reserve(8).copy_from_slice(&[10, 11, 12, 13, 14, 15, 16, 17]);
434 });
435 let receiver = std::thread::spawn(move || {
436 while reader.valid().len() < 8 {}
437 reader.consume(8);
438 reader
439 });
440 sender.join().unwrap();
441 let reader = receiver.join().unwrap();
442 let _: Box<[u8]> = reader.try_unwrap().map_err(|_| ()).expect("failed to recover storage");
443 }
444
445 #[test]
446 #[should_panic]
447 fn provided_storage_wrong_type() {
448 let storage = vec![0u8; 256].into_boxed_slice();
449 let (writer, reader) = bip_buffer_from(storage);
450 std::mem::drop(writer);
451 let _: Vec<u8> = reader.try_unwrap().map_err(|_| ()).expect("failed to recover storage");
452 }
453
454 #[test]
455 fn provided_storage_still_alive() {
456 let storage = vec![0u8; 256].into_boxed_slice();
457 let (writer, reader) = bip_buffer_from(storage);
458 let result: Result<Box<[u8]>, _> = reader.try_unwrap();
459 assert!(result.is_err());
460 std::mem::drop(writer);
461 }
462
463 #[test]
464 fn static_prime_length() {
465 const MSG_LENGTH: u8 = 17; let (mut writer, mut reader) = bip_buffer_from(vec![128u8; 64].into_boxed_slice());
467 let sender = std::thread::spawn(move || {
468 let mut msg = [0u8; MSG_LENGTH as usize];
469 for _ in 0..1024 {
470 for i in 0..128u8 {
471 &mut msg[..].copy_from_slice(&[i; MSG_LENGTH as usize][..]);
472 msg[i as usize % (MSG_LENGTH as usize)] = 0;
473 writer.spin_reserve(MSG_LENGTH as usize).copy_from_slice(&msg[..]);
474 }
475 }
476 });
477 let receiver = std::thread::spawn(move || {
478 let mut msg = [0u8; MSG_LENGTH as usize];
479 for _ in 0..1024 {
480 for i in 0..128u8 {
481 &mut msg[..].copy_from_slice(&[i; MSG_LENGTH as usize][..]);
482 msg[i as usize % (MSG_LENGTH as usize)] = 0;
483 while reader.valid().len() < (MSG_LENGTH as usize) {}
484 assert_eq!(&reader.valid()[..MSG_LENGTH as usize], &msg[..]);
485 assert!(reader.consume(MSG_LENGTH as usize));
486 }
487 }
488 });
489 sender.join().unwrap();
490 receiver.join().unwrap();
491 }
492
493 #[test]
494 fn random_length() {
495 use rand::Rng;
496
497 const MAX_LENGTH: usize = 127;
498 let (mut writer, mut reader) = bip_buffer_from(vec![0u8; 1024]);
499 let sender = std::thread::spawn(move || {
500 let mut rng = rand::thread_rng();
501 let mut msg = [0u8; MAX_LENGTH];
502 for _ in 0..1024 {
503 for round in 0..128u8 {
504 let length: u8 = rng.gen_range(1, MAX_LENGTH as u8);
505 msg[0] = length;
506 for i in 1..length {
507 msg[i as usize] = round;
508 }
509 writer.spin_reserve(length as usize).copy_from_slice(&msg[..length as usize]);
510 }
511 }
512 });
513 let receiver = std::thread::spawn(move || {
514 let mut msg = [0u8; MAX_LENGTH];
515 for _ in 0..1024 {
516 for round in 0..128u8 {
517 let msg_len = loop {
518 let valid = reader.valid();
519 if valid.len() < 1 { continue; }
520 break valid[0] as usize;
521 };
522 let recv_msg = loop {
523 let valid = reader.valid();
524 if valid.len() < msg_len { continue; }
525 break valid;
526 };
527 msg[0] = msg_len as u8;
528 for i in 1..msg_len {
529 msg[i as usize] = round;
530 }
531 assert_eq!(&recv_msg[..msg_len], &msg[..msg_len]);
532 assert!(reader.consume(msg_len as usize));
533 }
534 }
535 });
536 sender.join().unwrap();
537 receiver.join().unwrap();
538 }
539
540}