socks_lib/v5/
mod.rs

1pub mod server;
2
3use std::io;
4use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
5use std::sync::LazyLock;
6
7use bytes::{Buf, BufMut, Bytes, BytesMut};
8use tokio::io::{AsyncRead, AsyncReadExt, BufReader};
9
10/// # Method
11///
12/// ```text
13///  +--------+
14///  | METHOD |
15///  +--------+
16///  |   1    |
17///  +--------+
18/// ```
19///
20#[derive(Debug, Copy, Clone, PartialEq)]
21pub enum Method {
22    NoAuthentication,
23    GSSAPI,
24    UsernamePassword,
25    IanaAssigned(u8),
26    ReservedPrivate(u8),
27    NoAcceptableMethod,
28}
29
30impl Method {
31    #[rustfmt::skip]
32    #[inline]
33    fn as_u8(&self) -> u8 {
34        match self {
35            Self::NoAuthentication            => 0x00,
36            Self::GSSAPI                      => 0x01,
37            Self::UsernamePassword            => 0x03,
38            Self::IanaAssigned(value)         => *value,
39            Self::ReservedPrivate(value)      => *value,
40            Self::NoAcceptableMethod          => 0xFF,
41        }
42    }
43
44    #[rustfmt::skip]
45    #[inline]
46    fn from_u8(value: u8) -> Self {
47        match value {
48            0x00        => Self::NoAuthentication,
49            0x01        => Self::GSSAPI,
50            0x02        => Self::UsernamePassword,
51            0x03..=0x7F => Self::IanaAssigned(value),
52            0x80..=0xFE => Self::ReservedPrivate(value),
53            0xFF        => Self::NoAcceptableMethod,
54        }
55    }
56}
57
58/// # Request
59///
60/// ```text
61///  +-----+-------+------+----------+----------+
62///  | CMD |  RSV  | ATYP | DST.ADDR | DST.PORT |
63///  +-----+-------+------+----------+----------+
64///  |  1  | X'00' |  1   | Variable |    2     |
65///  +-----+-------+------+----------+----------+
66/// ```
67///
68#[derive(Debug, Clone, PartialEq)]
69pub enum Request {
70    Bind(Address),
71    Connect(Address),
72    Associate(Address),
73}
74
75#[rustfmt::skip]
76impl Request {
77    const SOCKS5_CMD_CONNECT:   u8 = 0x01;
78    const SOCKS5_CMD_BIND:      u8 = 0x02;
79    const SOCKS5_CMD_ASSOCIATE: u8 = 0x03;
80}
81
82impl Request {
83    pub async fn from_async_read<R: AsyncRead + Unpin>(
84        reader: &mut BufReader<R>,
85    ) -> io::Result<Self> {
86        let mut buf = [0u8; 2];
87        reader.read_exact(&mut buf).await?;
88
89        let command = buf[0];
90
91        let request = match command {
92            Self::SOCKS5_CMD_BIND => Self::Bind(Address::from_async_read(reader).await?),
93            Self::SOCKS5_CMD_CONNECT => Self::Connect(Address::from_async_read(reader).await?),
94            Self::SOCKS5_CMD_ASSOCIATE => Self::Associate(Address::from_async_read(reader).await?),
95            command => {
96                return Err(io::Error::new(
97                    io::ErrorKind::InvalidData,
98                    format!("Invalid request command: {}", command),
99                ));
100            }
101        };
102
103        Ok(request)
104    }
105}
106
107/// # Address
108///
109/// ```text
110///  +------+----------+----------+
111///  | ATYP | DST.ADDR | DST.PORT |
112///  +------+----------+----------+
113///  |  1   | Variable |    2     |
114///  +------+----------+----------+
115/// ```
116///
117/// ## DST.ADDR BND.ADDR
118///   In an address field (DST.ADDR, BND.ADDR), the ATYP field specifies
119///   the type of address contained within the field:
120///   
121/// o ATYP: X'01'
122///   the address is a version-4 IP address, with a length of 4 octets
123///   
124/// o ATYP: X'03'
125///   the address field contains a fully-qualified domain name.  The first
126///   octet of the address field contains the number of octets of name that
127///   follow, there is no terminating NUL octet.
128///   
129/// o ATYP: X'04'  
130///   the address is a version-6 IP address, with a length of 16 octets.
131///
132#[derive(Debug, Clone, PartialEq)]
133pub enum Address {
134    IPv4(SocketAddrV4),
135    IPv6(SocketAddrV6),
136    Domain(Domain, u16),
137}
138
139static UNSPECIFIED_ADDRESS: LazyLock<Address> =
140    LazyLock::new(|| Address::IPv4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)));
141
142#[rustfmt::skip]
143impl Address {
144    const PORT_LENGTH:         usize = 2;
145    const IPV4_ADDRESS_LENGTH: usize = 4;
146    const IPV6_ADDRESS_LENGTH: usize = 16;
147
148    const SOCKS5_ADDRESS_TYPE_IPV4:        u8 = 0x01;
149    const SOCKS5_ADDRESS_TYPE_DOMAIN_NAME: u8 = 0x03;
150    const SOCKS5_ADDRESS_TYPE_IPV6:        u8 = 0x04;
151}
152
153impl Address {
154    #[inline]
155    pub fn unspecified() -> &'static Self {
156        &UNSPECIFIED_ADDRESS
157    }
158
159    pub fn from_socket_addr(addr: SocketAddr) -> Self {
160        match addr {
161            SocketAddr::V4(addr) => Self::IPv4(addr),
162            SocketAddr::V6(addr) => Self::IPv6(addr),
163        }
164    }
165
166    pub async fn from_async_read<R: AsyncRead + Unpin>(
167        reader: &mut BufReader<R>,
168    ) -> io::Result<Self> {
169        let address_type = reader.read_u8().await?;
170
171        match address_type {
172            Self::SOCKS5_ADDRESS_TYPE_IPV4 => {
173                let mut buf = [0u8; Self::IPV4_ADDRESS_LENGTH + Self::PORT_LENGTH];
174                reader.read_exact(&mut buf).await?;
175
176                let ip = Ipv4Addr::new(buf[0], buf[1], buf[2], buf[3]);
177                let port = u16::from_be_bytes([buf[4], buf[5]]);
178
179                Ok(Address::IPv4(SocketAddrV4::new(ip, port)))
180            }
181
182            Self::SOCKS5_ADDRESS_TYPE_IPV6 => {
183                let mut buf = [0u8; Self::IPV6_ADDRESS_LENGTH + Self::PORT_LENGTH];
184                reader.read_exact(&mut buf).await?;
185
186                let ip = Ipv6Addr::from([
187                    buf[0], buf[1], buf[2], buf[3], buf[4], buf[5], buf[6], buf[7], buf[8], buf[9],
188                    buf[10], buf[11], buf[12], buf[13], buf[14], buf[15],
189                ]);
190                let port = u16::from_be_bytes([buf[16], buf[17]]);
191
192                Ok(Address::IPv6(SocketAddrV6::new(ip, port, 0, 0)))
193            }
194
195            Self::SOCKS5_ADDRESS_TYPE_DOMAIN_NAME => {
196                let domain_len = reader.read_u8().await? as usize;
197
198                let mut buf = vec![0u8; domain_len + Self::PORT_LENGTH];
199                reader.read_exact(&mut buf).await?;
200
201                let domain = Bytes::copy_from_slice(&buf[..domain_len]);
202                let port = u16::from_be_bytes([buf[domain_len], buf[domain_len + 1]]);
203
204                Ok(Address::Domain(Domain(domain), port))
205            }
206
207            n => Err(io::Error::new(
208                io::ErrorKind::InvalidData,
209                format!("Invalid address type: {}", n),
210            )),
211        }
212    }
213
214    pub fn from_bytes<B: Buf>(buf: &mut B) -> io::Result<Self> {
215        if buf.remaining() < 1 {
216            return Err(io::Error::new(
217                io::ErrorKind::InvalidData,
218                "Insufficient data for address",
219            ));
220        }
221
222        let address_type = buf.get_u8();
223
224        match address_type {
225            Self::SOCKS5_ADDRESS_TYPE_IPV4 => {
226                if buf.remaining() < Self::IPV4_ADDRESS_LENGTH + Self::PORT_LENGTH {
227                    return Err(io::Error::new(
228                        io::ErrorKind::InvalidData,
229                        "Insufficient data for IPv4 address",
230                    ));
231                }
232
233                let mut ip = [0u8; Self::IPV4_ADDRESS_LENGTH];
234                buf.copy_to_slice(&mut ip);
235
236                let port = buf.get_u16();
237
238                Ok(Address::IPv4(SocketAddrV4::new(Ipv4Addr::from(ip), port)))
239            }
240
241            Self::SOCKS5_ADDRESS_TYPE_IPV6 => {
242                if buf.remaining() < Self::IPV6_ADDRESS_LENGTH + Self::PORT_LENGTH {
243                    return Err(io::Error::new(
244                        io::ErrorKind::InvalidData,
245                        "Insufficient data for IPv6 address",
246                    ));
247                }
248
249                let mut ip = [0u8; Self::IPV6_ADDRESS_LENGTH];
250                buf.copy_to_slice(&mut ip);
251
252                let port = buf.get_u16();
253
254                Ok(Address::IPv6(SocketAddrV6::new(
255                    Ipv6Addr::from(ip),
256                    port,
257                    0,
258                    0,
259                )))
260            }
261
262            Self::SOCKS5_ADDRESS_TYPE_DOMAIN_NAME => {
263                if buf.remaining() < 1 {
264                    return Err(io::Error::new(
265                        io::ErrorKind::InvalidData,
266                        "Insufficient data for domain length",
267                    ));
268                }
269
270                let domain_len = buf.get_u8() as usize;
271
272                if buf.remaining() < domain_len + Self::PORT_LENGTH {
273                    return Err(io::Error::new(
274                        io::ErrorKind::InvalidData,
275                        "Insufficient data for domain name",
276                    ));
277                }
278
279                let mut domain = vec![0u8; domain_len];
280                buf.copy_to_slice(&mut domain);
281
282                let port = buf.get_u16();
283
284                Ok(Address::Domain(Domain(Bytes::from(domain)), port))
285            }
286
287            n => Err(io::Error::new(
288                io::ErrorKind::InvalidData,
289                format!("Invalid address type: {}", n),
290            )),
291        }
292    }
293
294    #[inline]
295    pub fn to_bytes(&self) -> Bytes {
296        let mut bytes = BytesMut::new();
297
298        match self {
299            Self::Domain(domain, port) => {
300                let domain_bytes = domain.as_bytes();
301                bytes.put_u8(Self::SOCKS5_ADDRESS_TYPE_DOMAIN_NAME);
302                bytes.put_u8(domain_bytes.len() as u8);
303                bytes.extend_from_slice(domain_bytes);
304                bytes.extend_from_slice(&port.to_be_bytes());
305            }
306            Self::IPv4(addr) => {
307                bytes.put_u8(Self::SOCKS5_ADDRESS_TYPE_IPV4);
308                bytes.extend_from_slice(&addr.ip().octets());
309                bytes.extend_from_slice(&addr.port().to_be_bytes());
310            }
311            Self::IPv6(addr) => {
312                bytes.put_u8(Self::SOCKS5_ADDRESS_TYPE_IPV6);
313                bytes.extend_from_slice(&addr.ip().octets());
314                bytes.extend_from_slice(&addr.port().to_be_bytes());
315            }
316        }
317
318        bytes.freeze()
319    }
320
321    #[inline]
322    pub fn port(&self) -> u16 {
323        match self {
324            Self::IPv4(addr) => addr.port(),
325            Self::IPv6(addr) => addr.port(),
326            Self::Domain(_, port) => *port,
327        }
328    }
329
330    #[inline]
331    pub fn format_as_string(&self) -> io::Result<String> {
332        match self {
333            Self::IPv4(addr) => Ok(addr.to_string()),
334            Self::IPv6(addr) => Ok(addr.to_string()),
335            Self::Domain(domain, port) => Ok(format!("{}:{}", domain.domain_str()?, port)),
336        }
337    }
338}
339
340#[derive(Debug, Clone, PartialEq)]
341pub struct Domain(Bytes);
342
343impl Domain {
344    #[inline]
345    pub fn domain_str(&self) -> io::Result<&str> {
346        use std::str::from_utf8;
347
348        from_utf8(&self.0).map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Invalid UTF-8"))
349    }
350
351    #[inline]
352    pub fn as_bytes(&self) -> &[u8] {
353        &self.0
354    }
355
356    #[inline]
357    pub fn to_bytes(self) -> Bytes {
358        self.0
359    }
360}
361
362impl AsRef<[u8]> for Domain {
363    #[inline]
364    fn as_ref(&self) -> &[u8] {
365        self.as_bytes()
366    }
367}
368
369/// # Response
370///
371/// ```text
372///  +-----+-------+------+----------+----------+
373///  | REP |  RSV  | ATYP | BND.ADDR | BND.PORT |
374///  +-----+-------+------+----------+----------+
375///  |  1  | X'00' |  1   | Variable |    2     |
376///  +-----+-------+------+----------+----------+
377/// ```
378///
379#[derive(Debug, Clone)]
380pub enum Response<'a> {
381    Success(&'a Address),
382    GeneralFailure,
383    ConnectionNotAllowed,
384    NetworkUnreachable,
385    HostUnreachable,
386    ConnectionRefused,
387    TTLExpired,
388    CommandNotSupported,
389    AddressTypeNotSupported,
390    Unassigned(u8),
391}
392
393#[rustfmt::skip]
394impl Response<'_> {
395    const SOCKS5_REPLY_SUCCEEDED:                  u8 = 0x00;
396    const SOCKS5_REPLY_GENERAL_FAILURE:            u8 = 0x01;
397    const SOCKS5_REPLY_CONNECTION_NOT_ALLOWED:     u8 = 0x02;
398    const SOCKS5_REPLY_NETWORK_UNREACHABLE:        u8 = 0x03;
399    const SOCKS5_REPLY_HOST_UNREACHABLE:           u8 = 0x04;
400    const SOCKS5_REPLY_CONNECTION_REFUSED:         u8 = 0x05;
401    const SOCKS5_REPLY_TTL_EXPIRED:                u8 = 0x06;
402    const SOCKS5_REPLY_COMMAND_NOT_SUPPORTED:      u8 = 0x07;
403    const SOCKS5_REPLY_ADDRESS_TYPE_NOT_SUPPORTED: u8 = 0x08;
404}
405
406impl Response<'_> {
407    #[inline]
408    pub fn to_bytes(&self) -> BytesMut {
409        let mut bytes = BytesMut::new();
410
411        let (reply, address) = match &self {
412            Self::GeneralFailure
413            | Self::ConnectionNotAllowed
414            | Self::NetworkUnreachable
415            | Self::HostUnreachable
416            | Self::ConnectionRefused
417            | Self::TTLExpired
418            | Self::CommandNotSupported
419            | Self::AddressTypeNotSupported => (self.as_u8(), Address::unspecified()),
420            Self::Unassigned(code) => (*code, Address::unspecified()),
421            Self::Success(address) => (self.as_u8(), *address),
422        };
423
424        bytes.put_u8(reply);
425        bytes.put_u8(0x00);
426        bytes.extend(address.to_bytes());
427
428        bytes
429    }
430
431    #[rustfmt::skip]
432    #[inline]
433    fn as_u8(&self) -> u8 {
434        match self {
435            Self::Success(_)                 => Self::SOCKS5_REPLY_SUCCEEDED,
436            Self::GeneralFailure             => Self::SOCKS5_REPLY_GENERAL_FAILURE,
437            Self::ConnectionNotAllowed       => Self::SOCKS5_REPLY_CONNECTION_NOT_ALLOWED,
438            Self::NetworkUnreachable         => Self::SOCKS5_REPLY_NETWORK_UNREACHABLE,
439            Self::HostUnreachable            => Self::SOCKS5_REPLY_HOST_UNREACHABLE,
440            Self::ConnectionRefused          => Self::SOCKS5_REPLY_CONNECTION_REFUSED,
441            Self::TTLExpired                 => Self::SOCKS5_REPLY_TTL_EXPIRED,
442            Self::CommandNotSupported        => Self::SOCKS5_REPLY_COMMAND_NOT_SUPPORTED,
443            Self::AddressTypeNotSupported    => Self::SOCKS5_REPLY_ADDRESS_TYPE_NOT_SUPPORTED,
444            Self::Unassigned(code)           => *code
445        }
446    }
447}
448
449/// # UDP Packet
450///
451///
452/// ```text
453///  +-----+------+------+----------+----------+----------+
454///  | RSV | FRAG | ATYP | DST.ADDR | DST.PORT |   DATA   |
455///  +-----+------+------+----------+----------+----------+
456///  |  2  |  1   |  1   | Variable |    2     | Variable |
457///  +-----+------+------+----------+----------+----------+
458/// ```
459///
460#[derive(Debug)]
461pub struct UdpPacket {
462    pub frag: u8,
463    pub address: Address,
464    pub data: Bytes,
465}
466
467impl UdpPacket {
468    pub fn from_bytes<B: Buf>(buf: &mut B) -> io::Result<Self> {
469        if buf.remaining() < 2 {
470            return Err(io::Error::new(
471                io::ErrorKind::InvalidData,
472                "Insufficient data for RSV",
473            ));
474        }
475        buf.advance(2);
476
477        if buf.remaining() < 1 {
478            return Err(io::Error::new(
479                io::ErrorKind::InvalidData,
480                "Insufficient data for FRAG",
481            ));
482        }
483        let frag = buf.get_u8();
484
485        let address = Address::from_bytes(buf)?;
486
487        let data = buf.copy_to_bytes(buf.remaining());
488
489        Ok(Self {
490            frag,
491            address,
492            data,
493        })
494    }
495
496    pub fn to_bytes(&self) -> Bytes {
497        let mut bytes = BytesMut::new();
498
499        bytes.put_u8(0x00);
500        bytes.put_u8(0x00);
501
502        bytes.put_u8(self.frag);
503        bytes.extend(self.address.to_bytes());
504        bytes.extend_from_slice(&self.data);
505
506        bytes.freeze()
507    }
508
509    pub fn un_frag(address: Address, data: Bytes) -> Self {
510        Self {
511            frag: 0,
512            address,
513            data,
514        }
515    }
516}
517
518pub struct Stream<T> {
519    version: u8,
520    from: SocketAddr,
521    inner: BufReader<T>,
522}
523
524impl<T> Stream<T> {
525    pub fn version(&self) -> u8 {
526        self.version
527    }
528
529    pub fn from_addr(&self) -> SocketAddr {
530        self.from
531    }
532}
533
534mod async_impl {
535    use std::io;
536    use std::pin::Pin;
537    use std::task::{Context, Poll};
538
539    use tokio::io::{AsyncRead, AsyncWrite};
540
541    use super::Stream;
542
543    impl<T> AsyncRead for Stream<T>
544    where
545        T: AsyncRead + AsyncWrite + Unpin,
546    {
547        fn poll_read(
548            mut self: Pin<&mut Self>,
549            cx: &mut Context<'_>,
550            buf: &mut tokio::io::ReadBuf<'_>,
551        ) -> Poll<io::Result<()>> {
552            AsyncRead::poll_read(Pin::new(&mut self.inner.get_mut()), cx, buf)
553        }
554    }
555
556    impl<T> AsyncWrite for Stream<T>
557    where
558        T: AsyncRead + AsyncWrite + Unpin,
559    {
560        fn poll_write(
561            mut self: Pin<&mut Self>,
562            cx: &mut Context<'_>,
563            buf: &[u8],
564        ) -> Poll<Result<usize, io::Error>> {
565            AsyncWrite::poll_write(Pin::new(&mut self.inner.get_mut()), cx, buf)
566        }
567
568        fn poll_flush(
569            mut self: Pin<&mut Self>,
570            cx: &mut Context<'_>,
571        ) -> Poll<Result<(), io::Error>> {
572            AsyncWrite::poll_flush(Pin::new(&mut self.inner.get_mut()), cx)
573        }
574
575        fn poll_shutdown(
576            mut self: Pin<&mut Self>,
577            cx: &mut Context<'_>,
578        ) -> Poll<Result<(), io::Error>> {
579            AsyncWrite::poll_shutdown(Pin::new(&mut self.inner.get_mut()), cx)
580        }
581    }
582}
583
584#[cfg(test)]
585mod tests {
586    mod test_request {
587        use crate::v5::{Address, Request};
588
589        use bytes::{BufMut, BytesMut};
590        use std::io::Cursor;
591        use tokio::io::BufReader;
592
593        #[tokio::test]
594        async fn test_request_from_async_read_connect_ipv4() {
595            let mut buffer = BytesMut::new();
596
597            // Command + Reserved
598            buffer.put_u8(Request::SOCKS5_CMD_CONNECT);
599            buffer.put_u8(0x00); // Reserved
600
601            // Address type + Address + Port
602            buffer.put_u8(Address::SOCKS5_ADDRESS_TYPE_IPV4);
603            buffer.put_slice(&[192, 168, 1, 1]); // IP
604            buffer.put_u16(80); // Port
605
606            let bytes = buffer.freeze();
607            let mut cursor = Cursor::new(bytes);
608            let mut reader = BufReader::new(&mut cursor);
609
610            let request = Request::from_async_read(&mut reader).await.unwrap();
611
612            match request {
613                Request::Connect(addr) => match addr {
614                    Address::IPv4(socket_addr) => {
615                        assert_eq!(socket_addr.ip().octets(), [192, 168, 1, 1]);
616                        assert_eq!(socket_addr.port(), 80);
617                    }
618                    _ => panic!("Should be IPv4 address"),
619                },
620                _ => panic!("Should be Connect request"),
621            }
622        }
623
624        #[tokio::test]
625        async fn test_request_from_async_read_bind_ipv6() {
626            let mut buffer = BytesMut::new();
627
628            // Command + Reserved
629            buffer.put_u8(Request::SOCKS5_CMD_BIND);
630            buffer.put_u8(0x00); // Reserved
631
632            // Address type + Address + Port
633            buffer.put_u8(Address::SOCKS5_ADDRESS_TYPE_IPV6);
634            buffer.put_slice(&[0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]); // IPv6
635            buffer.put_u16(443); // Port
636
637            let bytes = buffer.freeze();
638            let mut cursor = Cursor::new(bytes);
639            let mut reader = BufReader::new(&mut cursor);
640
641            let request = Request::from_async_read(&mut reader).await.unwrap();
642
643            match request {
644                Request::Bind(addr) => match addr {
645                    Address::IPv6(socket_addr) => {
646                        assert_eq!(
647                            socket_addr.ip().octets(),
648                            [0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]
649                        );
650                        assert_eq!(socket_addr.port(), 443);
651                    }
652                    _ => panic!("Should be IPv6 address"),
653                },
654                _ => panic!("Should be Bind request"),
655            }
656        }
657
658        #[tokio::test]
659        async fn test_request_from_async_read_associate_domain() {
660            let mut buffer = BytesMut::new();
661
662            // Command + Reserved
663            buffer.put_u8(Request::SOCKS5_CMD_ASSOCIATE);
664            buffer.put_u8(0x00); // Reserved
665
666            // Address type + Address + Port
667            buffer.put_u8(Address::SOCKS5_ADDRESS_TYPE_DOMAIN_NAME);
668            buffer.put_u8(11); // Length of domain name
669            buffer.put_slice(b"example.com"); // Domain name
670            buffer.put_u16(8080); // Port
671
672            let bytes = buffer.freeze();
673            let mut cursor = Cursor::new(bytes);
674            let mut reader = BufReader::new(&mut cursor);
675
676            let request = Request::from_async_read(&mut reader).await.unwrap();
677
678            match request {
679                Request::Associate(addr) => match addr {
680                    Address::Domain(domain, port) => {
681                        assert_eq!(domain.as_bytes(), b"example.com");
682                        assert_eq!(port, 8080);
683                    }
684                    _ => panic!("Should be domain address"),
685                },
686                _ => panic!("Should be Associate request"),
687            }
688        }
689
690        #[tokio::test]
691        async fn test_request_from_async_read_invalid_command() {
692            let mut buffer = BytesMut::new();
693
694            // Invalid Command + Reserved
695            buffer.put_u8(0xFF); // Invalid command
696            buffer.put_u8(0x00); // Reserved
697
698            let bytes = buffer.freeze();
699            let mut cursor = Cursor::new(bytes);
700            let mut reader = BufReader::new(&mut cursor);
701
702            let result = Request::from_async_read(&mut reader).await;
703
704            assert!(result.is_err());
705            if let Err(e) = result {
706                assert_eq!(e.kind(), std::io::ErrorKind::InvalidData);
707            }
708        }
709
710        #[tokio::test]
711        async fn test_request_from_async_read_incomplete_data() {
712            let mut buffer = BytesMut::new();
713
714            // Command only, missing reserved byte
715            buffer.put_u8(Request::SOCKS5_CMD_CONNECT);
716
717            let bytes = buffer.freeze();
718            let mut cursor = Cursor::new(bytes);
719            let mut reader = BufReader::new(&mut cursor);
720
721            let result = Request::from_async_read(&mut reader).await;
722
723            assert!(result.is_err());
724        }
725    }
726
727    mod test_address {
728        use crate::v5::{Address, Domain};
729
730        use bytes::{BufMut, Bytes, BytesMut};
731        use std::io::Cursor;
732        use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
733        use tokio::io::BufReader;
734
735        #[tokio::test]
736        async fn test_address_unspecified() {
737            let unspecified = Address::unspecified();
738            match unspecified {
739                Address::IPv4(addr) => {
740                    assert_eq!(addr.ip(), &Ipv4Addr::UNSPECIFIED);
741                    assert_eq!(addr.port(), 0);
742                }
743                _ => panic!("Unspecified address should be IPv4"),
744            }
745        }
746
747        #[tokio::test]
748        async fn test_address_from_socket_addr_ipv4() {
749            let socket = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 8080));
750            let address = Address::from_socket_addr(socket);
751
752            match address {
753                Address::IPv4(addr) => {
754                    assert_eq!(addr.ip().octets(), [127, 0, 0, 1]);
755                    assert_eq!(addr.port(), 8080);
756                }
757                _ => panic!("Should be IPv4 address"),
758            }
759        }
760
761        #[tokio::test]
762        async fn test_address_from_socket_addr_ipv6() {
763            let socket = SocketAddr::V6(SocketAddrV6::new(
764                Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1),
765                8080,
766                0,
767                0,
768            ));
769            let address = Address::from_socket_addr(socket);
770
771            match address {
772                Address::IPv6(addr) => {
773                    assert_eq!(
774                        addr.ip().octets(),
775                        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]
776                    );
777                    assert_eq!(addr.port(), 8080);
778                }
779                _ => panic!("Should be IPv6 address"),
780            }
781        }
782
783        #[tokio::test]
784        async fn test_address_to_bytes_ipv4() {
785            let addr = Address::IPv4(SocketAddrV4::new(Ipv4Addr::new(192, 168, 1, 1), 80));
786            let bytes = addr.to_bytes();
787
788            assert_eq!(bytes[0], Address::SOCKS5_ADDRESS_TYPE_IPV4);
789            assert_eq!(bytes[1..5], [192, 168, 1, 1]);
790            assert_eq!(bytes[5..7], [0, 80]); // Port 80 in big-endian
791        }
792
793        #[tokio::test]
794        async fn test_address_to_bytes_ipv6() {
795            let addr = Address::IPv6(SocketAddrV6::new(
796                Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1),
797                443,
798                0,
799                0,
800            ));
801            let bytes = addr.to_bytes();
802
803            assert_eq!(bytes[0], Address::SOCKS5_ADDRESS_TYPE_IPV6);
804            assert_eq!(
805                bytes[1..17],
806                [0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]
807            );
808            assert_eq!(bytes[17..19], [1, 187]); // Port 443 in big-endian
809        }
810
811        #[tokio::test]
812        async fn test_address_to_bytes_domain() {
813            let domain = Domain(Bytes::from("example.com"));
814            let addr = Address::Domain(domain, 8080);
815            let bytes = addr.to_bytes();
816
817            assert_eq!(bytes[0], Address::SOCKS5_ADDRESS_TYPE_DOMAIN_NAME);
818            assert_eq!(bytes[1], 11); // Length of "example.com"
819            assert_eq!(&bytes[2..13], b"example.com");
820            assert_eq!(bytes[13..15], [31, 144]); // Port 8080 in big-endian
821        }
822
823        #[tokio::test]
824        async fn test_address_from_bytes_ipv4() {
825            let mut buffer = BytesMut::new();
826            buffer.put_u8(Address::SOCKS5_ADDRESS_TYPE_IPV4);
827            buffer.put_slice(&[192, 168, 1, 1]); // IP
828            buffer.put_u16(80); // Port
829
830            let mut bytes = buffer.freeze();
831            let addr = Address::from_bytes(&mut bytes).unwrap();
832
833            match addr {
834                Address::IPv4(socket_addr) => {
835                    assert_eq!(socket_addr.ip().octets(), [192, 168, 1, 1]);
836                    assert_eq!(socket_addr.port(), 80);
837                }
838                _ => panic!("Should be IPv4 address"),
839            }
840        }
841
842        #[tokio::test]
843        async fn test_address_from_bytes_ipv6() {
844            let mut buffer = BytesMut::new();
845            buffer.put_u8(Address::SOCKS5_ADDRESS_TYPE_IPV6);
846            buffer.put_slice(&[0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]); // IPv6
847            buffer.put_u16(443); // Port
848
849            let mut bytes = buffer.freeze();
850            let addr = Address::from_bytes(&mut bytes).unwrap();
851
852            match addr {
853                Address::IPv6(socket_addr) => {
854                    assert_eq!(
855                        socket_addr.ip().octets(),
856                        [0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]
857                    );
858                    assert_eq!(socket_addr.port(), 443);
859                }
860                _ => panic!("Should be IPv6 address"),
861            }
862        }
863
864        #[tokio::test]
865        async fn test_address_from_bytes_domain() {
866            let mut buffer = BytesMut::new();
867            buffer.put_u8(Address::SOCKS5_ADDRESS_TYPE_DOMAIN_NAME);
868            buffer.put_u8(11); // Length of domain name
869            buffer.put_slice(b"example.com"); // Domain name
870            buffer.put_u16(8080); // Port
871
872            let mut bytes = buffer.freeze();
873            let addr = Address::from_bytes(&mut bytes).unwrap();
874
875            match addr {
876                Address::Domain(domain, port) => {
877                    assert_eq!(domain.as_bytes(), b"example.com");
878                    assert_eq!(port, 8080);
879                }
880                _ => panic!("Should be domain address"),
881            }
882        }
883
884        #[tokio::test]
885        async fn test_address_from_async_read_ipv4() {
886            let mut buffer = BytesMut::new();
887            buffer.put_u8(Address::SOCKS5_ADDRESS_TYPE_IPV4);
888            buffer.put_slice(&[192, 168, 1, 1]); // IP
889            buffer.put_u16(80); // Port
890
891            let bytes = buffer.freeze();
892            let mut cursor = Cursor::new(bytes);
893            let mut reader = BufReader::new(&mut cursor);
894
895            let addr = Address::from_async_read(&mut reader).await.unwrap();
896
897            match addr {
898                Address::IPv4(socket_addr) => {
899                    assert_eq!(socket_addr.ip().octets(), [192, 168, 1, 1]);
900                    assert_eq!(socket_addr.port(), 80);
901                }
902                _ => panic!("Should be IPv4 address"),
903            }
904        }
905
906        #[tokio::test]
907        async fn test_address_from_async_read_ipv6() {
908            let mut buffer = BytesMut::new();
909            buffer.put_u8(Address::SOCKS5_ADDRESS_TYPE_IPV6);
910            buffer.put_slice(&[0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]); // IPv6
911            buffer.put_u16(443); // Port
912
913            let bytes = buffer.freeze();
914            let mut cursor = Cursor::new(bytes);
915            let mut reader = BufReader::new(&mut cursor);
916
917            let addr = Address::from_async_read(&mut reader).await.unwrap();
918
919            match addr {
920                Address::IPv6(socket_addr) => {
921                    assert_eq!(
922                        socket_addr.ip().octets(),
923                        [0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]
924                    );
925                    assert_eq!(socket_addr.port(), 443);
926                }
927                _ => panic!("Should be IPv6 address"),
928            }
929        }
930
931        #[tokio::test]
932        async fn test_address_from_async_read_domain() {
933            let mut buffer = BytesMut::new();
934            buffer.put_u8(Address::SOCKS5_ADDRESS_TYPE_DOMAIN_NAME);
935            buffer.put_u8(11); // Length of domain name
936            buffer.put_slice(b"example.com"); // Domain name
937            buffer.put_u16(8080); // Port
938
939            let bytes = buffer.freeze();
940            let mut cursor = Cursor::new(bytes);
941            let mut reader = BufReader::new(&mut cursor);
942
943            let addr = Address::from_async_read(&mut reader).await.unwrap();
944
945            match addr {
946                Address::Domain(domain, port) => {
947                    assert_eq!(domain.as_bytes(), b"example.com");
948                    assert_eq!(port, 8080);
949                }
950                _ => panic!("Should be domain address"),
951            }
952        }
953
954        #[tokio::test]
955        async fn test_address_from_bytes_invalid_type() {
956            let mut buffer = BytesMut::new();
957            buffer.put_u8(0xFF); // Invalid address type
958
959            let mut bytes = buffer.freeze();
960            let result = Address::from_bytes(&mut bytes);
961
962            assert!(result.is_err());
963        }
964
965        #[tokio::test]
966        async fn test_address_from_bytes_insufficient_data() {
967            // IPv4 with incomplete data
968            let mut buffer = BytesMut::new();
969            buffer.put_u8(Address::SOCKS5_ADDRESS_TYPE_IPV4);
970            buffer.put_slice(&[192, 168]); // Incomplete IP
971
972            let mut bytes = buffer.freeze();
973            let result = Address::from_bytes(&mut bytes);
974
975            assert!(result.is_err());
976        }
977
978        #[tokio::test]
979        async fn test_address_port() {
980            let addr1 = Address::IPv4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 8080));
981            assert_eq!(addr1.port(), 8080);
982
983            let addr2 = Address::IPv6(SocketAddrV6::new(
984                Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1),
985                443,
986                0,
987                0,
988            ));
989            assert_eq!(addr2.port(), 443);
990
991            let addr3 = Address::Domain(Domain(Bytes::from("example.com")), 80);
992            assert_eq!(addr3.port(), 80);
993        }
994
995        #[tokio::test]
996        async fn test_address_format_as_string() {
997            let addr1 = Address::IPv4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 8080));
998            assert_eq!(addr1.format_as_string().unwrap(), "127.0.0.1:8080");
999
1000            let addr2 = Address::IPv6(SocketAddrV6::new(
1001                Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1),
1002                443,
1003                0,
1004                0,
1005            ));
1006            assert_eq!(addr2.format_as_string().unwrap(), "[::1]:443");
1007
1008            // This test assumes Domain::domain_str() returns Ok with the domain string
1009            let addr3 = Address::Domain(Domain(Bytes::from("example.com")), 80);
1010            assert_eq!(addr3.format_as_string().unwrap(), "example.com:80");
1011        }
1012    }
1013}