penguin_mux/lib.rs
1//! Multiplexing streamed data and datagrams over a single WebSocket
2//! connection.
3//!
4//! This is not a general-purpose WebSocket multiplexing library.
5//! It is tailored to the needs of `penguin`.
6//
7// SPDX-License-Identifier: Apache-2.0 OR GPL-3.0-or-later
8#![deny(missing_docs, missing_debug_implementations)]
9
10pub mod config;
11mod dupe;
12pub mod frame;
13mod proto_version;
14mod stream;
15mod task;
16#[cfg(test)]
17mod tests;
18pub mod timing;
19pub mod ws;
20
21use crate::frame::{BindPayload, BindType, FinalizedFrame, Frame};
22use crate::task::{Task, TaskData};
23use crate::ws::WebSocket;
24use bytes::Bytes;
25use futures_util::task::AtomicWaker;
26use parking_lot::{Mutex, RwLock};
27use rand::distr::uniform::SampleUniform;
28use std::collections::HashMap;
29use std::future::poll_fn;
30use std::hash::Hash;
31use std::sync::Arc;
32use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
33use thiserror::Error;
34use tokio::sync::mpsc::error::TrySendError;
35use tokio::sync::{mpsc, oneshot};
36use tokio::task::JoinSet;
37use tracing::{error, trace, warn};
38
39pub use crate::dupe::Dupe;
40pub use crate::proto_version::{PROTOCOL_VERSION, PROTOCOL_VERSION_NUMBER};
41pub use crate::stream::MuxStream;
42
43/// Multiplexor error
44#[derive(Debug, Error)]
45#[non_exhaustive]
46pub enum Error {
47 /// Requester exited before receiving the stream
48 /// (i.e. the `Receiver` was dropped before the task could send the stream).
49 #[error("Requester exited before receiving the stream")]
50 SendStreamToClient,
51 /// The multiplexor is closed.
52 #[error("Mux is already closed")]
53 Closed,
54 /// The peer does not support the requested operation.
55 #[error("Peer does not support requested operation")]
56 PeerUnsupportedOperation,
57 /// This `Multiplexor` is not configured for this operation.
58 #[error("Unsupported operation")]
59 UnsupportedOperation,
60 /// Peer rejected the flow ID selection.
61 #[error("Peer rejected flow ID selection")]
62 FlowIdRejected,
63
64 /// WebSocket errors
65 #[error("WebSocket Error: {0}")]
66 WebSocket(Box<dyn std::error::Error + Send>),
67
68 // These are the ones that shouldn't normally happen
69 /// A `Datagram` frame with a target host longer than 255 octets.
70 #[error("Datagram target host longer than 255 octets")]
71 DatagramHostTooLong,
72 /// Received an invalid frame.
73 #[error("Invalid frame: {0}")]
74 InvalidFrame(#[from] frame::Error),
75 /// The peer sent a `Text` message.
76 /// "The client and server MUST NOT use other WebSocket data frame types"
77 #[error("Received `Text` message")]
78 TextMessage,
79 /// A `Acknowledge` frame that does not match any pending [`Connect`](frame::OpCode::Connect) request.
80 #[error("Bogus `Acknowledge` frame")]
81 ConnAckGone,
82 /// An internal channel closed
83 #[error("Internal channel `{0}` closed")]
84 ChannelClosed(&'static str),
85}
86
87/// A variant of [`std::result::Result`] with [`enum@Error`] as the error type.
88pub type Result<T> = std::result::Result<T, Error>;
89
90/// A multiplexor over a `WebSocket` connection.
91#[derive(Debug)]
92pub struct Multiplexor {
93 /// Open stream channels: `flow_id` -> `FlowSlot`
94 flows: Arc<RwLock<HashMap<u32, FlowSlot>>>,
95 /// Where tasks queue frames to be sent
96 tx_frame_tx: mpsc::UnboundedSender<FinalizedFrame>,
97 /// We only use this to inform the task that the multiplexor is closed
98 /// and it should stop processing.
99 dropped_ports_tx: mpsc::UnboundedSender<u32>,
100 /// Channel of received datagram frames for processing.
101 datagram_rx: Mutex<mpsc::Receiver<Datagram>>,
102 /// Channel for a `Multiplexor` to receive newly
103 /// established streams after the peer requests one.
104 con_recv_stream_rx: Mutex<mpsc::Receiver<MuxStream>>,
105 /// Channel for `Bnd` requests.
106 bnd_request_rx: Option<Mutex<mpsc::Receiver<BindRequest<'static>>>>,
107 /// Number of retries to find a suitable flow ID
108 /// See [`config::Options`] for more details.
109 max_flow_id_retries: usize,
110 /// Number of `StreamFrame`s to buffer in `MuxStream`'s channels before blocking
111 /// See [`config::Options`] for more details.
112 rwnd: u32,
113}
114
115impl Multiplexor {
116 /// Create a new `Multiplexor`.
117 ///
118 /// # Arguments
119 ///
120 /// * `ws`: The `WebSocket` connection to multiplex over.
121 ///
122 /// * `options`: Multiplexor options. See [`config::Options`] for more details.
123 /// If `None`, the default options will be used.
124 ///
125 /// * `task_joinset`: A [`JoinSet`] to spawn the multiplexor task into so
126 /// that the caller can notice if the task exits. If it is `None`, the
127 /// task will be spawned by `tokio::spawn` and errors will be logged.
128 #[tracing::instrument(skip_all, level = "debug")]
129 pub fn new<S: WebSocket>(
130 ws: S,
131 options: Option<config::Options>,
132 task_joinset: Option<&mut JoinSet<Result<()>>>,
133 ) -> Self {
134 let options = options.unwrap_or_default();
135 let (datagram_tx, datagram_rx) = mpsc::channel(options.datagram_buffer_size);
136 let (con_recv_stream_tx, con_recv_stream_rx) = mpsc::channel(options.stream_buffer_size);
137 // This one is unbounded because the protocol provides its own flow control for `Push` frames
138 // and other frame types are to be immediately processed without any backpressure,
139 // so they are ok to be unbounded channels.
140 let (tx_frame_tx, tx_frame_rx) = mpsc::unbounded_channel();
141 // This one cannot be bounded because it needs to be used in Drop
142 let (dropped_ports_tx, dropped_ports_rx) = mpsc::unbounded_channel();
143
144 let (bnd_request_tx, bnd_request_rx) = if options.bind_buffer_size > 0 {
145 let (tx, rx) = mpsc::channel(options.bind_buffer_size);
146 (Some(tx), Some(rx))
147 } else {
148 (None, None)
149 };
150 let flows = Arc::new(RwLock::new(HashMap::new()));
151
152 let mux = Self {
153 tx_frame_tx: tx_frame_tx.dupe(),
154 flows: flows.dupe(),
155 dropped_ports_tx: dropped_ports_tx.dupe(),
156 datagram_rx: Mutex::new(datagram_rx),
157 con_recv_stream_rx: Mutex::new(con_recv_stream_rx),
158 bnd_request_rx: bnd_request_rx.map(Mutex::new),
159 max_flow_id_retries: options.max_flow_id_retries,
160 rwnd: options.rwnd,
161 };
162 let taskdata = TaskData {
163 task: Task {
164 ws: Mutex::new(ws),
165 tx_frame_tx,
166 flows,
167 dropped_ports_tx,
168 con_recv_stream_tx,
169 default_rwnd_threshold: options.default_rwnd_threshold,
170 rwnd: options.rwnd,
171 datagram_tx,
172 bnd_request_tx,
173 keepalive_interval: options.keepalive_interval,
174 },
175 dropped_ports_rx,
176 tx_frame_rx,
177 };
178 taskdata.spawn(task_joinset);
179 mux
180 }
181
182 /// Request a channel for `host` and `port`.
183 ///
184 /// # Arguments
185 /// * `host`: The host to forward to. While the current implementation
186 /// supports a domain of arbitrary length, Section 3.2.2 of
187 /// [RFC 3986](https://www.rfc-editor.org/rfc/rfc3986#section-3.2.2)
188 /// specifies that the host component of a URI is limited to 255 octets.
189 /// * `port`: The port to forward to.
190 ///
191 /// # Cancel safety
192 /// This function is not cancel safe. If the task is cancelled while waiting
193 /// for the channel to be established, that channel may be established but
194 /// inaccessible through normal means. Subsequent calls to this function
195 /// will result in a new channel being established.
196 #[tracing::instrument(skip(self), level = "debug")]
197 pub async fn new_stream_channel(&self, host: &[u8], port: u16) -> Result<MuxStream> {
198 let mut retries_left = self.max_flow_id_retries;
199 // Normally this should terminate in one loop
200 while retries_left > 0 {
201 retries_left -= 1;
202 let (stream_tx, stream_rx) = oneshot::channel();
203 let flow_id = {
204 let mut streams = self.flows.write();
205 // Allocate a new port
206 let flow_id = u32::next_available_key(&*streams);
207 trace!("flow_id = {flow_id:08x}");
208 streams.insert(flow_id, FlowSlot::Requested(stream_tx));
209 flow_id
210 };
211 trace!("sending `Connect`");
212 self.tx_frame_tx
213 .send(Frame::new_connect(host, port, flow_id, self.rwnd).finalize())
214 .map_err(|_| Error::Closed)?;
215 trace!("sending stream to user");
216 let stream = stream_rx
217 .await
218 // Happens if the task exits before sending the stream,
219 // thus `Closed` is the correct error
220 .map_err(|_| Error::Closed)?;
221 if let Some(s) = stream {
222 return Ok(s);
223 }
224 // For testing purposes. Make sure the previous flow ID is gone
225 debug_assert!(!self.flows.read().contains_key(&flow_id));
226 }
227 Err(Error::FlowIdRejected)
228 }
229
230 /// Accept a new stream channel from the remote peer.
231 ///
232 /// # Errors
233 /// Returns [`Error::Closed`] if the connection is closed.
234 ///
235 /// # Cancel Safety
236 /// This function is cancel safe. If the task is cancelled while waiting
237 /// for a new connection, it is guaranteed that no connected stream will
238 /// be lost.
239 #[tracing::instrument(skip(self), level = "debug")]
240 pub async fn accept_stream_channel(&self) -> Result<MuxStream> {
241 poll_fn(|cx| self.con_recv_stream_rx.lock().poll_recv(cx))
242 .await
243 .ok_or(Error::Closed)
244 }
245
246 /// Get the next available datagram.
247 ///
248 /// # Errors
249 /// Returns [`Error::Closed`] if the connection is closed.
250 ///
251 /// # Cancel Safety
252 /// This function is cancel safe. If the task is cancelled while waiting
253 /// for a datagram, it is guaranteed that no datagram will be lost.
254 #[tracing::instrument(skip(self), level = "debug")]
255 #[inline]
256 pub async fn get_datagram(&self) -> Result<Datagram> {
257 poll_fn(|cx| self.datagram_rx.lock().poll_recv(cx))
258 .await
259 .ok_or(Error::Closed)
260 }
261
262 /// Send a datagram
263 ///
264 /// # Errors
265 /// * Returns [`Error::DatagramHostTooLong`] if the destination host is
266 /// longer than 255 octets.
267 /// * Returns [`Error::Closed`] if the Multiplexor is already closed.
268 ///
269 /// # Cancel Safety
270 /// This function is cancel safe. If the task is cancelled, it is
271 /// guaranteed that the datagram has not been sent.
272 #[tracing::instrument(skip(self), level = "debug")]
273 #[inline]
274 pub async fn send_datagram(&self, datagram: Datagram) -> Result<()> {
275 if datagram.target_host.len() > 255 {
276 return Err(Error::DatagramHostTooLong);
277 }
278 let frame = Frame::new_datagram_owned(
279 datagram.flow_id,
280 datagram.target_host,
281 datagram.target_port,
282 datagram.data,
283 );
284 self.tx_frame_tx
285 .send(frame.finalize())
286 .map_err(|_| Error::Closed)?;
287 Ok(())
288 }
289
290 /// Request a `Bind` for `host` and `port`.
291 ///
292 /// # Arguments
293 /// * `host`: The local address or host to bind to. Hostname resolution might
294 /// not be supported by the remote peer.
295 /// * `port`: The local port to bind to.
296 ///
297 /// # Cancel Safety
298 /// This function is not cancel safe. If the task is cancelled while waiting
299 /// for the peer to reply, the user will not be able to receive whether the
300 /// peer accepted the bind request.
301 #[tracing::instrument(skip(self), level = "debug")]
302 #[inline]
303 pub async fn request_bind(&self, host: &[u8], port: u16, bind_type: BindType) -> Result<bool> {
304 let (result_tx, result_rx) = oneshot::channel();
305 let flow_id = {
306 let mut streams = self.flows.write();
307 // Allocate a new port
308 let flow_id = u32::next_available_key(&*streams);
309 trace!("flow_id = {flow_id:08x}");
310 streams.insert(flow_id, FlowSlot::BindRequested(result_tx));
311 flow_id
312 };
313 let bnd_frame = Frame::new_bind(flow_id, bind_type, host, port).finalize();
314 self.tx_frame_tx
315 .send(bnd_frame)
316 .map_err(|_| Error::Closed)?;
317 let result = result_rx.await.map_err(|_| Error::Closed)?;
318 Ok(result)
319 }
320
321 /// Accept a `Bind` request from the remote peer.
322 ///
323 /// # Cancel Safety
324 /// This function is cancel safe. If the task is cancelled while waiting
325 /// for a `Bind` request, it is guaranteed that no request will be lost.
326 #[tracing::instrument(skip(self), level = "debug")]
327 pub async fn next_bind_request(&self) -> Result<BindRequest<'static>> {
328 if let Some(rx) = self.bnd_request_rx.as_ref() {
329 poll_fn(|cx| rx.lock().poll_recv(cx))
330 .await
331 .ok_or(Error::Closed)
332 } else {
333 Err(Error::UnsupportedOperation)
334 }
335 }
336}
337
338impl Drop for Multiplexor {
339 fn drop(&mut self) {
340 if self.dropped_ports_tx.send(0).is_err() {
341 error!("Failed to inform task of dropped multiplexor");
342 }
343 }
344}
345
346#[derive(Debug)]
347struct EstablishedStreamData {
348 /// Channel for sending data to `MuxStream`'s `AsyncRead`
349 /// If `None`, we have received `Finish` from the peer but we can possibly still send data.
350 sender: Option<mpsc::Sender<Bytes>>,
351 /// Whether writes should succeed.
352 /// There are two cases for `true`:
353 /// 1. `Finish` has been sent.
354 /// 2. The stream has been removed from `inner.streams`.
355 // In general, our `Atomic*` types don't need more than `Relaxed` ordering
356 // because we are not protecting memory accesses, but rather counting the
357 // frames we have sent and received.
358 finish_sent: Arc<AtomicBool>,
359 /// Number of `Push` frames we are allowed to send before waiting for a `Acknowledge` frame.
360 psh_send_remaining: Arc<AtomicU32>,
361 /// Waker to wake up the task that sends frames because their `psh_send_remaining`
362 /// has increased.
363 writer_waker: Arc<AtomicWaker>,
364}
365
366impl EstablishedStreamData {
367 /// Process a `Finish` frame from the peer and thus disallowing further `AsyncRead` operations
368 /// Returns the sender if it was not already taken.
369 #[inline]
370 const fn disallow_read(&mut self) -> Option<mpsc::Sender<Bytes>> {
371 self.sender.take()
372 }
373
374 /// Process a `Acknowledge` frame from the peer
375 #[inline]
376 fn acknowledge(&self, acknowledged: u32) {
377 // Atomic ordering: as long as the value is incremented atomically,
378 // whether a writer sees the new value or the old value is not
379 // important. If it sees the old value and decides to return
380 // `Poll::Pending`, it will be woken up by the `Waker` anyway.
381 self.psh_send_remaining
382 .fetch_add(acknowledged, Ordering::Relaxed);
383 // Wake up the writer if it is waiting for `Acknowledge`
384 self.writer_waker.wake();
385 }
386
387 /// Disallow any `AsyncWrite` operations.
388 /// Note that this should not be used from inside the `MuxStream` itself
389 #[inline]
390 fn disallow_write(&self) -> bool {
391 // Atomic ordering:
392 // Load part:
393 // If the user calls `poll_shutdown`, but we see `true` here,
394 // the other end will receive a bogus `Reset` frame, which is fine.
395 // Store part:
396 // We need to make sure the writer can see the new value
397 // before we call `wake()`.
398 let old = self.finish_sent.swap(true, Ordering::AcqRel);
399 // If there is a writer waiting for `Acknowledge`, wake it up because it will never receive one.
400 // Waking it here and the user should receive a `BrokenPipe` error.
401 self.writer_waker.wake();
402 old
403 }
404}
405
406#[derive(Debug)]
407enum FlowSlot {
408 /// A `Connect` frame was sent and waiting for the peer to `Acknowledge`.
409 Requested(oneshot::Sender<Option<MuxStream>>),
410 /// The stream is established.
411 Established(EstablishedStreamData),
412 /// A `Bind` request was sent and waiting for the peer to `Acknowledge` or `Reset`.
413 BindRequested(oneshot::Sender<bool>),
414}
415
416impl FlowSlot {
417 /// Take the sender and set the slot to `Established`.
418 /// Returns `None` if the slot is already established.
419 #[inline]
420 fn establish(
421 &mut self,
422 data: EstablishedStreamData,
423 ) -> Option<oneshot::Sender<Option<MuxStream>>> {
424 // Make sure it is not replaced in the error case
425 if matches!(self, Self::Established(_) | Self::BindRequested(_)) {
426 error!("establishing an established or invalid slot");
427 return None;
428 }
429 let sender = match std::mem::replace(self, Self::Established(data)) {
430 Self::Requested(sender) => sender,
431 Self::Established(_) | Self::BindRequested(_) => unreachable!(),
432 };
433 Some(sender)
434 }
435
436 /// If the slot is established, send data. Otherwise, return `None`.
437 #[inline]
438 fn dispatch(&self, data: Bytes) -> Option<std::result::Result<(), TrySendError<()>>> {
439 if let Self::Established(stream_data) = self {
440 let r = stream_data
441 .sender
442 .as_ref()
443 .map(|sender| sender.try_send(data))?
444 .map_err(|e| match e {
445 TrySendError::Full(_) => TrySendError::Full(()),
446 TrySendError::Closed(_) => TrySendError::Closed(()),
447 });
448 Some(r)
449 } else {
450 None
451 }
452 }
453}
454
455/// Datagram frame data
456#[derive(Clone, Debug)]
457pub struct Datagram {
458 /// Flow ID
459 pub flow_id: u32,
460 /// Target host
461 pub target_host: Bytes,
462 /// Target port
463 pub target_port: u16,
464 /// Data
465 pub data: Bytes,
466}
467
468/// A `Bind` request that the user can respond to
469#[derive(Debug)]
470pub struct BindRequest<'data> {
471 /// Flow ID
472 flow_id: u32,
473 /// Bind payload
474 payload: BindPayload<'data>,
475 /// Place to respond to the bind request
476 tx_frame_tx: mpsc::UnboundedSender<FinalizedFrame>,
477}
478
479impl BindRequest<'_> {
480 /// Get the flow ID of the bind request
481 #[inline]
482 pub const fn flow_id(&self) -> u32 {
483 self.flow_id
484 }
485
486 /// Get the bind type of the bind request
487 #[inline]
488 pub const fn bind_type(&self) -> BindType {
489 self.payload.bind_type
490 }
491
492 /// Get the host of the bind request
493 #[inline]
494 pub fn host(&self) -> &[u8] {
495 self.payload.target_host.as_ref()
496 }
497
498 /// Get the port of the bind request
499 #[inline]
500 pub const fn port(&self) -> u16 {
501 self.payload.target_port
502 }
503
504 /// Accept or reject the bind request
505 #[tracing::instrument(skip(self), level = "debug")]
506 pub fn reply(&self, accepted: bool) -> Result<()> {
507 if accepted {
508 self.tx_frame_tx
509 .send(Frame::new_finish(self.flow_id).finalize())
510 } else {
511 self.tx_frame_tx
512 .send(Frame::new_reset(self.flow_id).finalize())
513 }
514 .map_err(|_| Error::Closed)
515 }
516}
517
518impl Drop for BindRequest<'_> {
519 /// Dropping a `BindRequest` will reject the request
520 fn drop(&mut self) {
521 self.reply(false).ok();
522 }
523}
524
525/// Randomly generate a new number
526pub trait IntKey: Eq + Hash + Copy + SampleUniform + PartialOrd {
527 /// The minimum value of the key
528 const MIN: Self;
529 /// The maximum value of the key
530 const MAX: Self;
531
532 /// Generate a new key that is not in the map
533 #[inline]
534 #[must_use]
535 fn next_available_key<V>(map: &HashMap<Self, V>) -> Self {
536 loop {
537 let i = rand::random_range(Self::MIN..Self::MAX);
538 if !map.contains_key(&i) {
539 break i;
540 }
541 }
542 }
543}
544
545macro_rules! impl_int_key {
546 ($($t:ty),*) => {
547 $(
548 impl IntKey for $t {
549 // 0 is for special use
550 const MIN : Self = 1;
551 const MAX : Self = Self::MAX;
552 }
553 )*
554 };
555}
556
557impl_int_key!(u8, u16, u32, u64, u128, usize);