sozu_lib/
socket.rs

1use std::{
2    io::{ErrorKind, Read, Write},
3    net::SocketAddr,
4};
5
6use mio::net::{TcpListener, TcpStream};
7use rustls::{ProtocolVersion, ServerConnection};
8use socket2::{Domain, Protocol, Socket, Type};
9use sozu_command::config::MAX_LOOP_ITERATIONS;
10
11#[derive(thiserror::Error, Debug)]
12pub enum ServerBindError {
13    #[error("could not set bind to socket: {0}")]
14    BindError(std::io::Error),
15    #[error("could not listen on socket: {0}")]
16    Listen(std::io::Error),
17    #[error("could not set socket to nonblocking: {0}")]
18    SetNonBlocking(std::io::Error),
19    #[error("could not set reuse address: {0}")]
20    SetReuseAddress(std::io::Error),
21    #[error("could not set reuse address: {0}")]
22    SetReusePort(std::io::Error),
23    #[error("Could not create socket: {0}")]
24    SocketCreationError(std::io::Error),
25    #[error("Invalid socket address '{address}': {error}")]
26    InvalidSocketAddress { address: String, error: String },
27}
28
29#[derive(Debug, PartialEq, Eq, Copy, Clone)]
30pub enum SocketResult {
31    Continue,
32    Closed,
33    WouldBlock,
34    Error,
35}
36
37#[derive(Debug, PartialEq, Eq, Copy, Clone)]
38pub enum TransportProtocol {
39    Tcp,
40    Ssl2,
41    Ssl3,
42    Tls1_0,
43    Tls1_1,
44    Tls1_2,
45    Tls1_3,
46}
47
48pub trait SocketHandler {
49    fn socket_read(&mut self, buf: &mut [u8]) -> (usize, SocketResult);
50    fn socket_write(&mut self, buf: &[u8]) -> (usize, SocketResult);
51    fn socket_write_vectored(&mut self, _buf: &[std::io::IoSlice]) -> (usize, SocketResult);
52    fn socket_wants_write(&self) -> bool {
53        false
54    }
55    fn socket_close(&mut self) {}
56    fn socket_ref(&self) -> &TcpStream;
57    fn socket_mut(&mut self) -> &mut TcpStream;
58    fn protocol(&self) -> TransportProtocol;
59    fn read_error(&self);
60    fn write_error(&self);
61}
62
63impl SocketHandler for TcpStream {
64    fn socket_read(&mut self, buf: &mut [u8]) -> (usize, SocketResult) {
65        let mut size = 0usize;
66        let mut counter = 0;
67        loop {
68            counter += 1;
69            if counter > MAX_LOOP_ITERATIONS {
70                error!("MAX_LOOP_ITERATION reached in TcpStream::socket_read");
71                incr!("socket.read.infinite_loop.error");
72            }
73            if size == buf.len() {
74                return (size, SocketResult::Continue);
75            }
76            match self.read(&mut buf[size..]) {
77                Ok(0) => return (size, SocketResult::Closed),
78                Ok(sz) => size += sz,
79                Err(e) => match e.kind() {
80                    ErrorKind::WouldBlock => return (size, SocketResult::WouldBlock),
81                    ErrorKind::ConnectionReset
82                    | ErrorKind::ConnectionAborted
83                    | ErrorKind::BrokenPipe => return (size, SocketResult::Closed),
84                    _ => {
85                        error!("SOCKET\tsocket_read error={:?}", e);
86                        return (size, SocketResult::Error);
87                    }
88                },
89            }
90        }
91    }
92
93    fn socket_write(&mut self, buf: &[u8]) -> (usize, SocketResult) {
94        let mut size = 0usize;
95        let mut counter = 0;
96        loop {
97            counter += 1;
98            if counter > MAX_LOOP_ITERATIONS {
99                error!("MAX_LOOP_ITERATION reached in TcpStream::socket_write");
100                incr!("socket.write.infinite_loop.error");
101            }
102            if size == buf.len() {
103                return (size, SocketResult::Continue);
104            }
105            match self.write(&buf[size..]) {
106                Ok(0) => return (size, SocketResult::Continue),
107                Ok(sz) => size += sz,
108                Err(e) => match e.kind() {
109                    ErrorKind::WouldBlock => return (size, SocketResult::WouldBlock),
110                    ErrorKind::ConnectionReset
111                    | ErrorKind::ConnectionAborted
112                    | ErrorKind::BrokenPipe
113                    | ErrorKind::ConnectionRefused => {
114                        incr!("tcp.write.error");
115                        return (size, SocketResult::Closed);
116                    }
117                    _ => {
118                        //FIXME: timeout and other common errors should be sent up
119                        error!("SOCKET\tsocket_write error={:?}", e);
120                        incr!("tcp.write.error");
121                        return (size, SocketResult::Error);
122                    }
123                },
124            }
125        }
126    }
127
128    fn socket_write_vectored(&mut self, bufs: &[std::io::IoSlice]) -> (usize, SocketResult) {
129        match self.write_vectored(bufs) {
130            Ok(sz) => (sz, SocketResult::Continue),
131            Err(e) => match e.kind() {
132                ErrorKind::WouldBlock => (0, SocketResult::WouldBlock),
133                ErrorKind::ConnectionReset
134                | ErrorKind::ConnectionAborted
135                | ErrorKind::BrokenPipe
136                | ErrorKind::ConnectionRefused => {
137                    incr!("tcp.write.error");
138                    (0, SocketResult::Closed)
139                }
140                _ => {
141                    //FIXME: timeout and other common errors should be sent up
142                    error!("SOCKET\tsocket_write error={:?}", e);
143                    incr!("tcp.write.error");
144                    (0, SocketResult::Error)
145                }
146            },
147        }
148    }
149
150    fn socket_ref(&self) -> &TcpStream {
151        self
152    }
153
154    fn socket_mut(&mut self) -> &mut TcpStream {
155        self
156    }
157
158    fn protocol(&self) -> TransportProtocol {
159        TransportProtocol::Tcp
160    }
161
162    fn read_error(&self) {
163        incr!("tcp.read.error");
164    }
165
166    fn write_error(&self) {
167        incr!("tcp.write.error");
168    }
169}
170
171pub struct FrontRustls {
172    pub stream: TcpStream,
173    pub session: ServerConnection,
174}
175
176impl SocketHandler for FrontRustls {
177    fn socket_read(&mut self, buf: &mut [u8]) -> (usize, SocketResult) {
178        let mut size = 0usize;
179        let mut can_read = true;
180        let mut is_error = false;
181        let mut is_closed = false;
182
183        let mut counter = 0;
184        loop {
185            counter += 1;
186            if counter > MAX_LOOP_ITERATIONS {
187                error!("MAX_LOOP_ITERATION reached in FrontRustls::socket_read");
188                incr!("rustls.read.infinite_loop.error");
189            }
190
191            if size == buf.len() {
192                break;
193            }
194
195            if !can_read | is_error | is_closed {
196                break;
197            }
198
199            match self.session.read_tls(&mut self.stream) {
200                Ok(0) => {
201                    can_read = false;
202                    is_closed = true;
203                }
204                Ok(_sz) => {}
205                Err(e) => match e.kind() {
206                    ErrorKind::WouldBlock => {
207                        can_read = false;
208                    }
209                    ErrorKind::ConnectionReset
210                    | ErrorKind::ConnectionAborted
211                    | ErrorKind::BrokenPipe => {
212                        is_closed = true;
213                    }
214                    // https://github.com/rustls/rustls/blob/main/rustls/src/conn.rs#L482-L500,
215                    ErrorKind::Other => {
216                        warn!("rustls buffer is full, we will consume it, before processing new incoming packets, to mitigate this issue, you could try to increase the buffer size, {:?}", e);
217                    }
218                    _ => {
219                        error!("could not read TLS stream from socket: {:?}", e);
220                        is_error = true;
221                        break;
222                    }
223                },
224            }
225
226            if let Err(e) = self.session.process_new_packets() {
227                error!("could not process read TLS packets: {:?}", e);
228                is_error = true;
229                break;
230            }
231
232            while !self.session.wants_read() {
233                match self.session.reader().read(&mut buf[size..]) {
234                    Ok(0) => break,
235                    Ok(sz) => {
236                        size += sz;
237                    }
238                    Err(e) => match e.kind() {
239                        ErrorKind::WouldBlock => {
240                            break;
241                        }
242                        ErrorKind::ConnectionReset
243                        | ErrorKind::ConnectionAborted
244                        | ErrorKind::BrokenPipe => {
245                            is_closed = true;
246                            break;
247                        }
248                        _ => {
249                            error!("could not read data from TLS stream: {:?}", e);
250                            is_error = true;
251                            break;
252                        }
253                    },
254                }
255            }
256        }
257
258        if is_error {
259            (size, SocketResult::Error)
260        } else if is_closed {
261            (size, SocketResult::Closed)
262        } else if !can_read {
263            (size, SocketResult::WouldBlock)
264        } else {
265            (size, SocketResult::Continue)
266        }
267    }
268
269    fn socket_write(&mut self, buf: &[u8]) -> (usize, SocketResult) {
270        let mut buffered_size = 0usize;
271        let mut can_write = true;
272        let mut is_error = false;
273        let mut is_closed = false;
274
275        let mut counter = 0;
276        loop {
277            counter += 1;
278            if counter > MAX_LOOP_ITERATIONS {
279                error!("MAX_LOOP_ITERATION reached in FrontRustls::socket_write");
280                incr!("rustls.write.infinite_loop.error");
281            }
282            if buffered_size == buf.len() {
283                break;
284            }
285
286            if !can_write | is_error | is_closed {
287                break;
288            }
289
290            match self.session.writer().write(&buf[buffered_size..]) {
291                Ok(0) => {} // zero byte written means that the Rustls buffers are full, we will try to write on the socket and try again
292                Ok(sz) => {
293                    buffered_size += sz;
294                }
295                Err(e) => match e.kind() {
296                    ErrorKind::WouldBlock => {
297                        // we don't need to do anything, the session will return false in wants_write?
298                        //error!("rustls socket_write wouldblock");
299                    }
300                    ErrorKind::ConnectionReset
301                    | ErrorKind::ConnectionAborted
302                    | ErrorKind::BrokenPipe => {
303                        //FIXME: this should probably not happen here
304                        incr!("rustls.write.error");
305                        is_closed = true;
306                        break;
307                    }
308                    _ => {
309                        error!("could not write data to TLS stream: {:?}", e);
310                        incr!("rustls.write.error");
311                        is_error = true;
312                        break;
313                    }
314                },
315            }
316
317            loop {
318                match self.session.write_tls(&mut self.stream) {
319                    Ok(0) => {
320                        //can_write = false;
321                        break;
322                    }
323                    Ok(_sz) => {}
324                    Err(e) => match e.kind() {
325                        ErrorKind::WouldBlock => {
326                            can_write = false;
327                            break;
328                        }
329                        ErrorKind::ConnectionReset
330                        | ErrorKind::ConnectionAborted
331                        | ErrorKind::BrokenPipe => {
332                            incr!("rustls.write.error");
333                            is_closed = true;
334                            break;
335                        }
336                        _ => {
337                            error!("could not write TLS stream to socket: {:?}", e);
338                            incr!("rustls.write.error");
339                            is_error = true;
340                            break;
341                        }
342                    },
343                }
344            }
345        }
346
347        if is_error {
348            (buffered_size, SocketResult::Error)
349        } else if is_closed {
350            (buffered_size, SocketResult::Closed)
351        } else if !can_write {
352            (buffered_size, SocketResult::WouldBlock)
353        } else {
354            (buffered_size, SocketResult::Continue)
355        }
356    }
357
358    fn socket_write_vectored(&mut self, bufs: &[std::io::IoSlice]) -> (usize, SocketResult) {
359        let mut buffered_size = 0usize;
360        let mut can_write = true;
361        let mut is_error = false;
362        let mut is_closed = false;
363
364        match self.session.writer().write_vectored(bufs) {
365            Ok(0) => {} // zero byte written means that the Rustls buffers are full, we will try to write on the socket and try again
366            Ok(sz) => {
367                buffered_size += sz;
368            }
369            Err(e) => match e.kind() {
370                ErrorKind::WouldBlock => {
371                    // we don't need to do anything, the session will return false in wants_write?
372                    //error!("rustls socket_write wouldblock");
373                }
374                ErrorKind::ConnectionReset
375                | ErrorKind::ConnectionAborted
376                | ErrorKind::BrokenPipe => {
377                    //FIXME: this should probably not happen here
378                    incr!("rustls.write.error");
379                    is_closed = true;
380                }
381                _ => {
382                    error!("could not write data to TLS stream: {:?}", e);
383                    incr!("rustls.write.error");
384                    is_error = true;
385                }
386            },
387        }
388
389        let mut counter = 0;
390        loop {
391            counter += 1;
392            if counter > MAX_LOOP_ITERATIONS {
393                error!("MAX_LOOP_ITERATION reached in FrontRustls::socket_write_vectored");
394                incr!("rustls.write.infinite_loop.error");
395            }
396            match self.session.write_tls(&mut self.stream) {
397                Ok(0) => {
398                    break;
399                }
400                Ok(_sz) => {}
401                Err(e) => match e.kind() {
402                    ErrorKind::WouldBlock => {
403                        can_write = false;
404                        break;
405                    }
406                    ErrorKind::ConnectionReset
407                    | ErrorKind::ConnectionAborted
408                    | ErrorKind::BrokenPipe => {
409                        incr!("rustls.write.error");
410                        is_closed = true;
411                        break;
412                    }
413                    _ => {
414                        error!("could not write TLS stream to socket: {:?}", e);
415                        incr!("rustls.write.error");
416                        is_error = true;
417                        break;
418                    }
419                },
420            }
421        }
422
423        if is_error {
424            (buffered_size, SocketResult::Error)
425        } else if is_closed {
426            (buffered_size, SocketResult::Closed)
427        } else if !can_write {
428            (buffered_size, SocketResult::WouldBlock)
429        } else {
430            (buffered_size, SocketResult::Continue)
431        }
432    }
433
434    fn socket_close(&mut self) {
435        self.session.send_close_notify();
436    }
437
438    fn socket_wants_write(&self) -> bool {
439        self.session.wants_write()
440    }
441
442    fn socket_ref(&self) -> &TcpStream {
443        &self.stream
444    }
445
446    fn socket_mut(&mut self) -> &mut TcpStream {
447        &mut self.stream
448    }
449
450    fn protocol(&self) -> TransportProtocol {
451        self.session
452            .protocol_version()
453            .map(|version| match version {
454                ProtocolVersion::SSLv2 => TransportProtocol::Ssl2,
455                ProtocolVersion::SSLv3 => TransportProtocol::Ssl3,
456                ProtocolVersion::TLSv1_0 => TransportProtocol::Tls1_0,
457                ProtocolVersion::TLSv1_1 => TransportProtocol::Tls1_1,
458                ProtocolVersion::TLSv1_2 => TransportProtocol::Tls1_2,
459                ProtocolVersion::TLSv1_3 => TransportProtocol::Tls1_3,
460                _ => TransportProtocol::Tls1_3,
461            })
462            .unwrap_or(TransportProtocol::Tcp)
463    }
464
465    fn read_error(&self) {
466        incr!("rustls.read.error");
467    }
468
469    fn write_error(&self) {
470        incr!("rustls.write.error");
471    }
472}
473
474pub fn server_bind(addr: SocketAddr) -> Result<TcpListener, ServerBindError> {
475    let sock = Socket::new(Domain::for_address(addr), Type::STREAM, Some(Protocol::TCP))
476        .map_err(ServerBindError::SocketCreationError)?;
477
478    // set so_reuseaddr, but only on unix (mirrors what libstd does)
479    if cfg!(unix) {
480        sock.set_reuse_address(true)
481            .map_err(ServerBindError::SetReuseAddress)?;
482    }
483
484    sock.set_reuse_port(true)
485        .map_err(ServerBindError::SetReusePort)?;
486
487    sock.bind(&addr.into())
488        .map_err(ServerBindError::BindError)?;
489
490    sock.set_nonblocking(true)
491        .map_err(ServerBindError::SetNonBlocking)?;
492
493    // listen
494    // FIXME: make the backlog configurable?
495    sock.listen(1024).map_err(ServerBindError::Listen)?;
496
497    Ok(TcpListener::from_std(sock.into()))
498}
499
500/// Socket statistics
501pub mod stats {
502    use std::{os::fd::AsRawFd, time::Duration};
503
504    use internal::{TcpInfo, OPT_LEVEL, OPT_NAME};
505
506    /// Round trip time for a TCP socket
507    pub fn socket_rtt<A: AsRawFd>(socket: &A) -> Option<Duration> {
508        socket_info(socket.as_raw_fd()).map(|info| Duration::from_micros(info.rtt() as u64))
509    }
510
511    #[cfg(unix)]
512    pub fn socket_info(fd: libc::c_int) -> Option<TcpInfo> {
513        let mut tcp_info: TcpInfo = unsafe { std::mem::zeroed() };
514        let mut len = std::mem::size_of::<TcpInfo>() as libc::socklen_t;
515        let status = unsafe {
516            libc::getsockopt(
517                fd,
518                OPT_LEVEL,
519                OPT_NAME,
520                &mut tcp_info as *mut _ as *mut _,
521                &mut len,
522            )
523        };
524        if status != 0 {
525            None
526        } else {
527            Some(tcp_info)
528        }
529    }
530    #[cfg(not(unix))]
531    pub fn socketinfo(fd: libc::c_int) -> Option<TcpInfo> {
532        None
533    }
534
535    #[cfg(unix)]
536    #[cfg(not(any(target_os = "macos", target_os = "ios")))]
537    mod internal {
538        pub const OPT_LEVEL: libc::c_int = libc::SOL_TCP;
539        pub const OPT_NAME: libc::c_int = libc::TCP_INFO;
540
541        #[derive(Clone, Debug)]
542        #[repr(C)]
543        pub struct TcpInfo {
544            // State
545            tcpi_state: u8,
546            tcpi_ca_state: u8,
547            tcpi_retransmits: u8,
548            tcpi_probes: u8,
549            tcpi_backoff: u8,
550            tcpi_options: u8,
551            tcpi_snd_rcv_wscale: u8, // 4bits|4bits
552
553            tcpi_rto: u32,
554            tcpi_ato: u32,
555            tcpi_snd_mss: u32,
556            tcpi_rcv_mss: u32,
557
558            tcpi_unacked: u32,
559            tcpi_sacked: u32,
560            tcpi_lost: u32,
561            tcpi_retrans: u32,
562            tcpi_fackets: u32,
563
564            // Times
565            tcpi_last_data_sent: u32,
566            tcpi_last_ack_sent: u32, // Not remembered
567            tcpi_last_data_recv: u32,
568            tcpi_last_ack_recv: u32,
569
570            // Metrics
571            tcpi_pmtu: u32,
572            tcpi_rcv_ssthresh: u32,
573            tcpi_rtt: u32,
574            tcpi_rttvar: u32,
575            tcpi_snd_ssthresh: u32,
576            tcpi_snd_cwnd: u32,
577            tcpi_advmss: u32,
578            tcpi_reordering: u32,
579        }
580        impl TcpInfo {
581            pub fn rtt(&self) -> u32 {
582                self.tcpi_rtt
583            }
584        }
585    }
586
587    #[cfg(unix)]
588    #[cfg(any(target_os = "macos", target_os = "ios"))]
589    mod internal {
590        pub const OPT_LEVEL: libc::c_int = libc::IPPROTO_TCP;
591        pub const OPT_NAME: libc::c_int = 0x106;
592
593        #[derive(Clone, Debug)]
594        #[repr(C)]
595        pub struct TcpInfo {
596            tcpi_state: u8,
597            tcpi_snd_wscale: u8,
598            tcpi_rcv_wscale: u8,
599            __pad1: u8,
600            tcpi_options: u32,
601            tcpi_flags: u32,
602            tcpi_rto: u32,
603            tcpi_maxseg: u32,
604            tcpi_snd_ssthresh: u32,
605            tcpi_snd_cwnd: u32,
606            tcpi_snd_wnd: u32,
607            tcpi_snd_sbbytes: u32,
608            tcpi_rcv_wnd: u32,
609            tcpi_rttcur: u32,
610            tcpi_srtt: u32,
611            tcpi_rttvar: u32,
612            tcpi_tfo: u32,
613            tcpi_txpackets: u64,
614            tcpi_txbytes: u64,
615            tcpi_txretransmitbytes: u64,
616            tcpi_rxpackets: u64,
617            tcpi_rxbytes: u64,
618            tcpi_rxoutoforderbytes: u64,
619            tcpi_txretransmitpackets: u64,
620        }
621        impl TcpInfo {
622            pub fn rtt(&self) -> u32 {
623                // tcpi_srtt is in milliseconds not microseconds
624                self.tcpi_srtt * 1000
625            }
626        }
627    }
628
629    #[cfg(not(unix))]
630    #[derive(Clone, Debug)]
631    struct TcpInfo {}
632
633    #[test]
634    #[serial_test::serial]
635    fn test_rtt() {
636        let sock = std::net::TcpStream::connect("google.com:80").unwrap();
637        let fd = sock.as_raw_fd();
638        let info = socket_info(fd);
639        assert!(info.is_some());
640        println!("{:#?}", info);
641        println!(
642            "rtt: {}",
643            sozu_command::logging::LogDuration(socket_rtt(&sock))
644        );
645    }
646}