socks_http_kit/
socks5.rs

1use std::{
2    fmt::{self, Display, Formatter},
3    io::{Error, ErrorKind, Result},
4};
5
6use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
7
8use crate::{Address, AuthMethod};
9
10const SOCKS5_VER: u8 = 0x05;
11const SOCKS5_AUTH_VER: u8 = 0x01;
12
13/// SOCKS5 commands as defined in RFC 1928 section 4.
14///
15/// These commands specify the type of proxy operation requested by the client:
16/// - `CONNECT`: Establish a TCP/IP connection to the target.
17/// - `BIND`: Request the server to bind to a port for incoming connections.
18/// - `UDP_ASSOCIATE`: Establish a UDP relay.
19///
20/// Reference: <https://datatracker.ietf.org/doc/html/rfc1928#section-4>
21#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
22pub enum Socks5Command {
23    #[allow(missing_docs)]
24    Connect = 0x01,
25    #[allow(missing_docs)]
26    Bind = 0x02,
27    #[allow(missing_docs)]
28    UdpAssociate = 0x03,
29}
30
31impl TryFrom<u8> for Socks5Command {
32    type Error = Error;
33
34    fn try_from(value: u8) -> Result<Self> {
35        match value {
36            0x01 => Ok(Socks5Command::Connect),
37            0x02 => Ok(Socks5Command::Bind),
38            0x03 => Ok(Socks5Command::UdpAssociate),
39            _ => Err(Socks5Error::InvalidCommand.into()),
40        }
41    }
42}
43
44/// SOCKS5 server reply codes as defined in RFC 1928 section 6.
45///
46/// These reply codes indicate the status of a client's request:
47/// - Succeeded (0x00): Request granted.
48/// - Various error codes (0x01-0x08): Different failure reasons.
49///
50/// Reference: <https://datatracker.ietf.org/doc/html/rfc1928#section-6>
51#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
52pub enum Socks5Reply {
53    #[allow(missing_docs)]
54    Succeeded = 0x00,
55    #[allow(missing_docs)]
56    GeneralFailure = 0x01,
57    #[allow(missing_docs)]
58    ConnectionNotAllowed = 0x02,
59    #[allow(missing_docs)]
60    NetworkUnreachable = 0x03,
61    #[allow(missing_docs)]
62    HostUnreachable = 0x04,
63    #[allow(missing_docs)]
64    ConnectionRefused = 0x05,
65    #[allow(missing_docs)]
66    TTLExpired = 0x06,
67    #[allow(missing_docs)]
68    CommandNotSupported = 0x07,
69    #[allow(missing_docs)]
70    AddressTypeNotSupported = 0x08,
71}
72
73impl TryFrom<u8> for Socks5Reply {
74    type Error = Error;
75
76    fn try_from(value: u8) -> Result<Self> {
77        match value {
78            0x00 => Ok(Socks5Reply::Succeeded),
79            0x01 => Ok(Socks5Reply::GeneralFailure),
80            0x02 => Ok(Socks5Reply::ConnectionNotAllowed),
81            0x03 => Ok(Socks5Reply::NetworkUnreachable),
82            0x04 => Ok(Socks5Reply::HostUnreachable),
83            0x05 => Ok(Socks5Reply::ConnectionRefused),
84            0x06 => Ok(Socks5Reply::TTLExpired),
85            0x07 => Ok(Socks5Reply::CommandNotSupported),
86            0x08 => Ok(Socks5Reply::AddressTypeNotSupported),
87            _ => Err(Socks5Error::InvalidReply.into()),
88        }
89    }
90}
91
92/// Accepts a SOCKS5 proxy connection request from a client.
93///
94/// This function reads, responses and processes an SOCKS5 handshake from the client,
95/// validates authentication if required, and extracts the command and target address.
96///
97/// # Arguments
98/// * `stream` - A mutable reference to an asynchronous stream.
99/// * `auth_method` - The authentication method required for this connection.
100///
101/// # Returns
102/// * `Result<(Socks5Command, Address)>` - The requested command and target address on success,
103///   or an error if the handshake fails, authentication fails, or the request is invalid.
104pub async fn socks5_accept<T>(
105    stream: &mut T,
106    auth_method: &AuthMethod,
107) -> Result<(Socks5Command, Address)>
108where
109    T: AsyncRead + AsyncWrite + Unpin,
110{
111    // Read client greeting
112    let client_auth_method = read_client_hello(stream).await?;
113
114    if !client_auth_method.contains(&auth_method.into()) {
115        // Write server greeting
116        write_server_hello(stream, &Socks5AuthOption::NoAcceptable).await?;
117        return Err(Socks5Error::NoAcceptableAuthMethod.into());
118    }
119
120    // Write server greeting
121    write_server_hello(stream, &auth_method.into()).await?;
122
123    // Handle authentication
124    match auth_method {
125        AuthMethod::NoAuth => (), // No authentication required
126        AuthMethod::UserPass { username, password } => {
127            let auth = read_auth_request(stream).await?;
128            if &auth.username != username || &auth.password != password {
129                write_auth_response(stream, false).await?;
130                return Err(Socks5Error::AuthenticationFailed.into());
131            } else {
132                write_auth_response(stream, true).await?;
133            }
134        }
135    }
136
137    // Read connection request
138    let (command, address) = read_connection_request(stream).await?;
139    Ok((command, address))
140}
141
142/// Completes a SOCKS5 proxy connection by sending a reply to the client.
143///
144/// After processing a client's SOCKS5 connection request with `socks5_accept`,
145/// this function sends the appropriate response to indicate success or failure.
146///
147/// # Arguments
148/// * `stream` - A mutable reference to an asynchronous stream.
149/// * `reply` - The SOCKS5 reply code to send to the client.
150/// * `address` - The bound address to include in the response.
151///
152/// # Returns
153/// * `Result<()>` - Success if the response is sent, or an IO error if writing fails.
154pub async fn socks5_finalize_accept<T>(
155    stream: &mut T,
156    reply: &Socks5Reply,
157    address: &Address,
158) -> Result<()>
159where
160    T: AsyncWrite + Unpin,
161{
162    // Write connection response
163    write_connection_response(stream, reply, address).await?;
164
165    Ok(())
166}
167
168/// Establishes a SOCKS5 proxy connection to a target server.
169///
170/// This function sends an SOCKS5 handshake to a proxy server with the specified
171/// target address and authentication credentials, then verifies the response.
172///
173/// # Arguments
174/// * `stream` - A mutable reference to an asynchronous stream.
175/// * `command` - The SOCKS5 command to execute (Connect, Bind, or UdpAssociate).
176/// * `address` - The target address to connect to.
177/// * `auth` - An array of supported authentication methods.
178///
179/// # Returns
180/// * `Result<Address>` - The bound address returned by the server on success,
181///   or an error if the connection fails or is rejected by the server.
182pub async fn socks5_connect<T>(
183    stream: &mut T,
184    command: &Socks5Command,
185    address: &Address,
186    auth: &[AuthMethod],
187) -> Result<Address>
188where
189    T: AsyncRead + AsyncWrite + Unpin,
190{
191    let client_auth_methods = auth.iter().map(|a| a.into()).collect::<Vec<_>>();
192    if client_auth_methods.len() > 255 {
193        return Err(Socks5Error::TooManyAuthMethods.into());
194    }
195
196    // Write client greeting
197    write_client_hello(stream, &client_auth_methods).await?;
198
199    // Read server greeting
200    let server_auth_method = read_server_hello(stream).await?;
201
202    let auth_method = match client_auth_methods
203        .iter()
204        .position(|c| c == &server_auth_method)
205    {
206        Some(i) => auth[i].clone(),
207        None => {
208            return Err(Socks5Error::NoAcceptableAuthMethod.into());
209        }
210    };
211
212    // Handle authentication
213    match auth_method {
214        AuthMethod::NoAuth => (), // No authentication required
215        AuthMethod::UserPass { username, password } => {
216            write_auth_request(stream, &UserPassAuth { username, password }).await?;
217            read_auth_response(stream).await?;
218        }
219    }
220
221    // Write connection request
222    write_connection_request(stream, command, address).await?;
223
224    // Read connection response
225    let (reply, address) = read_connection_response(stream).await?;
226
227    // Handle connection response
228    if reply != Socks5Reply::Succeeded {
229        return Err(Socks5Error::ConnectionFailed.into());
230    }
231
232    Ok(address)
233}
234
235/// Reads and parses the SOCKS5 UDP header information and destination address from a UDP packet buffer.
236///
237/// Per RFC 1928 Section 7, SOCKS5 UDP requests/responses contain a header in the following format:
238/// ```text
239/// +----+------+------+----------+----------+----------+
240/// |RSV | FRAG | ATYP | DST.ADDR | DST.PORT |   DATA   |
241/// +----+------+------+----------+----------+----------+
242/// | 2  |  1   |  1   | Variable |    2     | Variable |
243/// +----+------+------+----------+----------+----------+
244/// ```
245/// RSV: Reserved field, must be 0
246/// FRAG: Fragment number (currently unused, set to 0)
247/// ATYP/DST.ADDR/DST.PORT: Destination address encoded in the same format as SOCKS5 requests
248///
249/// # Arguments
250/// * `buf` - A buffer slice containing the UDP packet
251///
252/// # Returns
253/// * `Result<(Address, usize)>` - On success, returns the parsed destination
254///   address and the total header length in bytes, or an IO error on failure
255///
256pub fn socks5_read_udp_header(buf: &[u8]) -> Result<(Address, usize)> {
257    let first = buf
258        .first_chunk::<3>()
259        .ok_or(Error::new(ErrorKind::UnexpectedEof, "buffer too short"))?;
260    if first != &[0, 0, 0] {
261        return Err(Error::new(ErrorKind::InvalidData, "invalid UDP header"));
262    }
263    let (address, len) = Address::decode_from_buf(&buf[3..])?;
264    Ok((address, 2 + 1 + len))
265}
266
267/// Writes the SOCKS5 UDP header and destination address into a buffer.
268///
269/// # Arguments
270/// * `address` - The destination address to encode
271/// * `buf` - A mutable buffer slice to write the UDP header into
272///
273/// # Returns
274/// * `Result<usize>` - On success, returns the total header length written in bytes,
275///   or an IO error on failure
276///
277pub fn socks5_write_udp_header(address: &Address, buf: &mut [u8]) -> Result<usize> {
278    let first = buf
279        .first_chunk_mut::<3>()
280        .ok_or(Error::new(ErrorKind::UnexpectedEof, "buffer too short"))?;
281    *first = [0, 0, 0];
282    let len = address.encode_to_buf(&mut buf[3..])?;
283    Ok(2 + 1 + len)
284}
285
286#[derive(Copy, Clone, Debug, Eq, PartialEq)]
287enum Socks5AuthOption {
288    NoAuth = 0x00,
289    GssApi = 0x01,
290    UserPass = 0x02,
291    NoAcceptable = 0xFF,
292}
293
294impl TryFrom<u8> for Socks5AuthOption {
295    type Error = Error;
296
297    fn try_from(value: u8) -> Result<Self> {
298        match value {
299            0x00 => Ok(Socks5AuthOption::NoAuth),
300            0x01 => Ok(Socks5AuthOption::GssApi),
301            0x02 => Ok(Socks5AuthOption::UserPass),
302            0xFF => Ok(Socks5AuthOption::NoAcceptable),
303            _ => Err(Socks5Error::InvalidAuthMethod.into()),
304        }
305    }
306}
307
308impl From<&AuthMethod> for Socks5AuthOption {
309    fn from(value: &AuthMethod) -> Self {
310        match value {
311            AuthMethod::NoAuth => Socks5AuthOption::NoAuth,
312            AuthMethod::UserPass { .. } => Socks5AuthOption::UserPass,
313        }
314    }
315}
316
317#[derive(Debug)]
318struct UserPassAuth {
319    username: String,
320    password: String,
321}
322
323/// According to RFC 1928, client hello format is:
324/// ```text
325/// +----+----------+----------+
326/// |VER | NMETHODS | METHODS  |
327/// +----+----------+----------+
328/// | 1  |    1     | 1 to 255 |
329/// +----+----------+----------+
330/// ```
331/// VER: SOCKS protocol version, must be 0x05
332/// NMETHODS: Number of authentication methods supported by client
333/// METHODS: List of authentication methods supported by client
334async fn read_client_hello<T>(reader: &mut T) -> Result<Vec<Socks5AuthOption>>
335where
336    T: AsyncRead + Unpin,
337{
338    // Read version number
339    let ver = reader.read_u8().await?;
340    if ver != SOCKS5_VER {
341        return Err(Socks5Error::InvalidSocksVersion.into());
342    }
343
344    // Read number of authentication methods
345    let nmethods = reader.read_u8().await?;
346    if nmethods == 0 {
347        return Err(Socks5Error::NoAuthMethods.into());
348    }
349
350    // Read authentication methods list
351    let mut methods = Vec::with_capacity(nmethods as usize);
352    for _ in 0..nmethods {
353        let method_byte = reader.read_u8().await?;
354        match Socks5AuthOption::try_from(method_byte) {
355            Ok(method) => methods.push(method),
356            Err(_) => continue, // Ignore unsupported authentication methods
357        }
358    }
359
360    if methods.is_empty() {
361        return Err(Socks5Error::NoSupportedAuthMethods.into());
362    }
363
364    Ok(methods)
365}
366
367async fn write_client_hello<T>(writer: &mut T, auth_method: &[Socks5AuthOption]) -> Result<()>
368where
369    T: AsyncWrite + Unpin,
370{
371    if auth_method.is_empty() {
372        return Err(Socks5Error::NoAuthMethods.into());
373    }
374
375    // Write version number
376    writer.write_u8(SOCKS5_VER).await?;
377
378    // Write number of authentication methods
379    writer.write_u8(auth_method.len() as u8).await?;
380
381    // Write authentication methods list
382    for method in auth_method {
383        writer.write_u8(*method as u8).await?;
384    }
385
386    writer.flush().await?;
387    Ok(())
388}
389
390/// According to RFC 1928, server hello format is:
391/// ```text
392/// +----+--------+
393/// |VER | METHOD |
394/// +----+--------+
395/// | 1  |   1    |
396/// +----+--------+
397/// ```
398/// VER: SOCKS protocol version, must be 0x05
399/// METHOD: Server's selected authentication method, 0xFF means none of the client's methods are acceptable
400async fn read_server_hello<T>(reader: &mut T) -> Result<Socks5AuthOption>
401where
402    T: AsyncRead + Unpin,
403{
404    // Read version number
405    let ver = reader.read_u8().await?;
406    if ver != SOCKS5_VER {
407        return Err(Socks5Error::InvalidSocksVersion.into());
408    }
409
410    // Read server's selected authentication method
411    let method_byte = reader.read_u8().await?;
412    Socks5AuthOption::try_from(method_byte)
413}
414
415async fn write_server_hello<T>(writer: &mut T, auth_method: &Socks5AuthOption) -> Result<()>
416where
417    T: AsyncWrite + Unpin,
418{
419    // Write version number
420    writer.write_u8(SOCKS5_VER).await?;
421
422    // Write selected authentication method
423    writer.write_u8(*auth_method as u8).await?;
424
425    writer.flush().await?;
426    Ok(())
427}
428
429/// According to RFC 1929, username/password authentication request format is:
430/// ```text
431/// +----+------+----------+------+----------+
432/// |VER | ULEN |  UNAME   | PLEN |  PASSWD  |
433/// +----+------+----------+------+----------+
434/// | 1  |  1   | 1 to 255 |  1   | 1 to 255 |
435/// +----+------+----------+------+----------+
436/// ```
437/// VER: Authentication sub-protocol version, must be 0x01.
438/// ULEN: Username length (1-255 bytes).
439/// UNAME: Username.
440/// PLEN: Password length (1-255 bytes).
441/// PASSWD: Password.
442async fn read_auth_request<T>(reader: &mut T) -> Result<UserPassAuth>
443where
444    T: AsyncRead + Unpin,
445{
446    // Read authentication sub-protocol version number
447    let ver = reader.read_u8().await?;
448    if ver != SOCKS5_AUTH_VER {
449        return Err(Socks5Error::InvalidAuthVersion.into());
450    }
451
452    // Read username
453    let ulen = reader.read_u8().await? as usize;
454    let mut uname = vec![0u8; ulen];
455    reader.read_exact(&mut uname).await?;
456    let username = String::from_utf8(uname).map_err(|_| Socks5Error::InvalidUsernameEncoding)?;
457
458    // Read password
459    let plen = reader.read_u8().await? as usize;
460    let mut passwd = vec![0u8; plen];
461    reader.read_exact(&mut passwd).await?;
462    let password = String::from_utf8(passwd).map_err(|_| Socks5Error::InvalidPasswordEncoding)?;
463
464    Ok(UserPassAuth { username, password })
465}
466
467async fn write_auth_request<T>(writer: &mut T, auth: &UserPassAuth) -> Result<()>
468where
469    T: AsyncWrite + Unpin,
470{
471    // Write authentication sub-protocol version number
472    writer.write_u8(SOCKS5_AUTH_VER).await?;
473
474    // Write username
475    let username_bytes = auth.username.as_bytes();
476    if username_bytes.len() > 255 {
477        return Err(Socks5Error::UsernameTooLong.into());
478    }
479    writer.write_u8(username_bytes.len() as u8).await?;
480    writer.write_all(username_bytes).await?;
481
482    // Write password
483    let password_bytes = auth.password.as_bytes();
484    if password_bytes.len() > 255 {
485        return Err(Socks5Error::PasswordTooLong.into());
486    }
487    writer.write_u8(password_bytes.len() as u8).await?;
488    writer.write_all(password_bytes).await?;
489
490    writer.flush().await?;
491    Ok(())
492}
493
494/// According to RFC 1929, username/password authentication response format is:
495/// ```text
496/// +----+--------+
497/// |VER | STATUS |
498/// +----+--------+
499/// | 1  |   1    |
500/// +----+--------+
501/// ```
502/// VER: Authentication sub-protocol version, must be 0x01.
503/// STATUS: Authentication result, 0x00 means success, other values mean failure.
504async fn read_auth_response<T>(reader: &mut T) -> Result<()>
505where
506    T: AsyncRead + Unpin,
507{
508    // Read authentication sub-protocol version number
509    let ver = reader.read_u8().await?;
510    if ver != SOCKS5_AUTH_VER {
511        return Err(Socks5Error::InvalidAuthVersion.into());
512    }
513
514    // Read authentication result
515    let status = reader.read_u8().await?;
516    if status != 0 {
517        return Err(Socks5Error::AuthenticationFailed.into());
518    }
519
520    Ok(())
521}
522
523async fn write_auth_response<T>(writer: &mut T, is_ok: bool) -> Result<()>
524where
525    T: AsyncWrite + Unpin,
526{
527    // Write authentication sub-protocol version number
528    writer.write_u8(SOCKS5_AUTH_VER).await?;
529
530    // Write authentication result
531    writer.write_u8(if is_ok { 0 } else { 1 }).await?;
532
533    writer.flush().await?;
534    Ok(())
535}
536
537/// According to RFC 1928, connection request format is:
538/// ```text
539/// +----+-----+-------+------+----------+----------+
540/// |VER | CMD |  RSV  | ATYP | DST.ADDR | DST.PORT |
541/// +----+-----+-------+------+----------+----------+
542/// | 1  |  1  | X'00' |  1   | Variable |    2     |
543/// +----+-----+-------+------+----------+----------+
544/// ```
545/// VER: SOCKS protocol version, must be 0x05.
546/// CMD: Command code - 0x01 (CONNECT), 0x02 (BIND), 0x03 (UDP ASSOCIATE).
547/// RSV: Reserved field, must be 0x00.
548/// ATYP: Address type - 0x01 (IPv4), 0x03 (domain name), 0x04 (IPv6).
549/// DST.ADDR: Destination address, format depends on ATYP.
550/// DST.PORT: Destination port, network byte order (big-endian).
551async fn read_connection_request<T>(reader: &mut T) -> Result<(Socks5Command, Address)>
552where
553    T: AsyncRead + Unpin,
554{
555    // Read version number
556    let ver = reader.read_u8().await?;
557    if ver != SOCKS5_VER {
558        return Err(Socks5Error::InvalidSocksVersion.into());
559    }
560
561    // Read command
562    let cmd = Socks5Command::try_from(reader.read_u8().await?)?;
563
564    // Read reserved field
565    let rsv = reader.read_u8().await?;
566    if rsv != 0 {
567        return Err(Socks5Error::InvalidRsvValue.into());
568    }
569
570    // Read address type and address
571    let (address, _) = Address::decode_from_reader(reader).await?;
572
573    Ok((cmd, address))
574}
575
576async fn write_connection_request<T>(
577    writer: &mut T,
578    command: &Socks5Command,
579    address: &Address,
580) -> Result<()>
581where
582    T: AsyncWrite + Unpin,
583{
584    // Write version number
585    writer.write_u8(SOCKS5_VER).await?;
586
587    // Write command
588    writer.write_u8(*command as u8).await?;
589
590    // Write reserved field
591    writer.write_u8(0).await?;
592
593    // Write address type and address
594    address.encode_to_writer(writer).await?;
595
596    writer.flush().await?;
597    Ok(())
598}
599
600/// According to RFC 1928, the connection response format is:
601/// ```text
602/// +----+-----+-------+------+----------+----------+
603/// |VER | REP |  RSV  | ATYP | BND.ADDR | BND.PORT |
604/// +----+-----+-------+------+----------+----------+
605/// | 1  |  1  | X'00' |  1   | Variable |    2     |
606/// +----+-----+-------+------+----------+----------+
607/// ```
608/// VER: SOCKS protocol version, must be 0x05.
609/// REP: Reply code - 0x00 (succeeded), 0x01-0x08 (various errors).
610/// RSV: Reserved field, must be 0x00.
611/// ATYP: Address type - 0x01 (IPv4), 0x03 (domain name), 0x04 (IPv6).
612/// BND.ADDR: Server bound address, format depends on ATYP.
613/// BND.PORT: Server bound port, network byte order (big-endian).
614async fn read_connection_response<T>(reader: &mut T) -> Result<(Socks5Reply, Address)>
615where
616    T: AsyncRead + Unpin,
617{
618    // Read version number
619    let ver = reader.read_u8().await?;
620    if ver != SOCKS5_VER {
621        return Err(Socks5Error::InvalidSocksVersion.into());
622    }
623
624    // Read reply code
625    let reply = Socks5Reply::try_from(reader.read_u8().await?)?;
626
627    // Read reserved field
628    let rsv = reader.read_u8().await?;
629    if rsv != 0 {
630        return Err(Socks5Error::InvalidRsvValue.into());
631    }
632
633    // Read address type and address (though we might not need to use them)
634    let (address, _) = Address::decode_from_reader(reader).await?;
635
636    Ok((reply, address))
637}
638
639async fn write_connection_response<T>(
640    writer: &mut T,
641    reply: &Socks5Reply,
642    address: &Address,
643) -> Result<()>
644where
645    T: AsyncWrite + Unpin,
646{
647    // Write version number
648    writer.write_u8(SOCKS5_VER).await?;
649
650    // Write reply code
651    writer.write_u8(*reply as u8).await?;
652
653    // Write reserved field
654    writer.write_u8(0).await?;
655
656    // Write address type and address
657    address.encode_to_writer(writer).await?;
658
659    writer.flush().await?;
660    Ok(())
661}
662
663/// Errors that can occur during SOCKS5 protocol operations.
664///
665/// Each variant represents a specific error condition that may arise when implementing
666/// or using the SOCKS5 protocol, as defined in RFC 1928 and RFC 1929.
667#[derive(Clone, Debug, Eq, PartialEq)]
668#[non_exhaustive]
669pub enum Socks5Error {
670    /// Server reports no acceptable authentication methods from those offered by client.
671    NoAcceptableAuthMethod,
672    /// User credentials were rejected during authentication phase.
673    AuthenticationFailed,
674    /// Connection to the target host could not be established.
675    ConnectionFailed,
676    /// Client sent an incorrect SOCKS version (expected 0x05).
677    InvalidSocksVersion,
678    /// Client sent an incorrect authentication subprotocol version.
679    InvalidAuthVersion,
680    /// Client did not provide any authentication methods.
681    NoAuthMethods,
682    /// None of the client's offered authentication methods are supported.
683    NoSupportedAuthMethods,
684    /// The authentication method byte value is not recognized.
685    InvalidAuthMethod,
686    /// The command byte is not a valid SOCKS5 command.
687    InvalidCommand,
688    /// The reply byte is not a valid SOCKS5 reply code.
689    InvalidReply,
690    /// The reserved field contains a non-zero value.
691    InvalidRsvValue,
692    /// The username contains invalid UTF-8 encoding.
693    InvalidUsernameEncoding,
694    /// The password contains invalid UTF-8 encoding.
695    InvalidPasswordEncoding,
696    /// The username exceeds maximum allowed length (255 bytes).
697    UsernameTooLong,
698    /// The password exceeds maximum allowed length (255 bytes).
699    PasswordTooLong,
700    /// Client offered more than 255 authentication methods.
701    TooManyAuthMethods,
702}
703
704impl Display for Socks5Error {
705    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
706        match self {
707            Self::NoAcceptableAuthMethod => write!(f, "No acceptable authentication method"),
708            Self::AuthenticationFailed => write!(f, "Authentication failed"),
709            Self::ConnectionFailed => write!(f, "Connection failed"),
710            Self::InvalidSocksVersion => write!(f, "Invalid SOCKS version"),
711            Self::InvalidAuthVersion => write!(f, "Invalid auth version"),
712            Self::NoAuthMethods => write!(f, "No authentication methods provided"),
713            Self::NoSupportedAuthMethods => write!(f, "No supported authentication methods"),
714            Self::InvalidAuthMethod => write!(f, "Invalid AuthMethod"),
715            Self::InvalidCommand => write!(f, "Invalid Command"),
716
717            Self::InvalidReply => write!(f, "Invalid Reply"),
718            Self::InvalidRsvValue => write!(f, "Invalid RSV value"),
719            Self::InvalidUsernameEncoding => write!(f, "Invalid username encoding"),
720            Self::InvalidPasswordEncoding => write!(f, "Invalid password encoding"),
721            Self::UsernameTooLong => write!(f, "Username too long"),
722            Self::PasswordTooLong => write!(f, "Password too long"),
723            Self::TooManyAuthMethods => write!(f, "Too many authentication methods"),
724        }
725    }
726}
727
728impl std::error::Error for Socks5Error {}
729
730impl From<Socks5Error> for Error {
731    fn from(e: Socks5Error) -> Self {
732        match e {
733            Socks5Error::NoAcceptableAuthMethod => Error::new(ErrorKind::PermissionDenied, e),
734            Socks5Error::AuthenticationFailed => Error::new(ErrorKind::PermissionDenied, e),
735            Socks5Error::ConnectionFailed => Error::new(ErrorKind::ConnectionRefused, e),
736            Socks5Error::InvalidSocksVersion => Error::new(ErrorKind::InvalidData, e),
737            Socks5Error::InvalidAuthVersion => Error::new(ErrorKind::InvalidData, e),
738            Socks5Error::NoAuthMethods => Error::new(ErrorKind::InvalidInput, e),
739            Socks5Error::NoSupportedAuthMethods => Error::new(ErrorKind::InvalidData, e),
740            Socks5Error::InvalidAuthMethod => Error::new(ErrorKind::InvalidData, e),
741            Socks5Error::InvalidCommand => Error::new(ErrorKind::InvalidData, e),
742            Socks5Error::InvalidReply => Error::new(ErrorKind::InvalidData, e),
743            Socks5Error::InvalidRsvValue => Error::new(ErrorKind::InvalidData, e),
744            Socks5Error::InvalidUsernameEncoding => Error::new(ErrorKind::InvalidData, e),
745            Socks5Error::InvalidPasswordEncoding => Error::new(ErrorKind::InvalidData, e),
746            Socks5Error::UsernameTooLong => Error::new(ErrorKind::InvalidInput, e),
747            Socks5Error::PasswordTooLong => Error::new(ErrorKind::InvalidInput, e),
748            Socks5Error::TooManyAuthMethods => Error::new(ErrorKind::InvalidInput, e),
749        }
750    }
751}
752
753#[cfg(test)]
754mod test {
755    use std::net::{Ipv4Addr, Ipv6Addr};
756
757    use super::*;
758    use crate::test_utils::create_mock_stream;
759
760    #[tokio::test]
761    async fn test_client_hello_write_read() {
762        let all_methods = [
763            vec![Socks5AuthOption::NoAuth],
764            vec![
765                Socks5AuthOption::NoAuth,
766                Socks5AuthOption::UserPass,
767                Socks5AuthOption::GssApi,
768            ],
769        ];
770        for methods in all_methods {
771            let (mut stream1, mut stream2) = create_mock_stream();
772            write_client_hello(&mut stream1, &methods).await.unwrap();
773            let recevied_methods = read_client_hello(&mut stream2).await.unwrap();
774            assert_eq!(methods.as_slice(), recevied_methods.as_slice());
775        }
776    }
777
778    #[tokio::test]
779    async fn test_server_hello_write_read() {
780        let (mut stream1, mut stream2) = create_mock_stream();
781        write_server_hello(&mut stream1, &Socks5AuthOption::NoAuth)
782            .await
783            .unwrap();
784        let method = read_server_hello(&mut stream2).await.unwrap();
785        assert_eq!(Socks5AuthOption::NoAuth, method);
786    }
787
788    #[tokio::test]
789    async fn test_auth_request_write_read() {
790        let (mut stream1, mut stream2) = create_mock_stream();
791        let auth = UserPassAuth {
792            username: "test_user".to_string(),
793            password: "test_pass".to_string(),
794        };
795        write_auth_request(&mut stream1, &auth).await.unwrap();
796        let received_auth = read_auth_request(&mut stream2).await.unwrap();
797        assert_eq!(auth.username, received_auth.username);
798        assert_eq!(auth.password, received_auth.password);
799    }
800
801    #[tokio::test]
802    async fn test_auth_response_write_read() {
803        // Authentication success
804        let (mut stream1, mut stream2) = create_mock_stream();
805        write_auth_response(&mut stream1, true).await.unwrap();
806        read_auth_response(&mut stream2).await.unwrap();
807
808        // Authentication failure
809        let (mut stream1, mut stream2) = create_mock_stream();
810        write_auth_response(&mut stream1, false).await.unwrap();
811        let err = read_auth_response(&mut stream2).await.unwrap_err();
812        assert_eq!(
813            err.downcast::<Socks5Error>().unwrap(),
814            Socks5Error::AuthenticationFailed
815        );
816    }
817
818    #[tokio::test]
819    async fn test_connection_request_write_read() {
820        let all_commands = [
821            Socks5Command::Connect,
822            Socks5Command::Bind,
823            Socks5Command::UdpAssociate,
824        ];
825        let all_addresses = [
826            Address::IPv4((Ipv4Addr::new(192, 168, 1, 1), 8080)),
827            Address::DomainName(("example.com".to_string(), 443)),
828            Address::IPv6((
829                Ipv6Addr::new(0x20, 0x01, 0x0d, 0xb8, 0x00, 0x00, 0x00, 0x01),
830                8080,
831            )),
832        ];
833        for command in all_commands {
834            for address in all_addresses.iter() {
835                let (mut stream1, mut stream2) = create_mock_stream();
836                write_connection_request(&mut stream1, &command, address)
837                    .await
838                    .unwrap();
839                let (received_command, received_address) =
840                    read_connection_request(&mut stream2).await.unwrap();
841                assert_eq!(command, received_command);
842                assert_eq!(address, &received_address);
843            }
844        }
845    }
846
847    #[tokio::test]
848    async fn test_connection_response_write_read() {
849        let all_replies = [
850            Socks5Reply::Succeeded,
851            Socks5Reply::GeneralFailure,
852            Socks5Reply::ConnectionNotAllowed,
853            Socks5Reply::NetworkUnreachable,
854            Socks5Reply::HostUnreachable,
855            Socks5Reply::ConnectionRefused,
856            Socks5Reply::TTLExpired,
857            Socks5Reply::CommandNotSupported,
858            Socks5Reply::AddressTypeNotSupported,
859        ];
860        let all_addresses = [
861            Address::IPv4((Ipv4Addr::new(192, 168, 1, 1), 8080)),
862            Address::DomainName(("example.com".to_string(), 443)),
863            Address::IPv6((
864                Ipv6Addr::new(0x20, 0x01, 0x0d, 0xb8, 0x00, 0x00, 0x00, 0x01),
865                8080,
866            )),
867        ];
868        for reply in all_replies {
869            for address in all_addresses.iter() {
870                let (mut stream1, mut stream2) = create_mock_stream();
871                write_connection_response(&mut stream1, &reply, address)
872                    .await
873                    .unwrap();
874                let (received_reply, received_address) =
875                    read_connection_response(&mut stream2).await.unwrap();
876                assert_eq!(reply, received_reply);
877                assert_eq!(address, &received_address);
878            }
879        }
880    }
881
882    #[tokio::test]
883    async fn test_read_client_hello_invalid_version() {
884        let (mut client, server) = create_mock_stream();
885
886        // Invalid SOCKS version: 0x04 instead of 0x05
887        server.write_immediate(&[0x04, 0x01, 0x00]).unwrap();
888
889        let result = read_client_hello(&mut client).await;
890
891        let err = result.unwrap_err();
892        assert_eq!(err.kind(), ErrorKind::InvalidData);
893        assert_eq!(
894            err.downcast::<Socks5Error>().unwrap(),
895            Socks5Error::InvalidSocksVersion
896        );
897    }
898
899    #[tokio::test]
900    async fn test_read_client_hello_no_auth_method() {
901        let (mut client, server) = create_mock_stream();
902
903        // No authentication methods: NMETHODS is 0
904        server.write_immediate(&[0x05, 0x00]).unwrap();
905
906        let result = read_client_hello(&mut client).await;
907
908        let err = result.unwrap_err();
909        assert_eq!(err.kind(), ErrorKind::InvalidInput);
910        assert_eq!(
911            err.downcast::<Socks5Error>().unwrap(),
912            Socks5Error::NoAuthMethods
913        );
914    }
915
916    #[tokio::test]
917    async fn test_read_client_hello_unsupported_auth_methods() {
918        let (mut client, server) = create_mock_stream();
919
920        // Only unsupported auth method: 0x80 is not a valid SOCKS5 auth method
921        server.write_immediate(&[0x05, 0x01, 0x80]).unwrap();
922
923        let result = read_client_hello(&mut client).await;
924
925        let err = result.unwrap_err();
926        assert_eq!(err.kind(), ErrorKind::InvalidData);
927        assert_eq!(
928            err.downcast::<Socks5Error>().unwrap(),
929            Socks5Error::NoSupportedAuthMethods
930        );
931    }
932
933    #[tokio::test]
934    async fn test_write_client_hello_no_auth_method() {
935        let (mut client, _server) = create_mock_stream();
936
937        let result = write_client_hello(&mut client, &[]).await;
938
939        let err = result.unwrap_err();
940        assert_eq!(err.kind(), ErrorKind::InvalidInput);
941        assert_eq!(
942            err.downcast::<Socks5Error>().unwrap(),
943            Socks5Error::NoAuthMethods
944        );
945    }
946
947    #[tokio::test]
948    async fn test_read_server_hello_invalid_version() {
949        let (mut client, server) = create_mock_stream();
950
951        // Invalid SOCKS version: 0x04 instead of 0x05
952        server.write_immediate(&[0x04, 0x00]).unwrap();
953
954        let result = read_server_hello(&mut client).await;
955
956        let err = result.unwrap_err();
957        assert_eq!(err.kind(), ErrorKind::InvalidData);
958        assert_eq!(
959            err.downcast::<Socks5Error>().unwrap(),
960            Socks5Error::InvalidSocksVersion
961        );
962    }
963
964    #[tokio::test]
965    async fn test_read_auth_request_invalid_version() {
966        let (mut client, server) = create_mock_stream();
967
968        // Invalid auth version: 0x02 instead of 0x01
969        // Format: [version, username length, username, password length, password]
970        server
971            .write_immediate(&[
972                0x02, 0x04, b'u', b's', b'e', b'r', 0x04, b'p', b'a', b's', b's',
973            ])
974            .unwrap();
975
976        let result = read_auth_request(&mut client).await;
977
978        let err = result.unwrap_err();
979        assert_eq!(err.kind(), ErrorKind::InvalidData);
980        assert_eq!(
981            err.downcast::<Socks5Error>().unwrap(),
982            Socks5Error::InvalidAuthVersion
983        );
984    }
985
986    #[tokio::test]
987    async fn test_read_auth_request_invalid_username_encoding() {
988        let (mut client, server) = create_mock_stream();
989
990        // Invalid UTF-8 sequence for username
991        server
992            .write_immediate(&[
993                0x01, 0x04, 0xFF, 0xFF, 0xFF, 0xFF, // Invalid UTF-8 sequence
994                0x04, b'p', b'a', b's', b's',
995            ])
996            .unwrap();
997
998        let result = read_auth_request(&mut client).await;
999
1000        let err = result.unwrap_err();
1001        assert_eq!(err.kind(), ErrorKind::InvalidData);
1002        assert_eq!(
1003            err.downcast::<Socks5Error>().unwrap(),
1004            Socks5Error::InvalidUsernameEncoding
1005        );
1006    }
1007
1008    #[tokio::test]
1009    async fn test_read_auth_request_invalid_password_encoding() {
1010        let (mut client, server) = create_mock_stream();
1011
1012        // Invalid UTF-8 sequence for password
1013        server
1014            .write_immediate(&[
1015                0x01, 0x04, b'u', b's', b'e', b'r', 0x04, 0xFF, 0xFF, 0xFF,
1016                0xFF, // Invalid UTF-8 sequence
1017            ])
1018            .unwrap();
1019
1020        let result = read_auth_request(&mut client).await;
1021
1022        let err = result.unwrap_err();
1023        assert_eq!(err.kind(), ErrorKind::InvalidData);
1024        assert_eq!(
1025            err.downcast::<Socks5Error>().unwrap(),
1026            Socks5Error::InvalidPasswordEncoding
1027        );
1028    }
1029
1030    #[tokio::test]
1031    async fn test_write_auth_request_username_too_long() {
1032        let (mut client, _server) = create_mock_stream();
1033
1034        // Username length of 256 bytes (exceeds max of 255)
1035        let long_username = "a".repeat(256);
1036        let auth = UserPassAuth {
1037            username: long_username,
1038            password: "password".to_string(),
1039        };
1040
1041        let result = write_auth_request(&mut client, &auth).await;
1042
1043        let err = result.unwrap_err();
1044        assert_eq!(err.kind(), ErrorKind::InvalidInput);
1045        assert_eq!(
1046            err.downcast::<Socks5Error>().unwrap(),
1047            Socks5Error::UsernameTooLong
1048        );
1049    }
1050
1051    #[tokio::test]
1052    async fn test_write_auth_request_password_too_long() {
1053        let (mut client, _server) = create_mock_stream();
1054
1055        // Password length of 256 bytes (exceeds max of 255)
1056        let long_password = "a".repeat(256);
1057        let auth = UserPassAuth {
1058            username: "username".to_string(),
1059            password: long_password,
1060        };
1061
1062        let result = write_auth_request(&mut client, &auth).await;
1063
1064        let err = result.unwrap_err();
1065        assert_eq!(err.kind(), ErrorKind::InvalidInput);
1066        assert_eq!(
1067            err.downcast::<Socks5Error>().unwrap(),
1068            Socks5Error::PasswordTooLong
1069        );
1070    }
1071
1072    #[tokio::test]
1073    async fn test_read_auth_response_invalid_auth_version() {
1074        let (mut client, server) = create_mock_stream();
1075
1076        // Invalid auth version: 0x02 instead of 0x01
1077        server.write_immediate(&[0x02, 0x00]).unwrap();
1078
1079        let result = read_auth_response(&mut client).await;
1080
1081        let err = result.unwrap_err();
1082        assert_eq!(err.kind(), ErrorKind::InvalidData);
1083        assert_eq!(
1084            err.downcast::<Socks5Error>().unwrap(),
1085            Socks5Error::InvalidAuthVersion
1086        );
1087    }
1088
1089    #[tokio::test]
1090    async fn test_read_auth_response_auth_failed() {
1091        let (mut client, server) = create_mock_stream();
1092
1093        // Status 0x01 indicates authentication failure
1094        server.write_immediate(&[0x01, 0x01]).unwrap();
1095
1096        let result = read_auth_response(&mut client).await;
1097
1098        let err = result.unwrap_err();
1099        assert_eq!(err.kind(), ErrorKind::PermissionDenied);
1100        assert_eq!(
1101            err.downcast::<Socks5Error>().unwrap(),
1102            Socks5Error::AuthenticationFailed
1103        );
1104    }
1105
1106    #[tokio::test]
1107    async fn test_read_connection_request_invalid_version() {
1108        let (mut client, server) = create_mock_stream();
1109
1110        // Invalid SOCKS version: 0x04 instead of 0x05
1111        server
1112            .write_immediate(&[
1113                0x04, 0x01, // CONNECT command
1114                0x00, // Reserved field
1115                0x01, // IPv4 address type
1116                0x7F, 0x00, 0x00, 0x01, // 127.0.0.1
1117                0x00, 0x50, // Port 80
1118            ])
1119            .unwrap();
1120
1121        let result = read_connection_request(&mut client).await;
1122
1123        let err = result.unwrap_err();
1124        assert_eq!(err.kind(), ErrorKind::InvalidData);
1125        assert_eq!(
1126            err.downcast::<Socks5Error>().unwrap(),
1127            Socks5Error::InvalidSocksVersion
1128        );
1129    }
1130
1131    #[tokio::test]
1132    async fn test_read_connection_request_invalid_rsv() {
1133        let (mut client, server) = create_mock_stream();
1134
1135        // Invalid reserved field: 0x01 instead of 0x00
1136        server
1137            .write_immediate(&[
1138                0x05, // SOCKS5 version
1139                0x01, // CONNECT command
1140                0x01, // Invalid reserved field
1141                0x01, // IPv4 address type
1142                0x7F, 0x00, 0x00, 0x01, // 127.0.0.1
1143                0x00, 0x50, // Port 80
1144            ])
1145            .unwrap();
1146
1147        let result = read_connection_request(&mut client).await;
1148
1149        let err = result.unwrap_err();
1150        assert_eq!(err.kind(), ErrorKind::InvalidData);
1151        assert_eq!(
1152            err.downcast::<Socks5Error>().unwrap(),
1153            Socks5Error::InvalidRsvValue
1154        );
1155    }
1156
1157    #[tokio::test]
1158    async fn test_read_connection_response_invalid_version() {
1159        let (mut client, server) = create_mock_stream();
1160
1161        // Invalid SOCKS version: 0x04 instead of 0x05
1162        server
1163            .write_immediate(&[
1164                0x04, // Invalid SOCKS version
1165                0x00, // Success reply
1166                0x00, // Reserved field
1167                0x01, // IPv4 address type
1168                0x7F, 0x00, 0x00, 0x01, // 127.0.0.1
1169                0x00, 0x50, // Port 80
1170            ])
1171            .unwrap();
1172
1173        let result = read_connection_response(&mut client).await;
1174
1175        let err = result.unwrap_err();
1176        assert_eq!(err.kind(), ErrorKind::InvalidData);
1177        assert_eq!(
1178            err.downcast::<Socks5Error>().unwrap(),
1179            Socks5Error::InvalidSocksVersion
1180        );
1181    }
1182
1183    #[tokio::test]
1184    async fn test_read_connection_response_invalid_rsv() {
1185        let (mut client, server) = create_mock_stream();
1186
1187        // Invalid reserved field: 0x01 instead of 0x00
1188        server
1189            .write_immediate(&[
1190                0x05, // SOCKS5 version
1191                0x00, // Success reply
1192                0x01, // Invalid reserved field
1193                0x01, // IPv4 address type
1194                0x7F, 0x00, 0x00, 0x01, // 127.0.0.1
1195                0x00, 0x50, // Port 80
1196            ])
1197            .unwrap();
1198
1199        let result = read_connection_response(&mut client).await;
1200
1201        let err = result.unwrap_err();
1202        assert_eq!(err.kind(), ErrorKind::InvalidData);
1203        assert_eq!(
1204            err.downcast::<Socks5Error>().unwrap(),
1205            Socks5Error::InvalidRsvValue
1206        );
1207    }
1208}