1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
#![allow(clippy::large_enum_variant)] use futures::{try_ready, Async, Future, Poll, Stream}; use state_machine_future::{transition, RentToOwn, StateMachineFuture}; use std::io; use crate::config::TargetSessionAttrs; use crate::proto::{ Client, ConnectRawFuture, ConnectSocketFuture, Connection, MaybeTlsStream, SimpleQueryStream, }; use crate::{Config, Error, SimpleQueryMessage, Socket, TlsConnect}; #[derive(StateMachineFuture)] pub enum ConnectOnce<T> where T: TlsConnect<Socket>, { #[state_machine_future(start, transitions(ConnectingSocket))] Start { idx: usize, tls: T, config: Config }, #[state_machine_future(transitions(ConnectingRaw))] ConnectingSocket { future: ConnectSocketFuture, idx: usize, tls: T, config: Config, }, #[state_machine_future(transitions(CheckingSessionAttrs, Finished))] ConnectingRaw { future: ConnectRawFuture<Socket, T>, target_session_attrs: TargetSessionAttrs, }, #[state_machine_future(transitions(Finished))] CheckingSessionAttrs { stream: SimpleQueryStream, client: Client, connection: Connection<MaybeTlsStream<Socket, T::Stream>>, }, #[state_machine_future(ready)] Finished((Client, Connection<MaybeTlsStream<Socket, T::Stream>>)), #[state_machine_future(error)] Failed(Error), } impl<T> PollConnectOnce<T> for ConnectOnce<T> where T: TlsConnect<Socket>, { fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start<T>>) -> Poll<AfterStart<T>, Error> { let state = state.take(); transition!(ConnectingSocket { future: ConnectSocketFuture::new(state.config.clone(), state.idx), idx: state.idx, tls: state.tls, config: state.config, }) } fn poll_connecting_socket<'a>( state: &'a mut RentToOwn<'a, ConnectingSocket<T>>, ) -> Poll<AfterConnectingSocket<T>, Error> { let socket = try_ready!(state.future.poll()); let state = state.take(); transition!(ConnectingRaw { target_session_attrs: state.config.0.target_session_attrs, future: ConnectRawFuture::new(socket, state.tls, state.config, Some(state.idx)), }) } fn poll_connecting_raw<'a>( state: &'a mut RentToOwn<'a, ConnectingRaw<T>>, ) -> Poll<AfterConnectingRaw<T>, Error> { let (client, connection) = try_ready!(state.future.poll()); if let TargetSessionAttrs::ReadWrite = state.target_session_attrs { transition!(CheckingSessionAttrs { stream: client.simple_query("SHOW transaction_read_only"), client, connection, }) } else { transition!(Finished((client, connection))) } } fn poll_checking_session_attrs<'a>( state: &'a mut RentToOwn<'a, CheckingSessionAttrs<T>>, ) -> Poll<AfterCheckingSessionAttrs<T>, Error> { loop { if let Async::Ready(()) = state.connection.poll()? { return Err(Error::closed()); } match try_ready!(state.stream.poll()) { Some(SimpleQueryMessage::Row(row)) => { if row.try_get(0)? == Some("on") { return Err(Error::connect(io::Error::new( io::ErrorKind::PermissionDenied, "database does not allow writes", ))); } else { let state = state.take(); transition!(Finished((state.client, state.connection))) } } Some(_) => {} None => return Err(Error::closed()), } } } } impl<T> ConnectOnceFuture<T> where T: TlsConnect<Socket>, { pub fn new(idx: usize, tls: T, config: Config) -> ConnectOnceFuture<T> { ConnectOnce::start(idx, tls, config) } }