1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
use std::io::Read;
use std::io::Write;
use std::marker;
use std::pin::Pin;
use std::ptr;
use std::task::Context;
use std::task::Poll;
use std::{error, fmt, io};
use tokio::io::AsyncRead;
use tokio::io::AsyncWrite;
#[derive(Debug)]
pub struct AsyncIoAsSyncIo<S: Unpin> {
inner: S,
context: *mut (),
}
unsafe impl<S: Unpin + Sync> Sync for AsyncIoAsSyncIo<S> {}
unsafe impl<S: Unpin + Send> Send for AsyncIoAsSyncIo<S> {}
impl<S: Unpin> AsyncIoAsSyncIo<S> {
pub fn get_inner_mut(&mut self) -> &mut S {
&mut self.inner
}
pub fn get_inner_ref(&self) -> &S {
&self.inner
}
pub fn new(inner: S) -> AsyncIoAsSyncIo<S> {
AsyncIoAsSyncIo {
inner,
context: ptr::null_mut(),
}
}
pub unsafe fn set_context(&mut self, cx: &mut Context<'_>) {
assert!(self.context.is_null());
self.context = cx as *mut _ as *mut _;
}
pub unsafe fn unset_context(&mut self) {
assert!(!self.context.is_null());
self.context = ptr::null_mut();
}
}
pub trait AsyncIoAsSyncIoWrapper<S: Unpin>: Sized {
fn get_mut(&mut self) -> &mut AsyncIoAsSyncIo<S>;
fn with_context<F, R>(&mut self, cx: &mut Context<'_>, f: F) -> R
where
F: FnOnce(&mut Self) -> R,
{
unsafe {
let s = self.get_mut();
s.set_context(cx);
let g = Guard(self, marker::PhantomData);
f(g.0)
}
}
fn with_context_sync_to_async<F, R>(
&mut self,
cx: &mut Context<'_>,
f: F,
) -> Poll<io::Result<R>>
where
F: FnOnce(&mut Self) -> io::Result<R>,
{
result_to_poll(self.with_context(cx, f))
}
}
impl<S: Unpin> AsyncIoAsSyncIoWrapper<S> for AsyncIoAsSyncIo<S> {
fn get_mut(&mut self) -> &mut AsyncIoAsSyncIo<S> {
self
}
}
struct Guard<'a, S: Unpin, W: AsyncIoAsSyncIoWrapper<S>>(&'a mut W, marker::PhantomData<S>);
impl<'a, S: Unpin, W: AsyncIoAsSyncIoWrapper<S>> Drop for Guard<'a, S, W> {
fn drop(&mut self) {
unsafe {
let s = self.0.get_mut();
s.unset_context();
}
}
}
impl<S: Unpin> AsyncIoAsSyncIo<S> {
fn with_context_inner<F, R>(&mut self, f: F) -> R
where
F: FnOnce(&mut Context<'_>, Pin<&mut S>) -> R,
{
unsafe {
assert!(!self.context.is_null());
let context = &mut *(self.context as *mut _);
f(context, Pin::new(&mut self.inner))
}
}
}
impl<S: AsyncRead + Unpin> Read for AsyncIoAsSyncIo<S> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.with_context_inner(|cx, s| poll_to_result(s.poll_read(cx, buf)))
}
}
impl<S: AsyncWrite + Unpin> Write for AsyncIoAsSyncIo<S> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.with_context_inner(|cx, s| poll_to_result(s.poll_write(cx, buf)))
}
fn flush(&mut self) -> io::Result<()> {
self.with_context_inner(|cx, s| poll_to_result(s.poll_flush(cx)))
}
}
pub fn result_to_poll<T>(r: io::Result<T>) -> Poll<io::Result<T>> {
match r {
Ok(v) => Poll::Ready(Ok(v)),
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
Err(e) => Poll::Ready(Err(e)),
}
}
#[derive(Debug)]
struct ShouldNotReturnWouldBlockFromAsync(io::Error);
impl error::Error for ShouldNotReturnWouldBlockFromAsync {}
impl fmt::Display for ShouldNotReturnWouldBlockFromAsync {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "should not return WouldBlock from async API: {}", self.0)
}
}
pub fn poll_to_result<T>(r: Poll<io::Result<T>>) -> io::Result<T> {
match r {
Poll::Ready(Ok(r)) => Ok(r),
Poll::Ready(Err(e)) if e.kind() == io::ErrorKind::WouldBlock => Err(io::Error::new(
io::ErrorKind::Other,
ShouldNotReturnWouldBlockFromAsync(e),
)),
Poll::Ready(Err(e)) => Err(e),
Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
}
}