1#![deny(missing_docs)]
2use std::{cell::UnsafeCell, collections::VecDeque, io, sync::Arc, task::Waker};
5
6#[cfg(feature = "loom")]
7use loom::sync::{
8 Mutex, RwLock,
9 atomic::{AtomicBool, AtomicU64, Ordering},
10};
11#[cfg(not(feature = "loom"))]
12use std::sync::{
13 Mutex, RwLock,
14 atomic::{AtomicBool, AtomicU64, Ordering},
15};
16
17use stable_vec::StableVec;
18use thiserror::Error;
19use tracing::{Span, debug, field::Empty, instrument};
20
21mod driver;
22mod id_factory;
23mod reader;
24mod writer;
25
26pub use driver::{ChannelDriver, ChannelStrongIoDriver, ChannelWeakIoDriver};
27pub use reader::{Reader, StrongReader, WeakReader};
28pub use writer::{StrongWriter, WeakWriter, Writer};
29
30use crate::id_factory::{Id, IdFactory};
31
32type Result<T> = std::result::Result<T, ChannelError>;
33
34#[derive(Clone, Copy, Debug, PartialEq, Eq)]
36pub enum Backpressure {
37 Park,
39 Drop,
41}
42
43struct RingBuffer {
54 buf: UnsafeCell<Box<[u8]>>,
56 size: usize,
58 mask: u64,
60}
61
62#[derive(Clone)]
63struct FrameMeta {
64 start: u64,
65 len: u64,
66 writer_id: u16,
67}
68
69pub struct Channel {
76 buf: RingBuffer,
78 queue: Mutex<Vec<(u64, Waker)>>,
80 heads: RwLock<StableVec<Arc<AtomicU64>>>,
82 tails: RwLock<StableVec<Arc<AtomicU64>>>,
84 idf: Arc<IdFactory>,
86 frames: Mutex<VecDeque<FrameMeta>>,
88 tail_cache: AtomicU64,
92 next_tail: AtomicU64,
94 terminated: AtomicBool,
96 draining: AtomicBool,
98 backpressure: Backpressure,
100 span: Span,
102}
103
104#[derive(Error, Debug, PartialEq)]
106pub enum ChannelError {
107 #[error("Reader was too slow and got left behind")]
109 ReaderBehind(u64),
110 #[error("Cannot terminate channel that is draining")]
112 TerminateDraining,
113 #[error("Cannot drain a terminated channel")]
115 DrainTerminated,
116 #[error("io error: {0}")]
118 Io(String),
119}
120
121impl From<io::Error> for ChannelError {
122 fn from(value: io::Error) -> Self {
123 Self::Io(value.to_string())
124 }
125}
126
127impl RingBuffer {
128 fn new(mut size: usize) -> Self {
129 size = size.next_power_of_two();
130 let buf: Vec<u8> = vec![0u8; size];
134
135 Self {
136 buf: UnsafeCell::new(buf.into_boxed_slice()),
137 size,
138 mask: (size - 1) as u64,
139 }
140 }
141
142 unsafe fn read(&self, dst: &mut [u8], pos: u64) {
151 let dstlen = dst.len().min(self.size);
152
153 let idx: usize = (pos & self.mask)
155 .try_into()
156 .expect("pointer size less than 64b");
157
158 let taillen = dstlen.min(self.size - idx);
159 let headlen = idx.min(dstlen - taillen);
160
161 let src = unsafe { (&mut *self.buf.get()).as_mut_ptr() };
163 let dst = dst.as_mut_ptr();
164
165 unsafe {
167 src.add(idx).copy_to_nonoverlapping(dst, taillen);
168 }
169
170 if headlen > 0 {
171 unsafe {
173 src.copy_to_nonoverlapping(dst.add(taillen), headlen);
174 }
175 }
176 }
177
178 unsafe fn write(&self, src: &[u8], pos: u64) {
187 let srclen = src.len().min(self.size);
188
189 let idx: usize = (pos & self.mask)
191 .try_into()
192 .expect("pointer size less than 64b");
193
194 let taillen = srclen.min(self.size - idx);
195 let headlen = idx.min(srclen - taillen);
196
197 let src = src.as_ptr();
198 let dst = unsafe { (&mut *self.buf.get()).as_mut_ptr() };
200
201 unsafe {
203 src.copy_to_nonoverlapping(dst.add(idx), taillen);
204 }
205
206 if headlen > 0 {
207 unsafe {
209 src.add(taillen).copy_to_nonoverlapping(dst, headlen);
210 }
211 }
212 }
213}
214
215unsafe impl Sync for RingBuffer {}
216
217impl Channel {
218 pub fn new(size: usize) -> Arc<Self> {
220 Self::with_parameters(size, Backpressure::Park)
221 }
222
223 #[instrument(name = "Channel", skip_all, fields(ptr = Empty))]
225 pub fn with_parameters(size: usize, backpressure: Backpressure) -> Arc<Self> {
226 let this = Arc::new(Self {
227 buf: RingBuffer::new(size),
228 queue: Mutex::new(Vec::new()),
229 heads: RwLock::new(StableVec::new()),
230 tails: RwLock::new(StableVec::new()),
231 idf: IdFactory::new(),
232 frames: Mutex::new(VecDeque::new()),
233 tail_cache: AtomicU64::new(0),
234 next_tail: AtomicU64::new(0),
235 terminated: AtomicBool::new(false),
236 draining: AtomicBool::new(false),
237 backpressure,
238 span: Span::current(),
239 });
240
241 this.span
242 .record("ptr", format_args!("{:p}", this.as_ref() as *const _));
243 debug!("create channel");
244
245 this
246 }
247
248 fn get_head(&self) -> Option<u64> {
249 let heads = self
250 .heads
251 .read()
252 .unwrap_or_else(|poison| poison.into_inner());
253 self.get_head_locked(&heads)
254 }
255
256 fn get_head_locked(&self, heads: &StableVec<Arc<AtomicU64>>) -> Option<u64> {
257 heads.iter().map(|(_, t)| t.load(Ordering::Acquire)).min()
258 }
259
260 fn get_tail(&self) -> u64 {
261 let tails = self
262 .tails
263 .read()
264 .unwrap_or_else(|poison| poison.into_inner());
265 self.get_tail_locked(&tails)
266 }
267
268 fn get_tail_locked(&self, tails: &StableVec<Arc<AtomicU64>>) -> u64 {
269 tails
270 .iter()
271 .map(|(_, t)| t.load(Ordering::Acquire))
272 .min()
273 .unwrap_or(self.tail_cache.load(Ordering::Acquire))
274 }
275
276 fn reader_start_pos(&self, head: Option<u64>) -> u64 {
277 if let Some(head) = head {
278 return head;
279 }
280
281 let tail = self.get_tail();
282 let floor = tail.saturating_sub(self.buf.size as u64);
283 let frames = self
284 .frames
285 .lock()
286 .unwrap_or_else(|poison| poison.into_inner());
287
288 frames
289 .iter()
290 .find(|frame| frame.start >= floor)
291 .map(|frame| frame.start)
292 .unwrap_or(tail)
293 }
294
295 fn remove_head(&self, idx: usize) {
296 self.heads
297 .write()
298 .unwrap_or_else(|poison| poison.into_inner())
299 .remove(idx);
300 self.prune_frames();
301 }
302
303 fn remove_tail(&self, idx: usize) {
304 let mut tails = self
305 .tails
306 .write()
307 .unwrap_or_else(|poison| poison.into_inner());
308 if let Some(tail) = tails.remove(idx) {
309 self.tail_cache
310 .fetch_max(tail.load(Ordering::Acquire), Ordering::AcqRel);
311 }
312 }
313
314 fn register_frame(&self, start: u64, len: u64, writer_id: u16) {
315 let mut frames = self
316 .frames
317 .lock()
318 .unwrap_or_else(|poison| poison.into_inner());
319 frames.push_back(FrameMeta {
320 start,
321 len,
322 writer_id,
323 });
324 }
325
326 fn frame_for(&self, pos: u64) -> Option<FrameMeta> {
327 let frames = self
328 .frames
329 .lock()
330 .unwrap_or_else(|poison| poison.into_inner());
331 frames.iter().find(|frame| frame.start == pos).cloned()
332 }
333
334 fn frame_from(&self, pos: u64) -> Option<FrameMeta> {
335 let frames = self
336 .frames
337 .lock()
338 .unwrap_or_else(|poison| poison.into_inner());
339 frames.iter().find(|frame| frame.start >= pos).cloned()
340 }
341
342 fn prune_frames(&self) {
343 let head = match self.get_head() {
344 Some(head) => head,
345 None => self.get_tail().saturating_sub(self.buf.size as u64),
346 };
347
348 let mut frames = self
349 .frames
350 .lock()
351 .unwrap_or_else(|poison| poison.into_inner());
352 while let Some(frame) = frames.front() {
353 if frame.start + frame.len <= head {
354 frames.pop_front();
355 } else {
356 break;
357 }
358 }
359 }
360
361 fn writable_size(&self, pos: u64) -> u64 {
363 (self.buf.size as u64).saturating_sub(pos - self.get_head().unwrap_or(pos))
364 }
365
366 fn write(&self, pos: u64, buf: &[u8]) -> usize {
368 let len = (buf.len() as u64).min(self.writable_size(pos));
370
371 if len == 0 {
373 return 0;
374 }
375
376 let ulen: usize = len.try_into().expect("pointer size less than 64b");
377
378 unsafe { self.buf.write(&buf[..ulen], pos) };
381
382 ulen
383 }
384
385 fn read(&self, pos: u64, buf: &mut [u8]) -> Result<usize> {
387 let tail = self.get_tail();
389 if pos + (self.buf.size as u64) < tail {
390 return Err(ChannelError::ReaderBehind(tail - self.buf.size as u64));
391 }
392
393 Ok(unsafe { self.read_unsafe(pos, buf) })
394 }
395
396 unsafe fn read_unsafe(&self, pos: u64, buf: &mut [u8]) -> usize {
401 let len = (buf.len() as u64).min(self.get_tail().saturating_sub(pos));
403
404 if len == 0 {
406 return 0;
407 }
408
409 let ulen: usize = len.try_into().expect("pointer size less than 64b");
410
411 unsafe { self.buf.read(&mut buf[..ulen], pos) };
416
417 ulen
418 }
419
420 #[instrument(parent = &self.span, skip(self, waker))]
424 fn enqueue(&self, pos: u64, waker: Waker) {
425 debug!(pos, "channel enqueue");
426 self.queue
427 .lock()
428 .unwrap_or_else(|poison| poison.into_inner())
429 .push((pos, waker));
430 }
431
432 #[instrument(parent = &self.span, skip(self))]
434 fn schedule_writers(&self) {
435 if self
437 .tails
438 .read()
439 .unwrap_or_else(|poison| poison.into_inner())
440 .is_empty()
441 {
442 return;
443 }
444
445 let tail_pos = self.get_tail();
446 let head_pos = self.get_head().unwrap_or(tail_pos);
447 let mut queue = self
448 .queue
449 .lock()
450 .unwrap_or_else(|poison| poison.into_inner());
451
452 debug!(
453 queued = queue.len(),
454 head_pos, tail_pos, "channel schedule_writers"
455 );
456
457 queue
459 .extract_if(.., |(pos, _)| {
460 let wake = *pos < (head_pos + self.buf.size as u64) && *pos >= tail_pos;
461 if wake {
462 debug!(pos, "channel wake writer");
463 }
464 wake
465 })
466 .for_each(|(_, waker)| waker.wake());
467 }
468
469 #[instrument(parent = &self.span, skip(self))]
471 fn schedule_readers(&self) {
472 let tail_pos = self.get_tail();
473 let mut queue = self
474 .queue
475 .lock()
476 .unwrap_or_else(|poison| poison.into_inner());
477
478 debug!(queued = queue.len(), tail_pos, "channel schedule_readers");
479 queue
481 .extract_if(.., |(pos, _)| {
482 let wake = *pos < tail_pos;
483 if wake {
484 debug!(pos, "channel wake reader");
485 }
486 wake
487 })
488 .for_each(|(_, waker)| waker.wake());
489 }
490
491 pub fn new_writer(self: &Arc<Channel>) -> Writer {
501 Writer::Strong(self.new_strong_writer())
502 }
503
504 pub fn new_strong_writer(self: &Arc<Channel>) -> StrongWriter {
506 self.new_strong_writer_with_id(self.idf.generate())
507 }
508
509 fn new_strong_writer_with_id(self: &Arc<Channel>, id: Id) -> StrongWriter {
510 let mut tails = self
511 .tails
512 .write()
513 .unwrap_or_else(|poison| poison.into_inner());
514 let pos = Arc::new(AtomicU64::new(self.get_tail_locked(&tails)));
515 let pos_id = tails.push(pos.clone());
516 drop(tails);
517
518 StrongWriter::new(id, self.clone(), pos, Some(pos_id))
519 }
520
521 pub fn new_weak_writer(self: &Arc<Channel>) -> WeakWriter {
523 WeakWriter::new(self.idf.generate(), self.clone())
524 }
525
526 pub fn new_strong_reader(self: &Arc<Channel>) -> StrongReader {
531 let mut heads = self
532 .heads
533 .write()
534 .unwrap_or_else(|poison| poison.into_inner());
535 let pos = Arc::new(AtomicU64::new(
536 self.reader_start_pos(self.get_head_locked(&heads)),
537 ));
538 let id = heads.push(pos.clone());
539 drop(heads);
540
541 StrongReader::new(self.clone(), pos, id)
542 }
543
544 pub fn new_weak_reader(self: &Arc<Channel>) -> WeakReader {
548 WeakReader::new(self.clone(), self.reader_start_pos(self.get_head()))
549 }
550
551 pub fn reserve_slice(&self, len: u64) -> u64 {
556 self.next_tail.fetch_add(len, Ordering::SeqCst)
557 }
558
559 #[instrument(parent = &self.span, skip(self))]
563 pub fn terminate(&self) -> Result<()> {
564 debug!("terminate channel");
565
566 if self.draining.load(Ordering::Acquire) {
567 return Err(ChannelError::TerminateDraining);
568 }
569 self.terminated.store(true, Ordering::Release);
570
571 self.queue
573 .lock()
574 .unwrap_or_else(|poison| poison.into_inner())
575 .drain(..)
576 .for_each(|(_, waker)| waker.wake());
577
578 Ok(())
579 }
580
581 #[instrument(parent = &self.span, skip(self))]
586 pub fn drain(&self) -> Result<()> {
587 debug!("start draining channel");
588
589 if self.terminated.load(Ordering::Acquire) {
590 Err(ChannelError::DrainTerminated)
591 } else {
592 self.draining.store(true, Ordering::Release);
593 Ok(())
594 }
595 }
596}
597
598impl Drop for Channel {
599 fn drop(&mut self) {
600 let _ = self.terminate();
601 }
602}
603
604#[cfg(all(test, not(feature = "loom")))]
605mod tests {
606 use std::{
607 pin::pin,
608 task::{Context, Poll},
609 };
610
611 use futures::task::{noop_waker, noop_waker_ref};
612 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
613
614 use super::*;
615
616 #[test]
617 fn drop_backpressure_returns_zero_when_full() {
618 use std::task::Poll;
619
620 use futures::task::{Context, noop_waker};
621
622 let channel = Channel::with_parameters(2, Backpressure::Drop);
623 let _reader_guard = channel.new_strong_reader();
624 let waker = noop_waker();
625 let mut cx = Context::from_waker(&waker);
626 let writer = channel.new_writer();
627 let mut pinned = pin!(writer);
628
629 match pinned.as_mut().poll_write(&mut cx, &[1, 2]) {
630 Poll::Ready(Ok(2)) => {}
631 other => panic!("unexpected poll result: {:?}", other),
632 }
633 match pinned.as_mut().poll_write(&mut cx, &[3, 4]) {
634 Poll::Ready(Ok(0)) => {}
635 other => panic!("unexpected poll result: {:?}", other),
636 }
637 }
638
639 #[test]
640 fn drop_backpressure_leaves_existing_bytes_intact() {
641 use std::task::Poll;
642
643 use futures::task::{Context, noop_waker};
644
645 let channel = Channel::with_parameters(2, Backpressure::Drop);
646 let reader = channel.new_strong_reader();
647 let waker = noop_waker();
648 let mut cx = Context::from_waker(&waker);
649 let writer = channel.new_writer();
650 let mut pinned_writer = pin!(writer);
651
652 match pinned_writer.as_mut().poll_write(&mut cx, &[7, 8]) {
653 Poll::Ready(Ok(2)) => {}
654 other => panic!("unexpected poll result: {:?}", other),
655 }
656 match pinned_writer.as_mut().poll_write(&mut cx, &[9, 10, 11]) {
657 Poll::Ready(Ok(0)) => {}
658 other => panic!("unexpected poll result: {:?}", other),
659 }
660
661 let mut pinned_reader = pin!(reader);
662 let mut buf = [0u8; 2];
663 let mut rb = ReadBuf::new(&mut buf);
664 match pinned_reader.as_mut().poll_read(&mut cx, &mut rb) {
665 Poll::Ready(Ok(())) if rb.filled().len() == 2 => {}
666 other => panic!("unexpected poll result: {:?}", other),
667 }
668 assert_eq!(&buf, &[7, 8]);
669 }
670
671 #[test]
672 fn ring_buffer_write_clamps_to_capacity() {
673 let ring = RingBuffer::new(8);
674 let data = vec![42u8; 32];
675 unsafe { ring.write(&data, 0) };
676 let mut buf = [0u8; 8];
677 unsafe { ring.read(&mut buf, 0) };
678 assert!(buf.iter().all(|b| *b == 42));
679 }
680
681 #[test]
682 fn ring_buffer_read_large_destination_stays_within_bounds() {
683 let ring = RingBuffer::new(8);
684 let data: Vec<u8> = (0u8..8).collect();
685 unsafe { ring.write(&data, 0) };
686 let mut dst = [0u8; 16];
687 unsafe { ring.read(&mut dst, 0) };
688 assert_eq!(&dst[..8], &data[..]);
689 assert!(dst[8..].iter().all(|b| *b == 0));
690 }
691
692 fn new_buf(size: u8) -> RingBuffer {
693 let buf = RingBuffer::new(size as usize);
694 unsafe {
695 let dst = (&mut *buf.buf.get()).as_mut_ptr();
697 dst.copy_from_nonoverlapping((0..size).collect::<Vec<_>>().as_ptr(), size as usize);
698 }
699 buf
700 }
701
702 #[test]
703 fn test_ring_read_small() {
704 let mut dst = [0; 2];
705
706 let buf = new_buf(4);
707 unsafe { buf.read(&mut dst, 1) };
708
709 assert_eq!(dst, [1, 2]);
710 }
711
712 #[test]
713 fn test_ring_read_large() {
714 let mut dst = [0; 4];
715
716 let buf = new_buf(4);
717 unsafe { buf.read(&mut dst, 1) };
718
719 assert_eq!(dst, [1, 2, 3, 0]);
720 }
721
722 #[test]
723 fn test_ring_read_too_large() {
724 let mut dst = [0; 5];
725
726 let buf = new_buf(4);
727 unsafe { buf.read(&mut dst, 1) };
728
729 assert_eq!(&dst[..4], [1, 2, 3, 0].as_ref());
730 assert_eq!(dst[4], 0);
731 }
732
733 #[test]
734 fn test_ring_write_small() {
735 let src = [4; 2];
736
737 let buf = new_buf(4);
738 unsafe { buf.write(&src, 1) };
739
740 let mut dst = [0; 4];
741 unsafe {
742 let src = (&mut *buf.buf.get()).as_mut_ptr();
743 src.copy_to_nonoverlapping(dst.as_mut_ptr(), 4)
744 };
745 assert_eq!(dst, [0, 4, 4, 3]);
746 }
747
748 #[test]
749 fn test_ring_write_large() {
750 let src = [4; 4];
751
752 let buf = new_buf(4);
753 unsafe { buf.write(&src, 1) };
754
755 let mut dst = [0; 4];
756 unsafe {
757 let src = (&mut *buf.buf.get()).as_mut_ptr();
758 src.copy_to_nonoverlapping(dst.as_mut_ptr(), 4)
759 };
760 assert_eq!(dst, [4, 4, 4, 4]);
761 }
762
763 #[test]
764 fn test_ring_write_too_large() {
765 let src = [4; 5];
766
767 let buf = new_buf(4);
768 unsafe { buf.write(&src, 1) };
769
770 let mut dst = [0; 4];
771 unsafe {
772 let src = (&mut *buf.buf.get()).as_mut_ptr();
773 src.copy_to_nonoverlapping(dst.as_mut_ptr(), 4)
774 };
775 assert_eq!(dst, [4, 4, 4, 4]);
776 }
777
778 #[test]
779 fn test_channel_write() {
780 let channel = Channel::new(4);
781 assert_eq!(channel.write(0, &[]), 0);
782 assert_eq!(channel.write(4, &[0; 3]), 3);
783 assert_eq!(channel.write(1, &[0; 4]), 4);
784 assert_eq!(channel.write(1, &[0; 5]), 4);
785
786 channel
787 .heads
788 .write()
789 .unwrap()
790 .push(Arc::new(AtomicU64::new(1)));
791
792 assert_eq!(channel.write(5, &[0; 3]), 0);
793 assert_eq!(channel.write(3, &[0; 3]), 2);
794 }
795
796 #[test]
797 fn test_channel_write_returns_zero_when_full() {
798 let channel = Channel::new(4);
799 channel
800 .heads
801 .write()
802 .unwrap()
803 .push(Arc::new(AtomicU64::new(0)));
804 assert_eq!(channel.write(0, &[1, 2, 3, 4]), 4);
805 assert_eq!(channel.write(4, &[9]), 0);
806 }
807
808 #[test]
809 fn test_channel_read() {
810 let channel = Channel::new(4);
811 let mut buf = [0; 3];
812
813 let tail = Arc::new(AtomicU64::new(2));
814 channel.tails.write().unwrap().push(tail.clone());
815
816 assert_eq!(channel.read(0, &mut buf).unwrap(), 2);
817
818 tail.store(5, Ordering::Release);
819
820 assert_eq!(
821 channel.read(0, &mut buf).unwrap_err(),
822 ChannelError::ReaderBehind(1)
823 );
824 }
825
826 #[test]
827 fn test_channel_read_unsafe() {
828 let channel = Channel::new(4);
829 let mut buf = [0; 3];
830
831 assert_eq!(unsafe { channel.read_unsafe(0, &mut buf) }, 0);
832
833 let tail = Arc::new(AtomicU64::new(2));
834 channel.tails.write().unwrap().push(tail.clone());
835
836 assert_eq!(channel.read(0, &mut buf).unwrap(), 2);
837
838 tail.store(5, Ordering::Release);
839
840 assert_eq!(channel.read(1, &mut buf).unwrap(), 3);
841 assert_eq!(channel.read(3, &mut buf).unwrap(), 2);
842 }
843
844 #[test]
845 fn test_channel_schedule_writers() {
846 let channel = Channel::new(4);
847 channel
848 .tails
849 .write()
850 .unwrap()
851 .push(Arc::new(AtomicU64::new(0)));
852
853 let mut queue = channel.queue.lock().unwrap();
854 queue.push((0, noop_waker()));
855 queue.push((1, noop_waker()));
856 queue.push((4, noop_waker()));
857 drop(queue);
858
859 channel.schedule_writers();
860
861 let queue = channel.queue.lock().unwrap();
862 assert_eq!(queue.len(), 1);
863 assert_eq!(queue.first().unwrap().0, 4);
864 }
865
866 #[test]
867 fn test_channel_schedule_readers() {
868 let channel = Channel::new(4);
869 channel
870 .tails
871 .write()
872 .unwrap()
873 .push(Arc::new(AtomicU64::new(2)));
874 channel
875 .heads
876 .write()
877 .unwrap()
878 .push(Arc::new(AtomicU64::new(1)));
879
880 let mut queue = channel.queue.lock().unwrap();
881 queue.push((1, noop_waker()));
882 queue.push((2, noop_waker()));
883 drop(queue);
884
885 channel.schedule_readers();
886
887 let queue = channel.queue.lock().unwrap();
888 assert_eq!(queue.len(), 1);
889 assert_eq!(queue.first().unwrap().0, 2);
890 }
891
892 #[test]
893 fn test_channel_new_writer() {
894 let channel = Arc::new(Channel::new(4));
895 let writer = channel.new_writer();
896 assert_eq!(channel.tails.read().unwrap().num_elements(), 1);
897 drop(writer);
898 assert!(channel.tails.read().unwrap().is_empty());
899 }
900
901 #[test]
902 fn test_channel_new_strong_reader() {
903 let channel = Arc::new(Channel::new(4));
904 channel
905 .tails
906 .write()
907 .unwrap()
908 .push(Arc::new(AtomicU64::new(5)));
909 let reader = channel.new_strong_reader();
910 assert_eq!(channel.heads.read().unwrap().num_elements(), 1);
911 assert_eq!(reader.pos.load(Ordering::Acquire), 5);
912 drop(reader);
913 assert!(channel.heads.read().unwrap().is_empty());
914 }
915
916 #[test]
917 fn test_writer_poll_write() {
918 let mut cx = Context::from_waker(noop_waker_ref());
919 let channel = Arc::new(Channel::new(4));
920 channel
921 .heads
922 .write()
923 .unwrap()
924 .push(Arc::new(AtomicU64::new(0)));
925 let mut writer = pin!(channel.new_strong_writer());
926
927 assert!(matches!(
928 writer.as_mut().poll_write(&mut cx, &[1, 2, 3]),
929 Poll::Ready(Ok(3))
930 ));
931 assert_eq!(channel.next_tail.load(Ordering::Acquire), 3);
932 assert_eq!(writer.pos.load(Ordering::Acquire), 3);
933 assert_eq!(writer.rem, 0);
934
935 assert!(matches!(
936 writer.as_mut().poll_write(&mut cx, &[1, 2, 3]),
937 Poll::Ready(Ok(1))
938 ));
939 assert_eq!(channel.next_tail.load(Ordering::Acquire), 6);
940 assert_eq!(writer.pos.load(Ordering::Acquire), 4);
941 assert_eq!(writer.rem, 2);
942
943 assert!(writer.as_mut().poll_write(&mut cx, &[1, 2, 3]).is_pending());
944 }
945
946 #[test]
947 fn test_writer_poll_strong_read() {
948 let mut cx = Context::from_waker(noop_waker_ref());
949 let channel = Arc::new(Channel::new(4));
950 channel
951 .tails
952 .write()
953 .unwrap()
954 .push(Arc::new(AtomicU64::new(4)));
955 unsafe {
956 let dst = (&mut *channel.buf.buf.get()).as_mut_ptr();
957 dst.copy_from_nonoverlapping((1..=4).collect::<Vec<_>>().as_ptr(), 4);
958 }
959 let mut reader = pin!(channel.new_strong_reader());
960 reader.pos.store(0, Ordering::Release);
961
962 let mut buf = [0; 3];
963 let mut rb = ReadBuf::new(&mut buf);
964 assert!(matches!(
965 reader.as_mut().poll_read(&mut cx, &mut rb),
966 Poll::Ready(Ok(()))
967 ));
968 assert_eq!(rb.filled().len(), 3);
969 assert_eq!(reader.pos.load(Ordering::Acquire), 3);
970
971 let mut buf = [0; 3];
972 let mut rb = ReadBuf::new(&mut buf);
973 assert!(matches!(
974 reader.as_mut().poll_read(&mut cx, &mut rb),
975 Poll::Ready(Ok(()))
976 ));
977 assert_eq!(rb.filled().len(), 1);
978 assert_eq!(reader.pos.load(Ordering::Acquire), 4);
979
980 let mut buf = [0; 3];
981 let mut rb = ReadBuf::new(&mut buf);
982 assert!(matches!(
983 reader.as_mut().poll_read(&mut cx, &mut rb),
984 Poll::Pending
985 ));
986 }
987
988 #[test]
989 fn test_writer_poll_weak_read() {
990 let mut cx = Context::from_waker(noop_waker_ref());
991 let channel = Arc::new(Channel::new(4));
992 let tail_pos = Arc::new(AtomicU64::new(4));
993 channel.tails.write().unwrap().push(tail_pos.clone());
994 unsafe {
995 let dst = (&mut *channel.buf.buf.get()).as_mut_ptr();
996 dst.copy_from_nonoverlapping((1..=4).collect::<Vec<_>>().as_ptr(), 4);
997 }
998 let mut reader = pin!(channel.new_weak_reader());
999 reader.pos = 0;
1000
1001 let mut buf = [0; 4];
1002 let mut rb = ReadBuf::new(&mut buf);
1003 assert!(matches!(
1004 reader.as_mut().poll_read(&mut cx, &mut rb),
1005 Poll::Ready(Ok(()))
1006 ));
1007 assert_eq!(rb.filled().len(), 4);
1008 assert_eq!(reader.pos, 4);
1009
1010 let mut buf = [0; 1];
1011 let mut rb = ReadBuf::new(&mut buf);
1012 assert!(matches!(
1013 reader.as_mut().poll_read(&mut cx, &mut rb),
1014 Poll::Pending
1015 ));
1016
1017 tail_pos.store(9, Ordering::Release);
1018
1019 let mut buf = [0; 1];
1020 let mut rb = ReadBuf::new(&mut buf);
1021 assert!(matches!(
1022 reader.as_mut().poll_read(&mut cx, &mut rb),
1023 Poll::Ready(Err(_))
1024 ));
1025 }
1026
1027 #[test]
1028 fn test_writer_wraparound_preserves_order() {
1029 let mut cx = Context::from_waker(noop_waker_ref());
1030 let channel = Arc::new(Channel::new(4));
1031 let mut writer = pin!(channel.new_strong_writer());
1032 let mut reader = pin!(channel.new_strong_reader());
1033
1034 let mut buf = [0u8; 3];
1035 {
1036 let mut rb = ReadBuf::new(&mut buf);
1037 assert!(matches!(
1038 writer.as_mut().poll_write(&mut cx, &[1, 2, 3]),
1039 Poll::Ready(Ok(3))
1040 ));
1041 assert!(matches!(
1042 reader.as_mut().poll_read(&mut cx, &mut rb),
1043 Poll::Ready(Ok(()))
1044 ));
1045 assert_eq!(rb.filled().len(), 3);
1046 }
1047 assert_eq!(buf, [1, 2, 3]);
1048
1049 let mut buf2 = [0u8; 2];
1050 let mut rb2 = ReadBuf::new(&mut buf2);
1051 assert!(matches!(
1052 writer.as_mut().poll_write(&mut cx, &[4, 5]),
1053 Poll::Ready(Ok(2))
1054 ));
1055 assert!(matches!(
1056 reader.as_mut().poll_read(&mut cx, &mut rb2),
1057 Poll::Ready(Ok(()))
1058 ));
1059 assert_eq!(rb2.filled().len(), 2);
1060 assert_eq!(buf2, [4, 5]);
1061 }
1062
1063 #[test]
1064 fn test_terminate() {
1065 let mut cx = Context::from_waker(noop_waker_ref());
1066 let channel = Arc::new(Channel::new(4));
1067 let mut writer = channel.new_writer();
1068 let mut strong_reader = pin!(channel.new_strong_reader());
1069 let mut weak_reader = pin!(channel.new_weak_reader());
1070 let mut weak_reader2 = pin!(channel.new_weak_reader());
1071 let mut buf = [0; 4];
1072 let mut rb = ReadBuf::new(&mut buf);
1073
1074 writer.terminate();
1075 let mut writer = pin!(writer);
1076 assert!(check_poll_aborted(
1077 writer.as_mut().poll_write(&mut cx, &[]).map_ok(|_| ())
1078 ));
1079
1080 assert!(matches!(
1081 strong_reader.as_mut().poll_read(&mut cx, &mut rb),
1082 Poll::Pending
1083 ));
1084 assert!(matches!(
1085 weak_reader.as_mut().poll_read(&mut cx, &mut rb),
1086 Poll::Pending
1087 ));
1088 assert!(matches!(
1089 weak_reader2.as_mut().poll_read(&mut cx, &mut rb),
1090 Poll::Pending
1091 ));
1092
1093 strong_reader.terminate();
1094 assert!(check_poll_aborted(
1095 strong_reader.poll_read(&mut cx, &mut rb)
1096 ));
1097
1098 assert!(matches!(
1099 weak_reader.as_mut().poll_read(&mut cx, &mut rb),
1100 Poll::Pending
1101 ));
1102 assert!(matches!(
1103 weak_reader2.as_mut().poll_read(&mut cx, &mut rb),
1104 Poll::Pending
1105 ));
1106
1107 weak_reader.terminate();
1108 assert!(check_poll_aborted(weak_reader.poll_read(&mut cx, &mut rb)));
1109
1110 assert!(matches!(
1111 weak_reader2.as_mut().poll_read(&mut cx, &mut rb),
1112 Poll::Pending
1113 ));
1114
1115 channel.terminate().unwrap();
1116 assert!(check_poll_aborted(weak_reader2.poll_read(&mut cx, &mut rb)));
1117 }
1118
1119 #[test]
1120 fn strong_reader_terminate_aborts_read() {
1121 let mut cx = Context::from_waker(noop_waker_ref());
1122 let channel = Arc::new(Channel::new(16));
1123 let reader = channel.new_strong_reader();
1124 reader.terminate();
1125 let mut reader = pin!(reader);
1126 let mut buf = [0u8; 1];
1127 let mut rb = ReadBuf::new(&mut buf);
1128 assert!(check_poll_aborted(
1129 reader.as_mut().poll_read(&mut cx, &mut rb)
1130 ));
1131 }
1132
1133 #[test]
1134 fn writer_terminate_aborts_write() {
1135 let mut cx = Context::from_waker(noop_waker_ref());
1136 let channel = Arc::new(Channel::new(16));
1137 let mut writer = channel.new_writer();
1138 writer.terminate();
1139 let mut writer = pin!(writer);
1140 assert!(check_poll_aborted(
1141 writer.as_mut().poll_write(&mut cx, &[]).map_ok(|_| ())
1142 ));
1143 }
1144
1145 fn check_poll_aborted(poll: Poll<std::io::Result<()>>) -> bool {
1146 match poll {
1147 Poll::Ready(r) => r.is_err_and(|e| e.kind() == std::io::ErrorKind::ConnectionAborted),
1148 _ => false,
1149 }
1150 }
1151
1152 #[test]
1153 fn test_drop_backpressure_writer_returns_zero_when_full() {
1154 let mut cx = Context::from_waker(noop_waker_ref());
1155 let channel = Arc::new(Channel::with_parameters(4, Backpressure::Drop));
1157 let mut writer = pin!(channel.new_writer());
1158 let _reader = pin!(channel.new_strong_reader());
1160
1161 assert!(matches!(
1163 writer.as_mut().poll_write(&mut cx, &[1, 2, 3, 4]),
1164 Poll::Ready(Ok(4))
1165 ));
1166 assert!(matches!(
1168 writer.as_mut().poll_write(&mut cx, &[9]),
1169 Poll::Ready(Ok(0))
1170 ));
1171 }
1172}
1173
1174#[cfg(all(test, feature = "loom"))]
1175mod loom_tests {
1176 use futures::future;
1177 use loom::future::block_on;
1178 use tokio::io::{AsyncReadExt, AsyncWriteExt};
1179
1180 use super::*;
1181
1182 #[test]
1183 fn strong_reader_and_writer_progress() {
1184 loom::model(|| {
1185 let channel = Channel::new(4);
1186 let mut reader = channel.new_strong_reader();
1187 let mut writer = channel.new_writer();
1188
1189 block_on(async move {
1190 let mut buf = [0u8; 1];
1191 let read = reader.read_exact(&mut buf);
1192 let write = writer.write_all(&[42]);
1193 let (_w, r) = future::join(write, read).await;
1194 r.unwrap();
1195 assert_eq!(buf[0], 42);
1196 });
1197 });
1198 }
1199}