1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
use crate::handle::UnsafeHandle;
use futures::{AsyncRead, AsyncWrite};
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use windows_sys::Win32::{
    Foundation::{FALSE, HANDLE, WAIT_ABANDONED_0, WAIT_EVENT, WAIT_OBJECT_0},
    System::Threading::{WaitForMultipleObjects, INFINITE},
};

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum WaitingStopReason {
    Shutdown,
    Ready,
}

#[derive(Debug)]
enum ReadState {
    Waiting(Option<blocking::Task<WaitingStopReason>>),
    Idle,
    Closed,
}

pub struct AsyncSession {
    session: Arc<crate::session::Session>,
    read_state: ReadState,
}

impl From<Arc<crate::session::Session>> for AsyncSession {
    fn from(session: Arc<crate::session::Session>) -> Self {
        Self {
            session,
            read_state: ReadState::Idle,
        }
    }
}

impl Drop for AsyncSession {
    fn drop(&mut self) {
        self.session.shutdown().ok();
    }
}

impl AsyncSession {
    fn wait_for_read(read_event: UnsafeHandle<HANDLE>, shutdown_event: UnsafeHandle<HANDLE>) -> WaitingStopReason {
        const WAIT_OBJECT_1: WAIT_EVENT = WAIT_OBJECT_0 + 1;
        const WAIT_ABANDONED_1: WAIT_EVENT = WAIT_ABANDONED_0 + 1;
        let handles = [shutdown_event.0, read_event.0];
        match unsafe { WaitForMultipleObjects(handles.len() as u32, &handles as _, FALSE, INFINITE) } {
            WAIT_OBJECT_0 | WAIT_ABANDONED_0 => WaitingStopReason::Shutdown,
            WAIT_OBJECT_1 => WaitingStopReason::Ready,
            WAIT_ABANDONED_1 => panic!("Read event deleted unexpectedly"),
            e => panic!("WaitForMultipleObjects returned unexpected value {:?}", e),
        }
    }
}

impl AsyncRead for AsyncSession {
    fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll<std::io::Result<usize>> {
        loop {
            match &mut self.read_state {
                ReadState::Idle => match self.session.try_receive() {
                    Ok(Some(packet)) => {
                        let size = packet.bytes.len().min(buf.len());
                        buf[..size].copy_from_slice(&packet.bytes[..size]);
                        return Poll::Ready(Ok(size));
                    }
                    Ok(None) => {
                        let read_event = self.session.get_read_wait_event()?;
                        let shutdown_event = self.session.shutdown_event.get_handle();
                        self.read_state = ReadState::Waiting(Some(blocking::unblock(move || {
                            Self::wait_for_read(read_event, shutdown_event)
                        })));
                    }
                    Err(err) => return Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::Other, err))),
                },
                ReadState::Waiting(task) => {
                    let mut task = task.take().unwrap();
                    self.read_state = match Pin::new(&mut task).poll(cx) {
                        Poll::Ready(WaitingStopReason::Shutdown) => ReadState::Closed,
                        Poll::Ready(WaitingStopReason::Ready) => ReadState::Idle,
                        Poll::Pending => ReadState::Waiting(Some(task)),
                    };
                    if let ReadState::Waiting(_) = self.read_state {
                        return Poll::Pending;
                    }
                }
                ReadState::Closed => return Poll::Ready(Ok(0)),
            }
        }
    }
}

impl AsyncWrite for AsyncSession {
    fn poll_write(self: Pin<&mut Self>, _cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
        let packet = self.session.allocate_send_packet(buf.len() as _)?;
        packet.bytes.copy_from_slice(buf);
        self.session.send_packet(packet);
        Poll::Ready(Ok(buf.len()))
    }

    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
        Poll::Ready(Ok(()))
    }

    fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
        self.session.shutdown()?;
        Poll::Ready(Ok(()))
    }
}