webrtc/mux/
mod.rs

1#[cfg(test)]
2mod mux_test;
3
4pub mod endpoint;
5pub mod mux_func;
6
7use std::collections::HashMap;
8use std::sync::atomic::Ordering;
9use std::sync::Arc;
10
11use portable_atomic::AtomicUsize;
12use tokio::sync::{mpsc, Mutex};
13use util::{Buffer, Conn};
14
15use crate::error::Result;
16use crate::mux::endpoint::Endpoint;
17use crate::mux::mux_func::MatchFunc;
18use crate::util::Error;
19
20/// mux multiplexes packets on a single socket (RFC7983)
21///
22/// The maximum amount of data that can be buffered before returning errors.
23const MAX_BUFFER_SIZE: usize = 1000 * 1000; // 1MB
24
25/// Config collects the arguments to mux.Mux construction into
26/// a single structure
27pub struct Config {
28    pub conn: Arc<dyn Conn + Send + Sync>,
29    pub buffer_size: usize,
30}
31
32/// Mux allows multiplexing
33#[derive(Clone)]
34pub struct Mux {
35    id: Arc<AtomicUsize>,
36    next_conn: Arc<dyn Conn + Send + Sync>,
37    endpoints: Arc<Mutex<HashMap<usize, Arc<Endpoint>>>>,
38    buffer_size: usize,
39    closed_ch_tx: Option<mpsc::Sender<()>>,
40}
41
42impl Mux {
43    pub fn new(config: Config) -> Self {
44        let (closed_ch_tx, closed_ch_rx) = mpsc::channel(1);
45        let m = Mux {
46            id: Arc::new(AtomicUsize::new(0)),
47            next_conn: Arc::clone(&config.conn),
48            endpoints: Arc::new(Mutex::new(HashMap::new())),
49            buffer_size: config.buffer_size,
50            closed_ch_tx: Some(closed_ch_tx),
51        };
52
53        let buffer_size = m.buffer_size;
54        let next_conn = Arc::clone(&m.next_conn);
55        let endpoints = Arc::clone(&m.endpoints);
56        tokio::spawn(async move {
57            Mux::read_loop(buffer_size, next_conn, closed_ch_rx, endpoints).await;
58        });
59
60        m
61    }
62
63    /// creates a new Endpoint
64    pub async fn new_endpoint(&self, f: MatchFunc) -> Arc<Endpoint> {
65        let mut endpoints = self.endpoints.lock().await;
66
67        let id = self.id.fetch_add(1, Ordering::SeqCst);
68        // Set a maximum size of the buffer in bytes.
69        let e = Arc::new(Endpoint {
70            id,
71            buffer: Buffer::new(0, MAX_BUFFER_SIZE),
72            match_fn: f,
73            next_conn: Arc::clone(&self.next_conn),
74            endpoints: Arc::clone(&self.endpoints),
75        });
76
77        endpoints.insert(e.id, Arc::clone(&e));
78
79        e
80    }
81
82    /// remove_endpoint removes an endpoint from the Mux
83    pub async fn remove_endpoint(&mut self, e: &Endpoint) {
84        let mut endpoints = self.endpoints.lock().await;
85        endpoints.remove(&e.id);
86    }
87
88    /// Close closes the Mux and all associated Endpoints.
89    pub async fn close(&mut self) {
90        self.closed_ch_tx.take();
91
92        let mut endpoints = self.endpoints.lock().await;
93        endpoints.clear();
94    }
95
96    async fn read_loop(
97        buffer_size: usize,
98        next_conn: Arc<dyn Conn + Send + Sync>,
99        mut closed_ch_rx: mpsc::Receiver<()>,
100        endpoints: Arc<Mutex<HashMap<usize, Arc<Endpoint>>>>,
101    ) {
102        let mut buf = vec![0u8; buffer_size];
103        let mut n = 0usize;
104        loop {
105            tokio::select! {
106                _ = closed_ch_rx.recv() => break,
107                result = next_conn.recv(&mut buf) => {
108                    if let Ok(m) = result{
109                        n = m;
110                    }
111                }
112            };
113
114            if let Err(err) = Mux::dispatch(&buf[..n], &endpoints).await {
115                log::error!("mux: ending readLoop dispatch error {:?}", err);
116                break;
117            }
118        }
119    }
120
121    async fn dispatch(
122        buf: &[u8],
123        endpoints: &Arc<Mutex<HashMap<usize, Arc<Endpoint>>>>,
124    ) -> Result<()> {
125        let mut endpoint = None;
126
127        {
128            let eps = endpoints.lock().await;
129            for ep in eps.values() {
130                if (ep.match_fn)(buf) {
131                    endpoint = Some(Arc::clone(ep));
132                    break;
133                }
134            }
135        }
136
137        if let Some(ep) = endpoint {
138            match ep.buffer.write(buf).await {
139                // Expected when bytes are received faster than the endpoint can process them
140                Err(Error::ErrBufferFull) => {
141                    log::info!("mux: endpoint buffer is full, dropping packet")
142                }
143                Ok(_) => (),
144                Err(e) => return Err(crate::Error::Util(e)),
145            }
146        } else if !buf.is_empty() {
147            log::warn!(
148                "Warning: mux: no endpoint for packet starting with {}",
149                buf[0]
150            );
151        } else {
152            log::warn!("Warning: mux: no endpoint for zero length packet");
153        }
154
155        Ok(())
156    }
157}