1pub(crate) mod codec;
4pub(crate) mod generic;
5
6mod connect;
7
8pub(crate) use generic::DriverTx;
9
10#[cfg(feature = "tls")]
11mod tls;
12
13#[cfg(feature = "quic")]
14pub(crate) mod quic;
15
16use core::{
17 future::{Future, IntoFuture},
18 net::SocketAddr,
19 pin::Pin,
20};
21
22use std::io;
23
24use postgres_protocol::message::{backend, frontend};
25use xitca_io::{
26 bytes::{Buf, BytesMut},
27 io::{AsyncIo, AsyncIoDyn, Interest},
28 net::TcpStream,
29};
30
31use super::{
32 client::Client,
33 config::{Config, SslMode, SslNegotiation},
34 error::{unexpected_eof_err, ConfigError, Error},
35 iter::AsyncLendingIterator,
36 session::{ConnectInfo, Session},
37};
38
39use self::generic::GenericDriver;
40
41#[cfg(feature = "tls")]
42use xitca_tls::rustls::{ClientConnection, TlsStream};
43
44#[cfg(unix)]
45use xitca_io::net::UnixStream;
46
47pub(super) async fn connect(cfg: &mut Config) -> Result<(Client, Driver), Error> {
48 if cfg.get_hosts().is_empty() {
49 return Err(ConfigError::EmptyHost.into());
50 }
51
52 if cfg.get_ports().is_empty() {
53 return Err(ConfigError::EmptyPort.into());
54 }
55
56 let mut err = None;
57 let hosts = cfg.get_hosts().to_vec();
58 for host in hosts {
59 match self::connect::connect_host(host, cfg).await {
60 Ok((tx, session, drv)) => return Ok((Client::new(tx, session), drv)),
61 Err(e) => err = Some(e),
62 }
63 }
64
65 Err(err.unwrap())
66}
67
68pub(super) async fn connect_io<Io>(io: Io, cfg: &mut Config) -> Result<(Client, Driver), Error>
69where
70 Io: AsyncIo + Send + 'static,
71{
72 let (tx, session, drv) = prepare_driver(ConnectInfo::default(), Box::new(io) as _, cfg).await?;
73 Ok((Client::new(tx, session), Driver::Dynamic(drv)))
74}
75
76pub(super) async fn connect_info(info: ConnectInfo) -> Result<(DriverTx, Driver), Error> {
77 self::connect::connect_info(info).await
78}
79
80async fn prepare_driver<Io>(
81 info: ConnectInfo,
82 io: Io,
83 cfg: &mut Config,
84) -> Result<(DriverTx, Session, GenericDriver<Io>), Error>
85where
86 Io: AsyncIo + Send + 'static,
87{
88 let (mut drv, tx) = GenericDriver::new(io);
89 let session = Session::prepare_session(info, &mut drv, cfg).await?;
90 Ok((tx, session, drv))
91}
92
93async fn should_connect_tls<Io>(io: &mut Io, ssl_mode: SslMode, ssl_negotiation: SslNegotiation) -> Result<bool, Error>
94where
95 Io: AsyncIo,
96{
97 async fn query_tls_availability<Io>(io: &mut Io) -> std::io::Result<bool>
98 where
99 Io: AsyncIo,
100 {
101 let mut buf = BytesMut::new();
102 frontend::ssl_request(&mut buf);
103
104 while !buf.is_empty() {
105 match io.write(&buf) {
106 Ok(0) => return Err(unexpected_eof_err()),
107 Ok(n) => buf.advance(n),
108 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
109 io.ready(Interest::WRITABLE).await?;
110 }
111 Err(e) => return Err(e),
112 }
113 }
114
115 let mut buf = [0];
116 loop {
117 match io.read(&mut buf) {
118 Ok(0) => return Err(unexpected_eof_err()),
119 Ok(_) => return Ok(buf[0] == b'S'),
120 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
121 io.ready(Interest::READABLE).await?;
122 }
123 Err(e) => return Err(e),
124 }
125 }
126 }
127
128 match ssl_mode {
129 SslMode::Disable => Ok(false),
130 _ if matches!(ssl_negotiation, SslNegotiation::Direct) => Ok(true),
131 mode => match (query_tls_availability(io).await?, mode) {
132 (false, SslMode::Require) => Err(Error::todo()),
133 (bool, _) => Ok(bool),
134 },
135 }
136}
137
138async fn dns_resolve<'p>(host: &'p str, ports: &'p [u16]) -> Result<impl Iterator<Item = SocketAddr> + 'p, Error> {
139 let addrs = tokio::net::lookup_host((host, 0)).await?.flat_map(|mut addr| {
140 ports.iter().map(move |port| {
141 addr.set_port(*port);
142 addr
143 })
144 });
145 Ok(addrs)
146}
147
148pub enum Driver {
198 Tcp(GenericDriver<TcpStream>),
199 Dynamic(GenericDriver<Box<dyn AsyncIoDyn + Send>>),
200 #[cfg(feature = "tls")]
201 Tls(GenericDriver<TlsStream<ClientConnection, TcpStream>>),
202 #[cfg(unix)]
203 Unix(GenericDriver<UnixStream>),
204 #[cfg(all(unix, feature = "tls"))]
205 UnixTls(GenericDriver<TlsStream<ClientConnection, UnixStream>>),
206 #[cfg(feature = "quic")]
207 Quic(GenericDriver<crate::driver::quic::QuicStream>),
208}
209
210impl Driver {
211 #[inline]
212 pub(crate) async fn send(&mut self, buf: BytesMut) -> Result<(), Error> {
213 match self {
214 Self::Tcp(ref mut drv) => drv.send(buf).await,
215 Self::Dynamic(ref mut drv) => drv.send(buf).await,
216 #[cfg(feature = "tls")]
217 Self::Tls(ref mut drv) => drv.send(buf).await,
218 #[cfg(unix)]
219 Self::Unix(ref mut drv) => drv.send(buf).await,
220 #[cfg(all(unix, feature = "tls"))]
221 Self::UnixTls(ref mut drv) => drv.send(buf).await,
222 #[cfg(feature = "quic")]
223 Self::Quic(ref mut drv) => drv.send(buf).await,
224 }
225 }
226}
227
228impl AsyncLendingIterator for Driver {
229 type Ok<'i>
230 = backend::Message
231 where
232 Self: 'i;
233 type Err = Error;
234
235 #[inline]
236 async fn try_next(&mut self) -> Result<Option<Self::Ok<'_>>, Self::Err> {
237 match self {
238 Self::Tcp(ref mut drv) => drv.try_next().await,
239 Self::Dynamic(ref mut drv) => drv.try_next().await,
240 #[cfg(feature = "tls")]
241 Self::Tls(ref mut drv) => drv.try_next().await,
242 #[cfg(unix)]
243 Self::Unix(ref mut drv) => drv.try_next().await,
244 #[cfg(all(unix, feature = "tls"))]
245 Self::UnixTls(ref mut drv) => drv.try_next().await,
246 #[cfg(feature = "quic")]
247 Self::Quic(ref mut drv) => drv.try_next().await,
248 }
249 }
250}
251
252impl IntoFuture for Driver {
253 type Output = Result<(), Error>;
254 type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send>>;
255
256 fn into_future(mut self) -> Self::IntoFuture {
257 Box::pin(async move {
258 while self.try_next().await?.is_some() {}
259 Ok(())
260 })
261 }
262}