Skip to main content

vt_muxer/
lib.rs

1use std::io;
2use std::io::Error;
3use std::pin::{pin, Pin};
4use std::sync::Arc;
5use std::task::{ready, Context, Poll};
6use tokio::io::{AsyncRead, AsyncWrite, BufReader, ReadBuf};
7use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
8use tokio::net::TcpStream;
9use tokio::sync::Mutex;
10use crate::thin_addr::SocketAddr;
11use crate::constructor::ConstructExt;
12use crate::poll_mutex::PollMutex;
13use crate::read::{ReaderInner, SharedReader};
14use crate::write::WriteInner;
15
16pub mod thin_addr;
17mod poll_mutex;
18mod packet_buffer;
19mod constructor;
20mod write;
21mod read;
22mod protocol;
23mod integers;
24
25type Writer = OwnedWriteHalf;
26type Reader = BufReader<OwnedReadHalf>;
27
28/// Represents a multiplexed connection.
29///
30/// # Usage
31///
32/// This struct is intended to be used in networking or IPC (Inter-Process
33/// Communication) systems where multiplexing is required. The writer and
34/// reader components work together to manage input/output streams.
35///
36/// Note: Ensure proper synchronization and error handling when dealing
37/// with concurrent reads and writes to avoid potential data races or
38/// inconsistencies.
39pub struct MuxConnection {
40    write: Box<WriteInner>,
41    read: ReaderInner
42}
43
44impl MuxConnection {
45    fn new(write: Box<WriteInner>, read: ReaderInner) -> Self {
46        Self {
47            write,
48            read
49        }
50    }
51    
52    pub fn addr(&self) -> SocketAddr {
53        self.write.addr()
54    }
55}
56
57impl AsyncWrite for MuxConnection {
58    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, Error>> {
59        Pin::new(&mut Pin::into_inner(self).write).poll_write(cx, buf)
60    }
61
62    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
63        Pin::new(&mut Pin::into_inner(self).write).poll_flush(cx)
64    }
65
66    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
67        Pin::new(&mut Pin::into_inner(self).write).poll_shutdown(cx)
68    }
69}
70
71impl AsyncRead for MuxConnection {
72    fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
73        Pin::new(&mut Pin::into_inner(self).read).poll_read(cx, buf)
74    }
75}
76
77
78/// # MuxPipe
79/// 
80/// The client-side interface for creating multiplexed connections over a single TCP stream.
81/// This can also be used by the server, but this structure can only initiate connection
82/// and can't accept them
83///
84/// ## Important Notes
85/// 
86/// - `MuxPipe` is designed to be used with a `MuxListener` on the server side.
87/// - The struct implements `Clone`, allowing it to be safely shared between multiple tasks.
88/// - All logical connections created through a single `MuxPipe` share the same underlying TCP connection.
89/// - The socket addresses used with `add_connection` serve as identifiers for the logical connections and don't represent actual network endpoints.
90/// - The implementation uses Tokio for async I/O operations, so it must be used within a Tokio runtime.
91#[derive(Clone)]
92pub struct MuxPipe {
93    write: Arc<Mutex<Writer>>,
94    read: Arc<SharedReader>,
95}
96
97impl MuxPipe {
98    /// Creates a new `MuxPipe` from a TCP stream. This takes ownership of the stream and splits it into read and write halves for multiplexing.
99    pub fn new(stream: TcpStream) -> Self {
100        MuxListener::with_listener_capacity(stream, 0).into_pipe()
101    }
102    
103    fn make_writer(&self, addr: SocketAddr) -> Box<WriteInner> {
104        WriteInner::box_new((addr, PollMutex::new(Arc::clone(&self.write))))
105    }
106    
107    
108    /// Creates a new logical connection with the specified socket address. This performs a handshake with the remote end to establish the connection.
109    /// 
110    /// #### Parameters
111    /// - `addr`: The socket address to use for identifying the logical connection
112    /// 
113    /// #### Returns
114    /// - `io::Result<MuxConnection>`: A result containing the new multiplexed connection if successful
115    /// 
116    /// ## Usage Example
117    /// 
118    /// ```rust
119    /// # use std::io;
120    /// # use vt_muxer::MuxPipe;
121    /// # use tokio::net::TcpStream;
122    /// # use tokio::io::AsyncWriteExt;
123    ///
124    /// async fn example() -> io::Result<()> {
125    ///     // Create a TCP connection to the server
126    ///     let tcp_stream = TcpStream::connect("server_address:port").await?;
127    ///     
128    ///     // Create a multiplexer over the TCP stream
129    ///     let mux = MuxPipe::new(tcp_stream);
130    ///     
131    ///     // Create multiple logical connections over the same TCP stream
132    ///     let addr1 = "127.0.0.1:12345".parse().unwrap();
133    ///     let mut connection1 = mux.add_connection(addr1).await?;
134    ///     
135    ///     let addr2 = "127.0.0.1:12346".parse().unwrap();
136    ///     let mut connection2 = mux.add_connection(addr2).await?;
137    ///     
138    ///     // Use the connections independently
139    ///     connection1.write_all(b"Data for connection 1").await?;
140    ///     connection2.write_all(b"Data for connection 2").await?;
141    ///     
142    ///     // Don't forget to properly shut down connections when done
143    ///     connection1.shutdown().await?;
144    ///     connection2.shutdown().await?;
145    ///     
146    ///     Ok(())
147    /// }
148    /// ```
149    /// 
150    ///
151    pub async fn add_connection(&self, addr: SocketAddr) -> io::Result<MuxConnection> {
152        let reader = self.read.add_connection(addr)?;
153        let mut writer = self.make_writer(addr);
154        writer.handshake().await?;
155        Ok(MuxConnection::new(writer, reader))
156    }
157}
158
159
160
161
162/// `MuxListener` is a structure designed to manage and listen for incoming multiplexed connections.
163
164///
165/// # Purpose
166/// The `MuxListener` serves as an abstraction to handle incoming connections from a MuxPipe,
167/// please note that once discarded there is no way of listening to new incoming connections
168pub struct MuxListener {
169    pipe: MuxPipe,
170    receiver: flume::Receiver<(SocketAddr, ReaderInner)>
171}
172
173impl MuxListener {
174    pub fn new(stream: TcpStream) -> Self {
175        Self::with_listener_capacity(stream, 1)
176    }
177
178    fn with_listener_capacity(stream: TcpStream, capacity: usize) -> Self {
179        let (read, write) = stream.into_split();
180        let reader = BufReader::new(read);
181        let (sender, receiver) = flume::bounded(capacity);
182        let read = SharedReader::new(reader, sender);
183        let write = Arc::new(Mutex::new(write));
184        
185        Self {
186            pipe: MuxPipe { write, read },
187            receiver
188        }
189    }
190    
191    pub async fn add_connection(&self, addr: SocketAddr) -> io::Result<MuxConnection> {
192        self.pipe.add_connection(addr).await
193    }
194    
195    pub async fn accept(&self) -> io::Result<MuxConnection> {
196        let mut fut = pin!(self.receiver.recv_async());
197        let (addr, reader) = std::future::poll_fn(move |cx| {
198            if let Poll::Ready(res) = fut.as_mut().poll(cx) { 
199                return Poll::Ready(Ok::<_, Error>(res.expect("receiver should never close")))
200            }
201            
202            match ready!(self.pipe.read.poll(cx))? {}
203        }).await?;
204        let writer = self.pipe.make_writer(addr);
205        Ok(MuxConnection::new(writer, reader))
206    }
207
208    pub fn pipe(&self) -> &MuxPipe {
209        &self.pipe
210    }
211    
212    pub fn into_pipe(self) -> MuxPipe {
213        self.pipe
214    }
215}
216
217#[cfg(all(test, not(miri)))]
218mod tests {
219    use super::*;
220    use tokio::net::TcpListener;
221    use tokio::io::{AsyncReadExt, AsyncWriteExt};
222
223    fn dummy_addr() -> SocketAddr {
224        // Use a dummy address (you may need to adapt this depending on your SocketAddr type)
225        "127.0.0.1:12345".parse().unwrap()
226    }
227    
228    async fn mux_pipe() -> (MuxListener, MuxPipe) {
229        // Setup a real TCP listener
230        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
231        let addr = listener.local_addr().unwrap();
232
233        // Connect two TcpStreams
234        let server = async {
235            let (stream, _) = listener.accept().await.unwrap();
236            MuxListener::new(stream)
237        };
238
239        let client = async {
240            MuxPipe::new(TcpStream::connect(addr).await.unwrap())
241        };
242
243        tokio::join!(server, client)
244    }
245
246    #[tokio::test]
247    async fn test_mux_listener_accept_connection() {
248        let (mux_listener, conn) = mux_pipe().await;
249        
250        // Add a new connection from the client side
251        let client_task = async {
252            let mut mux_conn = conn.add_connection(dummy_addr()).await.unwrap();
253            
254            mux_conn.write_all(b"hello world").await.unwrap();
255            mux_conn.flush().await.unwrap();
256            mux_conn.shutdown().await.unwrap();
257        };
258
259        // Accept connection from the server side
260        let server_task = async {
261            let mut accepted = mux_listener.accept().await.unwrap();
262            let mut buf = vec![];
263            let n = accepted.read_to_end(&mut buf).await.unwrap();
264            let received = &buf[..n];
265            assert_eq!(received, b"hello world");
266        };
267
268        tokio::join!(client_task, server_task);
269    }
270
271    #[tokio::test]
272    async fn test_mux_pipe_add_connection_multiple_times() {
273        let (mux_pipe_server, mux_pipe_client) = mux_pipe().await;
274
275        // Open two different connections
276        let addr1 = dummy_addr();
277        let addr2 = "127.0.0.1:12346".parse::<SocketAddr>().unwrap();
278
279        let client_task = async {
280            let handle = async |addr, bytes| {
281                let mut conn = mux_pipe_client.add_connection(addr).await?;
282                conn.write_all(bytes).await?;
283                conn.flush().await?;
284                conn.shutdown().await
285            };
286            
287            tokio::try_join!(handle(addr1, b"first connection"), handle(addr2, b"second connection"))
288        };
289
290        let server_task = async {
291            let (mut conn1, mut conn2) = {
292                let conn1 = mux_pipe_server.accept().await?;
293                let conn2 = mux_pipe_server.accept().await?;
294                
295                match (conn1.addr(), conn2.addr()) {
296                    (con1, con2) if con1 == addr1 && con2 == addr2 => {
297                        (conn1, conn2)
298                    }
299                    (con1, con2) if con1 == addr2 && con2 == addr1 => {
300                        (conn2, conn1)
301                    }
302                    _ => unreachable!()
303                }
304            };
305            
306            let mut buf1 = vec![];
307            let n1 = conn1.read_to_end(&mut buf1).await?;
308            assert_eq!(&buf1[..n1], b"first connection");
309
310            let mut buf2 = vec![];
311            let n2 = conn2.read_to_end(&mut buf2).await?;
312            assert_eq!(&buf2[..n2], b"second connection");
313            Ok(())
314        };
315
316        tokio::try_join!(client_task, server_task).unwrap();
317    }
318}