socks5_protocol/
lib.rs

1//! # Async tokio protocol
2//!
3//! `socks5-protocol` provides types that can be read from `AsyncRead` and write to `AsyncWrite`.
4//!
5//! You can create socks5 server or socks5 client using this library.
6
7#![warn(missing_docs, missing_debug_implementations, rust_2018_idioms)]
8
9use std::{
10    convert::TryInto,
11    fmt, io,
12    net::{IpAddr, Ipv4Addr, SocketAddr},
13    str::FromStr,
14};
15use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
16
17pub use error::{Error, Result};
18
19mod error;
20#[cfg(feature = "sync")]
21/// Sync version.
22pub mod sync;
23
24/// Version conatins one byte. In socks5, it should be `5`, other value will return `Error::InvalidVersion`.
25///
26/// ```
27/// use std::io::Cursor;
28/// use socks5_protocol::Version;
29///
30/// #[tokio::main]
31/// async fn main() {
32///    let mut buf = Cursor::new([5u8]);
33///    let version = Version::read(&mut buf).await.unwrap();
34///    assert_eq!(version, Version::V5);
35/// }
36/// ```
37#[derive(Debug, PartialEq, Eq)]
38pub enum Version {
39    /// SOCKS Version 5
40    V5,
41}
42impl Version {
43    /// Read `Version` from AsyncRead.
44    pub async fn read(mut reader: impl AsyncRead + Unpin) -> Result<Version> {
45        let version = &mut [0u8];
46        reader.read_exact(version).await?;
47        match version[0] {
48            5 => Ok(Version::V5),
49            other => Err(Error::InvalidVersion(other)),
50        }
51    }
52    /// Write `Version` to AsyncWrite.
53    pub async fn write(&self, mut writer: impl AsyncWrite + Unpin) -> Result<()> {
54        let v = match self {
55            Version::V5 => 5u8,
56        };
57        writer.write_all(&[v]).await?;
58        Ok(())
59    }
60}
61
62/// `AuthMethod` is defined in RFC 1928.
63#[derive(Debug, Eq, PartialEq, Clone, Copy)]
64pub enum AuthMethod {
65    /// NO AUTHENTICATION REQUIRED. (`0x00`)
66    Noauth,
67    /// GSSAPI. (`0x01`)
68    Gssapi,
69    /// USERNAME/PASSWORD. (`0x02`)
70    UsernamePassword,
71    /// NO ACCEPTABLE METHODS. (`0xFF`)
72    NoAcceptableMethod,
73    /// Other values
74    Other(u8),
75}
76
77impl From<u8> for AuthMethod {
78    fn from(n: u8) -> Self {
79        match n {
80            0x00 => AuthMethod::Noauth,
81            0x01 => AuthMethod::Gssapi,
82            0x02 => AuthMethod::UsernamePassword,
83            0xff => AuthMethod::NoAcceptableMethod,
84            other => AuthMethod::Other(other),
85        }
86    }
87}
88
89impl Into<u8> for AuthMethod {
90    fn into(self) -> u8 {
91        match self {
92            AuthMethod::Noauth => 0x00,
93            AuthMethod::Gssapi => 0x01,
94            AuthMethod::UsernamePassword => 0x02,
95            AuthMethod::NoAcceptableMethod => 0xff,
96            AuthMethod::Other(other) => other,
97        }
98    }
99}
100
101/// `AuthRequest` message:
102///
103/// <pre><code>
104/// +----------+----------+
105/// | NMETHODS | METHODS  |
106/// +----------+----------+
107/// |    1     | 1 to 255 |
108/// +----------+----------+
109/// </code></pre>
110#[derive(Debug, Eq, PartialEq, Clone)]
111pub struct AuthRequest(pub Vec<AuthMethod>);
112
113impl AuthRequest {
114    /// Create an `AuthRequest`.
115    pub fn new(methods: impl Into<Vec<AuthMethod>>) -> AuthRequest {
116        AuthRequest(methods.into())
117    }
118    /// Read `AuthRequest` from AsyncRead.
119    pub async fn read(mut reader: impl AsyncRead + Unpin) -> Result<AuthRequest> {
120        let count = &mut [0u8];
121        reader.read_exact(count).await?;
122        let mut methods = vec![0u8; count[0] as usize];
123        reader.read_exact(&mut methods).await?;
124
125        Ok(AuthRequest(methods.into_iter().map(Into::into).collect()))
126    }
127    /// Write `AuthRequest` to AsyncWrite.
128    pub async fn write(&self, mut writer: impl AsyncWrite + Unpin) -> Result<()> {
129        let count = self.0.len();
130        if count > 255 {
131            return Err(Error::TooManyMethods);
132        }
133
134        writer.write_all(&[count as u8]).await?;
135        writer
136            .write_all(
137                &self
138                    .0
139                    .iter()
140                    .map(|i| Into::<u8>::into(*i))
141                    .collect::<Vec<_>>(),
142            )
143            .await?;
144
145        Ok(())
146    }
147    /// Select one `AuthMethod` from give slice.
148    pub fn select_from(&self, auth: &[AuthMethod]) -> AuthMethod {
149        self.0
150            .iter()
151            .enumerate()
152            .find(|(_, m)| auth.contains(*m))
153            .map(|(v, _)| AuthMethod::from(v as u8))
154            .unwrap_or(AuthMethod::NoAcceptableMethod)
155    }
156}
157
158/// `AuthResponse` message:
159///
160/// <pre><code>
161/// +--------+
162/// | METHOD |
163/// +--------+
164/// |   1    |
165/// +--------+
166/// </code></pre>
167#[derive(Debug, Eq, PartialEq, Clone)]
168pub struct AuthResponse(AuthMethod);
169
170impl AuthResponse {
171    /// Create an `AuthMethod`.
172    pub fn new(method: AuthMethod) -> AuthResponse {
173        AuthResponse(method)
174    }
175    /// Read `AuthResponse` from AsyncRead.
176    pub async fn read(mut reader: impl AsyncRead + Unpin) -> Result<AuthResponse> {
177        let method = &mut [0u8];
178        reader.read_exact(method).await?;
179        Ok(AuthResponse(method[0].into()))
180    }
181    /// Write `AuthResponse` to AsyncWrite.
182    pub async fn write(&self, mut writer: impl AsyncWrite + Unpin) -> Result<()> {
183        writer.write_all(&[self.0.into()]).await?;
184        Ok(())
185    }
186    /// Get method.
187    pub fn method(&self) -> AuthMethod {
188        self.0
189    }
190}
191
192/// `Command` type.
193///
194/// It has 3 commands: `Connect`, `Bind` and `UdpAssociate`.
195#[derive(Debug)]
196pub enum Command {
197    /// Connect
198    Connect,
199    /// Bind
200    Bind,
201    /// Udp Associate
202    UdpAssociate,
203}
204
205/// `CommandRequest` message:
206///
207/// <pre><code>
208/// +-----+-------+------+----------+----------+
209/// | CMD |  RSV  | ATYP | DST.ADDR | DST.PORT |
210/// +-----+-------+------+----------+----------+
211/// |  1  | X'00' |  1   | Variable |    2     |
212/// +-----+-------+------+----------+----------+
213/// </code></pre>
214#[derive(Debug)]
215pub struct CommandRequest {
216    /// command (CMD).
217    pub command: Command,
218    /// Address (ATYP, DST.ADDR, DST.PORT).
219    pub address: Address,
220}
221
222impl CommandRequest {
223    /// Create a `CommandRequest` with `Connect` to `address`.
224    pub fn connect(address: Address) -> CommandRequest {
225        CommandRequest {
226            command: Command::Connect,
227            address,
228        }
229    }
230    /// Create a `CommandRequest` with `UdpAssociate` to `address`.
231    pub fn udp_associate(address: Address) -> CommandRequest {
232        CommandRequest {
233            command: Command::UdpAssociate,
234            address,
235        }
236    }
237    /// Read `CommandRequest` from `AsyncRead`.
238    pub async fn read(mut reader: impl AsyncRead + Unpin) -> Result<CommandRequest> {
239        let buf = &mut [0u8; 3];
240        reader.read_exact(buf).await?;
241        if buf[0] != 5 {
242            return Err(Error::InvalidVersion(buf[0]));
243        }
244        if buf[2] != 0 {
245            return Err(Error::InvalidHandshake);
246        }
247        let cmd = match buf[1] {
248            1 => Command::Connect,
249            2 => Command::Bind,
250            3 => Command::UdpAssociate,
251            _ => return Err(Error::InvalidCommand(buf[1])),
252        };
253
254        let address = Address::read(reader).await?;
255
256        Ok(CommandRequest {
257            command: cmd,
258            address,
259        })
260    }
261    /// Write `CommandRequest` to `AsyncWrite`.
262    pub async fn write(&self, mut writer: impl AsyncWrite + Unpin) -> Result<()> {
263        let cmd = match self.command {
264            Command::Connect => 1u8,
265            Command::Bind => 2,
266            Command::UdpAssociate => 3,
267        };
268        writer.write_all(&[0x05, cmd, 0x00]).await?;
269        self.address.write(writer).await?;
270        Ok(())
271    }
272}
273
274/// Reply to `CommandRequest`
275#[derive(Debug, PartialEq, PartialOrd)]
276pub enum CommandReply {
277    /// succeeded (0x00)
278    Succeeded,
279    /// general SOCKS server failure (0x01)
280    GeneralSocksServerFailure,
281    /// connection not allowed by ruleset (0x02)
282    ConnectionNotAllowedByRuleset,
283    /// Network unreachable (0x03)
284    NetworkUnreachable,
285    /// Host unreachable (0x04)
286    HostUnreachable,
287    /// Connection refused (0x05)
288    ConnectionRefused,
289    /// TTL expired (0x06)
290    TtlExpired,
291    /// Command not supported (0x07)
292    CommandNotSupported,
293    /// Address type not supported (0x08)
294    AddressTypeNotSupported,
295}
296
297impl CommandReply {
298    /// From `u8` to `CommandReply`.
299    pub fn from_u8(n: u8) -> Result<CommandReply> {
300        Ok(match n {
301            0 => CommandReply::Succeeded,
302            1 => CommandReply::GeneralSocksServerFailure,
303            2 => CommandReply::ConnectionNotAllowedByRuleset,
304            3 => CommandReply::NetworkUnreachable,
305            4 => CommandReply::HostUnreachable,
306            5 => CommandReply::ConnectionRefused,
307            6 => CommandReply::TtlExpired,
308            7 => CommandReply::CommandNotSupported,
309            8 => CommandReply::AddressTypeNotSupported,
310            _ => return Err(Error::InvalidCommandReply(n)),
311        })
312    }
313    /// From `CommandReply` to `u8`.
314    pub fn to_u8(&self) -> u8 {
315        match self {
316            CommandReply::Succeeded => 0,
317            CommandReply::GeneralSocksServerFailure => 1,
318            CommandReply::ConnectionNotAllowedByRuleset => 2,
319            CommandReply::NetworkUnreachable => 3,
320            CommandReply::HostUnreachable => 4,
321            CommandReply::ConnectionRefused => 5,
322            CommandReply::TtlExpired => 6,
323            CommandReply::CommandNotSupported => 7,
324            CommandReply::AddressTypeNotSupported => 8,
325        }
326    }
327}
328
329/// `CommandResponse` message:
330///
331/// <pre><code>
332/// +-----+-------+------+----------+----------+
333/// | REP |  RSV  | ATYP | BND.ADDR | BND.PORT |
334/// +-----+-------+------+----------+----------+
335/// |  1  | X'00' |  1   | Variable |    2     |
336/// +-----+-------+------+----------+----------+
337/// </code></pre>
338#[derive(Debug)]
339pub struct CommandResponse {
340    /// Reply (REP).
341    pub reply: CommandReply,
342    /// Address (ATYP, BND.ADDR, BND.PORT).
343    pub address: Address,
344}
345
346impl CommandResponse {
347    /// Create a success `CommandResponse` with bind `address`.
348    pub fn success(address: Address) -> CommandResponse {
349        CommandResponse {
350            reply: CommandReply::Succeeded,
351            address,
352        }
353    }
354    /// Create a error `CommandResponse` with `reply`.
355    pub fn reply_error(reply: CommandReply) -> CommandResponse {
356        CommandResponse {
357            reply,
358            address: Default::default(),
359        }
360    }
361    /// Create a error `CommandResponse` with any `io::error`.
362    pub fn error(e: impl TryInto<io::Error>) -> CommandResponse {
363        match e.try_into() {
364            Ok(v) => {
365                use io::ErrorKind;
366                let reply = match v.kind() {
367                    ErrorKind::ConnectionRefused => CommandReply::ConnectionRefused,
368                    _ => CommandReply::GeneralSocksServerFailure,
369                };
370                CommandResponse {
371                    reply,
372                    address: Default::default(),
373                }
374            }
375            Err(_) => CommandResponse {
376                reply: CommandReply::GeneralSocksServerFailure,
377                address: Default::default(),
378            },
379        }
380    }
381    /// Read `CommandResponse` from `AsyncRead`.
382    pub async fn read(mut reader: impl AsyncRead + Unpin) -> Result<CommandResponse> {
383        let buf = &mut [0u8; 3];
384        reader.read_exact(buf).await?;
385        if buf[0] != 5 {
386            return Err(Error::InvalidVersion(buf[0]));
387        }
388        if buf[2] != 0 {
389            return Err(Error::InvalidHandshake);
390        }
391        let reply = CommandReply::from_u8(buf[1])?;
392
393        let address = Address::read(reader).await?;
394
395        if reply != CommandReply::Succeeded {
396            return Err(Error::CommandReply(reply));
397        }
398
399        Ok(CommandResponse { reply, address })
400    }
401    /// Write `CommandResponse` to `AsyncWrite`.
402    pub async fn write(&self, mut writer: impl AsyncWrite + Unpin) -> Result<()> {
403        writer.write_all(&[0x05, self.reply.to_u8(), 0x00]).await?;
404        self.address.write(writer).await?;
405        Ok(())
406    }
407}
408
409/// Address type in socks5.
410#[derive(Debug)]
411pub enum Address {
412    /// SocketAddr
413    SocketAddr(SocketAddr),
414    /// Domain
415    Domain(String, u16),
416}
417
418impl fmt::Display for Address {
419    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
420        match self {
421            Address::SocketAddr(s) => fmt::Display::fmt(s, f),
422            Address::Domain(domain, port) => write!(f, "{}:{}", domain, port),
423        }
424    }
425}
426
427impl Default for Address {
428    fn default() -> Self {
429        Address::SocketAddr(SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0))
430    }
431}
432
433impl From<SocketAddr> for Address {
434    fn from(addr: SocketAddr) -> Self {
435        Address::SocketAddr(addr)
436    }
437}
438
439fn strip_brackets(host: &str) -> &str {
440    host.strip_prefix('[')
441        .and_then(|h| h.strip_suffix(']'))
442        .unwrap_or(host)
443}
444
445fn host_to_address(host: &str, port: u16) -> Address {
446    match strip_brackets(host).parse::<IpAddr>() {
447        Ok(ip) => {
448            let addr = SocketAddr::new(ip, port);
449            addr.into()
450        }
451        Err(_) => Address::Domain(host.to_string(), port),
452    }
453}
454fn no_addr() -> io::Error {
455    io::ErrorKind::AddrNotAvailable.into()
456}
457
458impl FromStr for Address {
459    type Err = Error;
460
461    fn from_str(s: &str) -> Result<Self, Self::Err> {
462        let mut parts = s.rsplitn(2, ':');
463        let port: u16 = parts
464            .next()
465            .ok_or_else(no_addr)?
466            .parse()
467            .map_err(|_| no_addr())?;
468        let host = parts.next().ok_or_else(no_addr)?;
469        Ok(host_to_address(host, port))
470    }
471}
472
473impl Address {
474    /// Convert `Address` to `SocketAddr`. If `Address` is a domain, return `std::io::ErrorKind::InvalidInput`
475    pub fn to_socket_addr(self) -> Result<SocketAddr> {
476        match self {
477            Address::SocketAddr(s) => Ok(s),
478            _ => Err(Error::Io(io::ErrorKind::InvalidInput.into())),
479        }
480    }
481    async fn read_port<R>(mut reader: R) -> Result<u16>
482    where
483        R: AsyncRead + Unpin,
484    {
485        let mut buf = [0u8; 2];
486        reader.read_exact(&mut buf).await?;
487        let port = u16::from_be_bytes(buf);
488        Ok(port)
489    }
490    async fn write_port<W>(mut writer: W, port: u16) -> Result<()>
491    where
492        W: AsyncWrite + Unpin,
493    {
494        writer.write_all(&port.to_be_bytes()).await?;
495        Ok(())
496    }
497    /// Length of `Address` in bytes after serialized.
498    pub fn serialized_len(&self) -> Result<usize> {
499        Ok(match self {
500            Address::SocketAddr(SocketAddr::V4(_)) => {
501                // 1 byte for ATYP, 4 bytes for IPV4 address, 2 bytes for port
502                1 + 4 + 2
503            }
504            Address::SocketAddr(SocketAddr::V6(_)) => {
505                // 1 byte for ATYP, 16 bytes for IPV6 address, 2 bytes for port
506                1 + 16 + 2
507            }
508            Address::Domain(domain, _) => {
509                if domain.len() >= 256 {
510                    return Err(Error::DomainTooLong(domain.len()));
511                }
512                // 1 byte for ATYP, 1 byte for domain length, domain, 2 bytes for port
513                1 + 1 + domain.len() + 2
514            }
515        })
516    }
517    /// Write `Address` to `AsyncWrite`.
518    pub async fn write<W>(&self, mut writer: W) -> Result<()>
519    where
520        W: AsyncWrite + Unpin,
521    {
522        match self {
523            Address::SocketAddr(SocketAddr::V4(addr)) => {
524                writer.write_all(&[0x01]).await?;
525                writer.write_all(&addr.ip().octets()).await?;
526                Self::write_port(writer, addr.port()).await?;
527            }
528            Address::SocketAddr(SocketAddr::V6(addr)) => {
529                writer.write_all(&[0x04]).await?;
530                writer.write_all(&addr.ip().octets()).await?;
531                Self::write_port(writer, addr.port()).await?;
532            }
533            Address::Domain(domain, port) => {
534                if domain.len() >= 256 {
535                    return Err(Error::DomainTooLong(domain.len()));
536                }
537                let header = [0x03, domain.len() as u8];
538                writer.write_all(&header).await?;
539                writer.write_all(domain.as_bytes()).await?;
540                Self::write_port(writer, *port).await?;
541            }
542        };
543        Ok(())
544    }
545    /// Read `Address` from `AsyncRead`.
546    pub async fn read<R>(mut reader: R) -> Result<Self>
547    where
548        R: AsyncRead + Unpin,
549    {
550        let mut atyp = [0u8; 1];
551        reader.read_exact(&mut atyp).await?;
552
553        Ok(match atyp[0] {
554            1 => {
555                let mut ip = [0u8; 4];
556                reader.read_exact(&mut ip).await?;
557                Address::SocketAddr(SocketAddr::new(
558                    ip.into(),
559                    Self::read_port(&mut reader).await?,
560                ))
561            }
562            3 => {
563                let mut len = [0u8; 1];
564                reader.read_exact(&mut len).await?;
565                let len = len[0] as usize;
566                let mut domain = vec![0u8; len];
567                reader.read_exact(&mut domain).await?;
568
569                let domain =
570                    String::from_utf8(domain).map_err(|e| Error::InvalidDomain(e.into_bytes()))?;
571
572                Address::Domain(domain, Self::read_port(&mut reader).await?)
573            }
574            4 => {
575                let mut ip = [0u8; 16];
576                reader.read_exact(&mut ip).await?;
577                Address::SocketAddr(SocketAddr::new(
578                    ip.into(),
579                    Self::read_port(&mut reader).await?,
580                ))
581            }
582            _ => return Err(Error::InvalidAddressType(atyp[0])),
583        })
584    }
585}
586
587#[cfg(test)]
588mod tests {
589    use super::*;
590
591    #[test]
592    fn test_address_display() {
593        let addr = Address::SocketAddr("1.2.3.4:56789".parse().unwrap());
594        assert_eq!(addr.to_string(), "1.2.3.4:56789");
595
596        let addr = Address::Domain("example.com".to_string(), 80);
597        assert_eq!(addr.to_string(), "example.com:80");
598    }
599
600    #[test]
601    fn test_address_from_str() {
602        let addr: Address = "1.2.3.4:56789".parse().unwrap();
603        assert_eq!(addr.to_string(), "1.2.3.4:56789");
604
605        let addr: Address = "example.com:80".parse().unwrap();
606        assert_eq!(addr.to_string(), "example.com:80");
607
608        let addr: Result<Address, _> = "example.com".parse();
609        assert!(addr.is_err());
610    }
611
612    #[test]
613    fn test_address_serialized_len() {
614        let addr: Address = "1.2.3.4:56789".parse().unwrap();
615        assert_eq!(addr.serialized_len().unwrap(), 7);
616
617        let addr: Address = "[::1]:56789".parse().unwrap();
618        assert_eq!(addr.serialized_len().unwrap(), 19);
619
620        let addr: Address = "example.com:80".parse().unwrap();
621        assert_eq!(addr.serialized_len().unwrap(), 15);
622    }
623}