Skip to main content

socks5_impl/server/connection/
associate.rs

1use crate::protocol::{Address, AsyncStreamOperation, Reply, Response, StreamOperation, UdpHeader};
2use bytes::{Bytes, BytesMut};
3use std::{
4    net::SocketAddr,
5    pin::Pin,
6    sync::atomic::{AtomicUsize, Ordering},
7    task::{Context, Poll},
8    time::Duration,
9};
10use tokio::{
11    io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf},
12    net::{TcpStream, ToSocketAddrs, UdpSocket},
13};
14
15/// Socks5 connection type `UdpAssociate`
16#[derive(Debug)]
17pub struct UdpAssociate<S> {
18    stream: TcpStream,
19    _state: S,
20}
21
22impl<S: Default> UdpAssociate<S> {
23    #[inline]
24    pub(super) fn new(stream: TcpStream) -> Self {
25        Self {
26            stream,
27            _state: S::default(),
28        }
29    }
30
31    /// Reply to the SOCKS5 client with the given reply and address.
32    ///
33    /// If encountered an error while writing the reply, the error alongside the original `TcpStream` is returned.
34    pub async fn reply(mut self, reply: Reply, addr: Address) -> std::io::Result<UdpAssociate<Ready>> {
35        let resp = Response::new(reply, addr);
36        resp.write_to_async_stream(&mut self.stream).await?;
37        Ok(UdpAssociate::<Ready>::new(self.stream))
38    }
39
40    /// Causes the other peer to receive a read of length 0, indicating that no more data will be sent. This only closes the stream in one direction.
41    #[inline]
42    pub async fn shutdown(&mut self) -> std::io::Result<()> {
43        self.stream.shutdown().await
44    }
45
46    /// Returns the local address that this stream is bound to.
47    #[inline]
48    pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
49        self.stream.local_addr()
50    }
51
52    /// Returns the remote address that this stream is connected to.
53    #[inline]
54    pub fn peer_addr(&self) -> std::io::Result<SocketAddr> {
55        self.stream.peer_addr()
56    }
57
58    /// Gets the value of the `TCP_NODELAY` option on this socket.
59    ///
60    /// For more information about this option, see [`set_nodelay`](#method.set_nodelay).
61    #[inline]
62    pub fn nodelay(&self) -> std::io::Result<bool> {
63        self.stream.nodelay()
64    }
65
66    /// Sets the value of the `TCP_NODELAY` option on this socket.
67    ///
68    /// If set, this option disables the Nagle algorithm. This means that segments are always sent as soon as possible,
69    /// even if there is only a small amount of data. When not set, data is buffered until there is a sufficient amount to send out,
70    /// thereby avoiding the frequent sending of small packets.
71    pub fn set_nodelay(&self, nodelay: bool) -> std::io::Result<()> {
72        self.stream.set_nodelay(nodelay)
73    }
74
75    /// Gets the value of the `IP_TTL` option for this socket.
76    ///
77    /// For more information about this option, see [`set_ttl`](#method.set_ttl).
78    pub fn ttl(&self) -> std::io::Result<u32> {
79        self.stream.ttl()
80    }
81
82    /// Sets the value for the `IP_TTL` option on this socket.
83    ///
84    /// This value sets the time-to-live field that is used in every packet sent from this socket.
85    pub fn set_ttl(&self, ttl: u32) -> std::io::Result<()> {
86        self.stream.set_ttl(ttl)
87    }
88}
89
90#[derive(Debug, Default)]
91pub struct NeedReply;
92
93#[derive(Debug, Default)]
94pub struct Ready;
95
96impl UdpAssociate<Ready> {
97    /// Wait until the client closes this TCP connection.
98    ///
99    /// Socks5 protocol defines that when the client closes the TCP connection used to send the associate command,
100    /// the server should release the associated UDP socket.
101    pub async fn wait_until_closed(&mut self) -> std::io::Result<()> {
102        loop {
103            match self.stream.read(&mut [0]).await {
104                Ok(0) => break Ok(()),
105                Ok(_) => {}
106                Err(err) => break Err(err),
107            }
108        }
109    }
110}
111
112impl std::ops::Deref for UdpAssociate<Ready> {
113    type Target = TcpStream;
114
115    #[inline]
116    fn deref(&self) -> &Self::Target {
117        &self.stream
118    }
119}
120
121impl std::ops::DerefMut for UdpAssociate<Ready> {
122    #[inline]
123    fn deref_mut(&mut self) -> &mut Self::Target {
124        &mut self.stream
125    }
126}
127
128impl AsyncRead for UdpAssociate<Ready> {
129    #[inline]
130    fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<std::io::Result<()>> {
131        Pin::new(&mut self.stream).poll_read(cx, buf)
132    }
133}
134
135impl AsyncWrite for UdpAssociate<Ready> {
136    #[inline]
137    fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
138        Pin::new(&mut self.stream).poll_write(cx, buf)
139    }
140
141    #[inline]
142    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
143        Pin::new(&mut self.stream).poll_flush(cx)
144    }
145
146    #[inline]
147    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
148        Pin::new(&mut self.stream).poll_shutdown(cx)
149    }
150}
151
152impl<S> From<UdpAssociate<S>> for TcpStream {
153    #[inline]
154    fn from(conn: UdpAssociate<S>) -> Self {
155        conn.stream
156    }
157}
158
159/// This is a helper for managing the associated UDP socket.
160///
161/// It will add the socks5 UDP header to every UDP packet it sends, also try to parse the socks5 UDP header from any UDP packet received.
162///
163/// The receiving buffer size for each UDP packet can be set with [`set_recv_buffer_size()`](#method.set_recv_buffer_size),
164/// and be read with [`get_max_packet_size()`](#method.get_recv_buffer_size).
165///
166/// You can create this struct by using [`AssociatedUdpSocket::from::<(UdpSocket, usize)>()`](#impl-From<UdpSocket>),
167/// the first element of the tuple is the UDP socket, the second element is the receiving buffer size.
168///
169/// This struct can also be revert into a raw tokio UDP socket with [`UdpSocket::from::<AssociatedUdpSocket>()`](#impl-From<AssociatedUdpSocket>).
170///
171/// [`AssociatedUdpSocket`] can be used as the associated UDP socket.
172#[derive(Debug)]
173pub struct AssociatedUdpSocket {
174    socket: UdpSocket,
175    buf_size: AtomicUsize,
176}
177
178impl AssociatedUdpSocket {
179    /// Connects the UDP socket setting the default destination for send() and limiting packets that are read via recv from the address specified in addr.
180    #[inline]
181    pub async fn connect<A: ToSocketAddrs>(&self, addr: A) -> std::io::Result<()> {
182        self.socket.connect(addr).await
183    }
184
185    /// Get the maximum UDP packet size, with socks5 UDP header included.
186    pub fn get_max_packet_size(&self) -> usize {
187        self.buf_size.load(Ordering::Relaxed)
188    }
189
190    /// Set the maximum UDP packet size, with socks5 UDP header included, for adjusting the receiving buffer size.
191    pub fn set_max_packet_size(&self, size: usize) {
192        self.buf_size.store(size, Ordering::Release);
193    }
194
195    /// Receives a socks5 UDP relay packet on the socket from the remote address to which it is connected.
196    /// On success, returns the packet itself, the fragment number and the remote target address.
197    ///
198    /// The [`connect`](#method.connect) method will connect this socket to a remote address.
199    /// This method will fail if the socket is not connected.
200    pub async fn recv(&self) -> std::io::Result<(Bytes, u8, Address)> {
201        loop {
202            let max_packet_size = self.buf_size.load(Ordering::Acquire);
203            let mut buf = vec![0; max_packet_size];
204            let len = self.socket.recv(&mut buf).await?;
205            buf.truncate(len);
206            let pkt = Bytes::from(buf);
207
208            if let Ok(header) = UdpHeader::retrieve_from_async_stream(&mut pkt.as_ref()).await {
209                let pkt = pkt.slice(header.len()..);
210                return Ok((pkt, header.frag, header.address));
211            }
212        }
213    }
214
215    /// Receives a socks5 UDP relay packet on the socket from the any remote address.
216    /// On success, returns the packet itself, the fragment number, the remote target address and the source address.
217    pub async fn recv_from(&self) -> std::io::Result<(Bytes, u8, Address, SocketAddr)> {
218        loop {
219            let max_packet_size = self.buf_size.load(Ordering::Acquire);
220            let mut buf = vec![0; max_packet_size];
221            let (len, src_addr) = self.socket.recv_from(&mut buf).await?;
222            buf.truncate(len);
223            let pkt = Bytes::from(buf);
224
225            if let Ok(header) = UdpHeader::retrieve_from_async_stream(&mut pkt.as_ref()).await {
226                let pkt = pkt.slice(header.len()..);
227                return Ok((pkt, header.frag, header.address, src_addr));
228            }
229        }
230    }
231
232    /// Sends a UDP relay packet to the remote address to which it is connected. The socks5 UDP header will be added to the packet.
233    pub async fn send<P: AsRef<[u8]>>(&self, pkt: P, frag: u8, from_addr: Address) -> std::io::Result<usize> {
234        let header = UdpHeader::new(frag, from_addr);
235        let mut buf = BytesMut::with_capacity(header.len() + pkt.as_ref().len());
236        header.write_to_buf(&mut buf);
237        buf.extend_from_slice(pkt.as_ref());
238
239        self.socket.send(&buf).await.map(|len| len - header.len())
240    }
241
242    /// Sends a UDP relay packet to a specified remote address to which it is connected. The socks5 UDP header will be added to the packet.
243    pub async fn send_to<P: AsRef<[u8]>>(&self, pkt: P, frag: u8, from_addr: Address, to_addr: SocketAddr) -> std::io::Result<usize> {
244        let header = UdpHeader::new(frag, from_addr);
245        let mut buf = BytesMut::with_capacity(header.len() + pkt.as_ref().len());
246        header.write_to_buf(&mut buf);
247        buf.extend_from_slice(pkt.as_ref());
248
249        self.socket.send_to(&buf, to_addr).await.map(|len| len - header.len())
250    }
251}
252
253impl From<(UdpSocket, usize)> for AssociatedUdpSocket {
254    #[inline]
255    fn from(from: (UdpSocket, usize)) -> Self {
256        AssociatedUdpSocket {
257            socket: from.0,
258            buf_size: AtomicUsize::new(from.1),
259        }
260    }
261}
262
263impl From<AssociatedUdpSocket> for UdpSocket {
264    #[inline]
265    fn from(from: AssociatedUdpSocket) -> Self {
266        from.socket
267    }
268}
269
270impl AsRef<UdpSocket> for AssociatedUdpSocket {
271    #[inline]
272    fn as_ref(&self) -> &UdpSocket {
273        &self.socket
274    }
275}
276
277impl AsMut<UdpSocket> for AssociatedUdpSocket {
278    #[inline]
279    fn as_mut(&mut self) -> &mut UdpSocket {
280        &mut self.socket
281    }
282}