socks5_impl/server/connection/
associate.rs

1use crate::protocol::{Address, AsyncStreamOperation, Reply, Response, StreamOperation, UdpHeader};
2use bytes::{Bytes, BytesMut};
3use std::{
4    net::SocketAddr,
5    pin::Pin,
6    sync::atomic::{AtomicUsize, Ordering},
7    task::{Context, Poll},
8    time::Duration,
9};
10use tokio::{
11    io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf},
12    net::{TcpStream, ToSocketAddrs, UdpSocket},
13};
14
15/// Socks5 connection type `UdpAssociate`
16#[derive(Debug)]
17pub struct UdpAssociate<S> {
18    stream: TcpStream,
19    _state: S,
20}
21
22impl<S: Default> UdpAssociate<S> {
23    #[inline]
24    pub(super) fn new(stream: TcpStream) -> Self {
25        Self {
26            stream,
27            _state: S::default(),
28        }
29    }
30
31    /// Reply to the SOCKS5 client with the given reply and address.
32    ///
33    /// If encountered an error while writing the reply, the error alongside the original `TcpStream` is returned.
34    pub async fn reply(mut self, reply: Reply, addr: Address) -> std::io::Result<UdpAssociate<Ready>> {
35        let resp = Response::new(reply, addr);
36        resp.write_to_async_stream(&mut self.stream).await?;
37        Ok(UdpAssociate::<Ready>::new(self.stream))
38    }
39
40    /// Causes the other peer to receive a read of length 0, indicating that no more data will be sent. This only closes the stream in one direction.
41    #[inline]
42    pub async fn shutdown(&mut self) -> std::io::Result<()> {
43        self.stream.shutdown().await
44    }
45
46    /// Returns the local address that this stream is bound to.
47    #[inline]
48    pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
49        self.stream.local_addr()
50    }
51
52    /// Returns the remote address that this stream is connected to.
53    #[inline]
54    pub fn peer_addr(&self) -> std::io::Result<SocketAddr> {
55        self.stream.peer_addr()
56    }
57
58    /// Reads the linger duration for this socket by getting the `SO_LINGER` option.
59    ///
60    /// For more information about this option, see [`set_linger`](#method.set_linger).
61    #[inline]
62    pub fn linger(&self) -> std::io::Result<Option<Duration>> {
63        self.stream.linger()
64    }
65
66    /// Sets the linger duration of this socket by setting the `SO_LINGER` option.
67    ///
68    /// This option controls the action taken when a stream has unsent messages and the stream is closed. If `SO_LINGER` is set,
69    /// the system shall block the process until it can transmit the data or until the time expires.
70    ///
71    /// If `SO_LINGER` is not specified, and the stream is closed, the system handles the call in a way
72    /// that allows the process to continue as quickly as possible.
73    #[inline]
74    pub fn set_linger(&self, dur: Option<Duration>) -> std::io::Result<()> {
75        self.stream.set_linger(dur)
76    }
77
78    /// Gets the value of the `TCP_NODELAY` option on this socket.
79    ///
80    /// For more information about this option, see [`set_nodelay`](#method.set_nodelay).
81    #[inline]
82    pub fn nodelay(&self) -> std::io::Result<bool> {
83        self.stream.nodelay()
84    }
85
86    /// Sets the value of the `TCP_NODELAY` option on this socket.
87    ///
88    /// If set, this option disables the Nagle algorithm. This means that segments are always sent as soon as possible,
89    /// even if there is only a small amount of data. When not set, data is buffered until there is a sufficient amount to send out,
90    /// thereby avoiding the frequent sending of small packets.
91    pub fn set_nodelay(&self, nodelay: bool) -> std::io::Result<()> {
92        self.stream.set_nodelay(nodelay)
93    }
94
95    /// Gets the value of the `IP_TTL` option for this socket.
96    ///
97    /// For more information about this option, see [`set_ttl`](#method.set_ttl).
98    pub fn ttl(&self) -> std::io::Result<u32> {
99        self.stream.ttl()
100    }
101
102    /// Sets the value for the `IP_TTL` option on this socket.
103    ///
104    /// This value sets the time-to-live field that is used in every packet sent from this socket.
105    pub fn set_ttl(&self, ttl: u32) -> std::io::Result<()> {
106        self.stream.set_ttl(ttl)
107    }
108}
109
110#[derive(Debug, Default)]
111pub struct NeedReply;
112
113#[derive(Debug, Default)]
114pub struct Ready;
115
116impl UdpAssociate<Ready> {
117    /// Wait until the client closes this TCP connection.
118    ///
119    /// Socks5 protocol defines that when the client closes the TCP connection used to send the associate command,
120    /// the server should release the associated UDP socket.
121    pub async fn wait_until_closed(&mut self) -> std::io::Result<()> {
122        loop {
123            match self.stream.read(&mut [0]).await {
124                Ok(0) => break Ok(()),
125                Ok(_) => {}
126                Err(err) => break Err(err),
127            }
128        }
129    }
130}
131
132impl std::ops::Deref for UdpAssociate<Ready> {
133    type Target = TcpStream;
134
135    #[inline]
136    fn deref(&self) -> &Self::Target {
137        &self.stream
138    }
139}
140
141impl std::ops::DerefMut for UdpAssociate<Ready> {
142    #[inline]
143    fn deref_mut(&mut self) -> &mut Self::Target {
144        &mut self.stream
145    }
146}
147
148impl AsyncRead for UdpAssociate<Ready> {
149    #[inline]
150    fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<std::io::Result<()>> {
151        Pin::new(&mut self.stream).poll_read(cx, buf)
152    }
153}
154
155impl AsyncWrite for UdpAssociate<Ready> {
156    #[inline]
157    fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
158        Pin::new(&mut self.stream).poll_write(cx, buf)
159    }
160
161    #[inline]
162    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
163        Pin::new(&mut self.stream).poll_flush(cx)
164    }
165
166    #[inline]
167    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
168        Pin::new(&mut self.stream).poll_shutdown(cx)
169    }
170}
171
172impl<S> From<UdpAssociate<S>> for TcpStream {
173    #[inline]
174    fn from(conn: UdpAssociate<S>) -> Self {
175        conn.stream
176    }
177}
178
179/// This is a helper for managing the associated UDP socket.
180///
181/// It will add the socks5 UDP header to every UDP packet it sends, also try to parse the socks5 UDP header from any UDP packet received.
182///
183/// The receiving buffer size for each UDP packet can be set with [`set_recv_buffer_size()`](#method.set_recv_buffer_size),
184/// and be read with [`get_max_packet_size()`](#method.get_recv_buffer_size).
185///
186/// You can create this struct by using [`AssociatedUdpSocket::from::<(UdpSocket, usize)>()`](#impl-From<UdpSocket>),
187/// the first element of the tuple is the UDP socket, the second element is the receiving buffer size.
188///
189/// This struct can also be revert into a raw tokio UDP socket with [`UdpSocket::from::<AssociatedUdpSocket>()`](#impl-From<AssociatedUdpSocket>).
190///
191/// [`AssociatedUdpSocket`] can be used as the associated UDP socket.
192#[derive(Debug)]
193pub struct AssociatedUdpSocket {
194    socket: UdpSocket,
195    buf_size: AtomicUsize,
196}
197
198impl AssociatedUdpSocket {
199    /// Connects the UDP socket setting the default destination for send() and limiting packets that are read via recv from the address specified in addr.
200    #[inline]
201    pub async fn connect<A: ToSocketAddrs>(&self, addr: A) -> std::io::Result<()> {
202        self.socket.connect(addr).await
203    }
204
205    /// Get the maximum UDP packet size, with socks5 UDP header included.
206    pub fn get_max_packet_size(&self) -> usize {
207        self.buf_size.load(Ordering::Relaxed)
208    }
209
210    /// Set the maximum UDP packet size, with socks5 UDP header included, for adjusting the receiving buffer size.
211    pub fn set_max_packet_size(&self, size: usize) {
212        self.buf_size.store(size, Ordering::Release);
213    }
214
215    /// Receives a socks5 UDP relay packet on the socket from the remote address to which it is connected.
216    /// On success, returns the packet itself, the fragment number and the remote target address.
217    ///
218    /// The [`connect`](#method.connect) method will connect this socket to a remote address.
219    /// This method will fail if the socket is not connected.
220    pub async fn recv(&self) -> std::io::Result<(Bytes, u8, Address)> {
221        loop {
222            let max_packet_size = self.buf_size.load(Ordering::Acquire);
223            let mut buf = vec![0; max_packet_size];
224            let len = self.socket.recv(&mut buf).await?;
225            buf.truncate(len);
226            let pkt = Bytes::from(buf);
227
228            if let Ok(header) = UdpHeader::retrieve_from_async_stream(&mut pkt.as_ref()).await {
229                let pkt = pkt.slice(header.len()..);
230                return Ok((pkt, header.frag, header.address));
231            }
232        }
233    }
234
235    /// Receives a socks5 UDP relay packet on the socket from the any remote address.
236    /// On success, returns the packet itself, the fragment number, the remote target address and the source address.
237    pub async fn recv_from(&self) -> std::io::Result<(Bytes, u8, Address, SocketAddr)> {
238        loop {
239            let max_packet_size = self.buf_size.load(Ordering::Acquire);
240            let mut buf = vec![0; max_packet_size];
241            let (len, src_addr) = self.socket.recv_from(&mut buf).await?;
242            buf.truncate(len);
243            let pkt = Bytes::from(buf);
244
245            if let Ok(header) = UdpHeader::retrieve_from_async_stream(&mut pkt.as_ref()).await {
246                let pkt = pkt.slice(header.len()..);
247                return Ok((pkt, header.frag, header.address, src_addr));
248            }
249        }
250    }
251
252    /// Sends a UDP relay packet to the remote address to which it is connected. The socks5 UDP header will be added to the packet.
253    pub async fn send<P: AsRef<[u8]>>(&self, pkt: P, frag: u8, from_addr: Address) -> std::io::Result<usize> {
254        let header = UdpHeader::new(frag, from_addr);
255        let mut buf = BytesMut::with_capacity(header.len() + pkt.as_ref().len());
256        header.write_to_buf(&mut buf);
257        buf.extend_from_slice(pkt.as_ref());
258
259        self.socket.send(&buf).await.map(|len| len - header.len())
260    }
261
262    /// Sends a UDP relay packet to a specified remote address to which it is connected. The socks5 UDP header will be added to the packet.
263    pub async fn send_to<P: AsRef<[u8]>>(&self, pkt: P, frag: u8, from_addr: Address, to_addr: SocketAddr) -> std::io::Result<usize> {
264        let header = UdpHeader::new(frag, from_addr);
265        let mut buf = BytesMut::with_capacity(header.len() + pkt.as_ref().len());
266        header.write_to_buf(&mut buf);
267        buf.extend_from_slice(pkt.as_ref());
268
269        self.socket.send_to(&buf, to_addr).await.map(|len| len - header.len())
270    }
271}
272
273impl From<(UdpSocket, usize)> for AssociatedUdpSocket {
274    #[inline]
275    fn from(from: (UdpSocket, usize)) -> Self {
276        AssociatedUdpSocket {
277            socket: from.0,
278            buf_size: AtomicUsize::new(from.1),
279        }
280    }
281}
282
283impl From<AssociatedUdpSocket> for UdpSocket {
284    #[inline]
285    fn from(from: AssociatedUdpSocket) -> Self {
286        from.socket
287    }
288}
289
290impl AsRef<UdpSocket> for AssociatedUdpSocket {
291    #[inline]
292    fn as_ref(&self) -> &UdpSocket {
293        &self.socket
294    }
295}
296
297impl AsMut<UdpSocket> for AssociatedUdpSocket {
298    #[inline]
299    fn as_mut(&mut self) -> &mut UdpSocket {
300        &mut self.socket
301    }
302}