1use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6};
9
10use nom::{
11 Err, IResult,
12 bytes::streaming::{tag, take},
13 error::{Error, ErrorKind, ParseError},
14 number::streaming::{be_u8, be_u16},
15};
16
17use crate::protocol::proxy_protocol::header::{Command, HeaderV2, ProxyAddr};
18
19const PROTOCOL_SIGNATURE_V2: [u8; 12] = [
20 0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
21];
22
23fn parse_command(i: &[u8]) -> IResult<&[u8], Command> {
24 let i2 = i;
25 let (i, cmd) = be_u8(i)?;
26 debug_assert_eq!(
29 i.len() + 1,
30 i2.len(),
31 "command byte parse consumes exactly one byte"
32 );
33 match cmd {
34 0x20 => Ok((i, Command::Local)),
35 0x21 => Ok((i, Command::Proxy)),
36 _ => Err(Err::Error(Error::from_error_kind(i2, ErrorKind::Switch))),
37 }
38}
39
40pub fn parse_v2_header(i: &[u8]) -> IResult<&[u8], HeaderV2> {
41 let input_len = i.len();
42 let (i, _) = tag(&PROTOCOL_SIGNATURE_V2)(i)?;
43 let (i, command) = parse_command(i)?;
44 let (i, family) = be_u8(i)?;
45 let (i, len) = be_u16(i)?;
46 let (i, data) = take(len)(i)?;
47 debug_assert_eq!(
50 data.len(),
51 len as usize,
52 "address block must be exactly the declared length"
53 );
54 let (data_rest, addr) = parse_addr_v2(family)(data)?;
55 debug_assert!(
59 data_rest.len() <= data.len(),
60 "address parser cannot grow its input"
61 );
62
63 debug_assert!(i.len() <= input_len, "parser cannot grow its input");
67 debug_assert_eq!(
68 input_len - i.len(),
69 PROTOCOL_SIGNATURE_V2.len() + 1 + 1 + 2 + len as usize,
70 "consumed header length must reconcile with the signature, fixed fields, and declared address length"
71 );
72 debug_assert!(
76 matches!((family >> 4) & 0x0f, 0x00..=0x02),
77 "accepted family nibble must be AF_UNSPEC, AF_INET, or AF_INET6"
78 );
79
80 Ok((
81 i,
82 (HeaderV2 {
83 command,
84 family,
85 addr,
86 }),
87 ))
88}
89
90fn parse_addr_v2(family: u8) -> impl Fn(&[u8]) -> IResult<&[u8], ProxyAddr> {
91 move |i: &[u8]| match (family >> 4) & 0x0f {
92 0x00 => Ok((i, ProxyAddr::AfUnspec)),
93 0x01 => parse_ipv4_on_v2(i),
94 0x02 => parse_ipv6_on_v2(i),
95 _ => Err(Err::Error(Error::from_error_kind(i, ErrorKind::Switch))),
96 }
97}
98
99fn parse_ipv4_on_v2(i: &[u8]) -> IResult<&[u8], ProxyAddr> {
100 let in_len = i.len();
101 let (i, src_ip) = take(4u8)(i)?;
102 let (i, dest_ip) = take(4u8)(i)?;
103 let (i, src_port) = be_u16(i)?;
104 let (i, dest_port) = be_u16(i)?;
105 debug_assert_eq!(src_ip.len(), 4, "IPv4 source address is 4 bytes");
108 debug_assert_eq!(dest_ip.len(), 4, "IPv4 destination address is 4 bytes");
109 debug_assert!(i.len() <= in_len, "parser cannot grow its input");
110 debug_assert_eq!(
111 in_len - i.len(),
112 12,
113 "IPv4 v2 address block is exactly 12 bytes"
114 );
115
116 Ok((
117 i,
118 ProxyAddr::Ipv4Addr {
119 src_addr: SocketAddrV4::new(
120 Ipv4Addr::new(src_ip[0], src_ip[1], src_ip[2], src_ip[3]),
121 src_port,
122 ),
123 dst_addr: SocketAddrV4::new(
124 Ipv4Addr::new(dest_ip[0], dest_ip[1], dest_ip[2], dest_ip[3]),
125 dest_port,
126 ),
127 },
128 ))
129}
130
131fn parse_ipv6_on_v2(i: &[u8]) -> IResult<&[u8], ProxyAddr> {
132 let in_len = i.len();
133 let (i, src_ip) = take(16u8)(i)?;
134 let (i, dest_ip) = take(16u8)(i)?;
135 let (i, src_port) = be_u16(i)?;
136 let (i, dest_port) = be_u16(i)?;
137 debug_assert_eq!(src_ip.len(), 16, "IPv6 source address is 16 bytes");
141 debug_assert_eq!(dest_ip.len(), 16, "IPv6 destination address is 16 bytes");
142 debug_assert!(i.len() <= in_len, "parser cannot grow its input");
143 debug_assert_eq!(
144 in_len - i.len(),
145 36,
146 "IPv6 v2 address block is exactly 36 bytes"
147 );
148
149 Ok((
150 i,
151 ProxyAddr::Ipv6Addr {
152 src_addr: SocketAddrV6::new(slice_to_ipv6(src_ip), src_port, 0, 0),
153 dst_addr: SocketAddrV6::new(slice_to_ipv6(dest_ip), dest_port, 0, 0),
154 },
155 ))
156}
157
158pub fn slice_to_ipv6(sl: &[u8]) -> Ipv6Addr {
160 debug_assert_eq!(sl.len(), 16, "slice_to_ipv6 requires exactly 16 bytes");
165 let mut arr: [u8; 16] = [0; 16];
166 arr.clone_from_slice(sl);
167 Ipv6Addr::from(arr)
168}
169
170#[cfg(test)]
171mod test {
172
173 use std::net::{IpAddr, SocketAddr};
174
175 use nom::{Err, Needed};
176
177 use super::*;
178
179 #[test]
180 fn test_parse_proxy_protocol_v2_local_ipv4_addr_header() {
181 let input = &[
182 0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54,
183 0x0A, 0x20, 0x11, 0x00, 0x0C, 0x7D, 0x19, 0x0A, 0x01, 0x0A, 0x04, 0x05, 0x08, 0x1F, 0x90, 0x10, 0x68, ];
192
193 let src_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(125, 25, 10, 1)), 8080);
194 let dst_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 4, 5, 8)), 4200);
195 let expected = HeaderV2::new(Command::Local, src_addr, dst_addr);
196
197 assert_eq!(Ok((&[][..], expected)), parse_v2_header(input));
198 }
199
200 #[test]
201 fn test_parse_proxy_protocol_v2_proxy_ipv4_addr_header() {
202 let input = &[
203 0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54,
204 0x0A, 0x21, 0x11, 0x00, 0x0C, 0x7D, 0x19, 0x0A, 0x01, 0x0A, 0x04, 0x05, 0x08, 0x1F, 0x90, 0x10, 0x68, ];
213
214 let src_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(125, 25, 10, 1)), 8080);
215 let dst_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 4, 5, 8)), 4200);
216 let expected = HeaderV2::new(Command::Proxy, src_addr, dst_addr);
217
218 assert_eq!(Ok((&[][..], expected)), parse_v2_header(input));
219 }
220
221 #[test]
222 fn it_should_parse_proxy_protocol_v2_ipv6_addr_header() {
223 let input = &[
224 0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54,
225 0x0A, 0x20, 0x21, 0x00, 0x24, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
230 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
232 0x00, 0x02, 0x1F, 0x90, 0x10, 0x68, ];
236
237 let src_addr = SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)), 8080);
238 let dst_addr = SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 2)), 4200);
239 let expected = HeaderV2::new(Command::Local, src_addr, dst_addr);
240
241 assert_eq!(Ok((&[][..], expected)), parse_v2_header(input));
242 }
243
244 #[test]
245 fn it_should_parse_proxy_protocol_v2_afunspec_header() {
246 let input = &[
247 0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54,
248 0x0A, 0x20, 0x00, 0x00, 0x00, ];
253
254 let expected = HeaderV2 {
255 command: Command::Local,
256 family: 0,
257 addr: ProxyAddr::AfUnspec,
258 };
259
260 assert_eq!(Ok((&[][..], expected)), parse_v2_header(input));
261 }
262
263 #[test]
264 fn it_should_not_parse_proxy_protocol_v2_with_unknown_version() {
265 let unknow_version = 0x30;
266
267 let input = &[
268 0x0D,
269 0x0A,
270 0x0D,
271 0x0A,
272 0x00,
273 0x0D,
274 0x0A,
275 0x51,
276 0x55,
277 0x49,
278 0x54,
279 0x0A, unknow_version, ];
282
283 assert!(parse_v2_header(input).is_err());
284 }
285
286 #[test]
287 fn it_should_not_parse_proxy_protocol_v2_with_unknown_command() {
288 let unknow_command = 0x23;
289
290 let input = &[
291 0x0D,
292 0x0A,
293 0x0D,
294 0x0A,
295 0x00,
296 0x0D,
297 0x0A,
298 0x51,
299 0x55,
300 0x49,
301 0x54,
302 0x0A, unknow_command, ];
305
306 assert!(parse_v2_header(input).is_err());
307 }
308
309 #[test]
310 fn it_should_not_parse_proxy_protocol_with_unknown_family() {
311 let unknow_family = 0x30;
312
313 let input = &[
314 0x0D,
315 0x0A,
316 0x0D,
317 0x0A,
318 0x00,
319 0x0D,
320 0x0A,
321 0x51,
322 0x55,
323 0x49,
324 0x54,
325 0x0A, 0x20, unknow_family, 0x00,
329 0x00, ];
331
332 assert!(parse_v2_header(input).is_err());
333 }
334
335 #[test]
336 fn it_should_not_parse_request_without_magic_header() {
337 let input = &[
338 0x0D, 0x0D, 0x0D, 0x0D, 0x0D, 0x0D, 0x0D, 0x0D, 0x0D, 0x0D, 0x0D,
339 0x0D, ];
341
342 assert!(parse_v2_header(input).is_err());
343 }
344
345 #[test]
346 fn it_should_not_parse_proxy_protocol_v2_ipv4_addr_header_with_missing_data() {
347 let input = &[
348 0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54,
349 0x0A, 0x20, 0x11, ];
353
354 assert_eq!(Err(Err::Incomplete(Needed::new(2))), parse_v2_header(input));
355 }
356
357 #[test]
358 fn it_should_not_parse_proxy_protocol_v2_with_invalid_length() {
359 let input = &[
360 0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54,
361 0x0A, 0x20, 0x21, 0x00, 0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
366 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
368 0x00, 0x02, 0x1F, 0x90, 0x10, 0x68, ];
372
373 assert_eq!(
374 Err(Err::Incomplete(Needed::new(16))),
375 parse_v2_header(input)
376 );
377 }
378}