1#[cfg(test)]
2mod conn_test;
3
4use std::net::{IpAddr, SocketAddr};
5use std::sync::atomic::Ordering;
6use std::sync::{Arc, Weak};
7
8use async_trait::async_trait;
9use portable_atomic::AtomicBool;
10use tokio::sync::{mpsc, Mutex};
11
12use crate::conn::Conn;
13use crate::error::*;
14use crate::sync::RwLock;
15use crate::vnet::chunk::{Chunk, ChunkUdp};
16
17const MAX_READ_QUEUE_SIZE: usize = 1024;
18
19#[async_trait]
21pub(crate) trait ConnObserver {
22 async fn write(&self, c: Box<dyn Chunk + Send + Sync>) -> Result<()>;
23 async fn on_closed(&self, addr: SocketAddr);
24 fn determine_source_ip(&self, loc_ip: IpAddr, dst_ip: IpAddr) -> Option<IpAddr>;
25}
26
27pub(crate) type ChunkChTx = mpsc::Sender<Box<dyn Chunk + Send + Sync>>;
28
29pub(crate) struct UdpConn {
32 loc_addr: SocketAddr,
33 rem_addr: RwLock<Option<SocketAddr>>,
34 read_ch_tx: Arc<Mutex<Option<ChunkChTx>>>,
35 read_ch_rx: Mutex<mpsc::Receiver<Box<dyn Chunk + Send + Sync>>>,
36 closed: AtomicBool,
37 obs: Weak<Mutex<dyn ConnObserver + Send + Sync>>,
38}
39
40impl UdpConn {
41 pub(crate) fn new(
42 loc_addr: SocketAddr,
43 rem_addr: Option<SocketAddr>,
44 obs: Arc<Mutex<dyn ConnObserver + Send + Sync>>,
45 ) -> Self {
46 let (read_ch_tx, read_ch_rx) = mpsc::channel(MAX_READ_QUEUE_SIZE);
47
48 let weak_obs = Arc::downgrade(&obs);
49 UdpConn {
50 loc_addr,
51 rem_addr: RwLock::new(rem_addr),
52 read_ch_tx: Arc::new(Mutex::new(Some(read_ch_tx))),
53 read_ch_rx: Mutex::new(read_ch_rx),
54 closed: AtomicBool::new(false),
55 obs: weak_obs,
56 }
57 }
58
59 pub(crate) fn get_inbound_ch(&self) -> Arc<Mutex<Option<ChunkChTx>>> {
60 Arc::clone(&self.read_ch_tx)
61 }
62}
63
64#[async_trait]
65impl Conn for UdpConn {
66 async fn connect(&self, addr: SocketAddr) -> Result<()> {
67 self.rem_addr.write().replace(addr);
68
69 Ok(())
70 }
71 async fn recv(&self, buf: &mut [u8]) -> Result<usize> {
72 let (n, _) = self.recv_from(buf).await?;
73 Ok(n)
74 }
75
76 async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, SocketAddr)> {
84 let mut read_ch = self.read_ch_rx.lock().await;
85 let rem_addr = *self.rem_addr.read();
86 while let Some(chunk) = read_ch.recv().await {
87 let user_data = chunk.user_data();
88 let n = std::cmp::min(buf.len(), user_data.len());
89 buf[..n].copy_from_slice(&user_data[..n]);
90 let addr = chunk.source_addr();
91 {
92 if let Some(rem_addr) = &rem_addr {
93 if &addr != rem_addr {
94 continue; }
96 }
97 }
98 return Ok((n, addr));
99 }
100
101 Err(std::io::Error::new(std::io::ErrorKind::ConnectionAborted, "Connection Aborted").into())
102 }
103
104 async fn send(&self, buf: &[u8]) -> Result<usize> {
105 let rem_addr = *self.rem_addr.read();
106 if let Some(rem_addr) = rem_addr {
107 self.send_to(buf, rem_addr).await
108 } else {
109 Err(Error::ErrNoRemAddr)
110 }
111 }
112
113 async fn send_to(&self, buf: &[u8], target: SocketAddr) -> Result<usize> {
116 let obs = self.obs.upgrade().ok_or_else(|| Error::ErrVnetDisabled)?;
117
118 let src_ip = {
119 let obs = obs.lock().await;
120 match obs.determine_source_ip(self.loc_addr.ip(), target.ip()) {
121 Some(ip) => ip,
122 None => return Err(Error::ErrLocAddr),
123 }
124 };
125
126 let src_addr = SocketAddr::new(src_ip, self.loc_addr.port());
127
128 let mut chunk = ChunkUdp::new(src_addr, target);
129 chunk.user_data = buf.to_vec();
130 {
131 let c: Box<dyn Chunk + Send + Sync> = Box::new(chunk);
132 let obs = obs.lock().await;
133 obs.write(c).await?
134 }
135
136 Ok(buf.len())
137 }
138
139 fn local_addr(&self) -> Result<SocketAddr> {
140 Ok(self.loc_addr)
141 }
142
143 fn remote_addr(&self) -> Option<SocketAddr> {
144 *self.rem_addr.read()
145 }
146
147 async fn close(&self) -> Result<()> {
148 let obs = self.obs.upgrade().ok_or_else(|| Error::ErrVnetDisabled)?;
149
150 if self.closed.load(Ordering::SeqCst) {
151 return Err(Error::ErrAlreadyClosed);
152 }
153 self.closed.store(true, Ordering::SeqCst);
154 {
155 let mut reach_ch = self.read_ch_tx.lock().await;
156 reach_ch.take();
157 }
158 {
159 let obs = obs.lock().await;
160 obs.on_closed(self.loc_addr).await;
161 }
162
163 Ok(())
164 }
165
166 fn as_any(&self) -> &(dyn std::any::Any + Send + Sync) {
167 self
168 }
169}