socks_lib/v5/
stream.rs

1use std::net::SocketAddr;
2
3use bytes::BytesMut;
4
5use crate::io::{self, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
6use crate::v5::{Request, Response, method::Method};
7
8pub struct Stream<T> {
9    inner: T,
10    peer_addr: SocketAddr,
11    local_addr: SocketAddr,
12}
13
14impl<T> Stream<T> {
15    #[inline]
16    pub fn version(&self) -> u8 {
17        0x05
18    }
19
20    #[inline]
21    pub fn peer_addr(&self) -> SocketAddr {
22        self.peer_addr
23    }
24
25    #[inline]
26    pub fn local_addr(&self) -> SocketAddr {
27        self.local_addr
28    }
29
30    #[inline]
31    pub fn with(inner: T, peer_addr: SocketAddr, local_addr: SocketAddr) -> Self {
32        Self {
33            inner,
34            peer_addr,
35            local_addr,
36        }
37    }
38}
39
40// ===== STREAM Server Side Impl =====
41impl<T> Stream<T>
42where
43    T: AsyncRead + AsyncWrite + Unpin,
44{
45    /// # Methods
46    ///
47    /// ```text
48    ///  +----+----------+----------+
49    ///  |VER | NMETHODS | METHODS  |
50    ///  +----+----------+----------+
51    ///  | 1  |    1     | 1 to 255 |
52    ///  +----+----------+----------+
53    /// ```
54    #[inline]
55    pub async fn read_methods(&mut self) -> io::Result<Vec<Method>> {
56        let mut buffer = [0u8; 2];
57        self.read_exact(&mut buffer).await?;
58
59        let method_num = buffer[1];
60        if method_num == 1 {
61            let method = self.read_u8().await?;
62            return Ok(vec![Method::from_u8(method)]);
63        }
64
65        let mut methods = vec![0u8; method_num as usize];
66        self.read_exact(&mut methods).await?;
67
68        let result = methods.into_iter().map(Method::from_u8).collect();
69
70        Ok(result)
71    }
72
73    ///
74    /// ```text
75    ///  +----+--------+
76    ///  |VER | METHOD |
77    ///  +----+--------+
78    ///  | 1  |   1    |
79    ///  +----+--------+
80    ///  ```
81    #[inline]
82    pub async fn write_auth_method(&mut self, method: Method) -> io::Result<usize> {
83        let bytes = [self.version(), method.as_u8()];
84        self.write(&bytes).await
85    }
86
87    ///
88    /// ```text
89    ///  +----+-----+-------+------+----------+----------+
90    ///  |VER | CMD |  RSV  | ATYP | DST.ADDR | DST.PORT |
91    ///  +----+-----+-------+------+----------+----------+
92    ///  | 1  |  1  | X'00' |  1   | Variable |    2     |
93    ///  +----+-----+-------+------+----------+----------+
94    /// ```
95    ///
96    #[inline]
97    pub async fn read_request(&mut self) -> io::Result<Request> {
98        let _version = self.read_u8().await?;
99        Request::from_async_read(self).await
100    }
101
102    ///
103    /// ```text
104    ///  +----+-----+-------+------+----------+----------+
105    ///  |VER | REP |  RSV  | ATYP | BND.ADDR | BND.PORT |
106    ///  +----+-----+-------+------+----------+----------+
107    ///  | 1  |  1  | X'00' |  1   | Variable |    2     |
108    ///  +----+-----+-------+------+----------+----------+
109    /// ```
110    ///
111    #[inline]
112    pub async fn write_response<'a>(&mut self, resp: &Response<'a>) -> io::Result<usize> {
113        let bytes = prepend_u8(resp.to_bytes(), self.version());
114        self.write(&bytes).await
115    }
116
117    #[inline]
118    pub async fn write_response_unspecified(&mut self) -> io::Result<usize> {
119        use crate::v5::Address;
120        self.write_response(&Response::Success(Address::unspecified()))
121            .await
122    }
123
124    #[inline]
125    pub async fn write_response_unsupported(&mut self) -> io::Result<usize> {
126        self.write_response(&Response::CommandNotSupported).await
127    }
128}
129
130#[inline]
131fn prepend_u8(mut bytes: BytesMut, value: u8) -> BytesMut {
132    bytes.reserve(1);
133
134    unsafe {
135        let ptr = bytes.as_mut_ptr();
136        std::ptr::copy(ptr, ptr.add(1), bytes.len());
137        std::ptr::write(ptr, value);
138        let new_len = bytes.len() + 1;
139        bytes.set_len(new_len);
140    }
141
142    bytes
143}
144
145mod async_impl {
146    use super::Stream;
147
148    use std::io;
149    use std::pin::Pin;
150    use std::task::{Context, Poll};
151    use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
152
153    impl<T: AsyncRead + Unpin> AsyncRead for Stream<T> {
154        fn poll_read(
155            mut self: Pin<&mut Self>,
156            cx: &mut Context<'_>,
157            buf: &mut ReadBuf<'_>,
158        ) -> Poll<io::Result<()>> {
159            Pin::new(&mut self.inner).poll_read(cx, buf)
160        }
161    }
162
163    impl<T: AsyncWrite + Unpin> AsyncWrite for Stream<T> {
164        fn poll_write(
165            mut self: Pin<&mut Self>,
166            cx: &mut Context<'_>,
167            buf: &[u8],
168        ) -> Poll<io::Result<usize>> {
169            Pin::new(&mut self.inner).poll_write(cx, buf)
170        }
171
172        fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
173            Pin::new(&mut self.inner).poll_flush(cx)
174        }
175
176        fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
177            Pin::new(&mut self.inner).poll_shutdown(cx)
178        }
179    }
180}