1use futures::{future, prelude::*, ready, stream::{BoxStream, LocalBoxStream}};
25use tetsy_libp2p_core::muxing::{StreamMuxer, StreamMuxerEvent};
26use tetsy_libp2p_core::upgrade::{InboundUpgrade, OutboundUpgrade, UpgradeInfo};
27use parking_lot::Mutex;
28use std::{fmt, io, iter, pin::Pin, task::{Context, Poll}};
29use thiserror::Error;
30
31pub struct Remux<S>(Mutex<Inner<S>>);
33
34impl<S> fmt::Debug for Remux<S> {
35 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
36 f.write_str("Remux")
37 }
38}
39
40struct Inner<S> {
41 incoming: S,
43 control: remux::Control,
45}
46
47#[derive(Debug)]
49pub struct OpenSubstreamToken(());
50
51impl<C> Remux<Incoming<C>>
52where
53 C: AsyncRead + AsyncWrite + Send + Unpin + 'static
54{
55 fn new(io: C, cfg: remux::Config, mode: remux::Mode) -> Self {
57 let conn = remux::Connection::new(io, cfg, mode);
58 let ctrl = conn.control();
59 let inner = Inner {
60 incoming: Incoming {
61 stream: remux::into_stream(conn).err_into().boxed(),
62 _marker: std::marker::PhantomData
63 },
64 control: ctrl,
65 };
66 Remux(Mutex::new(inner))
67 }
68}
69
70impl<C> Remux<LocalIncoming<C>>
71where
72 C: AsyncRead + AsyncWrite + Unpin + 'static
73{
74 fn local(io: C, cfg: remux::Config, mode: remux::Mode) -> Self {
76 let conn = remux::Connection::new(io, cfg, mode);
77 let ctrl = conn.control();
78 let inner = Inner {
79 incoming: LocalIncoming {
80 stream: remux::into_stream(conn).err_into().boxed_local(),
81 _marker: std::marker::PhantomData
82 },
83 control: ctrl,
84 };
85 Remux(Mutex::new(inner))
86 }
87}
88
89pub type RemuxResult<T> = Result<T, RemuxError>;
90
91impl<S> StreamMuxer for Remux<S>
93where
94 S: Stream<Item = Result<remux::Stream, RemuxError>> + Unpin
95{
96 type Substream = remux::Stream;
97 type OutboundSubstream = OpenSubstreamToken;
98 type Error = RemuxError;
99
100 fn poll_event(&self, c: &mut Context<'_>)
101 -> Poll<RemuxResult<StreamMuxerEvent<Self::Substream>>>
102 {
103 let mut inner = self.0.lock();
104 match ready!(inner.incoming.poll_next_unpin(c)) {
105 Some(Ok(s)) => Poll::Ready(Ok(StreamMuxerEvent::InboundSubstream(s))),
106 Some(Err(e)) => Poll::Ready(Err(e)),
107 None => Poll::Ready(Err(remux::ConnectionError::Closed.into()))
108 }
109 }
110
111 fn open_outbound(&self) -> Self::OutboundSubstream {
112 OpenSubstreamToken(())
113 }
114
115 fn poll_outbound(&self, c: &mut Context<'_>, _: &mut OpenSubstreamToken)
116 -> Poll<RemuxResult<Self::Substream>>
117 {
118 let mut inner = self.0.lock();
119 Pin::new(&mut inner.control).poll_open_stream(c).map_err(RemuxError)
120 }
121
122 fn destroy_outbound(&self, _: Self::OutboundSubstream) {
123 self.0.lock().control.abort_open_stream()
124 }
125
126 fn read_substream(&self, c: &mut Context<'_>, s: &mut Self::Substream, b: &mut [u8])
127 -> Poll<RemuxResult<usize>>
128 {
129 Pin::new(s).poll_read(c, b).map_err(|e| RemuxError(e.into()))
130 }
131
132 fn write_substream(&self, c: &mut Context<'_>, s: &mut Self::Substream, b: &[u8])
133 -> Poll<RemuxResult<usize>>
134 {
135 Pin::new(s).poll_write(c, b).map_err(|e| RemuxError(e.into()))
136 }
137
138 fn flush_substream(&self, c: &mut Context<'_>, s: &mut Self::Substream)
139 -> Poll<RemuxResult<()>>
140 {
141 Pin::new(s).poll_flush(c).map_err(|e| RemuxError(e.into()))
142 }
143
144 fn shutdown_substream(&self, c: &mut Context<'_>, s: &mut Self::Substream)
145 -> Poll<RemuxResult<()>>
146 {
147 Pin::new(s).poll_close(c).map_err(|e| RemuxError(e.into()))
148 }
149
150 fn destroy_substream(&self, _: Self::Substream) { }
151
152 fn close(&self, c: &mut Context<'_>) -> Poll<RemuxResult<()>> {
153 let mut inner = self.0.lock();
154 if let std::task::Poll::Ready(x) = Pin::new(&mut inner.control).poll_close(c) {
155 return Poll::Ready(x.map_err(RemuxError))
156 }
157 while let std::task::Poll::Ready(x) = inner.incoming.poll_next_unpin(c) {
158 match x {
159 Some(Ok(_)) => {} Some(Err(e)) => return Poll::Ready(Err(e)),
161 None => return Poll::Ready(Ok(()))
162 }
163 }
164 Poll::Pending
165 }
166
167 fn flush_all(&self, _: &mut Context<'_>) -> Poll<RemuxResult<()>> {
168 Poll::Ready(Ok(()))
169 }
170}
171
172#[derive(Clone)]
174pub struct RemuxConfig {
175 inner: remux::Config,
176 mode: Option<remux::Mode>
177}
178
179pub struct WindowUpdateMode(remux::WindowUpdateMode);
182
183impl WindowUpdateMode {
184 pub fn on_receive() -> Self {
197 WindowUpdateMode(remux::WindowUpdateMode::OnReceive)
198 }
199
200 pub fn on_read() -> Self {
215 WindowUpdateMode(remux::WindowUpdateMode::OnRead)
216 }
217}
218
219#[derive(Clone)]
221pub struct RemuxLocalConfig(RemuxConfig);
222
223impl RemuxConfig {
224 pub fn client() -> Self {
227 let mut cfg = Self::default();
228 cfg.mode = Some(remux::Mode::Client);
229 cfg
230 }
231
232 pub fn server() -> Self {
235 let mut cfg = Self::default();
236 cfg.mode = Some(remux::Mode::Server);
237 cfg
238 }
239
240 pub fn set_receive_window_size(&mut self, num_bytes: u32) -> &mut Self {
242 self.inner.set_receive_window(num_bytes);
243 self
244 }
245
246 pub fn set_max_buffer_size(&mut self, num_bytes: usize) -> &mut Self {
248 self.inner.set_max_buffer_size(num_bytes);
249 self
250 }
251
252 pub fn set_max_num_streams(&mut self, num_streams: usize) -> &mut Self {
254 self.inner.set_max_num_streams(num_streams);
255 self
256 }
257
258 pub fn set_window_update_mode(&mut self, mode: WindowUpdateMode) -> &mut Self {
261 self.inner.set_window_update_mode(mode.0);
262 self
263 }
264
265 pub fn into_local(self) -> RemuxLocalConfig {
268 RemuxLocalConfig(self)
269 }
270}
271
272impl Default for RemuxConfig {
273 fn default() -> Self {
274 let mut inner = remux::Config::default();
275 inner.set_read_after_close(false);
278 RemuxConfig { inner, mode: None }
279 }
280}
281
282impl UpgradeInfo for RemuxConfig {
283 type Info = &'static [u8];
284 type InfoIter = iter::Once<Self::Info>;
285
286 fn protocol_info(&self) -> Self::InfoIter {
287 iter::once(b"/remux/1.0.0")
288 }
289}
290
291impl UpgradeInfo for RemuxLocalConfig {
292 type Info = &'static [u8];
293 type InfoIter = iter::Once<Self::Info>;
294
295 fn protocol_info(&self) -> Self::InfoIter {
296 iter::once(b"/remux/1.0.0")
297 }
298}
299
300impl<C> InboundUpgrade<C> for RemuxConfig
301where
302 C: AsyncRead + AsyncWrite + Send + Unpin + 'static
303{
304 type Output = Remux<Incoming<C>>;
305 type Error = io::Error;
306 type Future = future::Ready<Result<Self::Output, Self::Error>>;
307
308 fn upgrade_inbound(self, io: C, _: Self::Info) -> Self::Future {
309 let mode = self.mode.unwrap_or(remux::Mode::Server);
310 future::ready(Ok(Remux::new(io, self.inner, mode)))
311 }
312}
313
314impl<C> InboundUpgrade<C> for RemuxLocalConfig
315where
316 C: AsyncRead + AsyncWrite + Unpin + 'static
317{
318 type Output = Remux<LocalIncoming<C>>;
319 type Error = io::Error;
320 type Future = future::Ready<Result<Self::Output, Self::Error>>;
321
322 fn upgrade_inbound(self, io: C, _: Self::Info) -> Self::Future {
323 let cfg = self.0;
324 let mode = cfg.mode.unwrap_or(remux::Mode::Server);
325 future::ready(Ok(Remux::local(io, cfg.inner, mode)))
326 }
327}
328
329impl<C> OutboundUpgrade<C> for RemuxConfig
330where
331 C: AsyncRead + AsyncWrite + Send + Unpin + 'static
332{
333 type Output = Remux<Incoming<C>>;
334 type Error = io::Error;
335 type Future = future::Ready<Result<Self::Output, Self::Error>>;
336
337 fn upgrade_outbound(self, io: C, _: Self::Info) -> Self::Future {
338 let mode = self.mode.unwrap_or(remux::Mode::Client);
339 future::ready(Ok(Remux::new(io, self.inner, mode)))
340 }
341}
342
343impl<C> OutboundUpgrade<C> for RemuxLocalConfig
344where
345 C: AsyncRead + AsyncWrite + Unpin + 'static
346{
347 type Output = Remux<LocalIncoming<C>>;
348 type Error = io::Error;
349 type Future = future::Ready<Result<Self::Output, Self::Error>>;
350
351 fn upgrade_outbound(self, io: C, _: Self::Info) -> Self::Future {
352 let cfg = self.0;
353 let mode = cfg.mode.unwrap_or(remux::Mode::Client);
354 future::ready(Ok(Remux::local(io, cfg.inner, mode)))
355 }
356}
357
358#[derive(Debug, Error)]
360#[error("remux error: {0}")]
361pub struct RemuxError(#[from] remux::ConnectionError);
362
363impl Into<io::Error> for RemuxError {
364 fn into(self: RemuxError) -> io::Error {
365 match self.0 {
366 remux::ConnectionError::Io(e) => e,
367 e => io::Error::new(io::ErrorKind::Other, e)
368 }
369 }
370}
371
372pub struct Incoming<T> {
374 stream: BoxStream<'static, Result<remux::Stream, RemuxError>>,
375 _marker: std::marker::PhantomData<T>
376}
377
378impl<T> fmt::Debug for Incoming<T> {
379 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
380 f.write_str("Incoming")
381 }
382}
383
384pub struct LocalIncoming<T> {
386 stream: LocalBoxStream<'static, Result<remux::Stream, RemuxError>>,
387 _marker: std::marker::PhantomData<T>
388}
389
390impl<T> fmt::Debug for LocalIncoming<T> {
391 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
392 f.write_str("LocalIncoming")
393 }
394}
395
396impl<T> Stream for Incoming<T> {
397 type Item = Result<remux::Stream, RemuxError>;
398
399 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> std::task::Poll<Option<Self::Item>> {
400 self.stream.as_mut().poll_next_unpin(cx)
401 }
402
403 fn size_hint(&self) -> (usize, Option<usize>) {
404 self.stream.size_hint()
405 }
406}
407
408impl<T> Unpin for Incoming<T> {
409}
410
411impl<T> Stream for LocalIncoming<T> {
412 type Item = Result<remux::Stream, RemuxError>;
413
414 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> std::task::Poll<Option<Self::Item>> {
415 self.stream.as_mut().poll_next_unpin(cx)
416 }
417
418 fn size_hint(&self) -> (usize, Option<usize>) {
419 self.stream.size_hint()
420 }
421}
422
423impl<T> Unpin for LocalIncoming<T> {
424}