1#![warn(missing_docs)]
14
15use boring::ssl::{
16 self, ConnectConfiguration, ErrorCode, MidHandshakeSslStream, ShutdownResult, SslAcceptor,
17 SslRef,
18};
19use boring_sys as ffi;
20use std::error::Error;
21use std::fmt;
22use std::future::Future;
23use std::io::{self, Write};
24use std::pin::Pin;
25use std::task::{Context, Poll};
26use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
27
28mod async_callbacks;
29mod bridge;
30
31use self::bridge::AsyncStreamBridge;
32
33pub use crate::async_callbacks::SslContextBuilderExt;
34pub use boring::ssl::{
35 AsyncPrivateKeyMethod, AsyncPrivateKeyMethodError, AsyncSelectCertError, BoxGetSessionFinish,
36 BoxGetSessionFuture, BoxPrivateKeyMethodFinish, BoxPrivateKeyMethodFuture, BoxSelectCertFinish,
37 BoxSelectCertFuture, ExDataFuture,
38};
39
40pub async fn connect<S>(
45 config: ConnectConfiguration,
46 domain: &str,
47 stream: S,
48) -> Result<SslStream<S>, HandshakeError<S>>
49where
50 S: AsyncRead + AsyncWrite + Unpin,
51{
52 let mid_handshake = config
53 .setup_connect(domain, AsyncStreamBridge::new(stream))
54 .map_err(|err| HandshakeError(ssl::HandshakeError::SetupFailure(err)))?;
55
56 HandshakeFuture(Some(mid_handshake)).await
57}
58
59pub async fn accept<S>(acceptor: &SslAcceptor, stream: S) -> Result<SslStream<S>, HandshakeError<S>>
64where
65 S: AsyncRead + AsyncWrite + Unpin,
66{
67 let mid_handshake = acceptor
68 .setup_accept(AsyncStreamBridge::new(stream))
69 .map_err(|err| HandshakeError(ssl::HandshakeError::SetupFailure(err)))?;
70
71 HandshakeFuture(Some(mid_handshake)).await
72}
73
74fn cvt<T>(r: io::Result<T>) -> Poll<io::Result<T>> {
75 match r {
76 Ok(v) => Poll::Ready(Ok(v)),
77 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
78 Err(e) => Poll::Ready(Err(e)),
79 }
80}
81
82pub struct SslStreamBuilder<S> {
84 inner: ssl::SslStreamBuilder<AsyncStreamBridge<S>>,
85}
86
87impl<S> SslStreamBuilder<S>
88where
89 S: AsyncRead + AsyncWrite + Unpin,
90{
91 pub fn new(ssl: ssl::Ssl, stream: S) -> Self {
93 Self {
94 inner: ssl::SslStreamBuilder::new(ssl, AsyncStreamBridge::new(stream)),
95 }
96 }
97
98 pub async fn accept(self) -> Result<SslStream<S>, HandshakeError<S>> {
100 let mid_handshake = self.inner.setup_accept();
101
102 HandshakeFuture(Some(mid_handshake)).await
103 }
104
105 pub async fn connect(self) -> Result<SslStream<S>, HandshakeError<S>> {
107 let mid_handshake = self.inner.setup_connect();
108
109 HandshakeFuture(Some(mid_handshake)).await
110 }
111}
112
113impl<S> SslStreamBuilder<S> {
114 #[must_use]
116 pub fn ssl(&self) -> &SslRef {
117 self.inner.ssl()
118 }
119
120 pub fn ssl_mut(&mut self) -> &mut SslRef {
122 self.inner.ssl_mut()
123 }
124}
125
126#[derive(Debug)]
134pub struct SslStream<S>(ssl::SslStream<AsyncStreamBridge<S>>);
135
136impl<S> SslStream<S> {
137 #[must_use]
139 pub fn ssl(&self) -> &SslRef {
140 self.0.ssl()
141 }
142
143 pub fn ssl_mut(&mut self) -> &mut SslRef {
145 self.0.ssl_mut()
146 }
147
148 #[must_use]
150 pub fn get_ref(&self) -> &S {
151 &self.0.get_ref().stream
152 }
153
154 pub fn get_mut(&mut self) -> &mut S {
156 &mut self.0.get_mut().stream
157 }
158
159 fn run_in_context<F, R>(&mut self, ctx: &mut Context<'_>, f: F) -> R
160 where
161 F: FnOnce(&mut ssl::SslStream<AsyncStreamBridge<S>>) -> R,
162 {
163 self.0.get_mut().set_waker(Some(ctx));
164
165 let result = f(&mut self.0);
166
167 self.0.get_mut().set_waker(None);
172
173 result
174 }
175}
176
177impl<S> SslStream<S>
178where
179 S: AsyncRead + AsyncWrite + Unpin,
180{
181 pub unsafe fn from_raw_parts(ssl: *mut ffi::SSL, stream: S) -> Self {
189 Self(ssl::SslStream::from_raw_parts(
190 ssl,
191 AsyncStreamBridge::new(stream),
192 ))
193 }
194}
195
196impl<S> AsyncRead for SslStream<S>
197where
198 S: AsyncRead + AsyncWrite + Unpin,
199{
200 fn poll_read(
201 mut self: Pin<&mut Self>,
202 ctx: &mut Context<'_>,
203 buf: &mut ReadBuf,
204 ) -> Poll<io::Result<()>> {
205 self.run_in_context(ctx, |s| {
206 match cvt(s.read_uninit(unsafe { buf.unfilled_mut() }))? {
208 Poll::Ready(nread) => {
209 unsafe {
210 buf.assume_init(nread);
211 }
212 buf.advance(nread);
213 Poll::Ready(Ok(()))
214 }
215 Poll::Pending => Poll::Pending,
216 }
217 })
218 }
219}
220
221impl<S> AsyncWrite for SslStream<S>
222where
223 S: AsyncRead + AsyncWrite + Unpin,
224{
225 fn poll_write(
226 mut self: Pin<&mut Self>,
227 ctx: &mut Context,
228 buf: &[u8],
229 ) -> Poll<io::Result<usize>> {
230 self.run_in_context(ctx, |s| cvt(s.write(buf)))
231 }
232
233 fn poll_flush(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
234 self.run_in_context(ctx, |s| cvt(s.flush()))
235 }
236
237 fn poll_shutdown(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
238 match self.run_in_context(ctx, |s| s.shutdown()) {
239 Ok(ShutdownResult::Sent | ShutdownResult::Received) => {}
240 Err(ref e) if e.code() == ErrorCode::ZERO_RETURN => {}
241 Err(ref e) if e.code() == ErrorCode::WANT_READ || e.code() == ErrorCode::WANT_WRITE => {
242 return Poll::Pending;
243 }
244 Err(e) => {
245 return Poll::Ready(Err(e.into_io_error().unwrap_or_else(io::Error::other)));
246 }
247 }
248
249 Pin::new(&mut self.0.get_mut().stream).poll_shutdown(ctx)
250 }
251}
252
253pub struct HandshakeError<S>(ssl::HandshakeError<AsyncStreamBridge<S>>);
255
256impl<S> HandshakeError<S> {
257 #[must_use]
259 pub fn ssl(&self) -> Option<&SslRef> {
260 match &self.0 {
261 ssl::HandshakeError::Failure(s) => Some(s.ssl()),
262 _ => None,
263 }
264 }
265
266 #[must_use]
268 pub fn into_source_stream(self) -> Option<S> {
269 match self.0 {
270 ssl::HandshakeError::Failure(s) => Some(s.into_source_stream().stream),
271 _ => None,
272 }
273 }
274
275 #[must_use]
277 pub fn as_source_stream(&self) -> Option<&S> {
278 match &self.0 {
279 ssl::HandshakeError::Failure(s) => Some(&s.get_ref().stream),
280 _ => None,
281 }
282 }
283
284 #[must_use]
286 pub fn code(&self) -> Option<ErrorCode> {
287 match &self.0 {
288 ssl::HandshakeError::Failure(s) => Some(s.error().code()),
289 _ => None,
290 }
291 }
292
293 #[must_use]
295 pub fn as_io_error(&self) -> Option<&io::Error> {
296 match &self.0 {
297 ssl::HandshakeError::Failure(s) => s.error().io_error(),
298 _ => None,
299 }
300 }
301}
302
303impl<S> fmt::Debug for HandshakeError<S>
304where
305 S: fmt::Debug,
306{
307 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
308 fmt::Debug::fmt(&self.0, fmt)
309 }
310}
311
312impl<S> fmt::Display for HandshakeError<S> {
313 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
314 fmt::Display::fmt(&self.0, fmt)
315 }
316}
317
318impl<S> Error for HandshakeError<S>
319where
320 S: fmt::Debug,
321{
322 fn source(&self) -> Option<&(dyn Error + 'static)> {
323 self.0.source()
324 }
325}
326
327pub struct HandshakeFuture<S>(Option<MidHandshakeSslStream<AsyncStreamBridge<S>>>);
331
332impl<S> Future for HandshakeFuture<S>
333where
334 S: AsyncRead + AsyncWrite + Unpin,
335{
336 type Output = Result<SslStream<S>, HandshakeError<S>>;
337
338 fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
339 let mut mid_handshake = self.0.take().expect("future polled after completion");
340
341 mid_handshake.get_mut().set_waker(Some(ctx));
342 mid_handshake
343 .ssl_mut()
344 .set_task_waker(Some(ctx.waker().clone()));
345
346 match mid_handshake.handshake() {
347 Ok(mut stream) => {
348 stream.get_mut().set_waker(None);
349 stream.ssl_mut().set_task_waker(None);
350
351 Poll::Ready(Ok(SslStream(stream)))
352 }
353 Err(ssl::HandshakeError::WouldBlock(mut mid_handshake)) => {
354 mid_handshake.get_mut().set_waker(None);
355 mid_handshake.ssl_mut().set_task_waker(None);
356
357 self.0 = Some(mid_handshake);
358
359 Poll::Pending
360 }
361 Err(ssl::HandshakeError::Failure(mut mid_handshake)) => {
362 mid_handshake.get_mut().set_waker(None);
363
364 Poll::Ready(Err(HandshakeError(ssl::HandshakeError::Failure(
365 mid_handshake,
366 ))))
367 }
368 Err(err @ ssl::HandshakeError::SetupFailure(_)) => {
369 Poll::Ready(Err(HandshakeError(err)))
370 }
371 }
372 }
373}