1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
//! Utility used in different implementations of TLS API.

use std::io::Read;
use std::io::Write;
use std::marker;
use std::pin::Pin;
use std::ptr;
use std::task::Context;
use std::task::Poll;
use std::{error, fmt, io};
use tokio::io::AsyncRead;
use tokio::io::AsyncWrite;

/// Async IO object as sync IO.
///
/// Used in API implementations.
#[derive(Debug)]
pub struct AsyncIoAsSyncIo<S: Unpin> {
    inner: S,
    context: *mut (),
}

unsafe impl<S: Unpin + Sync> Sync for AsyncIoAsSyncIo<S> {}
unsafe impl<S: Unpin + Send> Send for AsyncIoAsSyncIo<S> {}

impl<S: Unpin> AsyncIoAsSyncIo<S> {
    /// Get a mutable reference to a wrapped stream
    pub fn get_inner_mut(&mut self) -> &mut S {
        &mut self.inner
    }

    /// And a reference to a wrapped stream
    pub fn get_inner_ref(&self) -> &S {
        &self.inner
    }

    /// Wrap sync object in this wrapper.
    pub fn new(inner: S) -> AsyncIoAsSyncIo<S> {
        AsyncIoAsSyncIo {
            inner,
            context: ptr::null_mut(),
        }
    }

    /// Store async context inside this object
    pub unsafe fn set_context(&mut self, cx: &mut Context<'_>) {
        assert!(self.context.is_null());
        self.context = cx as *mut _ as *mut _;
    }

    /// Clear async context
    pub unsafe fn unset_context(&mut self) {
        assert!(!self.context.is_null());
        self.context = ptr::null_mut();
    }
}

pub trait AsyncIoAsSyncIoWrapper<S: Unpin>: Sized {
    fn get_mut(&mut self) -> &mut AsyncIoAsSyncIo<S>;

    fn with_context<F, R>(&mut self, cx: &mut Context<'_>, f: F) -> R
    where
        F: FnOnce(&mut Self) -> R,
    {
        unsafe {
            let s = self.get_mut();
            s.set_context(cx);
            let g = Guard(self, marker::PhantomData);
            f(g.0)
        }
    }

    fn with_context_sync_to_async<F, R>(
        &mut self,
        cx: &mut Context<'_>,
        f: F,
    ) -> Poll<io::Result<R>>
    where
        F: FnOnce(&mut Self) -> io::Result<R>,
    {
        result_to_poll(self.with_context(cx, f))
    }
}

impl<S: Unpin> AsyncIoAsSyncIoWrapper<S> for AsyncIoAsSyncIo<S> {
    fn get_mut(&mut self) -> &mut AsyncIoAsSyncIo<S> {
        self
    }
}

struct Guard<'a, S: Unpin, W: AsyncIoAsSyncIoWrapper<S>>(&'a mut W, marker::PhantomData<S>);

impl<'a, S: Unpin, W: AsyncIoAsSyncIoWrapper<S>> Drop for Guard<'a, S, W> {
    fn drop(&mut self) {
        unsafe {
            let s = self.0.get_mut();
            s.unset_context();
        }
    }
}

impl<S: Unpin> AsyncIoAsSyncIo<S> {
    fn with_context_inner<F, R>(&mut self, f: F) -> R
    where
        F: FnOnce(&mut Context<'_>, Pin<&mut S>) -> R,
    {
        unsafe {
            assert!(!self.context.is_null());
            let context = &mut *(self.context as *mut _);
            f(context, Pin::new(&mut self.inner))
        }
    }
}

impl<S: AsyncRead + Unpin> Read for AsyncIoAsSyncIo<S> {
    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
        self.with_context_inner(|cx, s| poll_to_result(s.poll_read(cx, buf)))
    }
}

impl<S: AsyncWrite + Unpin> Write for AsyncIoAsSyncIo<S> {
    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
        self.with_context_inner(|cx, s| poll_to_result(s.poll_write(cx, buf)))
    }

    fn flush(&mut self) -> io::Result<()> {
        self.with_context_inner(|cx, s| poll_to_result(s.poll_flush(cx)))
    }
}

/// Convert blocking API result to async result
pub fn result_to_poll<T>(r: io::Result<T>) -> Poll<io::Result<T>> {
    match r {
        Ok(v) => Poll::Ready(Ok(v)),
        Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
        Err(e) => Poll::Ready(Err(e)),
    }
}

#[derive(Debug)]
struct ShouldNotReturnWouldBlockFromAsync(io::Error);

impl error::Error for ShouldNotReturnWouldBlockFromAsync {}

impl fmt::Display for ShouldNotReturnWouldBlockFromAsync {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(f, "should not return WouldBlock from async API: {}", self.0)
    }
}

/// Convert nonblocking API to sync result
pub fn poll_to_result<T>(r: Poll<io::Result<T>>) -> io::Result<T> {
    match r {
        Poll::Ready(Ok(r)) => Ok(r),
        Poll::Ready(Err(e)) if e.kind() == io::ErrorKind::WouldBlock => Err(io::Error::new(
            io::ErrorKind::Other,
            ShouldNotReturnWouldBlockFromAsync(e),
        )),
        Poll::Ready(Err(e)) => Err(e),
        Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
    }
}