Skip to main content

socks5_impl/server/connection/
bind.rs

1use crate::protocol::{Address, AsyncStreamOperation, Reply, Response};
2use std::{
3    marker::PhantomData,
4    net::{SocketAddr, ToSocketAddrs},
5    pin::Pin,
6    task::{Context, Poll},
7};
8use tokio::{
9    io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf},
10    net::{
11        TcpListener, TcpStream,
12        tcp::{ReadHalf, WriteHalf},
13    },
14};
15
16/// Socks5 command type `Bind`
17///
18/// By [`wait_request`](crate::server::connection::Authenticated::wait_request)
19/// on an [`Authenticated`](crate::server::connection::Authenticated) from SOCKS5 client,
20/// you may get a `Bind<NeedFirstReply>`. After replying the client 2 times
21/// using [`reply()`](crate::server::connection::Bind::reply),
22/// you will get a `Bind<Ready>`, which can be used as a regular async TCP stream.
23///
24/// A `Bind<S>` can be converted to a regular tokio [`TcpStream`](https://docs.rs/tokio/latest/tokio/net/struct.TcpStream.html) by using the `From` trait.
25#[derive(Debug)]
26pub struct Bind<S> {
27    stream: TcpStream,
28    _state: PhantomData<S>,
29}
30
31/// Marker type indicating that the connection needs its first reply.
32#[derive(Debug, Default)]
33pub struct NeedFirstReply;
34
35/// Marker type indicating that the connection needs its second reply.
36#[derive(Debug, Default)]
37pub struct NeedSecondReply;
38
39/// Marker type indicating that the connection is ready to use as a regular TCP stream.
40#[derive(Debug, Default)]
41pub struct Ready;
42
43impl Bind<NeedFirstReply> {
44    #[inline]
45    pub(super) fn new(stream: TcpStream) -> Self {
46        Self {
47            stream,
48            _state: PhantomData,
49        }
50    }
51
52    /// Reply to the SOCKS5 client with the given reply and address.
53    ///
54    /// If encountered an error while writing the reply, the error alongside the original `TcpStream` is returned.
55    pub async fn reply(mut self, reply: Reply, addr: Address) -> std::io::Result<Bind<NeedSecondReply>> {
56        let resp = Response::new(reply, addr);
57        resp.write_to_async_stream(&mut self.stream).await?;
58        Ok(Bind::<NeedSecondReply>::new(self.stream))
59    }
60
61    /// Accept an incoming connection for SOCKS5 BIND.
62    ///
63    /// This binds a TCP listener to the requested `bind_addr`, sends the first BIND
64    /// reply with the actual bound address, then waits for the remote peer to connect.
65    /// The returned `Bind<NeedSecondReply>` can be used to send the second reply.
66    pub async fn accept(self, bind_addr: Address) -> std::io::Result<(Bind<NeedSecondReply>, TcpStream)> {
67        let bind_addr = bind_addr
68            .to_socket_addrs()?
69            .next()
70            .ok_or_else(|| std::io::Error::new(std::io::ErrorKind::InvalidInput, "invalid bind address"))?;
71
72        let listener = match TcpListener::bind(bind_addr).await {
73            Ok(listener) => listener,
74            Err(err) => {
75                let _ = self.reply(Reply::GeneralFailure, Address::unspecified()).await;
76                return Err(err);
77            }
78        };
79
80        let local_addr = listener.local_addr()?;
81        let bind = self.reply(Reply::Succeeded, Address::from(local_addr)).await?;
82        let (incoming, _) = listener.accept().await?;
83        Ok((bind, incoming))
84    }
85
86    /// Fully complete BIND handling, including first reply, accept, and second reply.
87    pub async fn bind(self, bind_addr: Address) -> std::io::Result<(Bind<Ready>, TcpStream)> {
88        let (bind, incoming) = self.accept(bind_addr).await?;
89        let remote_addr = incoming.peer_addr()?;
90        let conn = match bind.reply(Reply::Succeeded, Address::from(remote_addr)).await {
91            Ok(conn) => conn,
92            Err((err, _stream)) => return Err(err),
93        };
94        Ok((conn, incoming))
95    }
96
97    /// Causes the other peer to receive a read of length 0, indicating that no more data will be sent. This only closes the stream in one direction.
98    #[inline]
99    pub async fn shutdown(&mut self) -> std::io::Result<()> {
100        self.stream.shutdown().await
101    }
102
103    /// Returns the local address that this stream is bound to.
104    #[inline]
105    pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
106        self.stream.local_addr()
107    }
108
109    /// Returns the remote address that this stream is connected to.
110    #[inline]
111    pub fn peer_addr(&self) -> std::io::Result<SocketAddr> {
112        self.stream.peer_addr()
113    }
114
115    /// Gets the value of the `TCP_NODELAY` option on this socket.
116    ///
117    /// For more information about this option, see [`set_nodelay`](crate::server::connection::Bind::set_nodelay).
118    #[inline]
119    pub fn nodelay(&self) -> std::io::Result<bool> {
120        self.stream.nodelay()
121    }
122
123    /// Sets the value of the `TCP_NODELAY` option on this socket.
124    ///
125    /// If set, this option disables the Nagle algorithm. This means that segments are always sent as soon as possible,
126    /// even if there is only a small amount of data. When not set, data is buffered until there is a sufficient amount to send out,
127    /// thereby avoiding the frequent sending of small packets.
128    pub fn set_nodelay(&self, nodelay: bool) -> std::io::Result<()> {
129        self.stream.set_nodelay(nodelay)
130    }
131
132    /// Gets the value of the `IP_TTL` option for this socket.
133    ///
134    /// For more information about this option, see [`set_ttl`](crate::server::connection::Bind::set_ttl).
135    pub fn ttl(&self) -> std::io::Result<u32> {
136        self.stream.ttl()
137    }
138
139    /// Sets the value for the `IP_TTL` option on this socket.
140    ///
141    /// This value sets the time-to-live field that is used in every packet sent from this socket.
142    pub fn set_ttl(&self, ttl: u32) -> std::io::Result<()> {
143        self.stream.set_ttl(ttl)
144    }
145}
146
147impl Bind<NeedSecondReply> {
148    #[inline]
149    fn new(stream: TcpStream) -> Self {
150        Self {
151            stream,
152            _state: PhantomData,
153        }
154    }
155
156    /// Reply to the SOCKS5 client with the given reply and address.
157    ///
158    /// If encountered an error while writing the reply, the error alongside the original `TcpStream` is returned.
159    pub async fn reply(mut self, reply: Reply, addr: Address) -> Result<Bind<Ready>, (std::io::Error, TcpStream)> {
160        let resp = Response::new(reply, addr);
161
162        if let Err(err) = resp.write_to_async_stream(&mut self.stream).await {
163            return Err((err, self.stream));
164        }
165
166        Ok(Bind::<Ready>::new(self.stream))
167    }
168
169    /// Causes the other peer to receive a read of length 0, indicating that no more data will be sent. This only closes the stream in one direction.
170    #[inline]
171    pub async fn shutdown(&mut self) -> std::io::Result<()> {
172        self.stream.shutdown().await
173    }
174
175    /// Returns the local address that this stream is bound to.
176    #[inline]
177    pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
178        self.stream.local_addr()
179    }
180
181    /// Returns the remote address that this stream is connected to.
182    #[inline]
183    pub fn peer_addr(&self) -> std::io::Result<SocketAddr> {
184        self.stream.peer_addr()
185    }
186
187    /// Gets the value of the `TCP_NODELAY` option on this socket.
188    ///
189    /// For more information about this option, see
190    /// [`set_nodelay`](crate::server::connection::Bind::set_nodelay).
191    #[inline]
192    pub fn nodelay(&self) -> std::io::Result<bool> {
193        self.stream.nodelay()
194    }
195
196    /// Sets the value of the `TCP_NODELAY` option on this socket.
197    ///
198    /// If set, this option disables the Nagle algorithm. This means that segments are always sent as soon as possible,
199    /// even if there is only a small amount of data. When not set, data is buffered until there is a sufficient amount to send out,
200    /// thereby avoiding the frequent sending of small packets.
201    pub fn set_nodelay(&self, nodelay: bool) -> std::io::Result<()> {
202        self.stream.set_nodelay(nodelay)
203    }
204
205    /// Gets the value of the `IP_TTL` option for this socket.
206    ///
207    /// For more information about this option, see [`set_ttl`](crate::server::connection::Bind::set_ttl).
208    pub fn ttl(&self) -> std::io::Result<u32> {
209        self.stream.ttl()
210    }
211
212    /// Sets the value for the `IP_TTL` option on this socket.
213    ///
214    /// This value sets the time-to-live field that is used in every packet sent from this socket.
215    pub fn set_ttl(&self, ttl: u32) -> std::io::Result<()> {
216        self.stream.set_ttl(ttl)
217    }
218}
219
220impl Bind<Ready> {
221    #[inline]
222    fn new(stream: TcpStream) -> Self {
223        Self {
224            stream,
225            _state: PhantomData,
226        }
227    }
228
229    /// Split the connection into a read and a write half.
230    #[inline]
231    pub fn split(&mut self) -> (ReadHalf<'_>, WriteHalf<'_>) {
232        self.stream.split()
233    }
234}
235
236impl std::ops::Deref for Bind<Ready> {
237    type Target = TcpStream;
238
239    #[inline]
240    fn deref(&self) -> &Self::Target {
241        &self.stream
242    }
243}
244
245impl std::ops::DerefMut for Bind<Ready> {
246    #[inline]
247    fn deref_mut(&mut self) -> &mut Self::Target {
248        &mut self.stream
249    }
250}
251
252impl AsyncRead for Bind<Ready> {
253    #[inline]
254    fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<std::io::Result<()>> {
255        Pin::new(&mut self.stream).poll_read(cx, buf)
256    }
257}
258
259impl AsyncWrite for Bind<Ready> {
260    #[inline]
261    fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
262        Pin::new(&mut self.stream).poll_write(cx, buf)
263    }
264
265    #[inline]
266    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
267        Pin::new(&mut self.stream).poll_flush(cx)
268    }
269
270    #[inline]
271    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
272        Pin::new(&mut self.stream).poll_shutdown(cx)
273    }
274}
275
276impl<S> From<Bind<S>> for TcpStream {
277    #[inline]
278    fn from(conn: Bind<S>) -> Self {
279        conn.stream
280    }
281}