Skip to main content

sozu_lib/protocol/proxy_protocol/
expect.rs

1//! Inbound PROXY-v2 expectation state.
2//!
3//! Reads bytes from the freshly accepted front-end socket until a complete
4//! PROXY v2 header has been parsed (`parse_v2_header`), captures the peer
5//! address pair, and transitions the session to the configured downstream
6//! protocol (typically `Pipe` for TCP listeners). Bounded by
7//! `MAX_LOOP_ITERATIONS` to defend against malformed/empty headers.
8
9use std::{cell::RefCell, rc::Rc};
10
11use mio::{net::TcpStream, *};
12use nom::{Err, HexDisplay};
13use rusty_ulid::Ulid;
14use sozu_command::{
15    config::MAX_LOOP_ITERATIONS,
16    logging::{LogContext, ansi_palette},
17};
18
19use super::{header::ProxyAddr, parser::parse_v2_header};
20use crate::metrics::names;
21use crate::{
22    Protocol, Readiness, SessionMetrics, StateResult,
23    pool::Checkout,
24    protocol::{
25        SessionResult, SessionState,
26        pipe::{Pipe, WebSocketContext},
27    },
28    socket::{SocketHandler, SocketResult},
29    sozu_command::ready::Ready,
30    tcp::TcpListener,
31    timer::TimeoutContainer,
32};
33
34/// Module-level prefix used on every log line emitted from this module when
35/// no per-session state is in scope. Produces a bold bright-white
36/// `PROXY-EXPECT` label (uniform across every protocol) when the logger is in
37/// colored mode.
38macro_rules! log_module_context {
39    () => {{
40        let (open, reset, _, _, _) = ansi_palette();
41        format!(
42            "{open}PROXY-EXPECT{reset}\t >>>",
43            open = open,
44            reset = reset
45        )
46    }};
47}
48
49/// Per-session prefix for log lines emitted with an
50/// [`ExpectProxyProtocol`] in scope. Renders the canonical
51/// `[ulid - - -]\tPROXY-EXPECT\tSession(...)\t >>>` envelope so operators can
52/// grep these lines alongside `MUX-*`, `RUSTLS`, and `PIPE` traffic for the
53/// same session.
54macro_rules! log_context {
55    ($self:expr) => {{
56        let (open, reset, grey, gray, white) = ansi_palette();
57        format!(
58            "{gray}{ctx}{reset}\t{open}PROXY-EXPECT{reset}\t{grey}Session{reset}({gray}frontend{reset}={white}{frontend}{reset}, {gray}index{reset}={white}{index}{reset}, {gray}readiness{reset}={white}{readiness}{reset})\t >>>",
59            open = open,
60            reset = reset,
61            grey = grey,
62            gray = gray,
63            white = white,
64            ctx = $self.log_context(),
65            frontend = $self.frontend_token.0,
66            index = $self.index,
67            readiness = $self.frontend_readiness,
68        )
69    }};
70}
71
72#[derive(Clone, Copy)]
73pub enum HeaderLen {
74    V4,
75    V6,
76    Unix,
77}
78
79// TODO: should have a backend
80pub struct ExpectProxyProtocol<Front: SocketHandler> {
81    pub addresses: Option<ProxyAddr>,
82    pub container_frontend_timeout: TimeoutContainer,
83    frontend_buffer: [u8; 232],
84    pub frontend_readiness: Readiness,
85    pub frontend_token: Token,
86    pub frontend: Front,
87    header_len: HeaderLen,
88    index: usize,
89    pub request_id: Ulid,
90}
91
92impl<Front: SocketHandler> ExpectProxyProtocol<Front> {
93    /// Instantiate a new ExpectProxyProtocol SessionState with:
94    /// - frontend_interest: READABLE | HUP | ERROR
95    /// - frontend_event: EMPTY
96    pub fn new(
97        container_frontend_timeout: TimeoutContainer,
98        frontend: Front,
99        frontend_token: Token,
100        request_id: Ulid,
101    ) -> Self {
102        ExpectProxyProtocol {
103            addresses: None,
104            container_frontend_timeout,
105            frontend_buffer: [0; 232],
106            frontend_readiness: Readiness {
107                interest: Ready::READABLE | Ready::HUP | Ready::ERROR,
108                event: Ready::EMPTY,
109            },
110            frontend_token,
111            frontend,
112            header_len: HeaderLen::V4,
113            index: 0,
114            request_id,
115        }
116    }
117
118    pub fn readable(&mut self, metrics: &mut SessionMetrics) -> SessionResult {
119        let total_len = match self.header_len {
120            HeaderLen::V4 => 28,
121            HeaderLen::V6 => 52,
122            HeaderLen::Unix => 232,
123        };
124
125        // Anti-oversized-header / partial-read invariant: the accumulation
126        // cursor never runs past the staging window, and the per-stage target
127        // never exceeds the fixed 232-byte buffer (the absolute upper bound on
128        // a PROXY-v2 header). A violation here would slice-panic on the read
129        // below; the asserts name it as a logic bug.
130        debug_assert!(
131            self.index <= total_len,
132            "read cursor must not exceed the current stage target"
133        );
134        debug_assert!(
135            total_len <= self.frontend_buffer.len(),
136            "stage target must fit the fixed proxy-protocol buffer"
137        );
138
139        let index_before = self.index;
140        let (sz, socket_result) = self
141            .frontend
142            .socket_read(&mut self.frontend_buffer[self.index..total_len]);
143        // The socket may only fill the slice it was handed, so a successful
144        // read advances the cursor by at most the remaining stage capacity.
145        debug_assert!(
146            sz <= total_len - index_before,
147            "socket_read cannot return more bytes than the slice it was given"
148        );
149        trace!(
150            "{} read {} bytes and res={:?}, total_len = {}",
151            log_context!(self),
152            sz,
153            socket_result,
154            total_len
155        );
156
157        if sz > 0 {
158            self.index += sz;
159            // Partial-read accumulation is strictly monotonic and stays within
160            // the bounded buffer: bytes-consumed advances by exactly `sz` and
161            // never exceeds the buffer length (anti-oversized-header bound).
162            debug_assert_eq!(
163                self.index,
164                index_before + sz,
165                "read cursor advances by exactly the bytes just read"
166            );
167            debug_assert!(
168                self.index <= self.frontend_buffer.len(),
169                "accumulated bytes must never exceed the fixed buffer bound"
170            );
171
172            count!(names::backend::BYTES_IN, sz as i64);
173            metrics.bin += sz;
174
175            if self.index == self.frontend_buffer.len() {
176                self.frontend_readiness.interest.remove(Ready::READABLE);
177            }
178        } else {
179            debug_assert_eq!(
180                self.index, index_before,
181                "a non-positive read must leave the cursor unchanged"
182            );
183            self.frontend_readiness.event.remove(Ready::READABLE);
184        }
185
186        match socket_result {
187            SocketResult::Error => {
188                error!(
189                    "{} front socket error, closing the connection (read {}, wrote {})",
190                    log_context!(self),
191                    metrics.bin,
192                    metrics.bout
193                );
194                incr!(names::proxy_protocol::ERRORS);
195                self.frontend_readiness.reset();
196                return SessionResult::Close;
197            }
198            SocketResult::WouldBlock => {
199                self.frontend_readiness.event.remove(Ready::READABLE);
200            }
201            SocketResult::Closed => {
202                // Socket closed before any proxy-protocol bytes were received.
203                // This is the typical HAProxy bare TCP healthcheck pattern
204                // (SYN/ACK/FIN without send-proxy). Close immediately instead
205                // of waiting for request_timeout (default 10s), which would
206                // create zombie sessions consuming nb_connections quota.
207                if self.index == 0 {
208                    trace!(
209                        "{} socket closed with 0 bytes, closing session",
210                        log_context!(self)
211                    );
212                    return SessionResult::Close;
213                }
214            }
215            SocketResult::Continue => {}
216        }
217
218        match parse_v2_header(&self.frontend_buffer[..self.index]) {
219            Ok((rest, header)) => {
220                // Completion postcondition: the parser consumed a prefix of the
221                // accumulated bytes, so the unparsed remainder is no larger than
222                // what we fed it (a complete header was recognized within the
223                // bound).
224                debug_assert!(
225                    rest.len() <= self.index,
226                    "parser remainder cannot exceed the accumulated input"
227                );
228                trace!(
229                    "{} got expect header: {:?}, rest.len() = {}",
230                    log_context!(self),
231                    header,
232                    rest.len()
233                );
234                self.addresses = Some(header.addr);
235                SessionResult::Upgrade
236            }
237            Err(Err::Incomplete(_)) => {
238                match self.header_len {
239                    HeaderLen::V4 => {
240                        if self.index == 28 {
241                            self.header_len = HeaderLen::V6;
242                        }
243                    }
244                    HeaderLen::V6 => {
245                        if self.index == 52 {
246                            self.header_len = HeaderLen::Unix;
247                        }
248                    }
249                    HeaderLen::Unix => {
250                        if self.index == 232 {
251                            error!(
252                                "{} proxy protocol header exceeds maximum size (232 bytes), closing",
253                                log_context!(self)
254                            );
255                            incr!(names::proxy_protocol::ERRORS);
256                            self.frontend_readiness.reset();
257                            return SessionResult::Close;
258                        }
259                    }
260                };
261                SessionResult::Continue
262            }
263            Err(Err::Error(e)) | Err(Err::Failure(e)) => {
264                error!(
265                    "{} parse error, closing the connection:\n{}",
266                    log_context!(self),
267                    e.input.to_hex(16)
268                );
269                incr!(names::proxy_protocol::ERRORS);
270                self.frontend_readiness.reset();
271                SessionResult::Close
272            }
273        }
274    }
275
276    pub fn front_socket(&self) -> &TcpStream {
277        self.frontend.socket_ref()
278    }
279
280    pub fn into_pipe(
281        self,
282        front_buf: Checkout,
283        back_buf: Checkout,
284        backend_socket: Option<TcpStream>,
285        backend_token: Option<Token>,
286        listener: Rc<RefCell<TcpListener>>,
287    ) -> Pipe<Front, TcpListener> {
288        // Prefer the source address parsed from the PROXY-v2 header over
289        // the TCP `peer_addr` so the pipe phase records the real client
290        // — `peer_addr` here is the upstream PROXY-emitter (an LB / edge
291        // proxy / health-check probe), not the originating client.
292        // Falls back to `peer_addr` when the header carried `Command::Local`
293        // (no encapsulated addresses) or when the parser ran with
294        // `AddressFamily::Unspec`.
295        let addr = self
296            .addresses
297            .as_ref()
298            .and_then(|pa| pa.source())
299            .or_else(|| self.front_socket().peer_addr().ok());
300
301        let mut pipe = Pipe::new(
302            back_buf,
303            None,
304            backend_socket,
305            None,
306            None,
307            Some(self.container_frontend_timeout),
308            None,
309            front_buf,
310            self.frontend_token,
311            self.frontend,
312            listener,
313            Protocol::TCP,
314            self.request_id,
315            self.request_id,
316            addr,
317            WebSocketContext::Tcp,
318        );
319
320        pipe.frontend_readiness.event = self.frontend_readiness.event;
321
322        if let Some(backend_token) = backend_token {
323            pipe.set_back_token(backend_token);
324        }
325
326        pipe
327    }
328
329    pub fn log_context(&self) -> LogContext<'_> {
330        LogContext {
331            session_id: self.request_id,
332            request_id: None,
333            cluster_id: None,
334            backend_id: None,
335        }
336    }
337}
338
339impl<Front: SocketHandler> SessionState for ExpectProxyProtocol<Front> {
340    fn ready(
341        &mut self,
342        _session: Rc<RefCell<dyn crate::ProxySession>>,
343        _proxy: Rc<RefCell<dyn crate::L7Proxy>>,
344        metrics: &mut SessionMetrics,
345    ) -> SessionResult {
346        let mut counter = 0;
347
348        if self.frontend_readiness.event.is_hup() {
349            return SessionResult::Close;
350        }
351
352        while counter < MAX_LOOP_ITERATIONS {
353            let frontend_interest = self.frontend_readiness.filter_interest();
354
355            trace!(
356                "{} {:?} -> None",
357                log_context!(self),
358                self.frontend_readiness
359            );
360
361            if frontend_interest.is_empty() {
362                break;
363            }
364
365            if frontend_interest.is_readable() {
366                let session_result = self.readable(metrics);
367                if session_result != SessionResult::Continue {
368                    return session_result;
369                }
370            }
371
372            if frontend_interest.is_error() {
373                error!("{} front error, disconnecting", log_context!(self));
374                self.frontend_readiness.interest = Ready::EMPTY;
375
376                return SessionResult::Close;
377            }
378
379            let counter_before = counter;
380            counter += 1;
381            // The readiness loop is bounded by MAX_LOOP_ITERATIONS; the counter
382            // advances by exactly one per turn and stays within the cap, so the
383            // loop cannot spin unbounded on a stuck readiness state.
384            debug_assert_eq!(counter, counter_before + 1, "loop counter advances by one");
385            debug_assert!(
386                counter <= MAX_LOOP_ITERATIONS,
387                "loop counter must stay within the iteration cap"
388            );
389        }
390
391        if counter >= MAX_LOOP_ITERATIONS {
392            error!(
393                "{} handling session went through {} iterations, there's a probable infinite loop bug, closing the connection",
394                log_context!(self),
395                MAX_LOOP_ITERATIONS
396            );
397            incr!(names::http::INFINITE_LOOP_ERROR);
398
399            self.print_state("");
400
401            return SessionResult::Close;
402        }
403
404        SessionResult::Continue
405    }
406
407    fn update_readiness(&mut self, token: Token, events: Ready) {
408        if self.frontend_token == token {
409            self.frontend_readiness.event |= events;
410        }
411    }
412
413    fn timeout(&mut self, token: Token, _metrics: &mut SessionMetrics) -> StateResult {
414        if self.frontend_token == token {
415            self.container_frontend_timeout.triggered();
416            return StateResult::CloseSession;
417        }
418
419        error!(
420            "{} got timeout for an invalid token: {:?}",
421            log_module_context!(),
422            token
423        );
424        StateResult::CloseSession
425    }
426
427    fn cancel_timeouts(&mut self) {
428        self.container_frontend_timeout.cancel();
429    }
430
431    fn print_state(&self, context: &str) {
432        error!(
433            "{} {} Session(Expect)\n\tFrontend:\n\t\ttoken: {:?}\treadiness: {:?}",
434            log_context!(self),
435            context,
436            self.frontend_token,
437            self.frontend_readiness
438        );
439    }
440}
441
442#[cfg(test)]
443mod expect_test {
444    use std::{
445        io::Write,
446        net::{IpAddr, Ipv4Addr, SocketAddr, TcpStream as StdTcpStream},
447        sync::{Arc, Barrier},
448        thread::{self, JoinHandle},
449        time::Duration,
450    };
451
452    use mio::net::TcpListener;
453    use rusty_ulid::Ulid;
454
455    use super::*;
456    use crate::protocol::proxy_protocol::header::*;
457
458    // Flow diagram of the test below
459    //                [connect]   [send proxy protocol]
460    //upfront proxy  ----------------------X
461    //              /     |           |
462    //  sozu     ---------v-----------v----X
463    #[test]
464    fn middleware_should_receive_proxy_protocol_header_from_an_upfront_middleware() {
465        setup_test_logger!();
466        let middleware_addr: SocketAddr = "127.0.0.1:3500".parse().expect("parse address error");
467        let barrier = Arc::new(Barrier::new(2));
468
469        let upfront = start_upfront_middleware(middleware_addr, barrier.clone());
470        start_middleware(middleware_addr, barrier);
471
472        upfront.join().expect("should join");
473    }
474
475    // Accept connection from an upfront proxy and expect to read a proxy protocol header in this stream.
476    fn start_middleware(middleware_addr: SocketAddr, barrier: Arc<Barrier>) {
477        let upfront_middleware_conn_listener = TcpListener::bind(middleware_addr)
478            .expect("could not accept upfront middleware connection");
479        let session_stream;
480        barrier.wait();
481
482        // mio::TcpListener use a nonblocking mode so we have to loop on accept
483        loop {
484            if let Ok((stream, _addr)) = upfront_middleware_conn_listener.accept() {
485                session_stream = stream;
486                break;
487            }
488        }
489
490        let mut session_metrics = SessionMetrics::new(None);
491        let container_frontend_timeout = TimeoutContainer::new(Duration::from_secs(10), Token(0));
492        let mut expect_pp = ExpectProxyProtocol::new(
493            container_frontend_timeout,
494            session_stream,
495            Token(0),
496            Ulid::generate(),
497        );
498
499        let mut res = SessionResult::Continue;
500        while res == SessionResult::Continue {
501            res = expect_pp.readable(&mut session_metrics);
502        }
503
504        if res != SessionResult::Upgrade {
505            panic!("Should receive a complete proxy protocol header, res = {res:?}");
506        };
507    }
508
509    // Connect to the next middleware and send a proxy protocol header
510    fn start_upfront_middleware(
511        next_middleware_addr: SocketAddr,
512        barrier: Arc<Barrier>,
513    ) -> JoinHandle<()> {
514        thread::spawn(move || {
515            let src_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(125, 25, 10, 1)), 8080);
516            let dst_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 4, 5, 8)), 4200);
517            let proxy_protocol = HeaderV2::new(Command::Local, src_addr, dst_addr).into_bytes();
518
519            barrier.wait();
520            match StdTcpStream::connect(next_middleware_addr) {
521                Ok(mut stream) => {
522                    stream.write_all(&proxy_protocol).unwrap();
523                }
524                Err(e) => panic!("could not connect to the next middleware: {e}"),
525            };
526        })
527    }
528}