socks5_impl/server/connection/
bind.rs1use crate::protocol::{Address, AsyncStreamOperation, Reply, Response};
2use std::{
3 marker::PhantomData,
4 net::{SocketAddr, ToSocketAddrs},
5 pin::Pin,
6 task::{Context, Poll},
7};
8use tokio::{
9 io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf},
10 net::{
11 TcpListener, TcpStream,
12 tcp::{ReadHalf, WriteHalf},
13 },
14};
15
16#[derive(Debug)]
26pub struct Bind<S> {
27 stream: TcpStream,
28 _state: PhantomData<S>,
29}
30
31#[derive(Debug, Default)]
33pub struct NeedFirstReply;
34
35#[derive(Debug, Default)]
37pub struct NeedSecondReply;
38
39#[derive(Debug, Default)]
41pub struct Ready;
42
43impl Bind<NeedFirstReply> {
44 #[inline]
45 pub(super) fn new(stream: TcpStream) -> Self {
46 Self {
47 stream,
48 _state: PhantomData,
49 }
50 }
51
52 pub async fn reply(mut self, reply: Reply, addr: Address) -> std::io::Result<Bind<NeedSecondReply>> {
56 let resp = Response::new(reply, addr);
57 resp.write_to_async_stream(&mut self.stream).await?;
58 Ok(Bind::<NeedSecondReply>::new(self.stream))
59 }
60
61 pub async fn accept(self, bind_addr: Address) -> std::io::Result<(Bind<NeedSecondReply>, TcpStream)> {
67 let bind_addr = bind_addr
68 .to_socket_addrs()?
69 .next()
70 .ok_or_else(|| std::io::Error::new(std::io::ErrorKind::InvalidInput, "invalid bind address"))?;
71
72 let listener = match TcpListener::bind(bind_addr).await {
73 Ok(listener) => listener,
74 Err(err) => {
75 let _ = self.reply(Reply::GeneralFailure, Address::unspecified()).await;
76 return Err(err);
77 }
78 };
79
80 let local_addr = listener.local_addr()?;
81 let bind = self.reply(Reply::Succeeded, Address::from(local_addr)).await?;
82 let (incoming, _) = listener.accept().await?;
83 Ok((bind, incoming))
84 }
85
86 pub async fn bind(self, bind_addr: Address) -> std::io::Result<(Bind<Ready>, TcpStream)> {
88 let (bind, incoming) = self.accept(bind_addr).await?;
89 let remote_addr = incoming.peer_addr()?;
90 let conn = match bind.reply(Reply::Succeeded, Address::from(remote_addr)).await {
91 Ok(conn) => conn,
92 Err((err, _stream)) => return Err(err),
93 };
94 Ok((conn, incoming))
95 }
96
97 #[inline]
99 pub async fn shutdown(&mut self) -> std::io::Result<()> {
100 self.stream.shutdown().await
101 }
102
103 #[inline]
105 pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
106 self.stream.local_addr()
107 }
108
109 #[inline]
111 pub fn peer_addr(&self) -> std::io::Result<SocketAddr> {
112 self.stream.peer_addr()
113 }
114
115 #[inline]
119 pub fn nodelay(&self) -> std::io::Result<bool> {
120 self.stream.nodelay()
121 }
122
123 pub fn set_nodelay(&self, nodelay: bool) -> std::io::Result<()> {
129 self.stream.set_nodelay(nodelay)
130 }
131
132 pub fn ttl(&self) -> std::io::Result<u32> {
136 self.stream.ttl()
137 }
138
139 pub fn set_ttl(&self, ttl: u32) -> std::io::Result<()> {
143 self.stream.set_ttl(ttl)
144 }
145}
146
147impl Bind<NeedSecondReply> {
148 #[inline]
149 fn new(stream: TcpStream) -> Self {
150 Self {
151 stream,
152 _state: PhantomData,
153 }
154 }
155
156 pub async fn reply(mut self, reply: Reply, addr: Address) -> Result<Bind<Ready>, (std::io::Error, TcpStream)> {
160 let resp = Response::new(reply, addr);
161
162 if let Err(err) = resp.write_to_async_stream(&mut self.stream).await {
163 return Err((err, self.stream));
164 }
165
166 Ok(Bind::<Ready>::new(self.stream))
167 }
168
169 #[inline]
171 pub async fn shutdown(&mut self) -> std::io::Result<()> {
172 self.stream.shutdown().await
173 }
174
175 #[inline]
177 pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
178 self.stream.local_addr()
179 }
180
181 #[inline]
183 pub fn peer_addr(&self) -> std::io::Result<SocketAddr> {
184 self.stream.peer_addr()
185 }
186
187 #[inline]
192 pub fn nodelay(&self) -> std::io::Result<bool> {
193 self.stream.nodelay()
194 }
195
196 pub fn set_nodelay(&self, nodelay: bool) -> std::io::Result<()> {
202 self.stream.set_nodelay(nodelay)
203 }
204
205 pub fn ttl(&self) -> std::io::Result<u32> {
209 self.stream.ttl()
210 }
211
212 pub fn set_ttl(&self, ttl: u32) -> std::io::Result<()> {
216 self.stream.set_ttl(ttl)
217 }
218}
219
220impl Bind<Ready> {
221 #[inline]
222 fn new(stream: TcpStream) -> Self {
223 Self {
224 stream,
225 _state: PhantomData,
226 }
227 }
228
229 #[inline]
231 pub fn split(&mut self) -> (ReadHalf<'_>, WriteHalf<'_>) {
232 self.stream.split()
233 }
234}
235
236impl std::ops::Deref for Bind<Ready> {
237 type Target = TcpStream;
238
239 #[inline]
240 fn deref(&self) -> &Self::Target {
241 &self.stream
242 }
243}
244
245impl std::ops::DerefMut for Bind<Ready> {
246 #[inline]
247 fn deref_mut(&mut self) -> &mut Self::Target {
248 &mut self.stream
249 }
250}
251
252impl AsyncRead for Bind<Ready> {
253 #[inline]
254 fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<std::io::Result<()>> {
255 Pin::new(&mut self.stream).poll_read(cx, buf)
256 }
257}
258
259impl AsyncWrite for Bind<Ready> {
260 #[inline]
261 fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
262 Pin::new(&mut self.stream).poll_write(cx, buf)
263 }
264
265 #[inline]
266 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
267 Pin::new(&mut self.stream).poll_flush(cx)
268 }
269
270 #[inline]
271 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
272 Pin::new(&mut self.stream).poll_shutdown(cx)
273 }
274}
275
276impl<S> From<Bind<S>> for TcpStream {
277 #[inline]
278 fn from(conn: Bind<S>) -> Self {
279 conn.stream
280 }
281}