transformable/impls/net/
ip_addr.rs

1use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
2
3use super::Transformable;
4
5#[cfg(feature = "std")]
6use crate::utils::invalid_data;
7
8/// The wire error type for [`IpAddr`].
9#[derive(Debug, thiserror::Error)]
10pub enum IpAddrTransformError {
11  /// Returned when the buffer is too small to encode the [`IpAddr`].
12  #[error(
13    "buffer is too small, use `IpAddr::encoded_len` to pre-allocate a buffer with enough space"
14  )]
15  EncodeBufferTooSmall,
16  /// Returned when the address family is unknown.
17  #[error("invalid address family: {0}, only IPv4 and IPv6 are supported")]
18  UnknownAddressFamily(u8),
19  /// Returned when the address is corrupted.
20  #[error("{0}")]
21  NotEnoughBytes(&'static str),
22}
23
24const MIN_ENCODED_LEN: usize = TAG_SIZE + V4_SIZE;
25const V6_ENCODED_LEN: usize = TAG_SIZE + V6_SIZE;
26const V6_SIZE: usize = 16;
27const V4_SIZE: usize = 4;
28const TAG_SIZE: usize = 1;
29
30impl Transformable for IpAddr {
31  type Error = IpAddrTransformError;
32
33  fn encode(&self, dst: &mut [u8]) -> Result<usize, Self::Error> {
34    let encoded_len = self.encoded_len();
35    if dst.len() < encoded_len {
36      return Err(Self::Error::EncodeBufferTooSmall);
37    }
38    dst[0] = match self {
39      IpAddr::V4(_) => 4,
40      IpAddr::V6(_) => 6,
41    };
42    match self {
43      IpAddr::V4(addr) => {
44        dst[1..5].copy_from_slice(&addr.octets());
45      }
46      IpAddr::V6(addr) => {
47        dst[1..17].copy_from_slice(&addr.octets());
48      }
49    }
50
51    Ok(encoded_len)
52  }
53
54  #[cfg(feature = "std")]
55  #[cfg_attr(docsrs, doc(cfg(feature = "std")))]
56  fn encode_to_writer<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<usize> {
57    match self {
58      IpAddr::V4(addr) => {
59        let mut buf = [0u8; 7];
60        buf[0] = 4;
61        buf[1..5].copy_from_slice(&addr.octets());
62        writer.write_all(&buf).map(|_| 7)
63      }
64      IpAddr::V6(addr) => {
65        let mut buf = [0u8; 19];
66        buf[0] = 6;
67        buf[1..17].copy_from_slice(&addr.octets());
68        writer.write_all(&buf).map(|_| 19)
69      }
70    }
71  }
72
73  #[cfg(feature = "async")]
74  #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
75  async fn encode_to_async_writer<W: futures_util::io::AsyncWrite + Send + Unpin>(
76    &self,
77    writer: &mut W,
78  ) -> std::io::Result<usize> {
79    use futures_util::AsyncWriteExt;
80
81    match self {
82      IpAddr::V4(addr) => {
83        let mut buf = [0u8; 7];
84        buf[0] = 4;
85        buf[1..5].copy_from_slice(&addr.octets());
86        writer.write_all(&buf).await.map(|_| 7)
87      }
88      IpAddr::V6(addr) => {
89        let mut buf = [0u8; 19];
90        buf[0] = 6;
91        buf[1..17].copy_from_slice(&addr.octets());
92        writer.write_all(&buf).await.map(|_| 19)
93      }
94    }
95  }
96
97  fn encoded_len(&self) -> usize {
98    1 + match self {
99      IpAddr::V4(_) => 4,
100      IpAddr::V6(_) => 16,
101    } + core::mem::size_of::<u16>()
102  }
103
104  fn decode(src: &[u8]) -> Result<(usize, Self), Self::Error>
105  where
106    Self: Sized,
107  {
108    match src[0] {
109      4 => {
110        if src.len() < 7 {
111          return Err(IpAddrTransformError::NotEnoughBytes(
112            "not enough bytes to decode socket v4 address",
113          ));
114        }
115
116        let ip = std::net::Ipv4Addr::new(src[1], src[2], src[3], src[4]);
117        Ok((MIN_ENCODED_LEN, IpAddr::from(ip)))
118      }
119      6 => {
120        if src.len() < 19 {
121          return Err(IpAddrTransformError::NotEnoughBytes(
122            "not enough bytes to decode socket v6 address",
123          ));
124        }
125
126        let mut buf = [0u8; 16];
127        buf.copy_from_slice(&src[1..17]);
128        let ip = std::net::Ipv6Addr::from(buf);
129        Ok((V6_ENCODED_LEN, IpAddr::from(ip)))
130      }
131      val => Err(IpAddrTransformError::UnknownAddressFamily(val)),
132    }
133  }
134
135  #[cfg(feature = "std")]
136  #[cfg_attr(docsrs, doc(cfg(feature = "std")))]
137  fn decode_from_reader<R: std::io::Read>(reader: &mut R) -> std::io::Result<(usize, Self)>
138  where
139    Self: Sized,
140  {
141    let mut buf = [0; MIN_ENCODED_LEN];
142    reader.read_exact(&mut buf)?;
143    match buf[0] {
144      4 => {
145        let ip = Ipv4Addr::new(buf[1], buf[2], buf[3], buf[4]);
146        Ok((MIN_ENCODED_LEN, IpAddr::from(ip)))
147      }
148      6 => {
149        let mut remaining = [0; V6_ENCODED_LEN - MIN_ENCODED_LEN];
150        reader.read_exact(&mut remaining)?;
151        let mut ipv6 = [0; V6_SIZE];
152        ipv6[..MIN_ENCODED_LEN - TAG_SIZE].copy_from_slice(&buf[TAG_SIZE..]);
153        ipv6[MIN_ENCODED_LEN - TAG_SIZE..]
154          .copy_from_slice(&remaining[..V6_ENCODED_LEN - MIN_ENCODED_LEN]);
155        let ip = Ipv6Addr::from(ipv6);
156        Ok((V6_ENCODED_LEN, IpAddr::from(ip)))
157      }
158      val => Err(invalid_data(IpAddrTransformError::UnknownAddressFamily(
159        val,
160      ))),
161    }
162  }
163
164  #[cfg(feature = "async")]
165  #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
166  async fn decode_from_async_reader<R: futures_util::io::AsyncRead + Send + Unpin>(
167    reader: &mut R,
168  ) -> std::io::Result<(usize, Self)>
169  where
170    Self: Sized,
171  {
172    use futures_util::AsyncReadExt;
173
174    let mut buf = [0; MIN_ENCODED_LEN];
175    reader.read_exact(&mut buf).await?;
176    match buf[0] {
177      4 => {
178        let ip = Ipv4Addr::new(buf[1], buf[2], buf[3], buf[4]);
179        Ok((MIN_ENCODED_LEN, IpAddr::from(ip)))
180      }
181      6 => {
182        let mut remaining = [0; V6_ENCODED_LEN - MIN_ENCODED_LEN];
183        reader.read_exact(&mut remaining).await?;
184        let mut ipv6 = [0; V6_SIZE];
185        ipv6[..MIN_ENCODED_LEN - TAG_SIZE].copy_from_slice(&buf[TAG_SIZE..]);
186        ipv6[MIN_ENCODED_LEN - TAG_SIZE..]
187          .copy_from_slice(&remaining[..V6_ENCODED_LEN - MIN_ENCODED_LEN]);
188        let ip = Ipv6Addr::from(ipv6);
189        Ok((V6_ENCODED_LEN, IpAddr::from(ip)))
190      }
191      val => Err(invalid_data(IpAddrTransformError::UnknownAddressFamily(
192        val,
193      ))),
194    }
195  }
196}
197
198test_transformable!(IpAddr => test_socket_addr_v4_transformable(
199  IpAddr::V4(
200    Ipv4Addr::new(127, 0, 0, 1),
201  )
202));
203
204test_transformable!(IpAddr => test_socket_addr_v6_transformable(
205  IpAddr::V6(
206    Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1),
207  )
208));