socks_lib/v5/
mod.rs

1pub mod server;
2
3use std::io;
4use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
5use std::sync::LazyLock;
6
7use bytes::{Buf, BufMut, Bytes, BytesMut};
8use tokio::io::{AsyncRead, AsyncReadExt};
9
10/// # Method
11///
12/// ```text
13///  +--------+
14///  | METHOD |
15///  +--------+
16///  |   1    |
17///  +--------+
18/// ```
19///
20#[derive(Debug, Copy, Clone, PartialEq)]
21pub enum Method {
22    NoAuthentication,
23    GSSAPI,
24    UsernamePassword,
25    IanaAssigned(u8),
26    ReservedPrivate(u8),
27    NoAcceptableMethod,
28}
29
30impl Method {
31    #[rustfmt::skip]
32    #[inline]
33    fn as_u8(&self) -> u8 {
34        match self {
35            Self::NoAuthentication            => 0x00,
36            Self::GSSAPI                      => 0x01,
37            Self::UsernamePassword            => 0x03,
38            Self::IanaAssigned(value)         => *value,
39            Self::ReservedPrivate(value)      => *value,
40            Self::NoAcceptableMethod          => 0xFF,
41        }
42    }
43
44    #[rustfmt::skip]
45    #[inline]
46    fn from_u8(value: u8) -> Self {
47        match value {
48            0x00        => Self::NoAuthentication,
49            0x01        => Self::GSSAPI,
50            0x02        => Self::UsernamePassword,
51            0x03..=0x7F => Self::IanaAssigned(value),
52            0x80..=0xFE => Self::ReservedPrivate(value),
53            0xFF        => Self::NoAcceptableMethod,
54        }
55    }
56}
57
58/// # Request
59///
60/// ```text
61///  +-----+-------+------+----------+----------+
62///  | CMD |  RSV  | ATYP | DST.ADDR | DST.PORT |
63///  +-----+-------+------+----------+----------+
64///  |  1  | X'00' |  1   | Variable |    2     |
65///  +-----+-------+------+----------+----------+
66/// ```
67///
68#[derive(Debug, Clone, PartialEq)]
69pub enum Request {
70    Bind(Address),
71    Connect(Address),
72    Associate(Address),
73}
74
75#[rustfmt::skip]
76impl Request {
77    const SOCKS5_CMD_CONNECT:   u8 = 0x01;
78    const SOCKS5_CMD_BIND:      u8 = 0x02;
79    const SOCKS5_CMD_ASSOCIATE: u8 = 0x03;
80}
81
82impl Request {
83    pub async fn from_async_read<R: AsyncRead + Unpin>(reader: &mut R) -> io::Result<Self> {
84        let mut buf = [0u8; 2];
85        reader.read_exact(&mut buf).await?;
86
87        let command = buf[0];
88
89        let request = match command {
90            Self::SOCKS5_CMD_BIND => Self::Bind(Address::from_async_read(reader).await?),
91            Self::SOCKS5_CMD_CONNECT => Self::Connect(Address::from_async_read(reader).await?),
92            Self::SOCKS5_CMD_ASSOCIATE => Self::Associate(Address::from_async_read(reader).await?),
93            command => {
94                return Err(io::Error::new(
95                    io::ErrorKind::InvalidData,
96                    format!("Invalid request command: {}", command),
97                ));
98            }
99        };
100
101        Ok(request)
102    }
103}
104
105/// # Address
106///
107/// ```text
108///  +------+----------+----------+
109///  | ATYP | DST.ADDR | DST.PORT |
110///  +------+----------+----------+
111///  |  1   | Variable |    2     |
112///  +------+----------+----------+
113/// ```
114///
115/// ## DST.ADDR BND.ADDR
116///   In an address field (DST.ADDR, BND.ADDR), the ATYP field specifies
117///   the type of address contained within the field:
118///   
119/// o ATYP: X'01'
120///   the address is a version-4 IP address, with a length of 4 octets
121///   
122/// o ATYP: X'03'
123///   the address field contains a fully-qualified domain name.  The first
124///   octet of the address field contains the number of octets of name that
125///   follow, there is no terminating NUL octet.
126///   
127/// o ATYP: X'04'  
128///   the address is a version-6 IP address, with a length of 16 octets.
129///
130#[derive(Debug, Clone, PartialEq)]
131pub enum Address {
132    IPv4(SocketAddrV4),
133    IPv6(SocketAddrV6),
134    Domain(Domain, u16),
135}
136
137static UNSPECIFIED_ADDRESS: LazyLock<Address> =
138    LazyLock::new(|| Address::IPv4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)));
139
140#[rustfmt::skip]
141impl Address {
142    const PORT_LENGTH:         usize = 2;
143    const IPV4_ADDRESS_LENGTH: usize = 4;
144    const IPV6_ADDRESS_LENGTH: usize = 16;
145
146    const SOCKS5_ADDRESS_TYPE_IPV4:        u8 = 0x01;
147    const SOCKS5_ADDRESS_TYPE_DOMAIN_NAME: u8 = 0x03;
148    const SOCKS5_ADDRESS_TYPE_IPV6:        u8 = 0x04;
149}
150
151impl Address {
152    #[inline]
153    pub fn unspecified() -> &'static Self {
154        &UNSPECIFIED_ADDRESS
155    }
156
157    pub fn from_socket_addr(addr: SocketAddr) -> Self {
158        match addr {
159            SocketAddr::V4(addr) => Self::IPv4(addr),
160            SocketAddr::V6(addr) => Self::IPv6(addr),
161        }
162    }
163
164    pub async fn from_async_read<R: AsyncRead + Unpin>(reader: &mut R) -> io::Result<Self> {
165        let address_type = reader.read_u8().await?;
166
167        match address_type {
168            Self::SOCKS5_ADDRESS_TYPE_IPV4 => {
169                let mut buf = [0u8; Self::IPV4_ADDRESS_LENGTH + Self::PORT_LENGTH];
170                reader.read_exact(&mut buf).await?;
171
172                let ip = Ipv4Addr::new(buf[0], buf[1], buf[2], buf[3]);
173                let port = u16::from_be_bytes([buf[4], buf[5]]);
174
175                Ok(Address::IPv4(SocketAddrV4::new(ip, port)))
176            }
177
178            Self::SOCKS5_ADDRESS_TYPE_IPV6 => {
179                let mut buf = [0u8; Self::IPV6_ADDRESS_LENGTH + Self::PORT_LENGTH];
180                reader.read_exact(&mut buf).await?;
181
182                let ip = Ipv6Addr::from([
183                    buf[0], buf[1], buf[2], buf[3], buf[4], buf[5], buf[6], buf[7], buf[8], buf[9],
184                    buf[10], buf[11], buf[12], buf[13], buf[14], buf[15],
185                ]);
186                let port = u16::from_be_bytes([buf[16], buf[17]]);
187
188                Ok(Address::IPv6(SocketAddrV6::new(ip, port, 0, 0)))
189            }
190
191            Self::SOCKS5_ADDRESS_TYPE_DOMAIN_NAME => {
192                let domain_len = reader.read_u8().await? as usize;
193
194                let mut buf = vec![0u8; domain_len + Self::PORT_LENGTH];
195                reader.read_exact(&mut buf).await?;
196
197                let domain = Bytes::copy_from_slice(&buf[..domain_len]);
198                let port = u16::from_be_bytes([buf[domain_len], buf[domain_len + 1]]);
199
200                Ok(Address::Domain(Domain(domain), port))
201            }
202
203            n => Err(io::Error::new(
204                io::ErrorKind::InvalidData,
205                format!("Invalid address type: {}", n),
206            )),
207        }
208    }
209
210    pub fn from_bytes<B: Buf>(buf: &mut B) -> io::Result<Self> {
211        if buf.remaining() < 1 {
212            return Err(io::Error::new(
213                io::ErrorKind::InvalidData,
214                "Insufficient data for address",
215            ));
216        }
217
218        let address_type = buf.get_u8();
219
220        match address_type {
221            Self::SOCKS5_ADDRESS_TYPE_IPV4 => {
222                if buf.remaining() < Self::IPV4_ADDRESS_LENGTH + Self::PORT_LENGTH {
223                    return Err(io::Error::new(
224                        io::ErrorKind::InvalidData,
225                        "Insufficient data for IPv4 address",
226                    ));
227                }
228
229                let mut ip = [0u8; Self::IPV4_ADDRESS_LENGTH];
230                buf.copy_to_slice(&mut ip);
231
232                let port = buf.get_u16();
233
234                Ok(Address::IPv4(SocketAddrV4::new(Ipv4Addr::from(ip), port)))
235            }
236
237            Self::SOCKS5_ADDRESS_TYPE_IPV6 => {
238                if buf.remaining() < Self::IPV6_ADDRESS_LENGTH + Self::PORT_LENGTH {
239                    return Err(io::Error::new(
240                        io::ErrorKind::InvalidData,
241                        "Insufficient data for IPv6 address",
242                    ));
243                }
244
245                let mut ip = [0u8; Self::IPV6_ADDRESS_LENGTH];
246                buf.copy_to_slice(&mut ip);
247
248                let port = buf.get_u16();
249
250                Ok(Address::IPv6(SocketAddrV6::new(
251                    Ipv6Addr::from(ip),
252                    port,
253                    0,
254                    0,
255                )))
256            }
257
258            Self::SOCKS5_ADDRESS_TYPE_DOMAIN_NAME => {
259                if buf.remaining() < 1 {
260                    return Err(io::Error::new(
261                        io::ErrorKind::InvalidData,
262                        "Insufficient data for domain length",
263                    ));
264                }
265
266                let domain_len = buf.get_u8() as usize;
267
268                if buf.remaining() < domain_len + Self::PORT_LENGTH {
269                    return Err(io::Error::new(
270                        io::ErrorKind::InvalidData,
271                        "Insufficient data for domain name",
272                    ));
273                }
274
275                let mut domain = vec![0u8; domain_len];
276                buf.copy_to_slice(&mut domain);
277
278                let port = buf.get_u16();
279
280                Ok(Address::Domain(Domain(Bytes::from(domain)), port))
281            }
282
283            n => Err(io::Error::new(
284                io::ErrorKind::InvalidData,
285                format!("Invalid address type: {}", n),
286            )),
287        }
288    }
289
290    #[inline]
291    pub fn to_bytes(&self) -> Bytes {
292        let mut bytes = BytesMut::new();
293
294        match self {
295            Self::Domain(domain, port) => {
296                let domain_bytes = domain.as_bytes();
297                bytes.put_u8(Self::SOCKS5_ADDRESS_TYPE_DOMAIN_NAME);
298                bytes.put_u8(domain_bytes.len() as u8);
299                bytes.extend_from_slice(domain_bytes);
300                bytes.extend_from_slice(&port.to_be_bytes());
301            }
302            Self::IPv4(addr) => {
303                bytes.put_u8(Self::SOCKS5_ADDRESS_TYPE_IPV4);
304                bytes.extend_from_slice(&addr.ip().octets());
305                bytes.extend_from_slice(&addr.port().to_be_bytes());
306            }
307            Self::IPv6(addr) => {
308                bytes.put_u8(Self::SOCKS5_ADDRESS_TYPE_IPV6);
309                bytes.extend_from_slice(&addr.ip().octets());
310                bytes.extend_from_slice(&addr.port().to_be_bytes());
311            }
312        }
313
314        bytes.freeze()
315    }
316
317    #[inline]
318    pub fn port(&self) -> u16 {
319        match self {
320            Self::IPv4(addr) => addr.port(),
321            Self::IPv6(addr) => addr.port(),
322            Self::Domain(_, port) => *port,
323        }
324    }
325
326    #[inline]
327    pub fn format_as_string(&self) -> io::Result<String> {
328        match self {
329            Self::IPv4(addr) => Ok(addr.to_string()),
330            Self::IPv6(addr) => Ok(addr.to_string()),
331            Self::Domain(domain, port) => Ok(format!("{}:{}", domain.format_as_str()?, port)),
332        }
333    }
334}
335
336#[derive(Debug, Clone, PartialEq)]
337pub struct Domain(Bytes);
338
339impl From<String> for Domain {
340    fn from(value: String) -> Self {
341        Domain(Bytes::from(value))
342    }
343}
344
345impl From<&[u8]> for Domain {
346    fn from(value: &[u8]) -> Self {
347        Domain(Bytes::copy_from_slice(value))
348    }
349}
350
351impl From<&str> for Domain {
352    fn from(value: &str) -> Self {
353        Domain(Bytes::copy_from_slice(value.as_bytes()))
354    }
355}
356
357impl Domain {
358    #[inline]
359    pub fn format_as_str(&self) -> io::Result<&str> {
360        use std::str::from_utf8;
361
362        from_utf8(&self.0).map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Invalid UTF-8"))
363    }
364
365    #[inline]
366    pub fn as_bytes(&self) -> &[u8] {
367        &self.0
368    }
369
370    #[inline]
371    pub fn to_bytes(self) -> Bytes {
372        self.0
373    }
374
375    #[inline]
376    pub fn from_bytes(bytes: Bytes) -> Self {
377        Self(bytes)
378    }
379
380    #[inline]
381    pub fn from_string(string: String) -> Self {
382        string.into()
383    }
384}
385
386impl AsRef<[u8]> for Domain {
387    #[inline]
388    fn as_ref(&self) -> &[u8] {
389        self.as_bytes()
390    }
391}
392
393/// # Response
394///
395/// ```text
396///  +-----+-------+------+----------+----------+
397///  | REP |  RSV  | ATYP | BND.ADDR | BND.PORT |
398///  +-----+-------+------+----------+----------+
399///  |  1  | X'00' |  1   | Variable |    2     |
400///  +-----+-------+------+----------+----------+
401/// ```
402///
403#[derive(Debug, Clone)]
404pub enum Response<'a> {
405    Success(&'a Address),
406    GeneralFailure,
407    ConnectionNotAllowed,
408    NetworkUnreachable,
409    HostUnreachable,
410    ConnectionRefused,
411    TTLExpired,
412    CommandNotSupported,
413    AddressTypeNotSupported,
414    Unassigned(u8),
415}
416
417#[rustfmt::skip]
418impl Response<'_> {
419    const SOCKS5_REPLY_SUCCEEDED:                  u8 = 0x00;
420    const SOCKS5_REPLY_GENERAL_FAILURE:            u8 = 0x01;
421    const SOCKS5_REPLY_CONNECTION_NOT_ALLOWED:     u8 = 0x02;
422    const SOCKS5_REPLY_NETWORK_UNREACHABLE:        u8 = 0x03;
423    const SOCKS5_REPLY_HOST_UNREACHABLE:           u8 = 0x04;
424    const SOCKS5_REPLY_CONNECTION_REFUSED:         u8 = 0x05;
425    const SOCKS5_REPLY_TTL_EXPIRED:                u8 = 0x06;
426    const SOCKS5_REPLY_COMMAND_NOT_SUPPORTED:      u8 = 0x07;
427    const SOCKS5_REPLY_ADDRESS_TYPE_NOT_SUPPORTED: u8 = 0x08;
428}
429
430impl Response<'_> {
431    #[inline]
432    pub fn to_bytes(&self) -> BytesMut {
433        let mut bytes = BytesMut::new();
434
435        let (reply, address) = match &self {
436            Self::GeneralFailure
437            | Self::ConnectionNotAllowed
438            | Self::NetworkUnreachable
439            | Self::HostUnreachable
440            | Self::ConnectionRefused
441            | Self::TTLExpired
442            | Self::CommandNotSupported
443            | Self::AddressTypeNotSupported => (self.as_u8(), Address::unspecified()),
444            Self::Unassigned(code) => (*code, Address::unspecified()),
445            Self::Success(address) => (self.as_u8(), *address),
446        };
447
448        bytes.put_u8(reply);
449        bytes.put_u8(0x00);
450        bytes.extend(address.to_bytes());
451
452        bytes
453    }
454
455    #[rustfmt::skip]
456    #[inline]
457    fn as_u8(&self) -> u8 {
458        match self {
459            Self::Success(_)                 => Self::SOCKS5_REPLY_SUCCEEDED,
460            Self::GeneralFailure             => Self::SOCKS5_REPLY_GENERAL_FAILURE,
461            Self::ConnectionNotAllowed       => Self::SOCKS5_REPLY_CONNECTION_NOT_ALLOWED,
462            Self::NetworkUnreachable         => Self::SOCKS5_REPLY_NETWORK_UNREACHABLE,
463            Self::HostUnreachable            => Self::SOCKS5_REPLY_HOST_UNREACHABLE,
464            Self::ConnectionRefused          => Self::SOCKS5_REPLY_CONNECTION_REFUSED,
465            Self::TTLExpired                 => Self::SOCKS5_REPLY_TTL_EXPIRED,
466            Self::CommandNotSupported        => Self::SOCKS5_REPLY_COMMAND_NOT_SUPPORTED,
467            Self::AddressTypeNotSupported    => Self::SOCKS5_REPLY_ADDRESS_TYPE_NOT_SUPPORTED,
468            Self::Unassigned(code)           => *code
469        }
470    }
471}
472
473/// # UDP Packet
474///
475///
476/// ```text
477///  +-----+------+------+----------+----------+----------+
478///  | RSV | FRAG | ATYP | DST.ADDR | DST.PORT |   DATA   |
479///  +-----+------+------+----------+----------+----------+
480///  |  2  |  1   |  1   | Variable |    2     | Variable |
481///  +-----+------+------+----------+----------+----------+
482/// ```
483///
484#[derive(Debug)]
485pub struct UdpPacket {
486    pub frag: u8,
487    pub address: Address,
488    pub data: Bytes,
489}
490
491impl UdpPacket {
492    pub fn from_bytes<B: Buf>(buf: &mut B) -> io::Result<Self> {
493        if buf.remaining() < 2 {
494            return Err(io::Error::new(
495                io::ErrorKind::InvalidData,
496                "Insufficient data for RSV",
497            ));
498        }
499        buf.advance(2);
500
501        if buf.remaining() < 1 {
502            return Err(io::Error::new(
503                io::ErrorKind::InvalidData,
504                "Insufficient data for FRAG",
505            ));
506        }
507        let frag = buf.get_u8();
508
509        let address = Address::from_bytes(buf)?;
510
511        let data = buf.copy_to_bytes(buf.remaining());
512
513        Ok(Self {
514            frag,
515            address,
516            data,
517        })
518    }
519
520    pub fn to_bytes(&self) -> Bytes {
521        let mut bytes = BytesMut::new();
522
523        bytes.put_u8(0x00);
524        bytes.put_u8(0x00);
525
526        bytes.put_u8(self.frag);
527        bytes.extend(self.address.to_bytes());
528        bytes.extend_from_slice(&self.data);
529
530        bytes.freeze()
531    }
532
533    pub fn un_frag(address: Address, data: Bytes) -> Self {
534        Self {
535            frag: 0,
536            address,
537            data,
538        }
539    }
540}
541
542pub struct Stream<T>(T);
543
544impl<T> Stream<T> {
545    #[inline]
546    pub fn version(&self) -> u8 {
547        0x05
548    }
549
550    pub fn with(inner: T) -> Self {
551        Self(inner)
552    }
553}
554
555mod async_impl {
556    use std::io;
557    use std::pin::Pin;
558    use std::task::{Context, Poll};
559
560    use tokio::io::{AsyncRead, AsyncWrite};
561
562    use super::Stream;
563
564    impl<T> AsyncRead for Stream<T>
565    where
566        T: AsyncRead + AsyncWrite + Unpin,
567    {
568        fn poll_read(
569            mut self: Pin<&mut Self>,
570            cx: &mut Context<'_>,
571            buf: &mut tokio::io::ReadBuf<'_>,
572        ) -> Poll<io::Result<()>> {
573            AsyncRead::poll_read(Pin::new(&mut self.0), cx, buf)
574        }
575    }
576
577    impl<T> AsyncWrite for Stream<T>
578    where
579        T: AsyncRead + AsyncWrite + Unpin,
580    {
581        fn poll_write(
582            mut self: Pin<&mut Self>,
583            cx: &mut Context<'_>,
584            buf: &[u8],
585        ) -> Poll<Result<usize, io::Error>> {
586            AsyncWrite::poll_write(Pin::new(&mut self.0), cx, buf)
587        }
588
589        fn poll_flush(
590            mut self: Pin<&mut Self>,
591            cx: &mut Context<'_>,
592        ) -> Poll<Result<(), io::Error>> {
593            AsyncWrite::poll_flush(Pin::new(&mut self.0), cx)
594        }
595
596        fn poll_shutdown(
597            mut self: Pin<&mut Self>,
598            cx: &mut Context<'_>,
599        ) -> Poll<Result<(), io::Error>> {
600            AsyncWrite::poll_shutdown(Pin::new(&mut self.0), cx)
601        }
602    }
603}
604
605#[cfg(feature = "ombrac")]
606mod ombrac {
607    use super::{Address, Domain};
608
609    use ombrac::address::Address as OmbracAddress;
610    use ombrac::address::Domain as OmbracDomain;
611
612    impl From<OmbracDomain> for Domain {
613        #[inline]
614        fn from(value: OmbracDomain) -> Self {
615            Self(value.to_bytes())
616        }
617    }
618
619    impl From<Domain> for OmbracDomain {
620        #[inline]
621        fn from(value: Domain) -> Self {
622            Self::from_bytes(value.to_bytes())
623        }
624    }
625
626    impl From<OmbracAddress> for Address {
627        #[inline]
628        fn from(value: OmbracAddress) -> Self {
629            match value {
630                OmbracAddress::Domain(doamin, port) => Self::Domain(doamin.into(), port),
631                OmbracAddress::IPv4(addr) => Self::IPv4(addr),
632                OmbracAddress::IPv6(addr) => Self::IPv6(addr),
633            }
634        }
635    }
636
637    impl From<Address> for OmbracAddress {
638        #[inline]
639        fn from(value: Address) -> Self {
640            match value {
641                Address::Domain(domain, port) => OmbracAddress::Domain(domain.into(), port),
642                Address::IPv4(addr) => OmbracAddress::IPv4(addr),
643                Address::IPv6(addr) => OmbracAddress::IPv6(addr),
644            }
645        }
646    }
647}
648
649#[cfg(test)]
650mod tests {
651    mod test_request {
652        use crate::v5::{Address, Request};
653
654        use bytes::{BufMut, BytesMut};
655        use std::io::Cursor;
656        use tokio::io::BufReader;
657
658        #[tokio::test]
659        async fn test_request_from_async_read_connect_ipv4() {
660            let mut buffer = BytesMut::new();
661
662            // Command + Reserved
663            buffer.put_u8(Request::SOCKS5_CMD_CONNECT);
664            buffer.put_u8(0x00); // Reserved
665
666            // Address type + Address + Port
667            buffer.put_u8(Address::SOCKS5_ADDRESS_TYPE_IPV4);
668            buffer.put_slice(&[192, 168, 1, 1]); // IP
669            buffer.put_u16(80); // Port
670
671            let bytes = buffer.freeze();
672            let mut cursor = Cursor::new(bytes);
673            let mut reader = BufReader::new(&mut cursor);
674
675            let request = Request::from_async_read(&mut reader).await.unwrap();
676
677            match request {
678                Request::Connect(addr) => match addr {
679                    Address::IPv4(socket_addr) => {
680                        assert_eq!(socket_addr.ip().octets(), [192, 168, 1, 1]);
681                        assert_eq!(socket_addr.port(), 80);
682                    }
683                    _ => panic!("Should be IPv4 address"),
684                },
685                _ => panic!("Should be Connect request"),
686            }
687        }
688
689        #[tokio::test]
690        async fn test_request_from_async_read_bind_ipv6() {
691            let mut buffer = BytesMut::new();
692
693            // Command + Reserved
694            buffer.put_u8(Request::SOCKS5_CMD_BIND);
695            buffer.put_u8(0x00); // Reserved
696
697            // Address type + Address + Port
698            buffer.put_u8(Address::SOCKS5_ADDRESS_TYPE_IPV6);
699            buffer.put_slice(&[0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]); // IPv6
700            buffer.put_u16(443); // Port
701
702            let bytes = buffer.freeze();
703            let mut cursor = Cursor::new(bytes);
704            let mut reader = BufReader::new(&mut cursor);
705
706            let request = Request::from_async_read(&mut reader).await.unwrap();
707
708            match request {
709                Request::Bind(addr) => match addr {
710                    Address::IPv6(socket_addr) => {
711                        assert_eq!(
712                            socket_addr.ip().octets(),
713                            [0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]
714                        );
715                        assert_eq!(socket_addr.port(), 443);
716                    }
717                    _ => panic!("Should be IPv6 address"),
718                },
719                _ => panic!("Should be Bind request"),
720            }
721        }
722
723        #[tokio::test]
724        async fn test_request_from_async_read_associate_domain() {
725            let mut buffer = BytesMut::new();
726
727            // Command + Reserved
728            buffer.put_u8(Request::SOCKS5_CMD_ASSOCIATE);
729            buffer.put_u8(0x00); // Reserved
730
731            // Address type + Address + Port
732            buffer.put_u8(Address::SOCKS5_ADDRESS_TYPE_DOMAIN_NAME);
733            buffer.put_u8(11); // Length of domain name
734            buffer.put_slice(b"example.com"); // Domain name
735            buffer.put_u16(8080); // Port
736
737            let bytes = buffer.freeze();
738            let mut cursor = Cursor::new(bytes);
739            let mut reader = BufReader::new(&mut cursor);
740
741            let request = Request::from_async_read(&mut reader).await.unwrap();
742
743            match request {
744                Request::Associate(addr) => match addr {
745                    Address::Domain(domain, port) => {
746                        assert_eq!(domain.as_bytes(), b"example.com");
747                        assert_eq!(port, 8080);
748                    }
749                    _ => panic!("Should be domain address"),
750                },
751                _ => panic!("Should be Associate request"),
752            }
753        }
754
755        #[tokio::test]
756        async fn test_request_from_async_read_invalid_command() {
757            let mut buffer = BytesMut::new();
758
759            // Invalid Command + Reserved
760            buffer.put_u8(0xFF); // Invalid command
761            buffer.put_u8(0x00); // Reserved
762
763            let bytes = buffer.freeze();
764            let mut cursor = Cursor::new(bytes);
765            let mut reader = BufReader::new(&mut cursor);
766
767            let result = Request::from_async_read(&mut reader).await;
768
769            assert!(result.is_err());
770            if let Err(e) = result {
771                assert_eq!(e.kind(), std::io::ErrorKind::InvalidData);
772            }
773        }
774
775        #[tokio::test]
776        async fn test_request_from_async_read_incomplete_data() {
777            let mut buffer = BytesMut::new();
778
779            // Command only, missing reserved byte
780            buffer.put_u8(Request::SOCKS5_CMD_CONNECT);
781
782            let bytes = buffer.freeze();
783            let mut cursor = Cursor::new(bytes);
784            let mut reader = BufReader::new(&mut cursor);
785
786            let result = Request::from_async_read(&mut reader).await;
787
788            assert!(result.is_err());
789        }
790    }
791
792    mod test_address {
793        use crate::v5::{Address, Domain};
794
795        use bytes::{BufMut, Bytes, BytesMut};
796        use std::io::Cursor;
797        use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
798        use tokio::io::BufReader;
799
800        #[tokio::test]
801        async fn test_address_unspecified() {
802            let unspecified = Address::unspecified();
803            match unspecified {
804                Address::IPv4(addr) => {
805                    assert_eq!(addr.ip(), &Ipv4Addr::UNSPECIFIED);
806                    assert_eq!(addr.port(), 0);
807                }
808                _ => panic!("Unspecified address should be IPv4"),
809            }
810        }
811
812        #[tokio::test]
813        async fn test_address_from_socket_addr_ipv4() {
814            let socket = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 8080));
815            let address = Address::from_socket_addr(socket);
816
817            match address {
818                Address::IPv4(addr) => {
819                    assert_eq!(addr.ip().octets(), [127, 0, 0, 1]);
820                    assert_eq!(addr.port(), 8080);
821                }
822                _ => panic!("Should be IPv4 address"),
823            }
824        }
825
826        #[tokio::test]
827        async fn test_address_from_socket_addr_ipv6() {
828            let socket = SocketAddr::V6(SocketAddrV6::new(
829                Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1),
830                8080,
831                0,
832                0,
833            ));
834            let address = Address::from_socket_addr(socket);
835
836            match address {
837                Address::IPv6(addr) => {
838                    assert_eq!(
839                        addr.ip().octets(),
840                        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]
841                    );
842                    assert_eq!(addr.port(), 8080);
843                }
844                _ => panic!("Should be IPv6 address"),
845            }
846        }
847
848        #[tokio::test]
849        async fn test_address_to_bytes_ipv4() {
850            let addr = Address::IPv4(SocketAddrV4::new(Ipv4Addr::new(192, 168, 1, 1), 80));
851            let bytes = addr.to_bytes();
852
853            assert_eq!(bytes[0], Address::SOCKS5_ADDRESS_TYPE_IPV4);
854            assert_eq!(bytes[1..5], [192, 168, 1, 1]);
855            assert_eq!(bytes[5..7], [0, 80]); // Port 80 in big-endian
856        }
857
858        #[tokio::test]
859        async fn test_address_to_bytes_ipv6() {
860            let addr = Address::IPv6(SocketAddrV6::new(
861                Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1),
862                443,
863                0,
864                0,
865            ));
866            let bytes = addr.to_bytes();
867
868            assert_eq!(bytes[0], Address::SOCKS5_ADDRESS_TYPE_IPV6);
869            assert_eq!(
870                bytes[1..17],
871                [0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]
872            );
873            assert_eq!(bytes[17..19], [1, 187]); // Port 443 in big-endian
874        }
875
876        #[tokio::test]
877        async fn test_address_to_bytes_domain() {
878            let domain = Domain(Bytes::from("example.com"));
879            let addr = Address::Domain(domain, 8080);
880            let bytes = addr.to_bytes();
881
882            assert_eq!(bytes[0], Address::SOCKS5_ADDRESS_TYPE_DOMAIN_NAME);
883            assert_eq!(bytes[1], 11); // Length of "example.com"
884            assert_eq!(&bytes[2..13], b"example.com");
885            assert_eq!(bytes[13..15], [31, 144]); // Port 8080 in big-endian
886        }
887
888        #[tokio::test]
889        async fn test_address_from_bytes_ipv4() {
890            let mut buffer = BytesMut::new();
891            buffer.put_u8(Address::SOCKS5_ADDRESS_TYPE_IPV4);
892            buffer.put_slice(&[192, 168, 1, 1]); // IP
893            buffer.put_u16(80); // Port
894
895            let mut bytes = buffer.freeze();
896            let addr = Address::from_bytes(&mut bytes).unwrap();
897
898            match addr {
899                Address::IPv4(socket_addr) => {
900                    assert_eq!(socket_addr.ip().octets(), [192, 168, 1, 1]);
901                    assert_eq!(socket_addr.port(), 80);
902                }
903                _ => panic!("Should be IPv4 address"),
904            }
905        }
906
907        #[tokio::test]
908        async fn test_address_from_bytes_ipv6() {
909            let mut buffer = BytesMut::new();
910            buffer.put_u8(Address::SOCKS5_ADDRESS_TYPE_IPV6);
911            buffer.put_slice(&[0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]); // IPv6
912            buffer.put_u16(443); // Port
913
914            let mut bytes = buffer.freeze();
915            let addr = Address::from_bytes(&mut bytes).unwrap();
916
917            match addr {
918                Address::IPv6(socket_addr) => {
919                    assert_eq!(
920                        socket_addr.ip().octets(),
921                        [0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]
922                    );
923                    assert_eq!(socket_addr.port(), 443);
924                }
925                _ => panic!("Should be IPv6 address"),
926            }
927        }
928
929        #[tokio::test]
930        async fn test_address_from_bytes_domain() {
931            let mut buffer = BytesMut::new();
932            buffer.put_u8(Address::SOCKS5_ADDRESS_TYPE_DOMAIN_NAME);
933            buffer.put_u8(11); // Length of domain name
934            buffer.put_slice(b"example.com"); // Domain name
935            buffer.put_u16(8080); // Port
936
937            let mut bytes = buffer.freeze();
938            let addr = Address::from_bytes(&mut bytes).unwrap();
939
940            match addr {
941                Address::Domain(domain, port) => {
942                    assert_eq!(domain.as_bytes(), b"example.com");
943                    assert_eq!(port, 8080);
944                }
945                _ => panic!("Should be domain address"),
946            }
947        }
948
949        #[tokio::test]
950        async fn test_address_from_async_read_ipv4() {
951            let mut buffer = BytesMut::new();
952            buffer.put_u8(Address::SOCKS5_ADDRESS_TYPE_IPV4);
953            buffer.put_slice(&[192, 168, 1, 1]); // IP
954            buffer.put_u16(80); // Port
955
956            let bytes = buffer.freeze();
957            let mut cursor = Cursor::new(bytes);
958            let mut reader = BufReader::new(&mut cursor);
959
960            let addr = Address::from_async_read(&mut reader).await.unwrap();
961
962            match addr {
963                Address::IPv4(socket_addr) => {
964                    assert_eq!(socket_addr.ip().octets(), [192, 168, 1, 1]);
965                    assert_eq!(socket_addr.port(), 80);
966                }
967                _ => panic!("Should be IPv4 address"),
968            }
969        }
970
971        #[tokio::test]
972        async fn test_address_from_async_read_ipv6() {
973            let mut buffer = BytesMut::new();
974            buffer.put_u8(Address::SOCKS5_ADDRESS_TYPE_IPV6);
975            buffer.put_slice(&[0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]); // IPv6
976            buffer.put_u16(443); // Port
977
978            let bytes = buffer.freeze();
979            let mut cursor = Cursor::new(bytes);
980            let mut reader = BufReader::new(&mut cursor);
981
982            let addr = Address::from_async_read(&mut reader).await.unwrap();
983
984            match addr {
985                Address::IPv6(socket_addr) => {
986                    assert_eq!(
987                        socket_addr.ip().octets(),
988                        [0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]
989                    );
990                    assert_eq!(socket_addr.port(), 443);
991                }
992                _ => panic!("Should be IPv6 address"),
993            }
994        }
995
996        #[tokio::test]
997        async fn test_address_from_async_read_domain() {
998            let mut buffer = BytesMut::new();
999            buffer.put_u8(Address::SOCKS5_ADDRESS_TYPE_DOMAIN_NAME);
1000            buffer.put_u8(11); // Length of domain name
1001            buffer.put_slice(b"example.com"); // Domain name
1002            buffer.put_u16(8080); // Port
1003
1004            let bytes = buffer.freeze();
1005            let mut cursor = Cursor::new(bytes);
1006            let mut reader = BufReader::new(&mut cursor);
1007
1008            let addr = Address::from_async_read(&mut reader).await.unwrap();
1009
1010            match addr {
1011                Address::Domain(domain, port) => {
1012                    assert_eq!(domain.as_bytes(), b"example.com");
1013                    assert_eq!(port, 8080);
1014                }
1015                _ => panic!("Should be domain address"),
1016            }
1017        }
1018
1019        #[tokio::test]
1020        async fn test_address_from_bytes_invalid_type() {
1021            let mut buffer = BytesMut::new();
1022            buffer.put_u8(0xFF); // Invalid address type
1023
1024            let mut bytes = buffer.freeze();
1025            let result = Address::from_bytes(&mut bytes);
1026
1027            assert!(result.is_err());
1028        }
1029
1030        #[tokio::test]
1031        async fn test_address_from_bytes_insufficient_data() {
1032            // IPv4 with incomplete data
1033            let mut buffer = BytesMut::new();
1034            buffer.put_u8(Address::SOCKS5_ADDRESS_TYPE_IPV4);
1035            buffer.put_slice(&[192, 168]); // Incomplete IP
1036
1037            let mut bytes = buffer.freeze();
1038            let result = Address::from_bytes(&mut bytes);
1039
1040            assert!(result.is_err());
1041        }
1042
1043        #[tokio::test]
1044        async fn test_address_port() {
1045            let addr1 = Address::IPv4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 8080));
1046            assert_eq!(addr1.port(), 8080);
1047
1048            let addr2 = Address::IPv6(SocketAddrV6::new(
1049                Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1),
1050                443,
1051                0,
1052                0,
1053            ));
1054            assert_eq!(addr2.port(), 443);
1055
1056            let addr3 = Address::Domain(Domain(Bytes::from("example.com")), 80);
1057            assert_eq!(addr3.port(), 80);
1058        }
1059
1060        #[tokio::test]
1061        async fn test_address_format_as_string() {
1062            let addr1 = Address::IPv4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 8080));
1063            assert_eq!(addr1.format_as_string().unwrap(), "127.0.0.1:8080");
1064
1065            let addr2 = Address::IPv6(SocketAddrV6::new(
1066                Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1),
1067                443,
1068                0,
1069                0,
1070            ));
1071            assert_eq!(addr2.format_as_string().unwrap(), "[::1]:443");
1072
1073            // This test assumes Domain::domain_str() returns Ok with the domain string
1074            let addr3 = Address::Domain(Domain(Bytes::from("example.com")), 80);
1075            assert_eq!(addr3.format_as_string().unwrap(), "example.com:80");
1076        }
1077    }
1078}