socks_lib/v5/
mod.rs

1pub mod server;
2
3use std::io;
4use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
5use std::sync::LazyLock;
6
7use bytes::{Buf, BufMut, Bytes, BytesMut};
8use tokio::io::{AsyncRead, AsyncReadExt, BufReader};
9
10/// # Method
11///
12/// ```text
13///  +--------+
14///  | METHOD |
15///  +--------+
16///  |   1    |
17///  +--------+
18/// ```
19///
20#[derive(Debug, Copy, Clone, PartialEq)]
21pub enum Method {
22    NoAuthentication,
23    GSSAPI,
24    UsernamePassword,
25    IanaAssigned(u8),
26    ReservedPrivate(u8),
27    NoAcceptableMethod,
28}
29
30impl Method {
31    #[rustfmt::skip]
32    #[inline]
33    fn as_u8(&self) -> u8 {
34        match self {
35            Self::NoAuthentication            => 0x00,
36            Self::GSSAPI                      => 0x01,
37            Self::UsernamePassword            => 0x03,
38            Self::IanaAssigned(value)         => *value,
39            Self::ReservedPrivate(value)      => *value,
40            Self::NoAcceptableMethod          => 0xFF,
41        }
42    }
43
44    #[rustfmt::skip]
45    #[inline]
46    fn from_u8(value: u8) -> Self {
47        match value {
48            0x00        => Self::NoAuthentication,
49            0x01        => Self::GSSAPI,
50            0x02        => Self::UsernamePassword,
51            0x03..=0x7F => Self::IanaAssigned(value),
52            0x80..=0xFE => Self::ReservedPrivate(value),
53            0xFF        => Self::NoAcceptableMethod,
54        }
55    }
56}
57
58/// # Request
59///
60/// ```text
61///  +-----+-------+------+----------+----------+
62///  | CMD |  RSV  | ATYP | DST.ADDR | DST.PORT |
63///  +-----+-------+------+----------+----------+
64///  |  1  | X'00' |  1   | Variable |    2     |
65///  +-----+-------+------+----------+----------+
66/// ```
67///
68#[derive(Debug, Clone, PartialEq)]
69pub enum Request {
70    Bind(Address),
71    Connect(Address),
72    Associate(Address),
73}
74
75#[rustfmt::skip]
76impl Request {
77    const SOCKS5_CMD_CONNECT:   u8 = 0x01;
78    const SOCKS5_CMD_BIND:      u8 = 0x02;
79    const SOCKS5_CMD_ASSOCIATE: u8 = 0x03;
80}
81
82impl Request {
83    pub async fn from_async_read<R: AsyncRead + Unpin>(
84        reader: &mut BufReader<R>,
85    ) -> io::Result<Self> {
86        let mut buf = [0u8; 2];
87        reader.read_exact(&mut buf).await?;
88
89        let command = buf[0];
90
91        let request = match command {
92            Self::SOCKS5_CMD_BIND => Self::Bind(Address::from_async_read(reader).await?),
93            Self::SOCKS5_CMD_CONNECT => Self::Connect(Address::from_async_read(reader).await?),
94            Self::SOCKS5_CMD_ASSOCIATE => Self::Associate(Address::from_async_read(reader).await?),
95            command => {
96                return Err(io::Error::new(
97                    io::ErrorKind::InvalidData,
98                    format!("Invalid request command: {}", command),
99                ));
100            }
101        };
102
103        Ok(request)
104    }
105}
106
107/// # Address
108///
109/// ```text
110///  +------+----------+----------+
111///  | ATYP | DST.ADDR | DST.PORT |
112///  +------+----------+----------+
113///  |  1   | Variable |    2     |
114///  +------+----------+----------+
115/// ```
116///
117/// ## DST.ADDR BND.ADDR
118///   In an address field (DST.ADDR, BND.ADDR), the ATYP field specifies
119///   the type of address contained within the field:
120///   
121/// o ATYP: X'01'
122///   the address is a version-4 IP address, with a length of 4 octets
123///   
124/// o ATYP: X'03'
125///   the address field contains a fully-qualified domain name.  The first
126///   octet of the address field contains the number of octets of name that
127///   follow, there is no terminating NUL octet.
128///   
129/// o ATYP: X'04'  
130///   the address is a version-6 IP address, with a length of 16 octets.
131///
132#[derive(Debug, Clone, PartialEq)]
133pub enum Address {
134    IPv4(SocketAddrV4),
135    IPv6(SocketAddrV6),
136    Domain(Domain, u16),
137}
138
139static UNSPECIFIED_ADDRESS: LazyLock<Address> =
140    LazyLock::new(|| Address::IPv4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)));
141
142#[rustfmt::skip]
143impl Address {
144    const PORT_LENGTH:         usize = 2;
145    const IPV4_ADDRESS_LENGTH: usize = 4;
146    const IPV6_ADDRESS_LENGTH: usize = 16;
147
148    const SOCKS5_ADDRESS_TYPE_IPV4:        u8 = 0x01;
149    const SOCKS5_ADDRESS_TYPE_DOMAIN_NAME: u8 = 0x03;
150    const SOCKS5_ADDRESS_TYPE_IPV6:        u8 = 0x04;
151}
152
153impl Address {
154    #[inline]
155    pub fn unspecified() -> &'static Self {
156        &UNSPECIFIED_ADDRESS
157    }
158
159    pub fn from_socket_addr(addr: SocketAddr) -> Self {
160        match addr {
161            SocketAddr::V4(addr) => Self::IPv4(addr),
162            SocketAddr::V6(addr) => Self::IPv6(addr),
163        }
164    }
165
166    pub async fn from_async_read<R: AsyncRead + Unpin>(
167        reader: &mut BufReader<R>,
168    ) -> io::Result<Self> {
169        let address_type = reader.read_u8().await?;
170
171        match address_type {
172            Self::SOCKS5_ADDRESS_TYPE_IPV4 => {
173                let mut buf = [0u8; Self::IPV4_ADDRESS_LENGTH + Self::PORT_LENGTH];
174                reader.read_exact(&mut buf).await?;
175
176                let ip = Ipv4Addr::new(buf[0], buf[1], buf[2], buf[3]);
177                let port = u16::from_be_bytes([buf[4], buf[5]]);
178
179                Ok(Address::IPv4(SocketAddrV4::new(ip, port)))
180            }
181
182            Self::SOCKS5_ADDRESS_TYPE_IPV6 => {
183                let mut buf = [0u8; Self::IPV6_ADDRESS_LENGTH + Self::PORT_LENGTH];
184                reader.read_exact(&mut buf).await?;
185
186                let ip = Ipv6Addr::from([
187                    buf[0], buf[1], buf[2], buf[3], buf[4], buf[5], buf[6], buf[7], buf[8], buf[9],
188                    buf[10], buf[11], buf[12], buf[13], buf[14], buf[15],
189                ]);
190                let port = u16::from_be_bytes([buf[16], buf[17]]);
191
192                Ok(Address::IPv6(SocketAddrV6::new(ip, port, 0, 0)))
193            }
194
195            Self::SOCKS5_ADDRESS_TYPE_DOMAIN_NAME => {
196                let domain_len = reader.read_u8().await? as usize;
197
198                let mut buf = vec![0u8; domain_len + Self::PORT_LENGTH];
199                reader.read_exact(&mut buf).await?;
200
201                let domain = Bytes::copy_from_slice(&buf[..domain_len]);
202                let port = u16::from_be_bytes([buf[domain_len], buf[domain_len + 1]]);
203
204                Ok(Address::Domain(Domain(domain), port))
205            }
206
207            n => Err(io::Error::new(
208                io::ErrorKind::InvalidData,
209                format!("Invalid address type: {}", n),
210            )),
211        }
212    }
213
214    pub fn from_bytes<B: Buf>(buf: &mut B) -> io::Result<Self> {
215        if buf.remaining() < 1 {
216            return Err(io::Error::new(
217                io::ErrorKind::InvalidData,
218                "Insufficient data for address",
219            ));
220        }
221
222        let address_type = buf.get_u8();
223
224        match address_type {
225            Self::SOCKS5_ADDRESS_TYPE_IPV4 => {
226                if buf.remaining() < Self::IPV4_ADDRESS_LENGTH + Self::PORT_LENGTH {
227                    return Err(io::Error::new(
228                        io::ErrorKind::InvalidData,
229                        "Insufficient data for IPv4 address",
230                    ));
231                }
232
233                let mut ip = [0u8; Self::IPV4_ADDRESS_LENGTH];
234                buf.copy_to_slice(&mut ip);
235
236                let port = buf.get_u16();
237
238                Ok(Address::IPv4(SocketAddrV4::new(Ipv4Addr::from(ip), port)))
239            }
240
241            Self::SOCKS5_ADDRESS_TYPE_IPV6 => {
242                if buf.remaining() < Self::IPV6_ADDRESS_LENGTH + Self::PORT_LENGTH {
243                    return Err(io::Error::new(
244                        io::ErrorKind::InvalidData,
245                        "Insufficient data for IPv6 address",
246                    ));
247                }
248
249                let mut ip = [0u8; Self::IPV6_ADDRESS_LENGTH];
250                buf.copy_to_slice(&mut ip);
251
252                let port = buf.get_u16();
253
254                Ok(Address::IPv6(SocketAddrV6::new(
255                    Ipv6Addr::from(ip),
256                    port,
257                    0,
258                    0,
259                )))
260            }
261
262            Self::SOCKS5_ADDRESS_TYPE_DOMAIN_NAME => {
263                if buf.remaining() < 1 {
264                    return Err(io::Error::new(
265                        io::ErrorKind::InvalidData,
266                        "Insufficient data for domain length",
267                    ));
268                }
269
270                let domain_len = buf.get_u8() as usize;
271
272                if buf.remaining() < domain_len + Self::PORT_LENGTH {
273                    return Err(io::Error::new(
274                        io::ErrorKind::InvalidData,
275                        "Insufficient data for domain name",
276                    ));
277                }
278
279                let mut domain = vec![0u8; domain_len];
280                buf.copy_to_slice(&mut domain);
281
282                let port = buf.get_u16();
283
284                Ok(Address::Domain(Domain(Bytes::from(domain)), port))
285            }
286
287            n => Err(io::Error::new(
288                io::ErrorKind::InvalidData,
289                format!("Invalid address type: {}", n),
290            )),
291        }
292    }
293
294    #[inline]
295    pub fn to_bytes(&self) -> Bytes {
296        let mut bytes = BytesMut::new();
297
298        match self {
299            Self::Domain(domain, port) => {
300                let domain_bytes = domain.as_bytes();
301                bytes.put_u8(Self::SOCKS5_ADDRESS_TYPE_DOMAIN_NAME);
302                bytes.put_u8(domain_bytes.len() as u8);
303                bytes.extend_from_slice(domain_bytes);
304                bytes.extend_from_slice(&port.to_be_bytes());
305            }
306            Self::IPv4(addr) => {
307                bytes.put_u8(Self::SOCKS5_ADDRESS_TYPE_IPV4);
308                bytes.extend_from_slice(&addr.ip().octets());
309                bytes.extend_from_slice(&addr.port().to_be_bytes());
310            }
311            Self::IPv6(addr) => {
312                bytes.put_u8(Self::SOCKS5_ADDRESS_TYPE_IPV6);
313                bytes.extend_from_slice(&addr.ip().octets());
314                bytes.extend_from_slice(&addr.port().to_be_bytes());
315            }
316        }
317
318        bytes.freeze()
319    }
320
321    #[inline]
322    pub fn port(&self) -> u16 {
323        match self {
324            Self::IPv4(addr) => addr.port(),
325            Self::IPv6(addr) => addr.port(),
326            Self::Domain(_, port) => *port,
327        }
328    }
329
330    #[inline]
331    pub fn format_as_string(&self) -> io::Result<String> {
332        match self {
333            Self::IPv4(addr) => Ok(addr.to_string()),
334            Self::IPv6(addr) => Ok(addr.to_string()),
335            Self::Domain(domain, port) => Ok(format!("{}:{}", domain.format_as_str()?, port)),
336        }
337    }
338}
339
340#[derive(Debug, Clone, PartialEq)]
341pub struct Domain(Bytes);
342
343impl Into<Domain> for String {
344    fn into(self) -> Domain {
345        Domain(Bytes::from(self))
346    }
347}
348
349impl Into<Domain> for &[u8] {
350    fn into(self) -> Domain {
351        Domain(Bytes::copy_from_slice(self))
352    }
353}
354
355impl Into<Domain> for &str {
356    fn into(self) -> Domain {
357        Domain(Bytes::copy_from_slice(self.as_bytes()))
358    }
359}
360
361impl Domain {
362    #[inline]
363    pub fn format_as_str(&self) -> io::Result<&str> {
364        use std::str::from_utf8;
365
366        from_utf8(&self.0).map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Invalid UTF-8"))
367    }
368
369    #[inline]
370    pub fn as_bytes(&self) -> &[u8] {
371        &self.0
372    }
373
374    #[inline]
375    pub fn to_bytes(self) -> Bytes {
376        self.0
377    }
378}
379
380impl AsRef<[u8]> for Domain {
381    #[inline]
382    fn as_ref(&self) -> &[u8] {
383        self.as_bytes()
384    }
385}
386
387/// # Response
388///
389/// ```text
390///  +-----+-------+------+----------+----------+
391///  | REP |  RSV  | ATYP | BND.ADDR | BND.PORT |
392///  +-----+-------+------+----------+----------+
393///  |  1  | X'00' |  1   | Variable |    2     |
394///  +-----+-------+------+----------+----------+
395/// ```
396///
397#[derive(Debug, Clone)]
398pub enum Response<'a> {
399    Success(&'a Address),
400    GeneralFailure,
401    ConnectionNotAllowed,
402    NetworkUnreachable,
403    HostUnreachable,
404    ConnectionRefused,
405    TTLExpired,
406    CommandNotSupported,
407    AddressTypeNotSupported,
408    Unassigned(u8),
409}
410
411#[rustfmt::skip]
412impl Response<'_> {
413    const SOCKS5_REPLY_SUCCEEDED:                  u8 = 0x00;
414    const SOCKS5_REPLY_GENERAL_FAILURE:            u8 = 0x01;
415    const SOCKS5_REPLY_CONNECTION_NOT_ALLOWED:     u8 = 0x02;
416    const SOCKS5_REPLY_NETWORK_UNREACHABLE:        u8 = 0x03;
417    const SOCKS5_REPLY_HOST_UNREACHABLE:           u8 = 0x04;
418    const SOCKS5_REPLY_CONNECTION_REFUSED:         u8 = 0x05;
419    const SOCKS5_REPLY_TTL_EXPIRED:                u8 = 0x06;
420    const SOCKS5_REPLY_COMMAND_NOT_SUPPORTED:      u8 = 0x07;
421    const SOCKS5_REPLY_ADDRESS_TYPE_NOT_SUPPORTED: u8 = 0x08;
422}
423
424impl Response<'_> {
425    #[inline]
426    pub fn to_bytes(&self) -> BytesMut {
427        let mut bytes = BytesMut::new();
428
429        let (reply, address) = match &self {
430            Self::GeneralFailure
431            | Self::ConnectionNotAllowed
432            | Self::NetworkUnreachable
433            | Self::HostUnreachable
434            | Self::ConnectionRefused
435            | Self::TTLExpired
436            | Self::CommandNotSupported
437            | Self::AddressTypeNotSupported => (self.as_u8(), Address::unspecified()),
438            Self::Unassigned(code) => (*code, Address::unspecified()),
439            Self::Success(address) => (self.as_u8(), *address),
440        };
441
442        bytes.put_u8(reply);
443        bytes.put_u8(0x00);
444        bytes.extend(address.to_bytes());
445
446        bytes
447    }
448
449    #[rustfmt::skip]
450    #[inline]
451    fn as_u8(&self) -> u8 {
452        match self {
453            Self::Success(_)                 => Self::SOCKS5_REPLY_SUCCEEDED,
454            Self::GeneralFailure             => Self::SOCKS5_REPLY_GENERAL_FAILURE,
455            Self::ConnectionNotAllowed       => Self::SOCKS5_REPLY_CONNECTION_NOT_ALLOWED,
456            Self::NetworkUnreachable         => Self::SOCKS5_REPLY_NETWORK_UNREACHABLE,
457            Self::HostUnreachable            => Self::SOCKS5_REPLY_HOST_UNREACHABLE,
458            Self::ConnectionRefused          => Self::SOCKS5_REPLY_CONNECTION_REFUSED,
459            Self::TTLExpired                 => Self::SOCKS5_REPLY_TTL_EXPIRED,
460            Self::CommandNotSupported        => Self::SOCKS5_REPLY_COMMAND_NOT_SUPPORTED,
461            Self::AddressTypeNotSupported    => Self::SOCKS5_REPLY_ADDRESS_TYPE_NOT_SUPPORTED,
462            Self::Unassigned(code)           => *code
463        }
464    }
465}
466
467/// # UDP Packet
468///
469///
470/// ```text
471///  +-----+------+------+----------+----------+----------+
472///  | RSV | FRAG | ATYP | DST.ADDR | DST.PORT |   DATA   |
473///  +-----+------+------+----------+----------+----------+
474///  |  2  |  1   |  1   | Variable |    2     | Variable |
475///  +-----+------+------+----------+----------+----------+
476/// ```
477///
478#[derive(Debug)]
479pub struct UdpPacket {
480    pub frag: u8,
481    pub address: Address,
482    pub data: Bytes,
483}
484
485impl UdpPacket {
486    pub fn from_bytes<B: Buf>(buf: &mut B) -> io::Result<Self> {
487        if buf.remaining() < 2 {
488            return Err(io::Error::new(
489                io::ErrorKind::InvalidData,
490                "Insufficient data for RSV",
491            ));
492        }
493        buf.advance(2);
494
495        if buf.remaining() < 1 {
496            return Err(io::Error::new(
497                io::ErrorKind::InvalidData,
498                "Insufficient data for FRAG",
499            ));
500        }
501        let frag = buf.get_u8();
502
503        let address = Address::from_bytes(buf)?;
504
505        let data = buf.copy_to_bytes(buf.remaining());
506
507        Ok(Self {
508            frag,
509            address,
510            data,
511        })
512    }
513
514    pub fn to_bytes(&self) -> Bytes {
515        let mut bytes = BytesMut::new();
516
517        bytes.put_u8(0x00);
518        bytes.put_u8(0x00);
519
520        bytes.put_u8(self.frag);
521        bytes.extend(self.address.to_bytes());
522        bytes.extend_from_slice(&self.data);
523
524        bytes.freeze()
525    }
526
527    pub fn un_frag(address: Address, data: Bytes) -> Self {
528        Self {
529            frag: 0,
530            address,
531            data,
532        }
533    }
534}
535
536pub struct Stream<T> {
537    version: u8,
538    from: SocketAddr,
539    inner: BufReader<T>,
540}
541
542impl<T> Stream<T> {
543    pub fn version(&self) -> u8 {
544        self.version
545    }
546
547    pub fn from_addr(&self) -> SocketAddr {
548        self.from
549    }
550}
551
552mod async_impl {
553    use std::io;
554    use std::pin::Pin;
555    use std::task::{Context, Poll};
556
557    use tokio::io::{AsyncRead, AsyncWrite};
558
559    use super::Stream;
560
561    impl<T> AsyncRead for Stream<T>
562    where
563        T: AsyncRead + AsyncWrite + Unpin,
564    {
565        fn poll_read(
566            mut self: Pin<&mut Self>,
567            cx: &mut Context<'_>,
568            buf: &mut tokio::io::ReadBuf<'_>,
569        ) -> Poll<io::Result<()>> {
570            AsyncRead::poll_read(Pin::new(&mut self.inner.get_mut()), cx, buf)
571        }
572    }
573
574    impl<T> AsyncWrite for Stream<T>
575    where
576        T: AsyncRead + AsyncWrite + Unpin,
577    {
578        fn poll_write(
579            mut self: Pin<&mut Self>,
580            cx: &mut Context<'_>,
581            buf: &[u8],
582        ) -> Poll<Result<usize, io::Error>> {
583            AsyncWrite::poll_write(Pin::new(&mut self.inner.get_mut()), cx, buf)
584        }
585
586        fn poll_flush(
587            mut self: Pin<&mut Self>,
588            cx: &mut Context<'_>,
589        ) -> Poll<Result<(), io::Error>> {
590            AsyncWrite::poll_flush(Pin::new(&mut self.inner.get_mut()), cx)
591        }
592
593        fn poll_shutdown(
594            mut self: Pin<&mut Self>,
595            cx: &mut Context<'_>,
596        ) -> Poll<Result<(), io::Error>> {
597            AsyncWrite::poll_shutdown(Pin::new(&mut self.inner.get_mut()), cx)
598        }
599    }
600}
601
602#[cfg(test)]
603mod tests {
604    mod test_request {
605        use crate::v5::{Address, Request};
606
607        use bytes::{BufMut, BytesMut};
608        use std::io::Cursor;
609        use tokio::io::BufReader;
610
611        #[tokio::test]
612        async fn test_request_from_async_read_connect_ipv4() {
613            let mut buffer = BytesMut::new();
614
615            // Command + Reserved
616            buffer.put_u8(Request::SOCKS5_CMD_CONNECT);
617            buffer.put_u8(0x00); // Reserved
618
619            // Address type + Address + Port
620            buffer.put_u8(Address::SOCKS5_ADDRESS_TYPE_IPV4);
621            buffer.put_slice(&[192, 168, 1, 1]); // IP
622            buffer.put_u16(80); // Port
623
624            let bytes = buffer.freeze();
625            let mut cursor = Cursor::new(bytes);
626            let mut reader = BufReader::new(&mut cursor);
627
628            let request = Request::from_async_read(&mut reader).await.unwrap();
629
630            match request {
631                Request::Connect(addr) => match addr {
632                    Address::IPv4(socket_addr) => {
633                        assert_eq!(socket_addr.ip().octets(), [192, 168, 1, 1]);
634                        assert_eq!(socket_addr.port(), 80);
635                    }
636                    _ => panic!("Should be IPv4 address"),
637                },
638                _ => panic!("Should be Connect request"),
639            }
640        }
641
642        #[tokio::test]
643        async fn test_request_from_async_read_bind_ipv6() {
644            let mut buffer = BytesMut::new();
645
646            // Command + Reserved
647            buffer.put_u8(Request::SOCKS5_CMD_BIND);
648            buffer.put_u8(0x00); // Reserved
649
650            // Address type + Address + Port
651            buffer.put_u8(Address::SOCKS5_ADDRESS_TYPE_IPV6);
652            buffer.put_slice(&[0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]); // IPv6
653            buffer.put_u16(443); // Port
654
655            let bytes = buffer.freeze();
656            let mut cursor = Cursor::new(bytes);
657            let mut reader = BufReader::new(&mut cursor);
658
659            let request = Request::from_async_read(&mut reader).await.unwrap();
660
661            match request {
662                Request::Bind(addr) => match addr {
663                    Address::IPv6(socket_addr) => {
664                        assert_eq!(
665                            socket_addr.ip().octets(),
666                            [0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]
667                        );
668                        assert_eq!(socket_addr.port(), 443);
669                    }
670                    _ => panic!("Should be IPv6 address"),
671                },
672                _ => panic!("Should be Bind request"),
673            }
674        }
675
676        #[tokio::test]
677        async fn test_request_from_async_read_associate_domain() {
678            let mut buffer = BytesMut::new();
679
680            // Command + Reserved
681            buffer.put_u8(Request::SOCKS5_CMD_ASSOCIATE);
682            buffer.put_u8(0x00); // Reserved
683
684            // Address type + Address + Port
685            buffer.put_u8(Address::SOCKS5_ADDRESS_TYPE_DOMAIN_NAME);
686            buffer.put_u8(11); // Length of domain name
687            buffer.put_slice(b"example.com"); // Domain name
688            buffer.put_u16(8080); // Port
689
690            let bytes = buffer.freeze();
691            let mut cursor = Cursor::new(bytes);
692            let mut reader = BufReader::new(&mut cursor);
693
694            let request = Request::from_async_read(&mut reader).await.unwrap();
695
696            match request {
697                Request::Associate(addr) => match addr {
698                    Address::Domain(domain, port) => {
699                        assert_eq!(domain.as_bytes(), b"example.com");
700                        assert_eq!(port, 8080);
701                    }
702                    _ => panic!("Should be domain address"),
703                },
704                _ => panic!("Should be Associate request"),
705            }
706        }
707
708        #[tokio::test]
709        async fn test_request_from_async_read_invalid_command() {
710            let mut buffer = BytesMut::new();
711
712            // Invalid Command + Reserved
713            buffer.put_u8(0xFF); // Invalid command
714            buffer.put_u8(0x00); // Reserved
715
716            let bytes = buffer.freeze();
717            let mut cursor = Cursor::new(bytes);
718            let mut reader = BufReader::new(&mut cursor);
719
720            let result = Request::from_async_read(&mut reader).await;
721
722            assert!(result.is_err());
723            if let Err(e) = result {
724                assert_eq!(e.kind(), std::io::ErrorKind::InvalidData);
725            }
726        }
727
728        #[tokio::test]
729        async fn test_request_from_async_read_incomplete_data() {
730            let mut buffer = BytesMut::new();
731
732            // Command only, missing reserved byte
733            buffer.put_u8(Request::SOCKS5_CMD_CONNECT);
734
735            let bytes = buffer.freeze();
736            let mut cursor = Cursor::new(bytes);
737            let mut reader = BufReader::new(&mut cursor);
738
739            let result = Request::from_async_read(&mut reader).await;
740
741            assert!(result.is_err());
742        }
743    }
744
745    mod test_address {
746        use crate::v5::{Address, Domain};
747
748        use bytes::{BufMut, Bytes, BytesMut};
749        use std::io::Cursor;
750        use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
751        use tokio::io::BufReader;
752
753        #[tokio::test]
754        async fn test_address_unspecified() {
755            let unspecified = Address::unspecified();
756            match unspecified {
757                Address::IPv4(addr) => {
758                    assert_eq!(addr.ip(), &Ipv4Addr::UNSPECIFIED);
759                    assert_eq!(addr.port(), 0);
760                }
761                _ => panic!("Unspecified address should be IPv4"),
762            }
763        }
764
765        #[tokio::test]
766        async fn test_address_from_socket_addr_ipv4() {
767            let socket = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 8080));
768            let address = Address::from_socket_addr(socket);
769
770            match address {
771                Address::IPv4(addr) => {
772                    assert_eq!(addr.ip().octets(), [127, 0, 0, 1]);
773                    assert_eq!(addr.port(), 8080);
774                }
775                _ => panic!("Should be IPv4 address"),
776            }
777        }
778
779        #[tokio::test]
780        async fn test_address_from_socket_addr_ipv6() {
781            let socket = SocketAddr::V6(SocketAddrV6::new(
782                Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1),
783                8080,
784                0,
785                0,
786            ));
787            let address = Address::from_socket_addr(socket);
788
789            match address {
790                Address::IPv6(addr) => {
791                    assert_eq!(
792                        addr.ip().octets(),
793                        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]
794                    );
795                    assert_eq!(addr.port(), 8080);
796                }
797                _ => panic!("Should be IPv6 address"),
798            }
799        }
800
801        #[tokio::test]
802        async fn test_address_to_bytes_ipv4() {
803            let addr = Address::IPv4(SocketAddrV4::new(Ipv4Addr::new(192, 168, 1, 1), 80));
804            let bytes = addr.to_bytes();
805
806            assert_eq!(bytes[0], Address::SOCKS5_ADDRESS_TYPE_IPV4);
807            assert_eq!(bytes[1..5], [192, 168, 1, 1]);
808            assert_eq!(bytes[5..7], [0, 80]); // Port 80 in big-endian
809        }
810
811        #[tokio::test]
812        async fn test_address_to_bytes_ipv6() {
813            let addr = Address::IPv6(SocketAddrV6::new(
814                Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1),
815                443,
816                0,
817                0,
818            ));
819            let bytes = addr.to_bytes();
820
821            assert_eq!(bytes[0], Address::SOCKS5_ADDRESS_TYPE_IPV6);
822            assert_eq!(
823                bytes[1..17],
824                [0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]
825            );
826            assert_eq!(bytes[17..19], [1, 187]); // Port 443 in big-endian
827        }
828
829        #[tokio::test]
830        async fn test_address_to_bytes_domain() {
831            let domain = Domain(Bytes::from("example.com"));
832            let addr = Address::Domain(domain, 8080);
833            let bytes = addr.to_bytes();
834
835            assert_eq!(bytes[0], Address::SOCKS5_ADDRESS_TYPE_DOMAIN_NAME);
836            assert_eq!(bytes[1], 11); // Length of "example.com"
837            assert_eq!(&bytes[2..13], b"example.com");
838            assert_eq!(bytes[13..15], [31, 144]); // Port 8080 in big-endian
839        }
840
841        #[tokio::test]
842        async fn test_address_from_bytes_ipv4() {
843            let mut buffer = BytesMut::new();
844            buffer.put_u8(Address::SOCKS5_ADDRESS_TYPE_IPV4);
845            buffer.put_slice(&[192, 168, 1, 1]); // IP
846            buffer.put_u16(80); // Port
847
848            let mut bytes = buffer.freeze();
849            let addr = Address::from_bytes(&mut bytes).unwrap();
850
851            match addr {
852                Address::IPv4(socket_addr) => {
853                    assert_eq!(socket_addr.ip().octets(), [192, 168, 1, 1]);
854                    assert_eq!(socket_addr.port(), 80);
855                }
856                _ => panic!("Should be IPv4 address"),
857            }
858        }
859
860        #[tokio::test]
861        async fn test_address_from_bytes_ipv6() {
862            let mut buffer = BytesMut::new();
863            buffer.put_u8(Address::SOCKS5_ADDRESS_TYPE_IPV6);
864            buffer.put_slice(&[0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]); // IPv6
865            buffer.put_u16(443); // Port
866
867            let mut bytes = buffer.freeze();
868            let addr = Address::from_bytes(&mut bytes).unwrap();
869
870            match addr {
871                Address::IPv6(socket_addr) => {
872                    assert_eq!(
873                        socket_addr.ip().octets(),
874                        [0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]
875                    );
876                    assert_eq!(socket_addr.port(), 443);
877                }
878                _ => panic!("Should be IPv6 address"),
879            }
880        }
881
882        #[tokio::test]
883        async fn test_address_from_bytes_domain() {
884            let mut buffer = BytesMut::new();
885            buffer.put_u8(Address::SOCKS5_ADDRESS_TYPE_DOMAIN_NAME);
886            buffer.put_u8(11); // Length of domain name
887            buffer.put_slice(b"example.com"); // Domain name
888            buffer.put_u16(8080); // Port
889
890            let mut bytes = buffer.freeze();
891            let addr = Address::from_bytes(&mut bytes).unwrap();
892
893            match addr {
894                Address::Domain(domain, port) => {
895                    assert_eq!(domain.as_bytes(), b"example.com");
896                    assert_eq!(port, 8080);
897                }
898                _ => panic!("Should be domain address"),
899            }
900        }
901
902        #[tokio::test]
903        async fn test_address_from_async_read_ipv4() {
904            let mut buffer = BytesMut::new();
905            buffer.put_u8(Address::SOCKS5_ADDRESS_TYPE_IPV4);
906            buffer.put_slice(&[192, 168, 1, 1]); // IP
907            buffer.put_u16(80); // Port
908
909            let bytes = buffer.freeze();
910            let mut cursor = Cursor::new(bytes);
911            let mut reader = BufReader::new(&mut cursor);
912
913            let addr = Address::from_async_read(&mut reader).await.unwrap();
914
915            match addr {
916                Address::IPv4(socket_addr) => {
917                    assert_eq!(socket_addr.ip().octets(), [192, 168, 1, 1]);
918                    assert_eq!(socket_addr.port(), 80);
919                }
920                _ => panic!("Should be IPv4 address"),
921            }
922        }
923
924        #[tokio::test]
925        async fn test_address_from_async_read_ipv6() {
926            let mut buffer = BytesMut::new();
927            buffer.put_u8(Address::SOCKS5_ADDRESS_TYPE_IPV6);
928            buffer.put_slice(&[0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]); // IPv6
929            buffer.put_u16(443); // Port
930
931            let bytes = buffer.freeze();
932            let mut cursor = Cursor::new(bytes);
933            let mut reader = BufReader::new(&mut cursor);
934
935            let addr = Address::from_async_read(&mut reader).await.unwrap();
936
937            match addr {
938                Address::IPv6(socket_addr) => {
939                    assert_eq!(
940                        socket_addr.ip().octets(),
941                        [0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]
942                    );
943                    assert_eq!(socket_addr.port(), 443);
944                }
945                _ => panic!("Should be IPv6 address"),
946            }
947        }
948
949        #[tokio::test]
950        async fn test_address_from_async_read_domain() {
951            let mut buffer = BytesMut::new();
952            buffer.put_u8(Address::SOCKS5_ADDRESS_TYPE_DOMAIN_NAME);
953            buffer.put_u8(11); // Length of domain name
954            buffer.put_slice(b"example.com"); // Domain name
955            buffer.put_u16(8080); // Port
956
957            let bytes = buffer.freeze();
958            let mut cursor = Cursor::new(bytes);
959            let mut reader = BufReader::new(&mut cursor);
960
961            let addr = Address::from_async_read(&mut reader).await.unwrap();
962
963            match addr {
964                Address::Domain(domain, port) => {
965                    assert_eq!(domain.as_bytes(), b"example.com");
966                    assert_eq!(port, 8080);
967                }
968                _ => panic!("Should be domain address"),
969            }
970        }
971
972        #[tokio::test]
973        async fn test_address_from_bytes_invalid_type() {
974            let mut buffer = BytesMut::new();
975            buffer.put_u8(0xFF); // Invalid address type
976
977            let mut bytes = buffer.freeze();
978            let result = Address::from_bytes(&mut bytes);
979
980            assert!(result.is_err());
981        }
982
983        #[tokio::test]
984        async fn test_address_from_bytes_insufficient_data() {
985            // IPv4 with incomplete data
986            let mut buffer = BytesMut::new();
987            buffer.put_u8(Address::SOCKS5_ADDRESS_TYPE_IPV4);
988            buffer.put_slice(&[192, 168]); // Incomplete IP
989
990            let mut bytes = buffer.freeze();
991            let result = Address::from_bytes(&mut bytes);
992
993            assert!(result.is_err());
994        }
995
996        #[tokio::test]
997        async fn test_address_port() {
998            let addr1 = Address::IPv4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 8080));
999            assert_eq!(addr1.port(), 8080);
1000
1001            let addr2 = Address::IPv6(SocketAddrV6::new(
1002                Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1),
1003                443,
1004                0,
1005                0,
1006            ));
1007            assert_eq!(addr2.port(), 443);
1008
1009            let addr3 = Address::Domain(Domain(Bytes::from("example.com")), 80);
1010            assert_eq!(addr3.port(), 80);
1011        }
1012
1013        #[tokio::test]
1014        async fn test_address_format_as_string() {
1015            let addr1 = Address::IPv4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 8080));
1016            assert_eq!(addr1.format_as_string().unwrap(), "127.0.0.1:8080");
1017
1018            let addr2 = Address::IPv6(SocketAddrV6::new(
1019                Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1),
1020                443,
1021                0,
1022                0,
1023            ));
1024            assert_eq!(addr2.format_as_string().unwrap(), "[::1]:443");
1025
1026            // This test assumes Domain::domain_str() returns Ok with the domain string
1027            let addr3 = Address::Domain(Domain(Bytes::from("example.com")), 80);
1028            assert_eq!(addr3.format_as_string().unwrap(), "example.com:80");
1029        }
1030    }
1031}