socks5_impl/server/connection/
mod.rs1use self::{associate::UdpAssociate, bind::Bind, connect::Connect};
2use crate::{
3 protocol::{self, Address, AsyncStreamOperation, AuthMethod, Command, handshake},
4 server::AuthAdaptor,
5};
6use std::{net::SocketAddr, time::Duration};
7use tokio::{io::AsyncWriteExt, net::TcpStream};
8
9pub mod associate;
10pub mod bind;
11pub mod connect;
12
13pub struct IncomingConnection {
16 stream: TcpStream,
17 auth: AuthAdaptor,
18}
19
20impl IncomingConnection {
21 #[inline]
22 pub(crate) fn new(stream: TcpStream, auth: AuthAdaptor) -> Self {
23 IncomingConnection { stream, auth }
24 }
25
26 #[inline]
28 pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
29 self.stream.local_addr()
30 }
31
32 #[inline]
34 pub fn peer_addr(&self) -> std::io::Result<SocketAddr> {
35 self.stream.peer_addr()
36 }
37
38 #[inline]
40 pub async fn shutdown(&mut self) -> std::io::Result<()> {
41 self.stream.shutdown().await
42 }
43
44 #[inline]
49 pub fn nodelay(&self) -> std::io::Result<bool> {
50 self.stream.nodelay()
51 }
52
53 pub fn set_nodelay(&self, nodelay: bool) -> std::io::Result<()> {
59 self.stream.set_nodelay(nodelay)
60 }
61
62 pub fn ttl(&self) -> std::io::Result<u32> {
67 self.stream.ttl()
68 }
69
70 pub fn set_ttl(&self, ttl: u32) -> std::io::Result<()> {
74 self.stream.set_ttl(ttl)
75 }
76
77 pub async fn authenticate_with_timeout(self, timeout: Duration) -> crate::Result<Authenticated> {
79 tokio::time::timeout(timeout, self.authenticate())
80 .await
81 .map_err(|_| crate::Error::String("handshake timeout".into()))?
82 }
83
84 pub async fn authenticate(mut self) -> crate::Result<Authenticated> {
92 let request = handshake::Request::retrieve_from_async_stream(&mut self.stream).await?;
93 if let Some(method) = self.evaluate_request(&request) {
94 let response = handshake::Response::new(method);
95 response.write_to_async_stream(&mut self.stream).await?;
96 if !self.auth.execute(&mut self.stream).await? {
97 use std::io::{Error, ErrorKind::PermissionDenied};
98 return Err(crate::Error::Io(Error::new(PermissionDenied, "authentication failed")));
99 }
100 Ok(Authenticated::new(self.stream))
101 } else {
102 let response = handshake::Response::new(AuthMethod::NoAcceptableMethods);
103 response.write_to_async_stream(&mut self.stream).await?;
104 let err = "No available handshake method provided by client";
105 Err(crate::Error::Io(std::io::Error::new(std::io::ErrorKind::Unsupported, err)))
106 }
107 }
108
109 fn evaluate_request(&self, req: &handshake::Request) -> Option<AuthMethod> {
110 let method = self.auth.auth_method();
111 if req.evaluate_method(method) { Some(method) } else { None }
112 }
113}
114
115impl std::fmt::Debug for IncomingConnection {
116 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
117 f.debug_struct("IncomingConnection").field("stream", &self.stream).finish()
118 }
119}
120
121impl From<IncomingConnection> for TcpStream {
122 #[inline]
123 fn from(conn: IncomingConnection) -> Self {
124 conn.stream
125 }
126}
127
128pub struct Authenticated(TcpStream);
135
136impl Authenticated {
137 #[inline]
138 fn new(stream: TcpStream) -> Self {
139 Self(stream)
140 }
141
142 pub async fn wait_request(mut self) -> crate::Result<ClientConnection> {
150 let req = protocol::Request::retrieve_from_async_stream(&mut self.0).await?;
151
152 match req.command {
153 Command::UdpAssociate => Ok(ClientConnection::UdpAssociate(
154 UdpAssociate::<associate::NeedReply>::new(self.0),
155 req.address,
156 )),
157 Command::Bind => Ok(ClientConnection::Bind(Bind::<bind::NeedFirstReply>::new(self.0), req.address)),
158 Command::Connect => Ok(ClientConnection::Connect(Connect::<connect::NeedReply>::new(self.0), req.address)),
159 }
160 }
161
162 #[inline]
164 pub async fn shutdown(&mut self) -> std::io::Result<()> {
165 self.0.shutdown().await
166 }
167
168 #[inline]
170 pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
171 self.0.local_addr()
172 }
173
174 #[inline]
176 pub fn peer_addr(&self) -> std::io::Result<SocketAddr> {
177 self.0.peer_addr()
178 }
179
180 #[inline]
185 pub fn nodelay(&self) -> std::io::Result<bool> {
186 self.0.nodelay()
187 }
188
189 pub fn set_nodelay(&self, nodelay: bool) -> std::io::Result<()> {
195 self.0.set_nodelay(nodelay)
196 }
197
198 pub fn ttl(&self) -> std::io::Result<u32> {
203 self.0.ttl()
204 }
205
206 pub fn set_ttl(&self, ttl: u32) -> std::io::Result<()> {
210 self.0.set_ttl(ttl)
211 }
212}
213
214impl From<Authenticated> for TcpStream {
215 #[inline]
216 fn from(conn: Authenticated) -> Self {
217 conn.0
218 }
219}
220
221#[derive(Debug)]
227pub enum ClientConnection {
228 UdpAssociate(UdpAssociate<associate::NeedReply>, Address),
229 Bind(Bind<bind::NeedFirstReply>, Address),
230 Connect(Connect<connect::NeedReply>, Address),
231}