sozu_lib/
socket.rs

1use std::{
2    io::{ErrorKind, Read, Write},
3    net::SocketAddr,
4};
5
6use mio::net::{TcpListener, TcpStream};
7use rustls::{ProtocolVersion, ServerConnection};
8use socket2::{Domain, Protocol, Socket, Type};
9use sozu_command::config::MAX_LOOP_ITERATIONS;
10
11#[derive(thiserror::Error, Debug)]
12pub enum ServerBindError {
13    #[error("could not set bind to socket: {0}")]
14    BindError(std::io::Error),
15    #[error("could not listen on socket: {0}")]
16    Listen(std::io::Error),
17    #[error("could not set socket to nonblocking: {0}")]
18    SetNonBlocking(std::io::Error),
19    #[error("could not set reuse address: {0}")]
20    SetReuseAddress(std::io::Error),
21    #[error("could not set reuse address: {0}")]
22    SetReusePort(std::io::Error),
23    #[error("Could not create socket: {0}")]
24    SocketCreationError(std::io::Error),
25    #[error("Invalid socket address '{address}': {error}")]
26    InvalidSocketAddress { address: String, error: String },
27}
28
29#[derive(Debug, PartialEq, Eq, Copy, Clone)]
30pub enum SocketResult {
31    Continue,
32    Closed,
33    WouldBlock,
34    Error,
35}
36
37#[derive(Debug, PartialEq, Eq, Copy, Clone)]
38pub enum TransportProtocol {
39    Tcp,
40    Ssl2,
41    Ssl3,
42    Tls1_0,
43    Tls1_1,
44    Tls1_2,
45    Tls1_3,
46}
47
48pub trait SocketHandler {
49    fn socket_read(&mut self, buf: &mut [u8]) -> (usize, SocketResult);
50    fn socket_write(&mut self, buf: &[u8]) -> (usize, SocketResult);
51    fn socket_write_vectored(&mut self, _buf: &[std::io::IoSlice]) -> (usize, SocketResult);
52    fn socket_wants_write(&self) -> bool {
53        false
54    }
55    fn socket_close(&mut self) {}
56    fn socket_ref(&self) -> &TcpStream;
57    fn socket_mut(&mut self) -> &mut TcpStream;
58    fn protocol(&self) -> TransportProtocol;
59    fn read_error(&self);
60    fn write_error(&self);
61}
62
63impl SocketHandler for TcpStream {
64    fn socket_read(&mut self, buf: &mut [u8]) -> (usize, SocketResult) {
65        let mut size = 0usize;
66        let mut counter = 0;
67        loop {
68            counter += 1;
69            if counter > MAX_LOOP_ITERATIONS {
70                error!("MAX_LOOP_ITERATION reached in TcpStream::socket_read");
71                incr!("socket.read.infinite_loop.error");
72            }
73            if size == buf.len() {
74                return (size, SocketResult::Continue);
75            }
76            match self.read(&mut buf[size..]) {
77                Ok(0) => return (size, SocketResult::Closed),
78                Ok(sz) => size += sz,
79                Err(e) => match e.kind() {
80                    ErrorKind::WouldBlock => return (size, SocketResult::WouldBlock),
81                    ErrorKind::ConnectionReset
82                    | ErrorKind::ConnectionAborted
83                    | ErrorKind::BrokenPipe => return (size, SocketResult::Closed),
84                    _ => {
85                        error!("SOCKET\tsocket_read error={:?}", e);
86                        return (size, SocketResult::Error);
87                    }
88                },
89            }
90        }
91    }
92
93    fn socket_write(&mut self, buf: &[u8]) -> (usize, SocketResult) {
94        let mut size = 0usize;
95        let mut counter = 0;
96        loop {
97            counter += 1;
98            if counter > MAX_LOOP_ITERATIONS {
99                error!("MAX_LOOP_ITERATION reached in TcpStream::socket_write");
100                incr!("socket.write.infinite_loop.error");
101            }
102            if size == buf.len() {
103                return (size, SocketResult::Continue);
104            }
105            match self.write(&buf[size..]) {
106                Ok(0) => return (size, SocketResult::Continue),
107                Ok(sz) => size += sz,
108                Err(e) => match e.kind() {
109                    ErrorKind::WouldBlock => return (size, SocketResult::WouldBlock),
110                    ErrorKind::ConnectionReset
111                    | ErrorKind::ConnectionAborted
112                    | ErrorKind::BrokenPipe
113                    | ErrorKind::ConnectionRefused => {
114                        incr!("tcp.write.error");
115                        return (size, SocketResult::Closed);
116                    }
117                    _ => {
118                        //FIXME: timeout and other common errors should be sent up
119                        error!("SOCKET\tsocket_write error={:?}", e);
120                        incr!("tcp.write.error");
121                        return (size, SocketResult::Error);
122                    }
123                },
124            }
125        }
126    }
127
128    fn socket_write_vectored(&mut self, bufs: &[std::io::IoSlice]) -> (usize, SocketResult) {
129        match self.write_vectored(bufs) {
130            Ok(sz) => (sz, SocketResult::Continue),
131            Err(e) => match e.kind() {
132                ErrorKind::WouldBlock => (0, SocketResult::WouldBlock),
133                ErrorKind::ConnectionReset
134                | ErrorKind::ConnectionAborted
135                | ErrorKind::BrokenPipe
136                | ErrorKind::ConnectionRefused => {
137                    incr!("tcp.write.error");
138                    (0, SocketResult::Closed)
139                }
140                _ => {
141                    //FIXME: timeout and other common errors should be sent up
142                    error!("SOCKET\tsocket_write error={:?}", e);
143                    incr!("tcp.write.error");
144                    (0, SocketResult::Error)
145                }
146            },
147        }
148    }
149
150    fn socket_ref(&self) -> &TcpStream {
151        self
152    }
153
154    fn socket_mut(&mut self) -> &mut TcpStream {
155        self
156    }
157
158    fn protocol(&self) -> TransportProtocol {
159        TransportProtocol::Tcp
160    }
161
162    fn read_error(&self) {
163        incr!("tcp.read.error");
164    }
165
166    fn write_error(&self) {
167        incr!("tcp.write.error");
168    }
169}
170
171pub struct FrontRustls {
172    pub stream: TcpStream,
173    pub session: ServerConnection,
174}
175
176impl SocketHandler for FrontRustls {
177    fn socket_read(&mut self, buf: &mut [u8]) -> (usize, SocketResult) {
178        let mut size = 0usize;
179        let mut can_read = true;
180        let mut is_error = false;
181        let mut is_closed = false;
182
183        let mut counter = 0;
184        loop {
185            counter += 1;
186            if counter > MAX_LOOP_ITERATIONS {
187                error!("MAX_LOOP_ITERATION reached in FrontRustls::socket_read");
188                incr!("rustls.read.infinite_loop.error");
189            }
190
191            if size == buf.len() {
192                break;
193            }
194
195            if !can_read | is_error | is_closed {
196                break;
197            }
198
199            match self.session.read_tls(&mut self.stream) {
200                Ok(0) => {
201                    can_read = false;
202                    is_closed = true;
203                }
204                Ok(_sz) => {}
205                Err(e) => match e.kind() {
206                    ErrorKind::WouldBlock => {
207                        can_read = false;
208                    }
209                    ErrorKind::ConnectionReset
210                    | ErrorKind::ConnectionAborted
211                    | ErrorKind::BrokenPipe => {
212                        is_closed = true;
213                    }
214                    // https://github.com/rustls/rustls/blob/main/rustls/src/conn.rs#L482-L500,
215                    ErrorKind::Other => {
216                        warn!(
217                            "rustls buffer is full, we will consume it, before processing new incoming packets, to mitigate this issue, you could try to increase the buffer size, {:?}",
218                            e
219                        );
220                    }
221                    _ => {
222                        error!("could not read TLS stream from socket: {:?}", e);
223                        is_error = true;
224                        break;
225                    }
226                },
227            }
228
229            if let Err(e) = self.session.process_new_packets() {
230                error!("could not process read TLS packets: {:?}", e);
231                is_error = true;
232                break;
233            }
234
235            while !self.session.wants_read() {
236                match self.session.reader().read(&mut buf[size..]) {
237                    Ok(0) => break,
238                    Ok(sz) => {
239                        size += sz;
240                    }
241                    Err(e) => match e.kind() {
242                        ErrorKind::WouldBlock => {
243                            break;
244                        }
245                        ErrorKind::ConnectionReset
246                        | ErrorKind::ConnectionAborted
247                        | ErrorKind::BrokenPipe => {
248                            is_closed = true;
249                            break;
250                        }
251                        _ => {
252                            error!("could not read data from TLS stream: {:?}", e);
253                            is_error = true;
254                            break;
255                        }
256                    },
257                }
258            }
259        }
260
261        if is_error {
262            (size, SocketResult::Error)
263        } else if is_closed {
264            (size, SocketResult::Closed)
265        } else if !can_read {
266            (size, SocketResult::WouldBlock)
267        } else {
268            (size, SocketResult::Continue)
269        }
270    }
271
272    fn socket_write(&mut self, buf: &[u8]) -> (usize, SocketResult) {
273        let mut buffered_size = 0usize;
274        let mut can_write = true;
275        let mut is_error = false;
276        let mut is_closed = false;
277
278        let mut counter = 0;
279        loop {
280            counter += 1;
281            if counter > MAX_LOOP_ITERATIONS {
282                error!("MAX_LOOP_ITERATION reached in FrontRustls::socket_write");
283                incr!("rustls.write.infinite_loop.error");
284            }
285            if buffered_size == buf.len() {
286                break;
287            }
288
289            if !can_write | is_error | is_closed {
290                break;
291            }
292
293            match self.session.writer().write(&buf[buffered_size..]) {
294                Ok(0) => {} // zero byte written means that the Rustls buffers are full, we will try to write on the socket and try again
295                Ok(sz) => {
296                    buffered_size += sz;
297                }
298                Err(e) => match e.kind() {
299                    ErrorKind::WouldBlock => {
300                        // we don't need to do anything, the session will return false in wants_write?
301                        //error!("rustls socket_write wouldblock");
302                    }
303                    ErrorKind::ConnectionReset
304                    | ErrorKind::ConnectionAborted
305                    | ErrorKind::BrokenPipe => {
306                        //FIXME: this should probably not happen here
307                        incr!("rustls.write.error");
308                        is_closed = true;
309                        break;
310                    }
311                    _ => {
312                        error!("could not write data to TLS stream: {:?}", e);
313                        incr!("rustls.write.error");
314                        is_error = true;
315                        break;
316                    }
317                },
318            }
319
320            loop {
321                match self.session.write_tls(&mut self.stream) {
322                    Ok(0) => {
323                        //can_write = false;
324                        break;
325                    }
326                    Ok(_sz) => {}
327                    Err(e) => match e.kind() {
328                        ErrorKind::WouldBlock => {
329                            can_write = false;
330                            break;
331                        }
332                        ErrorKind::ConnectionReset
333                        | ErrorKind::ConnectionAborted
334                        | ErrorKind::BrokenPipe => {
335                            incr!("rustls.write.error");
336                            is_closed = true;
337                            break;
338                        }
339                        _ => {
340                            error!("could not write TLS stream to socket: {:?}", e);
341                            incr!("rustls.write.error");
342                            is_error = true;
343                            break;
344                        }
345                    },
346                }
347            }
348        }
349
350        if is_error {
351            (buffered_size, SocketResult::Error)
352        } else if is_closed {
353            (buffered_size, SocketResult::Closed)
354        } else if !can_write {
355            (buffered_size, SocketResult::WouldBlock)
356        } else {
357            (buffered_size, SocketResult::Continue)
358        }
359    }
360
361    fn socket_write_vectored(&mut self, bufs: &[std::io::IoSlice]) -> (usize, SocketResult) {
362        let mut buffered_size = 0usize;
363        let mut can_write = true;
364        let mut is_error = false;
365        let mut is_closed = false;
366
367        match self.session.writer().write_vectored(bufs) {
368            Ok(0) => {} // zero byte written means that the Rustls buffers are full, we will try to write on the socket and try again
369            Ok(sz) => {
370                buffered_size += sz;
371            }
372            Err(e) => match e.kind() {
373                ErrorKind::WouldBlock => {
374                    // we don't need to do anything, the session will return false in wants_write?
375                    //error!("rustls socket_write wouldblock");
376                }
377                ErrorKind::ConnectionReset
378                | ErrorKind::ConnectionAborted
379                | ErrorKind::BrokenPipe => {
380                    //FIXME: this should probably not happen here
381                    incr!("rustls.write.error");
382                    is_closed = true;
383                }
384                _ => {
385                    error!("could not write data to TLS stream: {:?}", e);
386                    incr!("rustls.write.error");
387                    is_error = true;
388                }
389            },
390        }
391
392        let mut counter = 0;
393        loop {
394            counter += 1;
395            if counter > MAX_LOOP_ITERATIONS {
396                error!("MAX_LOOP_ITERATION reached in FrontRustls::socket_write_vectored");
397                incr!("rustls.write.infinite_loop.error");
398            }
399            match self.session.write_tls(&mut self.stream) {
400                Ok(0) => {
401                    break;
402                }
403                Ok(_sz) => {}
404                Err(e) => match e.kind() {
405                    ErrorKind::WouldBlock => {
406                        can_write = false;
407                        break;
408                    }
409                    ErrorKind::ConnectionReset
410                    | ErrorKind::ConnectionAborted
411                    | ErrorKind::BrokenPipe => {
412                        incr!("rustls.write.error");
413                        is_closed = true;
414                        break;
415                    }
416                    _ => {
417                        error!("could not write TLS stream to socket: {:?}", e);
418                        incr!("rustls.write.error");
419                        is_error = true;
420                        break;
421                    }
422                },
423            }
424        }
425
426        if is_error {
427            (buffered_size, SocketResult::Error)
428        } else if is_closed {
429            (buffered_size, SocketResult::Closed)
430        } else if !can_write {
431            (buffered_size, SocketResult::WouldBlock)
432        } else {
433            (buffered_size, SocketResult::Continue)
434        }
435    }
436
437    fn socket_close(&mut self) {
438        self.session.send_close_notify();
439    }
440
441    fn socket_wants_write(&self) -> bool {
442        self.session.wants_write()
443    }
444
445    fn socket_ref(&self) -> &TcpStream {
446        &self.stream
447    }
448
449    fn socket_mut(&mut self) -> &mut TcpStream {
450        &mut self.stream
451    }
452
453    fn protocol(&self) -> TransportProtocol {
454        self.session
455            .protocol_version()
456            .map(|version| match version {
457                ProtocolVersion::SSLv2 => TransportProtocol::Ssl2,
458                ProtocolVersion::SSLv3 => TransportProtocol::Ssl3,
459                ProtocolVersion::TLSv1_0 => TransportProtocol::Tls1_0,
460                ProtocolVersion::TLSv1_1 => TransportProtocol::Tls1_1,
461                ProtocolVersion::TLSv1_2 => TransportProtocol::Tls1_2,
462                ProtocolVersion::TLSv1_3 => TransportProtocol::Tls1_3,
463                _ => TransportProtocol::Tls1_3,
464            })
465            .unwrap_or(TransportProtocol::Tcp)
466    }
467
468    fn read_error(&self) {
469        incr!("rustls.read.error");
470    }
471
472    fn write_error(&self) {
473        incr!("rustls.write.error");
474    }
475}
476
477pub fn server_bind(addr: SocketAddr) -> Result<TcpListener, ServerBindError> {
478    let sock = Socket::new(Domain::for_address(addr), Type::STREAM, Some(Protocol::TCP))
479        .map_err(ServerBindError::SocketCreationError)?;
480
481    // set so_reuseaddr, but only on unix (mirrors what libstd does)
482    if cfg!(unix) {
483        sock.set_reuse_address(true)
484            .map_err(ServerBindError::SetReuseAddress)?;
485    }
486
487    sock.set_reuse_port(true)
488        .map_err(ServerBindError::SetReusePort)?;
489
490    sock.bind(&addr.into())
491        .map_err(ServerBindError::BindError)?;
492
493    sock.set_nonblocking(true)
494        .map_err(ServerBindError::SetNonBlocking)?;
495
496    // listen
497    // FIXME: make the backlog configurable?
498    sock.listen(1024).map_err(ServerBindError::Listen)?;
499
500    Ok(TcpListener::from_std(sock.into()))
501}
502
503/// Socket statistics
504pub mod stats {
505    use std::{os::fd::AsRawFd, time::Duration};
506
507    use internal::{OPT_LEVEL, OPT_NAME, TcpInfo};
508
509    /// Round trip time for a TCP socket
510    pub fn socket_rtt<A: AsRawFd>(socket: &A) -> Option<Duration> {
511        socket_info(socket.as_raw_fd()).map(|info| Duration::from_micros(info.rtt() as u64))
512    }
513
514    #[cfg(unix)]
515    pub fn socket_info(fd: libc::c_int) -> Option<TcpInfo> {
516        let mut tcp_info: TcpInfo = unsafe { std::mem::zeroed() };
517        let mut len = std::mem::size_of::<TcpInfo>() as libc::socklen_t;
518        let status = unsafe {
519            libc::getsockopt(
520                fd,
521                OPT_LEVEL,
522                OPT_NAME,
523                &mut tcp_info as *mut _ as *mut _,
524                &mut len,
525            )
526        };
527        if status != 0 { None } else { Some(tcp_info) }
528    }
529    #[cfg(not(unix))]
530    pub fn socketinfo(fd: libc::c_int) -> Option<TcpInfo> {
531        None
532    }
533
534    #[cfg(unix)]
535    #[cfg(not(any(target_os = "macos", target_os = "ios")))]
536    mod internal {
537        pub const OPT_LEVEL: libc::c_int = libc::SOL_TCP;
538        pub const OPT_NAME: libc::c_int = libc::TCP_INFO;
539
540        #[derive(Clone, Debug)]
541        #[repr(C)]
542        pub struct TcpInfo {
543            // State
544            tcpi_state: u8,
545            tcpi_ca_state: u8,
546            tcpi_retransmits: u8,
547            tcpi_probes: u8,
548            tcpi_backoff: u8,
549            tcpi_options: u8,
550            tcpi_snd_rcv_wscale: u8, // 4bits|4bits
551
552            tcpi_rto: u32,
553            tcpi_ato: u32,
554            tcpi_snd_mss: u32,
555            tcpi_rcv_mss: u32,
556
557            tcpi_unacked: u32,
558            tcpi_sacked: u32,
559            tcpi_lost: u32,
560            tcpi_retrans: u32,
561            tcpi_fackets: u32,
562
563            // Times
564            tcpi_last_data_sent: u32,
565            tcpi_last_ack_sent: u32, // Not remembered
566            tcpi_last_data_recv: u32,
567            tcpi_last_ack_recv: u32,
568
569            // Metrics
570            tcpi_pmtu: u32,
571            tcpi_rcv_ssthresh: u32,
572            tcpi_rtt: u32,
573            tcpi_rttvar: u32,
574            tcpi_snd_ssthresh: u32,
575            tcpi_snd_cwnd: u32,
576            tcpi_advmss: u32,
577            tcpi_reordering: u32,
578        }
579        impl TcpInfo {
580            pub fn rtt(&self) -> u32 {
581                self.tcpi_rtt
582            }
583        }
584    }
585
586    #[cfg(unix)]
587    #[cfg(any(target_os = "macos", target_os = "ios"))]
588    mod internal {
589        pub const OPT_LEVEL: libc::c_int = libc::IPPROTO_TCP;
590        pub const OPT_NAME: libc::c_int = 0x106;
591
592        #[derive(Clone, Debug)]
593        #[repr(C)]
594        pub struct TcpInfo {
595            tcpi_state: u8,
596            tcpi_snd_wscale: u8,
597            tcpi_rcv_wscale: u8,
598            __pad1: u8,
599            tcpi_options: u32,
600            tcpi_flags: u32,
601            tcpi_rto: u32,
602            tcpi_maxseg: u32,
603            tcpi_snd_ssthresh: u32,
604            tcpi_snd_cwnd: u32,
605            tcpi_snd_wnd: u32,
606            tcpi_snd_sbbytes: u32,
607            tcpi_rcv_wnd: u32,
608            tcpi_rttcur: u32,
609            tcpi_srtt: u32,
610            tcpi_rttvar: u32,
611            tcpi_tfo: u32,
612            tcpi_txpackets: u64,
613            tcpi_txbytes: u64,
614            tcpi_txretransmitbytes: u64,
615            tcpi_rxpackets: u64,
616            tcpi_rxbytes: u64,
617            tcpi_rxoutoforderbytes: u64,
618            tcpi_txretransmitpackets: u64,
619        }
620        impl TcpInfo {
621            pub fn rtt(&self) -> u32 {
622                // tcpi_srtt is in milliseconds not microseconds
623                self.tcpi_srtt * 1000
624            }
625        }
626    }
627
628    #[cfg(not(unix))]
629    #[derive(Clone, Debug)]
630    struct TcpInfo {}
631
632    #[test]
633    #[serial_test::serial]
634    fn test_rtt() {
635        let sock = std::net::TcpStream::connect("google.com:80").unwrap();
636        let fd = sock.as_raw_fd();
637        let info = socket_info(fd);
638        assert!(info.is_some());
639        println!("{:#?}", info);
640        println!(
641            "rtt: {}",
642            sozu_command::logging::LogDuration(socket_rtt(&sock))
643        );
644    }
645}