socks5_impl/server/connection/
bind.rs1use crate::protocol::{Address, AsyncStreamOperation, Reply, Response};
2use std::{
3 marker::PhantomData,
4 net::SocketAddr,
5 pin::Pin,
6 task::{Context, Poll},
7};
8use tokio::{
9 io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf},
10 net::{
11 TcpStream,
12 tcp::{ReadHalf, WriteHalf},
13 },
14};
15
16#[derive(Debug)]
26pub struct Bind<S> {
27 stream: TcpStream,
28 _state: PhantomData<S>,
29}
30
31#[derive(Debug, Default)]
33pub struct NeedFirstReply;
34
35#[derive(Debug, Default)]
37pub struct NeedSecondReply;
38
39#[derive(Debug, Default)]
41pub struct Ready;
42
43impl Bind<NeedFirstReply> {
44 #[inline]
45 pub(super) fn new(stream: TcpStream) -> Self {
46 Self {
47 stream,
48 _state: PhantomData,
49 }
50 }
51
52 pub async fn reply(mut self, reply: Reply, addr: Address) -> std::io::Result<Bind<NeedSecondReply>> {
56 let resp = Response::new(reply, addr);
57 resp.write_to_async_stream(&mut self.stream).await?;
58 Ok(Bind::<NeedSecondReply>::new(self.stream))
59 }
60
61 #[inline]
63 pub async fn shutdown(&mut self) -> std::io::Result<()> {
64 self.stream.shutdown().await
65 }
66
67 #[inline]
69 pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
70 self.stream.local_addr()
71 }
72
73 #[inline]
75 pub fn peer_addr(&self) -> std::io::Result<SocketAddr> {
76 self.stream.peer_addr()
77 }
78
79 #[inline]
83 pub fn nodelay(&self) -> std::io::Result<bool> {
84 self.stream.nodelay()
85 }
86
87 pub fn set_nodelay(&self, nodelay: bool) -> std::io::Result<()> {
93 self.stream.set_nodelay(nodelay)
94 }
95
96 pub fn ttl(&self) -> std::io::Result<u32> {
100 self.stream.ttl()
101 }
102
103 pub fn set_ttl(&self, ttl: u32) -> std::io::Result<()> {
107 self.stream.set_ttl(ttl)
108 }
109}
110
111impl Bind<NeedSecondReply> {
112 #[inline]
113 fn new(stream: TcpStream) -> Self {
114 Self {
115 stream,
116 _state: PhantomData,
117 }
118 }
119
120 pub async fn reply(mut self, reply: Reply, addr: Address) -> Result<Bind<Ready>, (std::io::Error, TcpStream)> {
124 let resp = Response::new(reply, addr);
125
126 if let Err(err) = resp.write_to_async_stream(&mut self.stream).await {
127 return Err((err, self.stream));
128 }
129
130 Ok(Bind::<Ready>::new(self.stream))
131 }
132
133 #[inline]
135 pub async fn shutdown(&mut self) -> std::io::Result<()> {
136 self.stream.shutdown().await
137 }
138
139 #[inline]
141 pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
142 self.stream.local_addr()
143 }
144
145 #[inline]
147 pub fn peer_addr(&self) -> std::io::Result<SocketAddr> {
148 self.stream.peer_addr()
149 }
150
151 #[inline]
156 pub fn nodelay(&self) -> std::io::Result<bool> {
157 self.stream.nodelay()
158 }
159
160 pub fn set_nodelay(&self, nodelay: bool) -> std::io::Result<()> {
166 self.stream.set_nodelay(nodelay)
167 }
168
169 pub fn ttl(&self) -> std::io::Result<u32> {
173 self.stream.ttl()
174 }
175
176 pub fn set_ttl(&self, ttl: u32) -> std::io::Result<()> {
180 self.stream.set_ttl(ttl)
181 }
182}
183
184impl Bind<Ready> {
185 #[inline]
186 fn new(stream: TcpStream) -> Self {
187 Self {
188 stream,
189 _state: PhantomData,
190 }
191 }
192
193 #[inline]
195 pub fn split(&mut self) -> (ReadHalf<'_>, WriteHalf<'_>) {
196 self.stream.split()
197 }
198}
199
200impl std::ops::Deref for Bind<Ready> {
201 type Target = TcpStream;
202
203 #[inline]
204 fn deref(&self) -> &Self::Target {
205 &self.stream
206 }
207}
208
209impl std::ops::DerefMut for Bind<Ready> {
210 #[inline]
211 fn deref_mut(&mut self) -> &mut Self::Target {
212 &mut self.stream
213 }
214}
215
216impl AsyncRead for Bind<Ready> {
217 #[inline]
218 fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<std::io::Result<()>> {
219 Pin::new(&mut self.stream).poll_read(cx, buf)
220 }
221}
222
223impl AsyncWrite for Bind<Ready> {
224 #[inline]
225 fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
226 Pin::new(&mut self.stream).poll_write(cx, buf)
227 }
228
229 #[inline]
230 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
231 Pin::new(&mut self.stream).poll_flush(cx)
232 }
233
234 #[inline]
235 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
236 Pin::new(&mut self.stream).poll_shutdown(cx)
237 }
238}
239
240impl<S> From<Bind<S>> for TcpStream {
241 #[inline]
242 fn from(conn: Bind<S>) -> Self {
243 conn.stream
244 }
245}