socks_lib/v5/
server.rs

1use std::io;
2use std::sync::Arc;
3use std::time::Duration;
4
5use crate::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
6use crate::net::TcpListener;
7use crate::v5::{Method, Request, Stream};
8
9const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
10
11pub struct Config<A, H> {
12    auth: A,
13    handler: H,
14    timeout: Duration,
15}
16
17impl<A, H> Config<A, H> {
18    pub fn new(auth: A, handler: H) -> Self {
19        Self {
20            auth,
21            handler,
22            timeout: DEFAULT_TIMEOUT,
23        }
24    }
25
26    pub fn with_timeout(mut self, timeout: Duration) -> Self {
27        self.timeout = timeout;
28        self
29    }
30}
31
32/// SOCKS5 server implementation
33pub struct Server;
34
35impl Server {
36    pub async fn run<H, A>(
37        listener: TcpListener,
38        config: Arc<Config<A, H>>,
39        shutdown_signal: impl Future<Output = ()>,
40    ) -> io::Result<()>
41    where
42        H: Handler + 'static,
43        A: Authenticator + 'static,
44    {
45        tokio::pin!(shutdown_signal);
46
47        loop {
48            tokio::select! {
49                // Bias select to prefer the shutdown signal if both are ready
50                biased;
51
52                _ = &mut shutdown_signal => return Ok(()),
53
54                result = listener.accept() => {
55                    let (inner, addr) = match result {
56                        Ok(res) => res,
57                        Err(_err) => {
58                            #[cfg(feature = "tracing")]
59                            tracing::error!("Failed to accept connection: {}", _err);
60                            continue;
61                        }
62                    };
63
64                    let config = config.clone();
65                    tokio::spawn(async move {
66                        let mut stream = Stream::with(inner, addr);
67
68                        if let Err(_err) = Self::handle_connection(&mut stream, &config).await {
69                            #[cfg(feature = "tracing")]
70                            tracing::warn!("Connection {} error: {}", addr, _err);
71                        }
72                    });
73                }
74            }
75        }
76    }
77
78    async fn handle_connection<H, A, S>(
79        stream: &mut Stream<S>,
80        config: &Config<A, H>,
81    ) -> io::Result<()>
82    where
83        H: Handler + 'static,
84        A: Authenticator + 'static,
85        S: AsyncRead + AsyncWrite + Unpin + Send + Sync,
86    {
87        // Apply timeout to handshake phase
88        let request = tokio::time::timeout(config.timeout, async {
89            let methods = stream.read_methods().await?;
90            config.auth.auth(stream, methods).await?;
91            stream.read_request().await
92        })
93        .await
94        .map_err(|_| io::Error::new(io::ErrorKind::TimedOut, "Timeout during authentication"))??;
95
96        config.handler.handle(stream, request).await
97    }
98}
99
100/// Authentication trait for SOCKS5 server
101pub trait Authenticator: Send + Sync {
102    fn auth<T>(
103        &self,
104        stream: &mut Stream<T>,
105        methods: Vec<Method>,
106    ) -> impl Future<Output = io::Result<()>> + Send
107    where
108        T: AsyncRead + AsyncWrite + Unpin + Send + Sync;
109}
110
111/// Request handler trait for SOCKS5 server
112pub trait Handler: Send + Sync {
113    fn handle<T>(
114        &self,
115        stream: &mut Stream<T>,
116        request: Request,
117    ) -> impl Future<Output = io::Result<()>> + Send
118    where
119        T: AsyncRead + AsyncWrite + Unpin + Send + Sync;
120}
121
122pub mod auth {
123    use super::*;
124
125    pub struct NoAuthentication;
126
127    impl Authenticator for NoAuthentication {
128        async fn auth<T>(&self, stream: &mut Stream<T>, _methods: Vec<Method>) -> io::Result<()>
129        where
130            T: AsyncRead + AsyncWrite + Unpin + Send + Sync,
131        {
132            stream.write_auth_method(Method::NoAuthentication).await?;
133            Ok(())
134        }
135    }
136
137    pub struct UserPassword {
138        username: String,
139        password: String,
140    }
141
142    impl UserPassword {
143        pub fn new(username: String, password: String) -> Self {
144            Self { username, password }
145        }
146    }
147
148    impl Authenticator for UserPassword {
149        async fn auth<T>(&self, stream: &mut Stream<T>, methods: Vec<Method>) -> io::Result<()>
150        where
151            T: AsyncRead + AsyncWrite + Unpin + Send + Sync,
152        {
153            if !methods.contains(&Method::UsernamePassword) {
154                return Err(io::Error::new(
155                    io::ErrorKind::PermissionDenied,
156                    "Username/Password authentication required",
157                ));
158            }
159
160            stream.write_auth_method(Method::UsernamePassword).await?;
161
162            // Read username/password subnegotiation
163            let version = stream.read_u8().await?;
164            if version != 0x01 {
165                return Err(io::Error::new(
166                    io::ErrorKind::InvalidData,
167                    "Invalid subnegotiation version",
168                ));
169            }
170
171            let ulen = stream.read_u8().await?;
172            let mut username = vec![0; ulen as usize];
173            stream.read_exact(&mut username).await?;
174
175            let plen = stream.read_u8().await?;
176            let mut password = vec![0; plen as usize];
177            stream.read_exact(&mut password).await?;
178
179            // Verify credentials
180            if username != self.username.as_bytes() || password != self.password.as_bytes() {
181                stream.write_all(&[0x01, 0x01]).await?;
182                return Err(io::Error::new(
183                    io::ErrorKind::PermissionDenied,
184                    "Invalid username or password",
185                ));
186            }
187
188            stream.write_all(&[0x01, 0x00]).await?;
189
190            Ok(())
191        }
192    }
193}