ttpkit_http/client/
connector.rs

1use std::{
2    io::{self, IoSlice},
3    pin::Pin,
4    str::FromStr,
5    task::{Context, Poll},
6};
7
8use tokio::{
9    io::{AsyncRead, AsyncWrite, ReadBuf},
10    net::TcpStream,
11};
12
13#[cfg(feature = "tls-client")]
14use openssl::ssl::{SslConnector, SslMethod};
15
16use crate::{Error, Scheme, url::Url};
17
18#[cfg(feature = "tls-client")]
19use crate::tls::{TlsConnector, TlsStream};
20
21/// Simple HTTP connector.
22#[derive(Clone)]
23pub struct Connector {
24    #[cfg(feature = "tls-client")]
25    tls: TlsConnector,
26}
27
28impl Connector {
29    /// Create a new connector.
30    #[cfg(feature = "tls-client")]
31    pub async fn new() -> Result<Self, Error> {
32        let blocking = tokio::task::spawn_blocking(|| {
33            let connector = SslConnector::builder(SslMethod::tls())
34                .map_err(Error::from_other)?
35                .build()
36                .into();
37
38            Ok(connector) as Result<_, Error>
39        });
40
41        let tls = blocking
42            .await
43            .map_err(|_| Error::from_static_msg("interrupted"))??;
44
45        let res = Self { tls };
46
47        Ok(res)
48    }
49
50    /// Create a new connector.
51    #[cfg(not(feature = "tls-client"))]
52    #[inline]
53    pub async fn new() -> Result<Self, Error> {
54        Ok(Self {})
55    }
56
57    /// Connect to a given server.
58    pub async fn connect(&self, url: &Url) -> Result<Connection, Error> {
59        let scheme = Scheme::from_str(url.scheme())?;
60
61        let host = url.host();
62        let port = url.port().unwrap_or_else(|| scheme.default_port());
63
64        let tcp_stream = TcpStream::connect((host, port)).await?;
65
66        match scheme {
67            Scheme::HTTP => Ok(tcp_stream.into()),
68
69            #[cfg(feature = "tls-client")]
70            Scheme::HTTPS => {
71                let tls_stream = self.tls.connect(host, tcp_stream).await?;
72
73                Ok(tls_stream.into())
74            }
75
76            #[cfg(not(feature = "tls-client"))]
77            Scheme::HTTPS => Err(Error::from_static_msg("TLS is not supported")),
78        }
79    }
80}
81
82/// Plain HTTP connection.
83pub struct Connection {
84    #[cfg(feature = "tls-client")]
85    inner: Pin<Box<dyn AsyncReadWrite + Send>>,
86
87    #[cfg(not(feature = "tls-client"))]
88    inner: TcpStream,
89}
90
91impl AsyncRead for Connection {
92    #[inline]
93    fn poll_read(
94        mut self: Pin<&mut Self>,
95        cx: &mut Context<'_>,
96        buf: &mut ReadBuf<'_>,
97    ) -> Poll<io::Result<()>> {
98        AsyncRead::poll_read(Pin::new(&mut self.inner), cx, buf)
99    }
100}
101
102impl AsyncWrite for Connection {
103    #[inline]
104    fn poll_write(
105        mut self: Pin<&mut Self>,
106        cx: &mut Context<'_>,
107        buf: &[u8],
108    ) -> Poll<io::Result<usize>> {
109        AsyncWrite::poll_write(Pin::new(&mut self.inner), cx, buf)
110    }
111
112    #[inline]
113    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
114        AsyncWrite::poll_flush(Pin::new(&mut self.inner), cx)
115    }
116
117    #[inline]
118    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
119        AsyncWrite::poll_shutdown(Pin::new(&mut self.inner), cx)
120    }
121
122    #[inline]
123    fn poll_write_vectored(
124        mut self: Pin<&mut Self>,
125        cx: &mut Context<'_>,
126        bufs: &[IoSlice<'_>],
127    ) -> Poll<io::Result<usize>> {
128        AsyncWrite::poll_write_vectored(Pin::new(&mut self.inner), cx, bufs)
129    }
130
131    #[inline]
132    fn is_write_vectored(&self) -> bool {
133        self.inner.is_write_vectored()
134    }
135}
136
137#[cfg(feature = "tls-client")]
138impl From<TcpStream> for Connection {
139    #[inline]
140    fn from(stream: TcpStream) -> Self {
141        Self {
142            inner: Box::pin(stream),
143        }
144    }
145}
146
147#[cfg(not(feature = "tls-client"))]
148impl From<TcpStream> for Connection {
149    #[inline]
150    fn from(stream: TcpStream) -> Self {
151        Self { inner: stream }
152    }
153}
154
155#[cfg(feature = "tls-client")]
156impl From<TlsStream<TcpStream>> for Connection {
157    #[inline]
158    fn from(stream: TlsStream<TcpStream>) -> Self {
159        Self {
160            inner: Box::pin(stream),
161        }
162    }
163}
164
165/// Helper trait.
166#[cfg(feature = "tls-client")]
167trait AsyncReadWrite: AsyncRead + AsyncWrite {}
168
169#[cfg(feature = "tls-client")]
170impl<T> AsyncReadWrite for T where T: AsyncRead + AsyncWrite {}