xitca_postgres/
driver.rs

1//! client driver module.
2
3pub(crate) mod codec;
4pub(crate) mod generic;
5
6mod connect;
7
8pub(crate) use generic::DriverTx;
9
10#[cfg(feature = "tls")]
11mod tls;
12
13#[cfg(feature = "quic")]
14pub(crate) mod quic;
15
16use core::{
17    future::{Future, IntoFuture},
18    net::SocketAddr,
19    pin::Pin,
20};
21
22use std::io;
23
24use postgres_protocol::message::{backend, frontend};
25use xitca_io::{
26    bytes::{Buf, BytesMut},
27    io::{AsyncIo, AsyncIoDyn, Interest},
28    net::TcpStream,
29};
30
31use super::{
32    client::Client,
33    config::{Config, SslMode, SslNegotiation},
34    error::{unexpected_eof_err, ConfigError, Error},
35    iter::AsyncLendingIterator,
36    session::{ConnectInfo, Session},
37};
38
39use self::generic::GenericDriver;
40
41#[cfg(feature = "tls")]
42use xitca_tls::rustls::{ClientConnection, TlsStream};
43
44#[cfg(unix)]
45use xitca_io::net::UnixStream;
46
47pub(super) async fn connect(cfg: &mut Config) -> Result<(Client, Driver), Error> {
48    if cfg.get_hosts().is_empty() {
49        return Err(ConfigError::EmptyHost.into());
50    }
51
52    if cfg.get_ports().is_empty() {
53        return Err(ConfigError::EmptyPort.into());
54    }
55
56    let mut err = None;
57    let hosts = cfg.get_hosts().to_vec();
58    for host in hosts {
59        match self::connect::connect_host(host, cfg).await {
60            Ok((tx, session, drv)) => return Ok((Client::new(tx, session), drv)),
61            Err(e) => err = Some(e),
62        }
63    }
64
65    Err(err.unwrap())
66}
67
68pub(super) async fn connect_io<Io>(io: Io, cfg: &mut Config) -> Result<(Client, Driver), Error>
69where
70    Io: AsyncIo + Send + 'static,
71{
72    let (tx, session, drv) = prepare_driver(ConnectInfo::default(), Box::new(io) as _, cfg).await?;
73    Ok((Client::new(tx, session), Driver::Dynamic(drv)))
74}
75
76pub(super) async fn connect_info(info: ConnectInfo) -> Result<(DriverTx, Driver), Error> {
77    self::connect::connect_info(info).await
78}
79
80async fn prepare_driver<Io>(
81    info: ConnectInfo,
82    io: Io,
83    cfg: &mut Config,
84) -> Result<(DriverTx, Session, GenericDriver<Io>), Error>
85where
86    Io: AsyncIo + Send + 'static,
87{
88    let (mut drv, tx) = GenericDriver::new(io);
89    let session = Session::prepare_session(info, &mut drv, cfg).await?;
90    Ok((tx, session, drv))
91}
92
93async fn should_connect_tls<Io>(io: &mut Io, ssl_mode: SslMode, ssl_negotiation: SslNegotiation) -> Result<bool, Error>
94where
95    Io: AsyncIo,
96{
97    async fn query_tls_availability<Io>(io: &mut Io) -> std::io::Result<bool>
98    where
99        Io: AsyncIo,
100    {
101        let mut buf = BytesMut::new();
102        frontend::ssl_request(&mut buf);
103
104        while !buf.is_empty() {
105            match io.write(&buf) {
106                Ok(0) => return Err(unexpected_eof_err()),
107                Ok(n) => buf.advance(n),
108                Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
109                    io.ready(Interest::WRITABLE).await?;
110                }
111                Err(e) => return Err(e),
112            }
113        }
114
115        let mut buf = [0];
116        loop {
117            match io.read(&mut buf) {
118                Ok(0) => return Err(unexpected_eof_err()),
119                Ok(_) => return Ok(buf[0] == b'S'),
120                Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
121                    io.ready(Interest::READABLE).await?;
122                }
123                Err(e) => return Err(e),
124            }
125        }
126    }
127
128    match ssl_mode {
129        SslMode::Disable => Ok(false),
130        _ if matches!(ssl_negotiation, SslNegotiation::Direct) => Ok(true),
131        mode => match (query_tls_availability(io).await?, mode) {
132            (false, SslMode::Require) => Err(Error::todo()),
133            (bool, _) => Ok(bool),
134        },
135    }
136}
137
138async fn dns_resolve<'p>(host: &'p str, ports: &'p [u16]) -> Result<impl Iterator<Item = SocketAddr> + 'p, Error> {
139    let addrs = tokio::net::lookup_host((host, 0)).await?.flat_map(|mut addr| {
140        ports.iter().map(move |port| {
141            addr.set_port(*port);
142            addr
143        })
144    });
145    Ok(addrs)
146}
147
148/// async driver of [`Client`]
149///
150/// it handles IO and emit server sent message that do not belong to any query with [`AsyncLendingIterator`]
151/// trait impl.
152///
153/// # Examples
154/// ```
155/// use std::future::IntoFuture;
156/// use xitca_postgres::{iter::AsyncLendingIterator, Driver};
157///
158/// // drive the client and listen to server notify at the same time.
159/// fn drive_with_server_notify(mut drv: Driver) {
160///     tokio::spawn(async move {
161///         while let Ok(Some(msg)) = drv.try_next().await {
162///             // *Note:
163///             // handle message must be non-blocking to prevent starvation of driver.
164///         }
165///     });
166/// }
167///
168/// // drive client without handling notify.
169/// fn drive_only(drv: Driver) {
170///     tokio::spawn(drv.into_future());
171/// }
172/// ```
173///
174/// # Lifetime
175/// Driver and [`Client`] have a dependent lifetime where either side can trigger the other part to shutdown.
176/// From Driver side it's in the form of dropping ownership.
177/// ## Examples
178/// ```
179/// # use xitca_postgres::{error::Error, Config, Execute, Postgres};
180/// # async fn shut_down(cfg: Config) -> Result<(), Error> {
181/// // connect to a database
182/// let (cli, drv) = Postgres::new(cfg).connect().await?;
183///
184/// // drop driver
185/// drop(drv);
186///
187/// // client will always return error when it's driver is gone.
188/// let e = "SELECT 1".query(&cli).await.unwrap_err();
189/// // a shortcut method can be used to determine if the error is caused by a shutdown driver.
190/// assert!(e.is_driver_down());
191///
192/// # Ok(())
193/// # }
194/// ```
195///
196// TODO: use Box<dyn AsyncIterator> when life time GAT is object safe.
197pub enum Driver {
198    Tcp(GenericDriver<TcpStream>),
199    Dynamic(GenericDriver<Box<dyn AsyncIoDyn + Send>>),
200    #[cfg(feature = "tls")]
201    Tls(GenericDriver<TlsStream<ClientConnection, TcpStream>>),
202    #[cfg(unix)]
203    Unix(GenericDriver<UnixStream>),
204    #[cfg(all(unix, feature = "tls"))]
205    UnixTls(GenericDriver<TlsStream<ClientConnection, UnixStream>>),
206    #[cfg(feature = "quic")]
207    Quic(GenericDriver<crate::driver::quic::QuicStream>),
208}
209
210impl Driver {
211    #[inline]
212    pub(crate) async fn send(&mut self, buf: BytesMut) -> Result<(), Error> {
213        match self {
214            Self::Tcp(ref mut drv) => drv.send(buf).await,
215            Self::Dynamic(ref mut drv) => drv.send(buf).await,
216            #[cfg(feature = "tls")]
217            Self::Tls(ref mut drv) => drv.send(buf).await,
218            #[cfg(unix)]
219            Self::Unix(ref mut drv) => drv.send(buf).await,
220            #[cfg(all(unix, feature = "tls"))]
221            Self::UnixTls(ref mut drv) => drv.send(buf).await,
222            #[cfg(feature = "quic")]
223            Self::Quic(ref mut drv) => drv.send(buf).await,
224        }
225    }
226}
227
228impl AsyncLendingIterator for Driver {
229    type Ok<'i>
230        = backend::Message
231    where
232        Self: 'i;
233    type Err = Error;
234
235    #[inline]
236    async fn try_next(&mut self) -> Result<Option<Self::Ok<'_>>, Self::Err> {
237        match self {
238            Self::Tcp(ref mut drv) => drv.try_next().await,
239            Self::Dynamic(ref mut drv) => drv.try_next().await,
240            #[cfg(feature = "tls")]
241            Self::Tls(ref mut drv) => drv.try_next().await,
242            #[cfg(unix)]
243            Self::Unix(ref mut drv) => drv.try_next().await,
244            #[cfg(all(unix, feature = "tls"))]
245            Self::UnixTls(ref mut drv) => drv.try_next().await,
246            #[cfg(feature = "quic")]
247            Self::Quic(ref mut drv) => drv.try_next().await,
248        }
249    }
250}
251
252impl IntoFuture for Driver {
253    type Output = Result<(), Error>;
254    type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send>>;
255
256    fn into_future(mut self) -> Self::IntoFuture {
257        Box::pin(async move {
258            while self.try_next().await?.is_some() {}
259            Ok(())
260        })
261    }
262}