Skip to main content

tokio_tls_helper/
lib.rs

1mod client_config;
2mod connected;
3mod error;
4mod identity;
5// mod io;
6mod server_config;
7mod tls;
8
9use std::sync::Arc;
10use tokio_rustls::rustls::ClientConfig;
11
12pub use client_config::ClientTlsConfig;
13pub use error::Error;
14pub use identity::{Certificate, Identity};
15pub use server_config::ServerTlsConfig;
16
17// re-exports
18pub use tokio_rustls::client::TlsStream as TlsClientStream;
19pub use tokio_rustls::server::TlsStream as TlsServerStream;
20
21pub(crate) use connected::Connected;
22
23#[derive(Clone)]
24pub struct TlsConnector {
25    pub config: Arc<ClientConfig>,
26    pub domain: Arc<String>,
27}
28#[derive(Clone)]
29pub struct TlsAcceptor {
30    pub inner: Arc<tokio_rustls::rustls::ServerConfig>,
31}
32
33#[cfg(test)]
34mod tests {
35    use http::Uri;
36    use test_helper::start_logger;
37    use tokio::{
38        io::{split, AsyncReadExt, AsyncWriteExt},
39        net::{TcpListener, TcpStream},
40    };
41    use tracing::{error, info};
42
43    use super::*;
44
45    #[tokio::test]
46    async fn tls_server_build_config_should_work() {
47        start_logger();
48
49        let msg = b"Hello world\n";
50        let mut buf = [0; 12];
51
52        let cert = include_str!("fixtures/server.cert");
53        let key = include_str!("fixtures/server.key");
54        let identity = Identity::from_pem(cert, key);
55        let config = ServerTlsConfig::new().identity(identity);
56        start_server(config, "0.0.0.0:5000").await;
57
58        let cert = Certificate::from_pem(include_str!("fixtures/ca.cert"));
59        let config = ClientTlsConfig::new().ca_certificate(cert);
60        start_client(config, "127.0.0.1:5000", msg, &mut buf).await;
61        assert_eq!(&buf, msg);
62    }
63
64    #[tokio::test]
65    async fn tls_server_load_config_file_should_work() {
66        start_logger();
67
68        let msg = b"Hello world\n";
69        let mut buf = [0; 12];
70
71        let config = toml::from_str(include_str!("fixtures/server.toml")).unwrap();
72        start_server(config, "0.0.0.0:5001").await;
73
74        let config = toml::from_str(include_str!("fixtures/client.toml")).unwrap();
75        start_client(config, "127.0.0.1:5001", msg, &mut buf).await;
76        assert_eq!(&buf, msg);
77    }
78
79    #[tokio::test]
80    async fn tls_server_verify_client_cert_should_work() {
81        start_logger();
82
83        let msg = b"Hello world\n";
84        let mut buf = [0; 12];
85
86        let config =
87            toml::from_str(include_str!("fixtures/server_verify_client_cert.toml")).unwrap();
88        start_server(config, "0.0.0.0:5002").await;
89
90        let config = toml::from_str(include_str!("fixtures/client_with_cert.toml")).unwrap();
91        start_client(config, "127.0.0.1:5002", msg, &mut buf).await;
92        assert_eq!(&buf, msg);
93    }
94
95    #[tokio::test]
96    async fn tls_server_invalid_client_cert_should_fail() {
97        start_logger();
98
99        let msg = b"Hello world\n";
100        let mut buf = [0; 12];
101
102        let config: ServerTlsConfig =
103            toml::from_str(include_str!("fixtures/server_verify_client_cert.toml")).unwrap();
104        let acceptor = config.tls_acceptor().unwrap();
105        let listener = TcpListener::bind("0.0.0.0:5003").await.unwrap();
106        tokio::spawn(async move {
107            let (stream, _peer_addr) = listener.accept().await.unwrap();
108            let result = acceptor.accept(stream).await;
109            assert!(result.is_err());
110            error!("server: failed client auth");
111        });
112
113        let config: ClientTlsConfig =
114            toml::from_str(include_str!("fixtures/client_with_invalid_cert.toml")).unwrap();
115        let connector = config.tls_connector(Uri::from_static("localhost")).unwrap();
116
117        let stream = TcpStream::connect("127.0.0.1:5003").await.unwrap();
118        let mut stream = connector.connect(stream).await.unwrap();
119        info!("client: TLS conn established");
120
121        stream.write_all(msg).await.unwrap();
122
123        info!("client: send data");
124
125        let (mut reader, _writer) = split(stream);
126
127        let result = reader.read_exact(&mut buf).await;
128        assert!(result.is_err());
129    }
130
131    async fn start_server(config: ServerTlsConfig, addr: &str) {
132        let acceptor = config.tls_acceptor().unwrap();
133        let listener = TcpListener::bind(addr).await.unwrap();
134        tokio::spawn(async move {
135            let (stream, _peer_addr) = listener.accept().await.unwrap();
136            let stream = acceptor.accept(stream).await.unwrap();
137            info!("server: Accepted client conn with TLS");
138
139            let (mut reader, mut writer) = split(stream);
140            let mut buf = [0; 12];
141            reader.read_exact(&mut buf).await.unwrap();
142            info!("server: got data: {:?}", buf);
143            writer.write_all(&buf).await.unwrap();
144            info!("server: flush the data out");
145        });
146    }
147
148    async fn start_client(config: ClientTlsConfig, addr: &str, msg: &[u8], buf: &mut [u8]) {
149        let connector = config.tls_connector(Uri::from_static("localhost")).unwrap();
150
151        let stream = TcpStream::connect(addr).await.unwrap();
152        let mut stream = connector.connect(stream).await.unwrap();
153        info!("client: TLS conn established");
154
155        stream.write_all(msg).await.unwrap();
156
157        info!("client: send data");
158
159        let (mut reader, _writer) = split(stream);
160
161        reader.read_exact(buf).await.unwrap();
162
163        info!("client: read echoed data");
164    }
165}