uid_mux/
yamux.rs

1//! Yamux multiplexer.
2//!
3//! This module provides a [`yamux`](https://crates.io/crates/yamux) wrapper which implements [`UidMux`](crate::UidMux).
4
5use std::{
6    collections::{HashMap, VecDeque},
7    fmt,
8    future::IntoFuture,
9    pin::Pin,
10    sync::{
11        atomic::{AtomicBool, Ordering},
12        Arc, Mutex,
13    },
14    task::{Context, Poll, Waker},
15};
16
17use async_trait::async_trait;
18use futures::{stream::FuturesUnordered, AsyncRead, AsyncWrite, Future, FutureExt, StreamExt};
19use tokio::sync::{oneshot, Notify};
20use yamux::Connection;
21
22use crate::{
23    future::{ReadId, ReturnStream},
24    log::{debug, error, info, trace, warn},
25    InternalId, UidMux,
26};
27
28pub use yamux::{Config, ConnectionError, Mode, Stream};
29
30type Result<T, E = ConnectionError> = std::result::Result<T, E>;
31
32#[derive(Debug, Clone, Copy)]
33enum Role {
34    Client,
35    Server,
36}
37
38impl fmt::Display for Role {
39    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40        match self {
41            Role::Client => write!(f, "Client"),
42            Role::Server => write!(f, "Server"),
43        }
44    }
45}
46
47/// A yamux multiplexer.
48#[derive(Debug)]
49pub struct Yamux<Io> {
50    role: Role,
51    conn: Connection<Io>,
52    queue: Arc<Mutex<Queue>>,
53    close_notify: Arc<Notify>,
54    shutdown_notify: Arc<AtomicBool>,
55}
56
57#[derive(Debug, Default)]
58struct Queue {
59    waiting: HashMap<InternalId, oneshot::Sender<Stream>>,
60    ready: HashMap<InternalId, Stream>,
61    alloc: usize,
62    waker: Option<Waker>,
63}
64
65impl<Io> Yamux<Io> {
66    /// Returns a new control handle.
67    pub fn control(&self) -> YamuxCtrl {
68        YamuxCtrl {
69            role: self.role,
70            queue: self.queue.clone(),
71            close_notify: self.close_notify.clone(),
72            shutdown_notify: self.shutdown_notify.clone(),
73        }
74    }
75}
76
77impl<Io> Yamux<Io>
78where
79    Io: AsyncWrite + AsyncRead + Unpin,
80{
81    /// Creates a new yamux multiplexer.
82    pub fn new(io: Io, config: Config, mode: Mode) -> Self {
83        let role = match mode {
84            Mode::Client => Role::Client,
85            Mode::Server => Role::Server,
86        };
87
88        Self {
89            role,
90            conn: Connection::new(io, config, mode),
91            queue: Default::default(),
92            close_notify: Default::default(),
93            shutdown_notify: Default::default(),
94        }
95    }
96}
97
98impl<Io> IntoFuture for Yamux<Io>
99where
100    Io: AsyncWrite + AsyncRead + Unpin,
101{
102    type Output = Result<()>;
103    type IntoFuture = YamuxFuture<Io>;
104
105    fn into_future(self) -> Self::IntoFuture {
106        YamuxFuture {
107            role: self.role,
108            conn: self.conn,
109            incoming: Default::default(),
110            allocated: Default::default(),
111            outgoing: Default::default(),
112            queue: self.queue,
113            closed: false,
114            remote_closed: false,
115            close_notify: self.close_notify,
116            shutdown_notify: self.shutdown_notify,
117        }
118    }
119}
120
121/// A yamux connection future.
122#[derive(Debug)]
123#[must_use = "futures do nothing unless you `.await` or poll them"]
124pub struct YamuxFuture<Io> {
125    role: Role,
126    conn: Connection<Io>,
127    /// Pending incoming streams, waiting for ids to be received.
128    incoming: FuturesUnordered<ReadId<Stream>>,
129    /// Streams which have been allocated but not assigned an id.
130    allocated: VecDeque<Stream>,
131    /// Pending outgoing streams, waiting to send ids and return streams
132    /// to callers.
133    outgoing: FuturesUnordered<ReturnStream<Stream>>,
134    queue: Arc<Mutex<Queue>>,
135    /// Whether this side has closed the connection.
136    closed: bool,
137    /// Whether the remote has closed the connection.
138    remote_closed: bool,
139    close_notify: Arc<Notify>,
140    shutdown_notify: Arc<AtomicBool>,
141}
142
143impl<Io> YamuxFuture<Io>
144where
145    Io: AsyncWrite + AsyncRead + Unpin,
146{
147    #[cfg_attr(feature = "tracing", tracing::instrument(skip_all, err))]
148    fn client_handle_inbound(&mut self, cx: &mut Context<'_>) -> Result<()> {
149        if let Poll::Ready(stream) = self.conn.poll_next_inbound(cx).map(Option::transpose)? {
150            if stream.is_some() {
151                error!("client mux received incoming stream");
152                return Err(
153                    std::io::Error::other("client mode cannot accept incoming streams").into(),
154                );
155            }
156
157            info!("remote closed connection");
158            self.remote_closed = true;
159        }
160
161        Ok(())
162    }
163
164    #[cfg_attr(feature = "tracing", tracing::instrument(skip_all, err))]
165    fn client_handle_outbound(&mut self, cx: &mut Context<'_>) -> Result<()> {
166        // Putting this in a block so the lock is released as soon as possible.
167        {
168            let mut queue = self.queue.lock().unwrap();
169
170            // Allocate new streams.
171            while queue.alloc > 0 {
172                if let Poll::Ready(stream) = self.conn.poll_new_outbound(cx)? {
173                    self.allocated.push_back(stream);
174                    queue.alloc -= 1;
175                    debug!("allocated new stream");
176                } else {
177                    break;
178                }
179            }
180
181            while !queue.waiting.is_empty() {
182                let stream = if let Some(stream) = self.allocated.pop_front() {
183                    stream
184                } else if let Poll::Ready(stream) = self.conn.poll_new_outbound(cx)? {
185                    stream
186                } else {
187                    break;
188                };
189
190                let id = *queue.waiting.keys().next().unwrap();
191                let sender = queue.waiting.remove(&id).unwrap();
192
193                debug!("opened new stream: {}", id);
194
195                self.outgoing.push(ReturnStream::new(id, stream, sender));
196            }
197
198            // Set the waker so `YamuxCtrl` can wake up the connection.
199            queue.waker = Some(cx.waker().clone());
200        }
201
202        while let Poll::Ready(Some(result)) = self.outgoing.poll_next_unpin(cx) {
203            if let Err(err) = result {
204                warn!("connection closed while opening stream: {}", err);
205                self.remote_closed = true;
206            } else {
207                trace!("finished opening stream");
208            }
209        }
210
211        Ok(())
212    }
213
214    #[cfg_attr(feature = "tracing", tracing::instrument(skip_all, err))]
215    fn server_handle_inbound(&mut self, cx: &mut Context<'_>) -> Result<()> {
216        while let Poll::Ready(stream) = self.conn.poll_next_inbound(cx).map(Option::transpose)? {
217            let Some(stream) = stream else {
218                if !self.remote_closed {
219                    info!("remote closed connection");
220                    self.remote_closed = true;
221                }
222
223                break;
224            };
225
226            debug!("received incoming stream");
227            // The size of this is bounded by yamux max streams config.
228            self.incoming.push(ReadId::new(stream));
229        }
230
231        Ok(())
232    }
233
234    #[cfg_attr(feature = "tracing", tracing::instrument(skip_all, err))]
235    fn server_process_inbound(&mut self, cx: &mut Context<'_>) -> Result<()> {
236        let mut queue = self.queue.lock().unwrap();
237        while let Poll::Ready(Some(result)) = self.incoming.poll_next_unpin(cx) {
238            match result {
239                Ok((id, stream)) => {
240                    debug!("received stream: {}", id);
241                    if let Some(sender) = queue.waiting.remove(&id) {
242                        _ = sender
243                            .send(stream)
244                            .inspect_err(|_| error!("caller dropped receiver"));
245                        trace!("returned stream to caller: {}", id);
246                    } else {
247                        trace!("queuing stream: {}", id);
248                        queue.ready.insert(id, stream);
249                    }
250                }
251                Err(err) => {
252                    warn!("connection closed while receiving stream: {}", err);
253                    self.remote_closed = true;
254                }
255            }
256        }
257
258        // Set the waker so `YamuxCtrl` can wake up the connection.
259        queue.waker = Some(cx.waker().clone());
260
261        Ok(())
262    }
263
264    #[cfg_attr(feature = "tracing", tracing::instrument(skip_all, err))]
265    fn handle_shutdown(&mut self, cx: &mut Context<'_>) -> Result<()> {
266        // Attempt to close the connection if the shutdown notify has been set.
267        if !self.closed && self.shutdown_notify.load(Ordering::Relaxed) {
268            if let Poll::Ready(()) = self.conn.poll_close(cx)? {
269                self.closed = true;
270                info!("mux connection closed");
271            }
272        }
273
274        Ok(())
275    }
276
277    fn is_complete(&self) -> bool {
278        self.remote_closed || self.closed
279    }
280
281    fn poll_client(&mut self, cx: &mut Context<'_>) -> Result<()> {
282        self.client_handle_inbound(cx)?;
283
284        if !self.remote_closed {
285            self.client_handle_outbound(cx)?;
286
287            // We need to poll the inbound again to make sure the connection
288            // flushes the write buffer.
289            self.client_handle_inbound(cx)?;
290        }
291
292        self.handle_shutdown(cx)?;
293
294        Ok(())
295    }
296
297    fn poll_server(&mut self, cx: &mut Context<'_>) -> Result<()> {
298        self.server_handle_inbound(cx)?;
299        self.server_process_inbound(cx)?;
300        self.handle_shutdown(cx)?;
301
302        Ok(())
303    }
304}
305
306impl<Io> Future for YamuxFuture<Io>
307where
308    Io: AsyncWrite + AsyncRead + Unpin,
309{
310    type Output = Result<()>;
311
312    #[cfg_attr(
313        feature = "tracing",
314        tracing::instrument(
315            fields(role = %self.role),
316            skip_all
317        )
318    )]
319    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
320        match self.role {
321            Role::Client => self.poll_client(cx)?,
322            Role::Server => self.poll_server(cx)?,
323        };
324
325        if self.is_complete() {
326            self.close_notify.notify_waiters();
327            info!("connection complete");
328            Poll::Ready(Ok(()))
329        } else {
330            Poll::Pending
331        }
332    }
333}
334
335/// A yamux control handle.
336#[derive(Debug, Clone)]
337pub struct YamuxCtrl {
338    role: Role,
339    queue: Arc<Mutex<Queue>>,
340    close_notify: Arc<Notify>,
341    shutdown_notify: Arc<AtomicBool>,
342}
343
344impl YamuxCtrl {
345    /// Allocates `count` streams.
346    ///
347    /// This can be used to efficiently pre-allocate streams prior to assigning ids to them.
348    ///
349    /// # Note
350    ///
351    /// This method only has an effect for the client side of the connection.
352    pub fn alloc(&self, count: usize) {
353        if let Role::Server = self.role {
354            warn!("alloc has no effect for server side of connection");
355            return;
356        }
357
358        let mut queue = self.queue.lock().unwrap();
359        queue.alloc += count;
360        if let Some(waker) = queue.waker.as_ref() {
361            waker.wake_by_ref()
362        }
363    }
364
365    /// Closes the yamux connection.
366    pub fn close(&self) {
367        self.shutdown_notify.store(true, Ordering::Relaxed);
368
369        // Wake up the connection.
370        if let Some(waker) = self.queue.lock().unwrap().waker.as_ref() {
371            waker.wake_by_ref()
372        }
373    }
374}
375
376#[async_trait]
377impl<Id> UidMux<Id> for YamuxCtrl
378where
379    Id: fmt::Debug + AsRef<[u8]> + Sync,
380{
381    type Stream = Stream;
382    type Error = std::io::Error;
383
384    #[cfg_attr(
385        feature = "tracing",
386        tracing::instrument(
387            fields(role = %self.role, id = hex::encode(id)),
388            skip_all,
389            err
390        )
391    )]
392    async fn open(&self, id: &Id) -> Result<Self::Stream, Self::Error> {
393        let internal_id = InternalId::new(id.as_ref());
394
395        debug!("opening stream: {}", internal_id);
396
397        let receiver = {
398            let mut queue = self.queue.lock().unwrap();
399            if let Some(stream) = queue.ready.remove(&internal_id) {
400                trace!("stream already opened");
401                return Ok(stream);
402            }
403
404            let (sender, receiver) = oneshot::channel();
405
406            // Insert the oneshot into the queue.
407            queue.waiting.insert(internal_id, sender);
408            // Wake up the connection.
409            if let Some(waker) = queue.waker.as_ref() {
410                waker.wake_by_ref()
411            }
412
413            trace!("waiting for stream");
414
415            receiver
416        };
417
418        futures::select! {
419            stream = receiver.fuse() =>
420                stream
421                    .inspect(|_| debug!("caller received stream"))
422                    .inspect_err(|_| error!("connection cancelled stream"))
423                    .map_err(|_| {
424                    std::io::Error::other("connection cancelled stream".to_string())
425                }),
426            _ = self.close_notify.notified().fuse() => {
427                error!("connection closed before stream opened");
428                Err(std::io::ErrorKind::ConnectionAborted.into())
429            }
430        }
431    }
432}
433
434#[cfg(test)]
435mod tests {
436    use super::*;
437    use futures::{AsyncReadExt, AsyncWriteExt};
438    use tokio::io::duplex;
439    use tokio_util::compat::TokioAsyncReadCompatExt;
440
441    #[tokio::test]
442    async fn test_yamux() {
443        let (client_io, server_io) = duplex(1024);
444        let client = Yamux::new(client_io.compat(), Config::default(), Mode::Client);
445        let server = Yamux::new(server_io.compat(), Config::default(), Mode::Server);
446
447        let client_ctrl = client.control();
448        let server_ctrl = server.control();
449
450        let conn_task = tokio::spawn(async {
451            futures::try_join!(client.into_future(), server.into_future()).unwrap();
452        });
453
454        futures::join!(
455            async {
456                let mut stream = client_ctrl.open(b"0").await.unwrap();
457                let mut stream2 = client_ctrl.open(b"00").await.unwrap();
458
459                stream.write_all(b"ping").await.unwrap();
460                stream2.write_all(b"ping2").await.unwrap();
461            },
462            async {
463                let mut stream = server_ctrl.open(b"0").await.unwrap();
464                let mut stream2 = server_ctrl.open(b"00").await.unwrap();
465
466                let mut buf = [0; 4];
467                stream.read_exact(&mut buf).await.unwrap();
468                assert_eq!(&buf, b"ping");
469
470                let mut buf = [0; 5];
471                stream2.read_exact(&mut buf).await.unwrap();
472                assert_eq!(&buf, b"ping2");
473            }
474        );
475
476        client_ctrl.close();
477        server_ctrl.close();
478
479        conn_task.await.unwrap();
480    }
481
482    #[tokio::test]
483    async fn test_yamux_client_close() {
484        let (client_io, server_io) = duplex(1024);
485        let client = Yamux::new(client_io.compat(), Config::default(), Mode::Client);
486        let server = Yamux::new(server_io.compat(), Config::default(), Mode::Server);
487
488        let client_ctrl = client.control();
489
490        let mut fut = futures::future::try_join(client.into_future(), server.into_future());
491
492        _ = futures::poll!(&mut fut);
493
494        client_ctrl.close();
495
496        // Both connections close cleanly.
497        fut.await.unwrap();
498    }
499
500    // Test the case where the client closes the connection while the server is expecting a new stream.
501    #[tokio::test]
502    async fn test_yamux_client_close_early() {
503        let (client_io, server_io) = duplex(1024);
504        let client = Yamux::new(client_io.compat(), Config::default(), Mode::Client);
505        let server = Yamux::new(server_io.compat(), Config::default(), Mode::Server);
506
507        let client_ctrl = client.control();
508        let server_ctrl = server.control();
509
510        let mut fut_conn = futures::future::try_join(client.into_future(), server.into_future());
511        _ = futures::poll!(&mut fut_conn);
512
513        let mut fut_open = server_ctrl.open(b"0");
514        _ = futures::poll!(&mut fut_open);
515
516        client_ctrl.close();
517
518        // Both connections close cleanly.
519        fut_conn.await.unwrap();
520        // But caller gets an error.
521        assert!(fut_open.await.is_err());
522    }
523
524    #[tokio::test]
525    async fn test_yamux_server_close() {
526        let (client_io, server_io) = duplex(1024);
527        let client = Yamux::new(client_io.compat(), Config::default(), Mode::Client);
528        let server = Yamux::new(server_io.compat(), Config::default(), Mode::Server);
529
530        let server_ctrl = server.control();
531
532        let mut fut = futures::future::try_join(client.into_future(), server.into_future());
533
534        _ = futures::poll!(&mut fut);
535
536        server_ctrl.close();
537
538        // Both connections close cleanly.
539        fut.await.unwrap();
540    }
541
542    // Test the case where the server closes the connection while the client is opening a new stream.
543    #[tokio::test]
544    async fn test_yamux_server_close_early() {
545        let (client_io, server_io) = duplex(1024);
546        let client = Yamux::new(client_io.compat(), Config::default(), Mode::Client);
547        let server = Yamux::new(server_io.compat(), Config::default(), Mode::Server);
548
549        let client_ctrl = client.control();
550        let server_ctrl = server.control();
551
552        let mut fut_client = client.into_future();
553        let mut fut_server = server.into_future();
554
555        let mut fut_conn = futures::future::try_join(&mut fut_client, &mut fut_server);
556        _ = futures::poll!(&mut fut_conn);
557        drop(fut_conn);
558
559        let mut fut_open = client_ctrl.open(b"0");
560        _ = futures::poll!(&mut fut_open);
561
562        // We need to prevent the client from beating us to the punch here.
563        fut_client.queue.lock().unwrap().waiting.clear();
564
565        server_ctrl.close();
566
567        // Both connections close cleanly.
568        futures::try_join!(fut_client, fut_server).unwrap();
569        // But caller gets an error.
570        assert!(fut_open.await.is_err());
571    }
572
573    #[tokio::test]
574    async fn test_yamux_alloc() {
575        let (client_io, server_io) = duplex(1024);
576        let client = Yamux::new(client_io.compat(), Config::default(), Mode::Client);
577        let server = Yamux::new(server_io.compat(), Config::default(), Mode::Server);
578
579        let client_ctrl = client.control();
580        let server_ctrl = server.control();
581
582        let mut fut_client = client.into_future();
583        let mut fut_server = server.into_future();
584
585        client_ctrl.alloc(1);
586        assert_eq!(fut_client.queue.lock().unwrap().alloc, 1);
587
588        let mut fut_conn = futures::future::try_join(&mut fut_client, &mut fut_server);
589        _ = futures::poll!(&mut fut_conn);
590        drop(fut_conn);
591
592        assert_eq!(fut_client.queue.lock().unwrap().alloc, 0);
593        assert_eq!(fut_client.allocated.len(), 1);
594
595        let fut_open = futures::future::try_join(client_ctrl.open(b"0"), server_ctrl.open(b"0"));
596
597        let fut_conn = futures::future::try_join(&mut fut_client, &mut fut_server);
598
599        futures::select! {
600            _ = fut_open.fuse() => {},
601            _ = fut_conn.fuse() => panic!("connection closed before stream opened"),
602        }
603
604        // Assert that the pre-allocated stream was consumed.
605        assert!(fut_client.allocated.is_empty());
606    }
607}