1use crate::crypto_provider;
2use RustlsClientTransportInner::{Tcp, Tls};
3#[cfg(feature = "dangerous")]
4use futures_rustls::rustls::{
5 DigitallySignedStruct, SignatureScheme,
6 client::danger::{HandshakeSignatureValid, ServerCertVerified},
7 crypto::{verify_tls12_signature, verify_tls13_signature},
8 pki_types::{CertificateDer, UnixTime},
9};
10use futures_rustls::{
11 TlsConnector,
12 client::TlsStream,
13 rustls::{
14 ClientConfig, ClientConnection, RootCertStore,
15 client::{WebPkiServerVerifier, danger::ServerCertVerifier},
16 crypto::CryptoProvider,
17 pki_types::ServerName,
18 },
19};
20use std::{
21 fmt::{self, Debug, Formatter},
22 io::{Error, ErrorKind, IoSlice, Result},
23 net::SocketAddr,
24 pin::Pin,
25 sync::Arc,
26 task::{Context, Poll},
27};
28use trillium_server_common::{AsyncRead, AsyncWrite, Connector, Destination, Transport, Url};
29
30#[derive(Clone, Debug)]
37pub struct RustlsClientConfig(Arc<ClientConfig>);
38
39#[derive(Clone, Default)]
41pub struct RustlsConfig<Config> {
42 pub rustls_config: RustlsClientConfig,
44
45 pub tcp_config: Config,
47}
48
49impl<C: Connector> RustlsConfig<C> {
50 pub fn new(rustls_config: impl Into<RustlsClientConfig>, tcp_config: C) -> Self {
52 Self {
53 rustls_config: rustls_config.into(),
54 tcp_config,
55 }
56 }
57}
58
59impl Default for RustlsClientConfig {
60 fn default() -> Self {
61 Self(Arc::new(default_client_config()))
62 }
63}
64
65#[cfg(feature = "platform-verifier")]
66fn verifier(provider: Arc<CryptoProvider>) -> Arc<dyn ServerCertVerifier> {
67 Arc::new(rustls_platform_verifier::Verifier::new(provider).unwrap())
68}
69
70#[cfg(not(feature = "platform-verifier"))]
71fn verifier(provider: Arc<CryptoProvider>) -> Arc<dyn ServerCertVerifier> {
72 let roots = Arc::new(RootCertStore::from_iter(
73 webpki_roots::TLS_SERVER_ROOTS.iter().cloned(),
74 ));
75 WebPkiServerVerifier::builder_with_provider(roots, provider)
76 .build()
77 .unwrap()
78}
79
80fn client_config_with_verifier(verifier: Arc<dyn ServerCertVerifier>) -> ClientConfig {
81 let mut config = ClientConfig::builder_with_provider(crypto_provider())
82 .with_safe_default_protocol_versions()
83 .expect("crypto provider did not support safe default protocol versions")
84 .dangerous()
85 .with_custom_certificate_verifier(verifier)
86 .with_no_client_auth();
87
88 config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
89
90 config
91}
92
93fn default_client_config() -> ClientConfig {
94 client_config_with_verifier(verifier(crypto_provider()))
95}
96
97impl RustlsClientConfig {
98 pub fn from_root_cert_pem(pem: &[u8]) -> Result<Self> {
114 let mut roots = RootCertStore::empty();
115 let mut reader = pem;
116 for cert in rustls_pemfile::certs(&mut reader) {
117 roots.add(cert?).map_err(Error::other)?;
118 }
119
120 if roots.is_empty() {
121 return Err(Error::new(
122 ErrorKind::InvalidInput,
123 "no certificates found in pem",
124 ));
125 }
126
127 let verifier =
128 WebPkiServerVerifier::builder_with_provider(Arc::new(roots), crypto_provider())
129 .build()
130 .map_err(Error::other)?;
131
132 Ok(Self(Arc::new(client_config_with_verifier(verifier))))
133 }
134}
135
136impl From<ClientConfig> for RustlsClientConfig {
137 fn from(rustls_config: ClientConfig) -> Self {
138 Self(Arc::new(rustls_config))
139 }
140}
141
142impl From<Arc<ClientConfig>> for RustlsClientConfig {
143 fn from(rustls_config: Arc<ClientConfig>) -> Self {
144 Self(rustls_config)
145 }
146}
147
148#[cfg(feature = "dangerous")]
149#[derive(Debug)]
150struct AcceptAnyServerCert(Arc<CryptoProvider>);
151
152#[cfg(feature = "dangerous")]
153impl ServerCertVerifier for AcceptAnyServerCert {
154 fn verify_server_cert(
155 &self,
156 _end_entity: &CertificateDer<'_>,
157 _intermediates: &[CertificateDer<'_>],
158 _server_name: &ServerName<'_>,
159 _ocsp_response: &[u8],
160 _now: UnixTime,
161 ) -> std::result::Result<ServerCertVerified, futures_rustls::rustls::Error> {
162 Ok(ServerCertVerified::assertion())
163 }
164
165 fn verify_tls12_signature(
166 &self,
167 message: &[u8],
168 cert: &CertificateDer<'_>,
169 dss: &DigitallySignedStruct,
170 ) -> std::result::Result<HandshakeSignatureValid, futures_rustls::rustls::Error> {
171 verify_tls12_signature(
172 message,
173 cert,
174 dss,
175 &self.0.signature_verification_algorithms,
176 )
177 }
178
179 fn verify_tls13_signature(
180 &self,
181 message: &[u8],
182 cert: &CertificateDer<'_>,
183 dss: &DigitallySignedStruct,
184 ) -> std::result::Result<HandshakeSignatureValid, futures_rustls::rustls::Error> {
185 verify_tls13_signature(
186 message,
187 cert,
188 dss,
189 &self.0.signature_verification_algorithms,
190 )
191 }
192
193 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
194 self.0.signature_verification_algorithms.supported_schemes()
195 }
196}
197
198#[cfg(feature = "dangerous")]
199#[cfg_attr(docsrs, doc(cfg(feature = "dangerous")))]
200impl RustlsClientConfig {
201 pub fn dangerously_accept_any_cert() -> Self {
213 log::warn!(
214 "constructing a rustls client config that accepts any server certificate; server \
215 authentication is disabled and connections are vulnerable to interception"
216 );
217 let verifier = Arc::new(AcceptAnyServerCert(crypto_provider()));
218 Self(Arc::new(client_config_with_verifier(verifier)))
219 }
220}
221
222impl<C: Connector> RustlsConfig<C> {
223 pub fn with_tcp_config(mut self, config: C) -> Self {
225 self.tcp_config = config;
226 self
227 }
228
229 #[must_use]
235 pub fn without_http2(mut self) -> Self {
236 let config = Arc::make_mut(&mut self.rustls_config.0);
237 config.alpn_protocols.retain(|p| p != b"h2");
238 self
239 }
240}
241
242impl<Config: Debug> Debug for RustlsConfig<Config> {
243 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
244 f.debug_struct("RustlsConfig")
245 .field("rustls_config", &format_args!(".."))
246 .field("tcp_config", &self.tcp_config)
247 .finish()
248 }
249}
250
251impl<C: Connector> Connector for RustlsConfig<C> {
252 type Runtime = C::Runtime;
253 type Transport = RustlsClientTransport<C::Transport>;
254 type Udp = C::Udp;
255
256 async fn connect(&self, url: &Url) -> Result<Self::Transport> {
257 self.connect_to(Destination::from_url(url)?).await
258 }
259
260 async fn connect_to(&self, destination: Destination) -> Result<Self::Transport> {
261 if !destination.secure() {
262 return self
263 .tcp_config
264 .connect_to(destination)
265 .await
266 .map(Into::into);
267 }
268
269 let rustls_config = if let Some(alpn) = destination.alpn() {
273 let mut config = (*self.rustls_config.0).clone();
274 config.alpn_protocols = alpn.iter().map(|p| p.to_vec()).collect();
275 Arc::new(config)
276 } else {
277 Arc::clone(&self.rustls_config.0)
278 };
279 let connector: TlsConnector = rustls_config.into();
280
281 let domain_server_name = destination
286 .host()
287 .map(|domain| {
288 ServerName::try_from(domain.to_owned())
289 .map_err(|e| Error::other(format!("invalid server name {domain:?}: {e}")))
290 })
291 .transpose()?;
292
293 let stream = self
294 .tcp_config
295 .connect_to(destination.with_secure(false))
296 .await?;
297
298 let server_name = match domain_server_name {
299 Some(server_name) => server_name,
300 None => {
301 let ip = stream
302 .peer_addr()?
303 .ok_or_else(|| Error::other("no peer address for bare-ip destination"))?
304 .ip();
305 ServerName::IpAddress(ip.into())
306 }
307 };
308
309 connector
310 .connect(server_name, stream)
311 .await
312 .map_err(|e| Error::other(e.to_string()))
313 .map(Into::into)
314 }
315
316 fn runtime(&self) -> Self::Runtime {
317 self.tcp_config.runtime()
318 }
319
320 async fn resolve(&self, host: &str, port: u16) -> Result<Vec<SocketAddr>> {
321 self.tcp_config.resolve(host, port).await
322 }
323}
324
325#[derive(Debug)]
326enum RustlsClientTransportInner<T> {
327 Tcp(T),
328 Tls(Box<TlsStream<T>>),
329}
330
331#[derive(Debug)]
336pub struct RustlsClientTransport<T>(RustlsClientTransportInner<T>);
337impl<T> From<T> for RustlsClientTransport<T> {
338 fn from(value: T) -> Self {
339 Self(Tcp(value))
340 }
341}
342
343impl<T> From<TlsStream<T>> for RustlsClientTransport<T> {
344 fn from(value: TlsStream<T>) -> Self {
345 Self(Tls(Box::new(value)))
346 }
347}
348
349impl<C> AsyncRead for RustlsClientTransport<C>
350where
351 C: AsyncWrite + AsyncRead + Unpin,
352{
353 fn poll_read(
354 mut self: Pin<&mut Self>,
355 cx: &mut Context<'_>,
356 buf: &mut [u8],
357 ) -> Poll<Result<usize>> {
358 match &mut self.0 {
359 Tcp(c) => Pin::new(c).poll_read(cx, buf),
360 Tls(c) => Pin::new(c).poll_read(cx, buf),
361 }
362 }
363
364 fn poll_read_vectored(
365 mut self: Pin<&mut Self>,
366 cx: &mut Context<'_>,
367 bufs: &mut [std::io::IoSliceMut<'_>],
368 ) -> Poll<Result<usize>> {
369 match &mut self.0 {
370 Tcp(c) => Pin::new(c).poll_read_vectored(cx, bufs),
371 Tls(c) => Pin::new(c).poll_read_vectored(cx, bufs),
372 }
373 }
374}
375
376impl<C> AsyncWrite for RustlsClientTransport<C>
377where
378 C: AsyncRead + AsyncWrite + Unpin,
379{
380 fn poll_write(
381 mut self: Pin<&mut Self>,
382 cx: &mut Context<'_>,
383 buf: &[u8],
384 ) -> Poll<Result<usize>> {
385 match &mut self.0 {
386 Tcp(c) => Pin::new(c).poll_write(cx, buf),
387 Tls(c) => Pin::new(&mut *c).poll_write(cx, buf),
388 }
389 }
390
391 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
392 match &mut self.0 {
393 Tcp(c) => Pin::new(c).poll_flush(cx),
394 Tls(c) => Pin::new(&mut *c).poll_flush(cx),
395 }
396 }
397
398 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
399 match &mut self.0 {
400 Tcp(c) => Pin::new(c).poll_close(cx),
401 Tls(c) => Pin::new(&mut *c).poll_close(cx),
402 }
403 }
404
405 fn poll_write_vectored(
406 mut self: Pin<&mut Self>,
407 cx: &mut Context<'_>,
408 bufs: &[IoSlice<'_>],
409 ) -> Poll<Result<usize>> {
410 match &mut self.0 {
411 Tcp(c) => Pin::new(c).poll_write_vectored(cx, bufs),
412 Tls(c) => Pin::new(&mut *c).poll_write_vectored(cx, bufs),
413 }
414 }
415}
416
417impl<T: Transport> Transport for RustlsClientTransport<T> {
418 fn peer_addr(&self) -> Result<Option<SocketAddr>> {
419 self.as_ref().peer_addr()
420 }
421
422 fn negotiated_alpn(&self) -> Option<std::borrow::Cow<'_, [u8]>> {
423 self.tls_state()
424 .and_then(|conn| conn.alpn_protocol())
425 .map(std::borrow::Cow::Borrowed)
426 }
427}
428
429impl<T> AsRef<T> for RustlsClientTransport<T> {
430 fn as_ref(&self) -> &T {
431 match &self.0 {
432 Tcp(x) => x,
433 Tls(x) => x.get_ref().0,
434 }
435 }
436}
437
438impl<T> RustlsClientTransport<T> {
439 pub fn tls_state_mut(&mut self) -> Option<&mut ClientConnection> {
441 match &mut self.0 {
442 Tls(x) => Some(x.get_mut().1),
443 _ => None,
444 }
445 }
446
447 pub fn tls_state(&self) -> Option<&ClientConnection> {
449 match &self.0 {
450 Tls(x) => Some(x.get_ref().1),
451 _ => None,
452 }
453 }
454}