1use std::fmt;
6use std::io;
7use std::io::Read;
8use std::io::Write;
9use std::marker::PhantomData;
10use std::pin::Pin;
11use std::task::Context;
12use std::task::Poll;
13
14use crate::runtime::AsyncRead;
15use crate::runtime::AsyncWrite;
16use crate::spi::restore_context;
17use crate::spi::save_context;
18use crate::spi::TlsStreamWithUpcastDyn;
19use crate::AsyncSocket;
20use crate::ImplInfo;
21use crate::TlsStreamDyn;
22use crate::TlsStreamWithSocketDyn;
23
24#[derive(Debug)]
28pub struct AsyncIoAsSyncIo<S: Unpin> {
29 inner: S,
30}
31
32unsafe impl<S: Unpin + Send> Send for AsyncIoAsSyncIo<S> {}
33
34impl<S: Unpin> AsyncIoAsSyncIo<S> {
35 pub fn get_inner_mut(&mut self) -> &mut S {
37 &mut self.inner
38 }
39
40 pub fn get_inner_ref(&self) -> &S {
42 &self.inner
43 }
44
45 pub fn new(inner: S) -> AsyncIoAsSyncIo<S> {
47 AsyncIoAsSyncIo { inner }
48 }
49
50 fn get_inner_pin(&mut self) -> Pin<&mut S> {
51 Pin::new(&mut self.inner)
52 }
53}
54
55impl<S: AsyncRead + Unpin> Read for AsyncIoAsSyncIo<S> {
56 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
57 restore_context_poll_to_result(|cx| {
58 #[cfg(feature = "runtime-tokio")]
59 {
60 let mut read_buf = tokio::io::ReadBuf::new(buf);
61 let p = self.get_inner_pin().poll_read(cx, &mut read_buf);
62 p.map_ok(|()| read_buf.filled().len())
63 }
64 #[cfg(feature = "runtime-async-std")]
65 {
66 self.get_inner_pin().poll_read(cx, buf)
67 }
68 })
69 }
70}
71
72impl<S: AsyncWrite + Unpin> Write for AsyncIoAsSyncIo<S> {
73 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
74 restore_context_poll_to_result(|cx| self.get_inner_pin().poll_write(cx, buf))
75 }
76
77 fn flush(&mut self) -> io::Result<()> {
78 restore_context_poll_to_result(|cx| self.get_inner_pin().poll_flush(cx))
79 }
80}
81
82fn result_to_poll<T>(r: io::Result<T>) -> Poll<io::Result<T>> {
84 match r {
85 Ok(v) => Poll::Ready(Ok(v)),
86 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
87 Err(e) => Poll::Ready(Err(e)),
88 }
89}
90
91#[derive(Debug, thiserror::Error)]
92#[error("should not return WouldBlock from async API: {}", _0)]
93struct ShouldNotReturnWouldBlockFromAsync(io::Error);
94
95fn poll_to_result<T>(r: Poll<io::Result<T>>) -> io::Result<T> {
97 match r {
98 Poll::Ready(Ok(r)) => Ok(r),
99 Poll::Ready(Err(e)) if e.kind() == io::ErrorKind::WouldBlock => Err(io::Error::new(
100 io::ErrorKind::Other,
101 ShouldNotReturnWouldBlockFromAsync(e),
102 )),
103 Poll::Ready(Err(e)) => Err(e),
104 Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
105 }
106}
107
108fn restore_context_poll_to_result<R>(
109 f: impl FnOnce(&mut Context<'_>) -> Poll<io::Result<R>>,
110) -> io::Result<R> {
111 restore_context(|cx| poll_to_result(f(cx)))
112}
113
114pub trait AsyncWrapperOps<A>: fmt::Debug + Unpin + Send + 'static
116where
117 A: Unpin,
118{
119 type SyncWrapper: Read + Write + WriteShutdown + Unpin + Send + 'static;
123
124 fn impl_info() -> ImplInfo;
126
127 fn debug(w: &Self::SyncWrapper) -> &dyn fmt::Debug;
130
131 fn get_mut(w: &mut Self::SyncWrapper) -> &mut AsyncIoAsSyncIo<A>;
133 fn get_ref(w: &Self::SyncWrapper) -> &AsyncIoAsSyncIo<A>;
135
136 fn get_alpn_protocol(w: &Self::SyncWrapper) -> anyhow::Result<Option<Vec<u8>>>;
138}
139
140pub trait WriteShutdown: Write {
143 fn shutdown(&mut self) -> Result<(), io::Error> {
176 self.flush()?;
177 Ok(())
178 }
179}
180
181pub struct TlsStreamOverSyncIo<A, O>
183where
184 A: Unpin,
185 O: AsyncWrapperOps<A>,
186{
187 pub stream: O::SyncWrapper,
189 _phantom: PhantomData<(A, O)>,
190}
191
192impl<A, O> fmt::Debug for TlsStreamOverSyncIo<A, O>
193where
194 A: Unpin,
195 O: AsyncWrapperOps<A>,
196{
197 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
198 f.debug_tuple("TlsStreamOverSyncIo")
199 .field(O::debug(&self.stream))
200 .finish()
201 }
202}
203
204impl<A, O> TlsStreamOverSyncIo<A, O>
205where
206 A: Unpin,
207 O: AsyncWrapperOps<A>,
208{
209 pub fn new(stream: O::SyncWrapper) -> TlsStreamOverSyncIo<A, O> {
211 TlsStreamOverSyncIo {
212 stream,
213 _phantom: PhantomData,
214 }
215 }
216
217 fn with_context_sync_to_async<F, R>(
218 &mut self,
219 cx: &mut Context<'_>,
220 f: F,
221 ) -> Poll<io::Result<R>>
222 where
223 F: FnOnce(&mut Self) -> io::Result<R>,
224 {
225 result_to_poll(save_context(cx, || f(self)))
226 }
227
228 #[cfg(feature = "runtime-tokio")]
229 fn with_context_sync_to_async_tokio<F>(
230 &mut self,
231 cx: &mut Context<'_>,
232 buf: &mut tokio::io::ReadBuf,
233 f: F,
234 ) -> Poll<io::Result<()>>
235 where
236 F: FnOnce(&mut Self, &mut [u8]) -> io::Result<usize>,
237 {
238 self.with_context_sync_to_async(cx, |s| {
239 let unfilled = buf.initialize_unfilled();
240 let read = f(s, unfilled)?;
241 buf.advance(read);
242 Ok(())
243 })
244 }
245}
246
247impl<A, O> AsyncRead for TlsStreamOverSyncIo<A, O>
248where
249 A: Unpin,
250 O: AsyncWrapperOps<A>,
251{
252 #[cfg(feature = "runtime-tokio")]
253 fn poll_read(
254 self: Pin<&mut Self>,
255 cx: &mut Context<'_>,
256 buf: &mut tokio::io::ReadBuf,
257 ) -> Poll<io::Result<()>> {
258 self.get_mut()
259 .with_context_sync_to_async_tokio(cx, buf, |s, buf| {
260 let result = s.stream.read(buf);
261 match result {
262 Ok(r) => Ok(r),
263 Err(e) if e.kind() == io::ErrorKind::ConnectionAborted => {
264 Ok(0)
266 }
267 Err(e) => Err(e),
268 }
269 })
270 }
271
272 #[cfg(feature = "runtime-async-std")]
273 fn poll_read(
274 self: Pin<&mut Self>,
275 cx: &mut Context<'_>,
276 buf: &mut [u8],
277 ) -> Poll<io::Result<usize>> {
278 self.get_mut().with_context_sync_to_async(cx, |s| {
279 let result = s.stream.read(buf);
280 match result {
281 Ok(r) => Ok(r),
282 Err(e) if e.kind() == io::ErrorKind::ConnectionAborted => {
283 Ok(0)
285 }
286 Err(e) => Err(e),
287 }
288 })
289 }
290}
291
292impl<A, O> AsyncWrite for TlsStreamOverSyncIo<A, O>
293where
294 A: Unpin,
295 O: AsyncWrapperOps<A>,
296{
297 fn poll_write(
298 self: Pin<&mut Self>,
299 cx: &mut Context<'_>,
300 buf: &[u8],
301 ) -> Poll<io::Result<usize>> {
302 self.get_mut()
303 .with_context_sync_to_async(cx, |stream| stream.stream.write(buf))
304 }
305
306 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
307 self.get_mut()
308 .with_context_sync_to_async(cx, |stream| stream.stream.flush())
309 }
310
311 #[cfg(feature = "runtime-tokio")]
312 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
313 self.get_mut()
314 .with_context_sync_to_async(cx, |stream| stream.stream.shutdown())
315 }
316
317 #[cfg(feature = "runtime-async-std")]
318 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
319 self.get_mut()
320 .with_context_sync_to_async(cx, |stream| stream.stream.shutdown())
321 }
322}
323
324impl<A, O> TlsStreamDyn for TlsStreamOverSyncIo<A, O>
325where
326 A: AsyncSocket,
327 O: AsyncWrapperOps<A>,
328{
329 fn impl_info(&self) -> ImplInfo {
330 O::impl_info()
331 }
332
333 fn get_alpn_protocol(&self) -> anyhow::Result<Option<Vec<u8>>> {
334 O::get_alpn_protocol(&self.stream)
335 }
336
337 fn get_socket_dyn_mut(&mut self) -> &mut dyn AsyncSocket {
338 O::get_mut(&mut self.stream).get_inner_mut()
339 }
340
341 fn get_socket_dyn_ref(&self) -> &dyn AsyncSocket {
342 O::get_ref(&self.stream).get_inner_ref()
343 }
344}
345
346impl<A, O> TlsStreamWithSocketDyn<A> for TlsStreamOverSyncIo<A, O>
347where
348 A: AsyncSocket,
349 O: AsyncWrapperOps<A>,
350{
351 fn get_socket_mut(&mut self) -> &mut A {
352 O::get_mut(&mut self.stream).get_inner_mut()
353 }
354
355 fn get_socket_ref(&self) -> &A {
356 O::get_ref(&self.stream).get_inner_ref()
357 }
358}
359
360impl<A, O> TlsStreamWithUpcastDyn<A> for TlsStreamOverSyncIo<A, O>
361where
362 A: AsyncSocket,
363 O: AsyncWrapperOps<A>,
364{
365 fn upcast_box(self: Box<Self>) -> Box<dyn TlsStreamDyn> {
366 self
367 }
368}
369
370#[macro_export]
372macro_rules! spi_tls_stream_over_sync_io_wrapper {
373 ( $t:ident, $n:ident ) => {
374 #[derive(Debug)]
375 pub struct TlsStream<A: AsyncSocket>(
376 pub(crate) TlsStreamOverSyncIo<A, AsyncWrapperOpsImpl<AsyncIoAsSyncIo<A>, A>>,
377 );
378
379 impl<A: AsyncSocket> TlsStream<A> {
380 pub(crate) fn new(stream: $n<AsyncIoAsSyncIo<A>>) -> TlsStream<A> {
381 TlsStream(TlsStreamOverSyncIo::new(stream))
382 }
383
384 fn deref_pin_mut_for_impl_socket(
385 self: std::pin::Pin<&mut Self>,
386 ) -> std::pin::Pin<
387 &mut TlsStreamOverSyncIo<A, AsyncWrapperOpsImpl<AsyncIoAsSyncIo<A>, A>>,
388 > {
389 std::pin::Pin::new(&mut self.get_mut().0)
390 }
391
392 fn deref_for_impl_socket(
393 &self,
394 ) -> &TlsStreamOverSyncIo<A, AsyncWrapperOpsImpl<AsyncIoAsSyncIo<A>, A>> {
395 &self.0
396 }
397 }
398
399 spi_async_socket_impl_delegate!($t<S>);
400
401 impl<A: tls_api::AsyncSocket> tls_api::TlsStreamDyn for $t<A> {
402 fn get_alpn_protocol(&self) -> anyhow::Result<Option<Vec<u8>>> {
403 self.0.get_alpn_protocol()
404 }
405
406 fn impl_info(&self) -> ImplInfo {
407 self.0.impl_info()
408 }
409
410 fn get_socket_dyn_mut(&mut self) -> &mut dyn AsyncSocket {
411 self.0.get_socket_dyn_mut()
412 }
413
414 fn get_socket_dyn_ref(&self) -> &dyn AsyncSocket {
415 self.0.get_socket_dyn_ref()
416 }
417 }
418
419 impl<A: tls_api::AsyncSocket> tls_api::TlsStreamWithSocketDyn<A> for $t<A> {
420 fn get_socket_mut(&mut self) -> &mut A {
421 self.0.get_socket_mut()
422 }
423
424 fn get_socket_ref(&self) -> &A {
425 self.0.get_socket_ref()
426 }
427 }
428
429 impl<A: tls_api::AsyncSocket> tls_api::spi::TlsStreamWithUpcastDyn<A> for $t<A> {
430 fn upcast_box(self: Box<Self>) -> Box<dyn tls_api::TlsStreamDyn> {
431 self
432 }
433 }
434 };
435}