transformable/impls/net/
socket_addr.rs1use std::net::SocketAddr;
2
3use super::Transformable;
4
5#[cfg(feature = "std")]
6use crate::utils::invalid_data;
7
8#[derive(Debug, thiserror::Error)]
10pub enum SocketAddrTransformError {
11 #[error(
13 "buffer is too small, use `SocketAddr::encoded_len` to pre-allocate a buffer with enough space"
14 )]
15 EncodeBufferTooSmall,
16 #[error("invalid address family: {0}, only IPv4 and IPv6 are supported")]
18 UnknownAddressFamily(u8),
19 #[error("not enough bytes to decode")]
21 NotEnoughBytes,
22}
23
24const MIN_ENCODED_LEN: usize = TAG_SIZE + V4_SIZE + PORT_SIZE;
25const V6_ENCODED_LEN: usize = TAG_SIZE + V6_SIZE + PORT_SIZE;
26const V6_SIZE: usize = 16;
27const V4_SIZE: usize = 4;
28const TAG_SIZE: usize = 1;
29const PORT_SIZE: usize = core::mem::size_of::<u16>();
30
31impl Transformable for SocketAddr {
32 type Error = SocketAddrTransformError;
33
34 fn encode(&self, dst: &mut [u8]) -> Result<usize, Self::Error> {
35 let encoded_len = self.encoded_len();
36 if dst.len() < encoded_len {
37 return Err(Self::Error::EncodeBufferTooSmall);
38 }
39 dst[0] = match self {
40 SocketAddr::V4(_) => 4,
41 SocketAddr::V6(_) => 6,
42 };
43 match self {
44 SocketAddr::V4(addr) => {
45 dst[1..5].copy_from_slice(&addr.ip().octets());
46 dst[5..7].copy_from_slice(&addr.port().to_be_bytes());
47 }
48 SocketAddr::V6(addr) => {
49 dst[1..17].copy_from_slice(&addr.ip().octets());
50 dst[17..19].copy_from_slice(&addr.port().to_be_bytes());
51 }
52 }
53
54 Ok(encoded_len)
55 }
56
57 #[cfg(feature = "std")]
58 fn encode_to_writer<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<usize> {
59 match self {
60 SocketAddr::V4(addr) => {
61 let mut buf = [0u8; 7];
62 buf[0] = 4;
63 buf[TAG_SIZE..5].copy_from_slice(&addr.ip().octets());
64 buf[5..MIN_ENCODED_LEN].copy_from_slice(&addr.port().to_be_bytes());
65 writer.write_all(&buf).map(|_| 7)
66 }
67 SocketAddr::V6(addr) => {
68 let mut buf = [0u8; 19];
69 buf[0] = 6;
70 buf[1..17].copy_from_slice(&addr.ip().octets());
71 buf[17..19].copy_from_slice(&addr.port().to_be_bytes());
72 writer.write_all(&buf).map(|_| 19)
73 }
74 }
75 }
76
77 #[cfg(feature = "async")]
78 async fn encode_to_async_writer<W: futures_util::io::AsyncWrite + Send + Unpin>(
79 &self,
80 writer: &mut W,
81 ) -> std::io::Result<usize> {
82 use futures_util::AsyncWriteExt;
83
84 match self {
85 SocketAddr::V4(addr) => {
86 let mut buf = [0u8; 7];
87 buf[0] = 4;
88 buf[1..5].copy_from_slice(&addr.ip().octets());
89 buf[5..7].copy_from_slice(&addr.port().to_be_bytes());
90 writer.write_all(&buf).await.map(|_| 7)
91 }
92 SocketAddr::V6(addr) => {
93 let mut buf = [0u8; 19];
94 buf[0] = 6;
95 buf[1..17].copy_from_slice(&addr.ip().octets());
96 buf[17..19].copy_from_slice(&addr.port().to_be_bytes());
97 writer.write_all(&buf).await.map(|_| 19)
98 }
99 }
100 }
101
102 fn encoded_len(&self) -> usize {
103 1 + match self {
104 SocketAddr::V4(_) => 4,
105 SocketAddr::V6(_) => 16,
106 } + core::mem::size_of::<u16>()
107 }
108
109 fn decode(src: &[u8]) -> Result<(usize, Self), Self::Error>
110 where
111 Self: Sized,
112 {
113 match src[0] {
114 4 => {
115 if src.len() < 7 {
116 return Err(SocketAddrTransformError::NotEnoughBytes);
117 }
118
119 let ip = std::net::Ipv4Addr::new(src[1], src[2], src[3], src[4]);
120 let port = u16::from_be_bytes([src[5], src[6]]);
121 Ok((MIN_ENCODED_LEN, SocketAddr::from((ip, port))))
122 }
123 6 => {
124 if src.len() < 19 {
125 return Err(SocketAddrTransformError::NotEnoughBytes);
126 }
127
128 let mut buf = [0u8; 16];
129 buf.copy_from_slice(&src[1..17]);
130 let ip = std::net::Ipv6Addr::from(buf);
131 let port = u16::from_be_bytes([src[17], src[18]]);
132 Ok((V6_ENCODED_LEN, SocketAddr::from((ip, port))))
133 }
134 val => Err(SocketAddrTransformError::UnknownAddressFamily(val)),
135 }
136 }
137
138 #[cfg(feature = "std")]
139 fn decode_from_reader<R: std::io::Read>(reader: &mut R) -> std::io::Result<(usize, Self)>
140 where
141 Self: Sized,
142 {
143 use std::net::{Ipv4Addr, Ipv6Addr};
144
145 let mut buf = [0; MIN_ENCODED_LEN];
146 reader.read_exact(&mut buf)?;
147 match buf[0] {
148 4 => {
149 let ip = Ipv4Addr::new(buf[1], buf[2], buf[3], buf[4]);
150 let port = u16::from_be_bytes([buf[5], buf[6]]);
151 Ok((MIN_ENCODED_LEN, SocketAddr::from((ip, port))))
152 }
153 6 => {
154 let mut remaining = [0; V6_ENCODED_LEN - MIN_ENCODED_LEN];
155 reader.read_exact(&mut remaining)?;
156 let mut ipv6 = [0; V6_SIZE];
157 ipv6[..MIN_ENCODED_LEN - TAG_SIZE].copy_from_slice(&buf[TAG_SIZE..]);
158 ipv6[MIN_ENCODED_LEN - TAG_SIZE..]
159 .copy_from_slice(&remaining[..V6_ENCODED_LEN - MIN_ENCODED_LEN - 2]);
160 let ip = Ipv6Addr::from(ipv6);
161 let port = u16::from_be_bytes([
162 remaining[V6_ENCODED_LEN - MIN_ENCODED_LEN - 2],
163 remaining[V6_ENCODED_LEN - MIN_ENCODED_LEN - 1],
164 ]);
165 Ok((V6_ENCODED_LEN, SocketAddr::from((ip, port))))
166 }
167 val => Err(invalid_data(
168 SocketAddrTransformError::UnknownAddressFamily(val),
169 )),
170 }
171 }
172
173 #[cfg(feature = "async")]
174 async fn decode_from_async_reader<R: futures_util::io::AsyncRead + Send + Unpin>(
175 reader: &mut R,
176 ) -> std::io::Result<(usize, Self)>
177 where
178 Self: Sized,
179 {
180 use futures_util::AsyncReadExt;
181 use std::net::{Ipv4Addr, Ipv6Addr};
182
183 let mut buf = [0; MIN_ENCODED_LEN];
184 reader.read_exact(&mut buf).await?;
185 match buf[0] {
186 4 => {
187 let ip = Ipv4Addr::new(buf[1], buf[2], buf[3], buf[4]);
188 let port = u16::from_be_bytes([buf[5], buf[6]]);
189 Ok((MIN_ENCODED_LEN, SocketAddr::from((ip, port))))
190 }
191 6 => {
192 let mut remaining = [0; V6_ENCODED_LEN - MIN_ENCODED_LEN];
193 reader.read_exact(&mut remaining).await?;
194 let mut ipv6 = [0; V6_SIZE];
195 ipv6[..MIN_ENCODED_LEN - TAG_SIZE].copy_from_slice(&buf[TAG_SIZE..]);
196 ipv6[MIN_ENCODED_LEN - TAG_SIZE..]
197 .copy_from_slice(&remaining[..V6_ENCODED_LEN - MIN_ENCODED_LEN - 2]);
198 let ip = Ipv6Addr::from(ipv6);
199 let port = u16::from_be_bytes([
200 remaining[V6_ENCODED_LEN - MIN_ENCODED_LEN - 2],
201 remaining[V6_ENCODED_LEN - MIN_ENCODED_LEN - 1],
202 ]);
203 Ok((V6_ENCODED_LEN, SocketAddr::from((ip, port))))
204 }
205 val => Err(invalid_data(
206 SocketAddrTransformError::UnknownAddressFamily(val),
207 )),
208 }
209 }
210}
211
212test_transformable!(SocketAddr => test_socket_addr_v4_transformable(
213 SocketAddr::V4(std::net::SocketAddrV4::new(
214 std::net::Ipv4Addr::new(127, 0, 0, 1),
215 8080
216 ))
217));
218
219test_transformable!(SocketAddr => test_socket_addr_v6_transformable(
220 SocketAddr::V6(std::net::SocketAddrV6::new(
221 std::net::Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1),
222 8080,
223 0,
224 0
225 ))
226));