1use std::{
2 io::{ErrorKind, Read, Write},
3 net::SocketAddr,
4};
5
6use mio::net::{TcpListener, TcpStream};
7use rustls::{ProtocolVersion, ServerConnection};
8use socket2::{Domain, Protocol, Socket, Type};
9use sozu_command::config::MAX_LOOP_ITERATIONS;
10
11#[derive(thiserror::Error, Debug)]
12pub enum ServerBindError {
13 #[error("could not set bind to socket: {0}")]
14 BindError(std::io::Error),
15 #[error("could not listen on socket: {0}")]
16 Listen(std::io::Error),
17 #[error("could not set socket to nonblocking: {0}")]
18 SetNonBlocking(std::io::Error),
19 #[error("could not set reuse address: {0}")]
20 SetReuseAddress(std::io::Error),
21 #[error("could not set reuse address: {0}")]
22 SetReusePort(std::io::Error),
23 #[error("Could not create socket: {0}")]
24 SocketCreationError(std::io::Error),
25 #[error("Invalid socket address '{address}': {error}")]
26 InvalidSocketAddress { address: String, error: String },
27}
28
29#[derive(Debug, PartialEq, Eq, Copy, Clone)]
30pub enum SocketResult {
31 Continue,
32 Closed,
33 WouldBlock,
34 Error,
35}
36
37#[derive(Debug, PartialEq, Eq, Copy, Clone)]
38pub enum TransportProtocol {
39 Tcp,
40 Ssl2,
41 Ssl3,
42 Tls1_0,
43 Tls1_1,
44 Tls1_2,
45 Tls1_3,
46}
47
48pub trait SocketHandler {
49 fn socket_read(&mut self, buf: &mut [u8]) -> (usize, SocketResult);
50 fn socket_write(&mut self, buf: &[u8]) -> (usize, SocketResult);
51 fn socket_write_vectored(&mut self, _buf: &[std::io::IoSlice]) -> (usize, SocketResult);
52 fn socket_wants_write(&self) -> bool {
53 false
54 }
55 fn socket_close(&mut self) {}
56 fn socket_ref(&self) -> &TcpStream;
57 fn socket_mut(&mut self) -> &mut TcpStream;
58 fn protocol(&self) -> TransportProtocol;
59 fn read_error(&self);
60 fn write_error(&self);
61}
62
63impl SocketHandler for TcpStream {
64 fn socket_read(&mut self, buf: &mut [u8]) -> (usize, SocketResult) {
65 let mut size = 0usize;
66 let mut counter = 0;
67 loop {
68 counter += 1;
69 if counter > MAX_LOOP_ITERATIONS {
70 error!("MAX_LOOP_ITERATION reached in TcpStream::socket_read");
71 incr!("socket.read.infinite_loop.error");
72 }
73 if size == buf.len() {
74 return (size, SocketResult::Continue);
75 }
76 match self.read(&mut buf[size..]) {
77 Ok(0) => return (size, SocketResult::Closed),
78 Ok(sz) => size += sz,
79 Err(e) => match e.kind() {
80 ErrorKind::WouldBlock => return (size, SocketResult::WouldBlock),
81 ErrorKind::ConnectionReset
82 | ErrorKind::ConnectionAborted
83 | ErrorKind::BrokenPipe => return (size, SocketResult::Closed),
84 _ => {
85 error!("SOCKET\tsocket_read error={:?}", e);
86 return (size, SocketResult::Error);
87 }
88 },
89 }
90 }
91 }
92
93 fn socket_write(&mut self, buf: &[u8]) -> (usize, SocketResult) {
94 let mut size = 0usize;
95 let mut counter = 0;
96 loop {
97 counter += 1;
98 if counter > MAX_LOOP_ITERATIONS {
99 error!("MAX_LOOP_ITERATION reached in TcpStream::socket_write");
100 incr!("socket.write.infinite_loop.error");
101 }
102 if size == buf.len() {
103 return (size, SocketResult::Continue);
104 }
105 match self.write(&buf[size..]) {
106 Ok(0) => return (size, SocketResult::Continue),
107 Ok(sz) => size += sz,
108 Err(e) => match e.kind() {
109 ErrorKind::WouldBlock => return (size, SocketResult::WouldBlock),
110 ErrorKind::ConnectionReset
111 | ErrorKind::ConnectionAborted
112 | ErrorKind::BrokenPipe
113 | ErrorKind::ConnectionRefused => {
114 incr!("tcp.write.error");
115 return (size, SocketResult::Closed);
116 }
117 _ => {
118 error!("SOCKET\tsocket_write error={:?}", e);
120 incr!("tcp.write.error");
121 return (size, SocketResult::Error);
122 }
123 },
124 }
125 }
126 }
127
128 fn socket_write_vectored(&mut self, bufs: &[std::io::IoSlice]) -> (usize, SocketResult) {
129 match self.write_vectored(bufs) {
130 Ok(sz) => (sz, SocketResult::Continue),
131 Err(e) => match e.kind() {
132 ErrorKind::WouldBlock => (0, SocketResult::WouldBlock),
133 ErrorKind::ConnectionReset
134 | ErrorKind::ConnectionAborted
135 | ErrorKind::BrokenPipe
136 | ErrorKind::ConnectionRefused => {
137 incr!("tcp.write.error");
138 (0, SocketResult::Closed)
139 }
140 _ => {
141 error!("SOCKET\tsocket_write error={:?}", e);
143 incr!("tcp.write.error");
144 (0, SocketResult::Error)
145 }
146 },
147 }
148 }
149
150 fn socket_ref(&self) -> &TcpStream {
151 self
152 }
153
154 fn socket_mut(&mut self) -> &mut TcpStream {
155 self
156 }
157
158 fn protocol(&self) -> TransportProtocol {
159 TransportProtocol::Tcp
160 }
161
162 fn read_error(&self) {
163 incr!("tcp.read.error");
164 }
165
166 fn write_error(&self) {
167 incr!("tcp.write.error");
168 }
169}
170
171pub struct FrontRustls {
172 pub stream: TcpStream,
173 pub session: ServerConnection,
174}
175
176impl SocketHandler for FrontRustls {
177 fn socket_read(&mut self, buf: &mut [u8]) -> (usize, SocketResult) {
178 let mut size = 0usize;
179 let mut can_read = true;
180 let mut is_error = false;
181 let mut is_closed = false;
182
183 let mut counter = 0;
184 loop {
185 counter += 1;
186 if counter > MAX_LOOP_ITERATIONS {
187 error!("MAX_LOOP_ITERATION reached in FrontRustls::socket_read");
188 incr!("rustls.read.infinite_loop.error");
189 }
190
191 if size == buf.len() {
192 break;
193 }
194
195 if !can_read | is_error | is_closed {
196 break;
197 }
198
199 match self.session.read_tls(&mut self.stream) {
200 Ok(0) => {
201 can_read = false;
202 is_closed = true;
203 }
204 Ok(_sz) => {}
205 Err(e) => match e.kind() {
206 ErrorKind::WouldBlock => {
207 can_read = false;
208 }
209 ErrorKind::ConnectionReset
210 | ErrorKind::ConnectionAborted
211 | ErrorKind::BrokenPipe => {
212 is_closed = true;
213 }
214 ErrorKind::Other => {
216 warn!("rustls buffer is full, we will consume it, before processing new incoming packets, to mitigate this issue, you could try to increase the buffer size, {:?}", e);
217 }
218 _ => {
219 error!("could not read TLS stream from socket: {:?}", e);
220 is_error = true;
221 break;
222 }
223 },
224 }
225
226 if let Err(e) = self.session.process_new_packets() {
227 error!("could not process read TLS packets: {:?}", e);
228 is_error = true;
229 break;
230 }
231
232 while !self.session.wants_read() {
233 match self.session.reader().read(&mut buf[size..]) {
234 Ok(0) => break,
235 Ok(sz) => {
236 size += sz;
237 }
238 Err(e) => match e.kind() {
239 ErrorKind::WouldBlock => {
240 break;
241 }
242 ErrorKind::ConnectionReset
243 | ErrorKind::ConnectionAborted
244 | ErrorKind::BrokenPipe => {
245 is_closed = true;
246 break;
247 }
248 _ => {
249 error!("could not read data from TLS stream: {:?}", e);
250 is_error = true;
251 break;
252 }
253 },
254 }
255 }
256 }
257
258 if is_error {
259 (size, SocketResult::Error)
260 } else if is_closed {
261 (size, SocketResult::Closed)
262 } else if !can_read {
263 (size, SocketResult::WouldBlock)
264 } else {
265 (size, SocketResult::Continue)
266 }
267 }
268
269 fn socket_write(&mut self, buf: &[u8]) -> (usize, SocketResult) {
270 let mut buffered_size = 0usize;
271 let mut can_write = true;
272 let mut is_error = false;
273 let mut is_closed = false;
274
275 let mut counter = 0;
276 loop {
277 counter += 1;
278 if counter > MAX_LOOP_ITERATIONS {
279 error!("MAX_LOOP_ITERATION reached in FrontRustls::socket_write");
280 incr!("rustls.write.infinite_loop.error");
281 }
282 if buffered_size == buf.len() {
283 break;
284 }
285
286 if !can_write | is_error | is_closed {
287 break;
288 }
289
290 match self.session.writer().write(&buf[buffered_size..]) {
291 Ok(0) => {} Ok(sz) => {
293 buffered_size += sz;
294 }
295 Err(e) => match e.kind() {
296 ErrorKind::WouldBlock => {
297 }
300 ErrorKind::ConnectionReset
301 | ErrorKind::ConnectionAborted
302 | ErrorKind::BrokenPipe => {
303 incr!("rustls.write.error");
305 is_closed = true;
306 break;
307 }
308 _ => {
309 error!("could not write data to TLS stream: {:?}", e);
310 incr!("rustls.write.error");
311 is_error = true;
312 break;
313 }
314 },
315 }
316
317 loop {
318 match self.session.write_tls(&mut self.stream) {
319 Ok(0) => {
320 break;
322 }
323 Ok(_sz) => {}
324 Err(e) => match e.kind() {
325 ErrorKind::WouldBlock => {
326 can_write = false;
327 break;
328 }
329 ErrorKind::ConnectionReset
330 | ErrorKind::ConnectionAborted
331 | ErrorKind::BrokenPipe => {
332 incr!("rustls.write.error");
333 is_closed = true;
334 break;
335 }
336 _ => {
337 error!("could not write TLS stream to socket: {:?}", e);
338 incr!("rustls.write.error");
339 is_error = true;
340 break;
341 }
342 },
343 }
344 }
345 }
346
347 if is_error {
348 (buffered_size, SocketResult::Error)
349 } else if is_closed {
350 (buffered_size, SocketResult::Closed)
351 } else if !can_write {
352 (buffered_size, SocketResult::WouldBlock)
353 } else {
354 (buffered_size, SocketResult::Continue)
355 }
356 }
357
358 fn socket_write_vectored(&mut self, bufs: &[std::io::IoSlice]) -> (usize, SocketResult) {
359 let mut buffered_size = 0usize;
360 let mut can_write = true;
361 let mut is_error = false;
362 let mut is_closed = false;
363
364 match self.session.writer().write_vectored(bufs) {
365 Ok(0) => {} Ok(sz) => {
367 buffered_size += sz;
368 }
369 Err(e) => match e.kind() {
370 ErrorKind::WouldBlock => {
371 }
374 ErrorKind::ConnectionReset
375 | ErrorKind::ConnectionAborted
376 | ErrorKind::BrokenPipe => {
377 incr!("rustls.write.error");
379 is_closed = true;
380 }
381 _ => {
382 error!("could not write data to TLS stream: {:?}", e);
383 incr!("rustls.write.error");
384 is_error = true;
385 }
386 },
387 }
388
389 let mut counter = 0;
390 loop {
391 counter += 1;
392 if counter > MAX_LOOP_ITERATIONS {
393 error!("MAX_LOOP_ITERATION reached in FrontRustls::socket_write_vectored");
394 incr!("rustls.write.infinite_loop.error");
395 }
396 match self.session.write_tls(&mut self.stream) {
397 Ok(0) => {
398 break;
399 }
400 Ok(_sz) => {}
401 Err(e) => match e.kind() {
402 ErrorKind::WouldBlock => {
403 can_write = false;
404 break;
405 }
406 ErrorKind::ConnectionReset
407 | ErrorKind::ConnectionAborted
408 | ErrorKind::BrokenPipe => {
409 incr!("rustls.write.error");
410 is_closed = true;
411 break;
412 }
413 _ => {
414 error!("could not write TLS stream to socket: {:?}", e);
415 incr!("rustls.write.error");
416 is_error = true;
417 break;
418 }
419 },
420 }
421 }
422
423 if is_error {
424 (buffered_size, SocketResult::Error)
425 } else if is_closed {
426 (buffered_size, SocketResult::Closed)
427 } else if !can_write {
428 (buffered_size, SocketResult::WouldBlock)
429 } else {
430 (buffered_size, SocketResult::Continue)
431 }
432 }
433
434 fn socket_close(&mut self) {
435 self.session.send_close_notify();
436 }
437
438 fn socket_wants_write(&self) -> bool {
439 self.session.wants_write()
440 }
441
442 fn socket_ref(&self) -> &TcpStream {
443 &self.stream
444 }
445
446 fn socket_mut(&mut self) -> &mut TcpStream {
447 &mut self.stream
448 }
449
450 fn protocol(&self) -> TransportProtocol {
451 self.session
452 .protocol_version()
453 .map(|version| match version {
454 ProtocolVersion::SSLv2 => TransportProtocol::Ssl2,
455 ProtocolVersion::SSLv3 => TransportProtocol::Ssl3,
456 ProtocolVersion::TLSv1_0 => TransportProtocol::Tls1_0,
457 ProtocolVersion::TLSv1_1 => TransportProtocol::Tls1_1,
458 ProtocolVersion::TLSv1_2 => TransportProtocol::Tls1_2,
459 ProtocolVersion::TLSv1_3 => TransportProtocol::Tls1_3,
460 _ => TransportProtocol::Tls1_3,
461 })
462 .unwrap_or(TransportProtocol::Tcp)
463 }
464
465 fn read_error(&self) {
466 incr!("rustls.read.error");
467 }
468
469 fn write_error(&self) {
470 incr!("rustls.write.error");
471 }
472}
473
474pub fn server_bind(addr: SocketAddr) -> Result<TcpListener, ServerBindError> {
475 let sock = Socket::new(Domain::for_address(addr), Type::STREAM, Some(Protocol::TCP))
476 .map_err(ServerBindError::SocketCreationError)?;
477
478 if cfg!(unix) {
480 sock.set_reuse_address(true)
481 .map_err(ServerBindError::SetReuseAddress)?;
482 }
483
484 sock.set_reuse_port(true)
485 .map_err(ServerBindError::SetReusePort)?;
486
487 sock.bind(&addr.into())
488 .map_err(ServerBindError::BindError)?;
489
490 sock.set_nonblocking(true)
491 .map_err(ServerBindError::SetNonBlocking)?;
492
493 sock.listen(1024).map_err(ServerBindError::Listen)?;
496
497 Ok(TcpListener::from_std(sock.into()))
498}
499
500pub mod stats {
502 use std::{os::fd::AsRawFd, time::Duration};
503
504 use internal::{TcpInfo, OPT_LEVEL, OPT_NAME};
505
506 pub fn socket_rtt<A: AsRawFd>(socket: &A) -> Option<Duration> {
508 socket_info(socket.as_raw_fd()).map(|info| Duration::from_micros(info.rtt() as u64))
509 }
510
511 #[cfg(unix)]
512 pub fn socket_info(fd: libc::c_int) -> Option<TcpInfo> {
513 let mut tcp_info: TcpInfo = unsafe { std::mem::zeroed() };
514 let mut len = std::mem::size_of::<TcpInfo>() as libc::socklen_t;
515 let status = unsafe {
516 libc::getsockopt(
517 fd,
518 OPT_LEVEL,
519 OPT_NAME,
520 &mut tcp_info as *mut _ as *mut _,
521 &mut len,
522 )
523 };
524 if status != 0 {
525 None
526 } else {
527 Some(tcp_info)
528 }
529 }
530 #[cfg(not(unix))]
531 pub fn socketinfo(fd: libc::c_int) -> Option<TcpInfo> {
532 None
533 }
534
535 #[cfg(unix)]
536 #[cfg(not(any(target_os = "macos", target_os = "ios")))]
537 mod internal {
538 pub const OPT_LEVEL: libc::c_int = libc::SOL_TCP;
539 pub const OPT_NAME: libc::c_int = libc::TCP_INFO;
540
541 #[derive(Clone, Debug)]
542 #[repr(C)]
543 pub struct TcpInfo {
544 tcpi_state: u8,
546 tcpi_ca_state: u8,
547 tcpi_retransmits: u8,
548 tcpi_probes: u8,
549 tcpi_backoff: u8,
550 tcpi_options: u8,
551 tcpi_snd_rcv_wscale: u8, tcpi_rto: u32,
554 tcpi_ato: u32,
555 tcpi_snd_mss: u32,
556 tcpi_rcv_mss: u32,
557
558 tcpi_unacked: u32,
559 tcpi_sacked: u32,
560 tcpi_lost: u32,
561 tcpi_retrans: u32,
562 tcpi_fackets: u32,
563
564 tcpi_last_data_sent: u32,
566 tcpi_last_ack_sent: u32, tcpi_last_data_recv: u32,
568 tcpi_last_ack_recv: u32,
569
570 tcpi_pmtu: u32,
572 tcpi_rcv_ssthresh: u32,
573 tcpi_rtt: u32,
574 tcpi_rttvar: u32,
575 tcpi_snd_ssthresh: u32,
576 tcpi_snd_cwnd: u32,
577 tcpi_advmss: u32,
578 tcpi_reordering: u32,
579 }
580 impl TcpInfo {
581 pub fn rtt(&self) -> u32 {
582 self.tcpi_rtt
583 }
584 }
585 }
586
587 #[cfg(unix)]
588 #[cfg(any(target_os = "macos", target_os = "ios"))]
589 mod internal {
590 pub const OPT_LEVEL: libc::c_int = libc::IPPROTO_TCP;
591 pub const OPT_NAME: libc::c_int = 0x106;
592
593 #[derive(Clone, Debug)]
594 #[repr(C)]
595 pub struct TcpInfo {
596 tcpi_state: u8,
597 tcpi_snd_wscale: u8,
598 tcpi_rcv_wscale: u8,
599 __pad1: u8,
600 tcpi_options: u32,
601 tcpi_flags: u32,
602 tcpi_rto: u32,
603 tcpi_maxseg: u32,
604 tcpi_snd_ssthresh: u32,
605 tcpi_snd_cwnd: u32,
606 tcpi_snd_wnd: u32,
607 tcpi_snd_sbbytes: u32,
608 tcpi_rcv_wnd: u32,
609 tcpi_rttcur: u32,
610 tcpi_srtt: u32,
611 tcpi_rttvar: u32,
612 tcpi_tfo: u32,
613 tcpi_txpackets: u64,
614 tcpi_txbytes: u64,
615 tcpi_txretransmitbytes: u64,
616 tcpi_rxpackets: u64,
617 tcpi_rxbytes: u64,
618 tcpi_rxoutoforderbytes: u64,
619 tcpi_txretransmitpackets: u64,
620 }
621 impl TcpInfo {
622 pub fn rtt(&self) -> u32 {
623 self.tcpi_srtt * 1000
625 }
626 }
627 }
628
629 #[cfg(not(unix))]
630 #[derive(Clone, Debug)]
631 struct TcpInfo {}
632
633 #[test]
634 #[serial_test::serial]
635 fn test_rtt() {
636 let sock = std::net::TcpStream::connect("google.com:80").unwrap();
637 let fd = sock.as_raw_fd();
638 let info = socket_info(fd);
639 assert!(info.is_some());
640 println!("{:#?}", info);
641 println!(
642 "rtt: {}",
643 sozu_command::logging::LogDuration(socket_rtt(&sock))
644 );
645 }
646}