1pub(crate) mod codec;
4pub(crate) mod generic;
5
6mod connect;
7
8pub(crate) use generic::DriverTx;
9
10#[cfg(feature = "tls")]
11mod tls;
12
13#[cfg(feature = "quic")]
14pub(crate) mod quic;
15
16#[cfg(feature = "io-uring")]
17pub(crate) mod io_uring;
18
19use core::{
20 future::{Future, IntoFuture},
21 net::SocketAddr,
22 pin::Pin,
23};
24
25use std::io;
26
27use postgres_protocol::message::{backend, frontend};
28use xitca_io::{
29 bytes::{Buf, BytesMut},
30 io::{AsyncIo, AsyncIoDyn, Interest},
31 net::TcpStream,
32};
33
34use super::{
35 client::Client,
36 config::{Config, SslMode, SslNegotiation},
37 error::{ConfigError, Error, unexpected_eof_err},
38 iter::AsyncLendingIterator,
39 session::{ConnectInfo, Session},
40};
41
42use self::generic::GenericDriver;
43
44#[cfg(feature = "tls")]
45use xitca_tls::rustls::{ClientConnection, TlsStream};
46
47#[cfg(unix)]
48use xitca_io::net::UnixStream;
49
50pub(super) async fn connect(cfg: &mut Config) -> Result<(Client, Driver), Error> {
51 if cfg.get_hosts().is_empty() {
52 return Err(ConfigError::EmptyHost.into());
53 }
54
55 if cfg.get_ports().is_empty() {
56 return Err(ConfigError::EmptyPort.into());
57 }
58
59 let mut err = None;
60 let hosts = cfg.get_hosts().to_vec();
61 for host in hosts {
62 match self::connect::connect_host(host, cfg).await {
63 Ok((tx, session, drv)) => return Ok((Client::new(tx, session), drv)),
64 Err(e) => err = Some(e),
65 }
66 }
67
68 Err(err.unwrap())
69}
70
71pub(super) async fn connect_io<Io>(io: Io, cfg: &mut Config) -> Result<(Client, Driver), Error>
72where
73 Io: AsyncIo + Send + 'static,
74{
75 let (tx, session, drv) = prepare_driver(ConnectInfo::default(), Box::new(io) as _, cfg).await?;
76 Ok((Client::new(tx, session), Driver::Dynamic(drv)))
77}
78
79pub(super) async fn connect_info(info: ConnectInfo) -> Result<(DriverTx, Driver), Error> {
80 self::connect::connect_info(info).await
81}
82
83async fn prepare_driver<Io>(
84 info: ConnectInfo,
85 io: Io,
86 cfg: &mut Config,
87) -> Result<(DriverTx, Session, GenericDriver<Io>), Error>
88where
89 Io: AsyncIo + Send + 'static,
90{
91 let (mut drv, tx) = GenericDriver::new(io);
92 let session = Session::prepare_session(info, &mut drv, cfg).await?;
93 Ok((tx, session, drv))
94}
95
96async fn should_connect_tls<Io>(io: &mut Io, ssl_mode: SslMode, ssl_negotiation: SslNegotiation) -> Result<bool, Error>
97where
98 Io: AsyncIo,
99{
100 async fn query_tls_availability<Io>(io: &mut Io) -> std::io::Result<bool>
101 where
102 Io: AsyncIo,
103 {
104 let mut buf = BytesMut::new();
105 frontend::ssl_request(&mut buf);
106
107 while !buf.is_empty() {
108 match io.write(&buf) {
109 Ok(0) => return Err(unexpected_eof_err()),
110 Ok(n) => buf.advance(n),
111 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
112 io.ready(Interest::WRITABLE).await?;
113 }
114 Err(e) => return Err(e),
115 }
116 }
117
118 let mut buf = [0];
119 loop {
120 match io.read(&mut buf) {
121 Ok(0) => return Err(unexpected_eof_err()),
122 Ok(_) => return Ok(buf[0] == b'S'),
123 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
124 io.ready(Interest::READABLE).await?;
125 }
126 Err(e) => return Err(e),
127 }
128 }
129 }
130
131 match ssl_mode {
132 SslMode::Disable => Ok(false),
133 _ if matches!(ssl_negotiation, SslNegotiation::Direct) => Ok(true),
134 mode => match (query_tls_availability(io).await?, mode) {
135 (false, SslMode::Require) => Err(Error::todo()),
136 (bool, _) => Ok(bool),
137 },
138 }
139}
140
141async fn dns_resolve<'p>(host: &'p str, ports: &'p [u16]) -> Result<impl Iterator<Item = SocketAddr> + 'p, Error> {
142 let addrs = tokio::net::lookup_host((host, 0)).await?.flat_map(|mut addr| {
143 ports.iter().map(move |port| {
144 addr.set_port(*port);
145 addr
146 })
147 });
148 Ok(addrs)
149}
150
151pub enum Driver {
201 Tcp(GenericDriver<TcpStream>),
202 Dynamic(GenericDriver<Box<dyn AsyncIoDyn + Send>>),
203 #[cfg(feature = "tls")]
204 Tls(GenericDriver<TlsStream<ClientConnection, TcpStream>>),
205 #[cfg(unix)]
206 Unix(GenericDriver<UnixStream>),
207 #[cfg(all(unix, feature = "tls"))]
208 UnixTls(GenericDriver<TlsStream<ClientConnection, UnixStream>>),
209 #[cfg(feature = "quic")]
210 Quic(GenericDriver<crate::driver::quic::QuicStream>),
211}
212
213impl Driver {
214 #[inline]
215 pub(crate) async fn send(&mut self, buf: BytesMut) -> Result<(), Error> {
216 match self {
217 Self::Tcp(drv) => drv.send(buf).await,
218 Self::Dynamic(drv) => drv.send(buf).await,
219 #[cfg(feature = "tls")]
220 Self::Tls(drv) => drv.send(buf).await,
221 #[cfg(unix)]
222 Self::Unix(drv) => drv.send(buf).await,
223 #[cfg(all(unix, feature = "tls"))]
224 Self::UnixTls(drv) => drv.send(buf).await,
225 #[cfg(feature = "quic")]
226 Self::Quic(drv) => drv.send(buf).await,
227 }
228 }
229
230 pub fn try_into_tcp(self) -> Option<GenericDriver<TcpStream>> {
232 match self {
233 Self::Tcp(drv) => Some(drv),
234 _ => None,
235 }
236 }
237
238 #[cfg(feature = "io-uring")]
239 pub fn try_into_uring(self) -> Option<io_uring::UringDriver> {
240 self.try_into_tcp().map(io_uring::UringDriver::from_tcp)
241 }
242}
243
244impl AsyncLendingIterator for Driver {
245 type Ok<'i>
246 = backend::Message
247 where
248 Self: 'i;
249 type Err = Error;
250
251 #[inline]
252 async fn try_next(&mut self) -> Result<Option<Self::Ok<'_>>, Self::Err> {
253 match self {
254 Self::Tcp(drv) => drv.try_next().await,
255 Self::Dynamic(drv) => drv.try_next().await,
256 #[cfg(feature = "tls")]
257 Self::Tls(drv) => drv.try_next().await,
258 #[cfg(unix)]
259 Self::Unix(drv) => drv.try_next().await,
260 #[cfg(all(unix, feature = "tls"))]
261 Self::UnixTls(drv) => drv.try_next().await,
262 #[cfg(feature = "quic")]
263 Self::Quic(drv) => drv.try_next().await,
264 }
265 }
266}
267
268impl IntoFuture for Driver {
269 type Output = Result<(), Error>;
270 type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send>>;
271
272 fn into_future(mut self) -> Self::IntoFuture {
273 Box::pin(async move {
274 while self.try_next().await?.is_some() {}
275 Ok(())
276 })
277 }
278}