socks5_server/connection/
mod.rs

1//! Connection abstraction of the SOCKS5 protocol
2
3use self::{associate::Associate, bind::Bind, connect::Connect};
4use crate::AuthAdaptor;
5use socks5_proto::{
6    handshake::{
7        Method as HandshakeMethod, Request as HandshakeRequest, Response as HandshakeResponse,
8    },
9    Address, Command as ProtocolCommand, Error, ProtocolError, Request,
10};
11use std::{fmt::Debug, io::Error as IoError, marker::PhantomData, net::SocketAddr};
12use tokio::{io::AsyncWriteExt, net::TcpStream};
13
14pub mod associate;
15pub mod bind;
16pub mod connect;
17
18/// Incoming connection state types
19pub mod state {
20    #[derive(Debug)]
21    pub struct NeedAuthenticate;
22
23    #[derive(Debug)]
24    pub struct NeedCommand;
25}
26
27/// An incoming SOCKS5 connection.
28///
29/// This may not be a valid SOCKS5 connection. You should call [`IncomingConnection::authenticate()`] and [`IncomingConnection::wait()`] to perform a SOCKS5 connection negotiation.
30pub struct IncomingConnection<A, S> {
31    stream: TcpStream,
32    auth: AuthAdaptor<A>,
33    _state: PhantomData<S>,
34}
35
36impl<A> IncomingConnection<A, state::NeedAuthenticate> {
37    /// Perform a SOCKS5 authentication handshake using the given [`Auth`](crate::Auth) adapter.
38    ///
39    /// If the handshake succeeds, an [`IncomingConnection<A, state::NeedCommand>`] alongs with the output of the [`Auth`](crate::Auth) adapter `A` is returned. Otherwise, the error and the underlying [`TcpStream`](tokio::net::TcpStream) is returned.
40    ///
41    /// Note that this method will not implicitly close the connection even if the handshake failed.
42    pub async fn authenticate(
43        mut self,
44    ) -> Result<(IncomingConnection<A, state::NeedCommand>, A), (Error, TcpStream)> {
45        let req = match HandshakeRequest::read_from(&mut self.stream).await {
46            Ok(req) => req,
47            Err(err) => return Err((err, self.stream)),
48        };
49        let chosen_method = self.auth.as_handshake_method();
50
51        if req.methods.contains(&chosen_method) {
52            let resp = HandshakeResponse::new(chosen_method);
53
54            if let Err(err) = resp.write_to(&mut self.stream).await {
55                return Err((Error::Io(err), self.stream));
56            }
57
58            let output = self.auth.execute(&mut self.stream).await;
59
60            Ok((IncomingConnection::new(self.stream, self.auth), output))
61        } else {
62            let resp = HandshakeResponse::new(HandshakeMethod::UNACCEPTABLE);
63
64            if let Err(err) = resp.write_to(&mut self.stream).await {
65                return Err((Error::Io(err), self.stream));
66            }
67
68            Err((
69                Error::Protocol(ProtocolError::NoAcceptableHandshakeMethod {
70                    version: socks5_proto::SOCKS_VERSION,
71                    chosen_method,
72                    methods: req.methods,
73                }),
74                self.stream,
75            ))
76        }
77    }
78}
79
80impl<A> IncomingConnection<A, state::NeedCommand> {
81    /// Waits the SOCKS5 client to send a request.
82    ///
83    /// This method will return a [`Command`] if the client sends a valid command.
84    ///
85    /// When encountering an error, the stream will be returned alongside the error.
86    ///
87    /// Note that this method will not implicitly close the connection even if the client sends an invalid command.
88    pub async fn wait(mut self) -> Result<Command, (Error, TcpStream)> {
89        let req = match Request::read_from(&mut self.stream).await {
90            Ok(req) => req,
91            Err(err) => return Err((err, self.stream)),
92        };
93
94        match req.command {
95            ProtocolCommand::Associate => {
96                Ok(Command::Associate(Associate::new(self.stream), req.address))
97            }
98            ProtocolCommand::Bind => Ok(Command::Bind(Bind::new(self.stream), req.address)),
99            ProtocolCommand::Connect => {
100                Ok(Command::Connect(Connect::new(self.stream), req.address))
101            }
102        }
103    }
104}
105
106impl<A, S> IncomingConnection<A, S> {
107    #[inline]
108    pub(crate) fn new(stream: TcpStream, auth: AuthAdaptor<A>) -> Self {
109        Self {
110            stream,
111            auth,
112            _state: PhantomData,
113        }
114    }
115
116    /// Causes the other peer to receive a read of length 0, indicating that no more data will be sent. This only closes the stream in one direction.
117    #[inline]
118    pub async fn close(&mut self) -> Result<(), IoError> {
119        self.stream.shutdown().await
120    }
121
122    /// Returns the local address that this stream is bound to.
123    #[inline]
124    pub fn local_addr(&self) -> Result<SocketAddr, IoError> {
125        self.stream.local_addr()
126    }
127
128    /// Returns the remote address that this stream is connected to.
129    #[inline]
130    pub fn peer_addr(&self) -> Result<SocketAddr, IoError> {
131        self.stream.peer_addr()
132    }
133
134    /// Returns a shared reference to the underlying stream.
135    ///
136    /// Note that this may break the encapsulation of the SOCKS5 connection and you should not use this method unless you know what you are doing.
137    #[inline]
138    pub fn get_ref(&self) -> &TcpStream {
139        &self.stream
140    }
141
142    /// Returns a mutable reference to the underlying stream.
143    ///
144    /// Note that this may break the encapsulation of the SOCKS5 connection and you should not use this method unless you know what you are doing.
145    #[inline]
146    pub fn get_mut(&mut self) -> &mut TcpStream {
147        &mut self.stream
148    }
149
150    /// Consumes the [`IncomingConnection`] and returns the underlying [`TcpStream`](tokio::net::TcpStream).
151    #[inline]
152    pub fn into_inner(self) -> TcpStream {
153        self.stream
154    }
155}
156
157impl<A, S> Debug for IncomingConnection<A, S> {
158    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
159        f.debug_struct("IncomingConnection")
160            .field("stream", &self.stream)
161            .finish()
162    }
163}
164
165/// A command sent from the SOCKS5 client.
166#[derive(Debug)]
167pub enum Command {
168    Associate(Associate<associate::state::NeedReply>, Address),
169    Bind(Bind<bind::state::NeedFirstReply>, Address),
170    Connect(Connect<connect::state::NeedReply>, Address),
171}