1use crate::io::{self, AsyncRead, AsyncReadExt};
2use crate::v5::Address;
3
4#[derive(Debug, Clone, PartialEq)]
15pub enum Request {
16 Bind(Address),
17 Connect(Address),
18 Associate(Address),
19}
20
21#[rustfmt::skip]
22impl Request {
23 const SOCKS5_CMD_CONNECT: u8 = 0x01;
24 const SOCKS5_CMD_BIND: u8 = 0x02;
25 const SOCKS5_CMD_ASSOCIATE: u8 = 0x03;
26}
27
28impl Request {
29 pub async fn from_async_read<R: AsyncRead + Unpin>(reader: &mut R) -> io::Result<Self> {
30 let mut buf = [0u8; 2];
31 reader.read_exact(&mut buf).await?;
32
33 let command = buf[0];
34
35 let request = match command {
36 Self::SOCKS5_CMD_BIND => Self::Bind(Address::from_async_read(reader).await?),
37 Self::SOCKS5_CMD_CONNECT => Self::Connect(Address::from_async_read(reader).await?),
38 Self::SOCKS5_CMD_ASSOCIATE => Self::Associate(Address::from_async_read(reader).await?),
39 command => {
40 return Err(io::Error::new(
41 io::ErrorKind::InvalidData,
42 format!("Invalid request command: {}", command),
43 ));
44 }
45 };
46
47 Ok(request)
48 }
49}
50
51#[cfg(test)]
52mod tests {
53 use crate::v5::{Address, Request};
54
55 use bytes::{BufMut, BytesMut};
56 use std::io::Cursor;
57 use tokio::io::BufReader;
58
59 #[tokio::test]
60 async fn test_request_from_async_read_connect_ipv4() {
61 let mut buffer = BytesMut::new();
62
63 buffer.put_u8(Request::SOCKS5_CMD_CONNECT);
65 buffer.put_u8(0x00); buffer.put_u8(Address::SOCKS5_ADDRESS_TYPE_IPV4);
69 buffer.put_slice(&[192, 168, 1, 1]); buffer.put_u16(80); let bytes = buffer.freeze();
73 let mut cursor = Cursor::new(bytes);
74 let mut reader = BufReader::new(&mut cursor);
75
76 let request = Request::from_async_read(&mut reader).await.unwrap();
77
78 match request {
79 Request::Connect(addr) => match addr {
80 Address::IPv4(socket_addr) => {
81 assert_eq!(socket_addr.ip().octets(), [192, 168, 1, 1]);
82 assert_eq!(socket_addr.port(), 80);
83 }
84 _ => panic!("Should be IPv4 address"),
85 },
86 _ => panic!("Should be Connect request"),
87 }
88 }
89
90 #[tokio::test]
91 async fn test_request_from_async_read_bind_ipv6() {
92 let mut buffer = BytesMut::new();
93
94 buffer.put_u8(Request::SOCKS5_CMD_BIND);
96 buffer.put_u8(0x00); buffer.put_u8(Address::SOCKS5_ADDRESS_TYPE_IPV6);
100 buffer.put_slice(&[0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]); buffer.put_u16(443); let bytes = buffer.freeze();
104 let mut cursor = Cursor::new(bytes);
105 let mut reader = BufReader::new(&mut cursor);
106
107 let request = Request::from_async_read(&mut reader).await.unwrap();
108
109 match request {
110 Request::Bind(addr) => match addr {
111 Address::IPv6(socket_addr) => {
112 assert_eq!(
113 socket_addr.ip().octets(),
114 [0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]
115 );
116 assert_eq!(socket_addr.port(), 443);
117 }
118 _ => panic!("Should be IPv6 address"),
119 },
120 _ => panic!("Should be Bind request"),
121 }
122 }
123
124 #[tokio::test]
125 async fn test_request_from_async_read_associate_domain() {
126 let mut buffer = BytesMut::new();
127
128 buffer.put_u8(Request::SOCKS5_CMD_ASSOCIATE);
130 buffer.put_u8(0x00); buffer.put_u8(Address::SOCKS5_ADDRESS_TYPE_DOMAIN_NAME);
134 buffer.put_u8(11); buffer.put_slice(b"example.com"); buffer.put_u16(8080); let bytes = buffer.freeze();
139 let mut cursor = Cursor::new(bytes);
140 let mut reader = BufReader::new(&mut cursor);
141
142 let request = Request::from_async_read(&mut reader).await.unwrap();
143
144 match request {
145 Request::Associate(addr) => match addr {
146 Address::Domain(domain, port) => {
147 assert_eq!(**domain.as_bytes(), *b"example.com");
148 assert_eq!(port, 8080);
149 }
150 _ => panic!("Should be domain address"),
151 },
152 _ => panic!("Should be Associate request"),
153 }
154 }
155
156 #[tokio::test]
157 async fn test_request_from_async_read_invalid_command() {
158 let mut buffer = BytesMut::new();
159
160 buffer.put_u8(0xFF); buffer.put_u8(0x00); let bytes = buffer.freeze();
165 let mut cursor = Cursor::new(bytes);
166 let mut reader = BufReader::new(&mut cursor);
167
168 let result = Request::from_async_read(&mut reader).await;
169
170 assert!(result.is_err());
171 if let Err(e) = result {
172 assert_eq!(e.kind(), std::io::ErrorKind::InvalidData);
173 }
174 }
175
176 #[tokio::test]
177 async fn test_request_from_async_read_incomplete_data() {
178 let mut buffer = BytesMut::new();
179
180 buffer.put_u8(Request::SOCKS5_CMD_CONNECT);
182
183 let bytes = buffer.freeze();
184 let mut cursor = Cursor::new(bytes);
185 let mut reader = BufReader::new(&mut cursor);
186
187 let result = Request::from_async_read(&mut reader).await;
188
189 assert!(result.is_err());
190 }
191}