socks5_server/connection/
mod.rs1use self::{associate::Associate, bind::Bind, connect::Connect};
4use crate::AuthAdaptor;
5use socks5_proto::{
6 handshake::{
7 Method as HandshakeMethod, Request as HandshakeRequest, Response as HandshakeResponse,
8 },
9 Address, Command as ProtocolCommand, Error, ProtocolError, Request,
10};
11use std::{fmt::Debug, io::Error as IoError, marker::PhantomData, net::SocketAddr};
12use tokio::{io::AsyncWriteExt, net::TcpStream};
13
14pub mod associate;
15pub mod bind;
16pub mod connect;
17
18pub mod state {
20 #[derive(Debug)]
21 pub struct NeedAuthenticate;
22
23 #[derive(Debug)]
24 pub struct NeedCommand;
25}
26
27pub struct IncomingConnection<A, S> {
31 stream: TcpStream,
32 auth: AuthAdaptor<A>,
33 _state: PhantomData<S>,
34}
35
36impl<A> IncomingConnection<A, state::NeedAuthenticate> {
37 pub async fn authenticate(
43 mut self,
44 ) -> Result<(IncomingConnection<A, state::NeedCommand>, A), (Error, TcpStream)> {
45 let req = match HandshakeRequest::read_from(&mut self.stream).await {
46 Ok(req) => req,
47 Err(err) => return Err((err, self.stream)),
48 };
49 let chosen_method = self.auth.as_handshake_method();
50
51 if req.methods.contains(&chosen_method) {
52 let resp = HandshakeResponse::new(chosen_method);
53
54 if let Err(err) = resp.write_to(&mut self.stream).await {
55 return Err((Error::Io(err), self.stream));
56 }
57
58 let output = self.auth.execute(&mut self.stream).await;
59
60 Ok((IncomingConnection::new(self.stream, self.auth), output))
61 } else {
62 let resp = HandshakeResponse::new(HandshakeMethod::UNACCEPTABLE);
63
64 if let Err(err) = resp.write_to(&mut self.stream).await {
65 return Err((Error::Io(err), self.stream));
66 }
67
68 Err((
69 Error::Protocol(ProtocolError::NoAcceptableHandshakeMethod {
70 version: socks5_proto::SOCKS_VERSION,
71 chosen_method,
72 methods: req.methods,
73 }),
74 self.stream,
75 ))
76 }
77 }
78}
79
80impl<A> IncomingConnection<A, state::NeedCommand> {
81 pub async fn wait(mut self) -> Result<Command, (Error, TcpStream)> {
89 let req = match Request::read_from(&mut self.stream).await {
90 Ok(req) => req,
91 Err(err) => return Err((err, self.stream)),
92 };
93
94 match req.command {
95 ProtocolCommand::Associate => {
96 Ok(Command::Associate(Associate::new(self.stream), req.address))
97 }
98 ProtocolCommand::Bind => Ok(Command::Bind(Bind::new(self.stream), req.address)),
99 ProtocolCommand::Connect => {
100 Ok(Command::Connect(Connect::new(self.stream), req.address))
101 }
102 }
103 }
104}
105
106impl<A, S> IncomingConnection<A, S> {
107 #[inline]
108 pub(crate) fn new(stream: TcpStream, auth: AuthAdaptor<A>) -> Self {
109 Self {
110 stream,
111 auth,
112 _state: PhantomData,
113 }
114 }
115
116 #[inline]
118 pub async fn close(&mut self) -> Result<(), IoError> {
119 self.stream.shutdown().await
120 }
121
122 #[inline]
124 pub fn local_addr(&self) -> Result<SocketAddr, IoError> {
125 self.stream.local_addr()
126 }
127
128 #[inline]
130 pub fn peer_addr(&self) -> Result<SocketAddr, IoError> {
131 self.stream.peer_addr()
132 }
133
134 #[inline]
138 pub fn get_ref(&self) -> &TcpStream {
139 &self.stream
140 }
141
142 #[inline]
146 pub fn get_mut(&mut self) -> &mut TcpStream {
147 &mut self.stream
148 }
149
150 #[inline]
152 pub fn into_inner(self) -> TcpStream {
153 self.stream
154 }
155}
156
157impl<A, S> Debug for IncomingConnection<A, S> {
158 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
159 f.debug_struct("IncomingConnection")
160 .field("stream", &self.stream)
161 .finish()
162 }
163}
164
165#[derive(Debug)]
167pub enum Command {
168 Associate(Associate<associate::state::NeedReply>, Address),
169 Bind(Bind<bind::state::NeedFirstReply>, Address),
170 Connect(Connect<connect::state::NeedReply>, Address),
171}