sozu_lib/protocol/proxy_protocol/
send.rs1use std::{
9 cell::RefCell,
10 io::{ErrorKind, Write},
11 rc::Rc,
12};
13
14use mio::{Token, net::TcpStream};
15use rusty_ulid::Ulid;
16use sozu_command::logging::ansi_palette;
17
18use crate::metrics::names;
19use crate::{
20 BackendConnectionStatus, Protocol, Readiness, SessionMetrics, SessionResult,
21 pool::Checkout,
22 protocol::{
23 pipe::{Pipe, WebSocketContext},
24 proxy_protocol::header::{Command, HeaderV2, ProxyProtocolHeader},
25 },
26 socket::SocketHandler,
27 sozu_command::ready::Ready,
28 tcp::TcpListener,
29};
30
31#[allow(unused_macros)]
36macro_rules! log_module_context {
37 () => {{
38 let (open, reset, _, _, _) = ansi_palette();
39 format!("{open}PROXY-SEND{reset}\t >>>", open = open, reset = reset)
40 }};
41}
42
43macro_rules! log_context {
48 ($self:expr) => {{
49 let (open, reset, grey, gray, white) = ansi_palette();
50 format!(
51 "{open}PROXY-SEND{reset}\t{grey}Session{reset}({gray}frontend{reset}={white}{frontend}{reset}, {gray}backend{reset}={white}{backend}{reset}, {gray}front_readiness{reset}={white}{front_readiness}{reset}, {gray}back_readiness{reset}={white}{back_readiness}{reset})\t >>>",
52 open = open,
53 reset = reset,
54 grey = grey,
55 gray = gray,
56 white = white,
57 frontend = $self.frontend_token.0,
58 backend = $self.backend_token.map(|t| t.0.to_string()).unwrap_or_else(|| "<none>".to_string()),
59 front_readiness = $self.frontend_readiness,
60 back_readiness = $self.backend_readiness,
61 )
62 }};
63}
64
65pub struct SendProxyProtocol<Front: SocketHandler> {
66 cursor_header: usize,
67 pub backend_readiness: Readiness,
68 pub backend_token: Option<Token>,
69 pub backend: Option<TcpStream>,
70 pub frontend_readiness: Readiness,
71 pub frontend_token: Token,
72 pub frontend: Front,
73 pub header: Option<Vec<u8>>,
74 pub request_id: Ulid,
75}
76
77impl<Front: SocketHandler> SendProxyProtocol<Front> {
78 pub fn new(
84 frontend: Front,
85 frontend_token: Token,
86 request_id: Ulid,
87 backend: Option<TcpStream>,
88 ) -> Self {
89 SendProxyProtocol {
90 header: None,
91 frontend,
92 request_id,
93 backend,
94 frontend_token,
95 backend_token: None,
96 frontend_readiness: Readiness {
97 interest: Ready::HUP | Ready::ERROR,
98 event: Ready::EMPTY,
99 },
100 backend_readiness: Readiness {
101 interest: Ready::HUP | Ready::ERROR,
102 event: Ready::EMPTY,
103 },
104 cursor_header: 0,
105 }
106 }
107
108 pub fn back_writable(&mut self, metrics: &mut SessionMetrics) -> SessionResult {
111 debug!(
112 "{} trying to write proxy protocol header",
113 log_context!(self)
114 );
115
116 if self.header.is_none() {
118 if let Ok(local_addr) = self.front_socket().local_addr() {
119 if let Ok(frontend_addr) = self.front_socket().peer_addr() {
120 let v2 = HeaderV2::new(Command::Proxy, frontend_addr, local_addr);
121 let declared_len = v2.len();
122 let serialized = ProxyProtocolHeader::V2(v2).into_bytes();
123 debug_assert_eq!(
128 serialized.len(),
129 declared_len,
130 "serialized send header length must match HeaderV2::len()"
131 );
132 debug_assert!(
133 serialized.len() >= 16,
134 "a v2 send header is at least its 16-byte fixed prefix"
135 );
136 self.header = Some(serialized);
137 } else {
138 return SessionResult::Close;
139 }
140 };
141 }
142
143 if let Some(ref mut socket) = self.backend {
144 if let Some(ref mut header) = self.header {
145 loop {
146 debug_assert!(
149 self.cursor_header <= header.len(),
150 "send cursor must stay within the serialized header"
151 );
152 let remaining = header.len() - self.cursor_header;
153 match socket.write(&header[self.cursor_header..]) {
154 Ok(sz) => {
155 debug_assert!(
156 sz <= remaining,
157 "socket.write cannot send more than the unsent header tail"
158 );
159 let cursor_before = self.cursor_header;
160 self.cursor_header += sz;
161 debug_assert_eq!(
164 self.cursor_header,
165 cursor_before + sz,
166 "send cursor advances by exactly the bytes written"
167 );
168 debug_assert!(
169 self.cursor_header <= header.len(),
170 "send cursor must not pass the header length"
171 );
172 count!(names::backend::BACK_BYTES_OUT, sz as i64);
173 metrics.backend_bout += sz;
174
175 if self.cursor_header == header.len() {
176 debug!("{} proxy protocol sent, upgrading", log_context!(self));
177 return SessionResult::Upgrade;
178 }
179 }
180 Err(e) => match e.kind() {
181 ErrorKind::WouldBlock => {
182 self.backend_readiness.event.remove(Ready::WRITABLE);
183 return SessionResult::Continue;
184 }
185 e => {
186 incr!(names::proxy_protocol::ERRORS);
187 debug!("{} write error: {:?}", log_context!(self), e);
188 return SessionResult::Close;
189 }
190 },
191 }
192 }
193 }
194 }
195
196 error!(
197 "{} started send proxy protocol with no header or backend socket",
198 log_context!(self)
199 );
200 SessionResult::Close
201 }
202
203 pub fn front_socket(&self) -> &TcpStream {
204 self.frontend.socket_ref()
205 }
206
207 pub fn front_socket_mut(&mut self) -> &mut TcpStream {
208 self.frontend.socket_mut()
209 }
210
211 pub fn back_socket(&self) -> Option<&TcpStream> {
212 self.backend.as_ref()
213 }
214
215 pub fn back_socket_mut(&mut self) -> Option<&mut TcpStream> {
216 self.backend.as_mut()
217 }
218
219 pub fn set_back_socket(&mut self, socket: TcpStream) {
220 self.backend = Some(socket);
221 }
222
223 pub fn back_token(&self) -> Option<Token> {
224 self.backend_token
225 }
226
227 pub fn set_back_token(&mut self, token: Token) {
228 self.backend_token = Some(token);
229 }
230
231 pub fn set_back_connected(&mut self, status: BackendConnectionStatus) {
232 if status == BackendConnectionStatus::Connected {
233 self.backend_readiness.interest.insert(Ready::WRITABLE);
234 }
235 }
236
237 pub fn into_pipe(
238 mut self,
239 front_buf: Checkout,
240 back_buf: Checkout,
241 listener: Rc<RefCell<TcpListener>>,
242 ) -> Pipe<Front, TcpListener> {
243 let backend_socket = self.backend.take().unwrap();
244 let addr = self.front_socket().peer_addr().ok();
245
246 let mut pipe = Pipe::new(
247 back_buf,
248 None,
249 Some(backend_socket),
250 None,
251 None,
252 None,
253 None,
254 front_buf,
255 self.frontend_token,
256 self.frontend,
257 listener,
258 Protocol::TCP,
259 self.request_id,
260 self.request_id,
261 addr,
262 WebSocketContext::Tcp,
263 );
264
265 pipe.frontend_readiness = self.frontend_readiness;
266 pipe.backend_readiness = self.backend_readiness;
267
268 pipe.frontend_readiness.interest.insert(Ready::READABLE);
269 pipe.backend_readiness.interest.insert(Ready::READABLE);
270
271 if let Some(back_token) = self.backend_token {
272 pipe.set_back_token(back_token);
273 }
274
275 pipe
276 }
277}
278
279#[cfg(test)]
280mod send_test {
281 use std::{
282 io::Read,
283 net::{SocketAddr, TcpListener as StdTcpListener, TcpStream as StdTcpStream},
284 os::unix::io::{FromRawFd, IntoRawFd},
285 sync::{Arc, Barrier},
286 thread::{self, JoinHandle},
287 };
288
289 use mio::net::{TcpListener, TcpStream};
290 use rusty_ulid::Ulid;
291
292 use super::{
293 super::parser::parse_v2_header, BackendConnectionStatus, ErrorKind, SendProxyProtocol,
294 SessionMetrics, SessionResult, Token,
295 };
296
297 #[test]
298 fn it_should_send_a_proxy_protocol_header_to_the_upstream_backend() {
299 setup_test_logger!();
300 let addr_client: SocketAddr = "127.0.0.1:6666".parse().expect("parse address error");
301 let addr_backend: SocketAddr = "127.0.0.1:2001".parse().expect("parse address error");
302 let barrier = Arc::new(Barrier::new(3));
303 let end_barrier = Arc::new(Barrier::new(2));
304
305 start_client(addr_client, barrier.clone(), end_barrier.clone());
306 let backend = start_backend(addr_backend, barrier.clone(), end_barrier);
307 start_middleware(addr_client, addr_backend, barrier);
308
309 backend
310 .join()
311 .expect("Couldn't join on the associated backend");
312 }
313
314 fn start_middleware(addr_client: SocketAddr, addr_backend: SocketAddr, barrier: Arc<Barrier>) {
317 let listener = TcpListener::bind(addr_client).expect("could not accept session connection");
318
319 let client_stream;
320 barrier.wait();
321
322 loop {
323 if let Ok((stream, _addr)) = listener.accept() {
324 client_stream = stream;
325 break;
326 }
327 }
328
329 let backend_stream =
331 StdTcpStream::connect(addr_backend).expect("could not connect to the backend");
332 let fd = backend_stream.into_raw_fd();
333 let backend_stream = unsafe { TcpStream::from_raw_fd(fd) };
337
338 let mut send_pp = SendProxyProtocol::new(
339 client_stream,
340 Token(0),
341 Ulid::generate(),
342 Some(backend_stream),
343 );
344 let mut session_metrics = SessionMetrics::new(None);
345
346 send_pp.set_back_connected(BackendConnectionStatus::Connected);
347
348 loop {
349 let result = send_pp.back_writable(&mut session_metrics);
350 if result == SessionResult::Upgrade {
351 break;
352 }
353
354 if result != SessionResult::Continue {
355 panic!("state machine error: result = {result:?}");
356 }
357 }
358 }
359
360 fn start_client(addr: SocketAddr, barrier: Arc<Barrier>, end_barrier: Arc<Barrier>) {
362 thread::spawn(move || {
363 barrier.wait();
364
365 let _stream = StdTcpStream::connect(addr).unwrap();
366
367 end_barrier.wait();
368 });
369 }
370
371 fn start_backend(
374 addr: SocketAddr,
375 barrier: Arc<Barrier>,
376 end_barrier: Arc<Barrier>,
377 ) -> JoinHandle<()> {
378 let listener = StdTcpListener::bind(addr).expect("could not start backend");
379
380 thread::spawn(move || {
381 barrier.wait();
382
383 let mut buf: [u8; 28] = [0; 28];
384 let (mut conn, _) = listener
385 .accept()
386 .expect("could not accept connection from light middleware");
387 println!("backend got a connection from the middleware");
388
389 let mut index = 0usize;
390 loop {
391 if index >= 28 {
392 break;
393 }
394
395 match conn.read(&mut buf[index..]) {
396 Err(e) => match e.kind() {
397 ErrorKind::WouldBlock => continue,
398 e => {
399 end_barrier.wait();
400 panic!("read error: {e:?}");
401 }
402 },
403 Ok(sz) => {
404 println!("backend read {sz} bytes");
405 index += sz;
406 }
407 }
408 }
409
410 match parse_v2_header(&buf) {
411 Ok((_, _)) => println!("complete header received"),
412 err => {
413 end_barrier.wait();
414 panic!("incorrect proxy protocol header received: {err:?}");
415 }
416 };
417
418 end_barrier.wait();
419 })
420 }
421}