1use std::{
2 io::{ErrorKind, Read, Write},
3 net::SocketAddr,
4};
5
6use mio::net::{TcpListener, TcpStream};
7use rustls::{ProtocolVersion, ServerConnection};
8use socket2::{Domain, Protocol, Socket, Type};
9use sozu_command::config::MAX_LOOP_ITERATIONS;
10
11#[derive(thiserror::Error, Debug)]
12pub enum ServerBindError {
13 #[error("could not set bind to socket: {0}")]
14 BindError(std::io::Error),
15 #[error("could not listen on socket: {0}")]
16 Listen(std::io::Error),
17 #[error("could not set socket to nonblocking: {0}")]
18 SetNonBlocking(std::io::Error),
19 #[error("could not set reuse address: {0}")]
20 SetReuseAddress(std::io::Error),
21 #[error("could not set reuse address: {0}")]
22 SetReusePort(std::io::Error),
23 #[error("Could not create socket: {0}")]
24 SocketCreationError(std::io::Error),
25 #[error("Invalid socket address '{address}': {error}")]
26 InvalidSocketAddress { address: String, error: String },
27}
28
29#[derive(Debug, PartialEq, Eq, Copy, Clone)]
30pub enum SocketResult {
31 Continue,
32 Closed,
33 WouldBlock,
34 Error,
35}
36
37#[derive(Debug, PartialEq, Eq, Copy, Clone)]
38pub enum TransportProtocol {
39 Tcp,
40 Ssl2,
41 Ssl3,
42 Tls1_0,
43 Tls1_1,
44 Tls1_2,
45 Tls1_3,
46}
47
48pub trait SocketHandler {
49 fn socket_read(&mut self, buf: &mut [u8]) -> (usize, SocketResult);
50 fn socket_write(&mut self, buf: &[u8]) -> (usize, SocketResult);
51 fn socket_write_vectored(&mut self, _buf: &[std::io::IoSlice]) -> (usize, SocketResult);
52 fn socket_wants_write(&self) -> bool {
53 false
54 }
55 fn socket_close(&mut self) {}
56 fn socket_ref(&self) -> &TcpStream;
57 fn socket_mut(&mut self) -> &mut TcpStream;
58 fn protocol(&self) -> TransportProtocol;
59 fn read_error(&self);
60 fn write_error(&self);
61}
62
63impl SocketHandler for TcpStream {
64 fn socket_read(&mut self, buf: &mut [u8]) -> (usize, SocketResult) {
65 let mut size = 0usize;
66 let mut counter = 0;
67 loop {
68 counter += 1;
69 if counter > MAX_LOOP_ITERATIONS {
70 error!("MAX_LOOP_ITERATION reached in TcpStream::socket_read");
71 incr!("socket.read.infinite_loop.error");
72 }
73 if size == buf.len() {
74 return (size, SocketResult::Continue);
75 }
76 match self.read(&mut buf[size..]) {
77 Ok(0) => return (size, SocketResult::Closed),
78 Ok(sz) => size += sz,
79 Err(e) => match e.kind() {
80 ErrorKind::WouldBlock => return (size, SocketResult::WouldBlock),
81 ErrorKind::ConnectionReset
82 | ErrorKind::ConnectionAborted
83 | ErrorKind::BrokenPipe => return (size, SocketResult::Closed),
84 _ => {
85 error!("SOCKET\tsocket_read error={:?}", e);
86 return (size, SocketResult::Error);
87 }
88 },
89 }
90 }
91 }
92
93 fn socket_write(&mut self, buf: &[u8]) -> (usize, SocketResult) {
94 let mut size = 0usize;
95 let mut counter = 0;
96 loop {
97 counter += 1;
98 if counter > MAX_LOOP_ITERATIONS {
99 error!("MAX_LOOP_ITERATION reached in TcpStream::socket_write");
100 incr!("socket.write.infinite_loop.error");
101 }
102 if size == buf.len() {
103 return (size, SocketResult::Continue);
104 }
105 match self.write(&buf[size..]) {
106 Ok(0) => return (size, SocketResult::Continue),
107 Ok(sz) => size += sz,
108 Err(e) => match e.kind() {
109 ErrorKind::WouldBlock => return (size, SocketResult::WouldBlock),
110 ErrorKind::ConnectionReset
111 | ErrorKind::ConnectionAborted
112 | ErrorKind::BrokenPipe
113 | ErrorKind::ConnectionRefused => {
114 incr!("tcp.write.error");
115 return (size, SocketResult::Closed);
116 }
117 _ => {
118 error!("SOCKET\tsocket_write error={:?}", e);
120 incr!("tcp.write.error");
121 return (size, SocketResult::Error);
122 }
123 },
124 }
125 }
126 }
127
128 fn socket_write_vectored(&mut self, bufs: &[std::io::IoSlice]) -> (usize, SocketResult) {
129 match self.write_vectored(bufs) {
130 Ok(sz) => (sz, SocketResult::Continue),
131 Err(e) => match e.kind() {
132 ErrorKind::WouldBlock => (0, SocketResult::WouldBlock),
133 ErrorKind::ConnectionReset
134 | ErrorKind::ConnectionAborted
135 | ErrorKind::BrokenPipe
136 | ErrorKind::ConnectionRefused => {
137 incr!("tcp.write.error");
138 (0, SocketResult::Closed)
139 }
140 _ => {
141 error!("SOCKET\tsocket_write error={:?}", e);
143 incr!("tcp.write.error");
144 (0, SocketResult::Error)
145 }
146 },
147 }
148 }
149
150 fn socket_ref(&self) -> &TcpStream {
151 self
152 }
153
154 fn socket_mut(&mut self) -> &mut TcpStream {
155 self
156 }
157
158 fn protocol(&self) -> TransportProtocol {
159 TransportProtocol::Tcp
160 }
161
162 fn read_error(&self) {
163 incr!("tcp.read.error");
164 }
165
166 fn write_error(&self) {
167 incr!("tcp.write.error");
168 }
169}
170
171pub struct FrontRustls {
172 pub stream: TcpStream,
173 pub session: ServerConnection,
174}
175
176impl SocketHandler for FrontRustls {
177 fn socket_read(&mut self, buf: &mut [u8]) -> (usize, SocketResult) {
178 let mut size = 0usize;
179 let mut can_read = true;
180 let mut is_error = false;
181 let mut is_closed = false;
182
183 let mut counter = 0;
184 loop {
185 counter += 1;
186 if counter > MAX_LOOP_ITERATIONS {
187 error!("MAX_LOOP_ITERATION reached in FrontRustls::socket_read");
188 incr!("rustls.read.infinite_loop.error");
189 }
190
191 if size == buf.len() {
192 break;
193 }
194
195 if !can_read | is_error | is_closed {
196 break;
197 }
198
199 match self.session.read_tls(&mut self.stream) {
200 Ok(0) => {
201 can_read = false;
202 is_closed = true;
203 }
204 Ok(_sz) => {}
205 Err(e) => match e.kind() {
206 ErrorKind::WouldBlock => {
207 can_read = false;
208 }
209 ErrorKind::ConnectionReset
210 | ErrorKind::ConnectionAborted
211 | ErrorKind::BrokenPipe => {
212 is_closed = true;
213 }
214 ErrorKind::Other => {
216 warn!(
217 "rustls buffer is full, we will consume it, before processing new incoming packets, to mitigate this issue, you could try to increase the buffer size, {:?}",
218 e
219 );
220 }
221 _ => {
222 error!("could not read TLS stream from socket: {:?}", e);
223 is_error = true;
224 break;
225 }
226 },
227 }
228
229 if let Err(e) = self.session.process_new_packets() {
230 error!("could not process read TLS packets: {:?}", e);
231 is_error = true;
232 break;
233 }
234
235 while !self.session.wants_read() {
236 match self.session.reader().read(&mut buf[size..]) {
237 Ok(0) => break,
238 Ok(sz) => {
239 size += sz;
240 }
241 Err(e) => match e.kind() {
242 ErrorKind::WouldBlock => {
243 break;
244 }
245 ErrorKind::ConnectionReset
246 | ErrorKind::ConnectionAborted
247 | ErrorKind::BrokenPipe => {
248 is_closed = true;
249 break;
250 }
251 _ => {
252 error!("could not read data from TLS stream: {:?}", e);
253 is_error = true;
254 break;
255 }
256 },
257 }
258 }
259 }
260
261 if is_error {
262 (size, SocketResult::Error)
263 } else if is_closed {
264 (size, SocketResult::Closed)
265 } else if !can_read {
266 (size, SocketResult::WouldBlock)
267 } else {
268 (size, SocketResult::Continue)
269 }
270 }
271
272 fn socket_write(&mut self, buf: &[u8]) -> (usize, SocketResult) {
273 let mut buffered_size = 0usize;
274 let mut can_write = true;
275 let mut is_error = false;
276 let mut is_closed = false;
277
278 let mut counter = 0;
279 loop {
280 counter += 1;
281 if counter > MAX_LOOP_ITERATIONS {
282 error!("MAX_LOOP_ITERATION reached in FrontRustls::socket_write");
283 incr!("rustls.write.infinite_loop.error");
284 }
285 if buffered_size == buf.len() {
286 break;
287 }
288
289 if !can_write | is_error | is_closed {
290 break;
291 }
292
293 match self.session.writer().write(&buf[buffered_size..]) {
294 Ok(0) => {} Ok(sz) => {
296 buffered_size += sz;
297 }
298 Err(e) => match e.kind() {
299 ErrorKind::WouldBlock => {
300 }
303 ErrorKind::ConnectionReset
304 | ErrorKind::ConnectionAborted
305 | ErrorKind::BrokenPipe => {
306 incr!("rustls.write.error");
308 is_closed = true;
309 break;
310 }
311 _ => {
312 error!("could not write data to TLS stream: {:?}", e);
313 incr!("rustls.write.error");
314 is_error = true;
315 break;
316 }
317 },
318 }
319
320 loop {
321 match self.session.write_tls(&mut self.stream) {
322 Ok(0) => {
323 break;
325 }
326 Ok(_sz) => {}
327 Err(e) => match e.kind() {
328 ErrorKind::WouldBlock => {
329 can_write = false;
330 break;
331 }
332 ErrorKind::ConnectionReset
333 | ErrorKind::ConnectionAborted
334 | ErrorKind::BrokenPipe => {
335 incr!("rustls.write.error");
336 is_closed = true;
337 break;
338 }
339 _ => {
340 error!("could not write TLS stream to socket: {:?}", e);
341 incr!("rustls.write.error");
342 is_error = true;
343 break;
344 }
345 },
346 }
347 }
348 }
349
350 if is_error {
351 (buffered_size, SocketResult::Error)
352 } else if is_closed {
353 (buffered_size, SocketResult::Closed)
354 } else if !can_write {
355 (buffered_size, SocketResult::WouldBlock)
356 } else {
357 (buffered_size, SocketResult::Continue)
358 }
359 }
360
361 fn socket_write_vectored(&mut self, bufs: &[std::io::IoSlice]) -> (usize, SocketResult) {
362 let mut buffered_size = 0usize;
363 let mut can_write = true;
364 let mut is_error = false;
365 let mut is_closed = false;
366
367 match self.session.writer().write_vectored(bufs) {
368 Ok(0) => {} Ok(sz) => {
370 buffered_size += sz;
371 }
372 Err(e) => match e.kind() {
373 ErrorKind::WouldBlock => {
374 }
377 ErrorKind::ConnectionReset
378 | ErrorKind::ConnectionAborted
379 | ErrorKind::BrokenPipe => {
380 incr!("rustls.write.error");
382 is_closed = true;
383 }
384 _ => {
385 error!("could not write data to TLS stream: {:?}", e);
386 incr!("rustls.write.error");
387 is_error = true;
388 }
389 },
390 }
391
392 let mut counter = 0;
393 loop {
394 counter += 1;
395 if counter > MAX_LOOP_ITERATIONS {
396 error!("MAX_LOOP_ITERATION reached in FrontRustls::socket_write_vectored");
397 incr!("rustls.write.infinite_loop.error");
398 }
399 match self.session.write_tls(&mut self.stream) {
400 Ok(0) => {
401 break;
402 }
403 Ok(_sz) => {}
404 Err(e) => match e.kind() {
405 ErrorKind::WouldBlock => {
406 can_write = false;
407 break;
408 }
409 ErrorKind::ConnectionReset
410 | ErrorKind::ConnectionAborted
411 | ErrorKind::BrokenPipe => {
412 incr!("rustls.write.error");
413 is_closed = true;
414 break;
415 }
416 _ => {
417 error!("could not write TLS stream to socket: {:?}", e);
418 incr!("rustls.write.error");
419 is_error = true;
420 break;
421 }
422 },
423 }
424 }
425
426 if is_error {
427 (buffered_size, SocketResult::Error)
428 } else if is_closed {
429 (buffered_size, SocketResult::Closed)
430 } else if !can_write {
431 (buffered_size, SocketResult::WouldBlock)
432 } else {
433 (buffered_size, SocketResult::Continue)
434 }
435 }
436
437 fn socket_close(&mut self) {
438 self.session.send_close_notify();
439 }
440
441 fn socket_wants_write(&self) -> bool {
442 self.session.wants_write()
443 }
444
445 fn socket_ref(&self) -> &TcpStream {
446 &self.stream
447 }
448
449 fn socket_mut(&mut self) -> &mut TcpStream {
450 &mut self.stream
451 }
452
453 fn protocol(&self) -> TransportProtocol {
454 self.session
455 .protocol_version()
456 .map(|version| match version {
457 ProtocolVersion::SSLv2 => TransportProtocol::Ssl2,
458 ProtocolVersion::SSLv3 => TransportProtocol::Ssl3,
459 ProtocolVersion::TLSv1_0 => TransportProtocol::Tls1_0,
460 ProtocolVersion::TLSv1_1 => TransportProtocol::Tls1_1,
461 ProtocolVersion::TLSv1_2 => TransportProtocol::Tls1_2,
462 ProtocolVersion::TLSv1_3 => TransportProtocol::Tls1_3,
463 _ => TransportProtocol::Tls1_3,
464 })
465 .unwrap_or(TransportProtocol::Tcp)
466 }
467
468 fn read_error(&self) {
469 incr!("rustls.read.error");
470 }
471
472 fn write_error(&self) {
473 incr!("rustls.write.error");
474 }
475}
476
477pub fn server_bind(addr: SocketAddr) -> Result<TcpListener, ServerBindError> {
478 let sock = Socket::new(Domain::for_address(addr), Type::STREAM, Some(Protocol::TCP))
479 .map_err(ServerBindError::SocketCreationError)?;
480
481 if cfg!(unix) {
483 sock.set_reuse_address(true)
484 .map_err(ServerBindError::SetReuseAddress)?;
485 }
486
487 sock.set_reuse_port(true)
488 .map_err(ServerBindError::SetReusePort)?;
489
490 sock.bind(&addr.into())
491 .map_err(ServerBindError::BindError)?;
492
493 sock.set_nonblocking(true)
494 .map_err(ServerBindError::SetNonBlocking)?;
495
496 sock.listen(1024).map_err(ServerBindError::Listen)?;
499
500 Ok(TcpListener::from_std(sock.into()))
501}
502
503pub mod stats {
505 use std::{os::fd::AsRawFd, time::Duration};
506
507 use internal::{OPT_LEVEL, OPT_NAME, TcpInfo};
508
509 pub fn socket_rtt<A: AsRawFd>(socket: &A) -> Option<Duration> {
511 socket_info(socket.as_raw_fd()).map(|info| Duration::from_micros(info.rtt() as u64))
512 }
513
514 #[cfg(unix)]
515 pub fn socket_info(fd: libc::c_int) -> Option<TcpInfo> {
516 let mut tcp_info: TcpInfo = unsafe { std::mem::zeroed() };
517 let mut len = std::mem::size_of::<TcpInfo>() as libc::socklen_t;
518 let status = unsafe {
519 libc::getsockopt(
520 fd,
521 OPT_LEVEL,
522 OPT_NAME,
523 &mut tcp_info as *mut _ as *mut _,
524 &mut len,
525 )
526 };
527 if status != 0 { None } else { Some(tcp_info) }
528 }
529 #[cfg(not(unix))]
530 pub fn socketinfo(fd: libc::c_int) -> Option<TcpInfo> {
531 None
532 }
533
534 #[cfg(unix)]
535 #[cfg(not(any(target_os = "macos", target_os = "ios")))]
536 mod internal {
537 pub const OPT_LEVEL: libc::c_int = libc::SOL_TCP;
538 pub const OPT_NAME: libc::c_int = libc::TCP_INFO;
539
540 #[derive(Clone, Debug)]
541 #[repr(C)]
542 pub struct TcpInfo {
543 tcpi_state: u8,
545 tcpi_ca_state: u8,
546 tcpi_retransmits: u8,
547 tcpi_probes: u8,
548 tcpi_backoff: u8,
549 tcpi_options: u8,
550 tcpi_snd_rcv_wscale: u8, tcpi_rto: u32,
553 tcpi_ato: u32,
554 tcpi_snd_mss: u32,
555 tcpi_rcv_mss: u32,
556
557 tcpi_unacked: u32,
558 tcpi_sacked: u32,
559 tcpi_lost: u32,
560 tcpi_retrans: u32,
561 tcpi_fackets: u32,
562
563 tcpi_last_data_sent: u32,
565 tcpi_last_ack_sent: u32, tcpi_last_data_recv: u32,
567 tcpi_last_ack_recv: u32,
568
569 tcpi_pmtu: u32,
571 tcpi_rcv_ssthresh: u32,
572 tcpi_rtt: u32,
573 tcpi_rttvar: u32,
574 tcpi_snd_ssthresh: u32,
575 tcpi_snd_cwnd: u32,
576 tcpi_advmss: u32,
577 tcpi_reordering: u32,
578 }
579 impl TcpInfo {
580 pub fn rtt(&self) -> u32 {
581 self.tcpi_rtt
582 }
583 }
584 }
585
586 #[cfg(unix)]
587 #[cfg(any(target_os = "macos", target_os = "ios"))]
588 mod internal {
589 pub const OPT_LEVEL: libc::c_int = libc::IPPROTO_TCP;
590 pub const OPT_NAME: libc::c_int = 0x106;
591
592 #[derive(Clone, Debug)]
593 #[repr(C)]
594 pub struct TcpInfo {
595 tcpi_state: u8,
596 tcpi_snd_wscale: u8,
597 tcpi_rcv_wscale: u8,
598 __pad1: u8,
599 tcpi_options: u32,
600 tcpi_flags: u32,
601 tcpi_rto: u32,
602 tcpi_maxseg: u32,
603 tcpi_snd_ssthresh: u32,
604 tcpi_snd_cwnd: u32,
605 tcpi_snd_wnd: u32,
606 tcpi_snd_sbbytes: u32,
607 tcpi_rcv_wnd: u32,
608 tcpi_rttcur: u32,
609 tcpi_srtt: u32,
610 tcpi_rttvar: u32,
611 tcpi_tfo: u32,
612 tcpi_txpackets: u64,
613 tcpi_txbytes: u64,
614 tcpi_txretransmitbytes: u64,
615 tcpi_rxpackets: u64,
616 tcpi_rxbytes: u64,
617 tcpi_rxoutoforderbytes: u64,
618 tcpi_txretransmitpackets: u64,
619 }
620 impl TcpInfo {
621 pub fn rtt(&self) -> u32 {
622 self.tcpi_srtt * 1000
624 }
625 }
626 }
627
628 #[cfg(not(unix))]
629 #[derive(Clone, Debug)]
630 struct TcpInfo {}
631
632 #[test]
633 #[serial_test::serial]
634 fn test_rtt() {
635 let sock = std::net::TcpStream::connect("google.com:80").unwrap();
636 let fd = sock.as_raw_fd();
637 let info = socket_info(fd);
638 assert!(info.is_some());
639 println!("{:#?}", info);
640 println!(
641 "rtt: {}",
642 sozu_command::logging::LogDuration(socket_rtt(&sock))
643 );
644 }
645}