ttpkit_http/client/
connector.rs1use std::{
2 io::{self, IoSlice},
3 pin::Pin,
4 str::FromStr,
5 task::{Context, Poll},
6};
7
8use tokio::{
9 io::{AsyncRead, AsyncWrite, ReadBuf},
10 net::TcpStream,
11};
12
13#[cfg(feature = "tls-client")]
14use openssl::ssl::{SslConnector, SslMethod};
15
16use crate::{Error, Scheme, url::Url};
17
18#[cfg(feature = "tls-client")]
19use crate::tls::{TlsConnector, TlsStream};
20
21#[derive(Clone)]
23pub struct Connector {
24 #[cfg(feature = "tls-client")]
25 tls: TlsConnector,
26}
27
28impl Connector {
29 #[cfg(feature = "tls-client")]
31 pub async fn new() -> Result<Self, Error> {
32 let blocking = tokio::task::spawn_blocking(|| {
33 let connector = SslConnector::builder(SslMethod::tls())
34 .map_err(Error::from_other)?
35 .build()
36 .into();
37
38 Ok(connector) as Result<_, Error>
39 });
40
41 let tls = blocking
42 .await
43 .map_err(|_| Error::from_static_msg("interrupted"))??;
44
45 let res = Self { tls };
46
47 Ok(res)
48 }
49
50 #[cfg(not(feature = "tls-client"))]
52 #[inline]
53 pub async fn new() -> Result<Self, Error> {
54 Ok(Self {})
55 }
56
57 pub async fn connect(&self, url: &Url) -> Result<Connection, Error> {
59 let scheme = Scheme::from_str(url.scheme())?;
60
61 let host = url.host();
62 let port = url.port().unwrap_or_else(|| scheme.default_port());
63
64 let tcp_stream = TcpStream::connect((host, port)).await?;
65
66 match scheme {
67 Scheme::HTTP => Ok(tcp_stream.into()),
68
69 #[cfg(feature = "tls-client")]
70 Scheme::HTTPS => {
71 let tls_stream = self.tls.connect(host, tcp_stream).await?;
72
73 Ok(tls_stream.into())
74 }
75
76 #[cfg(not(feature = "tls-client"))]
77 Scheme::HTTPS => Err(Error::from_static_msg("TLS is not supported")),
78 }
79 }
80}
81
82pub struct Connection {
84 #[cfg(feature = "tls-client")]
85 inner: Pin<Box<dyn AsyncReadWrite + Send>>,
86
87 #[cfg(not(feature = "tls-client"))]
88 inner: TcpStream,
89}
90
91impl AsyncRead for Connection {
92 #[inline]
93 fn poll_read(
94 mut self: Pin<&mut Self>,
95 cx: &mut Context<'_>,
96 buf: &mut ReadBuf<'_>,
97 ) -> Poll<io::Result<()>> {
98 AsyncRead::poll_read(Pin::new(&mut self.inner), cx, buf)
99 }
100}
101
102impl AsyncWrite for Connection {
103 #[inline]
104 fn poll_write(
105 mut self: Pin<&mut Self>,
106 cx: &mut Context<'_>,
107 buf: &[u8],
108 ) -> Poll<io::Result<usize>> {
109 AsyncWrite::poll_write(Pin::new(&mut self.inner), cx, buf)
110 }
111
112 #[inline]
113 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
114 AsyncWrite::poll_flush(Pin::new(&mut self.inner), cx)
115 }
116
117 #[inline]
118 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
119 AsyncWrite::poll_shutdown(Pin::new(&mut self.inner), cx)
120 }
121
122 #[inline]
123 fn poll_write_vectored(
124 mut self: Pin<&mut Self>,
125 cx: &mut Context<'_>,
126 bufs: &[IoSlice<'_>],
127 ) -> Poll<io::Result<usize>> {
128 AsyncWrite::poll_write_vectored(Pin::new(&mut self.inner), cx, bufs)
129 }
130
131 #[inline]
132 fn is_write_vectored(&self) -> bool {
133 self.inner.is_write_vectored()
134 }
135}
136
137#[cfg(feature = "tls-client")]
138impl From<TcpStream> for Connection {
139 #[inline]
140 fn from(stream: TcpStream) -> Self {
141 Self {
142 inner: Box::pin(stream),
143 }
144 }
145}
146
147#[cfg(not(feature = "tls-client"))]
148impl From<TcpStream> for Connection {
149 #[inline]
150 fn from(stream: TcpStream) -> Self {
151 Self { inner: stream }
152 }
153}
154
155#[cfg(feature = "tls-client")]
156impl From<TlsStream<TcpStream>> for Connection {
157 #[inline]
158 fn from(stream: TlsStream<TcpStream>) -> Self {
159 Self {
160 inner: Box::pin(stream),
161 }
162 }
163}
164
165#[cfg(feature = "tls-client")]
167trait AsyncReadWrite: AsyncRead + AsyncWrite {}
168
169#[cfg(feature = "tls-client")]
170impl<T> AsyncReadWrite for T where T: AsyncRead + AsyncWrite {}