1use bytes::BufMut;
2use std::{
3 fmt::{Display, Formatter, Result as FmtResult},
4 io::Error as IoError,
5 net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
6 vec,
7};
8use thiserror::Error;
9use tokio::io::{AsyncRead, AsyncReadExt};
10
11#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
13pub enum Address {
14 SocketAddress(SocketAddr),
15 DomainAddress(Vec<u8>, u16),
16}
17
18impl Address {
19 const ATYP_IPV4: u8 = 0x01;
20 const ATYP_FQDN: u8 = 0x03;
21 const ATYP_IPV6: u8 = 0x04;
22
23 pub(crate) async fn read_from<R>(stream: &mut R) -> Result<Self, AddressError>
24 where
25 R: AsyncRead + Unpin,
26 {
27 let atyp = stream.read_u8().await?;
28
29 match atyp {
30 Self::ATYP_IPV4 => {
31 let mut buf = [0; 6];
32 stream.read_exact(&mut buf).await?;
33
34 let addr = Ipv4Addr::new(buf[0], buf[1], buf[2], buf[3]);
35
36 let port = u16::from_be_bytes([buf[4], buf[5]]);
37
38 Ok(Self::SocketAddress(SocketAddr::from((addr, port))))
39 }
40 Self::ATYP_FQDN => {
41 let len = stream.read_u8().await? as usize;
42
43 let mut buf = vec![0; len + 2];
44 stream.read_exact(&mut buf).await?;
45
46 let port = u16::from_be_bytes([buf[len], buf[len + 1]]);
47 buf.truncate(len);
48
49 Ok(Self::DomainAddress(buf, port))
50 }
51 Self::ATYP_IPV6 => {
52 let mut buf = [0; 18];
53 stream.read_exact(&mut buf).await?;
54
55 let addr = Ipv6Addr::new(
56 u16::from_be_bytes([buf[0], buf[1]]),
57 u16::from_be_bytes([buf[2], buf[3]]),
58 u16::from_be_bytes([buf[4], buf[5]]),
59 u16::from_be_bytes([buf[6], buf[7]]),
60 u16::from_be_bytes([buf[8], buf[9]]),
61 u16::from_be_bytes([buf[10], buf[11]]),
62 u16::from_be_bytes([buf[12], buf[13]]),
63 u16::from_be_bytes([buf[14], buf[15]]),
64 );
65
66 let port = u16::from_be_bytes([buf[16], buf[17]]);
67
68 Ok(Self::SocketAddress(SocketAddr::from((addr, port))))
69 }
70 atyp => Err(AddressError::InvalidType(atyp)),
71 }
72 }
73
74 pub(crate) fn write_to_buf<B: BufMut>(&self, buf: &mut B) {
75 match self {
76 Self::SocketAddress(SocketAddr::V4(addr)) => {
77 buf.put_u8(Self::ATYP_IPV4);
78 buf.put_slice(&addr.ip().octets());
79 buf.put_u16(addr.port());
80 }
81 Self::SocketAddress(SocketAddr::V6(addr)) => {
82 buf.put_u8(Self::ATYP_IPV6);
83 for seg in addr.ip().segments() {
84 buf.put_u16(seg);
85 }
86 buf.put_u16(addr.port());
87 }
88 Self::DomainAddress(addr, port) => {
89 buf.put_u8(Self::ATYP_FQDN);
90 buf.put_u8(addr.len() as u8);
91 buf.put_slice(addr);
92 buf.put_u16(*port);
93 }
94 }
95 }
96
97 pub fn unspecified() -> Self {
98 Address::SocketAddress(SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0))
99 }
100
101 pub fn serialized_len(&self) -> usize {
102 1 + match self {
103 Address::SocketAddress(SocketAddr::V4(_)) => 6,
104 Address::SocketAddress(SocketAddr::V6(_)) => 18,
105 Address::DomainAddress(addr, _) => 1 + addr.len() + 2,
106 }
107 }
108}
109
110impl Display for Address {
111 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
112 match self {
113 Address::DomainAddress(hostname, port) => write!(
114 f,
115 "{hostname}:{port}",
116 hostname = String::from_utf8_lossy(hostname),
117 ),
118 Address::SocketAddress(addr) => write!(f, "{addr}"),
119 }
120 }
121}
122
123#[derive(Debug, Error)]
124pub(crate) enum AddressError {
125 #[error(transparent)]
126 Io(#[from] IoError),
127 #[error("Invalid address type {0:#04x}")]
128 InvalidType(u8),
129}