1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
use futures::{try_ready, Future, Poll}; use postgres_protocol::message::frontend; use state_machine_future::{transition, RentToOwn, StateMachineFuture}; use tokio_io::io::{self, ReadExact, WriteAll}; use tokio_io::{AsyncRead, AsyncWrite}; use crate::config::SslMode; use crate::proto::MaybeTlsStream; use crate::tls::private::ForcePrivateApi; use crate::tls::ChannelBinding; use crate::{Error, TlsConnect}; #[derive(StateMachineFuture)] pub enum Tls<S, T> where T: TlsConnect<S>, S: AsyncRead + AsyncWrite, { #[state_machine_future(start, transitions(SendingTls, Ready))] Start { stream: S, mode: SslMode, tls: T }, #[state_machine_future(transitions(ReadingTls))] SendingTls { future: WriteAll<S, Vec<u8>>, mode: SslMode, tls: T, }, #[state_machine_future(transitions(ConnectingTls, Ready))] ReadingTls { future: ReadExact<S, [u8; 1]>, mode: SslMode, tls: T, }, #[state_machine_future(transitions(Ready))] ConnectingTls { future: T::Future }, #[state_machine_future(ready)] Ready((MaybeTlsStream<S, T::Stream>, ChannelBinding)), #[state_machine_future(error)] Failed(Error), } impl<S, T> PollTls<S, T> for Tls<S, T> where T: TlsConnect<S>, S: AsyncRead + AsyncWrite, { fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start<S, T>>) -> Poll<AfterStart<S, T>, Error> { let state = state.take(); match state.mode { SslMode::Disable => transition!(Ready(( MaybeTlsStream::Raw(state.stream), ChannelBinding::none() ))), SslMode::Prefer if !state.tls.can_connect(ForcePrivateApi) => transition!(Ready(( MaybeTlsStream::Raw(state.stream), ChannelBinding::none() ))), SslMode::Prefer | SslMode::Require => { let mut buf = vec![]; frontend::ssl_request(&mut buf); transition!(SendingTls { future: io::write_all(state.stream, buf), mode: state.mode, tls: state.tls, }) } SslMode::__NonExhaustive => unreachable!(), } } fn poll_sending_tls<'a>( state: &'a mut RentToOwn<'a, SendingTls<S, T>>, ) -> Poll<AfterSendingTls<S, T>, Error> { let (stream, _) = try_ready!(state.future.poll().map_err(Error::io)); let state = state.take(); transition!(ReadingTls { future: io::read_exact(stream, [0]), mode: state.mode, tls: state.tls, }) } fn poll_reading_tls<'a>( state: &'a mut RentToOwn<'a, ReadingTls<S, T>>, ) -> Poll<AfterReadingTls<S, T>, Error> { let (stream, buf) = try_ready!(state.future.poll().map_err(Error::io)); let state = state.take(); if buf[0] == b'S' { transition!(ConnectingTls { future: state.tls.connect(stream), }) } else if state.mode == SslMode::Require { Err(Error::tls("server does not support TLS".into())) } else { transition!(Ready((MaybeTlsStream::Raw(stream), ChannelBinding::none()))) } } fn poll_connecting_tls<'a>( state: &'a mut RentToOwn<'a, ConnectingTls<S, T>>, ) -> Poll<AfterConnectingTls<S, T>, Error> { let (stream, channel_binding) = try_ready!(state.future.poll().map_err(|e| Error::tls(e.into()))); transition!(Ready((MaybeTlsStream::Tls(stream), channel_binding))) } } impl<S, T> TlsFuture<S, T> where T: TlsConnect<S>, S: AsyncRead + AsyncWrite, { pub fn new(stream: S, mode: SslMode, tls: T) -> TlsFuture<S, T> { Tls::start(stream, mode, tls) } }