1use crate::crypto_provider;
2use futures_rustls::{
3 client::TlsStream,
4 rustls::{
5 client::danger::ServerCertVerifier, crypto::CryptoProvider, pki_types::ServerName,
6 ClientConfig, ClientConnection,
7 },
8 TlsConnector,
9};
10use std::{
11 fmt::{self, Debug, Formatter},
12 future::Future,
13 io::{Error, ErrorKind, IoSlice, Result},
14 net::SocketAddr,
15 pin::Pin,
16 sync::Arc,
17 task::{Context, Poll},
18};
19use trillium_server_common::{async_trait, AsyncRead, AsyncWrite, Connector, Transport, Url};
20use RustlsClientTransportInner::{Tcp, Tls};
21
22#[derive(Clone, Debug)]
23pub struct RustlsClientConfig(Arc<ClientConfig>);
24
25#[derive(Clone, Default)]
29pub struct RustlsConfig<Config> {
30 pub rustls_config: RustlsClientConfig,
32
33 pub tcp_config: Config,
35}
36
37impl<C: Connector> RustlsConfig<C> {
38 pub fn new(rustls_config: impl Into<RustlsClientConfig>, tcp_config: C) -> Self {
40 Self {
41 rustls_config: rustls_config.into(),
42 tcp_config,
43 }
44 }
45}
46
47impl Default for RustlsClientConfig {
48 fn default() -> Self {
49 Self(Arc::new(default_client_config()))
50 }
51}
52
53#[cfg(feature = "platform-verifier")]
54fn verifier(provider: Arc<CryptoProvider>) -> Arc<dyn ServerCertVerifier> {
55 Arc::new(rustls_platform_verifier::Verifier::new().with_provider(provider))
56}
57
58#[cfg(not(feature = "platform-verifier"))]
59fn verifier(provider: Arc<CryptoProvider>) -> Arc<dyn ServerCertVerifier> {
60 let roots = Arc::new(futures_rustls::rustls::RootCertStore::from_iter(
61 webpki_roots::TLS_SERVER_ROOTS.iter().cloned(),
62 ));
63 futures_rustls::rustls::client::WebPkiServerVerifier::builder_with_provider(roots, provider)
64 .build()
65 .unwrap()
66}
67
68fn default_client_config() -> ClientConfig {
69 let provider = crypto_provider();
70 let verifier = verifier(Arc::clone(&provider));
71
72 ClientConfig::builder_with_provider(provider)
73 .with_safe_default_protocol_versions()
74 .expect("crypto provider did not support safe default protocol versions")
75 .dangerous()
76 .with_custom_certificate_verifier(verifier)
77 .with_no_client_auth()
78}
79
80impl From<ClientConfig> for RustlsClientConfig {
81 fn from(rustls_config: ClientConfig) -> Self {
82 Self(Arc::new(rustls_config))
83 }
84}
85
86impl From<Arc<ClientConfig>> for RustlsClientConfig {
87 fn from(rustls_config: Arc<ClientConfig>) -> Self {
88 Self(rustls_config)
89 }
90}
91
92impl<C: Connector> RustlsConfig<C> {
93 pub fn with_tcp_config(mut self, config: C) -> Self {
95 self.tcp_config = config;
96 self
97 }
98}
99
100impl<Config: Debug> Debug for RustlsConfig<Config> {
101 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
102 f.debug_struct("RustlsConfig")
103 .field("rustls_config", &"..")
104 .field("tcp_config", &self.tcp_config)
105 .finish()
106 }
107}
108
109#[async_trait]
110impl<C: Connector> Connector for RustlsConfig<C> {
111 type Transport = RustlsClientTransport<C::Transport>;
112
113 async fn connect(&self, url: &Url) -> Result<Self::Transport> {
114 match url.scheme() {
115 "https" => {
116 let mut http = url.clone();
117 http.set_scheme("http").ok();
118 http.set_port(url.port_or_known_default()).ok();
119
120 let connector: TlsConnector = Arc::clone(&self.rustls_config.0).into();
121 let domain = url
122 .domain()
123 .and_then(|dns_name| ServerName::try_from(dns_name.to_string()).ok())
124 .ok_or_else(|| Error::new(ErrorKind::Other, "missing domain"))?;
125
126 connector
127 .connect(domain, self.tcp_config.connect(&http).await?)
128 .await
129 .map_err(|e| Error::new(ErrorKind::Other, e.to_string()))
130 .map(Into::into)
131 }
132
133 "http" => self.tcp_config.connect(url).await.map(Into::into),
134
135 unknown => Err(Error::new(
136 ErrorKind::InvalidInput,
137 format!("unknown scheme {unknown}"),
138 )),
139 }
140 }
141
142 fn spawn<Fut: Future<Output = ()> + Send + 'static>(&self, fut: Fut) {
143 self.tcp_config.spawn(fut)
144 }
145}
146
147#[derive(Debug)]
148enum RustlsClientTransportInner<T> {
149 Tcp(T),
150 Tls(Box<TlsStream<T>>),
151}
152
153#[derive(Debug)]
160pub struct RustlsClientTransport<T>(RustlsClientTransportInner<T>);
161impl<T> From<T> for RustlsClientTransport<T> {
162 fn from(value: T) -> Self {
163 Self(Tcp(value))
164 }
165}
166
167impl<T> From<TlsStream<T>> for RustlsClientTransport<T> {
168 fn from(value: TlsStream<T>) -> Self {
169 Self(Tls(Box::new(value)))
170 }
171}
172
173impl<C> AsyncRead for RustlsClientTransport<C>
174where
175 C: AsyncWrite + AsyncRead + Unpin,
176{
177 fn poll_read(
178 mut self: Pin<&mut Self>,
179 cx: &mut Context<'_>,
180 buf: &mut [u8],
181 ) -> Poll<Result<usize>> {
182 match &mut self.0 {
183 Tcp(c) => Pin::new(c).poll_read(cx, buf),
184 Tls(c) => Pin::new(c).poll_read(cx, buf),
185 }
186 }
187}
188
189impl<C> AsyncWrite for RustlsClientTransport<C>
190where
191 C: AsyncRead + AsyncWrite + Unpin,
192{
193 fn poll_write(
194 mut self: Pin<&mut Self>,
195 cx: &mut Context<'_>,
196 buf: &[u8],
197 ) -> Poll<Result<usize>> {
198 match &mut self.0 {
199 Tcp(c) => Pin::new(c).poll_write(cx, buf),
200 Tls(c) => Pin::new(&mut *c).poll_write(cx, buf),
201 }
202 }
203
204 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
205 match &mut self.0 {
206 Tcp(c) => Pin::new(c).poll_flush(cx),
207 Tls(c) => Pin::new(&mut *c).poll_flush(cx),
208 }
209 }
210
211 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
212 match &mut self.0 {
213 Tcp(c) => Pin::new(c).poll_close(cx),
214 Tls(c) => Pin::new(&mut *c).poll_close(cx),
215 }
216 }
217
218 fn poll_write_vectored(
219 mut self: Pin<&mut Self>,
220 cx: &mut Context<'_>,
221 bufs: &[IoSlice<'_>],
222 ) -> Poll<Result<usize>> {
223 match &mut self.0 {
224 Tcp(c) => Pin::new(c).poll_write_vectored(cx, bufs),
225 Tls(c) => Pin::new(&mut *c).poll_write_vectored(cx, bufs),
226 }
227 }
228}
229
230impl<T: Transport> Transport for RustlsClientTransport<T> {
231 fn peer_addr(&self) -> Result<Option<SocketAddr>> {
232 self.as_ref().peer_addr()
233 }
234}
235
236impl<T> AsRef<T> for RustlsClientTransport<T> {
237 fn as_ref(&self) -> &T {
238 match &self.0 {
239 Tcp(x) => x,
240 Tls(x) => x.get_ref().0,
241 }
242 }
243}
244
245impl<T> RustlsClientTransport<T> {
246 pub fn tls_state_mut(&mut self) -> Option<&mut ClientConnection> {
248 match &mut self.0 {
249 Tls(x) => Some(x.get_mut().1),
250 _ => None,
251 }
252 }
253
254 pub fn tls_state(&self) -> Option<&ClientConnection> {
256 match &self.0 {
257 Tls(x) => Some(x.get_ref().1),
258 _ => None,
259 }
260 }
261}