1pub(crate) mod codec;
4pub(crate) mod generic;
5
6mod connect;
7
8pub(crate) use generic::DriverTx;
9
10#[cfg(feature = "tls")]
11mod tls;
12
13#[cfg(feature = "quic")]
14pub(crate) mod quic;
15
16#[cfg(feature = "io-uring")]
17pub(crate) mod io_uring;
18
19#[cfg(feature = "compio")]
20pub(crate) mod compio;
21
22use core::{
23 future::{Future, IntoFuture},
24 net::SocketAddr,
25 pin::Pin,
26};
27
28use std::io;
29
30use xitca_io::{
31 bytes::{Buf, BytesMut},
32 io::{AsyncIo, AsyncIoDyn, Interest},
33 net::TcpStream,
34};
35
36use super::{
37 client::Client,
38 config::{Config, SslMode, SslNegotiation},
39 error::{ConfigError, Error, unexpected_eof_err},
40 iter::AsyncLendingIterator,
41 protocol::message::{backend, frontend},
42 session::{ConnectInfo, Session},
43};
44
45use self::generic::GenericDriver;
46
47#[cfg(feature = "tls")]
48use xitca_tls::rustls::{ClientConnection, TlsStream};
49
50#[cfg(unix)]
51use xitca_io::net::UnixStream;
52
53pub(super) async fn connect(cfg: &mut Config) -> Result<(Client, Driver), Error> {
54 if cfg.get_hosts().is_empty() {
55 return Err(ConfigError::EmptyHost.into());
56 }
57
58 if cfg.get_ports().is_empty() {
59 return Err(ConfigError::EmptyPort.into());
60 }
61
62 let mut err = None;
63 let hosts = cfg.get_hosts().to_vec();
64 for host in hosts {
65 match self::connect::connect_host(host, cfg).await {
66 Ok((tx, session, drv)) => return Ok((Client::new(tx, session), drv)),
67 Err(e) => err = Some(e),
68 }
69 }
70
71 Err(err.unwrap())
72}
73
74pub(super) async fn connect_io<Io>(io: Io, cfg: &mut Config) -> Result<(Client, Driver), Error>
75where
76 Io: AsyncIo + Send + 'static,
77{
78 let (tx, session, drv) = prepare_driver(ConnectInfo::default(), Box::new(io) as _, cfg).await?;
79 Ok((Client::new(tx, session), Driver::Dynamic(drv)))
80}
81
82pub(super) async fn connect_info(info: ConnectInfo) -> Result<(DriverTx, Driver), Error> {
83 self::connect::connect_info(info).await
84}
85
86async fn prepare_driver<Io>(
87 info: ConnectInfo,
88 io: Io,
89 cfg: &mut Config,
90) -> Result<(DriverTx, Session, GenericDriver<Io>), Error>
91where
92 Io: AsyncIo + Send + 'static,
93{
94 let (mut drv, tx) = GenericDriver::new(io);
95 let session = Session::prepare_session(info, &mut drv, cfg).await?;
96 Ok((tx, session, drv))
97}
98
99async fn should_connect_tls<Io>(io: &mut Io, ssl_mode: SslMode, ssl_negotiation: SslNegotiation) -> Result<bool, Error>
100where
101 Io: AsyncIo,
102{
103 async fn query_tls_availability<Io>(io: &mut Io) -> std::io::Result<bool>
104 where
105 Io: AsyncIo,
106 {
107 let mut buf = BytesMut::new();
108 frontend::ssl_request(&mut buf);
109
110 while !buf.is_empty() {
111 match io.write(&buf) {
112 Ok(0) => return Err(unexpected_eof_err()),
113 Ok(n) => buf.advance(n),
114 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
115 io.ready(Interest::WRITABLE).await?;
116 }
117 Err(e) => return Err(e),
118 }
119 }
120
121 let mut buf = [0];
122 loop {
123 match io.read(&mut buf) {
124 Ok(0) => return Err(unexpected_eof_err()),
125 Ok(_) => return Ok(buf[0] == b'S'),
126 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
127 io.ready(Interest::READABLE).await?;
128 }
129 Err(e) => return Err(e),
130 }
131 }
132 }
133
134 match ssl_mode {
135 SslMode::Disable => Ok(false),
136 _ if matches!(ssl_negotiation, SslNegotiation::Direct) => Ok(true),
137 mode => match (query_tls_availability(io).await?, mode) {
138 (false, SslMode::Require) => Err(Error::todo()),
139 (bool, _) => Ok(bool),
140 },
141 }
142}
143
144async fn dns_resolve<'p>(host: &'p str, ports: &'p [u16]) -> Result<impl Iterator<Item = SocketAddr> + 'p, Error> {
145 let addrs = tokio::net::lookup_host((host, 0)).await?.flat_map(|mut addr| {
146 ports.iter().map(move |port| {
147 addr.set_port(*port);
148 addr
149 })
150 });
151 Ok(addrs)
152}
153
154pub enum Driver {
204 Tcp(GenericDriver<TcpStream>),
205 Dynamic(GenericDriver<Box<dyn AsyncIoDyn + Send>>),
206 #[cfg(feature = "tls")]
207 Tls(GenericDriver<TlsStream<ClientConnection, TcpStream>>),
208 #[cfg(unix)]
209 Unix(GenericDriver<UnixStream>),
210 #[cfg(all(unix, feature = "tls"))]
211 UnixTls(GenericDriver<TlsStream<ClientConnection, UnixStream>>),
212 #[cfg(feature = "quic")]
213 Quic(GenericDriver<crate::driver::quic::QuicStream>),
214}
215
216impl Driver {
217 #[inline]
218 pub(crate) async fn send(&mut self, buf: BytesMut) -> Result<(), Error> {
219 match self {
220 Self::Tcp(drv) => drv.send(buf).await,
221 Self::Dynamic(drv) => drv.send(buf).await,
222 #[cfg(feature = "tls")]
223 Self::Tls(drv) => drv.send(buf).await,
224 #[cfg(unix)]
225 Self::Unix(drv) => drv.send(buf).await,
226 #[cfg(all(unix, feature = "tls"))]
227 Self::UnixTls(drv) => drv.send(buf).await,
228 #[cfg(feature = "quic")]
229 Self::Quic(drv) => drv.send(buf).await,
230 }
231 }
232
233 pub fn try_into_tcp(self) -> Option<GenericDriver<TcpStream>> {
235 match self {
236 Self::Tcp(drv) => Some(drv),
237 _ => None,
238 }
239 }
240
241 #[cfg(feature = "io-uring")]
242 pub fn try_into_uring(self) -> Option<io_uring::UringDriver> {
243 self.try_into_tcp().map(io_uring::UringDriver::from_tcp)
244 }
245}
246
247impl AsyncLendingIterator for Driver {
248 type Ok<'i>
249 = backend::Message
250 where
251 Self: 'i;
252 type Err = Error;
253
254 #[inline]
255 async fn try_next(&mut self) -> Result<Option<Self::Ok<'_>>, Self::Err> {
256 match self {
257 Self::Tcp(drv) => drv.try_next().await,
258 Self::Dynamic(drv) => drv.try_next().await,
259 #[cfg(feature = "tls")]
260 Self::Tls(drv) => drv.try_next().await,
261 #[cfg(unix)]
262 Self::Unix(drv) => drv.try_next().await,
263 #[cfg(all(unix, feature = "tls"))]
264 Self::UnixTls(drv) => drv.try_next().await,
265 #[cfg(feature = "quic")]
266 Self::Quic(drv) => drv.try_next().await,
267 }
268 }
269}
270
271impl IntoFuture for Driver {
272 type Output = Result<(), Error>;
273 type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send>>;
274
275 fn into_future(mut self) -> Self::IntoFuture {
276 Box::pin(async move {
277 while self.try_next().await?.is_some() {}
278 Ok(())
279 })
280 }
281}