transformable/impls/net/
socket_addr.rs

1use std::net::SocketAddr;
2
3use super::Transformable;
4
5#[cfg(feature = "std")]
6use crate::utils::invalid_data;
7
8/// The wire error type for [`SocketAddr`].
9#[derive(Debug, thiserror::Error)]
10pub enum SocketAddrTransformError {
11  /// Returned when the buffer is too small to encode the [`SocketAddr`].
12  #[error(
13    "buffer is too small, use `SocketAddr::encoded_len` to pre-allocate a buffer with enough space"
14  )]
15  EncodeBufferTooSmall,
16  /// Returned when the address family is unknown.
17  #[error("invalid address family: {0}, only IPv4 and IPv6 are supported")]
18  UnknownAddressFamily(u8),
19  /// Returned when the address is corrupted.
20  #[error("not enough bytes to decode")]
21  NotEnoughBytes,
22}
23
24const MIN_ENCODED_LEN: usize = TAG_SIZE + V4_SIZE + PORT_SIZE;
25const V6_ENCODED_LEN: usize = TAG_SIZE + V6_SIZE + PORT_SIZE;
26const V6_SIZE: usize = 16;
27const V4_SIZE: usize = 4;
28const TAG_SIZE: usize = 1;
29const PORT_SIZE: usize = core::mem::size_of::<u16>();
30
31impl Transformable for SocketAddr {
32  type Error = SocketAddrTransformError;
33
34  fn encode(&self, dst: &mut [u8]) -> Result<usize, Self::Error> {
35    let encoded_len = self.encoded_len();
36    if dst.len() < encoded_len {
37      return Err(Self::Error::EncodeBufferTooSmall);
38    }
39    dst[0] = match self {
40      SocketAddr::V4(_) => 4,
41      SocketAddr::V6(_) => 6,
42    };
43    match self {
44      SocketAddr::V4(addr) => {
45        dst[1..5].copy_from_slice(&addr.ip().octets());
46        dst[5..7].copy_from_slice(&addr.port().to_be_bytes());
47      }
48      SocketAddr::V6(addr) => {
49        dst[1..17].copy_from_slice(&addr.ip().octets());
50        dst[17..19].copy_from_slice(&addr.port().to_be_bytes());
51      }
52    }
53
54    Ok(encoded_len)
55  }
56
57  #[cfg(feature = "std")]
58  fn encode_to_writer<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<usize> {
59    match self {
60      SocketAddr::V4(addr) => {
61        let mut buf = [0u8; 7];
62        buf[0] = 4;
63        buf[TAG_SIZE..5].copy_from_slice(&addr.ip().octets());
64        buf[5..MIN_ENCODED_LEN].copy_from_slice(&addr.port().to_be_bytes());
65        writer.write_all(&buf).map(|_| 7)
66      }
67      SocketAddr::V6(addr) => {
68        let mut buf = [0u8; 19];
69        buf[0] = 6;
70        buf[1..17].copy_from_slice(&addr.ip().octets());
71        buf[17..19].copy_from_slice(&addr.port().to_be_bytes());
72        writer.write_all(&buf).map(|_| 19)
73      }
74    }
75  }
76
77  #[cfg(feature = "async")]
78  async fn encode_to_async_writer<W: futures_util::io::AsyncWrite + Send + Unpin>(
79    &self,
80    writer: &mut W,
81  ) -> std::io::Result<usize> {
82    use futures_util::AsyncWriteExt;
83
84    match self {
85      SocketAddr::V4(addr) => {
86        let mut buf = [0u8; 7];
87        buf[0] = 4;
88        buf[1..5].copy_from_slice(&addr.ip().octets());
89        buf[5..7].copy_from_slice(&addr.port().to_be_bytes());
90        writer.write_all(&buf).await.map(|_| 7)
91      }
92      SocketAddr::V6(addr) => {
93        let mut buf = [0u8; 19];
94        buf[0] = 6;
95        buf[1..17].copy_from_slice(&addr.ip().octets());
96        buf[17..19].copy_from_slice(&addr.port().to_be_bytes());
97        writer.write_all(&buf).await.map(|_| 19)
98      }
99    }
100  }
101
102  fn encoded_len(&self) -> usize {
103    1 + match self {
104      SocketAddr::V4(_) => 4,
105      SocketAddr::V6(_) => 16,
106    } + core::mem::size_of::<u16>()
107  }
108
109  fn decode(src: &[u8]) -> Result<(usize, Self), Self::Error>
110  where
111    Self: Sized,
112  {
113    match src[0] {
114      4 => {
115        if src.len() < 7 {
116          return Err(SocketAddrTransformError::NotEnoughBytes);
117        }
118
119        let ip = std::net::Ipv4Addr::new(src[1], src[2], src[3], src[4]);
120        let port = u16::from_be_bytes([src[5], src[6]]);
121        Ok((MIN_ENCODED_LEN, SocketAddr::from((ip, port))))
122      }
123      6 => {
124        if src.len() < 19 {
125          return Err(SocketAddrTransformError::NotEnoughBytes);
126        }
127
128        let mut buf = [0u8; 16];
129        buf.copy_from_slice(&src[1..17]);
130        let ip = std::net::Ipv6Addr::from(buf);
131        let port = u16::from_be_bytes([src[17], src[18]]);
132        Ok((V6_ENCODED_LEN, SocketAddr::from((ip, port))))
133      }
134      val => Err(SocketAddrTransformError::UnknownAddressFamily(val)),
135    }
136  }
137
138  #[cfg(feature = "std")]
139  fn decode_from_reader<R: std::io::Read>(reader: &mut R) -> std::io::Result<(usize, Self)>
140  where
141    Self: Sized,
142  {
143    use std::net::{Ipv4Addr, Ipv6Addr};
144
145    let mut buf = [0; MIN_ENCODED_LEN];
146    reader.read_exact(&mut buf)?;
147    match buf[0] {
148      4 => {
149        let ip = Ipv4Addr::new(buf[1], buf[2], buf[3], buf[4]);
150        let port = u16::from_be_bytes([buf[5], buf[6]]);
151        Ok((MIN_ENCODED_LEN, SocketAddr::from((ip, port))))
152      }
153      6 => {
154        let mut remaining = [0; V6_ENCODED_LEN - MIN_ENCODED_LEN];
155        reader.read_exact(&mut remaining)?;
156        let mut ipv6 = [0; V6_SIZE];
157        ipv6[..MIN_ENCODED_LEN - TAG_SIZE].copy_from_slice(&buf[TAG_SIZE..]);
158        ipv6[MIN_ENCODED_LEN - TAG_SIZE..]
159          .copy_from_slice(&remaining[..V6_ENCODED_LEN - MIN_ENCODED_LEN - 2]);
160        let ip = Ipv6Addr::from(ipv6);
161        let port = u16::from_be_bytes([
162          remaining[V6_ENCODED_LEN - MIN_ENCODED_LEN - 2],
163          remaining[V6_ENCODED_LEN - MIN_ENCODED_LEN - 1],
164        ]);
165        Ok((V6_ENCODED_LEN, SocketAddr::from((ip, port))))
166      }
167      val => Err(invalid_data(
168        SocketAddrTransformError::UnknownAddressFamily(val),
169      )),
170    }
171  }
172
173  #[cfg(feature = "async")]
174  async fn decode_from_async_reader<R: futures_util::io::AsyncRead + Send + Unpin>(
175    reader: &mut R,
176  ) -> std::io::Result<(usize, Self)>
177  where
178    Self: Sized,
179  {
180    use futures_util::AsyncReadExt;
181    use std::net::{Ipv4Addr, Ipv6Addr};
182
183    let mut buf = [0; MIN_ENCODED_LEN];
184    reader.read_exact(&mut buf).await?;
185    match buf[0] {
186      4 => {
187        let ip = Ipv4Addr::new(buf[1], buf[2], buf[3], buf[4]);
188        let port = u16::from_be_bytes([buf[5], buf[6]]);
189        Ok((MIN_ENCODED_LEN, SocketAddr::from((ip, port))))
190      }
191      6 => {
192        let mut remaining = [0; V6_ENCODED_LEN - MIN_ENCODED_LEN];
193        reader.read_exact(&mut remaining).await?;
194        let mut ipv6 = [0; V6_SIZE];
195        ipv6[..MIN_ENCODED_LEN - TAG_SIZE].copy_from_slice(&buf[TAG_SIZE..]);
196        ipv6[MIN_ENCODED_LEN - TAG_SIZE..]
197          .copy_from_slice(&remaining[..V6_ENCODED_LEN - MIN_ENCODED_LEN - 2]);
198        let ip = Ipv6Addr::from(ipv6);
199        let port = u16::from_be_bytes([
200          remaining[V6_ENCODED_LEN - MIN_ENCODED_LEN - 2],
201          remaining[V6_ENCODED_LEN - MIN_ENCODED_LEN - 1],
202        ]);
203        Ok((V6_ENCODED_LEN, SocketAddr::from((ip, port))))
204      }
205      val => Err(invalid_data(
206        SocketAddrTransformError::UnknownAddressFamily(val),
207      )),
208    }
209  }
210}
211
212test_transformable!(SocketAddr => test_socket_addr_v4_transformable(
213  SocketAddr::V4(std::net::SocketAddrV4::new(
214    std::net::Ipv4Addr::new(127, 0, 0, 1),
215    8080
216  ))
217));
218
219test_transformable!(SocketAddr => test_socket_addr_v6_transformable(
220  SocketAddr::V6(std::net::SocketAddrV6::new(
221    std::net::Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1),
222    8080,
223    0,
224    0
225  ))
226));