socks5_server/connection/
bind.rs

1//! Socks5 command type `Bind`
2
3use socks5_proto::{Address, Reply, Response};
4use std::{
5    io::Error,
6    marker::PhantomData,
7    net::SocketAddr,
8    pin::Pin,
9    task::{Context, Poll},
10};
11use tokio::{
12    io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf},
13    net::TcpStream,
14};
15
16/// Connection state types
17pub mod state {
18    #[derive(Debug)]
19    pub struct NeedFirstReply;
20
21    #[derive(Debug)]
22    pub struct NeedSecondReply;
23
24    #[derive(Debug)]
25    pub struct Ready;
26}
27
28/// Socks5 command type `Bind`
29///
30/// Reply the client 2 times with [`Bind::reply()`] to complete the command negotiation.
31#[derive(Debug)]
32pub struct Bind<S> {
33    stream: TcpStream,
34    _state: PhantomData<S>,
35}
36
37impl Bind<state::NeedFirstReply> {
38    /// Reply to the SOCKS5 client with the given reply and address.
39    ///
40    /// If encountered an error while writing the reply, the error alongside the original `TcpStream` is returned.
41    pub async fn reply(
42        mut self,
43        reply: Reply,
44        addr: Address,
45    ) -> Result<Bind<state::NeedSecondReply>, (Error, TcpStream)> {
46        let resp = Response::new(reply, addr);
47
48        if let Err(err) = resp.write_to(&mut self.stream).await {
49            return Err((err, self.stream));
50        }
51
52        Ok(Bind::new(self.stream))
53    }
54}
55
56impl Bind<state::NeedSecondReply> {
57    /// Reply to the SOCKS5 client with the given reply and address.
58    ///
59    /// If encountered an error while writing the reply, the error alongside the original `TcpStream` is returned.
60    pub async fn reply(
61        mut self,
62        reply: Reply,
63        addr: Address,
64    ) -> Result<Bind<state::Ready>, (Error, TcpStream)> {
65        let resp = Response::new(reply, addr);
66
67        if let Err(err) = resp.write_to(&mut self.stream).await {
68            return Err((err, self.stream));
69        }
70
71        Ok(Bind::new(self.stream))
72    }
73}
74
75impl<S> Bind<S> {
76    #[inline]
77    pub(super) fn new(stream: TcpStream) -> Self {
78        Self {
79            stream,
80            _state: PhantomData,
81        }
82    }
83
84    /// Causes the other peer to receive a read of length 0, indicating that no more data will be sent. This only closes the stream in one direction.
85    #[inline]
86    pub async fn close(&mut self) -> Result<(), Error> {
87        self.stream.shutdown().await
88    }
89
90    /// Returns the local address that this stream is bound to.
91    #[inline]
92    pub fn local_addr(&self) -> Result<SocketAddr, Error> {
93        self.stream.local_addr()
94    }
95
96    /// Returns the remote address that this stream is connected to.
97    #[inline]
98    pub fn peer_addr(&self) -> Result<SocketAddr, Error> {
99        self.stream.peer_addr()
100    }
101
102    /// Returns a shared reference to the underlying stream.
103    ///
104    /// Note that this may break the encapsulation of the SOCKS5 connection and you should not use this method unless you know what you are doing.
105    #[inline]
106    pub fn get_ref(&self) -> &TcpStream {
107        &self.stream
108    }
109
110    /// Returns a mutable reference to the underlying stream.
111    ///
112    /// Note that this may break the encapsulation of the SOCKS5 connection and you should not use this method unless you know what you are doing.
113    #[inline]
114    pub fn get_mut(&mut self) -> &mut TcpStream {
115        &mut self.stream
116    }
117
118    /// Consumes the [`Bind<S>`] and returns the underlying [`TcpStream`](tokio::net::TcpStream).
119    #[inline]
120    pub fn into_inner(self) -> TcpStream {
121        self.stream
122    }
123}
124
125impl AsyncRead for Bind<state::Ready> {
126    #[inline]
127    fn poll_read(
128        mut self: Pin<&mut Self>,
129        cx: &mut Context<'_>,
130        buf: &mut ReadBuf<'_>,
131    ) -> Poll<Result<(), Error>> {
132        Pin::new(&mut self.stream).poll_read(cx, buf)
133    }
134}
135
136impl AsyncWrite for Bind<state::Ready> {
137    #[inline]
138    fn poll_write(
139        mut self: Pin<&mut Self>,
140        cx: &mut Context<'_>,
141        buf: &[u8],
142    ) -> Poll<Result<usize, Error>> {
143        Pin::new(&mut self.stream).poll_write(cx, buf)
144    }
145
146    #[inline]
147    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
148        Pin::new(&mut self.stream).poll_flush(cx)
149    }
150
151    #[inline]
152    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
153        Pin::new(&mut self.stream).poll_shutdown(cx)
154    }
155}