socks5_impl/protocol/
address.rs

1#[cfg(feature = "tokio")]
2use crate::protocol::AsyncStreamOperation;
3use crate::protocol::StreamOperation;
4use bytes::BufMut;
5use std::{
6    io::Cursor,
7    net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, ToSocketAddrs},
8};
9#[cfg(feature = "tokio")]
10use tokio::io::{AsyncRead, AsyncReadExt};
11
12#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd, Default)]
13#[repr(u8)]
14pub enum AddressType {
15    #[default]
16    IPv4 = 0x01,
17    Domain = 0x03,
18    IPv6 = 0x04,
19}
20
21impl TryFrom<u8> for AddressType {
22    type Error = std::io::Error;
23    fn try_from(code: u8) -> core::result::Result<Self, Self::Error> {
24        let err = format!("Unsupported address type code {code:#x}");
25        match code {
26            0x01 => Ok(AddressType::IPv4),
27            0x03 => Ok(AddressType::Domain),
28            0x04 => Ok(AddressType::IPv6),
29            _ => Err(std::io::Error::new(std::io::ErrorKind::InvalidInput, err)),
30        }
31    }
32}
33
34impl From<AddressType> for u8 {
35    fn from(addr_type: AddressType) -> Self {
36        match addr_type {
37            AddressType::IPv4 => 0x01,
38            AddressType::Domain => 0x03,
39            AddressType::IPv6 => 0x04,
40        }
41    }
42}
43
44/// SOCKS5 Adderss Format
45///
46/// ```plain
47/// +------+----------+----------+
48/// | ATYP | DST.ADDR | DST.PORT |
49/// +------+----------+----------+
50/// |  1   | Variable |    2     |
51/// +------+----------+----------+
52/// ```
53#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
54pub enum Address {
55    SocketAddress(SocketAddr),
56    DomainAddress(String, u16),
57}
58
59impl Address {
60    pub fn unspecified() -> Self {
61        Address::SocketAddress(SocketAddr::from((Ipv4Addr::UNSPECIFIED, 0)))
62    }
63
64    pub fn get_type(&self) -> AddressType {
65        match self {
66            Self::SocketAddress(SocketAddr::V4(_)) => AddressType::IPv4,
67            Self::SocketAddress(SocketAddr::V6(_)) => AddressType::IPv6,
68            Self::DomainAddress(_, _) => AddressType::Domain,
69        }
70    }
71
72    pub fn port(&self) -> u16 {
73        match self {
74            Self::SocketAddress(addr) => addr.port(),
75            Self::DomainAddress(_, port) => *port,
76        }
77    }
78
79    pub fn domain(&self) -> String {
80        match self {
81            Self::SocketAddress(addr) => addr.ip().to_string(),
82            Self::DomainAddress(addr, _) => addr.clone(),
83        }
84    }
85
86    pub const fn max_serialized_len() -> usize {
87        1 + 1 + u8::MAX as usize + 2
88    }
89}
90
91impl StreamOperation for Address {
92    fn retrieve_from_stream<R: std::io::Read>(stream: &mut R) -> std::io::Result<Self> {
93        let mut atyp = [0; 1];
94        stream.read_exact(&mut atyp)?;
95        match AddressType::try_from(atyp[0])? {
96            AddressType::IPv4 => {
97                let mut buf = [0; 6];
98                stream.read_exact(&mut buf)?;
99                let addr = Ipv4Addr::new(buf[0], buf[1], buf[2], buf[3]);
100                let port = u16::from_be_bytes([buf[4], buf[5]]);
101                Ok(Self::SocketAddress(SocketAddr::from((addr, port))))
102            }
103            AddressType::Domain => {
104                let mut len = [0; 1];
105                stream.read_exact(&mut len)?;
106                let len = len[0] as usize;
107                let mut buf = vec![0; len + 2];
108                stream.read_exact(&mut buf)?;
109
110                let port = u16::from_be_bytes([buf[len], buf[len + 1]]);
111                buf.truncate(len);
112
113                let addr = match String::from_utf8(buf) {
114                    Ok(addr) => addr,
115                    Err(err) => {
116                        let err = format!("Invalid address encoding: {err}");
117                        return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, err));
118                    }
119                };
120                Ok(Self::DomainAddress(addr, port))
121            }
122            AddressType::IPv6 => {
123                let mut buf = [0; 18];
124                stream.read_exact(&mut buf)?;
125                let port = u16::from_be_bytes([buf[16], buf[17]]);
126                let mut addr_bytes = [0; 16];
127                addr_bytes.copy_from_slice(&buf[..16]);
128                Ok(Self::SocketAddress(SocketAddr::from((Ipv6Addr::from(addr_bytes), port))))
129            }
130        }
131    }
132
133    fn write_to_buf<B: BufMut>(&self, buf: &mut B) {
134        match self {
135            Self::SocketAddress(SocketAddr::V4(addr)) => {
136                buf.put_u8(AddressType::IPv4.into());
137                buf.put_slice(&addr.ip().octets());
138                buf.put_u16(addr.port());
139            }
140            Self::SocketAddress(SocketAddr::V6(addr)) => {
141                buf.put_u8(AddressType::IPv6.into());
142                buf.put_slice(&addr.ip().octets());
143                buf.put_u16(addr.port());
144            }
145            Self::DomainAddress(addr, port) => {
146                let addr = addr.as_bytes();
147                buf.put_u8(AddressType::Domain.into());
148                buf.put_u8(addr.len() as u8);
149                buf.put_slice(addr);
150                buf.put_u16(*port);
151            }
152        }
153    }
154
155    fn len(&self) -> usize {
156        match self {
157            Address::SocketAddress(SocketAddr::V4(_)) => 1 + 4 + 2,
158            Address::SocketAddress(SocketAddr::V6(_)) => 1 + 16 + 2,
159            Address::DomainAddress(addr, _) => 1 + 1 + addr.len() + 2,
160        }
161    }
162}
163
164#[cfg(feature = "tokio")]
165#[async_trait::async_trait]
166impl AsyncStreamOperation for Address {
167    async fn retrieve_from_async_stream<R>(stream: &mut R) -> std::io::Result<Self>
168    where
169        R: AsyncRead + Unpin + Send + ?Sized,
170    {
171        let atyp = stream.read_u8().await?;
172        match AddressType::try_from(atyp)? {
173            AddressType::IPv4 => {
174                let mut addr_bytes = [0; 4];
175                stream.read_exact(&mut addr_bytes).await?;
176                let mut buf = [0; 2];
177                stream.read_exact(&mut buf).await?;
178                let addr = Ipv4Addr::from(addr_bytes);
179                let port = u16::from_be_bytes(buf);
180                Ok(Self::SocketAddress(SocketAddr::from((addr, port))))
181            }
182            AddressType::Domain => {
183                let len = stream.read_u8().await? as usize;
184                let mut buf = vec![0; len + 2];
185                stream.read_exact(&mut buf).await?;
186
187                let port = u16::from_be_bytes([buf[len], buf[len + 1]]);
188                buf.truncate(len);
189
190                let addr = match String::from_utf8(buf) {
191                    Ok(addr) => addr,
192                    Err(err) => {
193                        let err = format!("Invalid address encoding: {err}");
194                        return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, err));
195                    }
196                };
197                Ok(Self::DomainAddress(addr, port))
198            }
199            AddressType::IPv6 => {
200                let mut addr_bytes = [0; 16];
201                stream.read_exact(&mut addr_bytes).await?;
202                let mut buf = [0; 2];
203                stream.read_exact(&mut buf).await?;
204                let port = u16::from_be_bytes(buf);
205                Ok(Self::SocketAddress(SocketAddr::from((Ipv6Addr::from(addr_bytes), port))))
206            }
207        }
208    }
209}
210
211impl ToSocketAddrs for Address {
212    type Iter = std::vec::IntoIter<SocketAddr>;
213
214    fn to_socket_addrs(&self) -> std::io::Result<Self::Iter> {
215        match self {
216            Address::SocketAddress(addr) => Ok(vec![*addr].into_iter()),
217            Address::DomainAddress(addr, port) => Ok((addr.as_str(), *port).to_socket_addrs()?),
218        }
219    }
220}
221
222impl std::fmt::Display for Address {
223    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
224        match self {
225            Address::DomainAddress(hostname, port) => write!(f, "{hostname}:{port}"),
226            Address::SocketAddress(socket_addr) => write!(f, "{socket_addr}"),
227        }
228    }
229}
230
231impl TryFrom<Address> for SocketAddr {
232    type Error = std::io::Error;
233
234    fn try_from(address: Address) -> std::result::Result<Self, Self::Error> {
235        match address {
236            Address::SocketAddress(addr) => Ok(addr),
237            Address::DomainAddress(addr, port) => {
238                if let Ok(addr) = addr.parse::<Ipv4Addr>() {
239                    Ok(SocketAddr::from((addr, port)))
240                } else if let Ok(addr) = addr.parse::<Ipv6Addr>() {
241                    Ok(SocketAddr::from((addr, port)))
242                } else {
243                    let err = format!("domain address {addr} is not supported");
244                    Err(Self::Error::new(std::io::ErrorKind::Unsupported, err))
245                }
246            }
247        }
248    }
249}
250
251impl TryFrom<&Address> for SocketAddr {
252    type Error = std::io::Error;
253
254    fn try_from(address: &Address) -> std::result::Result<Self, Self::Error> {
255        TryFrom::<Address>::try_from(address.clone())
256    }
257}
258
259impl From<Address> for Vec<u8> {
260    fn from(addr: Address) -> Self {
261        let mut buf = Vec::with_capacity(addr.len());
262        addr.write_to_buf(&mut buf);
263        buf
264    }
265}
266
267impl TryFrom<Vec<u8>> for Address {
268    type Error = std::io::Error;
269
270    fn try_from(data: Vec<u8>) -> std::result::Result<Self, Self::Error> {
271        let mut rdr = Cursor::new(data);
272        Self::retrieve_from_stream(&mut rdr)
273    }
274}
275
276impl TryFrom<&[u8]> for Address {
277    type Error = std::io::Error;
278
279    fn try_from(data: &[u8]) -> std::result::Result<Self, Self::Error> {
280        let mut rdr = Cursor::new(data);
281        Self::retrieve_from_stream(&mut rdr)
282    }
283}
284
285impl From<SocketAddr> for Address {
286    fn from(addr: SocketAddr) -> Self {
287        Address::SocketAddress(addr)
288    }
289}
290
291impl From<&SocketAddr> for Address {
292    fn from(addr: &SocketAddr) -> Self {
293        Address::SocketAddress(*addr)
294    }
295}
296
297impl From<(Ipv4Addr, u16)> for Address {
298    fn from((addr, port): (Ipv4Addr, u16)) -> Self {
299        Address::SocketAddress(SocketAddr::from((addr, port)))
300    }
301}
302
303impl From<(Ipv6Addr, u16)> for Address {
304    fn from((addr, port): (Ipv6Addr, u16)) -> Self {
305        Address::SocketAddress(SocketAddr::from((addr, port)))
306    }
307}
308
309impl From<(IpAddr, u16)> for Address {
310    fn from((addr, port): (IpAddr, u16)) -> Self {
311        Address::SocketAddress(SocketAddr::from((addr, port)))
312    }
313}
314
315impl From<(String, u16)> for Address {
316    fn from((addr, port): (String, u16)) -> Self {
317        Address::DomainAddress(addr, port)
318    }
319}
320
321impl From<(&str, u16)> for Address {
322    fn from((addr, port): (&str, u16)) -> Self {
323        Address::DomainAddress(addr.to_owned(), port)
324    }
325}
326
327impl From<&Address> for Address {
328    fn from(addr: &Address) -> Self {
329        addr.clone()
330    }
331}
332
333impl TryFrom<&str> for Address {
334    type Error = crate::Error;
335
336    fn try_from(addr: &str) -> std::result::Result<Self, Self::Error> {
337        if let Ok(addr) = addr.parse::<SocketAddr>() {
338            Ok(Address::SocketAddress(addr))
339        } else {
340            let (addr, port) = if let Some(pos) = addr.rfind(':') {
341                (&addr[..pos], &addr[pos + 1..])
342            } else {
343                (addr, "0")
344            };
345            let port = port.parse::<u16>()?;
346            Ok(Address::DomainAddress(addr.to_owned(), port))
347        }
348    }
349}
350
351#[test]
352fn test_address() {
353    let addr = Address::from((Ipv4Addr::new(127, 0, 0, 1), 8080));
354    let mut buf = Vec::new();
355    addr.write_to_buf(&mut buf);
356    assert_eq!(buf, vec![0x01, 127, 0, 0, 1, 0x1f, 0x90]);
357    let addr2 = Address::retrieve_from_stream(&mut Cursor::new(&buf)).unwrap();
358    assert_eq!(addr, addr2);
359
360    let addr = Address::from((Ipv6Addr::new(0x45, 0xff89, 0, 0, 0, 0, 0, 1), 8080));
361    let mut buf = Vec::new();
362    addr.write_to_buf(&mut buf);
363    assert_eq!(buf, vec![0x04, 0, 0x45, 0xff, 0x89, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0x1f, 0x90]);
364    let addr2 = Address::retrieve_from_stream(&mut Cursor::new(&buf)).unwrap();
365    assert_eq!(addr, addr2);
366
367    let addr = Address::from(("sex.com".to_owned(), 8080));
368    let mut buf = Vec::new();
369    addr.write_to_buf(&mut buf);
370    assert_eq!(buf, vec![0x03, 0x07, b's', b'e', b'x', b'.', b'c', b'o', b'm', 0x1f, 0x90]);
371    let addr2 = Address::retrieve_from_stream(&mut Cursor::new(&buf)).unwrap();
372    assert_eq!(addr, addr2);
373}
374
375#[cfg(feature = "tokio")]
376#[tokio::test]
377async fn test_address_async() {
378    let addr = Address::from((Ipv4Addr::new(127, 0, 0, 1), 8080));
379    let mut buf = Vec::new();
380    addr.write_to_async_stream(&mut buf).await.unwrap();
381    assert_eq!(buf, vec![0x01, 127, 0, 0, 1, 0x1f, 0x90]);
382    let addr2 = Address::retrieve_from_async_stream(&mut Cursor::new(&buf)).await.unwrap();
383    assert_eq!(addr, addr2);
384
385    let addr = Address::from((Ipv6Addr::new(0x45, 0xff89, 0, 0, 0, 0, 0, 1), 8080));
386    let mut buf = Vec::new();
387    addr.write_to_async_stream(&mut buf).await.unwrap();
388    assert_eq!(buf, vec![0x04, 0, 0x45, 0xff, 0x89, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0x1f, 0x90]);
389    let addr2 = Address::retrieve_from_async_stream(&mut Cursor::new(&buf)).await.unwrap();
390    assert_eq!(addr, addr2);
391
392    let addr = Address::from(("sex.com".to_owned(), 8080));
393    let mut buf = Vec::new();
394    addr.write_to_async_stream(&mut buf).await.unwrap();
395    assert_eq!(buf, vec![0x03, 0x07, b's', b'e', b'x', b'.', b'c', b'o', b'm', 0x1f, 0x90]);
396    let addr2 = Address::retrieve_from_async_stream(&mut Cursor::new(&buf)).await.unwrap();
397    assert_eq!(addr, addr2);
398}