1#![warn(missing_docs)]
7
8use futures_util::future;
9use openssl::error::ErrorStack;
10use openssl::ssl::{self, ErrorCode, ShutdownResult, Ssl, SslRef};
11use std::fmt;
12use std::io::{self, Read, Write};
13use std::pin::Pin;
14use std::task::{Context, Poll};
15use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
16
17#[cfg(test)]
18mod test;
19
20struct StreamWrapper<S> {
21 stream: S,
22 context: usize,
23}
24
25impl<S> fmt::Debug for StreamWrapper<S>
26where
27 S: fmt::Debug,
28{
29 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
30 fmt::Debug::fmt(&self.stream, fmt)
31 }
32}
33
34impl<S> StreamWrapper<S> {
35 unsafe fn parts(&mut self) -> (Pin<&mut S>, &mut Context<'_>) {
40 debug_assert_ne!(self.context, 0);
41 let stream = Pin::new_unchecked(&mut self.stream);
42 let context = &mut *(self.context as *mut _);
43 (stream, context)
44 }
45}
46
47impl<S> Read for StreamWrapper<S>
48where
49 S: AsyncRead,
50{
51 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
52 let (stream, cx) = unsafe { self.parts() };
53 let mut buf = ReadBuf::new(buf);
54 match stream.poll_read(cx, &mut buf)? {
55 Poll::Ready(()) => Ok(buf.filled().len()),
56 Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
57 }
58 }
59}
60
61impl<S> Write for StreamWrapper<S>
62where
63 S: AsyncWrite,
64{
65 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
66 let (stream, cx) = unsafe { self.parts() };
67 match stream.poll_write(cx, buf) {
68 Poll::Ready(r) => r,
69 Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
70 }
71 }
72
73 fn flush(&mut self) -> io::Result<()> {
74 let (stream, cx) = unsafe { self.parts() };
75 match stream.poll_flush(cx) {
76 Poll::Ready(r) => r,
77 Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
78 }
79 }
80}
81
82fn cvt<T>(r: io::Result<T>) -> Poll<io::Result<T>> {
83 match r {
84 Ok(v) => Poll::Ready(Ok(v)),
85 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
86 Err(e) => Poll::Ready(Err(e)),
87 }
88}
89
90fn cvt_ossl<T>(r: Result<T, ssl::Error>) -> Poll<Result<T, ssl::Error>> {
91 match r {
92 Ok(v) => Poll::Ready(Ok(v)),
93 Err(e) => match e.code() {
94 ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => Poll::Pending,
95 _ => Poll::Ready(Err(e)),
96 },
97 }
98}
99
100#[derive(Debug)]
102pub struct SslStream<S>(ssl::SslStream<StreamWrapper<S>>);
103
104impl<S> SslStream<S>
105where
106 S: AsyncRead + AsyncWrite,
107{
108 pub fn new(ssl: Ssl, stream: S) -> Result<Self, ErrorStack> {
110 ssl::SslStream::new(ssl, StreamWrapper { stream, context: 0 }).map(SslStream)
111 }
112
113 pub fn poll_connect(
115 self: Pin<&mut Self>,
116 cx: &mut Context<'_>,
117 ) -> Poll<Result<(), ssl::Error>> {
118 self.with_context(cx, |s| cvt_ossl(s.connect()))
119 }
120
121 pub async fn connect(mut self: Pin<&mut Self>) -> Result<(), ssl::Error> {
123 future::poll_fn(|cx| self.as_mut().poll_connect(cx)).await
124 }
125
126 pub fn poll_accept(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), ssl::Error>> {
128 self.with_context(cx, |s| cvt_ossl(s.accept()))
129 }
130
131 pub async fn accept(mut self: Pin<&mut Self>) -> Result<(), ssl::Error> {
133 future::poll_fn(|cx| self.as_mut().poll_accept(cx)).await
134 }
135
136 pub fn poll_do_handshake(
138 self: Pin<&mut Self>,
139 cx: &mut Context<'_>,
140 ) -> Poll<Result<(), ssl::Error>> {
141 self.with_context(cx, |s| cvt_ossl(s.do_handshake()))
142 }
143
144 pub async fn do_handshake(mut self: Pin<&mut Self>) -> Result<(), ssl::Error> {
146 future::poll_fn(|cx| self.as_mut().poll_do_handshake(cx)).await
147 }
148
149 #[cfg(ossl111)]
151 pub fn poll_read_early_data(
152 self: Pin<&mut Self>,
153 cx: &mut Context<'_>,
154 buf: &mut [u8],
155 ) -> Poll<Result<usize, ssl::Error>> {
156 self.with_context(cx, |s| cvt_ossl(s.read_early_data(buf)))
157 }
158
159 #[cfg(ossl111)]
161 pub async fn read_early_data(
162 mut self: Pin<&mut Self>,
163 buf: &mut [u8],
164 ) -> Result<usize, ssl::Error> {
165 future::poll_fn(|cx| self.as_mut().poll_read_early_data(cx, buf)).await
166 }
167
168 #[cfg(ossl111)]
170 pub fn poll_write_early_data(
171 self: Pin<&mut Self>,
172 cx: &mut Context<'_>,
173 buf: &[u8],
174 ) -> Poll<Result<usize, ssl::Error>> {
175 self.with_context(cx, |s| cvt_ossl(s.write_early_data(buf)))
176 }
177
178 #[cfg(ossl111)]
180 pub async fn write_early_data(
181 mut self: Pin<&mut Self>,
182 buf: &[u8],
183 ) -> Result<usize, ssl::Error> {
184 future::poll_fn(|cx| self.as_mut().poll_write_early_data(cx, buf)).await
185 }
186}
187
188impl<S> SslStream<S> {
189 pub fn ssl(&self) -> &SslRef {
191 self.0.ssl()
192 }
193
194 pub fn get_ref(&self) -> &S {
196 &self.0.get_ref().stream
197 }
198
199 pub fn get_mut(&mut self) -> &mut S {
201 &mut self.0.get_mut().stream
202 }
203
204 pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S> {
206 unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().0.get_mut().stream) }
207 }
208
209 fn with_context<F, R>(self: Pin<&mut Self>, ctx: &mut Context<'_>, f: F) -> R
210 where
211 F: FnOnce(&mut ssl::SslStream<StreamWrapper<S>>) -> R,
212 {
213 let this = unsafe { self.get_unchecked_mut() };
214 this.0.get_mut().context = ctx as *mut _ as usize;
215 let r = f(&mut this.0);
216 this.0.get_mut().context = 0;
217 r
218 }
219}
220
221impl<S> AsyncRead for SslStream<S>
222where
223 S: AsyncRead + AsyncWrite,
224{
225 fn poll_read(
226 self: Pin<&mut Self>,
227 ctx: &mut Context<'_>,
228 buf: &mut ReadBuf<'_>,
229 ) -> Poll<io::Result<()>> {
230 self.with_context(ctx, |s| {
231 match cvt(s.read_uninit(unsafe { buf.unfilled_mut() }))? {
233 Poll::Ready(nread) => {
234 unsafe { buf.assume_init(nread) };
236 buf.advance(nread);
237 Poll::Ready(Ok(()))
238 }
239 Poll::Pending => Poll::Pending,
240 }
241 })
242 }
243}
244
245impl<S> AsyncWrite for SslStream<S>
246where
247 S: AsyncRead + AsyncWrite,
248{
249 fn poll_write(self: Pin<&mut Self>, ctx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
250 self.with_context(ctx, |s| cvt(s.write(buf)))
251 }
252
253 fn poll_flush(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
254 self.with_context(ctx, |s| cvt(s.flush()))
255 }
256
257 fn poll_shutdown(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
258 match self.as_mut().with_context(ctx, |s| s.shutdown()) {
259 Ok(ShutdownResult::Sent) | Ok(ShutdownResult::Received) => {}
260 Err(ref e) if e.code() == ErrorCode::ZERO_RETURN => {}
261 Err(ref e) if e.code() == ErrorCode::WANT_READ || e.code() == ErrorCode::WANT_WRITE => {
262 return Poll::Pending;
263 }
264 Err(e) => {
265 return Poll::Ready(Err(e
266 .into_io_error()
267 .unwrap_or_else(|e| io::Error::new(io::ErrorKind::Other, e))));
268 }
269 }
270
271 self.get_pin_mut().poll_shutdown(ctx)
272 }
273}