tokio_stream_multiplexor_fork/
lib.rs

1//! # Tokio Stream Multiplexor
2//!
3//! TL;DR: Multiplex multiple streams over a single stream. Has a TcpListener / TcpSocket style interface, and uses u16 ports similar to TCP itself.
4//!
5//! ```toml
6//! [dependencies]
7//! foo = "1.2.3"
8//! ```
9//!
10//! ## Why?
11//!
12//! Because sometimes you wanna cram as much shite down one TCP stream as you possibly can, rather than have your application connect with multiple ports.
13//!
14//! ## But Whyyyyyy?
15//!
16//! Because I needed it. Inspired by [async-smux](https://github.com/black-binary/async-smux), written for [Tokio](https://tokio.rs/).
17//!
18//! ## What about performance?
19//! Doesn't this whole protocol in a protocol thing hurt perf?
20//!
21//! Sure, but take a look at the benches:
22//!
23//! ```ignore
24//! throughput/tcp          time:   [28.968 ms 30.460 ms 31.744 ms]
25//!                         thrpt:  [7.8755 GiB/s 8.2076 GiB/s 8.6303 GiB/s]
26//!
27//! throughput/mux          time:   [122.24 ms 135.96 ms 158.80 ms]
28//!                         thrpt:  [1.5743 GiB/s 1.8388 GiB/s 2.0451 GiB/s]
29//! ```
30//!
31//! Approximately 4.5 times slower than TCP, but still able to shovel 1.8 GiB/s of shite... Seems alright to me. (Numbers could possibly be improved with some tuning of the config params too.)
32
33#![warn(missing_docs)]
34
35mod config;
36mod frame;
37mod inner;
38mod listener;
39mod socket;
40
41use std::{
42    collections::HashMap,
43    fmt::{Debug, Formatter, Result as FmtResult},
44    io,
45    sync::{
46        atomic::{AtomicBool, Ordering},
47        Arc,
48    },
49};
50
51extern crate async_channel;
52use rand::Rng;
53pub use tokio::io::DuplexStream;
54use tokio::{
55    io::{split, AsyncRead, AsyncWrite},
56    sync::{mpsc, watch, RwLock},
57};
58use tokio_util::codec::{FramedRead, FramedWrite};
59use tracing::{error, trace};
60
61pub use config::StreamMultiplexorConfig;
62use frame::{FrameDecoder, FrameEncoder};
63use inner::StreamMultiplexorInner;
64pub use listener::MuxListener;
65use socket::MuxSocket;
66
67/// Result type returned by `bind()`, `accept()`, and `connect()`.
68pub type Result<T> = std::result::Result<T, io::Error>;
69
70/// The Stream Multiplexor.
71pub struct StreamMultiplexor<T> {
72    inner: Arc<StreamMultiplexorInner<T>>,
73}
74
75impl<T> Debug for StreamMultiplexor<T> {
76    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
77        f.debug_struct("StreamMultiplexor")
78            .field("id", &self.inner.config.identifier)
79            .field("inner", &self.inner)
80            .finish()
81    }
82}
83
84impl<T> Drop for StreamMultiplexor<T> {
85    fn drop(&mut self) {
86        self.inner.watch_connected_send.send_replace(false);
87        error!("drop {:?}", self);
88    }
89}
90
91impl<T: AsyncRead + AsyncWrite + Send + Unpin + 'static> StreamMultiplexor<T> {
92    /// Constructs a new `StreamMultiplexor<T>`.
93    pub fn new(inner: T, config: StreamMultiplexorConfig) -> Self {
94        Self::new_running(inner, config, true)
95    }
96
97    /// Constructs a new paused `StreamMultiplexor<T>`.
98    ///
99    /// This allows you to bind and listen on a bunch of ports before
100    /// processing any packets from the inner stream, removing race conditions
101    /// between bind and connect. Call `start()` to start processing the
102    /// inner stream.
103    pub fn new_paused(inner: T, config: StreamMultiplexorConfig) -> Self {
104        Self::new_running(inner, config, false)
105    }
106
107    fn new_running(inner: T, config: StreamMultiplexorConfig, running: bool) -> Self {
108        let (reader, writer) = split(inner);
109
110        let framed_reader = FramedRead::new(reader, FrameDecoder::new(config));
111        let framed_writer = FramedWrite::new(writer, FrameEncoder::new(config));
112        let (send, recv) = mpsc::channel(config.max_queued_frames);
113        let (watch_connected_send, watch_connected_recv) = watch::channel(true);
114        let (running, _) = watch::channel(running);
115        let inner = Arc::from(StreamMultiplexorInner {
116            config,
117            connected: AtomicBool::from(true),
118            port_connections: RwLock::from(HashMap::new()),
119            port_listeners: RwLock::from(HashMap::new()),
120            watch_connected_send,
121            send: RwLock::from(send),
122            running,
123        });
124
125        tokio::spawn(inner.clone().framed_writer_sender(recv, framed_writer));
126        tokio::spawn(inner.clone().framed_reader_sender(framed_reader));
127        tokio::spawn(inner.clone().handle_disconnected(watch_connected_recv));
128
129        Self { inner }
130    }
131
132    /// Start processing the inner stream.
133    ///
134    /// Only effective on a paused `StreamMultiplexor<T>`.
135    /// See `new_paused(inner: T, config: StreamMultiplexorConfig)`.
136    pub fn start(&self) {
137        self.inner.running.send_replace(true);
138    }
139
140    /// Shut down the StreamMultiplexor<T> instance and drop reference
141    /// to the inner stream to close it.
142    pub fn close(&self) {
143        self.inner.watch_connected_send.send_replace(false);
144    }
145
146    /// Bind to port and return a `MuxListener<T>`.
147    #[tracing::instrument]
148    pub async fn bind(&self, port: u16) -> Result<MuxListener<T>> {
149        trace!("");
150        if !self.inner.connected.load(Ordering::Relaxed) {
151            return Err(io::Error::from(io::ErrorKind::ConnectionReset));
152        }
153        let mut port = port;
154        if port == 0 {
155            while port < 1024 || self.inner.port_listeners.read().await.contains_key(&port) {
156                port = rand::thread_rng().gen_range(1024u16..u16::MAX);
157            }
158            trace!("port = {}", port);
159        } else if self.inner.port_listeners.read().await.contains_key(&port) {
160            trace!("port_listeners already contains {}", port);
161            return Err(io::Error::from(io::ErrorKind::AddrInUse));
162        }
163        let (send, recv) = async_channel::bounded(self.inner.config.accept_queue_len);
164        self.inner.port_listeners.write().await.insert(port, send);
165        Ok(MuxListener::new(self.inner.clone(), port, recv))
166    }
167
168    /// Connect to `port` on the remote end.
169    #[tracing::instrument]
170    pub async fn connect(&self, port: u16) -> Result<DuplexStream> {
171        trace!("");
172        if !self.inner.connected.load(Ordering::Relaxed) {
173            trace!("Not connected, raise Error");
174            return Err(io::Error::from(io::ErrorKind::ConnectionReset));
175        }
176        let mut sport: u16 = 0;
177        while sport < 1024
178            || self
179                .inner
180                .port_connections
181                .read()
182                .await
183                .contains_key(&(sport, port))
184        {
185            sport = rand::thread_rng().gen_range(1024u16..u16::MAX);
186        }
187        trace!("sport = {}", sport);
188
189        let mux_socket = MuxSocket::new(self.inner.clone(), sport, port, false);
190        let mut rx = mux_socket.stream().await;
191        self.inner
192            .port_connections
193            .write()
194            .await
195            .insert((sport, port), mux_socket.clone());
196        mux_socket.start().await;
197
198        rx.recv()
199            .await
200            .ok_or_else(|| io::Error::from(io::ErrorKind::Other))?
201    }
202
203    /// Return a `tokio::sync::watch::Receiver` that will update to `false`
204    /// when the inner stream closes.
205    #[tracing::instrument]
206    pub fn watch_connected(&self) -> watch::Receiver<bool> {
207        trace!("");
208        self.inner.watch_connected_send.subscribe()
209    }
210}
211
212#[cfg(test)]
213mod test;