1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
use futures::{Async, Future, Poll};
use state_machine_future::{transition, RentToOwn, StateMachineFuture};

use crate::config::Host;
use crate::proto::{Client, ConnectOnceFuture, Connection, MaybeTlsStream};
use crate::{Config, Error, MakeTlsConnect, Socket};

#[derive(StateMachineFuture)]
pub enum Connect<T>
where
    T: MakeTlsConnect<Socket>,
{
    #[state_machine_future(start, transitions(Connecting))]
    Start {
        tls: T,
        config: Result<Config, Error>,
    },
    #[state_machine_future(transitions(Finished))]
    Connecting {
        future: ConnectOnceFuture<T::TlsConnect>,
        idx: usize,
        tls: T,
        config: Config,
    },
    #[state_machine_future(ready)]
    Finished((Client, Connection<MaybeTlsStream<Socket, T::Stream>>)),
    #[state_machine_future(error)]
    Failed(Error),
}

impl<T> PollConnect<T> for Connect<T>
where
    T: MakeTlsConnect<Socket>,
{
    fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start<T>>) -> Poll<AfterStart<T>, Error> {
        let mut state = state.take();

        let config = state.config?;

        if config.0.host.is_empty() {
            return Err(Error::config("host missing".into()));
        }

        if config.0.port.len() > 1 && config.0.port.len() != config.0.host.len() {
            return Err(Error::config("invalid number of ports".into()));
        }

        let hostname = match &config.0.host[0] {
            Host::Tcp(host) => &**host,
            // postgres doesn't support TLS over unix sockets, so the choice here doesn't matter
            #[cfg(unix)]
            Host::Unix(_) => "",
        };
        let tls = state
            .tls
            .make_tls_connect(hostname)
            .map_err(|e| Error::tls(e.into()))?;

        transition!(Connecting {
            future: ConnectOnceFuture::new(0, tls, config.clone()),
            idx: 0,
            tls: state.tls,
            config,
        })
    }

    fn poll_connecting<'a>(
        state: &'a mut RentToOwn<'a, Connecting<T>>,
    ) -> Poll<AfterConnecting<T>, Error> {
        loop {
            match state.future.poll() {
                Ok(Async::Ready(r)) => transition!(Finished(r)),
                Ok(Async::NotReady) => return Ok(Async::NotReady),
                Err(e) => {
                    let state = &mut **state;
                    state.idx += 1;

                    let host = match state.config.0.host.get(state.idx) {
                        Some(host) => host,
                        None => return Err(e),
                    };

                    let hostname = match host {
                        Host::Tcp(host) => &**host,
                        #[cfg(unix)]
                        Host::Unix(_) => "",
                    };
                    let tls = state
                        .tls
                        .make_tls_connect(hostname)
                        .map_err(|e| Error::tls(e.into()))?;

                    state.future = ConnectOnceFuture::new(state.idx, tls, state.config.clone());
                }
            }
        }
    }
}

impl<T> ConnectFuture<T>
where
    T: MakeTlsConnect<Socket>,
{
    pub fn new(tls: T, config: Result<Config, Error>) -> ConnectFuture<T> {
        Connect::start(tls, config)
    }
}