1use std::{
9 fmt,
10 net::{SocketAddr, SocketAddrV4, SocketAddrV6},
11};
12
13#[derive(PartialEq, Debug)]
14pub enum ProxyProtocolHeader {
15 V1(HeaderV1),
16 V2(HeaderV2),
17}
18
19impl ProxyProtocolHeader {
20 pub fn into_bytes(&self) -> Vec<u8> {
22 match *self {
23 ProxyProtocolHeader::V1(ref header) => header.into_bytes(),
24 ProxyProtocolHeader::V2(ref header) => header.into_bytes(),
25 }
26 }
27}
28
29#[derive(Debug, PartialEq, Eq)]
31pub enum ProtocolSupportedV1 {
32 TCP4, TCP6, UNKNOWN, }
36
37impl fmt::Display for ProtocolSupportedV1 {
38 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
39 match *self {
40 ProtocolSupportedV1::TCP4 => write!(f, "TCP4"),
41 ProtocolSupportedV1::TCP6 => write!(f, "TCP6"),
42 ProtocolSupportedV1::UNKNOWN => write!(f, "UNKNOWN"),
43 }
44 }
45}
46
47#[derive(Debug, PartialEq, Eq)]
56pub struct HeaderV1 {
57 pub protocol: ProtocolSupportedV1,
58 pub addr_src: SocketAddr,
59 pub addr_dst: SocketAddr,
60}
61
62const PROXY_PROTO_IDENTIFIER: &str = "PROXY";
63
64impl HeaderV1 {
65 pub fn new(addr_src: SocketAddr, addr_dst: SocketAddr) -> Self {
66 let protocol = if addr_dst.is_ipv6() {
67 ProtocolSupportedV1::TCP6
68 } else if addr_dst.is_ipv4() {
69 ProtocolSupportedV1::TCP4
70 } else {
71 ProtocolSupportedV1::UNKNOWN
72 };
73
74 HeaderV1 {
75 protocol,
76 addr_src,
77 addr_dst,
78 }
79 }
80
81 pub fn into_bytes(&self) -> Vec<u8> {
82 let bytes = if self.protocol.eq(&ProtocolSupportedV1::UNKNOWN) {
83 format!("{} {}\r\n", PROXY_PROTO_IDENTIFIER, self.protocol,).into_bytes()
84 } else {
85 format!(
86 "{} {} {} {} {} {}\r\n",
87 PROXY_PROTO_IDENTIFIER,
88 self.protocol,
89 self.addr_src.ip(),
90 self.addr_dst.ip(),
91 self.addr_src.port(),
92 self.addr_dst.port(),
93 )
94 .into_bytes()
95 };
96 debug_assert!(
99 bytes.starts_with(PROXY_PROTO_IDENTIFIER.as_bytes()),
100 "v1 header must start with the PROXY identifier"
101 );
102 debug_assert!(
103 bytes.ends_with(b"\r\n"),
104 "v1 header must be CRLF-terminated"
105 );
106 bytes
107 }
108}
109
110#[derive(Debug, PartialEq, Eq)]
144pub enum Command {
145 Local,
146 Proxy,
147}
148
149#[derive(Debug, PartialEq)]
150pub struct HeaderV2 {
151 pub command: Command,
152 pub family: u8, pub addr: ProxyAddr,
154}
155
156impl HeaderV2 {
157 pub fn new(command: Command, addr_src: SocketAddr, addr_dst: SocketAddr) -> Self {
158 let addr = ProxyAddr::from(addr_src, addr_dst);
159 let family = get_family(&addr);
160
161 debug_assert_eq!(
165 family,
166 get_family(&addr),
167 "cached family must match the address it describes"
168 );
169 debug_assert!(
170 matches!(addr, ProxyAddr::AfUnspec) == (family == 0x00),
171 "AfUnspec iff zero family byte"
172 );
173
174 HeaderV2 {
175 command,
176 family,
177 addr,
178 }
179 }
180
181 pub fn into_bytes(&self) -> Vec<u8> {
182 let expected_len = self.len();
183 let addr_len = self.addr.len() as usize;
184 let mut header = Vec::with_capacity(expected_len);
185
186 let signature = [
187 0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
188 ];
189 header.extend_from_slice(&signature);
190 debug_assert_eq!(
191 header.len(),
192 signature.len(),
193 "v2 header must open with exactly the 12-byte signature"
194 );
195
196 let command = match self.command {
197 Command::Local => 0,
198 Command::Proxy => 1,
199 };
200 let ver_and_cmd = 0x20 | command;
201 header.push(ver_and_cmd);
202
203 header.push(self.family);
204 header.extend_from_slice(&u16_to_array_of_u8(self.addr.len()));
205 debug_assert_eq!(
208 header.len(),
209 16,
210 "v2 fixed prefix (signature + ver/cmd + family + length) must be 16 bytes"
211 );
212 self.addr.write_bytes_to(&mut header);
213 debug_assert_eq!(
216 header.len(),
217 expected_len,
218 "serialized v2 header length must match HeaderV2::len()"
219 );
220 debug_assert_eq!(
221 header.len(),
222 16 + addr_len,
223 "serialized v2 header must be the 16-byte prefix plus the address block"
224 );
225 header
226 }
227
228 pub fn len(&self) -> usize {
229 let total = 12 + 1 + 1 + 2 + self.addr.len() as usize;
231 debug_assert!(
234 total >= 16,
235 "v2 header is at least its 16-byte fixed prefix"
236 );
237 debug_assert!(
238 total <= 16 + 216,
239 "v2 header never exceeds the 16-byte prefix plus the largest (unix) address block"
240 );
241 total
242 }
243
244 pub fn is_empty(&self) -> bool {
245 0 == self.len()
246 }
247}
248
249pub enum ProxyAddr {
250 Ipv4Addr {
251 src_addr: SocketAddrV4,
252 dst_addr: SocketAddrV4,
253 },
254 Ipv6Addr {
255 src_addr: SocketAddrV6,
256 dst_addr: SocketAddrV6,
257 },
258 UnixAddr {
259 src_addr: [u8; 108],
260 dst_addr: [u8; 108],
261 },
262 AfUnspec,
263}
264
265impl ProxyAddr {
266 pub fn from(addr_src: SocketAddr, addr_dst: SocketAddr) -> Self {
267 let addr = match (addr_src, addr_dst) {
268 (SocketAddr::V4(addr_ipv4_src), SocketAddr::V4(addr_ipv4_dst)) => ProxyAddr::Ipv4Addr {
269 src_addr: addr_ipv4_src,
270 dst_addr: addr_ipv4_dst,
271 },
272 (SocketAddr::V6(addr_ipv6_src), SocketAddr::V6(addr_ipv6_dst)) => ProxyAddr::Ipv6Addr {
273 src_addr: addr_ipv6_src,
274 dst_addr: addr_ipv6_dst,
275 },
276 _ => ProxyAddr::AfUnspec,
277 };
278 debug_assert_eq!(
282 matches!(addr, ProxyAddr::Ipv4Addr { .. }),
283 addr_src.is_ipv4() && addr_dst.is_ipv4(),
284 "Ipv4Addr variant iff both endpoints are IPv4"
285 );
286 debug_assert_eq!(
287 matches!(addr, ProxyAddr::Ipv6Addr { .. }),
288 addr_src.is_ipv6() && addr_dst.is_ipv6(),
289 "Ipv6Addr variant iff both endpoints are IPv6"
290 );
291 addr
292 }
293
294 fn len(&self) -> u16 {
295 match *self {
296 ProxyAddr::Ipv4Addr { .. } => 12,
297 ProxyAddr::Ipv6Addr { .. } => 36,
298 ProxyAddr::UnixAddr { .. } => 216,
299 ProxyAddr::AfUnspec => 0,
300 }
301 }
302
303 pub fn source(&self) -> Option<SocketAddr> {
304 match self {
305 ProxyAddr::Ipv4Addr { src_addr: src, .. } => Some(SocketAddr::V4(*src)),
306 ProxyAddr::Ipv6Addr { src_addr: src, .. } => Some(SocketAddr::V6(*src)),
307 _ => None,
308 }
309 }
310
311 pub fn destination(&self) -> Option<SocketAddr> {
312 match self {
313 ProxyAddr::Ipv4Addr { dst_addr: dst, .. } => Some(SocketAddr::V4(*dst)),
314 ProxyAddr::Ipv6Addr { dst_addr: dst, .. } => Some(SocketAddr::V6(*dst)),
315 _ => None,
316 }
317 }
318
319 fn write_bytes_to(&self, buf: &mut Vec<u8>) {
321 let before = buf.len();
322 let declared = self.len() as usize;
323 match *self {
324 ProxyAddr::Ipv4Addr { src_addr, dst_addr } => {
325 buf.extend_from_slice(&src_addr.ip().octets());
326 buf.extend_from_slice(&dst_addr.ip().octets());
327 buf.extend_from_slice(&u16_to_array_of_u8(src_addr.port()));
328 buf.extend_from_slice(&u16_to_array_of_u8(dst_addr.port()));
329 }
330 ProxyAddr::Ipv6Addr { src_addr, dst_addr } => {
331 buf.extend_from_slice(&src_addr.ip().octets());
332 buf.extend_from_slice(&dst_addr.ip().octets());
333 buf.extend_from_slice(&u16_to_array_of_u8(src_addr.port()));
334 buf.extend_from_slice(&u16_to_array_of_u8(dst_addr.port()));
335 }
336 ProxyAddr::UnixAddr { src_addr, dst_addr } => {
337 buf.extend_from_slice(&src_addr);
338 buf.extend_from_slice(&dst_addr);
339 }
340 ProxyAddr::AfUnspec => {}
341 };
342 debug_assert!(
347 buf.len() >= before,
348 "write_bytes_to must never shrink the buffer"
349 );
350 debug_assert_eq!(
351 buf.len() - before,
352 declared,
353 "appended address bytes must equal the declared ProxyAddr::len()"
354 );
355 }
356}
357
358impl fmt::Debug for ProxyAddr {
360 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
361 match *self {
362 ProxyAddr::Ipv4Addr { src_addr, dst_addr } => {
363 write!(f, "{dst_addr:?} {src_addr:?}")
364 }
365 ProxyAddr::Ipv6Addr { src_addr, dst_addr } => {
366 write!(f, "{dst_addr:?} {src_addr:?}")
367 }
368 ProxyAddr::UnixAddr { src_addr, dst_addr } => {
369 write!(f, "{:?} {:?}", &dst_addr[..], &src_addr[..])
370 }
371 ProxyAddr::AfUnspec => write!(f, "AFUNSPEC"),
372 }
373 }
374}
375
376impl PartialEq for ProxyAddr {
378 fn eq(&self, other: &ProxyAddr) -> bool {
379 match *self {
380 ProxyAddr::Ipv4Addr { src_addr, dst_addr } => match other {
381 ProxyAddr::Ipv4Addr {
382 src_addr: src_other,
383 dst_addr: dst_other,
384 } => *src_other == src_addr && *dst_other == dst_addr,
385 _ => false,
386 },
387 ProxyAddr::Ipv6Addr { src_addr, dst_addr } => match other {
388 ProxyAddr::Ipv6Addr {
389 src_addr: src_other,
390 dst_addr: dst_other,
391 } => *src_other == src_addr && *dst_other == dst_addr,
392 _ => false,
393 },
394 ProxyAddr::UnixAddr { src_addr, dst_addr } => match other {
395 ProxyAddr::UnixAddr {
396 src_addr: src_other,
397 dst_addr: dst_other,
398 } => src_other[..] == src_addr[..] && dst_other[..] == dst_addr[..],
399 _ => false,
400 },
401 ProxyAddr::AfUnspec => {
402 if let ProxyAddr::AfUnspec = other {
403 return true;
404 }
405 false
406 }
407 }
408 }
409}
410
411fn get_family(addr: &ProxyAddr) -> u8 {
412 let family = match *addr {
413 ProxyAddr::Ipv4Addr { .. } => 0x10 | 0x01, ProxyAddr::Ipv6Addr { .. } => 0x20 | 0x01, ProxyAddr::UnixAddr { .. } => 0x30 | 0x01, ProxyAddr::AfUnspec => 0x00, };
418 debug_assert!(
422 (family >> 4) <= 0x03,
423 "address family nibble must be one of AF_UNSPEC/INET/INET6/UNIX"
424 );
425 debug_assert!(
426 matches!(addr, ProxyAddr::AfUnspec) == (family == 0x00),
427 "only AfUnspec maps to the all-zero family byte"
428 );
429 debug_assert!(
430 matches!(addr, ProxyAddr::AfUnspec) || (family & 0x0f) == 0x01,
431 "concrete address families must advertise the STREAM transport"
432 );
433 family
434}
435
436fn u16_to_array_of_u8(x: u16) -> [u8; 2] {
437 let b1: u8 = ((x >> 8) & 0xff) as u8;
438 let b2: u8 = (x & 0xff) as u8;
439 let out = [b1, b2];
440 debug_assert_eq!(
444 u16::from_be_bytes(out),
445 x,
446 "big-endian split must round-trip the input u16"
447 );
448 out
449}
450
451#[cfg(test)]
452mod test_v2 {
453
454 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
455
456 use super::*;
457
458 #[test]
459 fn test_u16_to_array_of_u8() {
460 let val_u16: u16 = 65534;
461 let expected = [0xff, 0xfe];
462 assert_eq!(expected, u16_to_array_of_u8(val_u16));
463 }
464
465 #[test]
466 fn test_deserialize_tcp_ipv4_proxy_protocol_header() {
467 let src_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(125, 25, 10, 1)), 8080);
468 let dst_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 4, 5, 8)), 4200);
469
470 let header = HeaderV2::new(Command::Local, src_addr, dst_addr);
471 let expected = &[
472 0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54,
473 0x0A, 0x20, 0x11, 0x00, 0x0C, 0x7D, 0x19, 0x0A, 0x01, 0x0A, 0x04, 0x05, 0x08, 0x1F, 0x90, 0x10, 0x68, ];
482
483 assert_eq!(expected, &header.into_bytes()[..]);
484 }
485
486 #[test]
487 fn test_deserialize_tcp_ipv6_proxy_protocol_header() {
488 let src_addr = SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)), 8080);
489 let dst_addr = SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)), 4200);
490
491 let header = HeaderV2::new(Command::Proxy, src_addr, dst_addr);
492 let expected = [
493 0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54,
494 0x0A, 0x21, 0x21, 0x00, 0x24, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
499 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
501 0x00, 0x01, 0x1F, 0x90, 0x10, 0x68,
504 ];
505
506 assert_eq!(&expected[..], &header.into_bytes()[..]);
507 }
508}