tokio_rustls/
lib.rs

1//! Asynchronous TLS/SSL streams for Tokio using [Rustls](https://github.com/rustls/rustls).
2//!
3//! # Why do I need to call `poll_flush`?
4//!
5//! Most TLS implementations will have an internal buffer to improve throughput,
6//! and rustls is no exception.
7//!
8//! When we write data to `TlsStream`, we always write rustls buffer first,
9//! then take out rustls encrypted data packet, and write it to data channel (like TcpStream).
10//! When data channel is pending, some data may remain in rustls buffer.
11//!
12//! `tokio-rustls` To keep it simple and correct, [TlsStream] will behave like `BufWriter`.
13//! For `TlsStream<TcpStream>`, this means that data written by `poll_write` is not guaranteed to be written to `TcpStream`.
14//! You must call `poll_flush` to ensure that it is written to `TcpStream`.
15//!
16//! You should call `poll_flush` at the appropriate time,
17//! such as when a period of `poll_write` write is complete and there is no more data to write.
18//!
19//! ## Why don't we write during `poll_read`?
20//!
21//! We did this in the early days of `tokio-rustls`, but it caused some bugs.
22//! We can solve these bugs through some solutions, but this will cause performance degradation (reverse false wakeup).
23//!
24//! And reverse write will also prevent us implement full duplex in the future.
25//!
26//! see <https://github.com/tokio-rs/tls/issues/40>
27//!
28//! ## Why can't we handle it like `native-tls`?
29//!
30//! When data channel returns to pending, `native-tls` will falsely report the number of bytes it consumes.
31//! This means that if data written by `poll_write` is not actually written to data channel, it will not return `Ready`.
32//! Thus avoiding the call of `poll_flush`.
33//!
34//! but which does not conform to convention of `AsyncWrite` trait.
35//! This means that if you give inconsistent data in two `poll_write`, it may cause unexpected behavior.
36//!
37//! see <https://github.com/tokio-rs/tls/issues/41>
38
39#![warn(unreachable_pub, clippy::use_self)]
40
41use std::io;
42#[cfg(unix)]
43use std::os::unix::io::{AsRawFd, RawFd};
44#[cfg(windows)]
45use std::os::windows::io::{AsRawSocket, RawSocket};
46use std::pin::Pin;
47use std::task::{Context, Poll};
48
49pub use rustls;
50
51use rustls::CommonState;
52use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
53
54macro_rules! ready {
55    ( $e:expr ) => {
56        match $e {
57            std::task::Poll::Ready(t) => t,
58            std::task::Poll::Pending => return std::task::Poll::Pending,
59        }
60    };
61}
62
63pub mod client;
64pub use client::{Connect, FallibleConnect, TlsConnector, TlsConnectorWithAlpn};
65mod common;
66pub mod server;
67pub use server::{Accept, FallibleAccept, LazyConfigAcceptor, StartHandshake, TlsAcceptor};
68
69/// Unified TLS stream type
70///
71/// This abstracts over the inner `client::TlsStream` and `server::TlsStream`, so you can use
72/// a single type to keep both client- and server-initiated TLS-encrypted connections.
73#[allow(clippy::large_enum_variant)] // https://github.com/rust-lang/rust-clippy/issues/9798
74#[derive(Debug)]
75pub enum TlsStream<T> {
76    Client(client::TlsStream<T>),
77    Server(server::TlsStream<T>),
78}
79
80impl<T> TlsStream<T> {
81    pub fn get_ref(&self) -> (&T, &CommonState) {
82        use TlsStream::*;
83        match self {
84            Client(io) => {
85                let (io, session) = io.get_ref();
86                (io, session)
87            }
88            Server(io) => {
89                let (io, session) = io.get_ref();
90                (io, session)
91            }
92        }
93    }
94
95    pub fn get_mut(&mut self) -> (&mut T, &mut CommonState) {
96        use TlsStream::*;
97        match self {
98            Client(io) => {
99                let (io, session) = io.get_mut();
100                (io, &mut *session)
101            }
102            Server(io) => {
103                let (io, session) = io.get_mut();
104                (io, &mut *session)
105            }
106        }
107    }
108}
109
110impl<T> From<client::TlsStream<T>> for TlsStream<T> {
111    fn from(s: client::TlsStream<T>) -> Self {
112        Self::Client(s)
113    }
114}
115
116impl<T> From<server::TlsStream<T>> for TlsStream<T> {
117    fn from(s: server::TlsStream<T>) -> Self {
118        Self::Server(s)
119    }
120}
121
122#[cfg(unix)]
123impl<S> AsRawFd for TlsStream<S>
124where
125    S: AsRawFd,
126{
127    fn as_raw_fd(&self) -> RawFd {
128        self.get_ref().0.as_raw_fd()
129    }
130}
131
132#[cfg(windows)]
133impl<S> AsRawSocket for TlsStream<S>
134where
135    S: AsRawSocket,
136{
137    fn as_raw_socket(&self) -> RawSocket {
138        self.get_ref().0.as_raw_socket()
139    }
140}
141
142impl<T> AsyncRead for TlsStream<T>
143where
144    T: AsyncRead + AsyncWrite + Unpin,
145{
146    #[inline]
147    fn poll_read(
148        self: Pin<&mut Self>,
149        cx: &mut Context<'_>,
150        buf: &mut ReadBuf<'_>,
151    ) -> Poll<io::Result<()>> {
152        match self.get_mut() {
153            Self::Client(x) => Pin::new(x).poll_read(cx, buf),
154            Self::Server(x) => Pin::new(x).poll_read(cx, buf),
155        }
156    }
157}
158
159impl<T> AsyncBufRead for TlsStream<T>
160where
161    T: AsyncRead + AsyncWrite + Unpin,
162{
163    #[inline]
164    fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
165        match self.get_mut() {
166            Self::Client(x) => Pin::new(x).poll_fill_buf(cx),
167            Self::Server(x) => Pin::new(x).poll_fill_buf(cx),
168        }
169    }
170
171    #[inline]
172    fn consume(self: Pin<&mut Self>, amt: usize) {
173        match self.get_mut() {
174            Self::Client(x) => Pin::new(x).consume(amt),
175            Self::Server(x) => Pin::new(x).consume(amt),
176        }
177    }
178}
179
180impl<T> AsyncWrite for TlsStream<T>
181where
182    T: AsyncRead + AsyncWrite + Unpin,
183{
184    #[inline]
185    fn poll_write(
186        self: Pin<&mut Self>,
187        cx: &mut Context<'_>,
188        buf: &[u8],
189    ) -> Poll<io::Result<usize>> {
190        match self.get_mut() {
191            Self::Client(x) => Pin::new(x).poll_write(cx, buf),
192            Self::Server(x) => Pin::new(x).poll_write(cx, buf),
193        }
194    }
195
196    #[inline]
197    fn poll_write_vectored(
198        self: Pin<&mut Self>,
199        cx: &mut Context<'_>,
200        bufs: &[io::IoSlice<'_>],
201    ) -> Poll<io::Result<usize>> {
202        match self.get_mut() {
203            Self::Client(x) => Pin::new(x).poll_write_vectored(cx, bufs),
204            Self::Server(x) => Pin::new(x).poll_write_vectored(cx, bufs),
205        }
206    }
207
208    #[inline]
209    fn is_write_vectored(&self) -> bool {
210        match self {
211            Self::Client(x) => x.is_write_vectored(),
212            Self::Server(x) => x.is_write_vectored(),
213        }
214    }
215
216    #[inline]
217    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
218        match self.get_mut() {
219            Self::Client(x) => Pin::new(x).poll_flush(cx),
220            Self::Server(x) => Pin::new(x).poll_flush(cx),
221        }
222    }
223
224    #[inline]
225    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
226        match self.get_mut() {
227            Self::Client(x) => Pin::new(x).poll_shutdown(cx),
228            Self::Server(x) => Pin::new(x).poll_shutdown(cx),
229        }
230    }
231}