Skip to main content

socks5_impl/protocol/
address.rs

1#[cfg(feature = "tokio")]
2use crate::protocol::AsyncStreamOperation;
3use crate::protocol::StreamOperation;
4use bytes::BufMut;
5use std::{
6    io::Cursor,
7    net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, ToSocketAddrs},
8};
9#[cfg(feature = "tokio")]
10use tokio::io::{AsyncRead, AsyncReadExt};
11
12#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd, Default)]
13#[repr(u8)]
14pub enum AddressType {
15    #[default]
16    IPv4 = 0x01,
17    Domain = 0x03,
18    IPv6 = 0x04,
19}
20
21impl TryFrom<u8> for AddressType {
22    type Error = std::io::Error;
23    fn try_from(code: u8) -> core::result::Result<Self, Self::Error> {
24        let err = format!("Unsupported address type code {code:#x}");
25        match code {
26            0x01 => Ok(AddressType::IPv4),
27            0x03 => Ok(AddressType::Domain),
28            0x04 => Ok(AddressType::IPv6),
29            _ => Err(std::io::Error::new(std::io::ErrorKind::InvalidInput, err)),
30        }
31    }
32}
33
34impl From<AddressType> for u8 {
35    fn from(addr_type: AddressType) -> Self {
36        match addr_type {
37            AddressType::IPv4 => 0x01,
38            AddressType::Domain => 0x03,
39            AddressType::IPv6 => 0x04,
40        }
41    }
42}
43
44/// SOCKS5 Address Format
45///
46/// ```plain
47/// +------+----------+----------+
48/// | ATYP | DST.ADDR | DST.PORT |
49/// +------+----------+----------+
50/// |  1   | Variable |    2     |
51/// +------+----------+----------+
52/// ```
53#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
54#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
55pub enum Address {
56    /// Represents an IPv4 or IPv6 socket address.
57    SocketAddress(SocketAddr),
58    /// Represents a domain name and a port.
59    DomainAddress(Box<str>, u16),
60}
61
62impl Address {
63    /// Returns an unspecified IPv4 address (0.0.0.0:0).
64    pub fn unspecified() -> Self {
65        Address::SocketAddress(SocketAddr::from((Ipv4Addr::UNSPECIFIED, 0)))
66    }
67
68    /// Returns the type of the address (IPv4, IPv6, or Domain).
69    pub fn get_type(&self) -> AddressType {
70        match self {
71            Self::SocketAddress(SocketAddr::V4(_)) => AddressType::IPv4,
72            Self::SocketAddress(SocketAddr::V6(_)) => AddressType::IPv6,
73            Self::DomainAddress(_, _) => AddressType::Domain,
74        }
75    }
76
77    /// Returns the port number.
78    pub fn port(&self) -> u16 {
79        match self {
80            Self::SocketAddress(addr) => addr.port(),
81            Self::DomainAddress(_, port) => *port,
82        }
83    }
84
85    /// Returns the domain name or IP address as a string.
86    pub fn domain(&self) -> String {
87        match self {
88            Self::SocketAddress(addr) => addr.ip().to_string(),
89            Self::DomainAddress(addr, _) => addr.to_string(),
90        }
91    }
92
93    /// Returns `true` if it is an IPv4 address.
94    pub fn is_ipv4(&self) -> bool {
95        matches!(self, Self::SocketAddress(SocketAddr::V4(_)))
96    }
97
98    /// Returns `true` if it is an IPv6 address.
99    pub fn is_ipv6(&self) -> bool {
100        matches!(self, Self::SocketAddress(SocketAddr::V6(_)))
101    }
102
103    /// Returns `true` if it is a domain address.
104    pub fn is_domain(&self) -> bool {
105        matches!(self, Self::DomainAddress(_, _))
106    }
107
108    /// Returns the maximum possible serialized length of a SOCKS5 address.
109    pub const fn max_serialized_len() -> usize {
110        1 + 1 + u8::MAX as usize + 2
111    }
112}
113
114impl StreamOperation for Address {
115    fn retrieve_from_stream<R: std::io::Read>(stream: &mut R) -> std::io::Result<Self> {
116        let mut atyp_buf = [0; 1];
117        stream.read_exact(&mut atyp_buf)?;
118        match AddressType::try_from(atyp_buf[0])? {
119            AddressType::IPv4 => {
120                let mut buf = [0; 6];
121                stream.read_exact(&mut buf)?;
122                let addr = Ipv4Addr::new(buf[0], buf[1], buf[2], buf[3]);
123                let port = u16::from_be_bytes([buf[4], buf[5]]);
124                Ok(Self::SocketAddress(SocketAddr::from((addr, port))))
125            }
126            AddressType::Domain => {
127                let mut len_buf = [0; 1];
128                stream.read_exact(&mut len_buf)?;
129                let len = len_buf[0] as usize;
130                let mut domain_buf = vec![0; len];
131                stream.read_exact(&mut domain_buf)?;
132
133                let mut port_buf = [0; 2];
134                stream.read_exact(&mut port_buf)?;
135                let port = u16::from_be_bytes(port_buf);
136
137                let addr = String::from_utf8(domain_buf)
138                    .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Invalid address encoding: {err}")))?;
139                Ok(Self::DomainAddress(addr.into_boxed_str(), port))
140            }
141            AddressType::IPv6 => {
142                let mut buf = [0; 18];
143                stream.read_exact(&mut buf)?;
144                let port = u16::from_be_bytes([buf[16], buf[17]]);
145                let mut addr_bytes = [0; 16];
146                addr_bytes.copy_from_slice(&buf[..16]);
147                Ok(Self::SocketAddress(SocketAddr::from((Ipv6Addr::from(addr_bytes), port))))
148            }
149        }
150    }
151
152    fn write_to_buf<B: BufMut>(&self, buf: &mut B) {
153        match self {
154            Self::SocketAddress(SocketAddr::V4(addr)) => {
155                buf.put_u8(AddressType::IPv4.into());
156                buf.put_slice(&addr.ip().octets());
157                buf.put_u16(addr.port());
158            }
159            Self::SocketAddress(SocketAddr::V6(addr)) => {
160                buf.put_u8(AddressType::IPv6.into());
161                buf.put_slice(&addr.ip().octets());
162                buf.put_u16(addr.port());
163            }
164            Self::DomainAddress(addr, port) => {
165                let addr = addr.as_bytes();
166                buf.put_u8(AddressType::Domain.into());
167                buf.put_u8(u8::try_from(addr.len()).expect("domain address too long for SOCKS5"));
168                buf.put_slice(addr);
169                buf.put_u16(*port);
170            }
171        }
172    }
173
174    fn len(&self) -> usize {
175        match self {
176            Address::SocketAddress(SocketAddr::V4(_)) => 1 + 4 + 2,
177            Address::SocketAddress(SocketAddr::V6(_)) => 1 + 16 + 2,
178            Address::DomainAddress(addr, _) => 1 + 1 + addr.len() + 2,
179        }
180    }
181}
182
183#[cfg(feature = "tokio")]
184#[async_trait::async_trait]
185impl AsyncStreamOperation for Address {
186    async fn retrieve_from_async_stream<R>(stream: &mut R) -> std::io::Result<Self>
187    where
188        R: AsyncRead + Unpin + Send + ?Sized,
189    {
190        let atyp = stream.read_u8().await?;
191        match AddressType::try_from(atyp)? {
192            AddressType::IPv4 => {
193                let mut addr_bytes = [0; 4];
194                stream.read_exact(&mut addr_bytes).await?;
195                let port = stream.read_u16().await?;
196                let addr = Ipv4Addr::from(addr_bytes);
197                Ok(Self::SocketAddress(SocketAddr::from((addr, port))))
198            }
199            AddressType::Domain => {
200                let len = stream.read_u8().await? as usize;
201                let mut domain_buf = vec![0; len];
202                stream.read_exact(&mut domain_buf).await?;
203                let port = stream.read_u16().await?;
204
205                let addr = String::from_utf8(domain_buf)
206                    .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Invalid address encoding: {err}")))?;
207                Ok(Self::DomainAddress(addr.into_boxed_str(), port))
208            }
209            AddressType::IPv6 => {
210                let mut addr_bytes = [0; 16];
211                stream.read_exact(&mut addr_bytes).await?;
212                let port = stream.read_u16().await?;
213                Ok(Self::SocketAddress(SocketAddr::from((Ipv6Addr::from(addr_bytes), port))))
214            }
215        }
216    }
217}
218
219impl ToSocketAddrs for Address {
220    type Iter = std::vec::IntoIter<SocketAddr>;
221
222    fn to_socket_addrs(&self) -> std::io::Result<Self::Iter> {
223        match self {
224            Address::SocketAddress(addr) => Ok(vec![*addr].into_iter()),
225            Address::DomainAddress(addr, port) => Ok((&**addr, *port).to_socket_addrs()?),
226        }
227    }
228}
229
230impl std::fmt::Display for Address {
231    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
232        match self {
233            Address::DomainAddress(hostname, port) => write!(f, "{hostname}:{port}"),
234            Address::SocketAddress(socket_addr) => write!(f, "{socket_addr}"),
235        }
236    }
237}
238
239impl TryFrom<Address> for SocketAddr {
240    type Error = std::io::Error;
241
242    fn try_from(address: Address) -> std::result::Result<Self, Self::Error> {
243        match address {
244            Address::SocketAddress(addr) => Ok(addr),
245            Address::DomainAddress(addr, port) => {
246                if let Ok(addr) = addr.parse::<Ipv4Addr>() {
247                    Ok(SocketAddr::from((addr, port)))
248                } else if let Ok(addr) = addr.parse::<Ipv6Addr>() {
249                    Ok(SocketAddr::from((addr, port)))
250                } else if let Ok(addr) = addr.parse::<SocketAddr>() {
251                    Ok(addr)
252                } else {
253                    let err = format!("domain address {addr} is not supported");
254                    Err(Self::Error::new(std::io::ErrorKind::Unsupported, err))
255                }
256            }
257        }
258    }
259}
260
261impl TryFrom<&Address> for SocketAddr {
262    type Error = std::io::Error;
263
264    fn try_from(address: &Address) -> std::result::Result<Self, Self::Error> {
265        TryFrom::<Address>::try_from(address.clone())
266    }
267}
268
269impl From<Address> for Vec<u8> {
270    fn from(addr: Address) -> Self {
271        let mut buf = Vec::with_capacity(addr.len());
272        addr.write_to_buf(&mut buf);
273        buf
274    }
275}
276
277impl TryFrom<Vec<u8>> for Address {
278    type Error = std::io::Error;
279
280    fn try_from(data: Vec<u8>) -> std::result::Result<Self, Self::Error> {
281        let mut rdr = Cursor::new(data);
282        Self::retrieve_from_stream(&mut rdr)
283    }
284}
285
286impl TryFrom<&[u8]> for Address {
287    type Error = std::io::Error;
288
289    fn try_from(data: &[u8]) -> std::result::Result<Self, Self::Error> {
290        let mut rdr = Cursor::new(data);
291        Self::retrieve_from_stream(&mut rdr)
292    }
293}
294
295impl From<SocketAddr> for Address {
296    fn from(addr: SocketAddr) -> Self {
297        Address::SocketAddress(addr)
298    }
299}
300
301impl From<&SocketAddr> for Address {
302    fn from(addr: &SocketAddr) -> Self {
303        Address::SocketAddress(*addr)
304    }
305}
306
307impl From<(Ipv4Addr, u16)> for Address {
308    fn from((addr, port): (Ipv4Addr, u16)) -> Self {
309        Address::SocketAddress(SocketAddr::from((addr, port)))
310    }
311}
312
313impl From<(Ipv6Addr, u16)> for Address {
314    fn from((addr, port): (Ipv6Addr, u16)) -> Self {
315        Address::SocketAddress(SocketAddr::from((addr, port)))
316    }
317}
318
319impl From<(IpAddr, u16)> for Address {
320    fn from((addr, port): (IpAddr, u16)) -> Self {
321        Address::SocketAddress(SocketAddr::from((addr, port)))
322    }
323}
324
325impl From<(String, u16)> for Address {
326    fn from((addr, port): (String, u16)) -> Self {
327        Address::DomainAddress(addr.into_boxed_str(), port)
328    }
329}
330
331impl From<(&str, u16)> for Address {
332    fn from((addr, port): (&str, u16)) -> Self {
333        Address::DomainAddress(addr.into(), port)
334    }
335}
336
337impl From<&Address> for Address {
338    fn from(addr: &Address) -> Self {
339        addr.clone()
340    }
341}
342
343impl TryFrom<&str> for Address {
344    type Error = crate::Error;
345
346    fn try_from(addr: &str) -> std::result::Result<Self, Self::Error> {
347        if let Ok(addr) = addr.parse::<SocketAddr>() {
348            Ok(Address::SocketAddress(addr))
349        } else {
350            let (addr, port) = if let Some(pos) = addr.rfind(':') {
351                (&addr[..pos], &addr[pos + 1..])
352            } else {
353                (addr, "0")
354            };
355            let port = port.parse::<u16>()?;
356            Ok(Address::DomainAddress(addr.into(), port))
357        }
358    }
359}
360
361impl std::str::FromStr for Address {
362    type Err = crate::Error;
363
364    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
365        Self::try_from(s)
366    }
367}
368
369#[test]
370fn test_address() {
371    let addr = Address::from((Ipv4Addr::new(127, 0, 0, 1), 8080));
372    let mut buf = Vec::new();
373    addr.write_to_buf(&mut buf);
374    assert_eq!(buf, vec![0x01, 127, 0, 0, 1, 0x1f, 0x90]);
375    let addr2 = Address::retrieve_from_stream(&mut Cursor::new(&buf)).unwrap();
376    assert_eq!(addr, addr2);
377
378    let addr = Address::from((Ipv6Addr::new(0x45, 0xff89, 0, 0, 0, 0, 0, 1), 8080));
379    let mut buf = Vec::new();
380    addr.write_to_buf(&mut buf);
381    assert_eq!(buf, vec![0x04, 0, 0x45, 0xff, 0x89, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0x1f, 0x90]);
382    let addr2 = Address::retrieve_from_stream(&mut Cursor::new(&buf)).unwrap();
383    assert_eq!(addr, addr2);
384
385    let addr = Address::from(("sex.com", 8080));
386    let mut buf = Vec::new();
387    addr.write_to_buf(&mut buf);
388    assert_eq!(buf, vec![0x03, 0x07, b's', b'e', b'x', b'.', b'c', b'o', b'm', 0x1f, 0x90]);
389    let addr2 = Address::retrieve_from_stream(&mut Cursor::new(&buf)).unwrap();
390    assert_eq!(addr, addr2);
391}
392
393#[cfg(feature = "tokio")]
394#[tokio::test]
395async fn test_address_async() {
396    let addr = Address::from((Ipv4Addr::new(127, 0, 0, 1), 8080));
397    let mut buf = Vec::new();
398    addr.write_to_async_stream(&mut buf).await.unwrap();
399    assert_eq!(buf, vec![0x01, 127, 0, 0, 1, 0x1f, 0x90]);
400    let addr2 = Address::retrieve_from_async_stream(&mut Cursor::new(&buf)).await.unwrap();
401    assert_eq!(addr, addr2);
402
403    let addr = Address::from((Ipv6Addr::new(0x45, 0xff89, 0, 0, 0, 0, 0, 1), 8080));
404    let mut buf = Vec::new();
405    addr.write_to_async_stream(&mut buf).await.unwrap();
406    assert_eq!(buf, vec![0x04, 0, 0x45, 0xff, 0x89, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0x1f, 0x90]);
407    let addr2 = Address::retrieve_from_async_stream(&mut Cursor::new(&buf)).await.unwrap();
408    assert_eq!(addr, addr2);
409
410    let addr = Address::from(("sex.com", 8080));
411    let mut buf = Vec::new();
412    addr.write_to_async_stream(&mut buf).await.unwrap();
413    assert_eq!(buf, vec![0x03, 0x07, b's', b'e', b'x', b'.', b'c', b'o', b'm', 0x1f, 0x90]);
414    let addr2 = Address::retrieve_from_async_stream(&mut Cursor::new(&buf)).await.unwrap();
415    assert_eq!(addr, addr2);
416}