Skip to main content

webrtc_util/vnet/
conn.rs

1#[cfg(test)]
2mod conn_test;
3
4use std::net::{IpAddr, SocketAddr};
5use std::sync::atomic::Ordering;
6use std::sync::{Arc, Weak};
7
8use async_trait::async_trait;
9use portable_atomic::AtomicBool;
10use tokio::sync::{mpsc, Mutex};
11
12use crate::conn::Conn;
13use crate::error::*;
14use crate::sync::RwLock;
15use crate::vnet::chunk::{Chunk, ChunkUdp};
16
17const MAX_READ_QUEUE_SIZE: usize = 1024;
18
19/// vNet implements this
20#[async_trait]
21pub(crate) trait ConnObserver {
22    async fn write(&self, c: Box<dyn Chunk + Send + Sync>) -> Result<()>;
23    async fn on_closed(&self, addr: SocketAddr);
24    fn determine_source_ip(&self, loc_ip: IpAddr, dst_ip: IpAddr) -> Option<IpAddr>;
25}
26
27pub(crate) type ChunkChTx = mpsc::Sender<Box<dyn Chunk + Send + Sync>>;
28
29/// UDPConn is the implementation of the Conn and PacketConn interfaces for UDP network connections.
30/// compatible with net.PacketConn and net.Conn
31pub(crate) struct UdpConn {
32    loc_addr: SocketAddr,
33    rem_addr: RwLock<Option<SocketAddr>>,
34    read_ch_tx: Arc<Mutex<Option<ChunkChTx>>>,
35    read_ch_rx: Mutex<mpsc::Receiver<Box<dyn Chunk + Send + Sync>>>,
36    closed: AtomicBool,
37    obs: Weak<Mutex<dyn ConnObserver + Send + Sync>>,
38}
39
40impl UdpConn {
41    pub(crate) fn new(
42        loc_addr: SocketAddr,
43        rem_addr: Option<SocketAddr>,
44        obs: Arc<Mutex<dyn ConnObserver + Send + Sync>>,
45    ) -> Self {
46        let (read_ch_tx, read_ch_rx) = mpsc::channel(MAX_READ_QUEUE_SIZE);
47
48        let weak_obs = Arc::downgrade(&obs);
49        UdpConn {
50            loc_addr,
51            rem_addr: RwLock::new(rem_addr),
52            read_ch_tx: Arc::new(Mutex::new(Some(read_ch_tx))),
53            read_ch_rx: Mutex::new(read_ch_rx),
54            closed: AtomicBool::new(false),
55            obs: weak_obs,
56        }
57    }
58
59    pub(crate) fn get_inbound_ch(&self) -> Arc<Mutex<Option<ChunkChTx>>> {
60        Arc::clone(&self.read_ch_tx)
61    }
62}
63
64#[async_trait]
65impl Conn for UdpConn {
66    async fn connect(&self, addr: SocketAddr) -> Result<()> {
67        self.rem_addr.write().replace(addr);
68
69        Ok(())
70    }
71    async fn recv(&self, buf: &mut [u8]) -> Result<usize> {
72        let (n, _) = self.recv_from(buf).await?;
73        Ok(n)
74    }
75
76    /// recv_from reads a packet from the connection,
77    /// copying the payload into p. It returns the number of
78    /// bytes copied into p and the return address that
79    /// was on the packet.
80    /// It returns the number of bytes read (0 <= n <= len(p))
81    /// and any error encountered. Callers should always process
82    /// the n > 0 bytes returned before considering the error err.
83    async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, SocketAddr)> {
84        let mut read_ch = self.read_ch_rx.lock().await;
85        let rem_addr = *self.rem_addr.read();
86        while let Some(chunk) = read_ch.recv().await {
87            let user_data = chunk.user_data();
88            let n = std::cmp::min(buf.len(), user_data.len());
89            buf[..n].copy_from_slice(&user_data[..n]);
90            let addr = chunk.source_addr();
91            {
92                if let Some(rem_addr) = &rem_addr {
93                    if &addr != rem_addr {
94                        continue; // discard (shouldn't happen)
95                    }
96                }
97            }
98            return Ok((n, addr));
99        }
100
101        Err(std::io::Error::new(std::io::ErrorKind::ConnectionAborted, "Connection Aborted").into())
102    }
103
104    async fn send(&self, buf: &[u8]) -> Result<usize> {
105        let rem_addr = *self.rem_addr.read();
106        if let Some(rem_addr) = rem_addr {
107            self.send_to(buf, rem_addr).await
108        } else {
109            Err(Error::ErrNoRemAddr)
110        }
111    }
112
113    /// send_to writes a packet with payload p to addr.
114    /// send_to can be made to time out and return
115    async fn send_to(&self, buf: &[u8], target: SocketAddr) -> Result<usize> {
116        let obs = self.obs.upgrade().ok_or_else(|| Error::ErrVnetDisabled)?;
117
118        let src_ip = {
119            let obs = obs.lock().await;
120            match obs.determine_source_ip(self.loc_addr.ip(), target.ip()) {
121                Some(ip) => ip,
122                None => return Err(Error::ErrLocAddr),
123            }
124        };
125
126        let src_addr = SocketAddr::new(src_ip, self.loc_addr.port());
127
128        let mut chunk = ChunkUdp::new(src_addr, target);
129        chunk.user_data = buf.to_vec();
130        {
131            let c: Box<dyn Chunk + Send + Sync> = Box::new(chunk);
132            let obs = obs.lock().await;
133            obs.write(c).await?
134        }
135
136        Ok(buf.len())
137    }
138
139    fn local_addr(&self) -> Result<SocketAddr> {
140        Ok(self.loc_addr)
141    }
142
143    fn remote_addr(&self) -> Option<SocketAddr> {
144        *self.rem_addr.read()
145    }
146
147    async fn close(&self) -> Result<()> {
148        let obs = self.obs.upgrade().ok_or_else(|| Error::ErrVnetDisabled)?;
149
150        if self.closed.load(Ordering::SeqCst) {
151            return Err(Error::ErrAlreadyClosed);
152        }
153        self.closed.store(true, Ordering::SeqCst);
154        {
155            let mut reach_ch = self.read_ch_tx.lock().await;
156            reach_ch.take();
157        }
158        {
159            let obs = obs.lock().await;
160            obs.on_closed(self.loc_addr).await;
161        }
162
163        Ok(())
164    }
165
166    fn as_any(&self) -> &(dyn std::any::Any + Send + Sync) {
167        self
168    }
169}