1use std::{
6 collections::{HashMap, VecDeque},
7 fmt,
8 future::IntoFuture,
9 pin::Pin,
10 sync::{
11 atomic::{AtomicBool, Ordering},
12 Arc, Mutex,
13 },
14 task::{Context, Poll, Waker},
15};
16
17use async_trait::async_trait;
18use futures::{stream::FuturesUnordered, AsyncRead, AsyncWrite, Future, FutureExt, StreamExt};
19use tokio::sync::{oneshot, Notify};
20use yamux::Connection;
21
22use crate::{
23 future::{ReadId, ReturnStream},
24 log::{debug, error, info, trace, warn},
25 InternalId, UidMux,
26};
27
28pub use yamux::{Config, ConnectionError, Mode, Stream};
29
30type Result<T, E = ConnectionError> = std::result::Result<T, E>;
31
32#[derive(Debug, Clone, Copy)]
33enum Role {
34 Client,
35 Server,
36}
37
38impl fmt::Display for Role {
39 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40 match self {
41 Role::Client => write!(f, "Client"),
42 Role::Server => write!(f, "Server"),
43 }
44 }
45}
46
47#[derive(Debug)]
49pub struct Yamux<Io> {
50 role: Role,
51 conn: Connection<Io>,
52 queue: Arc<Mutex<Queue>>,
53 close_notify: Arc<Notify>,
54 shutdown_notify: Arc<AtomicBool>,
55}
56
57#[derive(Debug, Default)]
58struct Queue {
59 waiting: HashMap<InternalId, oneshot::Sender<Stream>>,
60 ready: HashMap<InternalId, Stream>,
61 alloc: usize,
62 waker: Option<Waker>,
63}
64
65impl<Io> Yamux<Io> {
66 pub fn control(&self) -> YamuxCtrl {
68 YamuxCtrl {
69 role: self.role,
70 queue: self.queue.clone(),
71 close_notify: self.close_notify.clone(),
72 shutdown_notify: self.shutdown_notify.clone(),
73 }
74 }
75}
76
77impl<Io> Yamux<Io>
78where
79 Io: AsyncWrite + AsyncRead + Unpin,
80{
81 pub fn new(io: Io, config: Config, mode: Mode) -> Self {
83 let role = match mode {
84 Mode::Client => Role::Client,
85 Mode::Server => Role::Server,
86 };
87
88 Self {
89 role,
90 conn: Connection::new(io, config, mode),
91 queue: Default::default(),
92 close_notify: Default::default(),
93 shutdown_notify: Default::default(),
94 }
95 }
96}
97
98impl<Io> IntoFuture for Yamux<Io>
99where
100 Io: AsyncWrite + AsyncRead + Unpin,
101{
102 type Output = Result<()>;
103 type IntoFuture = YamuxFuture<Io>;
104
105 fn into_future(self) -> Self::IntoFuture {
106 YamuxFuture {
107 role: self.role,
108 conn: self.conn,
109 incoming: Default::default(),
110 allocated: Default::default(),
111 outgoing: Default::default(),
112 queue: self.queue,
113 closed: false,
114 remote_closed: false,
115 close_notify: self.close_notify,
116 shutdown_notify: self.shutdown_notify,
117 }
118 }
119}
120
121#[derive(Debug)]
123#[must_use = "futures do nothing unless you `.await` or poll them"]
124pub struct YamuxFuture<Io> {
125 role: Role,
126 conn: Connection<Io>,
127 incoming: FuturesUnordered<ReadId<Stream>>,
129 allocated: VecDeque<Stream>,
131 outgoing: FuturesUnordered<ReturnStream<Stream>>,
134 queue: Arc<Mutex<Queue>>,
135 closed: bool,
137 remote_closed: bool,
139 close_notify: Arc<Notify>,
140 shutdown_notify: Arc<AtomicBool>,
141}
142
143impl<Io> YamuxFuture<Io>
144where
145 Io: AsyncWrite + AsyncRead + Unpin,
146{
147 #[cfg_attr(feature = "tracing", tracing::instrument(skip_all, err))]
148 fn client_handle_inbound(&mut self, cx: &mut Context<'_>) -> Result<()> {
149 if let Poll::Ready(stream) = self.conn.poll_next_inbound(cx).map(Option::transpose)? {
150 if stream.is_some() {
151 error!("client mux received incoming stream");
152 return Err(
153 std::io::Error::other("client mode cannot accept incoming streams").into(),
154 );
155 }
156
157 info!("remote closed connection");
158 self.remote_closed = true;
159 }
160
161 Ok(())
162 }
163
164 #[cfg_attr(feature = "tracing", tracing::instrument(skip_all, err))]
165 fn client_handle_outbound(&mut self, cx: &mut Context<'_>) -> Result<()> {
166 {
168 let mut queue = self.queue.lock().unwrap();
169
170 while queue.alloc > 0 {
172 if let Poll::Ready(stream) = self.conn.poll_new_outbound(cx)? {
173 self.allocated.push_back(stream);
174 queue.alloc -= 1;
175 debug!("allocated new stream");
176 } else {
177 break;
178 }
179 }
180
181 while !queue.waiting.is_empty() {
182 let stream = if let Some(stream) = self.allocated.pop_front() {
183 stream
184 } else if let Poll::Ready(stream) = self.conn.poll_new_outbound(cx)? {
185 stream
186 } else {
187 break;
188 };
189
190 let id = *queue.waiting.keys().next().unwrap();
191 let sender = queue.waiting.remove(&id).unwrap();
192
193 debug!("opened new stream: {}", id);
194
195 self.outgoing.push(ReturnStream::new(id, stream, sender));
196 }
197
198 queue.waker = Some(cx.waker().clone());
200 }
201
202 while let Poll::Ready(Some(result)) = self.outgoing.poll_next_unpin(cx) {
203 if let Err(err) = result {
204 warn!("connection closed while opening stream: {}", err);
205 self.remote_closed = true;
206 } else {
207 trace!("finished opening stream");
208 }
209 }
210
211 Ok(())
212 }
213
214 #[cfg_attr(feature = "tracing", tracing::instrument(skip_all, err))]
215 fn server_handle_inbound(&mut self, cx: &mut Context<'_>) -> Result<()> {
216 while let Poll::Ready(stream) = self.conn.poll_next_inbound(cx).map(Option::transpose)? {
217 let Some(stream) = stream else {
218 if !self.remote_closed {
219 info!("remote closed connection");
220 self.remote_closed = true;
221 }
222
223 break;
224 };
225
226 debug!("received incoming stream");
227 self.incoming.push(ReadId::new(stream));
229 }
230
231 Ok(())
232 }
233
234 #[cfg_attr(feature = "tracing", tracing::instrument(skip_all, err))]
235 fn server_process_inbound(&mut self, cx: &mut Context<'_>) -> Result<()> {
236 let mut queue = self.queue.lock().unwrap();
237 while let Poll::Ready(Some(result)) = self.incoming.poll_next_unpin(cx) {
238 match result {
239 Ok((id, stream)) => {
240 debug!("received stream: {}", id);
241 if let Some(sender) = queue.waiting.remove(&id) {
242 _ = sender
243 .send(stream)
244 .inspect_err(|_| error!("caller dropped receiver"));
245 trace!("returned stream to caller: {}", id);
246 } else {
247 trace!("queuing stream: {}", id);
248 queue.ready.insert(id, stream);
249 }
250 }
251 Err(err) => {
252 warn!("connection closed while receiving stream: {}", err);
253 self.remote_closed = true;
254 }
255 }
256 }
257
258 queue.waker = Some(cx.waker().clone());
260
261 Ok(())
262 }
263
264 #[cfg_attr(feature = "tracing", tracing::instrument(skip_all, err))]
265 fn handle_shutdown(&mut self, cx: &mut Context<'_>) -> Result<()> {
266 if !self.closed && self.shutdown_notify.load(Ordering::Relaxed) {
268 if let Poll::Ready(()) = self.conn.poll_close(cx)? {
269 self.closed = true;
270 info!("mux connection closed");
271 }
272 }
273
274 Ok(())
275 }
276
277 fn is_complete(&self) -> bool {
278 self.remote_closed || self.closed
279 }
280
281 fn poll_client(&mut self, cx: &mut Context<'_>) -> Result<()> {
282 self.client_handle_inbound(cx)?;
283
284 if !self.remote_closed {
285 self.client_handle_outbound(cx)?;
286
287 self.client_handle_inbound(cx)?;
290 }
291
292 self.handle_shutdown(cx)?;
293
294 Ok(())
295 }
296
297 fn poll_server(&mut self, cx: &mut Context<'_>) -> Result<()> {
298 self.server_handle_inbound(cx)?;
299 self.server_process_inbound(cx)?;
300 self.handle_shutdown(cx)?;
301
302 Ok(())
303 }
304}
305
306impl<Io> Future for YamuxFuture<Io>
307where
308 Io: AsyncWrite + AsyncRead + Unpin,
309{
310 type Output = Result<()>;
311
312 #[cfg_attr(
313 feature = "tracing",
314 tracing::instrument(
315 fields(role = %self.role),
316 skip_all
317 )
318 )]
319 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
320 match self.role {
321 Role::Client => self.poll_client(cx)?,
322 Role::Server => self.poll_server(cx)?,
323 };
324
325 if self.is_complete() {
326 self.close_notify.notify_waiters();
327 info!("connection complete");
328 Poll::Ready(Ok(()))
329 } else {
330 Poll::Pending
331 }
332 }
333}
334
335#[derive(Debug, Clone)]
337pub struct YamuxCtrl {
338 role: Role,
339 queue: Arc<Mutex<Queue>>,
340 close_notify: Arc<Notify>,
341 shutdown_notify: Arc<AtomicBool>,
342}
343
344impl YamuxCtrl {
345 pub fn alloc(&self, count: usize) {
353 if let Role::Server = self.role {
354 warn!("alloc has no effect for server side of connection");
355 return;
356 }
357
358 let mut queue = self.queue.lock().unwrap();
359 queue.alloc += count;
360 if let Some(waker) = queue.waker.as_ref() {
361 waker.wake_by_ref()
362 }
363 }
364
365 pub fn close(&self) {
367 self.shutdown_notify.store(true, Ordering::Relaxed);
368
369 if let Some(waker) = self.queue.lock().unwrap().waker.as_ref() {
371 waker.wake_by_ref()
372 }
373 }
374}
375
376#[async_trait]
377impl<Id> UidMux<Id> for YamuxCtrl
378where
379 Id: fmt::Debug + AsRef<[u8]> + Sync,
380{
381 type Stream = Stream;
382 type Error = std::io::Error;
383
384 #[cfg_attr(
385 feature = "tracing",
386 tracing::instrument(
387 fields(role = %self.role, id = hex::encode(id)),
388 skip_all,
389 err
390 )
391 )]
392 async fn open(&self, id: &Id) -> Result<Self::Stream, Self::Error> {
393 let internal_id = InternalId::new(id.as_ref());
394
395 debug!("opening stream: {}", internal_id);
396
397 let receiver = {
398 let mut queue = self.queue.lock().unwrap();
399 if let Some(stream) = queue.ready.remove(&internal_id) {
400 trace!("stream already opened");
401 return Ok(stream);
402 }
403
404 let (sender, receiver) = oneshot::channel();
405
406 queue.waiting.insert(internal_id, sender);
408 if let Some(waker) = queue.waker.as_ref() {
410 waker.wake_by_ref()
411 }
412
413 trace!("waiting for stream");
414
415 receiver
416 };
417
418 futures::select! {
419 stream = receiver.fuse() =>
420 stream
421 .inspect(|_| debug!("caller received stream"))
422 .inspect_err(|_| error!("connection cancelled stream"))
423 .map_err(|_| {
424 std::io::Error::other("connection cancelled stream".to_string())
425 }),
426 _ = self.close_notify.notified().fuse() => {
427 error!("connection closed before stream opened");
428 Err(std::io::ErrorKind::ConnectionAborted.into())
429 }
430 }
431 }
432}
433
434#[cfg(test)]
435mod tests {
436 use super::*;
437 use futures::{AsyncReadExt, AsyncWriteExt};
438 use tokio::io::duplex;
439 use tokio_util::compat::TokioAsyncReadCompatExt;
440
441 #[tokio::test]
442 async fn test_yamux() {
443 let (client_io, server_io) = duplex(1024);
444 let client = Yamux::new(client_io.compat(), Config::default(), Mode::Client);
445 let server = Yamux::new(server_io.compat(), Config::default(), Mode::Server);
446
447 let client_ctrl = client.control();
448 let server_ctrl = server.control();
449
450 let conn_task = tokio::spawn(async {
451 futures::try_join!(client.into_future(), server.into_future()).unwrap();
452 });
453
454 futures::join!(
455 async {
456 let mut stream = client_ctrl.open(b"0").await.unwrap();
457 let mut stream2 = client_ctrl.open(b"00").await.unwrap();
458
459 stream.write_all(b"ping").await.unwrap();
460 stream2.write_all(b"ping2").await.unwrap();
461 },
462 async {
463 let mut stream = server_ctrl.open(b"0").await.unwrap();
464 let mut stream2 = server_ctrl.open(b"00").await.unwrap();
465
466 let mut buf = [0; 4];
467 stream.read_exact(&mut buf).await.unwrap();
468 assert_eq!(&buf, b"ping");
469
470 let mut buf = [0; 5];
471 stream2.read_exact(&mut buf).await.unwrap();
472 assert_eq!(&buf, b"ping2");
473 }
474 );
475
476 client_ctrl.close();
477 server_ctrl.close();
478
479 conn_task.await.unwrap();
480 }
481
482 #[tokio::test]
483 async fn test_yamux_client_close() {
484 let (client_io, server_io) = duplex(1024);
485 let client = Yamux::new(client_io.compat(), Config::default(), Mode::Client);
486 let server = Yamux::new(server_io.compat(), Config::default(), Mode::Server);
487
488 let client_ctrl = client.control();
489
490 let mut fut = futures::future::try_join(client.into_future(), server.into_future());
491
492 _ = futures::poll!(&mut fut);
493
494 client_ctrl.close();
495
496 fut.await.unwrap();
498 }
499
500 #[tokio::test]
502 async fn test_yamux_client_close_early() {
503 let (client_io, server_io) = duplex(1024);
504 let client = Yamux::new(client_io.compat(), Config::default(), Mode::Client);
505 let server = Yamux::new(server_io.compat(), Config::default(), Mode::Server);
506
507 let client_ctrl = client.control();
508 let server_ctrl = server.control();
509
510 let mut fut_conn = futures::future::try_join(client.into_future(), server.into_future());
511 _ = futures::poll!(&mut fut_conn);
512
513 let mut fut_open = server_ctrl.open(b"0");
514 _ = futures::poll!(&mut fut_open);
515
516 client_ctrl.close();
517
518 fut_conn.await.unwrap();
520 assert!(fut_open.await.is_err());
522 }
523
524 #[tokio::test]
525 async fn test_yamux_server_close() {
526 let (client_io, server_io) = duplex(1024);
527 let client = Yamux::new(client_io.compat(), Config::default(), Mode::Client);
528 let server = Yamux::new(server_io.compat(), Config::default(), Mode::Server);
529
530 let server_ctrl = server.control();
531
532 let mut fut = futures::future::try_join(client.into_future(), server.into_future());
533
534 _ = futures::poll!(&mut fut);
535
536 server_ctrl.close();
537
538 fut.await.unwrap();
540 }
541
542 #[tokio::test]
544 async fn test_yamux_server_close_early() {
545 let (client_io, server_io) = duplex(1024);
546 let client = Yamux::new(client_io.compat(), Config::default(), Mode::Client);
547 let server = Yamux::new(server_io.compat(), Config::default(), Mode::Server);
548
549 let client_ctrl = client.control();
550 let server_ctrl = server.control();
551
552 let mut fut_client = client.into_future();
553 let mut fut_server = server.into_future();
554
555 let mut fut_conn = futures::future::try_join(&mut fut_client, &mut fut_server);
556 _ = futures::poll!(&mut fut_conn);
557 drop(fut_conn);
558
559 let mut fut_open = client_ctrl.open(b"0");
560 _ = futures::poll!(&mut fut_open);
561
562 fut_client.queue.lock().unwrap().waiting.clear();
564
565 server_ctrl.close();
566
567 futures::try_join!(fut_client, fut_server).unwrap();
569 assert!(fut_open.await.is_err());
571 }
572
573 #[tokio::test]
574 async fn test_yamux_alloc() {
575 let (client_io, server_io) = duplex(1024);
576 let client = Yamux::new(client_io.compat(), Config::default(), Mode::Client);
577 let server = Yamux::new(server_io.compat(), Config::default(), Mode::Server);
578
579 let client_ctrl = client.control();
580 let server_ctrl = server.control();
581
582 let mut fut_client = client.into_future();
583 let mut fut_server = server.into_future();
584
585 client_ctrl.alloc(1);
586 assert_eq!(fut_client.queue.lock().unwrap().alloc, 1);
587
588 let mut fut_conn = futures::future::try_join(&mut fut_client, &mut fut_server);
589 _ = futures::poll!(&mut fut_conn);
590 drop(fut_conn);
591
592 assert_eq!(fut_client.queue.lock().unwrap().alloc, 0);
593 assert_eq!(fut_client.allocated.len(), 1);
594
595 let fut_open = futures::future::try_join(client_ctrl.open(b"0"), server_ctrl.open(b"0"));
596
597 let fut_conn = futures::future::try_join(&mut fut_client, &mut fut_server);
598
599 futures::select! {
600 _ = fut_open.fuse() => {},
601 _ = fut_conn.fuse() => panic!("connection closed before stream opened"),
602 }
603
604 assert!(fut_client.allocated.is_empty());
606 }
607}