Skip to main content

socks5_impl/server/connection/
associate.rs

1use crate::protocol::{Address, AsyncStreamOperation, Reply, Response, StreamOperation, UdpHeader};
2use bytes::{Bytes, BytesMut};
3use std::{
4    net::SocketAddr,
5    pin::Pin,
6    sync::atomic::{AtomicUsize, Ordering},
7    task::{Context, Poll},
8};
9use tokio::{
10    io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf},
11    net::{TcpStream, ToSocketAddrs, UdpSocket},
12};
13
14/// Socks5 connection type `UdpAssociate`
15#[derive(Debug)]
16pub struct UdpAssociate<S> {
17    stream: TcpStream,
18    _state: S,
19}
20
21impl<S: Default> UdpAssociate<S> {
22    #[inline]
23    pub(super) fn new(stream: TcpStream) -> Self {
24        Self {
25            stream,
26            _state: S::default(),
27        }
28    }
29
30    /// Reply to the SOCKS5 client with the given reply and address.
31    ///
32    /// If encountered an error while writing the reply, the error alongside the original `TcpStream` is returned.
33    pub async fn reply(mut self, reply: Reply, addr: Address) -> std::io::Result<UdpAssociate<Ready>> {
34        let resp = Response::new(reply, addr);
35        resp.write_to_async_stream(&mut self.stream).await?;
36        Ok(UdpAssociate::<Ready>::new(self.stream))
37    }
38
39    /// Causes the other peer to receive a read of length 0, indicating that no more data will be sent. This only closes the stream in one direction.
40    #[inline]
41    pub async fn shutdown(&mut self) -> std::io::Result<()> {
42        self.stream.shutdown().await
43    }
44
45    /// Returns the local address that this stream is bound to.
46    #[inline]
47    pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
48        self.stream.local_addr()
49    }
50
51    /// Returns the remote address that this stream is connected to.
52    #[inline]
53    pub fn peer_addr(&self) -> std::io::Result<SocketAddr> {
54        self.stream.peer_addr()
55    }
56
57    /// Gets the value of the `TCP_NODELAY` option on this socket.
58    ///
59    /// For more information about this option, see [`set_nodelay`](#method.set_nodelay).
60    #[inline]
61    pub fn nodelay(&self) -> std::io::Result<bool> {
62        self.stream.nodelay()
63    }
64
65    /// Sets the value of the `TCP_NODELAY` option on this socket.
66    ///
67    /// If set, this option disables the Nagle algorithm. This means that segments are always sent as soon as possible,
68    /// even if there is only a small amount of data. When not set, data is buffered until there is a sufficient amount to send out,
69    /// thereby avoiding the frequent sending of small packets.
70    pub fn set_nodelay(&self, nodelay: bool) -> std::io::Result<()> {
71        self.stream.set_nodelay(nodelay)
72    }
73
74    /// Gets the value of the `IP_TTL` option for this socket.
75    ///
76    /// For more information about this option, see [`set_ttl`](#method.set_ttl).
77    pub fn ttl(&self) -> std::io::Result<u32> {
78        self.stream.ttl()
79    }
80
81    /// Sets the value for the `IP_TTL` option on this socket.
82    ///
83    /// This value sets the time-to-live field that is used in every packet sent from this socket.
84    pub fn set_ttl(&self, ttl: u32) -> std::io::Result<()> {
85        self.stream.set_ttl(ttl)
86    }
87}
88
89#[derive(Debug, Default)]
90pub struct NeedReply;
91
92#[derive(Debug, Default)]
93pub struct Ready;
94
95impl UdpAssociate<Ready> {
96    /// Wait until the client closes this TCP connection.
97    ///
98    /// Socks5 protocol defines that when the client closes the TCP connection used to send the associate command,
99    /// the server should release the associated UDP socket.
100    pub async fn wait_until_closed(&mut self) -> std::io::Result<()> {
101        loop {
102            match self.stream.read(&mut [0]).await {
103                Ok(0) => break Ok(()),
104                Ok(_) => {}
105                Err(err) => break Err(err),
106            }
107        }
108    }
109}
110
111impl std::ops::Deref for UdpAssociate<Ready> {
112    type Target = TcpStream;
113
114    #[inline]
115    fn deref(&self) -> &Self::Target {
116        &self.stream
117    }
118}
119
120impl std::ops::DerefMut for UdpAssociate<Ready> {
121    #[inline]
122    fn deref_mut(&mut self) -> &mut Self::Target {
123        &mut self.stream
124    }
125}
126
127impl AsyncRead for UdpAssociate<Ready> {
128    #[inline]
129    fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<std::io::Result<()>> {
130        Pin::new(&mut self.stream).poll_read(cx, buf)
131    }
132}
133
134impl AsyncWrite for UdpAssociate<Ready> {
135    #[inline]
136    fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
137        Pin::new(&mut self.stream).poll_write(cx, buf)
138    }
139
140    #[inline]
141    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
142        Pin::new(&mut self.stream).poll_flush(cx)
143    }
144
145    #[inline]
146    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
147        Pin::new(&mut self.stream).poll_shutdown(cx)
148    }
149}
150
151impl<S> From<UdpAssociate<S>> for TcpStream {
152    #[inline]
153    fn from(conn: UdpAssociate<S>) -> Self {
154        conn.stream
155    }
156}
157
158/// This is a helper for managing the associated UDP socket.
159///
160/// It will add the socks5 UDP header to every UDP packet it sends, also try to parse the socks5 UDP header from any UDP packet received.
161///
162/// The receiving buffer size for each UDP packet can be set with [`set_recv_buffer_size()`](#method.set_recv_buffer_size),
163/// and be read with [`get_max_packet_size()`](#method.get_recv_buffer_size).
164///
165/// You can create this struct by using [`AssociatedUdpSocket::from::<(UdpSocket, usize)>()`](#impl-From<UdpSocket>),
166/// the first element of the tuple is the UDP socket, the second element is the receiving buffer size.
167///
168/// This struct can also be revert into a raw tokio UDP socket with [`UdpSocket::from::<AssociatedUdpSocket>()`](#impl-From<AssociatedUdpSocket>).
169///
170/// [`AssociatedUdpSocket`] can be used as the associated UDP socket.
171#[derive(Debug)]
172pub struct AssociatedUdpSocket {
173    socket: UdpSocket,
174    buf_size: AtomicUsize,
175}
176
177impl AssociatedUdpSocket {
178    /// Connects the UDP socket setting the default destination for send() and limiting packets that are read via recv from the address specified in addr.
179    #[inline]
180    pub async fn connect<A: ToSocketAddrs>(&self, addr: A) -> std::io::Result<()> {
181        self.socket.connect(addr).await
182    }
183
184    /// Get the maximum UDP packet size, with socks5 UDP header included.
185    pub fn get_max_packet_size(&self) -> usize {
186        self.buf_size.load(Ordering::Relaxed)
187    }
188
189    /// Set the maximum UDP packet size, with socks5 UDP header included, for adjusting the receiving buffer size.
190    pub fn set_max_packet_size(&self, size: usize) {
191        self.buf_size.store(size, Ordering::Release);
192    }
193
194    /// Receives a socks5 UDP relay packet on the socket from the remote address to which it is connected.
195    /// On success, returns the packet itself, the fragment number and the remote target address.
196    ///
197    /// The [`connect`](#method.connect) method will connect this socket to a remote address.
198    /// This method will fail if the socket is not connected.
199    pub async fn recv(&self) -> std::io::Result<(Bytes, u8, Address)> {
200        loop {
201            let max_packet_size = self.buf_size.load(Ordering::Acquire);
202            let mut buf = vec![0; max_packet_size];
203            let len = self.socket.recv(&mut buf).await?;
204            buf.truncate(len);
205            let pkt = Bytes::from(buf);
206
207            if let Ok(header) = UdpHeader::retrieve_from_async_stream(&mut pkt.as_ref()).await {
208                let pkt = pkt.slice(header.len()..);
209                return Ok((pkt, header.frag, header.address));
210            }
211        }
212    }
213
214    /// Receives a socks5 UDP relay packet on the socket from the any remote address.
215    /// On success, returns the packet itself, the fragment number, the remote target address and the source address.
216    pub async fn recv_from(&self) -> std::io::Result<(Bytes, u8, Address, SocketAddr)> {
217        loop {
218            let max_packet_size = self.buf_size.load(Ordering::Acquire);
219            let mut buf = vec![0; max_packet_size];
220            let (len, src_addr) = self.socket.recv_from(&mut buf).await?;
221            buf.truncate(len);
222            let pkt = Bytes::from(buf);
223
224            if let Ok(header) = UdpHeader::retrieve_from_async_stream(&mut pkt.as_ref()).await {
225                let pkt = pkt.slice(header.len()..);
226                return Ok((pkt, header.frag, header.address, src_addr));
227            }
228        }
229    }
230
231    /// Sends a UDP relay packet to the remote address to which it is connected. The socks5 UDP header will be added to the packet.
232    pub async fn send<P: AsRef<[u8]>>(&self, pkt: P, frag: u8, from_addr: Address) -> std::io::Result<usize> {
233        let header = UdpHeader::new(frag, from_addr);
234        let mut buf = BytesMut::with_capacity(header.len() + pkt.as_ref().len());
235        header.write_to_buf(&mut buf);
236        buf.extend_from_slice(pkt.as_ref());
237
238        self.socket.send(&buf).await.map(|len| len - header.len())
239    }
240
241    /// Sends a UDP relay packet to a specified remote address to which it is connected. The socks5 UDP header will be added to the packet.
242    pub async fn send_to<P: AsRef<[u8]>>(&self, pkt: P, frag: u8, from_addr: Address, to_addr: SocketAddr) -> std::io::Result<usize> {
243        let header = UdpHeader::new(frag, from_addr);
244        let mut buf = BytesMut::with_capacity(header.len() + pkt.as_ref().len());
245        header.write_to_buf(&mut buf);
246        buf.extend_from_slice(pkt.as_ref());
247
248        self.socket.send_to(&buf, to_addr).await.map(|len| len - header.len())
249    }
250}
251
252impl From<(UdpSocket, usize)> for AssociatedUdpSocket {
253    #[inline]
254    fn from(from: (UdpSocket, usize)) -> Self {
255        AssociatedUdpSocket {
256            socket: from.0,
257            buf_size: AtomicUsize::new(from.1),
258        }
259    }
260}
261
262impl From<AssociatedUdpSocket> for UdpSocket {
263    #[inline]
264    fn from(from: AssociatedUdpSocket) -> Self {
265        from.socket
266    }
267}
268
269impl AsRef<UdpSocket> for AssociatedUdpSocket {
270    #[inline]
271    fn as_ref(&self) -> &UdpSocket {
272        &self.socket
273    }
274}
275
276impl AsMut<UdpSocket> for AssociatedUdpSocket {
277    #[inline]
278    fn as_mut(&mut self) -> &mut UdpSocket {
279        &mut self.socket
280    }
281}