1use std::{cell::RefCell, rc::Rc};
10
11use mio::{net::TcpStream, *};
12use nom::{Err, HexDisplay};
13use rusty_ulid::Ulid;
14use sozu_command::{
15 config::MAX_LOOP_ITERATIONS,
16 logging::{LogContext, ansi_palette},
17};
18
19use super::{header::ProxyAddr, parser::parse_v2_header};
20use crate::metrics::names;
21use crate::{
22 Protocol, Readiness, SessionMetrics, StateResult,
23 pool::Checkout,
24 protocol::{
25 SessionResult, SessionState,
26 pipe::{Pipe, WebSocketContext},
27 },
28 socket::{SocketHandler, SocketResult},
29 sozu_command::ready::Ready,
30 tcp::TcpListener,
31 timer::TimeoutContainer,
32};
33
34macro_rules! log_module_context {
39 () => {{
40 let (open, reset, _, _, _) = ansi_palette();
41 format!(
42 "{open}PROXY-EXPECT{reset}\t >>>",
43 open = open,
44 reset = reset
45 )
46 }};
47}
48
49macro_rules! log_context {
55 ($self:expr) => {{
56 let (open, reset, grey, gray, white) = ansi_palette();
57 format!(
58 "{gray}{ctx}{reset}\t{open}PROXY-EXPECT{reset}\t{grey}Session{reset}({gray}frontend{reset}={white}{frontend}{reset}, {gray}index{reset}={white}{index}{reset}, {gray}readiness{reset}={white}{readiness}{reset})\t >>>",
59 open = open,
60 reset = reset,
61 grey = grey,
62 gray = gray,
63 white = white,
64 ctx = $self.log_context(),
65 frontend = $self.frontend_token.0,
66 index = $self.index,
67 readiness = $self.frontend_readiness,
68 )
69 }};
70}
71
72#[derive(Clone, Copy)]
73pub enum HeaderLen {
74 V4,
75 V6,
76 Unix,
77}
78
79pub struct ExpectProxyProtocol<Front: SocketHandler> {
81 pub addresses: Option<ProxyAddr>,
82 pub container_frontend_timeout: TimeoutContainer,
83 frontend_buffer: [u8; 232],
84 pub frontend_readiness: Readiness,
85 pub frontend_token: Token,
86 pub frontend: Front,
87 header_len: HeaderLen,
88 index: usize,
89 pub request_id: Ulid,
90}
91
92impl<Front: SocketHandler> ExpectProxyProtocol<Front> {
93 pub fn new(
97 container_frontend_timeout: TimeoutContainer,
98 frontend: Front,
99 frontend_token: Token,
100 request_id: Ulid,
101 ) -> Self {
102 ExpectProxyProtocol {
103 addresses: None,
104 container_frontend_timeout,
105 frontend_buffer: [0; 232],
106 frontend_readiness: Readiness {
107 interest: Ready::READABLE | Ready::HUP | Ready::ERROR,
108 event: Ready::EMPTY,
109 },
110 frontend_token,
111 frontend,
112 header_len: HeaderLen::V4,
113 index: 0,
114 request_id,
115 }
116 }
117
118 pub fn readable(&mut self, metrics: &mut SessionMetrics) -> SessionResult {
119 let total_len = match self.header_len {
120 HeaderLen::V4 => 28,
121 HeaderLen::V6 => 52,
122 HeaderLen::Unix => 232,
123 };
124
125 debug_assert!(
131 self.index <= total_len,
132 "read cursor must not exceed the current stage target"
133 );
134 debug_assert!(
135 total_len <= self.frontend_buffer.len(),
136 "stage target must fit the fixed proxy-protocol buffer"
137 );
138
139 let index_before = self.index;
140 let (sz, socket_result) = self
141 .frontend
142 .socket_read(&mut self.frontend_buffer[self.index..total_len]);
143 debug_assert!(
146 sz <= total_len - index_before,
147 "socket_read cannot return more bytes than the slice it was given"
148 );
149 trace!(
150 "{} read {} bytes and res={:?}, total_len = {}",
151 log_context!(self),
152 sz,
153 socket_result,
154 total_len
155 );
156
157 if sz > 0 {
158 self.index += sz;
159 debug_assert_eq!(
163 self.index,
164 index_before + sz,
165 "read cursor advances by exactly the bytes just read"
166 );
167 debug_assert!(
168 self.index <= self.frontend_buffer.len(),
169 "accumulated bytes must never exceed the fixed buffer bound"
170 );
171
172 count!(names::backend::BYTES_IN, sz as i64);
173 metrics.bin += sz;
174
175 if self.index == self.frontend_buffer.len() {
176 self.frontend_readiness.interest.remove(Ready::READABLE);
177 }
178 } else {
179 debug_assert_eq!(
180 self.index, index_before,
181 "a non-positive read must leave the cursor unchanged"
182 );
183 self.frontend_readiness.event.remove(Ready::READABLE);
184 }
185
186 match socket_result {
187 SocketResult::Error => {
188 error!(
189 "{} front socket error, closing the connection (read {}, wrote {})",
190 log_context!(self),
191 metrics.bin,
192 metrics.bout
193 );
194 incr!(names::proxy_protocol::ERRORS);
195 self.frontend_readiness.reset();
196 return SessionResult::Close;
197 }
198 SocketResult::WouldBlock => {
199 self.frontend_readiness.event.remove(Ready::READABLE);
200 }
201 SocketResult::Closed => {
202 if self.index == 0 {
208 trace!(
209 "{} socket closed with 0 bytes, closing session",
210 log_context!(self)
211 );
212 return SessionResult::Close;
213 }
214 }
215 SocketResult::Continue => {}
216 }
217
218 match parse_v2_header(&self.frontend_buffer[..self.index]) {
219 Ok((rest, header)) => {
220 debug_assert!(
225 rest.len() <= self.index,
226 "parser remainder cannot exceed the accumulated input"
227 );
228 trace!(
229 "{} got expect header: {:?}, rest.len() = {}",
230 log_context!(self),
231 header,
232 rest.len()
233 );
234 self.addresses = Some(header.addr);
235 SessionResult::Upgrade
236 }
237 Err(Err::Incomplete(_)) => {
238 match self.header_len {
239 HeaderLen::V4 => {
240 if self.index == 28 {
241 self.header_len = HeaderLen::V6;
242 }
243 }
244 HeaderLen::V6 => {
245 if self.index == 52 {
246 self.header_len = HeaderLen::Unix;
247 }
248 }
249 HeaderLen::Unix => {
250 if self.index == 232 {
251 error!(
252 "{} proxy protocol header exceeds maximum size (232 bytes), closing",
253 log_context!(self)
254 );
255 incr!(names::proxy_protocol::ERRORS);
256 self.frontend_readiness.reset();
257 return SessionResult::Close;
258 }
259 }
260 };
261 SessionResult::Continue
262 }
263 Err(Err::Error(e)) | Err(Err::Failure(e)) => {
264 error!(
265 "{} parse error, closing the connection:\n{}",
266 log_context!(self),
267 e.input.to_hex(16)
268 );
269 incr!(names::proxy_protocol::ERRORS);
270 self.frontend_readiness.reset();
271 SessionResult::Close
272 }
273 }
274 }
275
276 pub fn front_socket(&self) -> &TcpStream {
277 self.frontend.socket_ref()
278 }
279
280 pub fn into_pipe(
281 self,
282 front_buf: Checkout,
283 back_buf: Checkout,
284 backend_socket: Option<TcpStream>,
285 backend_token: Option<Token>,
286 listener: Rc<RefCell<TcpListener>>,
287 ) -> Pipe<Front, TcpListener> {
288 let addr = self
296 .addresses
297 .as_ref()
298 .and_then(|pa| pa.source())
299 .or_else(|| self.front_socket().peer_addr().ok());
300
301 let mut pipe = Pipe::new(
302 back_buf,
303 None,
304 backend_socket,
305 None,
306 None,
307 Some(self.container_frontend_timeout),
308 None,
309 front_buf,
310 self.frontend_token,
311 self.frontend,
312 listener,
313 Protocol::TCP,
314 self.request_id,
315 self.request_id,
316 addr,
317 WebSocketContext::Tcp,
318 );
319
320 pipe.frontend_readiness.event = self.frontend_readiness.event;
321
322 if let Some(backend_token) = backend_token {
323 pipe.set_back_token(backend_token);
324 }
325
326 pipe
327 }
328
329 pub fn log_context(&self) -> LogContext<'_> {
330 LogContext {
331 session_id: self.request_id,
332 request_id: None,
333 cluster_id: None,
334 backend_id: None,
335 }
336 }
337}
338
339impl<Front: SocketHandler> SessionState for ExpectProxyProtocol<Front> {
340 fn ready(
341 &mut self,
342 _session: Rc<RefCell<dyn crate::ProxySession>>,
343 _proxy: Rc<RefCell<dyn crate::L7Proxy>>,
344 metrics: &mut SessionMetrics,
345 ) -> SessionResult {
346 let mut counter = 0;
347
348 if self.frontend_readiness.event.is_hup() {
349 return SessionResult::Close;
350 }
351
352 while counter < MAX_LOOP_ITERATIONS {
353 let frontend_interest = self.frontend_readiness.filter_interest();
354
355 trace!(
356 "{} {:?} -> None",
357 log_context!(self),
358 self.frontend_readiness
359 );
360
361 if frontend_interest.is_empty() {
362 break;
363 }
364
365 if frontend_interest.is_readable() {
366 let session_result = self.readable(metrics);
367 if session_result != SessionResult::Continue {
368 return session_result;
369 }
370 }
371
372 if frontend_interest.is_error() {
373 error!("{} front error, disconnecting", log_context!(self));
374 self.frontend_readiness.interest = Ready::EMPTY;
375
376 return SessionResult::Close;
377 }
378
379 let counter_before = counter;
380 counter += 1;
381 debug_assert_eq!(counter, counter_before + 1, "loop counter advances by one");
385 debug_assert!(
386 counter <= MAX_LOOP_ITERATIONS,
387 "loop counter must stay within the iteration cap"
388 );
389 }
390
391 if counter >= MAX_LOOP_ITERATIONS {
392 error!(
393 "{} handling session went through {} iterations, there's a probable infinite loop bug, closing the connection",
394 log_context!(self),
395 MAX_LOOP_ITERATIONS
396 );
397 incr!(names::http::INFINITE_LOOP_ERROR);
398
399 self.print_state("");
400
401 return SessionResult::Close;
402 }
403
404 SessionResult::Continue
405 }
406
407 fn update_readiness(&mut self, token: Token, events: Ready) {
408 if self.frontend_token == token {
409 self.frontend_readiness.event |= events;
410 }
411 }
412
413 fn timeout(&mut self, token: Token, _metrics: &mut SessionMetrics) -> StateResult {
414 if self.frontend_token == token {
415 self.container_frontend_timeout.triggered();
416 return StateResult::CloseSession;
417 }
418
419 error!(
420 "{} got timeout for an invalid token: {:?}",
421 log_module_context!(),
422 token
423 );
424 StateResult::CloseSession
425 }
426
427 fn cancel_timeouts(&mut self) {
428 self.container_frontend_timeout.cancel();
429 }
430
431 fn print_state(&self, context: &str) {
432 error!(
433 "{} {} Session(Expect)\n\tFrontend:\n\t\ttoken: {:?}\treadiness: {:?}",
434 log_context!(self),
435 context,
436 self.frontend_token,
437 self.frontend_readiness
438 );
439 }
440}
441
442#[cfg(test)]
443mod expect_test {
444 use std::{
445 io::Write,
446 net::{IpAddr, Ipv4Addr, SocketAddr, TcpStream as StdTcpStream},
447 sync::{Arc, Barrier},
448 thread::{self, JoinHandle},
449 time::Duration,
450 };
451
452 use mio::net::TcpListener;
453 use rusty_ulid::Ulid;
454
455 use super::*;
456 use crate::protocol::proxy_protocol::header::*;
457
458 #[test]
464 fn middleware_should_receive_proxy_protocol_header_from_an_upfront_middleware() {
465 setup_test_logger!();
466 let middleware_addr: SocketAddr = "127.0.0.1:3500".parse().expect("parse address error");
467 let barrier = Arc::new(Barrier::new(2));
468
469 let upfront = start_upfront_middleware(middleware_addr, barrier.clone());
470 start_middleware(middleware_addr, barrier);
471
472 upfront.join().expect("should join");
473 }
474
475 fn start_middleware(middleware_addr: SocketAddr, barrier: Arc<Barrier>) {
477 let upfront_middleware_conn_listener = TcpListener::bind(middleware_addr)
478 .expect("could not accept upfront middleware connection");
479 let session_stream;
480 barrier.wait();
481
482 loop {
484 if let Ok((stream, _addr)) = upfront_middleware_conn_listener.accept() {
485 session_stream = stream;
486 break;
487 }
488 }
489
490 let mut session_metrics = SessionMetrics::new(None);
491 let container_frontend_timeout = TimeoutContainer::new(Duration::from_secs(10), Token(0));
492 let mut expect_pp = ExpectProxyProtocol::new(
493 container_frontend_timeout,
494 session_stream,
495 Token(0),
496 Ulid::generate(),
497 );
498
499 let mut res = SessionResult::Continue;
500 while res == SessionResult::Continue {
501 res = expect_pp.readable(&mut session_metrics);
502 }
503
504 if res != SessionResult::Upgrade {
505 panic!("Should receive a complete proxy protocol header, res = {res:?}");
506 };
507 }
508
509 fn start_upfront_middleware(
511 next_middleware_addr: SocketAddr,
512 barrier: Arc<Barrier>,
513 ) -> JoinHandle<()> {
514 thread::spawn(move || {
515 let src_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(125, 25, 10, 1)), 8080);
516 let dst_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 4, 5, 8)), 4200);
517 let proxy_protocol = HeaderV2::new(Command::Local, src_addr, dst_addr).into_bytes();
518
519 barrier.wait();
520 match StdTcpStream::connect(next_middleware_addr) {
521 Ok(mut stream) => {
522 stream.write_all(&proxy_protocol).unwrap();
523 }
524 Err(e) => panic!("could not connect to the next middleware: {e}"),
525 };
526 })
527 }
528}