1#![doc(
58 html_logo_url = "https://www.libressl.org/images/libressl.jpg",
59 html_favicon_url = "https://www.libressl.org/favicon.ico"
60)]
61#![warn(missing_docs)]
62
63pub mod error;
65
66pub mod prelude;
68
69use error::Error;
70use libtls::{config::Config, error::Error as TlsError, tls::Tls};
71use mio::{event::Evented, unix::EventedFd, PollOpt, Ready, Token};
72use prelude::*;
73use std::{
74 io::{self, Read, Write},
75 net::ToSocketAddrs,
76 ops::{Deref, DerefMut},
77 os::unix::io::{AsRawFd, RawFd},
78 pin::Pin,
79 task::{Context, Poll},
80 time::Duration,
81};
82use tokio::{
83 io::{AsyncRead, AsyncWrite, PollEvented},
84 net::{TcpListener, TcpStream},
85 time::timeout,
86};
87
88macro_rules! try_async_tls {
89 ($call: expr) => {
90 match $call {
91 Ok(size) => Poll::Ready(Ok(size)),
92 Err(err) => {
93 let err: io::Error = err.into();
94 if err.kind() == io::ErrorKind::WouldBlock {
95 Poll::Pending
96 } else {
97 Poll::Ready(Err(err))
98 }
99 }
100 }
101 };
102}
103
104#[derive(Debug)]
106pub struct TlsStream {
107 tls: Tls,
108 tcp: TcpStream,
109}
110
111impl TlsStream {
112 pub fn new(tls: Tls, tcp: TcpStream) -> Self {
114 Self { tls, tcp }
115 }
116}
117
118impl Deref for TlsStream {
119 type Target = Tls;
120
121 fn deref(&self) -> &Self::Target {
122 &self.tls
123 }
124}
125
126impl DerefMut for TlsStream {
127 fn deref_mut(&mut self) -> &mut Self::Target {
128 &mut self.tls
129 }
130}
131
132impl AsRawFd for TlsStream {
133 fn as_raw_fd(&self) -> RawFd {
134 self.tcp.as_raw_fd()
135 }
136}
137
138impl io::Read for TlsStream {
139 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
140 self.tls.read(buf)
141 }
142}
143
144impl io::Write for TlsStream {
145 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
146 self.tls.write(buf)
147 }
148
149 fn flush(&mut self) -> io::Result<()> {
150 self.tls.flush()
151 }
152}
153
154impl AsyncRead for TlsStream {
155 fn poll_read(
156 mut self: Pin<&mut Self>,
157 _cx: &mut Context<'_>,
158 buf: &mut [u8],
159 ) -> Poll<Result<usize, io::Error>> {
160 try_async_tls!(self.tls.read(buf))
161 }
162}
163
164impl AsyncWrite for TlsStream {
165 fn poll_write(
166 mut self: Pin<&mut Self>,
167 _cx: &mut Context<'_>,
168 buf: &[u8],
169 ) -> Poll<Result<usize, io::Error>> {
170 try_async_tls!(self.tls.write(buf))
171 }
172
173 fn poll_flush(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
174 try_async_tls!(self.tls.close()).map(|_| Ok(()))
175 }
176
177 fn poll_shutdown(
178 mut self: Pin<&mut Self>,
179 _cx: &mut Context<'_>,
180 ) -> Poll<Result<(), io::Error>> {
181 try_async_tls!(self.tls.close()).map(|_| Ok(()))
182 }
183}
184
185impl Evented for TlsStream {
186 fn register(
187 &self,
188 poll: &mio::Poll,
189 token: Token,
190 interest: Ready,
191 opts: PollOpt,
192 ) -> io::Result<()> {
193 match EventedFd(&self.as_raw_fd()).register(poll, token, interest, opts) {
194 Err(ref err) if err.kind() == io::ErrorKind::AlreadyExists => {
195 self.reregister(poll, token, interest, opts)
196 }
197 Err(err) => Err(err),
198 Ok(_) => Ok(()),
199 }
200 }
201
202 fn reregister(
203 &self,
204 poll: &mio::Poll,
205 token: Token,
206 interest: Ready,
207 opts: PollOpt,
208 ) -> io::Result<()> {
209 EventedFd(&self.as_raw_fd()).reregister(poll, token, interest, opts)
210 }
211
212 fn deregister(&self, poll: &mio::Poll) -> io::Result<()> {
213 EventedFd(&self.as_raw_fd()).deregister(poll)
214 }
215}
216
217unsafe impl Send for TlsStream {}
218unsafe impl Sync for TlsStream {}
219
220pub type AsyncTlsStream = PollEvented<TlsStream>;
222
223pub struct AsyncTls {
225 inner: Option<Result<AsyncTlsStream, Error>>,
226}
227
228impl AsyncTls {
229 #[deprecated(since = "1.1.1", note = "Please use module function `accept_stream`")]
231 pub async fn accept_stream(
232 tcp: TcpStream,
233 config: &Config,
234 options: Option<Options>,
235 ) -> io::Result<AsyncTlsStream> {
236 accept_stream(tcp, config, options).await
237 }
238
239 #[deprecated(since = "1.1.1", note = "Please use module function `connect_stream`")]
241 pub async fn connect_stream(
242 tcp: TcpStream,
243 config: &Config,
244 options: Option<Options>,
245 ) -> io::Result<AsyncTlsStream> {
246 connect_stream(tcp, config, options).await
247 }
248
249 #[deprecated(since = "1.1.1", note = "Please use module function `connect`")]
251 pub async fn connect(
252 host: &str,
253 config: &Config,
254 options: Option<Options>,
255 ) -> io::Result<AsyncTlsStream> {
256 connect(host, config, options).await
257 }
258}
259
260impl Future for AsyncTls {
261 type Output = Result<AsyncTlsStream, io::Error>;
262
263 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
264 let inner = self
265 .inner
266 .take()
267 .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "cannot take inner"))?;
268 match inner {
269 Ok(tls) => {
270 cx.waker().wake_by_ref();
271 Poll::Ready(Ok(tls))
272 }
273 Err(Error::Readable(stream)) => {
274 self.inner = match stream.poll_read_ready(cx, Ready::readable()) {
275 Poll::Ready(_) => Some(Err(Error::Handshake(stream))),
276 _ => Some(Err(Error::Handshake(stream))),
277 };
278 cx.waker().wake_by_ref();
279 Poll::Pending
280 }
281 Err(Error::Writeable(stream)) => {
282 self.inner = match stream.poll_write_ready(cx) {
283 Poll::Ready(_) => Some(Err(Error::Handshake(stream))),
284 _ => Some(Err(Error::Writeable(stream))),
285 };
286 cx.waker().wake_by_ref();
287 Poll::Pending
288 }
289 Err(Error::Handshake(mut stream)) => {
290 let tls = &mut *stream.get_mut();
291 let res = match tls.tls_handshake() {
292 Ok(res) => {
293 if res == libtls::TLS_WANT_POLLIN as isize {
294 Err(Error::Readable(stream))
295 } else if res == libtls::TLS_WANT_POLLOUT as isize {
296 Err(Error::Writeable(stream))
297 } else {
298 Ok(stream)
299 }
300 }
301 Err(err) => Err(err.into()),
302 };
303 self.inner = Some(res);
304 cx.waker().wake_by_ref();
305 Poll::Pending
306 }
307 Err(Error::Error(TlsError::IoError(err))) => Poll::Ready(Err(err)),
308 Err(Error::Error(err)) => {
309 Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, err.to_string())))
310 }
311 }
312 }
313}
314
315unsafe impl Send for AsyncTls {}
316unsafe impl Sync for AsyncTls {}
317
318pub async fn accept(
320 listener: &mut TcpListener,
321 config: &Config,
322 options: Option<Options>,
323) -> io::Result<AsyncTlsStream> {
324 let options = options.unwrap_or_else(Options::new);
325
326 let (tcp, _) = listener.accept().await?;
327 let mut server = Tls::server()?;
328 server.configure(config)?;
329 let client = server.accept_raw_fd(&tcp)?;
330
331 let async_tls = TlsStream::new(client, tcp);
332 let stream = PollEvented::new(async_tls)?;
333 let fut = AsyncTls {
334 inner: Some(Err(Error::Readable(stream))),
335 };
336
337 let tls = match options.timeout {
339 Some(tm) => match timeout(tm, fut).await {
340 Ok(res) => res,
341 Err(err) => Err(err.into()),
342 },
343 None => fut.await,
344 }?;
345
346 Ok(tls)
347}
348
349pub async fn accept_stream(
351 tcp: TcpStream,
352 config: &Config,
353 options: Option<Options>,
354) -> io::Result<AsyncTlsStream> {
355 let options = options.unwrap_or_else(Options::new);
356
357 let mut server = Tls::server()?;
358 server.configure(config)?;
359 let client = server.accept_raw_fd(&tcp)?;
360
361 let async_tls = TlsStream::new(client, tcp);
362 let stream = PollEvented::new(async_tls)?;
363 let fut = AsyncTls {
364 inner: Some(Err(Error::Readable(stream))),
365 };
366
367 let tls = match options.timeout {
369 Some(tm) => match timeout(tm, fut).await {
370 Ok(res) => res,
371 Err(err) => Err(err.into()),
372 },
373 None => fut.await,
374 }?;
375
376 Ok(tls)
377}
378
379pub async fn connect_stream(
381 tcp: TcpStream,
382 config: &Config,
383 options: Option<Options>,
384) -> io::Result<AsyncTlsStream> {
385 let options = options.unwrap_or_else(Options::new);
386 let servername = match options.servername {
387 Some(name) => name,
388 None => tcp.peer_addr()?.to_string(),
389 };
390
391 let mut tls = Tls::client()?;
392
393 tls.configure(config)?;
394 tls.connect_raw_fd(&tcp, &servername)?;
395
396 let async_tls = TlsStream::new(tls, tcp);
397 let stream = PollEvented::new(async_tls)?;
398 let fut = AsyncTls {
399 inner: Some(Err(Error::Readable(stream))),
400 };
401
402 let tls = match options.timeout {
404 Some(tm) => match timeout(tm, fut).await {
405 Ok(res) => res,
406 Err(err) => Err(err.into()),
407 },
408 None => fut.await,
409 }?;
410
411 Ok(tls)
412}
413
414pub async fn connect(
416 host: &str,
417 config: &Config,
418 options: Option<Options>,
419) -> io::Result<AsyncTlsStream> {
420 let mut options = options.unwrap_or_else(Options::new);
421
422 if options.servername.is_none() {
424 match host.rfind(':') {
425 None => return Err(io::ErrorKind::InvalidInput.into()),
426 Some(index) => options.servername(&host[0..index]),
427 };
428 };
429
430 let mut last_error = io::ErrorKind::ConnectionRefused.into();
431
432 for addr in host.to_socket_addrs()? {
433 let res = match options.timeout {
435 Some(tm) => match timeout(tm, TcpStream::connect(&addr)).await {
436 Ok(res) => res,
437 Err(err) => Err(err.into()),
438 },
439 None => TcpStream::connect(&addr).await,
440 };
441
442 match res {
444 Ok(tcp) => {
445 return connect_stream(tcp, config, Some(options)).await;
446 }
447 Err(err) => last_error = err,
448 }
449 }
450
451 Err(last_error)
452}
453
454#[derive(Clone, Default, Debug, PartialEq)]
462pub struct Options {
463 timeout: Option<Duration>,
464 servername: Option<String>,
465}
466
467#[deprecated(
469 since = "1.1.1",
470 note = "Please use `Options` instead of `AsyncTlsOptions`"
471)]
472pub type AsyncTlsOptions = Options;
473
474impl Options {
475 pub fn new() -> Self {
477 Default::default()
478 }
479
480 pub fn timeout(&'_ mut self, timeout: Duration) -> &'_ mut Self {
482 self.timeout = Some(timeout);
483 self
484 }
485
486 pub fn servername(&'_ mut self, servername: &str) -> &'_ mut Self {
490 self.servername = Some(servername.to_owned());
491 self
492 }
493
494 pub fn build(&'_ mut self) -> Option<Self> {
496 if self == &mut Self::new() {
497 None
498 } else {
499 Some(self.clone())
500 }
501 }
502}
503
504#[cfg(test)]
505mod test {
506 use crate::prelude::*;
507 use std::{io, time::Duration};
508 use tokio::io::{AsyncReadExt, AsyncWriteExt};
509
510 async fn async_https_connect(servername: String) -> io::Result<()> {
511 let request = format!(
512 "GET / HTTP/1.1\r\n\
513 Host: {}\r\n\
514 Connection: close\r\n\r\n",
515 servername
516 );
517
518 let config = Builder::new().build()?;
519 let options = Options::new()
520 .servername(&servername)
521 .timeout(Duration::from_secs(60))
522 .build();
523 let mut tls = connect(&(servername + ":443"), &config, options).await?;
524 tls.write_all(request.as_bytes()).await?;
525
526 let mut buf = vec![0u8; 1024];
527 tls.read_exact(&mut buf).await?;
528
529 let ok = b"HTTP/1.1 200 OK\r\n";
530 assert_eq!(&buf[..ok.len()], ok);
531
532 Ok(())
533 }
534
535 #[tokio::test]
536 async fn test_async_https_connect() {
537 async_https_connect("www.example.com".to_owned())
538 .await
539 .unwrap();
540 }
541}