webrtc_ice/udp_mux/
mod.rs

1use std::collections::HashMap;
2use std::io::ErrorKind;
3use std::net::SocketAddr;
4use std::sync::{Arc, Weak};
5
6use async_trait::async_trait;
7use tokio::sync::{watch, Mutex};
8use util::sync::RwLock;
9use util::{Conn, Error};
10
11mod udp_mux_conn;
12pub use udp_mux_conn::{UDPMuxConn, UDPMuxConnParams, UDPMuxWriter};
13
14#[cfg(test)]
15mod udp_mux_test;
16
17mod socket_addr_ext;
18
19use stun::attributes::ATTR_USERNAME;
20use stun::message::{is_message as is_stun_message, Message as STUNMessage};
21
22use crate::candidate::RECEIVE_MTU;
23
24/// Normalize a target socket addr for sending over a given local socket addr. This is useful when
25/// a dual stack socket is used, in which case an IPv4 target needs to be mapped to an IPv6
26/// address.
27fn normalize_socket_addr(target: &SocketAddr, socket_addr: &SocketAddr) -> SocketAddr {
28    match (target, socket_addr) {
29        (SocketAddr::V4(target_ipv4), SocketAddr::V6(_)) => {
30            let ipv6_mapped = target_ipv4.ip().to_ipv6_mapped();
31
32            SocketAddr::new(std::net::IpAddr::V6(ipv6_mapped), target_ipv4.port())
33        }
34        // This will fail later if target is IPv6 and socket is IPv4, we ignore it here
35        (_, _) => *target,
36    }
37}
38
39#[async_trait]
40pub trait UDPMux {
41    /// Close the muxing.
42    async fn close(&self) -> Result<(), Error>;
43
44    /// Get the underlying connection for a given ufrag.
45    async fn get_conn(self: Arc<Self>, ufrag: &str) -> Result<Arc<dyn Conn + Send + Sync>, Error>;
46
47    /// Remove the underlying connection for a given ufrag.
48    async fn remove_conn_by_ufrag(&self, ufrag: &str);
49}
50
51pub struct UDPMuxParams {
52    conn: Box<dyn Conn + Send + Sync>,
53}
54
55impl UDPMuxParams {
56    pub fn new<C>(conn: C) -> Self
57    where
58        C: Conn + Send + Sync + 'static,
59    {
60        Self {
61            conn: Box::new(conn),
62        }
63    }
64}
65
66pub struct UDPMuxDefault {
67    /// The params this instance is configured with.
68    /// Contains the underlying UDP socket in use
69    params: UDPMuxParams,
70
71    /// Maps from ufrag to the underlying connection.
72    conns: Mutex<HashMap<String, UDPMuxConn>>,
73
74    /// Maps from ip address to the underlying connection.
75    address_map: RwLock<HashMap<SocketAddr, UDPMuxConn>>,
76
77    // Close sender
78    closed_watch_tx: Mutex<Option<watch::Sender<()>>>,
79
80    /// Close receiver
81    closed_watch_rx: watch::Receiver<()>,
82}
83
84impl UDPMuxDefault {
85    pub fn new(params: UDPMuxParams) -> Arc<Self> {
86        let (closed_watch_tx, closed_watch_rx) = watch::channel(());
87
88        let mux = Arc::new(Self {
89            params,
90            conns: Mutex::default(),
91            address_map: RwLock::default(),
92            closed_watch_tx: Mutex::new(Some(closed_watch_tx)),
93            closed_watch_rx: closed_watch_rx.clone(),
94        });
95
96        let cloned_mux = Arc::clone(&mux);
97        cloned_mux.start_conn_worker(closed_watch_rx);
98
99        mux
100    }
101
102    pub async fn is_closed(&self) -> bool {
103        self.closed_watch_tx.lock().await.is_none()
104    }
105
106    /// Create a muxed connection for a given ufrag.
107    fn create_muxed_conn(self: &Arc<Self>, ufrag: &str) -> Result<UDPMuxConn, Error> {
108        let local_addr = self.params.conn.local_addr()?;
109
110        let params = UDPMuxConnParams {
111            local_addr,
112            key: ufrag.into(),
113            udp_mux: Arc::downgrade(self) as Weak<dyn UDPMuxWriter + Send + Sync>,
114        };
115
116        Ok(UDPMuxConn::new(params))
117    }
118
119    async fn conn_from_stun_message(&self, buffer: &[u8], addr: &SocketAddr) -> Option<UDPMuxConn> {
120        let (result, message) = {
121            let mut m = STUNMessage::new();
122
123            (m.unmarshal_binary(buffer), m)
124        };
125
126        match result {
127            Err(err) => {
128                log::warn!("Failed to handle decode ICE from {addr}: {err}");
129                None
130            }
131            Ok(_) => {
132                let (attr, found) = message.attributes.get(ATTR_USERNAME);
133                if !found {
134                    log::warn!("No username attribute in STUN message from {}", &addr);
135                    return None;
136                }
137
138                let s = match String::from_utf8(attr.value) {
139                    // Per the RFC this shouldn't happen
140                    // https://datatracker.ietf.org/doc/html/rfc5389#section-15.3
141                    Err(err) => {
142                        log::warn!("Failed to decode USERNAME from STUN message as UTF-8: {err}");
143                        return None;
144                    }
145                    Ok(s) => s,
146                };
147
148                let conns = self.conns.lock().await;
149                let conn = s
150                    .split(':')
151                    .next()
152                    .and_then(|ufrag| conns.get(ufrag))
153                    .cloned();
154
155                conn
156            }
157        }
158    }
159
160    fn start_conn_worker(self: Arc<Self>, mut closed_watch_rx: watch::Receiver<()>) {
161        tokio::spawn(async move {
162            let mut buffer = [0u8; RECEIVE_MTU];
163
164            loop {
165                let loop_self = Arc::clone(&self);
166                let conn = &loop_self.params.conn;
167
168                tokio::select! {
169                    res = conn.recv_from(&mut buffer) => {
170                        match res {
171                            Ok((len, addr)) => {
172                                // Find connection based on previously having seen this source address
173                                let conn = {
174                                    let address_map = loop_self
175                                        .address_map
176                                        .read();
177
178                                    address_map.get(&addr).cloned()
179                                };
180
181                                let conn = match conn {
182                                    // If we couldn't find the connection based on source address, see if
183                                    // this is a STUN message and if so if we can find the connection based on ufrag.
184                                    None if is_stun_message(&buffer) => {
185                                        loop_self.conn_from_stun_message(&buffer, &addr).await
186                                    }
187                                    s @ Some(_) => s,
188                                    _ => None,
189                                };
190
191                                match conn {
192                                    None => {
193                                        log::trace!("Dropping packet from {}", &addr);
194                                    }
195                                    Some(conn) => {
196                                        if let Err(err) = conn.write_packet(&buffer[..len], addr).await {
197                                            log::error!("Failed to write packet: {err}");
198                                        }
199                                    }
200                                }
201                            }
202                            Err(Error::Io(err)) if err.0.kind() == ErrorKind::TimedOut => continue,
203                            Err(err) => {
204                                log::error!("Could not read udp packet: {err}");
205                                break;
206                            }
207                        }
208                    }
209                    _ = closed_watch_rx.changed() => {
210                        return;
211                    }
212                }
213            }
214        });
215    }
216}
217
218#[async_trait]
219impl UDPMux for UDPMuxDefault {
220    async fn close(&self) -> Result<(), Error> {
221        if self.is_closed().await {
222            return Err(Error::ErrAlreadyClosed);
223        }
224
225        let mut closed_tx = self.closed_watch_tx.lock().await;
226
227        if let Some(tx) = closed_tx.take() {
228            let _ = tx.send(());
229            drop(closed_tx);
230
231            let old_conns = {
232                let mut conns = self.conns.lock().await;
233
234                std::mem::take(&mut (*conns))
235            };
236
237            // NOTE: We don't wait for these closure to complete
238            for (_, conn) in old_conns {
239                conn.close();
240            }
241
242            {
243                let mut address_map = self.address_map.write();
244
245                // NOTE: This is important, we need to drop all instances of `UDPMuxConn` to
246                // avoid a retain cycle due to the use of [`std::sync::Arc`] on both sides.
247                let _ = std::mem::take(&mut (*address_map));
248            }
249        }
250
251        Ok(())
252    }
253
254    async fn get_conn(self: Arc<Self>, ufrag: &str) -> Result<Arc<dyn Conn + Send + Sync>, Error> {
255        if self.is_closed().await {
256            return Err(Error::ErrUseClosedNetworkConn);
257        }
258
259        {
260            let mut conns = self.conns.lock().await;
261            if let Some(conn) = conns.get(ufrag) {
262                // UDPMuxConn uses `Arc` internally so it's cheap to clone, but because
263                // we implement `Conn` we need to further wrap it in an `Arc` here.
264                return Ok(Arc::new(conn.clone()) as Arc<dyn Conn + Send + Sync>);
265            }
266
267            let muxed_conn = self.create_muxed_conn(ufrag)?;
268            let mut close_rx = muxed_conn.close_rx();
269            let cloned_self = Arc::clone(&self);
270            let cloned_ufrag = ufrag.to_string();
271            tokio::spawn(async move {
272                let _ = close_rx.changed().await;
273
274                // Arc needed
275                cloned_self.remove_conn_by_ufrag(&cloned_ufrag).await;
276            });
277
278            conns.insert(ufrag.into(), muxed_conn.clone());
279
280            Ok(Arc::new(muxed_conn) as Arc<dyn Conn + Send + Sync>)
281        }
282    }
283
284    async fn remove_conn_by_ufrag(&self, ufrag: &str) {
285        // Pion's ice implementation has both `RemoveConnByFrag` and `RemoveConn`, but since `conns`
286        // is keyed on `ufrag` their implementation is equivalent.
287
288        let removed_conn = {
289            let mut conns = self.conns.lock().await;
290            conns.remove(ufrag)
291        };
292
293        if let Some(conn) = removed_conn {
294            let mut address_map = self.address_map.write();
295
296            for address in conn.get_addresses() {
297                address_map.remove(&address);
298            }
299        }
300    }
301}
302
303#[async_trait]
304impl UDPMuxWriter for UDPMuxDefault {
305    async fn register_conn_for_address(&self, conn: &UDPMuxConn, addr: SocketAddr) {
306        if self.is_closed().await {
307            return;
308        }
309
310        let key = conn.key();
311        {
312            let mut addresses = self.address_map.write();
313
314            addresses
315                .entry(addr)
316                .and_modify(|e| {
317                    if e.key() != key {
318                        e.remove_address(&addr);
319                        *e = conn.clone();
320                    }
321                })
322                .or_insert_with(|| conn.clone());
323        }
324
325        log::debug!("Registered {addr} for {key}");
326    }
327
328    async fn send_to(&self, buf: &[u8], target: &SocketAddr) -> Result<usize, Error> {
329        self.params.conn.send_to(buf, *target).await
330    }
331}