1use std::fmt::{self, Debug, Formatter};
3use std::io::Result as IoResult;
4use std::pin::Pin;
5use std::sync::Arc;
6use std::task::{Context, Poll};
7
8use futures_util::future::{BoxFuture, FutureExt};
9use pin_project::pin_project;
10use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
11use tokio_util::sync::CancellationToken;
12
13use crate::conn::{Coupler, Holding, HttpBuilder};
14use crate::fuse::ArcFuseFactory;
15use crate::service::HyperHandler;
16
17use super::{Accepted, Acceptor, Listener};
18
19pub enum JoinedCoupler<A, B> {
21 #[allow(missing_docs)]
22 A(A),
23 #[allow(missing_docs)]
24 B(B),
25}
26
27impl<A, B> Coupler for JoinedCoupler<A, B>
28where
29 A: Coupler + Unpin + 'static,
30 B: Coupler + Unpin + 'static,
31{
32 type Stream = JoinedStream<A::Stream, B::Stream>;
33
34 fn couple(
35 &self,
36 stream: Self::Stream,
37 handler: HyperHandler,
38 builder: Arc<HttpBuilder>,
39 graceful_stop_token: Option<CancellationToken>,
40 ) -> BoxFuture<'static, IoResult<()>> {
41 match (self, stream) {
42 (Self::A(a), JoinedStream::A(stream)) => a
43 .couple(stream, handler, builder, graceful_stop_token)
44 .boxed(),
45 (Self::B(b), JoinedStream::B(stream)) => b
46 .couple(stream, handler, builder, graceful_stop_token)
47 .boxed(),
48 _ => unreachable!(),
49 }
50 }
51}
52
53impl<A, B> Debug for JoinedCoupler<A, B> {
54 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
55 f.debug_struct("JoinedCoupler").finish()
56 }
57}
58
59pub enum JoinedStream<A, B> {
61 #[allow(missing_docs)]
62 A(A),
63 #[allow(missing_docs)]
64 B(B),
65}
66
67impl<A, B> Debug for JoinedStream<A, B> {
68 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
69 f.debug_struct("JoinedStream").finish()
70 }
71}
72
73impl<A, B> AsyncRead for JoinedStream<A, B>
74where
75 A: AsyncRead + Send + Unpin + 'static,
76 B: AsyncRead + Send + Unpin + 'static,
77{
78 #[inline]
79 fn poll_read(
80 self: Pin<&mut Self>,
81 cx: &mut Context<'_>,
82 buf: &mut ReadBuf<'_>,
83 ) -> Poll<IoResult<()>> {
84 match &mut self.get_mut() {
85 Self::A(a) => Pin::new(a).poll_read(cx, buf),
86 Self::B(b) => Pin::new(b).poll_read(cx, buf),
87 }
88 }
89}
90
91impl<A, B> AsyncWrite for JoinedStream<A, B>
92where
93 A: AsyncWrite + Send + Unpin + 'static,
94 B: AsyncWrite + Send + Unpin + 'static,
95{
96 #[inline]
97 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<IoResult<usize>> {
98 match &mut self.get_mut() {
99 Self::A(a) => Pin::new(a).poll_write(cx, buf),
100 Self::B(b) => Pin::new(b).poll_write(cx, buf),
101 }
102 }
103
104 #[inline]
105 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
106 match &mut self.get_mut() {
107 Self::A(a) => Pin::new(a).poll_flush(cx),
108 Self::B(b) => Pin::new(b).poll_flush(cx),
109 }
110 }
111
112 #[inline]
113 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
114 match &mut self.get_mut() {
115 Self::A(a) => Pin::new(a).poll_shutdown(cx),
116 Self::B(b) => Pin::new(b).poll_shutdown(cx),
117 }
118 }
119}
120
121#[pin_project]
123pub struct JoinedListener<A, B> {
124 #[pin]
125 a: A,
126 #[pin]
127 b: B,
128}
129
130impl<A: Debug, B: Debug> Debug for JoinedListener<A, B> {
131 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
132 f.debug_struct("JoinedListener")
133 .field("a", &self.a)
134 .field("b", &self.b)
135 .finish()
136 }
137}
138
139impl<A, B> JoinedListener<A, B> {
140 #[inline]
142 pub fn new(a: A, b: B) -> Self {
143 Self { a, b }
144 }
145}
146impl<A, B> Listener for JoinedListener<A, B>
147where
148 A: Listener + Send + Unpin + 'static,
149 B: Listener + Send + Unpin + 'static,
150 A::Acceptor: Acceptor + Send + Unpin + 'static,
151 B::Acceptor: Acceptor + Send + Unpin + 'static,
152{
153 type Acceptor = JoinedAcceptor<A::Acceptor, B::Acceptor>;
154
155 async fn try_bind(self) -> crate::Result<Self::Acceptor> {
156 let a = self.a.try_bind().await?;
157 let b = self.b.try_bind().await?;
158 let holdings = a
159 .holdings()
160 .iter()
161 .chain(b.holdings().iter())
162 .cloned()
163 .collect();
164 Ok(JoinedAcceptor { a, b, holdings })
165 }
166}
167
168pub struct JoinedAcceptor<A, B> {
170 a: A,
171 b: B,
172 holdings: Vec<Holding>,
173}
174
175impl<A: Debug, B: Debug> Debug for JoinedAcceptor<A, B> {
176 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
177 f.debug_struct("JoinedAcceptor")
178 .field("a", &self.a)
179 .field("b", &self.b)
180 .field("holdings", &self.holdings)
181 .finish()
182 }
183}
184
185impl<A, B> JoinedAcceptor<A, B>
186where
187 A: Acceptor,
188 B: Acceptor,
189{
190 pub fn new(a: A, b: B) -> Self {
192 let holdings = a
193 .holdings()
194 .iter()
195 .chain(b.holdings().iter())
196 .cloned()
197 .collect();
198 Self { a, b, holdings }
199 }
200}
201
202impl<A, B> Acceptor for JoinedAcceptor<A, B>
203where
204 A: Acceptor + Send + Unpin + 'static,
205 B: Acceptor + Send + Unpin + 'static,
206 A::Coupler: Coupler<Stream = A::Stream> + Unpin + 'static,
207 B::Coupler: Coupler<Stream = B::Stream> + Unpin + 'static,
208 A::Stream: Unpin + Send + 'static,
209 B::Stream: Unpin + Send + 'static,
210{
211 type Coupler = JoinedCoupler<A::Coupler, B::Coupler>;
212 type Stream = JoinedStream<A::Stream, B::Stream>;
213
214 #[inline]
215 fn holdings(&self) -> &[Holding] {
216 &self.holdings
217 }
218
219 #[inline]
220 async fn accept(
221 &mut self,
222 fuse_factory: Option<ArcFuseFactory>,
223 ) -> IoResult<Accepted<Self::Coupler, Self::Stream>> {
224 tokio::select! {
225 accepted = self.a.accept(fuse_factory.clone()) => {
226 Ok(accepted?.map_into(JoinedCoupler::A, JoinedStream::A))
227 }
228 accepted = self.b.accept(fuse_factory) => {
229 Ok(accepted?.map_into(JoinedCoupler::B, JoinedStream::B))
230 }
231 }
232 }
233}
234
235#[cfg(test)]
236mod tests {
237 use tokio::io::{AsyncReadExt, AsyncWriteExt};
238 use tokio::net::TcpStream;
239
240 use super::*;
241 use crate::conn::TcpListener;
242
243 #[tokio::test]
244 async fn test_joined_listener() {
245 let addr1 = std::net::SocketAddr::from(([127, 0, 0, 1], 6978));
246 let addr2 = std::net::SocketAddr::from(([127, 0, 0, 1], 6979));
247
248 let mut acceptor = TcpListener::new(addr1)
249 .join(TcpListener::new(addr2))
250 .bind()
251 .await;
252 tokio::spawn(async move {
253 let mut stream = TcpStream::connect(addr1).await.unwrap();
254 stream.write_i32(50).await.unwrap();
255
256 let mut stream = TcpStream::connect(addr2).await.unwrap();
257 stream.write_i32(100).await.unwrap();
258 });
259 let Accepted { mut stream, .. } = acceptor.accept(None).await.unwrap();
260 let first = stream.read_i32().await.unwrap();
261 let Accepted { mut stream, .. } = acceptor.accept(None).await.unwrap();
262 let second = stream.read_i32().await.unwrap();
263 assert_eq!(first + second, 150);
264 }
265}