1#[cfg(test)]
2mod mux_test;
3
4pub mod endpoint;
5pub mod mux_func;
6
7use std::collections::HashMap;
8use std::sync::atomic::Ordering;
9use std::sync::Arc;
10
11use portable_atomic::AtomicUsize;
12use tokio::sync::{mpsc, Mutex};
13use util::{Buffer, Conn};
14
15use crate::error::Result;
16use crate::mux::endpoint::Endpoint;
17use crate::mux::mux_func::MatchFunc;
18use crate::util::Error;
19
20const MAX_BUFFER_SIZE: usize = 1000 * 1000; pub struct Config {
28 pub conn: Arc<dyn Conn + Send + Sync>,
29 pub buffer_size: usize,
30}
31
32#[derive(Clone)]
34pub struct Mux {
35 id: Arc<AtomicUsize>,
36 next_conn: Arc<dyn Conn + Send + Sync>,
37 endpoints: Arc<Mutex<HashMap<usize, Arc<Endpoint>>>>,
38 buffer_size: usize,
39 closed_ch_tx: Option<mpsc::Sender<()>>,
40}
41
42impl Mux {
43 pub fn new(config: Config) -> Self {
44 let (closed_ch_tx, closed_ch_rx) = mpsc::channel(1);
45 let m = Mux {
46 id: Arc::new(AtomicUsize::new(0)),
47 next_conn: Arc::clone(&config.conn),
48 endpoints: Arc::new(Mutex::new(HashMap::new())),
49 buffer_size: config.buffer_size,
50 closed_ch_tx: Some(closed_ch_tx),
51 };
52
53 let buffer_size = m.buffer_size;
54 let next_conn = Arc::clone(&m.next_conn);
55 let endpoints = Arc::clone(&m.endpoints);
56 tokio::spawn(async move {
57 Mux::read_loop(buffer_size, next_conn, closed_ch_rx, endpoints).await;
58 });
59
60 m
61 }
62
63 pub async fn new_endpoint(&self, f: MatchFunc) -> Arc<Endpoint> {
65 let mut endpoints = self.endpoints.lock().await;
66
67 let id = self.id.fetch_add(1, Ordering::SeqCst);
68 let e = Arc::new(Endpoint {
70 id,
71 buffer: Buffer::new(0, MAX_BUFFER_SIZE),
72 match_fn: f,
73 next_conn: Arc::clone(&self.next_conn),
74 endpoints: Arc::clone(&self.endpoints),
75 });
76
77 endpoints.insert(e.id, Arc::clone(&e));
78
79 e
80 }
81
82 pub async fn remove_endpoint(&mut self, e: &Endpoint) {
84 let mut endpoints = self.endpoints.lock().await;
85 endpoints.remove(&e.id);
86 }
87
88 pub async fn close(&mut self) {
90 self.closed_ch_tx.take();
91
92 let mut endpoints = self.endpoints.lock().await;
93 endpoints.clear();
94 }
95
96 async fn read_loop(
97 buffer_size: usize,
98 next_conn: Arc<dyn Conn + Send + Sync>,
99 mut closed_ch_rx: mpsc::Receiver<()>,
100 endpoints: Arc<Mutex<HashMap<usize, Arc<Endpoint>>>>,
101 ) {
102 let mut buf = vec![0u8; buffer_size];
103 let mut n = 0usize;
104 loop {
105 tokio::select! {
106 _ = closed_ch_rx.recv() => break,
107 result = next_conn.recv(&mut buf) => {
108 if let Ok(m) = result{
109 n = m;
110 }
111 }
112 };
113
114 if let Err(err) = Mux::dispatch(&buf[..n], &endpoints).await {
115 log::error!("mux: ending readLoop dispatch error {:?}", err);
116 break;
117 }
118 }
119 }
120
121 async fn dispatch(
122 buf: &[u8],
123 endpoints: &Arc<Mutex<HashMap<usize, Arc<Endpoint>>>>,
124 ) -> Result<()> {
125 let mut endpoint = None;
126
127 {
128 let eps = endpoints.lock().await;
129 for ep in eps.values() {
130 if (ep.match_fn)(buf) {
131 endpoint = Some(Arc::clone(ep));
132 break;
133 }
134 }
135 }
136
137 if let Some(ep) = endpoint {
138 match ep.buffer.write(buf).await {
139 Err(Error::ErrBufferFull) => {
141 log::info!("mux: endpoint buffer is full, dropping packet")
142 }
143 Ok(_) => (),
144 Err(e) => return Err(crate::Error::Util(e)),
145 }
146 } else if !buf.is_empty() {
147 log::warn!(
148 "Warning: mux: no endpoint for packet starting with {}",
149 buf[0]
150 );
151 } else {
152 log::warn!("Warning: mux: no endpoint for zero length packet");
153 }
154
155 Ok(())
156 }
157}