sozu_lib/protocol/proxy_protocol/
send.rs1use std::{
2 cell::RefCell,
3 io::{ErrorKind, Write},
4 rc::Rc,
5};
6
7use mio::{net::TcpStream, Token};
8use rusty_ulid::Ulid;
9
10use crate::{
11 pool::Checkout,
12 protocol::{
13 pipe::{Pipe, WebSocketContext},
14 proxy_protocol::header::{Command, HeaderV2, ProxyProtocolHeader},
15 },
16 socket::SocketHandler,
17 sozu_command::ready::Ready,
18 tcp::TcpListener,
19 BackendConnectionStatus, Protocol, Readiness, SessionMetrics, SessionResult,
20};
21
22pub struct SendProxyProtocol<Front: SocketHandler> {
23 cursor_header: usize,
24 pub backend_readiness: Readiness,
25 pub backend_token: Option<Token>,
26 pub backend: Option<TcpStream>,
27 pub frontend_readiness: Readiness,
28 pub frontend_token: Token,
29 pub frontend: Front,
30 pub header: Option<Vec<u8>>,
31 pub request_id: Ulid,
32}
33
34impl<Front: SocketHandler> SendProxyProtocol<Front> {
35 pub fn new(
41 frontend: Front,
42 frontend_token: Token,
43 request_id: Ulid,
44 backend: Option<TcpStream>,
45 ) -> Self {
46 SendProxyProtocol {
47 header: None,
48 frontend,
49 request_id,
50 backend,
51 frontend_token,
52 backend_token: None,
53 frontend_readiness: Readiness {
54 interest: Ready::HUP | Ready::ERROR,
55 event: Ready::EMPTY,
56 },
57 backend_readiness: Readiness {
58 interest: Ready::HUP | Ready::ERROR,
59 event: Ready::EMPTY,
60 },
61 cursor_header: 0,
62 }
63 }
64
65 pub fn back_writable(&mut self, metrics: &mut SessionMetrics) -> SessionResult {
68 debug!("Trying to write proxy protocol header");
69
70 if self.header.is_none() {
72 if let Ok(local_addr) = self.front_socket().local_addr() {
73 if let Ok(frontend_addr) = self.front_socket().peer_addr() {
74 self.header = Some(
75 ProxyProtocolHeader::V2(HeaderV2::new(
76 Command::Proxy,
77 frontend_addr,
78 local_addr,
79 ))
80 .into_bytes(),
81 );
82 } else {
83 return SessionResult::Close;
84 }
85 };
86 }
87
88 if let Some(ref mut socket) = self.backend {
89 if let Some(ref mut header) = self.header {
90 loop {
91 match socket.write(&header[self.cursor_header..]) {
92 Ok(sz) => {
93 self.cursor_header += sz;
94 metrics.backend_bout += sz;
95
96 if self.cursor_header == header.len() {
97 debug!("Proxy protocol sent, upgrading");
98 return SessionResult::Upgrade;
99 }
100 }
101 Err(e) => match e.kind() {
102 ErrorKind::WouldBlock => {
103 self.backend_readiness.event.remove(Ready::WRITABLE);
104 return SessionResult::Continue;
105 }
106 e => {
107 incr!("proxy_protocol.errors");
108 debug!("send proxy protocol write error {:?}", e);
109 return SessionResult::Close;
110 }
111 },
112 }
113 }
114 }
115 }
116
117 error!("started Send proxy protocol with no header or backend socket");
118 SessionResult::Close
119 }
120
121 pub fn front_socket(&self) -> &TcpStream {
122 self.frontend.socket_ref()
123 }
124
125 pub fn front_socket_mut(&mut self) -> &mut TcpStream {
126 self.frontend.socket_mut()
127 }
128
129 pub fn back_socket(&self) -> Option<&TcpStream> {
130 self.backend.as_ref()
131 }
132
133 pub fn back_socket_mut(&mut self) -> Option<&mut TcpStream> {
134 self.backend.as_mut()
135 }
136
137 pub fn set_back_socket(&mut self, socket: TcpStream) {
138 self.backend = Some(socket);
139 }
140
141 pub fn back_token(&self) -> Option<Token> {
142 self.backend_token
143 }
144
145 pub fn set_back_token(&mut self, token: Token) {
146 self.backend_token = Some(token);
147 }
148
149 pub fn set_back_connected(&mut self, status: BackendConnectionStatus) {
150 if status == BackendConnectionStatus::Connected {
151 self.backend_readiness.interest.insert(Ready::WRITABLE);
152 }
153 }
154
155 pub fn into_pipe(
156 mut self,
157 front_buf: Checkout,
158 back_buf: Checkout,
159 listener: Rc<RefCell<TcpListener>>,
160 ) -> Pipe<Front, TcpListener> {
161 let backend_socket = self.backend.take().unwrap();
162 let addr = self.front_socket().peer_addr().ok();
163
164 let mut pipe = Pipe::new(
165 back_buf,
166 None,
167 Some(backend_socket),
168 None,
169 None,
170 None,
171 None,
172 front_buf,
173 self.frontend_token,
174 self.frontend,
175 listener,
176 Protocol::TCP,
177 self.request_id,
178 addr,
179 WebSocketContext::Tcp,
180 );
181
182 pipe.frontend_readiness = self.frontend_readiness;
183 pipe.backend_readiness = self.backend_readiness;
184
185 pipe.frontend_readiness.interest.insert(Ready::READABLE);
186 pipe.backend_readiness.interest.insert(Ready::READABLE);
187
188 if let Some(back_token) = self.backend_token {
189 pipe.set_back_token(back_token);
190 }
191
192 pipe
193 }
194}
195
196#[cfg(test)]
197mod send_test {
198 use std::{
199 io::Read,
200 net::{SocketAddr, TcpListener as StdTcpListener, TcpStream as StdTcpStream},
201 os::unix::io::{FromRawFd, IntoRawFd},
202 sync::{Arc, Barrier},
203 thread::{self, JoinHandle},
204 };
205
206 use mio::net::{TcpListener, TcpStream};
207 use rusty_ulid::Ulid;
208
209 use super::{
210 super::parser::parse_v2_header, BackendConnectionStatus, ErrorKind, SendProxyProtocol,
211 SessionMetrics, SessionResult, Token,
212 };
213
214 #[test]
215 fn it_should_send_a_proxy_protocol_header_to_the_upstream_backend() {
216 setup_test_logger!();
217 let addr_client: SocketAddr = "127.0.0.1:6666".parse().expect("parse address error");
218 let addr_backend: SocketAddr = "127.0.0.1:2001".parse().expect("parse address error");
219 let barrier = Arc::new(Barrier::new(3));
220 let end_barrier = Arc::new(Barrier::new(2));
221
222 start_client(addr_client, barrier.clone(), end_barrier.clone());
223 let backend = start_backend(addr_backend, barrier.clone(), end_barrier);
224 start_middleware(addr_client, addr_backend, barrier);
225
226 backend
227 .join()
228 .expect("Couldn't join on the associated backend");
229 }
230
231 fn start_middleware(addr_client: SocketAddr, addr_backend: SocketAddr, barrier: Arc<Barrier>) {
234 let listener = TcpListener::bind(addr_client).expect("could not accept session connection");
235
236 let client_stream;
237 barrier.wait();
238
239 loop {
240 if let Ok((stream, _addr)) = listener.accept() {
241 client_stream = stream;
242 break;
243 }
244 }
245
246 let backend_stream =
248 StdTcpStream::connect(addr_backend).expect("could not connect to the backend");
249 let fd = backend_stream.into_raw_fd();
250 let backend_stream = unsafe { TcpStream::from_raw_fd(fd) };
251
252 let mut send_pp = SendProxyProtocol::new(
253 client_stream,
254 Token(0),
255 Ulid::generate(),
256 Some(backend_stream),
257 );
258 let mut session_metrics = SessionMetrics::new(None);
259
260 send_pp.set_back_connected(BackendConnectionStatus::Connected);
261
262 loop {
263 let result = send_pp.back_writable(&mut session_metrics);
264 if result == SessionResult::Upgrade {
265 break;
266 }
267
268 if result != SessionResult::Continue {
269 panic!("state machine error: result = {result:?}");
270 }
271 }
272 }
273
274 fn start_client(addr: SocketAddr, barrier: Arc<Barrier>, end_barrier: Arc<Barrier>) {
276 thread::spawn(move || {
277 barrier.wait();
278
279 let _stream = StdTcpStream::connect(addr).unwrap();
280
281 end_barrier.wait();
282 });
283 }
284
285 fn start_backend(
288 addr: SocketAddr,
289 barrier: Arc<Barrier>,
290 end_barrier: Arc<Barrier>,
291 ) -> JoinHandle<()> {
292 let listener = StdTcpListener::bind(addr).expect("could not start backend");
293
294 thread::spawn(move || {
295 barrier.wait();
296
297 let mut buf: [u8; 28] = [0; 28];
298 let (mut conn, _) = listener
299 .accept()
300 .expect("could not accept connection from light middleware");
301 println!("backend got a connection from the middleware");
302
303 let mut index = 0usize;
304 loop {
305 if index >= 28 {
306 break;
307 }
308
309 match conn.read(&mut buf[index..]) {
310 Err(e) => match e.kind() {
311 ErrorKind::WouldBlock => continue,
312 e => {
313 end_barrier.wait();
314 panic!("read error: {e:?}");
315 }
316 },
317 Ok(sz) => {
318 println!("backend read {sz} bytes");
319 index += sz;
320 }
321 }
322 }
323
324 match parse_v2_header(&buf) {
325 Ok((_, _)) => println!("complete header received"),
326 err => {
327 end_barrier.wait();
328 panic!("incorrect proxy protocol header received: {err:?}");
329 }
330 };
331
332 end_barrier.wait();
333 })
334 }
335}