1mod certificates;
2mod upgrade;
3
4pub use upgrade::upgrade_tcp_stream;
8
9#[cfg(test)]
10mod tests {
11 use super::*;
12 use rcgen::{CertificateParams, KeyPair};
13 use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
14 use rustls::pki_types::{CertificateDer, ServerName};
15 use rustls::{ClientConfig, DigitallySignedStruct, Error as TlsError, SignatureScheme};
16 use rustls_pemfile::certs;
17 use std::fs::{remove_file, File};
18 use std::io::{BufReader, Write};
19 use std::path::PathBuf;
20 use std::sync::Arc;
21 use tokio::io::{self, AsyncReadExt, AsyncWriteExt};
22 use tokio::net::{TcpListener, TcpStream};
23 use tokio::sync::oneshot::{channel, Sender};
24 use tokio_rustls::TlsConnector;
25
26 async fn start_tls_server(
28 certificate_path: PathBuf,
29 key_path: PathBuf,
30 addr: &str,
31 tx: Sender<u8>,
32 ) -> io::Result<()> {
33 let listener = TcpListener::bind(addr).await?;
34
35 tx.send(1).unwrap(); let (stream, _) = listener.accept().await?;
39
40 let mut tls_stream = upgrade_tcp_stream(stream, certificate_path, key_path).await?;
42
43 let mut buffer = [0u8; 1024]; let n = tls_stream.read(&mut buffer).await?;
46 if n > 0 {
47 log::info!(
48 "Received from client: {:?}",
49 String::from_utf8_lossy(&buffer[..n])
50 );
51 let response = b"Hello TLS client!";
52 tls_stream.write_all(response).await?;
53 }
54
55 Ok(())
56 }
57
58 #[derive(Debug)]
60 struct NoVerification;
61
62 impl ServerCertVerifier for NoVerification {
64 fn verify_server_cert(
65 &self,
66 _end_entity: &CertificateDer,
67 _intermediates: &[CertificateDer],
68 _server_name: &ServerName,
69 _ocsp_response: &[u8],
70 _now: rustls::pki_types::UnixTime,
71 ) -> Result<ServerCertVerified, TlsError> {
72 Ok(ServerCertVerified::assertion())
73 }
74
75 fn verify_tls12_signature(
76 &self,
77 _message: &[u8],
78 _cert: &CertificateDer,
79 _dss: &DigitallySignedStruct,
80 ) -> Result<HandshakeSignatureValid, TlsError> {
81 Ok(HandshakeSignatureValid::assertion())
82 }
83
84 fn verify_tls13_signature(
85 &self,
86 _message: &[u8],
87 _cert: &CertificateDer,
88 _dss: &DigitallySignedStruct,
89 ) -> Result<HandshakeSignatureValid, TlsError> {
90 Ok(HandshakeSignatureValid::assertion())
91 }
92
93 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
94 vec![SignatureScheme::RSA_PSS_SHA256]
95 }
96 }
97
98 #[tokio::test]
99 async fn test_tls_upgrade() {
100 let subject_alt_names = vec!["localhost".to_string()];
102 let alg = &rcgen::PKCS_RSA_SHA512;
103 let key_pair = KeyPair::generate_for(alg).unwrap();
104 let cert = CertificateParams::new(subject_alt_names)
105 .unwrap()
106 .self_signed(&key_pair)
107 .unwrap();
108
109 let mut cert_file = File::create("cert.pem").unwrap();
111 let mut key_file = File::create("key.pem").unwrap();
112
113 writeln!(cert_file, "{}", cert.pem()).unwrap();
115 writeln!(key_file, "{}", key_pair.serialize_pem()).unwrap();
116
117 let cert_path = PathBuf::from("cert.pem");
119 let key_path = PathBuf::from("key.pem");
120 let server_addr = "127.0.0.1:5001";
121
122 let (tx, rx) = channel();
124
125 let server = tokio::spawn(async move {
127 start_tls_server(PathBuf::from("cert.pem"), key_path, server_addr, tx)
128 .await
129 .unwrap();
130 });
131
132 let certificates = certs(&mut BufReader::new(File::open(cert_path).unwrap()))
134 .next()
135 .unwrap()
136 .unwrap();
137
138 let mut cert_store = rustls::RootCertStore::empty();
140 cert_store.add(certificates).unwrap();
141
142 let mut config = ClientConfig::builder()
144 .with_root_certificates(cert_store)
145 .with_no_client_auth();
146
147 let verifier = Arc::new(NoVerification);
149 config.dangerous().set_certificate_verifier(verifier);
150
151 let dns_name = ServerName::try_from("localhost").unwrap();
153 let connector = TlsConnector::from(Arc::new(config));
154
155 rx.await.unwrap();
157
158 let stream = TcpStream::connect(server_addr).await.unwrap();
160 let mut tls_stream = connector.connect(dns_name, stream).await.unwrap();
161
162 let message = b"Hello Server!";
164 tls_stream.write_all(message).await.unwrap();
165 tls_stream.flush().await.unwrap();
166
167 let mut buffer = vec![0u8; 1024];
169 let n = tls_stream.read(&mut buffer).await.unwrap();
170
171 assert!(n > 0);
173
174 log::info!(
175 "Received from server: {:?}",
176 String::from_utf8_lossy(&buffer[..n])
177 );
178
179 assert_eq!(&buffer[..n], b"Hello TLS client!");
181
182 server.await.unwrap();
184
185 remove_file("cert.pem").unwrap();
187 remove_file("key.pem").unwrap();
188 }
189}