tokio_stream_multiplexor_fork/
lib.rs1#![warn(missing_docs)]
34
35mod config;
36mod frame;
37mod inner;
38mod listener;
39mod socket;
40
41use std::{
42 collections::HashMap,
43 fmt::{Debug, Formatter, Result as FmtResult},
44 io,
45 sync::{
46 atomic::{AtomicBool, Ordering},
47 Arc,
48 },
49};
50
51extern crate async_channel;
52use rand::Rng;
53pub use tokio::io::DuplexStream;
54use tokio::{
55 io::{split, AsyncRead, AsyncWrite},
56 sync::{mpsc, watch, RwLock},
57};
58use tokio_util::codec::{FramedRead, FramedWrite};
59use tracing::{error, trace};
60
61pub use config::StreamMultiplexorConfig;
62use frame::{FrameDecoder, FrameEncoder};
63use inner::StreamMultiplexorInner;
64pub use listener::MuxListener;
65use socket::MuxSocket;
66
67pub type Result<T> = std::result::Result<T, io::Error>;
69
70pub struct StreamMultiplexor<T> {
72 inner: Arc<StreamMultiplexorInner<T>>,
73}
74
75impl<T> Debug for StreamMultiplexor<T> {
76 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
77 f.debug_struct("StreamMultiplexor")
78 .field("id", &self.inner.config.identifier)
79 .field("inner", &self.inner)
80 .finish()
81 }
82}
83
84impl<T> Drop for StreamMultiplexor<T> {
85 fn drop(&mut self) {
86 self.inner.watch_connected_send.send_replace(false);
87 error!("drop {:?}", self);
88 }
89}
90
91impl<T: AsyncRead + AsyncWrite + Send + Unpin + 'static> StreamMultiplexor<T> {
92 pub fn new(inner: T, config: StreamMultiplexorConfig) -> Self {
94 Self::new_running(inner, config, true)
95 }
96
97 pub fn new_paused(inner: T, config: StreamMultiplexorConfig) -> Self {
104 Self::new_running(inner, config, false)
105 }
106
107 fn new_running(inner: T, config: StreamMultiplexorConfig, running: bool) -> Self {
108 let (reader, writer) = split(inner);
109
110 let framed_reader = FramedRead::new(reader, FrameDecoder::new(config));
111 let framed_writer = FramedWrite::new(writer, FrameEncoder::new(config));
112 let (send, recv) = mpsc::channel(config.max_queued_frames);
113 let (watch_connected_send, watch_connected_recv) = watch::channel(true);
114 let (running, _) = watch::channel(running);
115 let inner = Arc::from(StreamMultiplexorInner {
116 config,
117 connected: AtomicBool::from(true),
118 port_connections: RwLock::from(HashMap::new()),
119 port_listeners: RwLock::from(HashMap::new()),
120 watch_connected_send,
121 send: RwLock::from(send),
122 running,
123 });
124
125 tokio::spawn(inner.clone().framed_writer_sender(recv, framed_writer));
126 tokio::spawn(inner.clone().framed_reader_sender(framed_reader));
127 tokio::spawn(inner.clone().handle_disconnected(watch_connected_recv));
128
129 Self { inner }
130 }
131
132 pub fn start(&self) {
137 self.inner.running.send_replace(true);
138 }
139
140 pub fn close(&self) {
143 self.inner.watch_connected_send.send_replace(false);
144 }
145
146 #[tracing::instrument]
148 pub async fn bind(&self, port: u16) -> Result<MuxListener<T>> {
149 trace!("");
150 if !self.inner.connected.load(Ordering::Relaxed) {
151 return Err(io::Error::from(io::ErrorKind::ConnectionReset));
152 }
153 let mut port = port;
154 if port == 0 {
155 while port < 1024 || self.inner.port_listeners.read().await.contains_key(&port) {
156 port = rand::thread_rng().gen_range(1024u16..u16::MAX);
157 }
158 trace!("port = {}", port);
159 } else if self.inner.port_listeners.read().await.contains_key(&port) {
160 trace!("port_listeners already contains {}", port);
161 return Err(io::Error::from(io::ErrorKind::AddrInUse));
162 }
163 let (send, recv) = async_channel::bounded(self.inner.config.accept_queue_len);
164 self.inner.port_listeners.write().await.insert(port, send);
165 Ok(MuxListener::new(self.inner.clone(), port, recv))
166 }
167
168 #[tracing::instrument]
170 pub async fn connect(&self, port: u16) -> Result<DuplexStream> {
171 trace!("");
172 if !self.inner.connected.load(Ordering::Relaxed) {
173 trace!("Not connected, raise Error");
174 return Err(io::Error::from(io::ErrorKind::ConnectionReset));
175 }
176 let mut sport: u16 = 0;
177 while sport < 1024
178 || self
179 .inner
180 .port_connections
181 .read()
182 .await
183 .contains_key(&(sport, port))
184 {
185 sport = rand::thread_rng().gen_range(1024u16..u16::MAX);
186 }
187 trace!("sport = {}", sport);
188
189 let mux_socket = MuxSocket::new(self.inner.clone(), sport, port, false);
190 let mut rx = mux_socket.stream().await;
191 self.inner
192 .port_connections
193 .write()
194 .await
195 .insert((sport, port), mux_socket.clone());
196 mux_socket.start().await;
197
198 rx.recv()
199 .await
200 .ok_or_else(|| io::Error::from(io::ErrorKind::Other))?
201 }
202
203 #[tracing::instrument]
206 pub fn watch_connected(&self) -> watch::Receiver<bool> {
207 trace!("");
208 self.inner.watch_connected_send.subscribe()
209 }
210}
211
212#[cfg(test)]
213mod test;