sozu_lib/protocol/proxy_protocol/
send.rs

1use std::{
2    cell::RefCell,
3    io::{ErrorKind, Write},
4    rc::Rc,
5};
6
7use mio::{net::TcpStream, Token};
8use rusty_ulid::Ulid;
9
10use crate::{
11    pool::Checkout,
12    protocol::{
13        pipe::{Pipe, WebSocketContext},
14        proxy_protocol::header::{Command, HeaderV2, ProxyProtocolHeader},
15    },
16    socket::SocketHandler,
17    sozu_command::ready::Ready,
18    tcp::TcpListener,
19    BackendConnectionStatus, Protocol, Readiness, SessionMetrics, SessionResult,
20};
21
22pub struct SendProxyProtocol<Front: SocketHandler> {
23    cursor_header: usize,
24    pub backend_readiness: Readiness,
25    pub backend_token: Option<Token>,
26    pub backend: Option<TcpStream>,
27    pub frontend_readiness: Readiness,
28    pub frontend_token: Token,
29    pub frontend: Front,
30    pub header: Option<Vec<u8>>,
31    pub request_id: Ulid,
32}
33
34impl<Front: SocketHandler> SendProxyProtocol<Front> {
35    /// Instantiate a new SendProxyProtocol SessionState with:
36    /// - frontend_interest: HUP | ERROR
37    /// - frontend_event: EMPTY
38    /// - backend_interest: HUP | ERROR
39    /// - backend_event: EMPTY
40    pub fn new(
41        frontend: Front,
42        frontend_token: Token,
43        request_id: Ulid,
44        backend: Option<TcpStream>,
45    ) -> Self {
46        SendProxyProtocol {
47            header: None,
48            frontend,
49            request_id,
50            backend,
51            frontend_token,
52            backend_token: None,
53            frontend_readiness: Readiness {
54                interest: Ready::HUP | Ready::ERROR,
55                event: Ready::EMPTY,
56            },
57            backend_readiness: Readiness {
58                interest: Ready::HUP | Ready::ERROR,
59                event: Ready::EMPTY,
60            },
61            cursor_header: 0,
62        }
63    }
64
65    // The header is send immediately at once upon the connection is establish
66    // and prepended before any data.
67    pub fn back_writable(&mut self, metrics: &mut SessionMetrics) -> SessionResult {
68        debug!("Trying to write proxy protocol header");
69
70        // Generate the proxy protocol header if not already exist.
71        if self.header.is_none() {
72            if let Ok(local_addr) = self.front_socket().local_addr() {
73                if let Ok(frontend_addr) = self.front_socket().peer_addr() {
74                    self.header = Some(
75                        ProxyProtocolHeader::V2(HeaderV2::new(
76                            Command::Proxy,
77                            frontend_addr,
78                            local_addr,
79                        ))
80                        .into_bytes(),
81                    );
82                } else {
83                    return SessionResult::Close;
84                }
85            };
86        }
87
88        if let Some(ref mut socket) = self.backend {
89            if let Some(ref mut header) = self.header {
90                loop {
91                    match socket.write(&header[self.cursor_header..]) {
92                        Ok(sz) => {
93                            self.cursor_header += sz;
94                            metrics.backend_bout += sz;
95
96                            if self.cursor_header == header.len() {
97                                debug!("Proxy protocol sent, upgrading");
98                                return SessionResult::Upgrade;
99                            }
100                        }
101                        Err(e) => match e.kind() {
102                            ErrorKind::WouldBlock => {
103                                self.backend_readiness.event.remove(Ready::WRITABLE);
104                                return SessionResult::Continue;
105                            }
106                            e => {
107                                incr!("proxy_protocol.errors");
108                                debug!("send proxy protocol write error {:?}", e);
109                                return SessionResult::Close;
110                            }
111                        },
112                    }
113                }
114            }
115        }
116
117        error!("started Send proxy protocol with no header or backend socket");
118        SessionResult::Close
119    }
120
121    pub fn front_socket(&self) -> &TcpStream {
122        self.frontend.socket_ref()
123    }
124
125    pub fn front_socket_mut(&mut self) -> &mut TcpStream {
126        self.frontend.socket_mut()
127    }
128
129    pub fn back_socket(&self) -> Option<&TcpStream> {
130        self.backend.as_ref()
131    }
132
133    pub fn back_socket_mut(&mut self) -> Option<&mut TcpStream> {
134        self.backend.as_mut()
135    }
136
137    pub fn set_back_socket(&mut self, socket: TcpStream) {
138        self.backend = Some(socket);
139    }
140
141    pub fn back_token(&self) -> Option<Token> {
142        self.backend_token
143    }
144
145    pub fn set_back_token(&mut self, token: Token) {
146        self.backend_token = Some(token);
147    }
148
149    pub fn set_back_connected(&mut self, status: BackendConnectionStatus) {
150        if status == BackendConnectionStatus::Connected {
151            self.backend_readiness.interest.insert(Ready::WRITABLE);
152        }
153    }
154
155    pub fn into_pipe(
156        mut self,
157        front_buf: Checkout,
158        back_buf: Checkout,
159        listener: Rc<RefCell<TcpListener>>,
160    ) -> Pipe<Front, TcpListener> {
161        let backend_socket = self.backend.take().unwrap();
162        let addr = self.front_socket().peer_addr().ok();
163
164        let mut pipe = Pipe::new(
165            back_buf,
166            None,
167            Some(backend_socket),
168            None,
169            None,
170            None,
171            None,
172            front_buf,
173            self.frontend_token,
174            self.frontend,
175            listener,
176            Protocol::TCP,
177            self.request_id,
178            addr,
179            WebSocketContext::Tcp,
180        );
181
182        pipe.frontend_readiness = self.frontend_readiness;
183        pipe.backend_readiness = self.backend_readiness;
184
185        pipe.frontend_readiness.interest.insert(Ready::READABLE);
186        pipe.backend_readiness.interest.insert(Ready::READABLE);
187
188        if let Some(back_token) = self.backend_token {
189            pipe.set_back_token(back_token);
190        }
191
192        pipe
193    }
194}
195
196#[cfg(test)]
197mod send_test {
198    use std::{
199        io::Read,
200        net::{SocketAddr, TcpListener as StdTcpListener, TcpStream as StdTcpStream},
201        os::unix::io::{FromRawFd, IntoRawFd},
202        sync::{Arc, Barrier},
203        thread::{self, JoinHandle},
204    };
205
206    use mio::net::{TcpListener, TcpStream};
207    use rusty_ulid::Ulid;
208
209    use super::{
210        super::parser::parse_v2_header, BackendConnectionStatus, ErrorKind, SendProxyProtocol,
211        SessionMetrics, SessionResult, Token,
212    };
213
214    #[test]
215    fn it_should_send_a_proxy_protocol_header_to_the_upstream_backend() {
216        setup_test_logger!();
217        let addr_client: SocketAddr = "127.0.0.1:6666".parse().expect("parse address error");
218        let addr_backend: SocketAddr = "127.0.0.1:2001".parse().expect("parse address error");
219        let barrier = Arc::new(Barrier::new(3));
220        let end_barrier = Arc::new(Barrier::new(2));
221
222        start_client(addr_client, barrier.clone(), end_barrier.clone());
223        let backend = start_backend(addr_backend, barrier.clone(), end_barrier);
224        start_middleware(addr_client, addr_backend, barrier);
225
226        backend
227            .join()
228            .expect("Couldn't join on the associated backend");
229    }
230
231    // Get connection from the session and connect to the backend
232    // When connections are establish we send the proxy protocol header
233    fn start_middleware(addr_client: SocketAddr, addr_backend: SocketAddr, barrier: Arc<Barrier>) {
234        let listener = TcpListener::bind(addr_client).expect("could not accept session connection");
235
236        let client_stream;
237        barrier.wait();
238
239        loop {
240            if let Ok((stream, _addr)) = listener.accept() {
241                client_stream = stream;
242                break;
243            }
244        }
245
246        // connect in blocking first, then convert to a mio socket
247        let backend_stream =
248            StdTcpStream::connect(addr_backend).expect("could not connect to the backend");
249        let fd = backend_stream.into_raw_fd();
250        let backend_stream = unsafe { TcpStream::from_raw_fd(fd) };
251
252        let mut send_pp = SendProxyProtocol::new(
253            client_stream,
254            Token(0),
255            Ulid::generate(),
256            Some(backend_stream),
257        );
258        let mut session_metrics = SessionMetrics::new(None);
259
260        send_pp.set_back_connected(BackendConnectionStatus::Connected);
261
262        loop {
263            let result = send_pp.back_writable(&mut session_metrics);
264            if result == SessionResult::Upgrade {
265                break;
266            }
267
268            if result != SessionResult::Continue {
269                panic!("state machine error: result = {result:?}");
270            }
271        }
272    }
273
274    // Only connect to the middleware
275    fn start_client(addr: SocketAddr, barrier: Arc<Barrier>, end_barrier: Arc<Barrier>) {
276        thread::spawn(move || {
277            barrier.wait();
278
279            let _stream = StdTcpStream::connect(addr).unwrap();
280
281            end_barrier.wait();
282        });
283    }
284
285    // Get connection from the middleware read on the socket stream.
286    // We check if we receive a valid proxy protocol header
287    fn start_backend(
288        addr: SocketAddr,
289        barrier: Arc<Barrier>,
290        end_barrier: Arc<Barrier>,
291    ) -> JoinHandle<()> {
292        let listener = StdTcpListener::bind(addr).expect("could not start backend");
293
294        thread::spawn(move || {
295            barrier.wait();
296
297            let mut buf: [u8; 28] = [0; 28];
298            let (mut conn, _) = listener
299                .accept()
300                .expect("could not accept connection from light middleware");
301            println!("backend got a connection from the middleware");
302
303            let mut index = 0usize;
304            loop {
305                if index >= 28 {
306                    break;
307                }
308
309                match conn.read(&mut buf[index..]) {
310                    Err(e) => match e.kind() {
311                        ErrorKind::WouldBlock => continue,
312                        e => {
313                            end_barrier.wait();
314                            panic!("read error: {e:?}");
315                        }
316                    },
317                    Ok(sz) => {
318                        println!("backend read {sz} bytes");
319                        index += sz;
320                    }
321                }
322            }
323
324            match parse_v2_header(&buf) {
325                Ok((_, _)) => println!("complete header received"),
326                err => {
327                    end_barrier.wait();
328                    panic!("incorrect proxy protocol header received: {err:?}");
329                }
330            };
331
332            end_barrier.wait();
333        })
334    }
335}