rustls_tokio_stream/
lib.rs1mod adapter;
3mod connection_stream;
4mod handshake;
5mod stream;
6
7#[cfg(test)]
10mod system_test;
11
12pub use stream::ServerConfigProvider;
13pub use stream::TlsHandshake;
14pub use stream::TlsStream;
15pub use stream::TlsStreamRead;
16pub use stream::TlsStreamWrite;
17pub use stream::UnderlyingStream;
18
19pub use rustls;
21
22#[derive(Copy, Clone, Default)]
24struct TestOptions {
25 #[cfg(test)]
26 delay_handshake: bool,
27 #[cfg(test)]
28 slow_handshake_read: bool,
29 #[cfg(test)]
30 slow_handshake_write: bool,
31}
32
33#[cfg(feature = "trace")]
34static ENABLE_BYTE_TRACING: std::sync::atomic::AtomicBool =
35 std::sync::atomic::AtomicBool::new(false);
36
37#[cfg(feature = "trace")]
38pub fn enable_byte_tracing() {
39 ENABLE_BYTE_TRACING.store(true, std::sync::atomic::Ordering::SeqCst);
40}
41
42macro_rules! trace {
43 ($($args:expr),+) => {
44 if cfg!(feature="trace")
45 {
46 print!("[{:?}] ", std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_millis());
47 println!($($args),+);
48 }
49 };
50}
51
52pub(crate) use trace;
53
54#[cfg(test)]
55mod tests {
56 pub use super::stream::tests::tls_pair;
57 pub use super::stream::tests::tls_pair_buffer_size;
58 use rustls::client::danger::ServerCertVerified;
59 use rustls::client::danger::ServerCertVerifier;
60 use rustls::pki_types::CertificateDer;
61 use rustls::pki_types::PrivateKeyDer;
62 use rustls::pki_types::ServerName;
63 use rustls::ClientConfig;
64 use rustls::ServerConfig;
65 use std::io;
66 use std::io::BufRead;
67 use std::net::Ipv4Addr;
68 use std::net::SocketAddr;
69 use std::net::SocketAddrV4;
70 use std::sync::Arc;
71 use tokio::net::TcpListener;
72 use tokio::net::TcpSocket;
73 use tokio::net::TcpStream;
74 use tokio::spawn;
75
76 pub type TestResult = Result<(), Box<dyn std::error::Error>>;
77
78 #[derive(Debug)]
79 pub struct UnsafeVerifier {}
80
81 impl ServerCertVerifier for UnsafeVerifier {
82 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
83 vec![rustls::SignatureScheme::RSA_PSS_SHA256]
84 }
85
86 fn verify_tls12_signature(
87 &self,
88 _message: &[u8],
89 _cert: &rustls::pki_types::CertificateDer<'_>,
90 _dss: &rustls::DigitallySignedStruct,
91 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error>
92 {
93 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
94 }
95
96 fn verify_tls13_signature(
97 &self,
98 _message: &[u8],
99 _cert: &rustls::pki_types::CertificateDer<'_>,
100 _dss: &rustls::DigitallySignedStruct,
101 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error>
102 {
103 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
104 }
105
106 fn verify_server_cert(
107 &self,
108 _end_entity: &rustls::pki_types::CertificateDer<'_>,
109 _intermediates: &[rustls::pki_types::CertificateDer<'_>],
110 _server_name: &ServerName<'_>,
111 _ocsp_response: &[u8],
112 _now: rustls::pki_types::UnixTime,
113 ) -> Result<ServerCertVerified, rustls::Error> {
114 Ok(rustls::client::danger::ServerCertVerified::assertion())
115 }
116 }
117
118 pub fn certificate() -> CertificateDer<'static> {
119 let buf_read: &mut dyn BufRead =
120 &mut &include_bytes!("testdata/localhost.crt")[..];
121 let cert = rustls_pemfile::read_one(buf_read)
122 .expect("Failed to load test cert")
123 .unwrap();
124 match cert {
125 rustls_pemfile::Item::X509Certificate(cert) => cert,
126 _ => {
127 panic!("Unexpected item")
128 }
129 }
130 }
131
132 pub fn private_key() -> PrivateKeyDer<'static> {
133 let buf_read: &mut dyn BufRead =
134 &mut &include_bytes!("testdata/localhost.key")[..];
135 let cert = rustls_pemfile::read_one(buf_read)
136 .expect("Failed to load test key")
137 .unwrap();
138 match cert {
139 rustls_pemfile::Item::Pkcs8Key(key) => key.into(),
140 _ => {
141 panic!("Unexpected item")
142 }
143 }
144 }
145
146 pub fn server_config() -> ServerConfig {
147 ServerConfig::builder()
148 .with_no_client_auth()
149 .with_single_cert(vec![certificate()], private_key())
150 .expect("Failed to build server config")
151 }
152
153 pub fn client_config() -> ClientConfig {
154 ClientConfig::builder()
155 .dangerous()
156 .with_custom_certificate_verifier(Arc::new(UnsafeVerifier {}))
157 .with_no_client_auth()
158 }
159
160 pub fn server_name() -> ServerName<'static> {
161 "example.com".try_into().unwrap()
162 }
163
164 pub async fn tcp_pair() -> (TcpStream, TcpStream) {
165 let listener = TcpListener::bind(SocketAddr::V4(SocketAddrV4::new(
166 Ipv4Addr::LOCALHOST,
167 0,
168 )))
169 .await
170 .unwrap();
171 let port = listener.local_addr().unwrap().port();
172 let server = spawn(async move { listener.accept().await.unwrap().0 });
173 let client = spawn(async move {
174 TcpSocket::new_v4()
175 .unwrap()
176 .connect(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port)))
177 .await
178 .unwrap()
179 });
180
181 let (server, client) = (server.await.unwrap(), client.await.unwrap());
182 (server, client)
183 }
184
185 pub fn expect_io_error<T: std::fmt::Debug>(
186 e: Result<T, io::Error>,
187 kind: io::ErrorKind,
188 ) {
189 assert_eq!(e.expect_err("Expected error").kind(), kind);
190 }
191}