1#[cfg(all(feature = "http2-ring", feature = "http2-fips"))]
11compile_error!(
12 "Only one TLS crypto provider may be enabled. Choose one of: http2-ring, http2-fips"
13);
14
15#[cfg(not(any(feature = "http2-ring", feature = "http2-fips")))]
16compile_error!("http2 requires a crypto provider: enable http2-ring or http2-fips");
17
18use futures_util::ready;
21use hyper::server::accept::Accept;
22use hyper::server::conn::{AddrIncoming, AddrStream};
23use rustls_pki_types::pem::PemObject;
24use rustls_pki_types::{CertificateDer, PrivateKeyDer};
25use std::fs::File;
26use std::future::Future;
27use std::io::{self, BufReader, Cursor, Read};
28use std::net::SocketAddr;
29use std::path::{Path, PathBuf};
30use std::pin::Pin;
31use std::sync::Arc;
32use std::task::{Context, Poll};
33use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
34use tokio_rustls::rustls::{Error as TlsError, ServerConfig};
35
36use crate::transport::Transport;
37
38#[derive(Debug)]
40pub enum TlsConfigError {
41 Io(io::Error),
43 CertParseError,
45 InvalidIdentityPem,
47 EmptyKey,
49 UnknownPrivateKeyFormat,
51 InvalidKey(TlsError),
53 IllegalSectionStart(Vec<u8>),
55 IllegalSectionEnd(Vec<u8>),
57}
58
59impl std::fmt::Display for TlsConfigError {
60 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61 match self {
62 TlsConfigError::Io(err) => err.fmt(f),
63 TlsConfigError::CertParseError => write!(f, "failed to parse certificate"),
64 TlsConfigError::InvalidIdentityPem => write!(f, "the identity PEM provided is invalid"),
65 TlsConfigError::UnknownPrivateKeyFormat => {
66 write!(f, "the private key format is unknown")
67 }
68 TlsConfigError::EmptyKey => write!(f, "the key provided is probably missing or empty"),
69 TlsConfigError::InvalidKey(err) => write!(f, "the key provided is invalid, {err}"),
70 TlsConfigError::IllegalSectionStart(line) => {
71 let line = String::from_utf8(line.clone()).unwrap_or_default();
72 write!(f, "illegal section start in PEM at '{line}'")
73 }
74 TlsConfigError::IllegalSectionEnd(end_marker) => {
75 let end_marker = String::from_utf8(end_marker.clone()).unwrap_or_default();
76 write!(f, "illegal section end in PEM at '{end_marker}'")
77 }
78 }
79 }
80}
81
82impl std::error::Error for TlsConfigError {}
83
84pub struct TlsConfigBuilder {
86 cert: Box<dyn Read + Send + Sync>,
87 key: Box<dyn Read + Send + Sync>,
88}
89
90impl std::fmt::Debug for TlsConfigBuilder {
91 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> ::std::fmt::Result {
92 f.debug_struct("TlsConfigBuilder").finish()
93 }
94}
95
96impl TlsConfigBuilder {
97 pub fn new() -> TlsConfigBuilder {
99 TlsConfigBuilder {
100 key: Box::new(io::empty()),
101 cert: Box::new(io::empty()),
102 }
103 }
104
105 pub fn key_path(mut self, path: impl AsRef<Path>) -> Self {
107 self.key = Box::new(LazyFile {
108 path: path.as_ref().into(),
109 file: None,
110 });
111 self
112 }
113
114 pub fn key(mut self, key: &[u8]) -> Self {
116 self.key = Box::new(Cursor::new(Vec::from(key)));
117 self
118 }
119
120 pub fn cert_path(mut self, path: impl AsRef<Path>) -> Self {
122 self.cert = Box::new(LazyFile {
123 path: path.as_ref().into(),
124 file: None,
125 });
126 self
127 }
128
129 pub fn cert(mut self, cert: &[u8]) -> Self {
131 self.cert = Box::new(Cursor::new(Vec::from(cert)));
132 self
133 }
134
135 pub fn build(mut self) -> Result<ServerConfig, TlsConfigError> {
137 let mut cert_rdr = BufReader::new(self.cert);
138 let cert = CertificateDer::pem_reader_iter(&mut cert_rdr)
139 .collect::<Result<Vec<_>, _>>()
140 .map_err(|_e| TlsConfigError::CertParseError)?;
141
142 let mut key_buf = Vec::new();
144 self.key
145 .read_to_end(&mut key_buf)
146 .map_err(TlsConfigError::Io)?;
147
148 if key_buf.is_empty() {
149 return Err(TlsConfigError::EmptyKey);
150 }
151
152 let reader = Cursor::new(key_buf);
153 let key = PrivateKeyDer::from_pem_reader(reader).map_err(|err| match err {
154 rustls_pki_types::pem::Error::Base64Decode(_) => TlsConfigError::InvalidIdentityPem,
155 rustls_pki_types::pem::Error::NoItemsFound => TlsConfigError::EmptyKey,
156 rustls_pki_types::pem::Error::IllegalSectionStart { line } => {
157 TlsConfigError::IllegalSectionStart(line)
158 }
159 rustls_pki_types::pem::Error::MissingSectionEnd { end_marker } => {
160 TlsConfigError::IllegalSectionEnd(end_marker)
161 }
162 rustls_pki_types::pem::Error::Io(err) => {
163 TlsConfigError::Io(io::Error::new(io::ErrorKind::InvalidData, err))
164 }
165 _ => TlsConfigError::InvalidIdentityPem,
166 })?;
167
168 let mut config = ServerConfig::builder()
169 .with_no_client_auth()
170 .with_single_cert(cert, key)
171 .map_err(TlsConfigError::InvalidKey)?;
172 config.alpn_protocols = vec!["h2".into(), "http/1.1".into()];
173 Ok(config)
174 }
175}
176
177impl Default for TlsConfigBuilder {
178 fn default() -> Self {
179 Self::new()
180 }
181}
182
183struct LazyFile {
184 path: PathBuf,
185 file: Option<File>,
186}
187
188impl LazyFile {
189 fn lazy_read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
190 if self.file.is_none() {
191 self.file = Some(File::open(&self.path)?);
192 }
193
194 match self.file.as_mut() {
195 Some(file) => file.read(buf),
196 None => Err(io::Error::other("file handle unavailable after open")),
197 }
198 }
199}
200
201impl Read for LazyFile {
202 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
203 self.lazy_read(buf).map_err(|err| {
204 let kind = err.kind();
205 io::Error::new(
206 kind,
207 format!("error reading file ({:?}): {}", self.path.display(), err),
208 )
209 })
210 }
211}
212
213impl Transport for TlsStream {
214 fn remote_addr(&self) -> Option<SocketAddr> {
215 Some(self.remote_addr)
216 }
217}
218
219enum State {
220 Handshaking(tokio_rustls::Accept<AddrStream>),
221 Streaming(tokio_rustls::server::TlsStream<AddrStream>),
222}
223
224pub struct TlsStream {
229 state: State,
230 remote_addr: SocketAddr,
231}
232
233impl TlsStream {
234 fn new(stream: AddrStream, config: Arc<ServerConfig>) -> TlsStream {
235 let remote_addr = stream.remote_addr();
236 let accept = tokio_rustls::TlsAcceptor::from(config).accept(stream);
237 TlsStream {
238 state: State::Handshaking(accept),
239 remote_addr,
240 }
241 }
242}
243
244impl AsyncRead for TlsStream {
245 fn poll_read(
246 self: Pin<&mut Self>,
247 cx: &mut Context<'_>,
248 buf: &mut ReadBuf<'_>,
249 ) -> Poll<io::Result<()>> {
250 let pin = self.get_mut();
251 match pin.state {
252 State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) {
253 Ok(mut stream) => {
254 let result = Pin::new(&mut stream).poll_read(cx, buf);
255 pin.state = State::Streaming(stream);
256 result
257 }
258 Err(err) => Poll::Ready(Err(err)),
259 },
260 State::Streaming(ref mut stream) => Pin::new(stream).poll_read(cx, buf),
261 }
262 }
263}
264
265impl AsyncWrite for TlsStream {
266 fn poll_write(
267 self: Pin<&mut Self>,
268 cx: &mut Context<'_>,
269 buf: &[u8],
270 ) -> Poll<io::Result<usize>> {
271 let pin = self.get_mut();
272 match pin.state {
273 State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) {
274 Ok(mut stream) => {
275 let result = Pin::new(&mut stream).poll_write(cx, buf);
276 pin.state = State::Streaming(stream);
277 result
278 }
279 Err(err) => Poll::Ready(Err(err)),
280 },
281 State::Streaming(ref mut stream) => Pin::new(stream).poll_write(cx, buf),
282 }
283 }
284
285 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
286 match self.state {
287 State::Handshaking(_) => Poll::Ready(Ok(())),
288 State::Streaming(ref mut stream) => Pin::new(stream).poll_flush(cx),
289 }
290 }
291
292 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
293 match self.state {
294 State::Handshaking(_) => Poll::Ready(Ok(())),
295 State::Streaming(ref mut stream) => Pin::new(stream).poll_shutdown(cx),
296 }
297 }
298}
299
300pub struct TlsAcceptor {
302 config: Arc<ServerConfig>,
303 incoming: AddrIncoming,
304}
305
306impl TlsAcceptor {
307 pub fn new(config: ServerConfig, incoming: AddrIncoming) -> TlsAcceptor {
309 TlsAcceptor {
310 config: Arc::new(config),
311 incoming,
312 }
313 }
314}
315
316impl Accept for TlsAcceptor {
317 type Conn = TlsStream;
318 type Error = io::Error;
319
320 fn poll_accept(
321 self: Pin<&mut Self>,
322 cx: &mut Context<'_>,
323 ) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
324 let pin = self.get_mut();
325 match ready!(Pin::new(&mut pin.incoming).poll_accept(cx)) {
326 Some(Ok(sock)) => Poll::Ready(Some(Ok(TlsStream::new(sock, pin.config.clone())))),
327 Some(Err(e)) => Poll::Ready(Some(Err(e))),
328 None => Poll::Ready(None),
329 }
330 }
331}
332
333#[cfg(test)]
334mod tests {
335 use super::*;
336
337 #[test]
338 fn file_cert_key_rsa_pkcs1() {
339 TlsConfigBuilder::new()
340 .cert_path("tests/tls/local.dev_cert.rsa_pkcs1.pem")
341 .key_path("tests/tls/local.dev_key.rsa_pkcs1.pem")
342 .build()
343 .unwrap();
344 }
345
346 #[test]
347 fn bytes_cert_key_rsa_pkcs1() {
348 let cert = include_str!("../tests/tls/local.dev_cert.rsa_pkcs1.pem");
349 let key = include_str!("../tests/tls/local.dev_key.rsa_pkcs1.pem");
350
351 TlsConfigBuilder::new()
352 .key(key.as_bytes())
353 .cert(cert.as_bytes())
354 .build()
355 .unwrap();
356 }
357
358 #[test]
359 fn file_cert_key_pkcs8() {
360 TlsConfigBuilder::new()
361 .cert_path("tests/tls/local.dev_cert.pkcs8.pem")
362 .key_path("tests/tls/local.dev_key.pkcs8.pem")
363 .build()
364 .unwrap();
365 }
366
367 #[test]
368 fn bytes_cert_key_pkcs8() {
369 let cert = include_str!("../tests/tls/local.dev_cert.pkcs8.pem");
370 let key = include_str!("../tests/tls/local.dev_key.pkcs8.pem");
371
372 TlsConfigBuilder::new()
373 .key(key.as_bytes())
374 .cert(cert.as_bytes())
375 .build()
376 .unwrap();
377 }
378
379 #[test]
380 fn file_cert_key_sec1_ec() {
381 TlsConfigBuilder::new()
382 .cert_path("tests/tls/local.dev_cert.sec1_ec.pem")
383 .key_path("tests/tls/local.dev_key.sec1_ec.pem")
384 .build()
385 .unwrap();
386 }
387
388 #[test]
389 fn bytes_cert_key_sec1_ec() {
390 let cert = include_str!("../tests/tls/local.dev_cert.sec1_ec.pem");
391 let key = include_str!("../tests/tls/local.dev_key.sec1_ec.pem");
392
393 TlsConfigBuilder::new()
394 .key(key.as_bytes())
395 .cert(cert.as_bytes())
396 .build()
397 .unwrap();
398 }
399
400 #[cfg(feature = "http2-fips")]
401 #[test]
402 fn fips_mode_is_active() {
403 aws_lc_rs::try_fips_mode()
404 .expect("FIPS mode should be active when http2-fips feature is enabled");
405 }
406}