tokio_rustls/
client.rs

1use std::future::Future;
2#[cfg(unix)]
3use std::os::unix::io::{AsRawFd, RawFd};
4#[cfg(windows)]
5use std::os::windows::io::{AsRawSocket, RawSocket};
6use std::pin::Pin;
7#[cfg(feature = "early-data")]
8use std::task::Waker;
9use std::task::{Context, Poll};
10use std::{
11    io::{self, BufRead as _},
12    sync::Arc,
13};
14
15use rustls::{pki_types::ServerName, ClientConfig, ClientConnection};
16use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
17
18use crate::common::{IoSession, MidHandshake, Stream, TlsState};
19
20/// A wrapper around a `rustls::ClientConfig`, providing an async `connect` method.
21#[derive(Clone)]
22pub struct TlsConnector {
23    inner: Arc<ClientConfig>,
24    #[cfg(feature = "early-data")]
25    early_data: bool,
26}
27
28impl TlsConnector {
29    /// Enable 0-RTT.
30    ///
31    /// If you want to use 0-RTT,
32    /// You must also set `ClientConfig.enable_early_data` to `true`.
33    #[cfg(feature = "early-data")]
34    pub fn early_data(mut self, flag: bool) -> Self {
35        self.early_data = flag;
36        self
37    }
38
39    #[inline]
40    pub fn connect<IO>(&self, domain: ServerName<'static>, stream: IO) -> Connect<IO>
41    where
42        IO: AsyncRead + AsyncWrite + Unpin,
43    {
44        self.connect_impl(domain, stream, None, |_| ())
45    }
46
47    #[inline]
48    pub fn connect_with<IO, F>(&self, domain: ServerName<'static>, stream: IO, f: F) -> Connect<IO>
49    where
50        IO: AsyncRead + AsyncWrite + Unpin,
51        F: FnOnce(&mut ClientConnection),
52    {
53        self.connect_impl(domain, stream, None, f)
54    }
55
56    fn connect_impl<IO, F>(
57        &self,
58        domain: ServerName<'static>,
59        stream: IO,
60        alpn_protocols: Option<Vec<Vec<u8>>>,
61        f: F,
62    ) -> Connect<IO>
63    where
64        IO: AsyncRead + AsyncWrite + Unpin,
65        F: FnOnce(&mut ClientConnection),
66    {
67        let alpn = alpn_protocols.unwrap_or_else(|| self.inner.alpn_protocols.clone());
68        let mut session = match ClientConnection::new_with_alpn(self.inner.clone(), domain, alpn) {
69            Ok(session) => session,
70            Err(error) => {
71                return Connect(MidHandshake::Error {
72                    io: stream,
73                    // TODO(eliza): should this really return an `io::Error`?
74                    // Probably not...
75                    error: io::Error::new(io::ErrorKind::Other, error),
76                });
77            }
78        };
79        f(&mut session);
80
81        Connect(MidHandshake::Handshaking(TlsStream {
82            io: stream,
83
84            #[cfg(not(feature = "early-data"))]
85            state: TlsState::Stream,
86
87            #[cfg(feature = "early-data")]
88            state: if self.early_data && session.early_data().is_some() {
89                TlsState::EarlyData(0, Vec::new())
90            } else {
91                TlsState::Stream
92            },
93
94            need_flush: false,
95
96            #[cfg(feature = "early-data")]
97            early_waker: None,
98
99            session,
100        }))
101    }
102
103    pub fn with_alpn(&self, alpn_protocols: Vec<Vec<u8>>) -> TlsConnectorWithAlpn<'_> {
104        TlsConnectorWithAlpn {
105            inner: self,
106            alpn_protocols,
107        }
108    }
109
110    /// Get a read-only reference to underlying config
111    pub fn config(&self) -> &Arc<ClientConfig> {
112        &self.inner
113    }
114}
115
116impl From<Arc<ClientConfig>> for TlsConnector {
117    fn from(inner: Arc<ClientConfig>) -> Self {
118        Self {
119            inner,
120            #[cfg(feature = "early-data")]
121            early_data: false,
122        }
123    }
124}
125
126pub struct TlsConnectorWithAlpn<'c> {
127    inner: &'c TlsConnector,
128    alpn_protocols: Vec<Vec<u8>>,
129}
130
131impl TlsConnectorWithAlpn<'_> {
132    #[inline]
133    pub fn connect<IO>(self, domain: ServerName<'static>, stream: IO) -> Connect<IO>
134    where
135        IO: AsyncRead + AsyncWrite + Unpin,
136    {
137        self.inner
138            .connect_impl(domain, stream, Some(self.alpn_protocols), |_| ())
139    }
140
141    #[inline]
142    pub fn connect_with<IO, F>(self, domain: ServerName<'static>, stream: IO, f: F) -> Connect<IO>
143    where
144        IO: AsyncRead + AsyncWrite + Unpin,
145        F: FnOnce(&mut ClientConnection),
146    {
147        self.inner
148            .connect_impl(domain, stream, Some(self.alpn_protocols), f)
149    }
150}
151
152/// Future returned from `TlsConnector::connect` which will resolve
153/// once the connection handshake has finished.
154pub struct Connect<IO>(MidHandshake<TlsStream<IO>>);
155
156impl<IO> Connect<IO> {
157    #[inline]
158    pub fn into_fallible(self) -> FallibleConnect<IO> {
159        FallibleConnect(self.0)
160    }
161
162    pub fn get_ref(&self) -> Option<&IO> {
163        match &self.0 {
164            MidHandshake::Handshaking(sess) => Some(sess.get_ref().0),
165            MidHandshake::SendAlert { io, .. } => Some(io),
166            MidHandshake::Error { io, .. } => Some(io),
167            MidHandshake::End => None,
168        }
169    }
170
171    pub fn get_mut(&mut self) -> Option<&mut IO> {
172        match &mut self.0 {
173            MidHandshake::Handshaking(sess) => Some(sess.get_mut().0),
174            MidHandshake::SendAlert { io, .. } => Some(io),
175            MidHandshake::Error { io, .. } => Some(io),
176            MidHandshake::End => None,
177        }
178    }
179}
180
181impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Connect<IO> {
182    type Output = io::Result<TlsStream<IO>>;
183
184    #[inline]
185    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
186        Pin::new(&mut self.0).poll(cx).map_err(|(err, _)| err)
187    }
188}
189
190impl<IO: AsyncRead + AsyncWrite + Unpin> Future for FallibleConnect<IO> {
191    type Output = Result<TlsStream<IO>, (io::Error, IO)>;
192
193    #[inline]
194    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
195        Pin::new(&mut self.0).poll(cx)
196    }
197}
198
199/// Like [Connect], but returns `IO` on failure.
200pub struct FallibleConnect<IO>(MidHandshake<TlsStream<IO>>);
201
202/// A wrapper around an underlying raw stream which implements the TLS or SSL
203/// protocol.
204#[derive(Debug)]
205pub struct TlsStream<IO> {
206    pub(crate) io: IO,
207    pub(crate) session: ClientConnection,
208    pub(crate) state: TlsState,
209    pub(crate) need_flush: bool,
210
211    #[cfg(feature = "early-data")]
212    pub(crate) early_waker: Option<Waker>,
213}
214
215impl<IO> TlsStream<IO> {
216    #[inline]
217    pub fn get_ref(&self) -> (&IO, &ClientConnection) {
218        (&self.io, &self.session)
219    }
220
221    #[inline]
222    pub fn get_mut(&mut self) -> (&mut IO, &mut ClientConnection) {
223        (&mut self.io, &mut self.session)
224    }
225
226    #[inline]
227    pub fn into_inner(self) -> (IO, ClientConnection) {
228        (self.io, self.session)
229    }
230}
231
232#[cfg(unix)]
233impl<S> AsRawFd for TlsStream<S>
234where
235    S: AsRawFd,
236{
237    fn as_raw_fd(&self) -> RawFd {
238        self.get_ref().0.as_raw_fd()
239    }
240}
241
242#[cfg(windows)]
243impl<S> AsRawSocket for TlsStream<S>
244where
245    S: AsRawSocket,
246{
247    fn as_raw_socket(&self) -> RawSocket {
248        self.get_ref().0.as_raw_socket()
249    }
250}
251
252impl<IO> IoSession for TlsStream<IO> {
253    type Io = IO;
254    type Session = ClientConnection;
255
256    #[inline]
257    fn skip_handshake(&self) -> bool {
258        self.state.is_early_data()
259    }
260
261    #[inline]
262    fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session, &mut bool) {
263        (
264            &mut self.state,
265            &mut self.io,
266            &mut self.session,
267            &mut self.need_flush,
268        )
269    }
270
271    #[inline]
272    fn into_io(self) -> Self::Io {
273        self.io
274    }
275}
276
277#[cfg(feature = "early-data")]
278impl<IO> TlsStream<IO>
279where
280    IO: AsyncRead + AsyncWrite + Unpin,
281{
282    fn poll_early_data(&mut self, cx: &mut Context<'_>) {
283        // In the EarlyData state, we have not really established a Tls connection.
284        // Before writing data through `AsyncWrite` and completing the tls handshake,
285        // we ignore read readiness and return to pending.
286        //
287        // In order to avoid event loss,
288        // we need to register a waker and wake it up after tls is connected.
289        if self
290            .early_waker
291            .as_ref()
292            .filter(|waker| cx.waker().will_wake(waker))
293            .is_none()
294        {
295            self.early_waker = Some(cx.waker().clone());
296        }
297    }
298}
299
300impl<IO> AsyncRead for TlsStream<IO>
301where
302    IO: AsyncRead + AsyncWrite + Unpin,
303{
304    fn poll_read(
305        mut self: Pin<&mut Self>,
306        cx: &mut Context<'_>,
307        buf: &mut ReadBuf<'_>,
308    ) -> Poll<io::Result<()>> {
309        let data = ready!(self.as_mut().poll_fill_buf(cx))?;
310        let len = data.len().min(buf.remaining());
311        buf.put_slice(&data[..len]);
312        self.consume(len);
313        Poll::Ready(Ok(()))
314    }
315}
316
317impl<IO> AsyncBufRead for TlsStream<IO>
318where
319    IO: AsyncRead + AsyncWrite + Unpin,
320{
321    fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
322        match self.state {
323            #[cfg(feature = "early-data")]
324            TlsState::EarlyData(..) => {
325                self.get_mut().poll_early_data(cx);
326                Poll::Pending
327            }
328            TlsState::Stream | TlsState::WriteShutdown => {
329                let this = self.get_mut();
330                let stream =
331                    Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
332
333                match stream.poll_fill_buf(cx) {
334                    Poll::Ready(Ok(buf)) => {
335                        if buf.is_empty() {
336                            this.state.shutdown_read();
337                        }
338
339                        Poll::Ready(Ok(buf))
340                    }
341                    Poll::Ready(Err(err)) if err.kind() == io::ErrorKind::ConnectionAborted => {
342                        this.state.shutdown_read();
343                        Poll::Ready(Err(err))
344                    }
345                    output => output,
346                }
347            }
348            TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(&[])),
349        }
350    }
351
352    fn consume(mut self: Pin<&mut Self>, amt: usize) {
353        self.session.reader().consume(amt);
354    }
355}
356
357impl<IO> AsyncWrite for TlsStream<IO>
358where
359    IO: AsyncRead + AsyncWrite + Unpin,
360{
361    /// Note: that it does not guarantee the final data to be sent.
362    /// To be cautious, you must manually call `flush`.
363    fn poll_write(
364        self: Pin<&mut Self>,
365        cx: &mut Context<'_>,
366        buf: &[u8],
367    ) -> Poll<io::Result<usize>> {
368        let this = self.get_mut();
369        let mut stream = Stream::new(&mut this.io, &mut this.session)
370            .set_eof(!this.state.readable())
371            .set_need_flush(this.need_flush);
372
373        #[cfg(feature = "early-data")]
374        {
375            let bufs = [io::IoSlice::new(buf)];
376            let written = poll_handle_early_data(
377                &mut this.state,
378                &mut stream,
379                &mut this.early_waker,
380                cx,
381                &bufs,
382            )?;
383            match written {
384                Poll::Ready(0) => {}
385                Poll::Ready(written) => return Poll::Ready(Ok(written)),
386                Poll::Pending => {
387                    this.need_flush = stream.need_flush;
388                    return Poll::Pending;
389                }
390            }
391        }
392
393        stream.as_mut_pin().poll_write(cx, buf)
394    }
395
396    /// Note: that it does not guarantee the final data to be sent.
397    /// To be cautious, you must manually call `flush`.
398    fn poll_write_vectored(
399        self: Pin<&mut Self>,
400        cx: &mut Context<'_>,
401        bufs: &[io::IoSlice<'_>],
402    ) -> Poll<io::Result<usize>> {
403        let this = self.get_mut();
404        let mut stream = Stream::new(&mut this.io, &mut this.session)
405            .set_eof(!this.state.readable())
406            .set_need_flush(this.need_flush);
407
408        #[cfg(feature = "early-data")]
409        {
410            let written = poll_handle_early_data(
411                &mut this.state,
412                &mut stream,
413                &mut this.early_waker,
414                cx,
415                bufs,
416            )?;
417            match written {
418                Poll::Ready(0) => {}
419                Poll::Ready(written) => return Poll::Ready(Ok(written)),
420                Poll::Pending => {
421                    this.need_flush = stream.need_flush;
422                    return Poll::Pending;
423                }
424            }
425        }
426
427        stream.as_mut_pin().poll_write_vectored(cx, bufs)
428    }
429
430    #[inline]
431    fn is_write_vectored(&self) -> bool {
432        true
433    }
434
435    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
436        let this = self.get_mut();
437        let mut stream = Stream::new(&mut this.io, &mut this.session)
438            .set_eof(!this.state.readable())
439            .set_need_flush(this.need_flush);
440
441        #[cfg(feature = "early-data")]
442        {
443            let written = poll_handle_early_data(
444                &mut this.state,
445                &mut stream,
446                &mut this.early_waker,
447                cx,
448                &[],
449            )?;
450            if written.is_pending() {
451                this.need_flush = stream.need_flush;
452                return Poll::Pending;
453            }
454        }
455
456        stream.as_mut_pin().poll_flush(cx)
457    }
458
459    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
460        #[cfg(feature = "early-data")]
461        {
462            // complete handshake
463            if matches!(self.state, TlsState::EarlyData(..)) {
464                ready!(self.as_mut().poll_flush(cx))?;
465            }
466        }
467
468        if self.state.writeable() {
469            self.session.send_close_notify();
470            self.state.shutdown_write();
471        }
472
473        let this = self.get_mut();
474        let mut stream =
475            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
476        stream.as_mut_pin().poll_shutdown(cx)
477    }
478}
479
480#[cfg(feature = "early-data")]
481fn poll_handle_early_data<IO>(
482    state: &mut TlsState,
483    stream: &mut Stream<IO, ClientConnection>,
484    early_waker: &mut Option<Waker>,
485    cx: &mut Context<'_>,
486    bufs: &[io::IoSlice<'_>],
487) -> Poll<io::Result<usize>>
488where
489    IO: AsyncRead + AsyncWrite + Unpin,
490{
491    if let TlsState::EarlyData(pos, data) = state {
492        use std::io::Write;
493
494        // write early data
495        if let Some(mut early_data) = stream.session.early_data() {
496            let mut written = 0;
497
498            for buf in bufs {
499                if buf.is_empty() {
500                    continue;
501                }
502
503                let len = match early_data.write(buf) {
504                    Ok(0) => break,
505                    Ok(n) => n,
506                    Err(err) => return Poll::Ready(Err(err)),
507                };
508
509                written += len;
510                data.extend_from_slice(&buf[..len]);
511
512                if len < buf.len() {
513                    break;
514                }
515            }
516
517            if written != 0 {
518                return Poll::Ready(Ok(written));
519            }
520        }
521
522        // complete handshake
523        while stream.session.is_handshaking() {
524            ready!(stream.handshake(cx))?;
525        }
526
527        // write early data (fallback)
528        if !stream.session.is_early_data_accepted() {
529            while *pos < data.len() {
530                let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?;
531                *pos += len;
532            }
533        }
534
535        // end
536        *state = TlsState::Stream;
537
538        if let Some(waker) = early_waker.take() {
539            waker.wake();
540        }
541    }
542
543    Poll::Ready(Ok(0))
544}