ss_light/
handshake.rs

1use std::{
2    fmt::{self, Formatter},
3    io,
4    net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
5};
6
7use bytes::BufMut;
8use tokio::{
9    io::{AsyncRead, AsyncReadExt},
10    net::TcpStream,
11};
12
13use crate::consts::*;
14
15#[derive(PartialEq, Eq, PartialOrd, Ord)]
16pub enum Address {
17    SocketAddress(SocketAddr),
18    DomainNameAddress(String, u16), // domain name, port
19}
20
21impl Address {
22    pub async fn read_from<R>(stream: &mut R) -> Result<Address, Error>
23    where
24        R: AsyncRead + Unpin,
25    {
26        let mut addr_type_buf = [0u8; 1];
27        stream.read_exact(&mut addr_type_buf).await?;
28
29        match addr_type_buf[0] {
30            SOCKS5_ADDR_TYPE_IPV4 => {
31                let mut buf = [0u8; 6];
32                stream.read_exact(&mut buf).await?;
33                let ip = Ipv4Addr::new(buf[0], buf[1], buf[2], buf[3]);
34                let port = u16::from_be_bytes([buf[4], buf[5]]);
35                Ok(Address::SocketAddress(SocketAddr::V4(SocketAddrV4::new(
36                    ip, port,
37                ))))
38            }
39            SOCKS5_ADDR_TYPE_IPV6 => {
40                let mut buf = [0u8; 18];
41                stream.read_exact(&mut buf).await?;
42                let ip = Ipv6Addr::from([
43                    buf[0], buf[1], buf[2], buf[3], buf[4], buf[5], buf[6], buf[7], buf[8], buf[9],
44                    buf[10], buf[11], buf[12], buf[13], buf[14], buf[15],
45                ]);
46                let port = u16::from_be_bytes([buf[16], buf[17]]);
47                Ok(Address::SocketAddress(SocketAddr::V6(SocketAddrV6::new(
48                    ip, port, 0, 0,
49                ))))
50            }
51            SOCKS5_ADDR_TYPE_DOMAIN_NAME => {
52                let mut length_buf = [0u8; 1];
53                stream.read_exact(&mut length_buf).await?;
54                let length = length_buf[0] as usize;
55
56                let buf_length = length + 2; // domain + port
57                let mut buf = vec![0u8; buf_length];
58                stream.read_exact(&mut buf).await?;
59
60                let port = u16::from_be_bytes([buf[length], buf[length + 1]]);
61                buf.truncate(length);
62                let addr = String::from_utf8(buf)?;
63
64                Ok(Address::DomainNameAddress(addr, port))
65            }
66            _ => Err(Error::UnknownAddressType(addr_type_buf[0])),
67        }
68    }
69
70    pub fn write_socket_addr_to_buf<B: BufMut>(addr: &SocketAddr, buf: &mut B) {
71        match *addr {
72            SocketAddr::V4(ref addr) => {
73                buf.put_u8(SOCKS5_ADDR_TYPE_IPV4);
74                buf.put_slice(&addr.ip().octets());
75                buf.put_u16(addr.port());
76            }
77            SocketAddr::V6(ref addr) => {
78                buf.put_u8(SOCKS5_ADDR_TYPE_IPV6);
79                for seg in &addr.ip().segments() {
80                    buf.put_u16(*seg);
81                }
82                buf.put_u16(addr.port());
83            }
84        }
85    }
86
87    pub fn port(&self) -> u16 {
88        match *self {
89            Address::SocketAddress(addr) => addr.port(),
90            Address::DomainNameAddress(.., port) => port,
91        }
92    }
93
94    pub fn host(&self) -> String {
95        match *self {
96            Address::SocketAddress(ref addr) => addr.ip().to_string(),
97            Address::DomainNameAddress(ref domain, ..) => domain.to_owned(),
98        }
99    }
100
101    pub async fn connect(&self) -> io::Result<TcpStream> {
102        let stream = match *self {
103            Address::SocketAddress(ref sa) => TcpStream::connect(sa).await?,
104            Address::DomainNameAddress(ref dname, port) => {
105                TcpStream::connect((dname.as_str(), port)).await?
106            }
107        };
108        Ok(stream)
109    }
110}
111
112impl fmt::Display for Address {
113    #[inline]
114    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
115        match *self {
116            Address::SocketAddress(ref addr) => write!(f, "{}", addr),
117            Address::DomainNameAddress(ref addr, ref port) => write!(f, "{}:{}", addr, port),
118        }
119    }
120}