Skip to main content

sozu_lib/protocol/proxy_protocol/
send.rs

1//! PROXY-v2 send state.
2//!
3//! Synthesises a PROXY-v2 header (`HeaderV2`) describing the original
4//! client and emits it on a freshly opened backend `TcpStream` before the
5//! TCP/TLS payload begins. Used when the front-end accepted a non-PROXY
6//! connection but the configured backend expects PROXY-v2 metadata.
7
8use std::{
9    cell::RefCell,
10    io::{ErrorKind, Write},
11    rc::Rc,
12};
13
14use mio::{Token, net::TcpStream};
15use rusty_ulid::Ulid;
16use sozu_command::logging::ansi_palette;
17
18use crate::metrics::names;
19use crate::{
20    BackendConnectionStatus, Protocol, Readiness, SessionMetrics, SessionResult,
21    pool::Checkout,
22    protocol::{
23        pipe::{Pipe, WebSocketContext},
24        proxy_protocol::header::{Command, HeaderV2, ProxyProtocolHeader},
25    },
26    socket::SocketHandler,
27    sozu_command::ready::Ready,
28    tcp::TcpListener,
29};
30
31/// Module-level prefix used on every log line emitted from this module when
32/// no per-session state is in scope. Produces a bold bright-white
33/// `PROXY-SEND` label (uniform across every protocol) when the logger is in
34/// colored mode.
35#[allow(unused_macros)]
36macro_rules! log_module_context {
37    () => {{
38        let (open, reset, _, _, _) = ansi_palette();
39        format!("{open}PROXY-SEND{reset}\t >>>", open = open, reset = reset)
40    }};
41}
42
43/// Per-session prefix for log lines emitted with a [`SendProxyProtocol`] in
44/// scope. Renders the canonical `\tPROXY-SEND\tSession(...)\t >>>` envelope.
45/// The send-side state has no `request_id`-keyed [`LogContext`] yet; the
46/// bracket carries the front/back tokens instead.
47macro_rules! log_context {
48    ($self:expr) => {{
49        let (open, reset, grey, gray, white) = ansi_palette();
50        format!(
51            "{open}PROXY-SEND{reset}\t{grey}Session{reset}({gray}frontend{reset}={white}{frontend}{reset}, {gray}backend{reset}={white}{backend}{reset}, {gray}front_readiness{reset}={white}{front_readiness}{reset}, {gray}back_readiness{reset}={white}{back_readiness}{reset})\t >>>",
52            open = open,
53            reset = reset,
54            grey = grey,
55            gray = gray,
56            white = white,
57            frontend = $self.frontend_token.0,
58            backend = $self.backend_token.map(|t| t.0.to_string()).unwrap_or_else(|| "<none>".to_string()),
59            front_readiness = $self.frontend_readiness,
60            back_readiness = $self.backend_readiness,
61        )
62    }};
63}
64
65pub struct SendProxyProtocol<Front: SocketHandler> {
66    cursor_header: usize,
67    pub backend_readiness: Readiness,
68    pub backend_token: Option<Token>,
69    pub backend: Option<TcpStream>,
70    pub frontend_readiness: Readiness,
71    pub frontend_token: Token,
72    pub frontend: Front,
73    pub header: Option<Vec<u8>>,
74    pub request_id: Ulid,
75}
76
77impl<Front: SocketHandler> SendProxyProtocol<Front> {
78    /// Instantiate a new SendProxyProtocol SessionState with:
79    /// - frontend_interest: HUP | ERROR
80    /// - frontend_event: EMPTY
81    /// - backend_interest: HUP | ERROR
82    /// - backend_event: EMPTY
83    pub fn new(
84        frontend: Front,
85        frontend_token: Token,
86        request_id: Ulid,
87        backend: Option<TcpStream>,
88    ) -> Self {
89        SendProxyProtocol {
90            header: None,
91            frontend,
92            request_id,
93            backend,
94            frontend_token,
95            backend_token: None,
96            frontend_readiness: Readiness {
97                interest: Ready::HUP | Ready::ERROR,
98                event: Ready::EMPTY,
99            },
100            backend_readiness: Readiness {
101                interest: Ready::HUP | Ready::ERROR,
102                event: Ready::EMPTY,
103            },
104            cursor_header: 0,
105        }
106    }
107
108    // The header is send immediately at once upon the connection is establish
109    // and prepended before any data.
110    pub fn back_writable(&mut self, metrics: &mut SessionMetrics) -> SessionResult {
111        debug!(
112            "{} trying to write proxy protocol header",
113            log_context!(self)
114        );
115
116        // Generate the proxy protocol header if not already exist.
117        if self.header.is_none() {
118            if let Ok(local_addr) = self.front_socket().local_addr() {
119                if let Ok(frontend_addr) = self.front_socket().peer_addr() {
120                    let v2 = HeaderV2::new(Command::Proxy, frontend_addr, local_addr);
121                    let declared_len = v2.len();
122                    let serialized = ProxyProtocolHeader::V2(v2).into_bytes();
123                    // Send postcondition: the byte vector we will stream out is
124                    // exactly the length the header model declared. The cursor
125                    // logic below relies on `header.len()` being this serialized
126                    // size to detect completion.
127                    debug_assert_eq!(
128                        serialized.len(),
129                        declared_len,
130                        "serialized send header length must match HeaderV2::len()"
131                    );
132                    debug_assert!(
133                        serialized.len() >= 16,
134                        "a v2 send header is at least its 16-byte fixed prefix"
135                    );
136                    self.header = Some(serialized);
137                } else {
138                    return SessionResult::Close;
139                }
140            };
141        }
142
143        if let Some(ref mut socket) = self.backend {
144            if let Some(ref mut header) = self.header {
145                loop {
146                    // The cursor never overruns the serialized header: it only
147                    // advances by reported write sizes and stops at `len()`.
148                    debug_assert!(
149                        self.cursor_header <= header.len(),
150                        "send cursor must stay within the serialized header"
151                    );
152                    let remaining = header.len() - self.cursor_header;
153                    match socket.write(&header[self.cursor_header..]) {
154                        Ok(sz) => {
155                            debug_assert!(
156                                sz <= remaining,
157                                "socket.write cannot send more than the unsent header tail"
158                            );
159                            let cursor_before = self.cursor_header;
160                            self.cursor_header += sz;
161                            // Strictly monotonic: the cursor tracks exactly the
162                            // bytes emitted and never passes the header length.
163                            debug_assert_eq!(
164                                self.cursor_header,
165                                cursor_before + sz,
166                                "send cursor advances by exactly the bytes written"
167                            );
168                            debug_assert!(
169                                self.cursor_header <= header.len(),
170                                "send cursor must not pass the header length"
171                            );
172                            count!(names::backend::BACK_BYTES_OUT, sz as i64);
173                            metrics.backend_bout += sz;
174
175                            if self.cursor_header == header.len() {
176                                debug!("{} proxy protocol sent, upgrading", log_context!(self));
177                                return SessionResult::Upgrade;
178                            }
179                        }
180                        Err(e) => match e.kind() {
181                            ErrorKind::WouldBlock => {
182                                self.backend_readiness.event.remove(Ready::WRITABLE);
183                                return SessionResult::Continue;
184                            }
185                            e => {
186                                incr!(names::proxy_protocol::ERRORS);
187                                debug!("{} write error: {:?}", log_context!(self), e);
188                                return SessionResult::Close;
189                            }
190                        },
191                    }
192                }
193            }
194        }
195
196        error!(
197            "{} started send proxy protocol with no header or backend socket",
198            log_context!(self)
199        );
200        SessionResult::Close
201    }
202
203    pub fn front_socket(&self) -> &TcpStream {
204        self.frontend.socket_ref()
205    }
206
207    pub fn front_socket_mut(&mut self) -> &mut TcpStream {
208        self.frontend.socket_mut()
209    }
210
211    pub fn back_socket(&self) -> Option<&TcpStream> {
212        self.backend.as_ref()
213    }
214
215    pub fn back_socket_mut(&mut self) -> Option<&mut TcpStream> {
216        self.backend.as_mut()
217    }
218
219    pub fn set_back_socket(&mut self, socket: TcpStream) {
220        self.backend = Some(socket);
221    }
222
223    pub fn back_token(&self) -> Option<Token> {
224        self.backend_token
225    }
226
227    pub fn set_back_token(&mut self, token: Token) {
228        self.backend_token = Some(token);
229    }
230
231    pub fn set_back_connected(&mut self, status: BackendConnectionStatus) {
232        if status == BackendConnectionStatus::Connected {
233            self.backend_readiness.interest.insert(Ready::WRITABLE);
234        }
235    }
236
237    pub fn into_pipe(
238        mut self,
239        front_buf: Checkout,
240        back_buf: Checkout,
241        listener: Rc<RefCell<TcpListener>>,
242    ) -> Pipe<Front, TcpListener> {
243        let backend_socket = self.backend.take().unwrap();
244        let addr = self.front_socket().peer_addr().ok();
245
246        let mut pipe = Pipe::new(
247            back_buf,
248            None,
249            Some(backend_socket),
250            None,
251            None,
252            None,
253            None,
254            front_buf,
255            self.frontend_token,
256            self.frontend,
257            listener,
258            Protocol::TCP,
259            self.request_id,
260            self.request_id,
261            addr,
262            WebSocketContext::Tcp,
263        );
264
265        pipe.frontend_readiness = self.frontend_readiness;
266        pipe.backend_readiness = self.backend_readiness;
267
268        pipe.frontend_readiness.interest.insert(Ready::READABLE);
269        pipe.backend_readiness.interest.insert(Ready::READABLE);
270
271        if let Some(back_token) = self.backend_token {
272            pipe.set_back_token(back_token);
273        }
274
275        pipe
276    }
277}
278
279#[cfg(test)]
280mod send_test {
281    use std::{
282        io::Read,
283        net::{SocketAddr, TcpListener as StdTcpListener, TcpStream as StdTcpStream},
284        os::unix::io::{FromRawFd, IntoRawFd},
285        sync::{Arc, Barrier},
286        thread::{self, JoinHandle},
287    };
288
289    use mio::net::{TcpListener, TcpStream};
290    use rusty_ulid::Ulid;
291
292    use super::{
293        super::parser::parse_v2_header, BackendConnectionStatus, ErrorKind, SendProxyProtocol,
294        SessionMetrics, SessionResult, Token,
295    };
296
297    #[test]
298    fn it_should_send_a_proxy_protocol_header_to_the_upstream_backend() {
299        setup_test_logger!();
300        let addr_client: SocketAddr = "127.0.0.1:6666".parse().expect("parse address error");
301        let addr_backend: SocketAddr = "127.0.0.1:2001".parse().expect("parse address error");
302        let barrier = Arc::new(Barrier::new(3));
303        let end_barrier = Arc::new(Barrier::new(2));
304
305        start_client(addr_client, barrier.clone(), end_barrier.clone());
306        let backend = start_backend(addr_backend, barrier.clone(), end_barrier);
307        start_middleware(addr_client, addr_backend, barrier);
308
309        backend
310            .join()
311            .expect("Couldn't join on the associated backend");
312    }
313
314    // Get connection from the session and connect to the backend
315    // When connections are establish we send the proxy protocol header
316    fn start_middleware(addr_client: SocketAddr, addr_backend: SocketAddr, barrier: Arc<Barrier>) {
317        let listener = TcpListener::bind(addr_client).expect("could not accept session connection");
318
319        let client_stream;
320        barrier.wait();
321
322        loop {
323            if let Ok((stream, _addr)) = listener.accept() {
324                client_stream = stream;
325                break;
326            }
327        }
328
329        // connect in blocking first, then convert to a mio socket
330        let backend_stream =
331            StdTcpStream::connect(addr_backend).expect("could not connect to the backend");
332        let fd = backend_stream.into_raw_fd();
333        // SAFETY: `fd` was just released by `into_raw_fd` from the blocking
334        // `StdTcpStream` so it is a valid open descriptor with no other owner.
335        // Ownership transfers to the mio `TcpStream`, whose `Drop` closes it.
336        let backend_stream = unsafe { TcpStream::from_raw_fd(fd) };
337
338        let mut send_pp = SendProxyProtocol::new(
339            client_stream,
340            Token(0),
341            Ulid::generate(),
342            Some(backend_stream),
343        );
344        let mut session_metrics = SessionMetrics::new(None);
345
346        send_pp.set_back_connected(BackendConnectionStatus::Connected);
347
348        loop {
349            let result = send_pp.back_writable(&mut session_metrics);
350            if result == SessionResult::Upgrade {
351                break;
352            }
353
354            if result != SessionResult::Continue {
355                panic!("state machine error: result = {result:?}");
356            }
357        }
358    }
359
360    // Only connect to the middleware
361    fn start_client(addr: SocketAddr, barrier: Arc<Barrier>, end_barrier: Arc<Barrier>) {
362        thread::spawn(move || {
363            barrier.wait();
364
365            let _stream = StdTcpStream::connect(addr).unwrap();
366
367            end_barrier.wait();
368        });
369    }
370
371    // Get connection from the middleware read on the socket stream.
372    // We check if we receive a valid proxy protocol header
373    fn start_backend(
374        addr: SocketAddr,
375        barrier: Arc<Barrier>,
376        end_barrier: Arc<Barrier>,
377    ) -> JoinHandle<()> {
378        let listener = StdTcpListener::bind(addr).expect("could not start backend");
379
380        thread::spawn(move || {
381            barrier.wait();
382
383            let mut buf: [u8; 28] = [0; 28];
384            let (mut conn, _) = listener
385                .accept()
386                .expect("could not accept connection from light middleware");
387            println!("backend got a connection from the middleware");
388
389            let mut index = 0usize;
390            loop {
391                if index >= 28 {
392                    break;
393                }
394
395                match conn.read(&mut buf[index..]) {
396                    Err(e) => match e.kind() {
397                        ErrorKind::WouldBlock => continue,
398                        e => {
399                            end_barrier.wait();
400                            panic!("read error: {e:?}");
401                        }
402                    },
403                    Ok(sz) => {
404                        println!("backend read {sz} bytes");
405                        index += sz;
406                    }
407                }
408            }
409
410            match parse_v2_header(&buf) {
411                Ok((_, _)) => println!("complete header received"),
412                err => {
413                    end_barrier.wait();
414                    panic!("incorrect proxy protocol header received: {err:?}");
415                }
416            };
417
418            end_barrier.wait();
419        })
420    }
421}