socks_lib/v5/
server.rs

1use std::io;
2use std::net::SocketAddr;
3
4use bytes::BytesMut;
5use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader};
6use tokio::net::{TcpListener, ToSocketAddrs};
7
8use super::{Method, Request, Response, Stream};
9
10pub struct Server {
11    listener: TcpListener,
12}
13
14impl Server {
15    const VERSION_5: u8 = 0x05;
16
17    pub async fn bind<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
18        Ok(Self {
19            listener: TcpListener::bind(addr).await?,
20        })
21    }
22
23    pub fn local_addr(&self) -> io::Result<SocketAddr> {
24        self.listener.local_addr()
25    }
26
27    #[inline]
28    pub async fn accept(
29        &self,
30    ) -> io::Result<(
31        Request,
32        Stream<impl AsyncRead + AsyncWrite + Unpin + 'static>,
33    )> {
34        let (inner, from) = self.listener.accept().await?;
35        let inner = BufReader::new(inner);
36        let mut stream = Stream::with(Self::VERSION_5, from, inner);
37
38        let _methods = stream.read_methods().await?;
39
40        // TODO: impl username password
41        stream.write_auth_method(Method::NoAuthentication).await?;
42
43        let request = stream.read_request().await?;
44
45        Ok((request, stream))
46    }
47}
48
49impl<T> Stream<T>
50where
51    T: AsyncRead + AsyncWrite + Unpin,
52{
53    fn with<A: Into<SocketAddr>>(version: u8, from: A, inner: BufReader<T>) -> Self {
54        Self {
55            version,
56            from: from.into(),
57            inner,
58        }
59    }
60
61    /// # Methods
62    ///
63    /// ```text
64    ///  +----+----------+----------+
65    ///  |VER | NMETHODS | METHODS  |
66    ///  +----+----------+----------+
67    ///  | 1  |    1     | 1 to 255 |
68    ///  +----+----------+----------+
69    /// ```
70    #[inline]
71    async fn read_methods(&mut self) -> io::Result<Vec<Method>> {
72        let mut buffer = [0u8; 2];
73        self.inner.read_exact(&mut buffer).await?;
74
75        let method_num = buffer[1];
76        if method_num == 1 {
77            let method = self.inner.read_u8().await?;
78            return Ok(vec![Method::from_u8(method)]);
79        }
80
81        let mut methods = vec![0u8; method_num as usize];
82        self.inner.read_exact(&mut methods).await?;
83
84        let result = methods.into_iter().map(|e| Method::from_u8(e)).collect();
85
86        Ok(result)
87    }
88
89    ///
90    /// ```text
91    ///  +----+--------+
92    ///  |VER | METHOD |
93    ///  +----+--------+
94    ///  | 1  |   1    |
95    ///  +----+--------+
96    ///  ```
97    #[inline]
98    async fn write_auth_method(&mut self, method: Method) -> io::Result<usize> {
99        let bytes = [self.version, method.as_u8()];
100        self.inner.write(&bytes).await
101    }
102
103    ///
104    /// ```text
105    ///  +----+-----+-------+------+----------+----------+
106    ///  |VER | CMD |  RSV  | ATYP | DST.ADDR | DST.PORT |
107    ///  +----+-----+-------+------+----------+----------+
108    ///  | 1  |  1  | X'00' |  1   | Variable |    2     |
109    ///  +----+-----+-------+------+----------+----------+
110    /// ```
111    ///
112    #[inline]
113    async fn read_request(&mut self) -> io::Result<Request> {
114        let _version = self.inner.read_u8().await?;
115        Request::from_async_read(&mut self.inner).await
116    }
117
118    ///
119    /// ```text
120    ///  +----+-----+-------+------+----------+----------+
121    ///  |VER | REP |  RSV  | ATYP | BND.ADDR | BND.PORT |
122    ///  +----+-----+-------+------+----------+----------+
123    ///  | 1  |  1  | X'00' |  1   | Variable |    2     |
124    ///  +----+-----+-------+------+----------+----------+
125    /// ```
126    ///
127    #[inline]
128    pub async fn write_response<'a>(&mut self, resp: &Response<'a>) -> io::Result<usize> {
129        let bytes = prepend_u8(resp.to_bytes(), self.version);
130        self.inner.write(&bytes).await
131    }
132}
133
134fn prepend_u8(mut bytes: BytesMut, value: u8) -> BytesMut {
135    bytes.reserve(1);
136
137    unsafe {
138        let ptr = bytes.as_mut_ptr();
139        std::ptr::copy(ptr, ptr.add(1), bytes.len());
140        std::ptr::write(ptr, value);
141        let new_len = bytes.len() + 1;
142        bytes.set_len(new_len);
143    }
144
145    bytes
146}