xitca_postgres/
driver.rs

1//! client driver module.
2
3pub(crate) mod codec;
4pub(crate) mod generic;
5
6mod connect;
7
8pub(crate) use generic::DriverTx;
9
10#[cfg(feature = "tls")]
11mod tls;
12
13#[cfg(feature = "quic")]
14pub(crate) mod quic;
15
16#[cfg(feature = "io-uring")]
17pub(crate) mod io_uring;
18
19use core::{
20    future::{Future, IntoFuture},
21    net::SocketAddr,
22    pin::Pin,
23};
24
25use std::io;
26
27use postgres_protocol::message::{backend, frontend};
28use xitca_io::{
29    bytes::{Buf, BytesMut},
30    io::{AsyncIo, AsyncIoDyn, Interest},
31    net::TcpStream,
32};
33
34use super::{
35    client::Client,
36    config::{Config, SslMode, SslNegotiation},
37    error::{ConfigError, Error, unexpected_eof_err},
38    iter::AsyncLendingIterator,
39    session::{ConnectInfo, Session},
40};
41
42use self::generic::GenericDriver;
43
44#[cfg(feature = "tls")]
45use xitca_tls::rustls::{ClientConnection, TlsStream};
46
47#[cfg(unix)]
48use xitca_io::net::UnixStream;
49
50pub(super) async fn connect(cfg: &mut Config) -> Result<(Client, Driver), Error> {
51    if cfg.get_hosts().is_empty() {
52        return Err(ConfigError::EmptyHost.into());
53    }
54
55    if cfg.get_ports().is_empty() {
56        return Err(ConfigError::EmptyPort.into());
57    }
58
59    let mut err = None;
60    let hosts = cfg.get_hosts().to_vec();
61    for host in hosts {
62        match self::connect::connect_host(host, cfg).await {
63            Ok((tx, session, drv)) => return Ok((Client::new(tx, session), drv)),
64            Err(e) => err = Some(e),
65        }
66    }
67
68    Err(err.unwrap())
69}
70
71pub(super) async fn connect_io<Io>(io: Io, cfg: &mut Config) -> Result<(Client, Driver), Error>
72where
73    Io: AsyncIo + Send + 'static,
74{
75    let (tx, session, drv) = prepare_driver(ConnectInfo::default(), Box::new(io) as _, cfg).await?;
76    Ok((Client::new(tx, session), Driver::Dynamic(drv)))
77}
78
79pub(super) async fn connect_info(info: ConnectInfo) -> Result<(DriverTx, Driver), Error> {
80    self::connect::connect_info(info).await
81}
82
83async fn prepare_driver<Io>(
84    info: ConnectInfo,
85    io: Io,
86    cfg: &mut Config,
87) -> Result<(DriverTx, Session, GenericDriver<Io>), Error>
88where
89    Io: AsyncIo + Send + 'static,
90{
91    let (mut drv, tx) = GenericDriver::new(io);
92    let session = Session::prepare_session(info, &mut drv, cfg).await?;
93    Ok((tx, session, drv))
94}
95
96async fn should_connect_tls<Io>(io: &mut Io, ssl_mode: SslMode, ssl_negotiation: SslNegotiation) -> Result<bool, Error>
97where
98    Io: AsyncIo,
99{
100    async fn query_tls_availability<Io>(io: &mut Io) -> std::io::Result<bool>
101    where
102        Io: AsyncIo,
103    {
104        let mut buf = BytesMut::new();
105        frontend::ssl_request(&mut buf);
106
107        while !buf.is_empty() {
108            match io.write(&buf) {
109                Ok(0) => return Err(unexpected_eof_err()),
110                Ok(n) => buf.advance(n),
111                Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
112                    io.ready(Interest::WRITABLE).await?;
113                }
114                Err(e) => return Err(e),
115            }
116        }
117
118        let mut buf = [0];
119        loop {
120            match io.read(&mut buf) {
121                Ok(0) => return Err(unexpected_eof_err()),
122                Ok(_) => return Ok(buf[0] == b'S'),
123                Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
124                    io.ready(Interest::READABLE).await?;
125                }
126                Err(e) => return Err(e),
127            }
128        }
129    }
130
131    match ssl_mode {
132        SslMode::Disable => Ok(false),
133        _ if matches!(ssl_negotiation, SslNegotiation::Direct) => Ok(true),
134        mode => match (query_tls_availability(io).await?, mode) {
135            (false, SslMode::Require) => Err(Error::todo()),
136            (bool, _) => Ok(bool),
137        },
138    }
139}
140
141async fn dns_resolve<'p>(host: &'p str, ports: &'p [u16]) -> Result<impl Iterator<Item = SocketAddr> + 'p, Error> {
142    let addrs = tokio::net::lookup_host((host, 0)).await?.flat_map(|mut addr| {
143        ports.iter().map(move |port| {
144            addr.set_port(*port);
145            addr
146        })
147    });
148    Ok(addrs)
149}
150
151/// async driver of [`Client`]
152///
153/// it handles IO and emit server sent message that do not belong to any query with [`AsyncLendingIterator`]
154/// trait impl.
155///
156/// # Examples
157/// ```
158/// use std::future::IntoFuture;
159/// use xitca_postgres::{iter::AsyncLendingIterator, Driver};
160///
161/// // drive the client and listen to server notify at the same time.
162/// fn drive_with_server_notify(mut drv: Driver) {
163///     tokio::spawn(async move {
164///         while let Ok(Some(msg)) = drv.try_next().await {
165///             // *Note:
166///             // handle message must be non-blocking to prevent starvation of driver.
167///         }
168///     });
169/// }
170///
171/// // drive client without handling notify.
172/// fn drive_only(drv: Driver) {
173///     tokio::spawn(drv.into_future());
174/// }
175/// ```
176///
177/// # Lifetime
178/// Driver and [`Client`] have a dependent lifetime where either side can trigger the other part to shutdown.
179/// From Driver side it's in the form of dropping ownership.
180/// ## Examples
181/// ```
182/// # use xitca_postgres::{error::Error, Config, Execute, Postgres};
183/// # async fn shut_down(cfg: Config) -> Result<(), Error> {
184/// // connect to a database
185/// let (cli, drv) = Postgres::new(cfg).connect().await?;
186///
187/// // drop driver
188/// drop(drv);
189///
190/// // client will always return error when it's driver is gone.
191/// let e = "SELECT 1".query(&cli).await.unwrap_err();
192/// // a shortcut method can be used to determine if the error is caused by a shutdown driver.
193/// assert!(e.is_driver_down());
194///
195/// # Ok(())
196/// # }
197/// ```
198///
199// TODO: use Box<dyn AsyncIterator> when life time GAT is object safe.
200pub enum Driver {
201    Tcp(GenericDriver<TcpStream>),
202    Dynamic(GenericDriver<Box<dyn AsyncIoDyn + Send>>),
203    #[cfg(feature = "tls")]
204    Tls(GenericDriver<TlsStream<ClientConnection, TcpStream>>),
205    #[cfg(unix)]
206    Unix(GenericDriver<UnixStream>),
207    #[cfg(all(unix, feature = "tls"))]
208    UnixTls(GenericDriver<TlsStream<ClientConnection, UnixStream>>),
209    #[cfg(feature = "quic")]
210    Quic(GenericDriver<crate::driver::quic::QuicStream>),
211}
212
213impl Driver {
214    #[inline]
215    pub(crate) async fn send(&mut self, buf: BytesMut) -> Result<(), Error> {
216        match self {
217            Self::Tcp(drv) => drv.send(buf).await,
218            Self::Dynamic(drv) => drv.send(buf).await,
219            #[cfg(feature = "tls")]
220            Self::Tls(drv) => drv.send(buf).await,
221            #[cfg(unix)]
222            Self::Unix(drv) => drv.send(buf).await,
223            #[cfg(all(unix, feature = "tls"))]
224            Self::UnixTls(drv) => drv.send(buf).await,
225            #[cfg(feature = "quic")]
226            Self::Quic(drv) => drv.send(buf).await,
227        }
228    }
229
230    // try to unwrap driver that using unencrypted tcp connection
231    pub fn try_into_tcp(self) -> Option<GenericDriver<TcpStream>> {
232        match self {
233            Self::Tcp(drv) => Some(drv),
234            _ => None,
235        }
236    }
237
238    #[cfg(feature = "io-uring")]
239    pub fn try_into_uring(self) -> Option<io_uring::UringDriver> {
240        self.try_into_tcp().map(io_uring::UringDriver::from_tcp)
241    }
242}
243
244impl AsyncLendingIterator for Driver {
245    type Ok<'i>
246        = backend::Message
247    where
248        Self: 'i;
249    type Err = Error;
250
251    #[inline]
252    async fn try_next(&mut self) -> Result<Option<Self::Ok<'_>>, Self::Err> {
253        match self {
254            Self::Tcp(drv) => drv.try_next().await,
255            Self::Dynamic(drv) => drv.try_next().await,
256            #[cfg(feature = "tls")]
257            Self::Tls(drv) => drv.try_next().await,
258            #[cfg(unix)]
259            Self::Unix(drv) => drv.try_next().await,
260            #[cfg(all(unix, feature = "tls"))]
261            Self::UnixTls(drv) => drv.try_next().await,
262            #[cfg(feature = "quic")]
263            Self::Quic(drv) => drv.try_next().await,
264        }
265    }
266}
267
268impl IntoFuture for Driver {
269    type Output = Result<(), Error>;
270    type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send>>;
271
272    fn into_future(mut self) -> Self::IntoFuture {
273        Box::pin(async move {
274            while self.try_next().await?.is_some() {}
275            Ok(())
276        })
277    }
278}