socks_lib/v5/
server.rs

1use std::io;
2use std::sync::Arc;
3use std::time::Duration;
4
5use crate::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
6use crate::net::TcpListener;
7use crate::v5::{Method, Request, Stream};
8
9const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
10
11pub struct Config<A, H> {
12    auth: A,
13    handler: H,
14    timeout: Duration,
15}
16
17impl<A, H> Config<A, H> {
18    pub fn new(auth: A, handler: H) -> Self {
19        Self {
20            auth,
21            handler,
22            timeout: DEFAULT_TIMEOUT,
23        }
24    }
25
26    pub fn with_timeout(mut self, timeout: Duration) -> Self {
27        self.timeout = timeout;
28        self
29    }
30}
31
32/// SOCKS5 server implementation
33pub struct Server;
34
35impl Server {
36    pub async fn run<H, A>(
37        listener: TcpListener,
38        config: Arc<Config<A, H>>,
39        shutdown_signal: impl Future<Output = ()>,
40    ) -> io::Result<()>
41    where
42        H: Handler + 'static,
43        A: Authenticator + 'static,
44    {
45        tokio::pin!(shutdown_signal);
46
47        loop {
48            tokio::select! {
49                // Bias select to prefer the shutdown signal if both are ready
50                biased;
51
52                _ = &mut shutdown_signal => return Ok(()),
53
54                result = listener.accept() => {
55                    let (inner, addr) = match result {
56                        Ok(res) => res,
57                        Err(_err) => {
58                            #[cfg(feature = "tracing")]
59                            tracing::error!("Failed to accept connection: {}", _err);
60                            continue;
61                        }
62                    };
63
64                    let local_addr = match inner.local_addr() {
65                        Ok(addr) => addr,
66                        Err(_err) => {
67                            #[cfg(feature = "tracing")]
68                            tracing::error!("Failed to get local address for connection {}: {}", addr, _err);
69                            continue;
70                        }
71                    };
72
73                    let config = config.clone();
74                    tokio::spawn(async move {
75                        let mut stream = Stream::with(inner, addr, local_addr);
76
77                        if let Err(_err) = Self::handle_connection(&mut stream, &config).await {
78                            #[cfg(feature = "tracing")]
79                            tracing::warn!("Connection {} error: {}", addr, _err);
80                        }
81                    });
82                }
83            }
84        }
85    }
86
87    async fn handle_connection<H, A, S>(
88        stream: &mut Stream<S>,
89        config: &Config<A, H>,
90    ) -> io::Result<()>
91    where
92        H: Handler + 'static,
93        A: Authenticator + 'static,
94        S: AsyncRead + AsyncWrite + Unpin + Send + Sync,
95    {
96        // Apply timeout to handshake phase
97        let request = tokio::time::timeout(config.timeout, async {
98            let methods = stream.read_methods().await?;
99            config.auth.auth(stream, methods).await?;
100            stream.read_request().await
101        })
102        .await
103        .map_err(|_| io::Error::new(io::ErrorKind::TimedOut, "Timeout during authentication"))??;
104
105        config.handler.handle(stream, request).await
106    }
107}
108
109/// Authentication trait for SOCKS5 server
110pub trait Authenticator: Send + Sync {
111    fn auth<T>(
112        &self,
113        stream: &mut Stream<T>,
114        methods: Vec<Method>,
115    ) -> impl Future<Output = io::Result<()>> + Send
116    where
117        T: AsyncRead + AsyncWrite + Unpin + Send + Sync;
118}
119
120/// Request handler trait for SOCKS5 server
121pub trait Handler: Send + Sync {
122    fn handle<T>(
123        &self,
124        stream: &mut Stream<T>,
125        request: Request,
126    ) -> impl Future<Output = io::Result<()>> + Send
127    where
128        T: AsyncRead + AsyncWrite + Unpin + Send + Sync;
129}
130
131pub mod auth {
132    use super::*;
133
134    pub struct NoAuthentication;
135
136    impl Authenticator for NoAuthentication {
137        async fn auth<T>(&self, stream: &mut Stream<T>, _methods: Vec<Method>) -> io::Result<()>
138        where
139            T: AsyncRead + AsyncWrite + Unpin + Send + Sync,
140        {
141            stream.write_auth_method(Method::NoAuthentication).await?;
142            Ok(())
143        }
144    }
145
146    pub struct UserPassword {
147        username: String,
148        password: String,
149    }
150
151    impl UserPassword {
152        pub fn new(username: String, password: String) -> Self {
153            Self { username, password }
154        }
155    }
156
157    impl Authenticator for UserPassword {
158        async fn auth<T>(&self, stream: &mut Stream<T>, methods: Vec<Method>) -> io::Result<()>
159        where
160            T: AsyncRead + AsyncWrite + Unpin + Send + Sync,
161        {
162            if !methods.contains(&Method::UsernamePassword) {
163                return Err(io::Error::new(
164                    io::ErrorKind::PermissionDenied,
165                    "Username/Password authentication required",
166                ));
167            }
168
169            stream.write_auth_method(Method::UsernamePassword).await?;
170
171            // Read username/password subnegotiation
172            let version = stream.read_u8().await?;
173            if version != 0x01 {
174                return Err(io::Error::new(
175                    io::ErrorKind::InvalidData,
176                    "Invalid subnegotiation version",
177                ));
178            }
179
180            let ulen = stream.read_u8().await?;
181            let mut username = vec![0; ulen as usize];
182            stream.read_exact(&mut username).await?;
183
184            let plen = stream.read_u8().await?;
185            let mut password = vec![0; plen as usize];
186            stream.read_exact(&mut password).await?;
187
188            // Verify credentials
189            if username != self.username.as_bytes() || password != self.password.as_bytes() {
190                stream.write_all(&[0x01, 0x01]).await?;
191                return Err(io::Error::new(
192                    io::ErrorKind::PermissionDenied,
193                    "Invalid username or password",
194                ));
195            }
196
197            stream.write_all(&[0x01, 0x00]).await?;
198
199            Ok(())
200        }
201    }
202}