1#![warn(missing_docs)]
14#![cfg_attr(docsrs, feature(doc_auto_cfg))]
15
16use boring::ssl::{
17 self, ConnectConfiguration, ErrorCode, MidHandshakeSslStream, ShutdownResult, SslAcceptor,
18 SslRef,
19};
20use boring_sys as ffi;
21use std::error::Error;
22use std::fmt;
23use std::future::Future;
24use std::io::{self, Write};
25use std::pin::Pin;
26use std::task::{Context, Poll};
27use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
28
29mod async_callbacks;
30mod bridge;
31
32use self::bridge::AsyncStreamBridge;
33
34pub use crate::async_callbacks::SslContextBuilderExt;
35pub use boring::ssl::{
36 AsyncPrivateKeyMethod, AsyncPrivateKeyMethodError, AsyncSelectCertError, BoxGetSessionFinish,
37 BoxGetSessionFuture, BoxPrivateKeyMethodFinish, BoxPrivateKeyMethodFuture, BoxSelectCertFinish,
38 BoxSelectCertFuture, ExDataFuture,
39};
40
41pub async fn connect<S>(
46 config: ConnectConfiguration,
47 domain: &str,
48 stream: S,
49) -> Result<SslStream<S>, HandshakeError<S>>
50where
51 S: AsyncRead + AsyncWrite + Unpin,
52{
53 let mid_handshake = config
54 .setup_connect(domain, AsyncStreamBridge::new(stream))
55 .map_err(|err| HandshakeError(ssl::HandshakeError::SetupFailure(err)))?;
56
57 HandshakeFuture(Some(mid_handshake)).await
58}
59
60pub async fn accept<S>(acceptor: &SslAcceptor, stream: S) -> Result<SslStream<S>, HandshakeError<S>>
65where
66 S: AsyncRead + AsyncWrite + Unpin,
67{
68 let mid_handshake = acceptor
69 .setup_accept(AsyncStreamBridge::new(stream))
70 .map_err(|err| HandshakeError(ssl::HandshakeError::SetupFailure(err)))?;
71
72 HandshakeFuture(Some(mid_handshake)).await
73}
74
75fn cvt<T>(r: io::Result<T>) -> Poll<io::Result<T>> {
76 match r {
77 Ok(v) => Poll::Ready(Ok(v)),
78 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
79 Err(e) => Poll::Ready(Err(e)),
80 }
81}
82
83pub struct SslStreamBuilder<S> {
85 inner: ssl::SslStreamBuilder<AsyncStreamBridge<S>>,
86}
87
88impl<S> SslStreamBuilder<S>
89where
90 S: AsyncRead + AsyncWrite + Unpin,
91{
92 pub fn new(ssl: ssl::Ssl, stream: S) -> Self {
94 Self {
95 inner: ssl::SslStreamBuilder::new(ssl, AsyncStreamBridge::new(stream)),
96 }
97 }
98
99 pub async fn accept(self) -> Result<SslStream<S>, HandshakeError<S>> {
101 let mid_handshake = self.inner.setup_accept();
102
103 HandshakeFuture(Some(mid_handshake)).await
104 }
105
106 pub async fn connect(self) -> Result<SslStream<S>, HandshakeError<S>> {
108 let mid_handshake = self.inner.setup_connect();
109
110 HandshakeFuture(Some(mid_handshake)).await
111 }
112}
113
114impl<S> SslStreamBuilder<S> {
115 pub fn ssl(&self) -> &SslRef {
117 self.inner.ssl()
118 }
119
120 pub fn ssl_mut(&mut self) -> &mut SslRef {
122 self.inner.ssl_mut()
123 }
124}
125
126#[derive(Debug)]
134pub struct SslStream<S>(ssl::SslStream<AsyncStreamBridge<S>>);
135
136impl<S> SslStream<S> {
137 pub fn ssl(&self) -> &SslRef {
139 self.0.ssl()
140 }
141
142 pub fn ssl_mut(&mut self) -> &mut SslRef {
144 self.0.ssl_mut()
145 }
146
147 pub fn get_ref(&self) -> &S {
149 &self.0.get_ref().stream
150 }
151
152 pub fn get_mut(&mut self) -> &mut S {
154 &mut self.0.get_mut().stream
155 }
156
157 fn run_in_context<F, R>(&mut self, ctx: &mut Context<'_>, f: F) -> R
158 where
159 F: FnOnce(&mut ssl::SslStream<AsyncStreamBridge<S>>) -> R,
160 {
161 self.0.get_mut().set_waker(Some(ctx));
162
163 let result = f(&mut self.0);
164
165 self.0.get_mut().set_waker(None);
170
171 result
172 }
173}
174
175impl<S> SslStream<S>
176where
177 S: AsyncRead + AsyncWrite + Unpin,
178{
179 pub unsafe fn from_raw_parts(ssl: *mut ffi::SSL, stream: S) -> Self {
187 Self(ssl::SslStream::from_raw_parts(
188 ssl,
189 AsyncStreamBridge::new(stream),
190 ))
191 }
192}
193
194impl<S> AsyncRead for SslStream<S>
195where
196 S: AsyncRead + AsyncWrite + Unpin,
197{
198 fn poll_read(
199 mut self: Pin<&mut Self>,
200 ctx: &mut Context<'_>,
201 buf: &mut ReadBuf,
202 ) -> Poll<io::Result<()>> {
203 self.run_in_context(ctx, |s| {
204 match cvt(s.read_uninit(unsafe { buf.unfilled_mut() }))? {
206 Poll::Ready(nread) => {
207 unsafe {
208 buf.assume_init(nread);
209 }
210 buf.advance(nread);
211 Poll::Ready(Ok(()))
212 }
213 Poll::Pending => Poll::Pending,
214 }
215 })
216 }
217}
218
219impl<S> AsyncWrite for SslStream<S>
220where
221 S: AsyncRead + AsyncWrite + Unpin,
222{
223 fn poll_write(
224 mut self: Pin<&mut Self>,
225 ctx: &mut Context,
226 buf: &[u8],
227 ) -> Poll<io::Result<usize>> {
228 self.run_in_context(ctx, |s| cvt(s.write(buf)))
229 }
230
231 fn poll_flush(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
232 self.run_in_context(ctx, |s| cvt(s.flush()))
233 }
234
235 fn poll_shutdown(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
236 match self.run_in_context(ctx, |s| s.shutdown()) {
237 Ok(ShutdownResult::Sent) | Ok(ShutdownResult::Received) => {}
238 Err(ref e) if e.code() == ErrorCode::ZERO_RETURN => {}
239 Err(ref e) if e.code() == ErrorCode::WANT_READ || e.code() == ErrorCode::WANT_WRITE => {
240 return Poll::Pending;
241 }
242 Err(e) => {
243 return Poll::Ready(Err(e
244 .into_io_error()
245 .unwrap_or_else(|e| io::Error::new(io::ErrorKind::Other, e))));
246 }
247 }
248
249 Pin::new(&mut self.0.get_mut().stream).poll_shutdown(ctx)
250 }
251}
252
253pub struct HandshakeError<S>(ssl::HandshakeError<AsyncStreamBridge<S>>);
255
256impl<S> HandshakeError<S> {
257 pub fn ssl(&self) -> Option<&SslRef> {
259 match &self.0 {
260 ssl::HandshakeError::Failure(s) => Some(s.ssl()),
261 _ => None,
262 }
263 }
264
265 pub fn into_source_stream(self) -> Option<S> {
267 match self.0 {
268 ssl::HandshakeError::Failure(s) => Some(s.into_source_stream().stream),
269 _ => None,
270 }
271 }
272
273 pub fn as_source_stream(&self) -> Option<&S> {
275 match &self.0 {
276 ssl::HandshakeError::Failure(s) => Some(&s.get_ref().stream),
277 _ => None,
278 }
279 }
280
281 pub fn code(&self) -> Option<ErrorCode> {
283 match &self.0 {
284 ssl::HandshakeError::Failure(s) => Some(s.error().code()),
285 _ => None,
286 }
287 }
288
289 pub fn as_io_error(&self) -> Option<&io::Error> {
291 match &self.0 {
292 ssl::HandshakeError::Failure(s) => s.error().io_error(),
293 _ => None,
294 }
295 }
296}
297
298impl<S> fmt::Debug for HandshakeError<S>
299where
300 S: fmt::Debug,
301{
302 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
303 fmt::Debug::fmt(&self.0, fmt)
304 }
305}
306
307impl<S> fmt::Display for HandshakeError<S> {
308 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
309 fmt::Display::fmt(&self.0, fmt)
310 }
311}
312
313impl<S> Error for HandshakeError<S>
314where
315 S: fmt::Debug,
316{
317 fn source(&self) -> Option<&(dyn Error + 'static)> {
318 self.0.source()
319 }
320}
321
322pub struct HandshakeFuture<S>(Option<MidHandshakeSslStream<AsyncStreamBridge<S>>>);
326
327impl<S> Future for HandshakeFuture<S>
328where
329 S: AsyncRead + AsyncWrite + Unpin,
330{
331 type Output = Result<SslStream<S>, HandshakeError<S>>;
332
333 fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
334 let mut mid_handshake = self.0.take().expect("future polled after completion");
335
336 mid_handshake.get_mut().set_waker(Some(ctx));
337 mid_handshake
338 .ssl_mut()
339 .set_task_waker(Some(ctx.waker().clone()));
340
341 match mid_handshake.handshake() {
342 Ok(mut stream) => {
343 stream.get_mut().set_waker(None);
344 stream.ssl_mut().set_task_waker(None);
345
346 Poll::Ready(Ok(SslStream(stream)))
347 }
348 Err(ssl::HandshakeError::WouldBlock(mut mid_handshake)) => {
349 mid_handshake.get_mut().set_waker(None);
350 mid_handshake.ssl_mut().set_task_waker(None);
351
352 self.0 = Some(mid_handshake);
353
354 Poll::Pending
355 }
356 Err(ssl::HandshakeError::Failure(mut mid_handshake)) => {
357 mid_handshake.get_mut().set_waker(None);
358
359 Poll::Ready(Err(HandshakeError(ssl::HandshakeError::Failure(
360 mid_handshake,
361 ))))
362 }
363 Err(err @ ssl::HandshakeError::SetupFailure(_)) => {
364 Poll::Ready(Err(HandshakeError(err)))
365 }
366 }
367 }
368}