uid_mux/
serio.rs

1use ::serio::{codec::Codec, IoDuplex};
2use async_trait::async_trait;
3
4use crate::UidMux;
5
6/// A multiplexer that opens framed streams with unique ids.
7#[async_trait]
8pub trait FramedUidMux<Id> {
9    /// Stream type.
10    type Framed: IoDuplex;
11    /// Error type.
12    type Error;
13
14    /// Opens a new framed stream with the given id.
15    async fn open_framed(&self, id: &Id) -> Result<Self::Framed, Self::Error>;
16}
17
18/// A framed multiplexer.
19#[derive(Debug)]
20pub struct FramedMux<M, C> {
21    mux: M,
22    codec: C,
23}
24
25impl<M, C> FramedMux<M, C> {
26    /// Creates a new `FramedMux`.
27    pub fn new(mux: M, codec: C) -> Self {
28        Self { mux, codec }
29    }
30
31    /// Returns a reference to the mux.
32    pub fn mux(&self) -> &M {
33        &self.mux
34    }
35
36    /// Returns a mutable reference to the mux.
37    pub fn mux_mut(&mut self) -> &mut M {
38        &mut self.mux
39    }
40
41    /// Returns a reference to the codec.
42    pub fn codec(&self) -> &C {
43        &self.codec
44    }
45
46    /// Returns a mutable reference to the codec.
47    pub fn codec_mut(&mut self) -> &mut C {
48        &mut self.codec
49    }
50
51    /// Splits the `FramedMux` into its parts.
52    pub fn into_parts(self) -> (M, C) {
53        (self.mux, self.codec)
54    }
55}
56
57#[async_trait]
58impl<Id, M, C> FramedUidMux<Id> for FramedMux<M, C>
59where
60    Id: Sync,
61    M: UidMux<Id> + Sync,
62    C: Codec<<M as UidMux<Id>>::Stream> + Sync,
63{
64    /// Stream type.
65    type Framed = <C as Codec<<M as UidMux<Id>>::Stream>>::Framed;
66    /// Error type.
67    type Error = <M as UidMux<Id>>::Error;
68
69    /// Opens a new framed stream with the given id.
70    async fn open_framed(&self, id: &Id) -> Result<Self::Framed, Self::Error> {
71        let stream = self.mux.open(id).await?;
72        Ok(self.codec.new_framed(stream))
73    }
74}
75
76impl<M: Clone, C: Clone> Clone for FramedMux<M, C> {
77    fn clone(&self) -> Self {
78        Self {
79            mux: self.mux.clone(),
80            codec: self.codec.clone(),
81        }
82    }
83}
84
85#[cfg(test)]
86mod tests {
87    use std::future::IntoFuture;
88
89    use super::*;
90    use crate::yamux::{Config, Mode, Yamux};
91
92    use ::serio::codec::Bincode;
93    use serio::{stream::IoStreamExt, SinkExt};
94    use tokio::io::duplex;
95    use tokio_util::compat::TokioAsyncReadCompatExt;
96
97    #[tokio::test]
98    async fn test_framed_mux() {
99        let (client_io, server_io) = duplex(1024);
100        let client = Yamux::new(client_io.compat(), Config::default(), Mode::Client);
101        let server = Yamux::new(server_io.compat(), Config::default(), Mode::Server);
102
103        let client_ctrl = FramedMux::new(client.control(), Bincode);
104        let server_ctrl = FramedMux::new(server.control(), Bincode);
105
106        let conn_task = tokio::spawn(async {
107            futures::try_join!(client.into_future(), server.into_future()).unwrap();
108        });
109
110        futures::join!(
111            async {
112                let mut stream = client_ctrl.open_framed(b"test").await.unwrap();
113
114                stream.send(42u128).await.unwrap();
115
116                client_ctrl.mux().close();
117            },
118            async {
119                let mut stream = server_ctrl.open_framed(b"test").await.unwrap();
120
121                let num: u128 = stream.expect_next().await.unwrap();
122
123                server_ctrl.mux().close();
124
125                assert_eq!(num, 42u128);
126            }
127        );
128
129        conn_task.await.unwrap();
130    }
131}