socks5_impl/protocol/
address.rs

1#[cfg(feature = "tokio")]
2use crate::protocol::AsyncStreamOperation;
3use crate::protocol::StreamOperation;
4use bytes::BufMut;
5use std::{
6    io::Cursor,
7    net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, ToSocketAddrs},
8};
9#[cfg(feature = "tokio")]
10use tokio::io::{AsyncRead, AsyncReadExt};
11
12#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd, Default)]
13#[repr(u8)]
14pub enum AddressType {
15    #[default]
16    IPv4 = 0x01,
17    Domain = 0x03,
18    IPv6 = 0x04,
19}
20
21impl TryFrom<u8> for AddressType {
22    type Error = std::io::Error;
23    fn try_from(code: u8) -> core::result::Result<Self, Self::Error> {
24        let err = format!("Unsupported address type code {code:#x}");
25        match code {
26            0x01 => Ok(AddressType::IPv4),
27            0x03 => Ok(AddressType::Domain),
28            0x04 => Ok(AddressType::IPv6),
29            _ => Err(std::io::Error::new(std::io::ErrorKind::InvalidInput, err)),
30        }
31    }
32}
33
34impl From<AddressType> for u8 {
35    fn from(addr_type: AddressType) -> Self {
36        match addr_type {
37            AddressType::IPv4 => 0x01,
38            AddressType::Domain => 0x03,
39            AddressType::IPv6 => 0x04,
40        }
41    }
42}
43
44/// SOCKS5 Address Format
45///
46/// ```plain
47/// +------+----------+----------+
48/// | ATYP | DST.ADDR | DST.PORT |
49/// +------+----------+----------+
50/// |  1   | Variable |    2     |
51/// +------+----------+----------+
52/// ```
53#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
54pub enum Address {
55    /// Represents an IPv4 or IPv6 socket address.
56    SocketAddress(SocketAddr),
57    /// Represents a domain name and a port.
58    DomainAddress(Box<str>, u16),
59}
60
61impl Address {
62    /// Returns an unspecified IPv4 address (0.0.0.0:0).
63    pub fn unspecified() -> Self {
64        Address::SocketAddress(SocketAddr::from((Ipv4Addr::UNSPECIFIED, 0)))
65    }
66
67    /// Returns the type of the address (IPv4, IPv6, or Domain).
68    pub fn get_type(&self) -> AddressType {
69        match self {
70            Self::SocketAddress(SocketAddr::V4(_)) => AddressType::IPv4,
71            Self::SocketAddress(SocketAddr::V6(_)) => AddressType::IPv6,
72            Self::DomainAddress(_, _) => AddressType::Domain,
73        }
74    }
75
76    /// Returns the port number.
77    pub fn port(&self) -> u16 {
78        match self {
79            Self::SocketAddress(addr) => addr.port(),
80            Self::DomainAddress(_, port) => *port,
81        }
82    }
83
84    /// Returns the domain name or IP address as a string.
85    pub fn domain(&self) -> String {
86        match self {
87            Self::SocketAddress(addr) => addr.ip().to_string(),
88            Self::DomainAddress(addr, _) => addr.to_string(),
89        }
90    }
91
92    /// Returns `true` if it is an IPv4 address.
93    pub fn is_ipv4(&self) -> bool {
94        matches!(self, Self::SocketAddress(SocketAddr::V4(_)))
95    }
96
97    /// Returns `true` if it is an IPv6 address.
98    pub fn is_ipv6(&self) -> bool {
99        matches!(self, Self::SocketAddress(SocketAddr::V6(_)))
100    }
101
102    /// Returns `true` if it is a domain address.
103    pub fn is_domain(&self) -> bool {
104        matches!(self, Self::DomainAddress(_, _))
105    }
106
107    /// Returns the maximum possible serialized length of a SOCKS5 address.
108    pub const fn max_serialized_len() -> usize {
109        1 + 1 + u8::MAX as usize + 2
110    }
111}
112
113impl StreamOperation for Address {
114    fn retrieve_from_stream<R: std::io::Read>(stream: &mut R) -> std::io::Result<Self> {
115        let mut atyp_buf = [0; 1];
116        stream.read_exact(&mut atyp_buf)?;
117        match AddressType::try_from(atyp_buf[0])? {
118            AddressType::IPv4 => {
119                let mut buf = [0; 6];
120                stream.read_exact(&mut buf)?;
121                let addr = Ipv4Addr::new(buf[0], buf[1], buf[2], buf[3]);
122                let port = u16::from_be_bytes([buf[4], buf[5]]);
123                Ok(Self::SocketAddress(SocketAddr::from((addr, port))))
124            }
125            AddressType::Domain => {
126                let mut len_buf = [0; 1];
127                stream.read_exact(&mut len_buf)?;
128                let len = len_buf[0] as usize;
129                let mut domain_buf = vec![0; len];
130                stream.read_exact(&mut domain_buf)?;
131
132                let mut port_buf = [0; 2];
133                stream.read_exact(&mut port_buf)?;
134                let port = u16::from_be_bytes(port_buf);
135
136                let addr = String::from_utf8(domain_buf)
137                    .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Invalid address encoding: {err}")))?;
138                Ok(Self::DomainAddress(addr.into_boxed_str(), port))
139            }
140            AddressType::IPv6 => {
141                let mut buf = [0; 18];
142                stream.read_exact(&mut buf)?;
143                let port = u16::from_be_bytes([buf[16], buf[17]]);
144                let mut addr_bytes = [0; 16];
145                addr_bytes.copy_from_slice(&buf[..16]);
146                Ok(Self::SocketAddress(SocketAddr::from((Ipv6Addr::from(addr_bytes), port))))
147            }
148        }
149    }
150
151    fn write_to_buf<B: BufMut>(&self, buf: &mut B) {
152        match self {
153            Self::SocketAddress(SocketAddr::V4(addr)) => {
154                buf.put_u8(AddressType::IPv4.into());
155                buf.put_slice(&addr.ip().octets());
156                buf.put_u16(addr.port());
157            }
158            Self::SocketAddress(SocketAddr::V6(addr)) => {
159                buf.put_u8(AddressType::IPv6.into());
160                buf.put_slice(&addr.ip().octets());
161                buf.put_u16(addr.port());
162            }
163            Self::DomainAddress(addr, port) => {
164                let addr = addr.as_bytes();
165                buf.put_u8(AddressType::Domain.into());
166                buf.put_u8(addr.len() as u8);
167                buf.put_slice(addr);
168                buf.put_u16(*port);
169            }
170        }
171    }
172
173    fn len(&self) -> usize {
174        match self {
175            Address::SocketAddress(SocketAddr::V4(_)) => 1 + 4 + 2,
176            Address::SocketAddress(SocketAddr::V6(_)) => 1 + 16 + 2,
177            Address::DomainAddress(addr, _) => 1 + 1 + addr.len() + 2,
178        }
179    }
180}
181
182#[cfg(feature = "tokio")]
183#[async_trait::async_trait]
184impl AsyncStreamOperation for Address {
185    async fn retrieve_from_async_stream<R>(stream: &mut R) -> std::io::Result<Self>
186    where
187        R: AsyncRead + Unpin + Send + ?Sized,
188    {
189        let atyp = stream.read_u8().await?;
190        match AddressType::try_from(atyp)? {
191            AddressType::IPv4 => {
192                let mut addr_bytes = [0; 4];
193                stream.read_exact(&mut addr_bytes).await?;
194                let port = stream.read_u16().await?;
195                let addr = Ipv4Addr::from(addr_bytes);
196                Ok(Self::SocketAddress(SocketAddr::from((addr, port))))
197            }
198            AddressType::Domain => {
199                let len = stream.read_u8().await? as usize;
200                let mut domain_buf = vec![0; len];
201                stream.read_exact(&mut domain_buf).await?;
202                let port = stream.read_u16().await?;
203
204                let addr = String::from_utf8(domain_buf)
205                    .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Invalid address encoding: {err}")))?;
206                Ok(Self::DomainAddress(addr.into_boxed_str(), port))
207            }
208            AddressType::IPv6 => {
209                let mut addr_bytes = [0; 16];
210                stream.read_exact(&mut addr_bytes).await?;
211                let port = stream.read_u16().await?;
212                Ok(Self::SocketAddress(SocketAddr::from((Ipv6Addr::from(addr_bytes), port))))
213            }
214        }
215    }
216}
217
218impl ToSocketAddrs for Address {
219    type Iter = std::vec::IntoIter<SocketAddr>;
220
221    fn to_socket_addrs(&self) -> std::io::Result<Self::Iter> {
222        match self {
223            Address::SocketAddress(addr) => Ok(vec![*addr].into_iter()),
224            Address::DomainAddress(addr, port) => Ok((&**addr, *port).to_socket_addrs()?),
225        }
226    }
227}
228
229impl std::fmt::Display for Address {
230    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
231        match self {
232            Address::DomainAddress(hostname, port) => write!(f, "{hostname}:{port}"),
233            Address::SocketAddress(socket_addr) => write!(f, "{socket_addr}"),
234        }
235    }
236}
237
238impl TryFrom<Address> for SocketAddr {
239    type Error = std::io::Error;
240
241    fn try_from(address: Address) -> std::result::Result<Self, Self::Error> {
242        match address {
243            Address::SocketAddress(addr) => Ok(addr),
244            Address::DomainAddress(addr, port) => {
245                if let Ok(addr) = addr.parse::<Ipv4Addr>() {
246                    Ok(SocketAddr::from((addr, port)))
247                } else if let Ok(addr) = addr.parse::<Ipv6Addr>() {
248                    Ok(SocketAddr::from((addr, port)))
249                } else if let Ok(addr) = addr.parse::<SocketAddr>() {
250                    Ok(addr)
251                } else {
252                    let err = format!("domain address {addr} is not supported");
253                    Err(Self::Error::new(std::io::ErrorKind::Unsupported, err))
254                }
255            }
256        }
257    }
258}
259
260impl TryFrom<&Address> for SocketAddr {
261    type Error = std::io::Error;
262
263    fn try_from(address: &Address) -> std::result::Result<Self, Self::Error> {
264        TryFrom::<Address>::try_from(address.clone())
265    }
266}
267
268impl From<Address> for Vec<u8> {
269    fn from(addr: Address) -> Self {
270        let mut buf = Vec::with_capacity(addr.len());
271        addr.write_to_buf(&mut buf);
272        buf
273    }
274}
275
276impl TryFrom<Vec<u8>> for Address {
277    type Error = std::io::Error;
278
279    fn try_from(data: Vec<u8>) -> std::result::Result<Self, Self::Error> {
280        let mut rdr = Cursor::new(data);
281        Self::retrieve_from_stream(&mut rdr)
282    }
283}
284
285impl TryFrom<&[u8]> for Address {
286    type Error = std::io::Error;
287
288    fn try_from(data: &[u8]) -> std::result::Result<Self, Self::Error> {
289        let mut rdr = Cursor::new(data);
290        Self::retrieve_from_stream(&mut rdr)
291    }
292}
293
294impl From<SocketAddr> for Address {
295    fn from(addr: SocketAddr) -> Self {
296        Address::SocketAddress(addr)
297    }
298}
299
300impl From<&SocketAddr> for Address {
301    fn from(addr: &SocketAddr) -> Self {
302        Address::SocketAddress(*addr)
303    }
304}
305
306impl From<(Ipv4Addr, u16)> for Address {
307    fn from((addr, port): (Ipv4Addr, u16)) -> Self {
308        Address::SocketAddress(SocketAddr::from((addr, port)))
309    }
310}
311
312impl From<(Ipv6Addr, u16)> for Address {
313    fn from((addr, port): (Ipv6Addr, u16)) -> Self {
314        Address::SocketAddress(SocketAddr::from((addr, port)))
315    }
316}
317
318impl From<(IpAddr, u16)> for Address {
319    fn from((addr, port): (IpAddr, u16)) -> Self {
320        Address::SocketAddress(SocketAddr::from((addr, port)))
321    }
322}
323
324impl From<(String, u16)> for Address {
325    fn from((addr, port): (String, u16)) -> Self {
326        Address::DomainAddress(addr.into_boxed_str(), port)
327    }
328}
329
330impl From<(&str, u16)> for Address {
331    fn from((addr, port): (&str, u16)) -> Self {
332        Address::DomainAddress(addr.into(), port)
333    }
334}
335
336impl From<&Address> for Address {
337    fn from(addr: &Address) -> Self {
338        addr.clone()
339    }
340}
341
342impl TryFrom<&str> for Address {
343    type Error = crate::Error;
344
345    fn try_from(addr: &str) -> std::result::Result<Self, Self::Error> {
346        if let Ok(addr) = addr.parse::<SocketAddr>() {
347            Ok(Address::SocketAddress(addr))
348        } else {
349            let (addr, port) = if let Some(pos) = addr.rfind(':') {
350                (&addr[..pos], &addr[pos + 1..])
351            } else {
352                (addr, "0")
353            };
354            let port = port.parse::<u16>()?;
355            Ok(Address::DomainAddress(addr.into(), port))
356        }
357    }
358}
359
360#[test]
361fn test_address() {
362    let addr = Address::from((Ipv4Addr::new(127, 0, 0, 1), 8080));
363    let mut buf = Vec::new();
364    addr.write_to_buf(&mut buf);
365    assert_eq!(buf, vec![0x01, 127, 0, 0, 1, 0x1f, 0x90]);
366    let addr2 = Address::retrieve_from_stream(&mut Cursor::new(&buf)).unwrap();
367    assert_eq!(addr, addr2);
368
369    let addr = Address::from((Ipv6Addr::new(0x45, 0xff89, 0, 0, 0, 0, 0, 1), 8080));
370    let mut buf = Vec::new();
371    addr.write_to_buf(&mut buf);
372    assert_eq!(buf, vec![0x04, 0, 0x45, 0xff, 0x89, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0x1f, 0x90]);
373    let addr2 = Address::retrieve_from_stream(&mut Cursor::new(&buf)).unwrap();
374    assert_eq!(addr, addr2);
375
376    let addr = Address::from(("sex.com", 8080));
377    let mut buf = Vec::new();
378    addr.write_to_buf(&mut buf);
379    assert_eq!(buf, vec![0x03, 0x07, b's', b'e', b'x', b'.', b'c', b'o', b'm', 0x1f, 0x90]);
380    let addr2 = Address::retrieve_from_stream(&mut Cursor::new(&buf)).unwrap();
381    assert_eq!(addr, addr2);
382}
383
384#[cfg(feature = "tokio")]
385#[tokio::test]
386async fn test_address_async() {
387    let addr = Address::from((Ipv4Addr::new(127, 0, 0, 1), 8080));
388    let mut buf = Vec::new();
389    addr.write_to_async_stream(&mut buf).await.unwrap();
390    assert_eq!(buf, vec![0x01, 127, 0, 0, 1, 0x1f, 0x90]);
391    let addr2 = Address::retrieve_from_async_stream(&mut Cursor::new(&buf)).await.unwrap();
392    assert_eq!(addr, addr2);
393
394    let addr = Address::from((Ipv6Addr::new(0x45, 0xff89, 0, 0, 0, 0, 0, 1), 8080));
395    let mut buf = Vec::new();
396    addr.write_to_async_stream(&mut buf).await.unwrap();
397    assert_eq!(buf, vec![0x04, 0, 0x45, 0xff, 0x89, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0x1f, 0x90]);
398    let addr2 = Address::retrieve_from_async_stream(&mut Cursor::new(&buf)).await.unwrap();
399    assert_eq!(addr, addr2);
400
401    let addr = Address::from(("sex.com", 8080));
402    let mut buf = Vec::new();
403    addr.write_to_async_stream(&mut buf).await.unwrap();
404    assert_eq!(buf, vec![0x03, 0x07, b's', b'e', b'x', b'.', b'c', b'o', b'm', 0x1f, 0x90]);
405    let addr2 = Address::retrieve_from_async_stream(&mut Cursor::new(&buf)).await.unwrap();
406    assert_eq!(addr, addr2);
407}