salvo_core/conn/
joined.rs

1//! JoinListener and its implementations.
2use std::fmt::{self, Debug, Formatter};
3use std::io::Result as IoResult;
4use std::pin::Pin;
5use std::sync::Arc;
6use std::task::{Context, Poll};
7
8use futures_util::future::{BoxFuture, FutureExt};
9use pin_project::pin_project;
10use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
11use tokio_util::sync::CancellationToken;
12
13use crate::conn::{Coupler, Holding, HttpBuilder};
14use crate::fuse::ArcFuseFactory;
15use crate::service::HyperHandler;
16
17use super::{Accepted, Acceptor, Listener};
18
19/// An Coupler for JoinedListener.
20pub enum JoinedCoupler<A, B> {
21    #[allow(missing_docs)]
22    A(A),
23    #[allow(missing_docs)]
24    B(B),
25}
26
27impl<A, B> Coupler for JoinedCoupler<A, B>
28where
29    A: Coupler + Unpin + 'static,
30    B: Coupler + Unpin + 'static,
31{
32    type Stream = JoinedStream<A::Stream, B::Stream>;
33
34    fn couple(
35        &self,
36        stream: Self::Stream,
37        handler: HyperHandler,
38        builder: Arc<HttpBuilder>,
39        graceful_stop_token: Option<CancellationToken>,
40    ) -> BoxFuture<'static, IoResult<()>> {
41        match (self, stream) {
42            (Self::A(a), JoinedStream::A(stream)) => a
43                .couple(stream, handler, builder, graceful_stop_token)
44                .boxed(),
45            (Self::B(b), JoinedStream::B(stream)) => b
46                .couple(stream, handler, builder, graceful_stop_token)
47                .boxed(),
48            _ => unreachable!(),
49        }
50    }
51}
52
53impl<A, B> Debug for JoinedCoupler<A, B> {
54    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
55        f.debug_struct("JoinedCoupler").finish()
56    }
57}
58
59/// An I/O stream for JoinedListener.
60pub enum JoinedStream<A, B> {
61    #[allow(missing_docs)]
62    A(A),
63    #[allow(missing_docs)]
64    B(B),
65}
66
67impl<A, B> Debug for JoinedStream<A, B> {
68    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
69        f.debug_struct("JoinedStream").finish()
70    }
71}
72
73impl<A, B> AsyncRead for JoinedStream<A, B>
74where
75    A: AsyncRead + Send + Unpin + 'static,
76    B: AsyncRead + Send + Unpin + 'static,
77{
78    #[inline]
79    fn poll_read(
80        self: Pin<&mut Self>,
81        cx: &mut Context<'_>,
82        buf: &mut ReadBuf<'_>,
83    ) -> Poll<IoResult<()>> {
84        match &mut self.get_mut() {
85            Self::A(a) => Pin::new(a).poll_read(cx, buf),
86            Self::B(b) => Pin::new(b).poll_read(cx, buf),
87        }
88    }
89}
90
91impl<A, B> AsyncWrite for JoinedStream<A, B>
92where
93    A: AsyncWrite + Send + Unpin + 'static,
94    B: AsyncWrite + Send + Unpin + 'static,
95{
96    #[inline]
97    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<IoResult<usize>> {
98        match &mut self.get_mut() {
99            Self::A(a) => Pin::new(a).poll_write(cx, buf),
100            Self::B(b) => Pin::new(b).poll_write(cx, buf),
101        }
102    }
103
104    #[inline]
105    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
106        match &mut self.get_mut() {
107            Self::A(a) => Pin::new(a).poll_flush(cx),
108            Self::B(b) => Pin::new(b).poll_flush(cx),
109        }
110    }
111
112    #[inline]
113    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
114        match &mut self.get_mut() {
115            Self::A(a) => Pin::new(a).poll_shutdown(cx),
116            Self::B(b) => Pin::new(b).poll_shutdown(cx),
117        }
118    }
119}
120
121/// `JoinedListener` is a listener that can join two listeners.
122#[pin_project]
123pub struct JoinedListener<A, B> {
124    #[pin]
125    a: A,
126    #[pin]
127    b: B,
128}
129
130impl<A: Debug, B: Debug> Debug for JoinedListener<A, B> {
131    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
132        f.debug_struct("JoinedListener")
133            .field("a", &self.a)
134            .field("b", &self.b)
135            .finish()
136    }
137}
138
139impl<A, B> JoinedListener<A, B> {
140    /// Create a new `JoinedListener`.
141    #[inline]
142    pub fn new(a: A, b: B) -> Self {
143        Self { a, b }
144    }
145}
146impl<A, B> Listener for JoinedListener<A, B>
147where
148    A: Listener + Send + Unpin + 'static,
149    B: Listener + Send + Unpin + 'static,
150    A::Acceptor: Acceptor + Send + Unpin + 'static,
151    B::Acceptor: Acceptor + Send + Unpin + 'static,
152{
153    type Acceptor = JoinedAcceptor<A::Acceptor, B::Acceptor>;
154
155    async fn try_bind(self) -> crate::Result<Self::Acceptor> {
156        let a = self.a.try_bind().await?;
157        let b = self.b.try_bind().await?;
158        let holdings = a
159            .holdings()
160            .iter()
161            .chain(b.holdings().iter())
162            .cloned()
163            .collect();
164        Ok(JoinedAcceptor { a, b, holdings })
165    }
166}
167
168/// `JoinedAcceptor` is an acceptor that can accept connections from two different acceptors.
169pub struct JoinedAcceptor<A, B> {
170    a: A,
171    b: B,
172    holdings: Vec<Holding>,
173}
174
175impl<A: Debug, B: Debug> Debug for JoinedAcceptor<A, B> {
176    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
177        f.debug_struct("JoinedAcceptor")
178            .field("a", &self.a)
179            .field("b", &self.b)
180            .field("holdings", &self.holdings)
181            .finish()
182    }
183}
184
185impl<A, B> JoinedAcceptor<A, B>
186where
187    A: Acceptor,
188    B: Acceptor,
189{
190    /// Create a new `JoinedAcceptor`.
191    pub fn new(a: A, b: B) -> Self {
192        let holdings = a
193            .holdings()
194            .iter()
195            .chain(b.holdings().iter())
196            .cloned()
197            .collect();
198        Self { a, b, holdings }
199    }
200}
201
202impl<A, B> Acceptor for JoinedAcceptor<A, B>
203where
204    A: Acceptor + Send + Unpin + 'static,
205    B: Acceptor + Send + Unpin + 'static,
206    A::Coupler: Coupler<Stream = A::Stream> + Unpin + 'static,
207    B::Coupler: Coupler<Stream = B::Stream> + Unpin + 'static,
208    A::Stream: Unpin + Send + 'static,
209    B::Stream: Unpin + Send + 'static,
210{
211    type Coupler = JoinedCoupler<A::Coupler, B::Coupler>;
212    type Stream = JoinedStream<A::Stream, B::Stream>;
213
214    #[inline]
215    fn holdings(&self) -> &[Holding] {
216        &self.holdings
217    }
218
219    #[inline]
220    async fn accept(
221        &mut self,
222        fuse_factory: Option<ArcFuseFactory>,
223    ) -> IoResult<Accepted<Self::Coupler, Self::Stream>> {
224        tokio::select! {
225            accepted = self.a.accept(fuse_factory.clone()) => {
226                Ok(accepted?.map_into(JoinedCoupler::A, JoinedStream::A))
227            }
228            accepted = self.b.accept(fuse_factory) => {
229                Ok(accepted?.map_into(JoinedCoupler::B, JoinedStream::B))
230            }
231        }
232    }
233}
234
235#[cfg(test)]
236mod tests {
237    use tokio::io::{AsyncReadExt, AsyncWriteExt};
238    use tokio::net::TcpStream;
239
240    use super::*;
241    use crate::conn::TcpListener;
242
243    #[tokio::test]
244    async fn test_joined_listener() {
245        let addr1 = std::net::SocketAddr::from(([127, 0, 0, 1], 6978));
246        let addr2 = std::net::SocketAddr::from(([127, 0, 0, 1], 6979));
247
248        let mut acceptor = TcpListener::new(addr1)
249            .join(TcpListener::new(addr2))
250            .bind()
251            .await;
252        tokio::spawn(async move {
253            let mut stream = TcpStream::connect(addr1).await.unwrap();
254            stream.write_i32(50).await.unwrap();
255
256            let mut stream = TcpStream::connect(addr2).await.unwrap();
257            stream.write_i32(100).await.unwrap();
258        });
259        let Accepted { mut stream, .. } = acceptor.accept(None).await.unwrap();
260        let first = stream.read_i32().await.unwrap();
261        let Accepted { mut stream, .. } = acceptor.accept(None).await.unwrap();
262        let second = stream.read_i32().await.unwrap();
263        assert_eq!(first + second, 150);
264    }
265}