webrtc_ice/udp_mux/
mod.rs1use std::collections::HashMap;
2use std::io::ErrorKind;
3use std::net::SocketAddr;
4use std::sync::{Arc, Weak};
5
6use async_trait::async_trait;
7use tokio::sync::{watch, Mutex};
8use util::sync::RwLock;
9use util::{Conn, Error};
10
11mod udp_mux_conn;
12pub use udp_mux_conn::{UDPMuxConn, UDPMuxConnParams, UDPMuxWriter};
13
14#[cfg(test)]
15mod udp_mux_test;
16
17mod socket_addr_ext;
18
19use stun::attributes::ATTR_USERNAME;
20use stun::message::{is_message as is_stun_message, Message as STUNMessage};
21
22use crate::candidate::RECEIVE_MTU;
23
24fn normalize_socket_addr(target: &SocketAddr, socket_addr: &SocketAddr) -> SocketAddr {
28 match (target, socket_addr) {
29 (SocketAddr::V4(target_ipv4), SocketAddr::V6(_)) => {
30 let ipv6_mapped = target_ipv4.ip().to_ipv6_mapped();
31
32 SocketAddr::new(std::net::IpAddr::V6(ipv6_mapped), target_ipv4.port())
33 }
34 (_, _) => *target,
36 }
37}
38
39#[async_trait]
40pub trait UDPMux {
41 async fn close(&self) -> Result<(), Error>;
43
44 async fn get_conn(self: Arc<Self>, ufrag: &str) -> Result<Arc<dyn Conn + Send + Sync>, Error>;
46
47 async fn remove_conn_by_ufrag(&self, ufrag: &str);
49}
50
51pub struct UDPMuxParams {
52 conn: Box<dyn Conn + Send + Sync>,
53}
54
55impl UDPMuxParams {
56 pub fn new<C>(conn: C) -> Self
57 where
58 C: Conn + Send + Sync + 'static,
59 {
60 Self {
61 conn: Box::new(conn),
62 }
63 }
64}
65
66pub struct UDPMuxDefault {
67 params: UDPMuxParams,
70
71 conns: Mutex<HashMap<String, UDPMuxConn>>,
73
74 address_map: RwLock<HashMap<SocketAddr, UDPMuxConn>>,
76
77 closed_watch_tx: Mutex<Option<watch::Sender<()>>>,
79
80 closed_watch_rx: watch::Receiver<()>,
82}
83
84impl UDPMuxDefault {
85 pub fn new(params: UDPMuxParams) -> Arc<Self> {
86 let (closed_watch_tx, closed_watch_rx) = watch::channel(());
87
88 let mux = Arc::new(Self {
89 params,
90 conns: Mutex::default(),
91 address_map: RwLock::default(),
92 closed_watch_tx: Mutex::new(Some(closed_watch_tx)),
93 closed_watch_rx: closed_watch_rx.clone(),
94 });
95
96 let cloned_mux = Arc::clone(&mux);
97 cloned_mux.start_conn_worker(closed_watch_rx);
98
99 mux
100 }
101
102 pub async fn is_closed(&self) -> bool {
103 self.closed_watch_tx.lock().await.is_none()
104 }
105
106 fn create_muxed_conn(self: &Arc<Self>, ufrag: &str) -> Result<UDPMuxConn, Error> {
108 let local_addr = self.params.conn.local_addr()?;
109
110 let params = UDPMuxConnParams {
111 local_addr,
112 key: ufrag.into(),
113 udp_mux: Arc::downgrade(self) as Weak<dyn UDPMuxWriter + Send + Sync>,
114 };
115
116 Ok(UDPMuxConn::new(params))
117 }
118
119 async fn conn_from_stun_message(&self, buffer: &[u8], addr: &SocketAddr) -> Option<UDPMuxConn> {
120 let (result, message) = {
121 let mut m = STUNMessage::new();
122
123 (m.unmarshal_binary(buffer), m)
124 };
125
126 match result {
127 Err(err) => {
128 log::warn!("Failed to handle decode ICE from {addr}: {err}");
129 None
130 }
131 Ok(_) => {
132 let (attr, found) = message.attributes.get(ATTR_USERNAME);
133 if !found {
134 log::warn!("No username attribute in STUN message from {}", &addr);
135 return None;
136 }
137
138 let s = match String::from_utf8(attr.value) {
139 Err(err) => {
142 log::warn!("Failed to decode USERNAME from STUN message as UTF-8: {err}");
143 return None;
144 }
145 Ok(s) => s,
146 };
147
148 let conns = self.conns.lock().await;
149 let conn = s
150 .split(':')
151 .next()
152 .and_then(|ufrag| conns.get(ufrag))
153 .cloned();
154
155 conn
156 }
157 }
158 }
159
160 fn start_conn_worker(self: Arc<Self>, mut closed_watch_rx: watch::Receiver<()>) {
161 tokio::spawn(async move {
162 let mut buffer = [0u8; RECEIVE_MTU];
163
164 loop {
165 let loop_self = Arc::clone(&self);
166 let conn = &loop_self.params.conn;
167
168 tokio::select! {
169 res = conn.recv_from(&mut buffer) => {
170 match res {
171 Ok((len, addr)) => {
172 let conn = {
174 let address_map = loop_self
175 .address_map
176 .read();
177
178 address_map.get(&addr).cloned()
179 };
180
181 let conn = match conn {
182 None if is_stun_message(&buffer) => {
185 loop_self.conn_from_stun_message(&buffer, &addr).await
186 }
187 s @ Some(_) => s,
188 _ => None,
189 };
190
191 match conn {
192 None => {
193 log::trace!("Dropping packet from {}", &addr);
194 }
195 Some(conn) => {
196 if let Err(err) = conn.write_packet(&buffer[..len], addr).await {
197 log::error!("Failed to write packet: {err}");
198 }
199 }
200 }
201 }
202 Err(Error::Io(err)) if err.0.kind() == ErrorKind::TimedOut => continue,
203 Err(err) => {
204 log::error!("Could not read udp packet: {err}");
205 break;
206 }
207 }
208 }
209 _ = closed_watch_rx.changed() => {
210 return;
211 }
212 }
213 }
214 });
215 }
216}
217
218#[async_trait]
219impl UDPMux for UDPMuxDefault {
220 async fn close(&self) -> Result<(), Error> {
221 if self.is_closed().await {
222 return Err(Error::ErrAlreadyClosed);
223 }
224
225 let mut closed_tx = self.closed_watch_tx.lock().await;
226
227 if let Some(tx) = closed_tx.take() {
228 let _ = tx.send(());
229 drop(closed_tx);
230
231 let old_conns = {
232 let mut conns = self.conns.lock().await;
233
234 std::mem::take(&mut (*conns))
235 };
236
237 for (_, conn) in old_conns {
239 conn.close();
240 }
241
242 {
243 let mut address_map = self.address_map.write();
244
245 let _ = std::mem::take(&mut (*address_map));
248 }
249 }
250
251 Ok(())
252 }
253
254 async fn get_conn(self: Arc<Self>, ufrag: &str) -> Result<Arc<dyn Conn + Send + Sync>, Error> {
255 if self.is_closed().await {
256 return Err(Error::ErrUseClosedNetworkConn);
257 }
258
259 {
260 let mut conns = self.conns.lock().await;
261 if let Some(conn) = conns.get(ufrag) {
262 return Ok(Arc::new(conn.clone()) as Arc<dyn Conn + Send + Sync>);
265 }
266
267 let muxed_conn = self.create_muxed_conn(ufrag)?;
268 let mut close_rx = muxed_conn.close_rx();
269 let cloned_self = Arc::clone(&self);
270 let cloned_ufrag = ufrag.to_string();
271 tokio::spawn(async move {
272 let _ = close_rx.changed().await;
273
274 cloned_self.remove_conn_by_ufrag(&cloned_ufrag).await;
276 });
277
278 conns.insert(ufrag.into(), muxed_conn.clone());
279
280 Ok(Arc::new(muxed_conn) as Arc<dyn Conn + Send + Sync>)
281 }
282 }
283
284 async fn remove_conn_by_ufrag(&self, ufrag: &str) {
285 let removed_conn = {
289 let mut conns = self.conns.lock().await;
290 conns.remove(ufrag)
291 };
292
293 if let Some(conn) = removed_conn {
294 let mut address_map = self.address_map.write();
295
296 for address in conn.get_addresses() {
297 address_map.remove(&address);
298 }
299 }
300 }
301}
302
303#[async_trait]
304impl UDPMuxWriter for UDPMuxDefault {
305 async fn register_conn_for_address(&self, conn: &UDPMuxConn, addr: SocketAddr) {
306 if self.is_closed().await {
307 return;
308 }
309
310 let key = conn.key();
311 {
312 let mut addresses = self.address_map.write();
313
314 addresses
315 .entry(addr)
316 .and_modify(|e| {
317 if e.key() != key {
318 e.remove_address(&addr);
319 *e = conn.clone();
320 }
321 })
322 .or_insert_with(|| conn.clone());
323 }
324
325 log::debug!("Registered {addr} for {key}");
326 }
327
328 async fn send_to(&self, buf: &[u8], target: &SocketAddr) -> Result<usize, Error> {
329 self.params.conn.send_to(buf, *target).await
330 }
331}