socks5_server/connection/
associate.rs

1//! Socks5 command type `Associate`
2//!
3//! This module also provides an [`tokio::net::UdpSocket`] wrapper [`AssociatedUdpSocket`], which can be used to send and receive UDP packets without dealing with the SOCKS5 protocol UDP header.
4
5use bytes::{Bytes, BytesMut};
6use socks5_proto::{Address, Error as Socks5Error, Reply, Response, UdpHeader};
7use std::{
8    io::{Cursor, Error},
9    marker::PhantomData,
10    net::SocketAddr,
11    sync::atomic::{AtomicUsize, Ordering},
12};
13use tokio::{
14    io::{AsyncReadExt, AsyncWriteExt},
15    net::{TcpStream, UdpSocket},
16};
17
18/// Connection state types
19pub mod state {
20    #[derive(Debug)]
21    pub struct NeedReply;
22
23    #[derive(Debug)]
24    pub struct Ready;
25}
26
27/// Socks5 command type `Associate`
28///
29/// Reply the client with [`Associate::reply()`] to complete the command negotiation.
30#[derive(Debug)]
31pub struct Associate<S> {
32    stream: TcpStream,
33    _state: PhantomData<S>,
34}
35
36impl Associate<state::NeedReply> {
37    /// Reply to the SOCKS5 client with the given reply and address.
38    ///
39    /// If encountered an error while writing the reply, the error alongside the original `TcpStream` is returned.
40    pub async fn reply(
41        mut self,
42        reply: Reply,
43        addr: Address,
44    ) -> Result<Associate<state::Ready>, (Error, TcpStream)> {
45        let resp = Response::new(reply, addr);
46
47        if let Err(err) = resp.write_to(&mut self.stream).await {
48            return Err((err, self.stream));
49        }
50
51        Ok(Associate::new(self.stream))
52    }
53}
54
55impl Associate<state::Ready> {
56    /// Wait until the SOCKS5 client closes this TCP connection.
57    ///
58    /// Socks5 protocol defines that when the client closes the TCP connection used to send the associate command, the server should release the associated UDP socket.
59    pub async fn wait_close(&mut self) -> Result<(), Error> {
60        loop {
61            match self.stream.read(&mut [0]).await {
62                Ok(0) => break Ok(()),
63                Ok(_) => {}
64                Err(err) => break Err(err),
65            }
66        }
67    }
68}
69
70impl<S> Associate<S> {
71    #[inline]
72    pub(super) fn new(stream: TcpStream) -> Self {
73        Self {
74            stream,
75            _state: PhantomData,
76        }
77    }
78
79    /// Causes the other peer to receive a read of length 0, indicating that no more data will be sent. This only closes the stream in one direction.
80    #[inline]
81    pub async fn close(&mut self) -> Result<(), Error> {
82        self.stream.shutdown().await
83    }
84
85    /// Returns the local address that this stream is bound to.
86    #[inline]
87    pub fn local_addr(&self) -> Result<SocketAddr, Error> {
88        self.stream.local_addr()
89    }
90
91    /// Returns the remote address that this stream is connected to.
92    #[inline]
93    pub fn peer_addr(&self) -> Result<SocketAddr, Error> {
94        self.stream.peer_addr()
95    }
96
97    /// Returns a shared reference to the underlying stream.
98    ///
99    /// Note that this may break the encapsulation of the SOCKS5 connection and you should not use this method unless you know what you are doing.
100    #[inline]
101    pub fn get_ref(&self) -> &TcpStream {
102        &self.stream
103    }
104
105    /// Returns a mutable reference to the underlying stream.
106    ///
107    /// Note that this may break the encapsulation of the SOCKS5 connection and you should not use this method unless you know what you are doing.
108    #[inline]
109    pub fn get_mut(&mut self) -> &mut TcpStream {
110        &mut self.stream
111    }
112
113    /// Consumes the [`Associate<S>`] and returns the underlying [`TcpStream`](tokio::net::TcpStream).
114    #[inline]
115    pub fn into_inner(self) -> TcpStream {
116        self.stream
117    }
118}
119
120/// A wrapper of a tokio UDP socket dealing with SOCKS5 UDP header.
121///
122/// It only provides handful of methods to send / receive UDP packets with SOCKS5 UDP header. The underlying `UdpSocket` can be accessed with [`AssociatedUdpSocket::get_ref()`] and [`AssociatedUdpSocket::get_mut()`].
123#[derive(Debug)]
124pub struct AssociatedUdpSocket {
125    socket: UdpSocket,
126    buf_size: AtomicUsize,
127}
128
129impl AssociatedUdpSocket {
130    /// Creates a new [`AssociatedUdpSocket`] with a [`UdpSocket`](tokio::net::UdpSocket) and a maximum receiving UDP packet size, with SOCKS5 UDP header included.
131    pub fn new(socket: UdpSocket, buf_size: usize) -> Self {
132        Self {
133            socket,
134            buf_size: AtomicUsize::new(buf_size),
135        }
136    }
137
138    /// Receives a SOCKS5 UDP packet on the socket from the remote address which it is connected.
139    ///
140    /// On success, it returns the packet payload and the SOCKS5 UDP header. On error, it returns the error alongside an `Option<Vec<u8>>`. If the error occurs before / when receiving the raw UDP packet, the `Option<Vec<u8>>` will be `None`. Otherwise, it will be `Some(Vec<u8>)` containing the received raw UDP packet.
141    pub async fn recv(&self) -> Result<(Bytes, UdpHeader), (Socks5Error, Option<Vec<u8>>)> {
142        let max_pkt_size = self.buf_size.load(Ordering::Acquire);
143        let mut buf = vec![0; max_pkt_size];
144
145        let len = match self.socket.recv(&mut buf).await {
146            Ok(len) => len,
147            Err(err) => return Err((Socks5Error::Io(err), None)),
148        };
149
150        buf.truncate(len);
151
152        let header = match UdpHeader::read_from(&mut Cursor::new(buf.as_slice())).await {
153            Ok(header) => header,
154            Err(err) => return Err((err, Some(buf))),
155        };
156
157        let pkt = Bytes::from(buf).slice(header.serialized_len()..);
158
159        Ok((pkt, header))
160    }
161
162    /// Receives a SOCKS5 UDP packet on the socket from a remote address.
163    ///
164    /// On success, it returns the packet payload, the SOCKS5 UDP header and the source address. On error, it returns the error alongside an `Option<Vec<u8>>`. If the error occurs before / when receiving the raw UDP packet, the `Option<Vec<u8>>` will be `None`. Otherwise, it will be `Some(Vec<u8>)` containing the received raw UDP packet.
165    pub async fn recv_from(
166        &self,
167    ) -> Result<(Bytes, UdpHeader, SocketAddr), (Socks5Error, Option<Vec<u8>>)> {
168        let max_pkt_size = self.buf_size.load(Ordering::Acquire);
169        let mut buf = vec![0; max_pkt_size];
170
171        let (len, addr) = match self.socket.recv_from(&mut buf).await {
172            Ok(res) => res,
173            Err(err) => return Err((Socks5Error::Io(err), None)),
174        };
175
176        buf.truncate(len);
177
178        let header = match UdpHeader::read_from(&mut Cursor::new(buf.as_slice())).await {
179            Ok(header) => header,
180            Err(err) => return Err((err, Some(buf))),
181        };
182
183        let pkt = Bytes::from(buf).slice(header.serialized_len()..);
184
185        Ok((pkt, header, addr))
186    }
187
188    /// Sends a UDP packet to the remote address which it is connected. The SOCKS5 UDP header will be added to the packet.
189    pub async fn send<P: AsRef<[u8]>>(&self, pkt: P, header: &UdpHeader) -> Result<usize, Error> {
190        let mut buf = BytesMut::with_capacity(header.serialized_len() + pkt.as_ref().len());
191        header.write_to_buf(&mut buf);
192        buf.extend_from_slice(pkt.as_ref());
193
194        self.socket
195            .send(&buf)
196            .await
197            .map(|len| len - header.serialized_len())
198    }
199
200    /// Sends a UDP packet to a specified remote address. The SOCKS5 UDP header will be added to the packet.
201    pub async fn send_to<P: AsRef<[u8]>>(
202        &self,
203        pkt: P,
204        header: &UdpHeader,
205        addr: SocketAddr,
206    ) -> Result<usize, Error> {
207        let mut buf = BytesMut::with_capacity(header.serialized_len() + pkt.as_ref().len());
208        header.write_to_buf(&mut buf);
209        buf.extend_from_slice(pkt.as_ref());
210
211        self.socket
212            .send_to(&buf, addr)
213            .await
214            .map(|len| len - header.serialized_len())
215    }
216
217    /// Get the maximum receiving UDP packet size, with SOCKS5 UDP header included.
218    #[inline]
219    pub fn get_max_pkt_size(&self) -> usize {
220        self.buf_size.load(Ordering::Acquire)
221    }
222
223    /// Set the maximum receiving UDP packet size, with SOCKS5 UDP header included, for adjusting the receiving buffer size.
224    #[inline]
225    pub fn set_max_pkt_size(&self, size: usize) {
226        self.buf_size.store(size, Ordering::Release);
227    }
228
229    /// Returns a shared reference to the underlying socket.
230    ///
231    /// Note that this may break the encapsulation of the SOCKS5 connection and you should not use this method unless you know what you are doing.
232    #[inline]
233    pub fn get_ref(&self) -> &UdpSocket {
234        &self.socket
235    }
236
237    /// Returns a mutable reference to the underlying socket.
238    ///
239    /// Note that this may break the encapsulation of the SOCKS5 UDP abstraction and you should not use this method unless you know what you are doing.
240    #[inline]
241    pub fn get_mut(&mut self) -> &mut UdpSocket {
242        &mut self.socket
243    }
244
245    /// Consumes the [`AssociatedUdpSocket`] and returns the underlying [`UdpSocket`](tokio::net::UdpSocket).
246    #[inline]
247    pub fn into_inner(self) -> UdpSocket {
248        self.socket
249    }
250}