tokio_tls_upgrade/
lib.rs

1mod certificates;
2mod upgrade;
3
4/**
5 * This function upgrades a `tokio::net::TcpStream` to a `tokio_rustls::server::TlsStream` using a provided certificate chain/certificate and key file.
6 */
7pub use upgrade::upgrade_tcp_stream;
8
9#[cfg(test)]
10mod tests {
11    use super::*;
12    use rcgen::{CertificateParams, KeyPair};
13    use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
14    use rustls::pki_types::{CertificateDer, ServerName};
15    use rustls::{ClientConfig, DigitallySignedStruct, Error as TlsError, SignatureScheme};
16    use rustls_pemfile::certs;
17    use std::fs::{remove_file, File};
18    use std::io::{BufReader, Write};
19    use std::path::PathBuf;
20    use std::sync::Arc;
21    use tokio::io::{self, AsyncReadExt, AsyncWriteExt};
22    use tokio::net::{TcpListener, TcpStream};
23    use tokio::sync::oneshot::{channel, Sender};
24    use tokio_rustls::TlsConnector;
25
26    // Start a TLS server that listens on the given address and port
27    async fn start_tls_server(
28        certificate_path: PathBuf,
29        key_path: PathBuf,
30        addr: &str,
31        tx: Sender<u8>,
32    ) -> io::Result<()> {
33        let listener = TcpListener::bind(addr).await?;
34
35        tx.send(1).unwrap(); // Notify the client that the server is ready
36
37        // Accept a new connection. Code will not proceed until a connection is made
38        let (stream, _) = listener.accept().await?;
39
40        // Upgrade the connection to a TLS connection using the library function
41        let mut tls_stream = upgrade_tcp_stream(stream, certificate_path, key_path).await?;
42
43        // Handle the stream: for example, read a message and respond
44        let mut buffer = [0u8; 1024]; // Example buffer
45        let n = tls_stream.read(&mut buffer).await?;
46        if n > 0 {
47            log::info!(
48                "Received from client: {:?}",
49                String::from_utf8_lossy(&buffer[..n])
50            );
51            let response = b"Hello TLS client!";
52            tls_stream.write_all(response).await?;
53        }
54
55        Ok(())
56    }
57
58    // Implement a ServerCertVerifier that does not verify the server certificate for testing purposes
59    #[derive(Debug)]
60    struct NoVerification;
61
62    // Always return ServerCertVerified::assertion() to indicate that the server certificate is verified
63    impl ServerCertVerifier for NoVerification {
64        fn verify_server_cert(
65            &self,
66            _end_entity: &CertificateDer,
67            _intermediates: &[CertificateDer],
68            _server_name: &ServerName,
69            _ocsp_response: &[u8],
70            _now: rustls::pki_types::UnixTime,
71        ) -> Result<ServerCertVerified, TlsError> {
72            Ok(ServerCertVerified::assertion())
73        }
74
75        fn verify_tls12_signature(
76            &self,
77            _message: &[u8],
78            _cert: &CertificateDer,
79            _dss: &DigitallySignedStruct,
80        ) -> Result<HandshakeSignatureValid, TlsError> {
81            Ok(HandshakeSignatureValid::assertion())
82        }
83
84        fn verify_tls13_signature(
85            &self,
86            _message: &[u8],
87            _cert: &CertificateDer,
88            _dss: &DigitallySignedStruct,
89        ) -> Result<HandshakeSignatureValid, TlsError> {
90            Ok(HandshakeSignatureValid::assertion())
91        }
92
93        fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
94            vec![SignatureScheme::RSA_PSS_SHA256]
95        }
96    }
97
98    #[tokio::test]
99    async fn test_tls_upgrade() {
100        // Generate self-signed certificate
101        let subject_alt_names = vec!["localhost".to_string()];
102        let alg = &rcgen::PKCS_RSA_SHA512;
103        let key_pair = KeyPair::generate_for(alg).unwrap();
104        let cert = CertificateParams::new(subject_alt_names)
105            .unwrap()
106            .self_signed(&key_pair)
107            .unwrap();
108
109        // Write to cert.pem and key.pem
110        let mut cert_file = File::create("cert.pem").unwrap();
111        let mut key_file = File::create("key.pem").unwrap();
112
113        // Write the serialized certificates and pem files
114        writeln!(cert_file, "{}", cert.pem()).unwrap();
115        writeln!(key_file, "{}", key_pair.serialize_pem()).unwrap();
116
117        // Proceed with the test
118        let cert_path = PathBuf::from("cert.pem");
119        let key_path = PathBuf::from("key.pem");
120        let server_addr = "127.0.0.1:5001";
121
122        // Create a channel to communicate with the server
123        let (tx, rx) = channel();
124
125        // Start server in background that runs on TLS
126        let server = tokio::spawn(async move {
127            start_tls_server(PathBuf::from("cert.pem"), key_path, server_addr, tx)
128                .await
129                .unwrap();
130        });
131
132        // Parse the certificate into DER format
133        let certificates = certs(&mut BufReader::new(File::open(cert_path).unwrap()))
134            .next()
135            .unwrap()
136            .unwrap();
137
138        // Load the certificate into the root store
139        let mut cert_store = rustls::RootCertStore::empty();
140        cert_store.add(certificates).unwrap();
141
142        // Prepare the client TLS configuration
143        let mut config = ClientConfig::builder()
144            .with_root_certificates(cert_store)
145            .with_no_client_auth();
146
147        // Disable server certificate verification
148        let verifier = Arc::new(NoVerification);
149        config.dangerous().set_certificate_verifier(verifier);
150
151        // Create a DNS name for the server
152        let dns_name = ServerName::try_from("localhost").unwrap();
153        let connector = TlsConnector::from(Arc::new(config));
154
155        // Wait for the server to be ready
156        rx.await.unwrap();
157
158        // Connect and upgrade the client connection
159        let stream = TcpStream::connect(server_addr).await.unwrap();
160        let mut tls_stream = connector.connect(dns_name, stream).await.unwrap();
161
162        // Send and receive message
163        let message = b"Hello Server!";
164        tls_stream.write_all(message).await.unwrap();
165        tls_stream.flush().await.unwrap();
166
167        // Read the response from the server
168        let mut buffer = vec![0u8; 1024];
169        let n = tls_stream.read(&mut buffer).await.unwrap();
170
171        // Check if the response has a length greater than 0
172        assert!(n > 0);
173
174        log::info!(
175            "Received from server: {:?}",
176            String::from_utf8_lossy(&buffer[..n])
177        );
178
179        // Check if the response is correct
180        assert_eq!(&buffer[..n], b"Hello TLS client!");
181
182        // Await server task to conclude
183        server.await.unwrap();
184
185        // Delete the files
186        remove_file("cert.pem").unwrap();
187        remove_file("key.pem").unwrap();
188    }
189}