1#![warn(missing_docs)]
14#![cfg_attr(docsrs, feature(doc_auto_cfg))]
15
16use boring::ssl::{
17 self, ConnectConfiguration, ErrorCode, MidHandshakeSslStream, ShutdownResult, SslAcceptor,
18 SslRef,
19};
20use boring_sys as ffi;
21use std::error::Error;
22use std::fmt;
23use std::future::Future;
24use std::io::{self, Write};
25use std::pin::Pin;
26use std::task::{Context, Poll};
27use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
28
29mod async_callbacks;
30mod bridge;
31
32use self::bridge::AsyncStreamBridge;
33
34pub use crate::async_callbacks::SslContextBuilderExt;
35pub use boring::ssl::{
36 AsyncPrivateKeyMethod, AsyncPrivateKeyMethodError, AsyncSelectCertError, BoxGetSessionFinish,
37 BoxGetSessionFuture, BoxPrivateKeyMethodFinish, BoxPrivateKeyMethodFuture, BoxSelectCertFinish,
38 BoxSelectCertFuture, ExDataFuture,
39};
40
41pub async fn connect<S>(
46 config: ConnectConfiguration,
47 domain: &str,
48 stream: S,
49) -> Result<SslStream<S>, HandshakeError<S>>
50where
51 S: AsyncRead + AsyncWrite + Unpin,
52{
53 let mid_handshake = config
54 .setup_connect(domain, AsyncStreamBridge::new(stream))
55 .map_err(|err| HandshakeError(ssl::HandshakeError::SetupFailure(err)))?;
56
57 HandshakeFuture(Some(mid_handshake)).await
58}
59
60pub async fn accept<S>(acceptor: &SslAcceptor, stream: S) -> Result<SslStream<S>, HandshakeError<S>>
65where
66 S: AsyncRead + AsyncWrite + Unpin,
67{
68 let mid_handshake = acceptor
69 .setup_accept(AsyncStreamBridge::new(stream))
70 .map_err(|err| HandshakeError(ssl::HandshakeError::SetupFailure(err)))?;
71
72 HandshakeFuture(Some(mid_handshake)).await
73}
74
75fn cvt<T>(r: io::Result<T>) -> Poll<io::Result<T>> {
76 match r {
77 Ok(v) => Poll::Ready(Ok(v)),
78 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
79 Err(e) => Poll::Ready(Err(e)),
80 }
81}
82
83pub struct SslStreamBuilder<S> {
85 inner: ssl::SslStreamBuilder<AsyncStreamBridge<S>>,
86}
87
88impl<S> SslStreamBuilder<S>
89where
90 S: AsyncRead + AsyncWrite + Unpin,
91{
92 pub fn new(ssl: ssl::Ssl, stream: S) -> Self {
94 Self {
95 inner: ssl::SslStreamBuilder::new(ssl, AsyncStreamBridge::new(stream)),
96 }
97 }
98
99 pub async fn accept(self) -> Result<SslStream<S>, HandshakeError<S>> {
101 let mid_handshake = self.inner.setup_accept();
102
103 HandshakeFuture(Some(mid_handshake)).await
104 }
105
106 pub async fn connect(self) -> Result<SslStream<S>, HandshakeError<S>> {
108 let mid_handshake = self.inner.setup_connect();
109
110 HandshakeFuture(Some(mid_handshake)).await
111 }
112}
113
114impl<S> SslStreamBuilder<S> {
115 #[must_use]
117 pub fn ssl(&self) -> &SslRef {
118 self.inner.ssl()
119 }
120
121 pub fn ssl_mut(&mut self) -> &mut SslRef {
123 self.inner.ssl_mut()
124 }
125}
126
127#[derive(Debug)]
135pub struct SslStream<S>(ssl::SslStream<AsyncStreamBridge<S>>);
136
137impl<S> SslStream<S> {
138 #[must_use]
140 pub fn ssl(&self) -> &SslRef {
141 self.0.ssl()
142 }
143
144 pub fn ssl_mut(&mut self) -> &mut SslRef {
146 self.0.ssl_mut()
147 }
148
149 #[must_use]
151 pub fn get_ref(&self) -> &S {
152 &self.0.get_ref().stream
153 }
154
155 pub fn get_mut(&mut self) -> &mut S {
157 &mut self.0.get_mut().stream
158 }
159
160 fn run_in_context<F, R>(&mut self, ctx: &mut Context<'_>, f: F) -> R
161 where
162 F: FnOnce(&mut ssl::SslStream<AsyncStreamBridge<S>>) -> R,
163 {
164 self.0.get_mut().set_waker(Some(ctx));
165
166 let result = f(&mut self.0);
167
168 self.0.get_mut().set_waker(None);
173
174 result
175 }
176}
177
178impl<S> SslStream<S>
179where
180 S: AsyncRead + AsyncWrite + Unpin,
181{
182 pub unsafe fn from_raw_parts(ssl: *mut ffi::SSL, stream: S) -> Self {
190 Self(ssl::SslStream::from_raw_parts(
191 ssl,
192 AsyncStreamBridge::new(stream),
193 ))
194 }
195}
196
197impl<S> AsyncRead for SslStream<S>
198where
199 S: AsyncRead + AsyncWrite + Unpin,
200{
201 fn poll_read(
202 mut self: Pin<&mut Self>,
203 ctx: &mut Context<'_>,
204 buf: &mut ReadBuf,
205 ) -> Poll<io::Result<()>> {
206 self.run_in_context(ctx, |s| {
207 match cvt(s.read_uninit(unsafe { buf.unfilled_mut() }))? {
209 Poll::Ready(nread) => {
210 unsafe {
211 buf.assume_init(nread);
212 }
213 buf.advance(nread);
214 Poll::Ready(Ok(()))
215 }
216 Poll::Pending => Poll::Pending,
217 }
218 })
219 }
220}
221
222impl<S> AsyncWrite for SslStream<S>
223where
224 S: AsyncRead + AsyncWrite + Unpin,
225{
226 fn poll_write(
227 mut self: Pin<&mut Self>,
228 ctx: &mut Context,
229 buf: &[u8],
230 ) -> Poll<io::Result<usize>> {
231 self.run_in_context(ctx, |s| cvt(s.write(buf)))
232 }
233
234 fn poll_flush(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
235 self.run_in_context(ctx, |s| cvt(s.flush()))
236 }
237
238 fn poll_shutdown(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
239 match self.run_in_context(ctx, |s| s.shutdown()) {
240 Ok(ShutdownResult::Sent | ShutdownResult::Received) => {}
241 Err(ref e) if e.code() == ErrorCode::ZERO_RETURN => {}
242 Err(ref e) if e.code() == ErrorCode::WANT_READ || e.code() == ErrorCode::WANT_WRITE => {
243 return Poll::Pending;
244 }
245 Err(e) => {
246 return Poll::Ready(Err(e.into_io_error().unwrap_or_else(io::Error::other)));
247 }
248 }
249
250 Pin::new(&mut self.0.get_mut().stream).poll_shutdown(ctx)
251 }
252}
253
254pub struct HandshakeError<S>(ssl::HandshakeError<AsyncStreamBridge<S>>);
256
257impl<S> HandshakeError<S> {
258 #[must_use]
260 pub fn ssl(&self) -> Option<&SslRef> {
261 match &self.0 {
262 ssl::HandshakeError::Failure(s) => Some(s.ssl()),
263 _ => None,
264 }
265 }
266
267 #[must_use]
269 pub fn into_source_stream(self) -> Option<S> {
270 match self.0 {
271 ssl::HandshakeError::Failure(s) => Some(s.into_source_stream().stream),
272 _ => None,
273 }
274 }
275
276 #[must_use]
278 pub fn as_source_stream(&self) -> Option<&S> {
279 match &self.0 {
280 ssl::HandshakeError::Failure(s) => Some(&s.get_ref().stream),
281 _ => None,
282 }
283 }
284
285 #[must_use]
287 pub fn code(&self) -> Option<ErrorCode> {
288 match &self.0 {
289 ssl::HandshakeError::Failure(s) => Some(s.error().code()),
290 _ => None,
291 }
292 }
293
294 #[must_use]
296 pub fn as_io_error(&self) -> Option<&io::Error> {
297 match &self.0 {
298 ssl::HandshakeError::Failure(s) => s.error().io_error(),
299 _ => None,
300 }
301 }
302}
303
304impl<S> fmt::Debug for HandshakeError<S>
305where
306 S: fmt::Debug,
307{
308 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
309 fmt::Debug::fmt(&self.0, fmt)
310 }
311}
312
313impl<S> fmt::Display for HandshakeError<S> {
314 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
315 fmt::Display::fmt(&self.0, fmt)
316 }
317}
318
319impl<S> Error for HandshakeError<S>
320where
321 S: fmt::Debug,
322{
323 fn source(&self) -> Option<&(dyn Error + 'static)> {
324 self.0.source()
325 }
326}
327
328pub struct HandshakeFuture<S>(Option<MidHandshakeSslStream<AsyncStreamBridge<S>>>);
332
333impl<S> Future for HandshakeFuture<S>
334where
335 S: AsyncRead + AsyncWrite + Unpin,
336{
337 type Output = Result<SslStream<S>, HandshakeError<S>>;
338
339 fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
340 let mut mid_handshake = self.0.take().expect("future polled after completion");
341
342 mid_handshake.get_mut().set_waker(Some(ctx));
343 mid_handshake
344 .ssl_mut()
345 .set_task_waker(Some(ctx.waker().clone()));
346
347 match mid_handshake.handshake() {
348 Ok(mut stream) => {
349 stream.get_mut().set_waker(None);
350 stream.ssl_mut().set_task_waker(None);
351
352 Poll::Ready(Ok(SslStream(stream)))
353 }
354 Err(ssl::HandshakeError::WouldBlock(mut mid_handshake)) => {
355 mid_handshake.get_mut().set_waker(None);
356 mid_handshake.ssl_mut().set_task_waker(None);
357
358 self.0 = Some(mid_handshake);
359
360 Poll::Pending
361 }
362 Err(ssl::HandshakeError::Failure(mut mid_handshake)) => {
363 mid_handshake.get_mut().set_waker(None);
364
365 Poll::Ready(Err(HandshakeError(ssl::HandshakeError::Failure(
366 mid_handshake,
367 ))))
368 }
369 Err(err @ ssl::HandshakeError::SetupFailure(_)) => {
370 Poll::Ready(Err(HandshakeError(err)))
371 }
372 }
373 }
374}