rustls_tokio_stream/
lib.rs

1// Copyright 2018-2023 the Deno authors. All rights reserved. MIT license.
2mod adapter;
3mod connection_stream;
4mod handshake;
5mod stream;
6
7///! An `async` wrapper around the `rustls` connection types and a `tokio` TCP socket.
8
9#[cfg(test)]
10mod system_test;
11
12pub use stream::ServerConfigProvider;
13pub use stream::TlsHandshake;
14pub use stream::TlsStream;
15pub use stream::TlsStreamRead;
16pub use stream::TlsStreamWrite;
17pub use stream::UnderlyingStream;
18
19/// Re-export the version of rustls we are built on
20pub use rustls;
21
22/// Used to modify test timing to expose problems.
23#[derive(Copy, Clone, Default)]
24struct TestOptions {
25  #[cfg(test)]
26  delay_handshake: bool,
27  #[cfg(test)]
28  slow_handshake_read: bool,
29  #[cfg(test)]
30  slow_handshake_write: bool,
31}
32
33#[cfg(feature = "trace")]
34static ENABLE_BYTE_TRACING: std::sync::atomic::AtomicBool =
35  std::sync::atomic::AtomicBool::new(false);
36
37#[cfg(feature = "trace")]
38pub fn enable_byte_tracing() {
39  ENABLE_BYTE_TRACING.store(true, std::sync::atomic::Ordering::SeqCst);
40}
41
42macro_rules! trace {
43  ($($args:expr),+) => {
44    if cfg!(feature="trace")
45    {
46      print!("[{:?}] ", std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_millis());
47      println!($($args),+);
48    }
49  };
50}
51
52pub(crate) use trace;
53
54#[cfg(test)]
55mod tests {
56  pub use super::stream::tests::tls_pair;
57  pub use super::stream::tests::tls_pair_buffer_size;
58  use rustls::client::danger::ServerCertVerified;
59  use rustls::client::danger::ServerCertVerifier;
60  use rustls::pki_types::CertificateDer;
61  use rustls::pki_types::PrivateKeyDer;
62  use rustls::pki_types::ServerName;
63  use rustls::ClientConfig;
64  use rustls::ServerConfig;
65  use std::io;
66  use std::io::BufRead;
67  use std::net::Ipv4Addr;
68  use std::net::SocketAddr;
69  use std::net::SocketAddrV4;
70  use std::sync::Arc;
71  use tokio::net::TcpListener;
72  use tokio::net::TcpSocket;
73  use tokio::net::TcpStream;
74  use tokio::spawn;
75
76  pub type TestResult = Result<(), Box<dyn std::error::Error>>;
77
78  #[derive(Debug)]
79  pub struct UnsafeVerifier {}
80
81  impl ServerCertVerifier for UnsafeVerifier {
82    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
83      vec![rustls::SignatureScheme::RSA_PSS_SHA256]
84    }
85
86    fn verify_tls12_signature(
87      &self,
88      _message: &[u8],
89      _cert: &rustls::pki_types::CertificateDer<'_>,
90      _dss: &rustls::DigitallySignedStruct,
91    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error>
92    {
93      Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
94    }
95
96    fn verify_tls13_signature(
97      &self,
98      _message: &[u8],
99      _cert: &rustls::pki_types::CertificateDer<'_>,
100      _dss: &rustls::DigitallySignedStruct,
101    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error>
102    {
103      Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
104    }
105
106    fn verify_server_cert(
107      &self,
108      _end_entity: &rustls::pki_types::CertificateDer<'_>,
109      _intermediates: &[rustls::pki_types::CertificateDer<'_>],
110      _server_name: &ServerName<'_>,
111      _ocsp_response: &[u8],
112      _now: rustls::pki_types::UnixTime,
113    ) -> Result<ServerCertVerified, rustls::Error> {
114      Ok(rustls::client::danger::ServerCertVerified::assertion())
115    }
116  }
117
118  pub fn certificate() -> CertificateDer<'static> {
119    let buf_read: &mut dyn BufRead =
120      &mut &include_bytes!("testdata/localhost.crt")[..];
121    let cert = rustls_pemfile::read_one(buf_read)
122      .expect("Failed to load test cert")
123      .unwrap();
124    match cert {
125      rustls_pemfile::Item::X509Certificate(cert) => cert,
126      _ => {
127        panic!("Unexpected item")
128      }
129    }
130  }
131
132  pub fn private_key() -> PrivateKeyDer<'static> {
133    let buf_read: &mut dyn BufRead =
134      &mut &include_bytes!("testdata/localhost.key")[..];
135    let cert = rustls_pemfile::read_one(buf_read)
136      .expect("Failed to load test key")
137      .unwrap();
138    match cert {
139      rustls_pemfile::Item::Pkcs8Key(key) => key.into(),
140      _ => {
141        panic!("Unexpected item")
142      }
143    }
144  }
145
146  pub fn server_config() -> ServerConfig {
147    ServerConfig::builder()
148      .with_no_client_auth()
149      .with_single_cert(vec![certificate()], private_key())
150      .expect("Failed to build server config")
151  }
152
153  pub fn client_config() -> ClientConfig {
154    ClientConfig::builder()
155      .dangerous()
156      .with_custom_certificate_verifier(Arc::new(UnsafeVerifier {}))
157      .with_no_client_auth()
158  }
159
160  pub fn server_name() -> ServerName<'static> {
161    "example.com".try_into().unwrap()
162  }
163
164  pub async fn tcp_pair() -> (TcpStream, TcpStream) {
165    let listener = TcpListener::bind(SocketAddr::V4(SocketAddrV4::new(
166      Ipv4Addr::LOCALHOST,
167      0,
168    )))
169    .await
170    .unwrap();
171    let port = listener.local_addr().unwrap().port();
172    let server = spawn(async move { listener.accept().await.unwrap().0 });
173    let client = spawn(async move {
174      TcpSocket::new_v4()
175        .unwrap()
176        .connect(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port)))
177        .await
178        .unwrap()
179    });
180
181    let (server, client) = (server.await.unwrap(), client.await.unwrap());
182    (server, client)
183  }
184
185  pub fn expect_io_error<T: std::fmt::Debug>(
186    e: Result<T, io::Error>,
187    kind: io::ErrorKind,
188  ) {
189    assert_eq!(e.expect_err("Expected error").kind(), kind);
190  }
191}