trillium_websockets/
bidirectional_stream.rs

1use futures_lite::Stream;
2use std::{
3    fmt::{Debug, Formatter, Result},
4    pin::Pin,
5    task::{Context, Poll},
6};
7pin_project_lite::pin_project! {
8
9
10pub(crate) struct BidirectionalStream<I, O> {
11    pub(crate) inbound: Option<I>,
12    #[pin]
13    pub(crate) outbound: O,
14}
15}
16impl<I, O> Debug for BidirectionalStream<I, O> {
17    fn fmt(&self, f: &mut Formatter<'_>) -> Result {
18        f.debug_struct("BidirectionalStream")
19            .field(
20                "inbound",
21                &match self.inbound {
22                    Some(_) => "Some(_)",
23                    None => "None",
24                },
25            )
26            .field("outbound", &"..")
27            .finish()
28    }
29}
30
31#[derive(Debug)]
32pub(crate) enum Direction<I, O> {
33    Inbound(I),
34    Outbound(O),
35}
36
37impl<I, O> Stream for BidirectionalStream<I, O>
38where
39    I: Stream + Unpin + Send + Sync + 'static,
40    O: Stream + Send + Sync + 'static,
41{
42    type Item = Direction<I::Item, O::Item>;
43
44    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
45        let this = self.project();
46
47        macro_rules! poll_inbound {
48            () => {
49                if let Some(inbound) = &mut *this.inbound {
50                    match Pin::new(inbound).poll_next(cx) {
51                        Poll::Ready(Some(t)) => return Poll::Ready(Some(Direction::Inbound(t))),
52                        Poll::Ready(None) => return Poll::Ready(None),
53                        _ => (),
54                    }
55                }
56            };
57        }
58        macro_rules! poll_outbound {
59            () => {
60                match this.outbound.poll_next(cx) {
61                    Poll::Ready(Some(t)) => return Poll::Ready(Some(Direction::Outbound(t))),
62                    Poll::Ready(None) => return Poll::Ready(None),
63                    _ => (),
64                }
65            };
66        }
67
68        if fastrand::bool() {
69            poll_inbound!();
70            poll_outbound!();
71        } else {
72            poll_outbound!();
73            poll_inbound!();
74        }
75
76        Poll::Pending
77    }
78}