ss_light/
udprelay.rs

1use std::{
2    net::{Ipv4Addr, Ipv6Addr, SocketAddr},
3    sync::Arc,
4    time::Duration,
5};
6
7use bytes::Bytes;
8use futures::future;
9use lru_time_cache::LruCache;
10use tokio::{
11    io,
12    net::{lookup_host, UdpSocket},
13    sync::mpsc,
14    task::JoinHandle,
15    time,
16};
17use tracing::{debug, error, trace, warn};
18
19use crate::{
20    consts::{MAXIMUM_UDP_PAYLOAD_SIZE, UDP_KEEP_ALIVE_CHANNEL_SIZE, UDP_SEND_CHANNEL_SIZE},
21    crypto::PacketCipher,
22    Address, CipherKind,
23};
24
25pub struct UdpServer {
26    cipher: Arc<PacketCipher>,
27    socket: Arc<UdpSocket>,
28    route_table: LruCache<SocketAddr, UdpTunnelWorkerHandle>, // peer addr -> worker
29    keepalive_tx: mpsc::Sender<SocketAddr>,
30    keepalive_rx: mpsc::Receiver<SocketAddr>,
31    time_to_live: Duration,
32}
33
34impl UdpServer {
35    pub fn new(
36        socket: UdpSocket,
37        kind: CipherKind,
38        key: &[u8],
39        cap: usize,
40        time_to_live: Duration,
41    ) -> Self {
42        let cipher = PacketCipher::new(kind, key);
43        let cipher = Arc::new(cipher);
44        let route_table = LruCache::with_expiry_duration_and_capacity(time_to_live, cap);
45        let (keepalive_tx, keepalive_rx) = mpsc::channel(UDP_KEEP_ALIVE_CHANNEL_SIZE);
46        let socket = Arc::new(socket);
47        UdpServer {
48            cipher,
49            socket,
50            route_table,
51            keepalive_tx,
52            keepalive_rx,
53            time_to_live,
54        }
55    }
56
57    pub async fn run(mut self) {
58        let recv_buf = &mut [0u8; MAXIMUM_UDP_PAYLOAD_SIZE];
59        let mut cleanup_timer = time::interval(self.time_to_live);
60        loop {
61            tokio::select! {
62                result = self.cipher.recv_from(&self.socket, recv_buf) => {
63                    match result {
64                        Ok((n, peer, target)) => {
65                            if n == 0 {continue;}
66                            let data = &recv_buf[..n];
67
68                            if let Err(e) = self.send_to_tunnle_worker(peer, target, data).await {
69                                error!("udp proxy peer {} with {} bytes, send to tunnle worker error: {}", peer,  n, e);
70                            }
71                        }
72                        Err(e) => {
73                            error!("udp proxy recv error {}", e);
74                            continue;
75                        }
76                    }
77                }
78
79                _ = cleanup_timer.tick() => {
80                    let _ = self.route_table.iter();
81                }
82
83                peer_addr_keep_opt = self.keepalive_rx.recv() => {
84                    let peer = peer_addr_keep_opt.expect("keep-alive channel closed unexpectly");
85                    self.route_table.get(&peer);
86                }
87            }
88        }
89    }
90
91    async fn send_to_tunnle_worker(
92        &mut self,
93        peer: SocketAddr,
94        target: Address,
95        data: &[u8],
96    ) -> io::Result<()> {
97        if let Some(worker_handle) = self.route_table.get(&peer) {
98            return worker_handle.try_send_to_worker((target, Bytes::copy_from_slice(data)));
99        }
100        // create a new worker
101        debug!("new udp proxy request {} <-> ...", peer);
102        let woker_handle = UdpTunnelWorkerHandle::new(
103            self.socket.clone(),
104            self.keepalive_tx.clone(),
105            peer,
106            self.cipher.clone(),
107        );
108
109        woker_handle.try_send_to_worker((target, Bytes::copy_from_slice(data)))?;
110        self.route_table.insert(peer, woker_handle);
111        Ok(())
112    }
113}
114
115struct UdpTunnelWorkerHandle {
116    join_handle: JoinHandle<()>,
117    sender: mpsc::Sender<(Address, Bytes)>,
118}
119
120impl Drop for UdpTunnelWorkerHandle {
121    fn drop(&mut self) {
122        self.join_handle.abort();
123    }
124}
125
126impl UdpTunnelWorkerHandle {
127    fn new(
128        server_socket: Arc<UdpSocket>,
129        keepalive_tx: mpsc::Sender<SocketAddr>,
130        peer_addr: SocketAddr,
131        cipher: Arc<PacketCipher>,
132    ) -> Self {
133        let (join_handle, sender) =
134            UdpTunnelWorker::create(server_socket, keepalive_tx, peer_addr, cipher);
135        UdpTunnelWorkerHandle {
136            join_handle,
137            sender,
138        }
139    }
140    fn try_send_to_worker(&self, data: (Address, Bytes)) -> io::Result<()> {
141        if let Err(..) = self.sender.try_send(data) {
142            let err = io::Error::new(io::ErrorKind::Other, "udp send channel full");
143            return Err(err);
144        }
145        Ok(())
146    }
147}
148
149struct UdpTunnelWorker {
150    keepalive_tx: mpsc::Sender<SocketAddr>,
151    keepalive_flag: bool,
152    server_socket: Arc<UdpSocket>,
153    peer_addr: SocketAddr,
154    outbound_ipv4_socket: Option<UdpSocket>,
155    outbound_ipv6_socket: Option<UdpSocket>,
156    cipher: Arc<PacketCipher>,
157}
158
159impl UdpTunnelWorker {
160    fn create(
161        server_socket: Arc<UdpSocket>,
162        keepalive_tx: mpsc::Sender<SocketAddr>,
163        peer_addr: SocketAddr,
164        cipher: Arc<PacketCipher>,
165    ) -> (JoinHandle<()>, mpsc::Sender<(Address, Bytes)>) {
166        let (tx, rx) = mpsc::channel(UDP_SEND_CHANNEL_SIZE);
167
168        let woker = UdpTunnelWorker {
169            keepalive_tx,
170            keepalive_flag: false,
171            server_socket,
172            peer_addr,
173            outbound_ipv4_socket: None,
174            outbound_ipv6_socket: None,
175            cipher,
176        };
177
178        let join_handle = tokio::spawn(async move { woker.run(rx).await });
179
180        return (join_handle, tx);
181    }
182
183    async fn run(mut self, mut rx: mpsc::Receiver<(Address, Bytes)>) {
184        let mut outbound_ipv4_buffer = Vec::new();
185        let mut outbound_ipv6_buffer = Vec::new();
186        let mut keepalive_interval = time::interval(Duration::from_secs(1));
187        loop {
188            tokio::select! {
189                recevied_opt = rx.recv() => {
190                    let (target_addr, data) = match recevied_opt {
191                        Some(d) => d,
192                        None => {
193                            trace!("udp tunnel worker for peer {} -> ... channel closed", self.peer_addr);
194                            break;
195                        }
196
197                    };
198                    if let Err(e) = self.send_data_to_target(&target_addr, &data).await {
199                        error!("udp proxy {} <-> {}, L2R {} bytes err: {}", self.peer_addr, target_addr, data.len(), e);
200                    }
201                    debug!("udp proxy {} <-> {}, L2R {} bytes", self.peer_addr, target_addr, data.len())
202                }
203
204                recevied_opt = Self::recv_data_from_target(&self.outbound_ipv4_socket,&mut outbound_ipv4_buffer) => {
205                    let (n, target_addr) = match recevied_opt {
206                        Ok(r) => r,
207                        Err(e) => {
208                            error!("udp tunnel worker for peer {} <- ... failed, error: {}", self.peer_addr, e);
209                            continue;
210                        }
211                    };
212                    self.send_data_to_peer(target_addr, &outbound_ipv4_buffer[..n]).await;
213
214                }
215
216                recevied_opt = Self::recv_data_from_target(&self.outbound_ipv6_socket,&mut outbound_ipv6_buffer) => {
217                    let (n, target_addr) = match recevied_opt {
218                        Ok(r) => r,
219                        Err(e) => {
220                            error!("udp tunnel worker for peer {} <- ... failed, error: {}", self.peer_addr, e);
221                            continue;
222                        }
223                    };
224                    self.send_data_to_peer(target_addr, &outbound_ipv6_buffer[..n]).await;
225                }
226
227                _ = keepalive_interval.tick() => {
228                    if self.keepalive_flag {
229                        if let Err(..) = self.keepalive_tx.try_send(self.peer_addr) {
230                            debug!("udp tunnel worker for peer {} keep-alive failed, channel full or closed", self.peer_addr);
231                        } else {
232                            self.keepalive_flag = false;
233                        }
234                    }
235                }
236            }
237        }
238    }
239
240    async fn send_data_to_target(&mut self, target_addr: &Address, data: &[u8]) -> io::Result<()> {
241        let target_sa: SocketAddr;
242        match *target_addr {
243            Address::SocketAddress(sa) => target_sa = sa,
244            Address::DomainNameAddress(ref domain, port) => {
245                match lookup_host((domain.as_str(), port)).await {
246                    Ok(mut v) => {
247                        match v.next() {
248                            Some(sa) => target_sa = sa,
249                            None => {
250                                return Err(io::Error::new(
251                                    io::ErrorKind::Other,
252                                    format!("dns resolve exmpty: {}", domain),
253                                ))
254                            }
255                        };
256                    }
257                    Err(e) => {
258                        return Err(io::Error::new(
259                            io::ErrorKind::Other,
260                            format!("dns resolve {} error: {}", domain, e),
261                        ))
262                    }
263                };
264            }
265        }
266
267        let socket = match target_sa {
268            SocketAddr::V4(..) => match self.outbound_ipv4_socket {
269                Some(ref mut socket) => socket,
270                None => {
271                    let socket =
272                        UdpSocket::bind(SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0)).await?;
273                    self.outbound_ipv4_socket.insert(socket)
274                }
275            },
276            SocketAddr::V6(..) => match self.outbound_ipv6_socket {
277                Some(ref mut socket) => socket,
278                None => {
279                    let socket =
280                        UdpSocket::bind(SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), 0)).await?;
281                    self.outbound_ipv6_socket.insert(socket)
282                }
283            },
284        };
285
286        let n = socket.send_to(data, target_sa).await?;
287        if n != data.len() {
288            warn!(
289                "udp proxy {} -> {} sent {} bytes != expected {} bytes",
290                self.peer_addr,
291                target_addr,
292                n,
293                data.len()
294            );
295        }
296        Ok(())
297    }
298
299    async fn recv_data_from_target(
300        socket: &Option<UdpSocket>,
301        buf: &mut Vec<u8>,
302    ) -> io::Result<(usize, SocketAddr)> {
303        match *socket {
304            None => future::pending().await,
305            Some(ref s) => {
306                if buf.is_empty() {
307                    buf.resize(MAXIMUM_UDP_PAYLOAD_SIZE, 0);
308                }
309                s.recv_from(buf).await
310            }
311        }
312    }
313
314    async fn send_data_to_peer(&mut self, target: SocketAddr, data: &[u8]) {
315        self.keepalive_flag = true;
316
317        if let Err(e) = self
318            .cipher
319            .send_to(&self.server_socket, data, self.peer_addr, target)
320            .await
321        {
322            warn!(
323                "udp tunnel worker sendback {} bytes to peer {}, from target {}, err: {}",
324                data.len(),
325                self.peer_addr,
326                target,
327                e
328            );
329        } else {
330            debug!(
331                "udp proxy {} <-> {}, R2L {} bytes",
332                self.peer_addr,
333                target,
334                data.len()
335            );
336        }
337    }
338}