1use ::serio::{codec::Codec, IoDuplex};
2use async_trait::async_trait;
3
4use crate::UidMux;
5
6#[async_trait]
8pub trait FramedUidMux<Id> {
9 type Framed: IoDuplex;
11 type Error;
13
14 async fn open_framed(&self, id: &Id) -> Result<Self::Framed, Self::Error>;
16}
17
18#[derive(Debug)]
20pub struct FramedMux<M, C> {
21 mux: M,
22 codec: C,
23}
24
25impl<M, C> FramedMux<M, C> {
26 pub fn new(mux: M, codec: C) -> Self {
28 Self { mux, codec }
29 }
30
31 pub fn mux(&self) -> &M {
33 &self.mux
34 }
35
36 pub fn mux_mut(&mut self) -> &mut M {
38 &mut self.mux
39 }
40
41 pub fn codec(&self) -> &C {
43 &self.codec
44 }
45
46 pub fn codec_mut(&mut self) -> &mut C {
48 &mut self.codec
49 }
50
51 pub fn into_parts(self) -> (M, C) {
53 (self.mux, self.codec)
54 }
55}
56
57#[async_trait]
58impl<Id, M, C> FramedUidMux<Id> for FramedMux<M, C>
59where
60 Id: Sync,
61 M: UidMux<Id> + Sync,
62 C: Codec<<M as UidMux<Id>>::Stream> + Sync,
63{
64 type Framed = <C as Codec<<M as UidMux<Id>>::Stream>>::Framed;
66 type Error = <M as UidMux<Id>>::Error;
68
69 async fn open_framed(&self, id: &Id) -> Result<Self::Framed, Self::Error> {
71 let stream = self.mux.open(id).await?;
72 Ok(self.codec.new_framed(stream))
73 }
74}
75
76impl<M: Clone, C: Clone> Clone for FramedMux<M, C> {
77 fn clone(&self) -> Self {
78 Self {
79 mux: self.mux.clone(),
80 codec: self.codec.clone(),
81 }
82 }
83}
84
85#[cfg(test)]
86mod tests {
87 use std::future::IntoFuture;
88
89 use super::*;
90 use crate::yamux::{Config, Mode, Yamux};
91
92 use ::serio::codec::Bincode;
93 use serio::{stream::IoStreamExt, SinkExt};
94 use tokio::io::duplex;
95 use tokio_util::compat::TokioAsyncReadCompatExt;
96
97 #[tokio::test]
98 async fn test_framed_mux() {
99 let (client_io, server_io) = duplex(1024);
100 let client = Yamux::new(client_io.compat(), Config::default(), Mode::Client);
101 let server = Yamux::new(server_io.compat(), Config::default(), Mode::Server);
102
103 let client_ctrl = FramedMux::new(client.control(), Bincode);
104 let server_ctrl = FramedMux::new(server.control(), Bincode);
105
106 let conn_task = tokio::spawn(async {
107 futures::try_join!(client.into_future(), server.into_future()).unwrap();
108 });
109
110 futures::join!(
111 async {
112 let mut stream = client_ctrl.open_framed(b"test").await.unwrap();
113
114 stream.send(42u128).await.unwrap();
115
116 client_ctrl.mux().close();
117 },
118 async {
119 let mut stream = server_ctrl.open_framed(b"test").await.unwrap();
120
121 let num: u128 = stream.expect_next().await.unwrap();
122
123 server_ctrl.mux().close();
124
125 assert_eq!(num, 42u128);
126 }
127 );
128
129 conn_task.await.unwrap();
130 }
131}