wintun_bindings/
async_session.rs

1use crate::{handle::UnsafeHandle, session::Session};
2use futures::{AsyncRead, AsyncWrite};
3use std::future::Future;
4use std::pin::Pin;
5use std::sync::{Arc, Mutex};
6use std::task::{Context, Poll};
7use windows_sys::Win32::{
8    Foundation::{FALSE, HANDLE, WAIT_ABANDONED_0, WAIT_EVENT, WAIT_OBJECT_0},
9    System::Threading::{INFINITE, WaitForMultipleObjects},
10};
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13enum WaitingStopReason {
14    Shutdown,
15    Ready,
16}
17
18#[derive(Debug, Clone)]
19enum ReadState {
20    Waiting(Option<Arc<Mutex<blocking::Task<WaitingStopReason>>>>),
21    Idle,
22    Closed,
23}
24
25#[derive(Clone)]
26pub struct AsyncSession {
27    session: Arc<Session>,
28    read_state: ReadState,
29}
30
31impl std::ops::Deref for AsyncSession {
32    type Target = Session;
33
34    fn deref(&self) -> &Self::Target {
35        &self.session
36    }
37}
38
39impl From<Arc<Session>> for AsyncSession {
40    fn from(session: Arc<Session>) -> Self {
41        Self {
42            session,
43            read_state: ReadState::Idle,
44        }
45    }
46}
47
48impl AsyncSession {
49    fn wait_for_read(read_event: UnsafeHandle<HANDLE>, shutdown_event: UnsafeHandle<HANDLE>) -> WaitingStopReason {
50        const WAIT_OBJECT_1: WAIT_EVENT = WAIT_OBJECT_0 + 1;
51        const WAIT_ABANDONED_1: WAIT_EVENT = WAIT_ABANDONED_0 + 1;
52        let handles = [shutdown_event.0, read_event.0];
53        match unsafe { WaitForMultipleObjects(handles.len() as u32, &handles as _, FALSE, INFINITE) } {
54            WAIT_OBJECT_0 | WAIT_ABANDONED_0 => WaitingStopReason::Shutdown,
55            WAIT_OBJECT_1 => WaitingStopReason::Ready,
56            WAIT_ABANDONED_1 => panic!("Read event deleted unexpectedly"),
57            e => panic!("WaitForMultipleObjects returned unexpected value {:?}", e),
58        }
59    }
60
61    pub async fn recv(&self, buf: &mut [u8]) -> std::io::Result<usize> {
62        loop {
63            match self.session.try_receive() {
64                Ok(Some(packet)) => {
65                    let size = packet.bytes.len();
66                    if buf.len() < size {
67                        return Err(std::io::Error::other("Buffer too small"));
68                    }
69                    buf[..size].copy_from_slice(&packet.bytes[..size]);
70                    return Ok(size);
71                }
72                Ok(None) => {
73                    let read_event = self.session.get_read_wait_event()?;
74                    let shutdown_event = self.session.shutdown_event.get_handle();
75                    match blocking::unblock(move || Self::wait_for_read(read_event, shutdown_event)).await {
76                        WaitingStopReason::Shutdown => {
77                            return Err(crate::Error::ShuttingDown.into());
78                        }
79                        WaitingStopReason::Ready => continue,
80                    }
81                }
82                Err(err) => return Err(std::io::Error::other(err)),
83            }
84        }
85    }
86
87    pub async fn send(&self, buf: &[u8]) -> std::io::Result<usize> {
88        self.internal_send(buf)
89    }
90
91    fn internal_send(&self, buf: &[u8]) -> std::io::Result<usize> {
92        let packet = self.session.allocate_send_packet(buf.len() as _)?;
93        packet.bytes.copy_from_slice(buf);
94        self.session.send_packet(packet);
95        Ok(buf.len())
96    }
97}
98
99impl AsyncRead for AsyncSession {
100    fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll<std::io::Result<usize>> {
101        loop {
102            match &mut self.read_state {
103                ReadState::Idle => match self.session.try_receive() {
104                    Ok(Some(packet)) => {
105                        let size = packet.bytes.len();
106                        if buf.len() < size {
107                            return Poll::Ready(Err(std::io::Error::other("Buffer too small")));
108                        }
109                        buf[..size].copy_from_slice(&packet.bytes[..size]);
110                        return Poll::Ready(Ok(size));
111                    }
112                    Ok(None) => {
113                        let read_event = self.session.get_read_wait_event()?;
114                        let shutdown_event = self.session.shutdown_event.get_handle();
115                        let task = Arc::new(Mutex::new(blocking::unblock(move || {
116                            Self::wait_for_read(read_event, shutdown_event)
117                        })));
118                        self.read_state = ReadState::Waiting(Some(task));
119                    }
120                    Err(err) => return Poll::Ready(Err(std::io::Error::other(err))),
121                },
122                ReadState::Waiting(task) => {
123                    let task = match task.take() {
124                        Some(task) => task,
125                        None => return Poll::Pending,
126                    };
127                    let task_clone = task.clone();
128                    let mut task_guard = match task_clone.lock() {
129                        Ok(guard) => guard,
130                        Err(e) => {
131                            self.read_state = ReadState::Waiting(Some(task));
132                            return Poll::Ready(Err(std::io::Error::other(format!("Lock task failed: {}", e))));
133                        }
134                    };
135                    self.read_state = match Pin::new(&mut *task_guard).poll(cx) {
136                        Poll::Ready(WaitingStopReason::Shutdown) => ReadState::Closed,
137                        Poll::Ready(WaitingStopReason::Ready) => ReadState::Idle,
138                        Poll::Pending => ReadState::Waiting(Some(task)),
139                    };
140                    if let ReadState::Waiting(_) = self.read_state {
141                        return Poll::Pending;
142                    }
143                }
144                ReadState::Closed => return Poll::Ready(Ok(0)),
145            }
146        }
147    }
148}
149
150impl AsyncWrite for AsyncSession {
151    fn poll_write(self: Pin<&mut Self>, _cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
152        Poll::Ready(Ok(self.internal_send(buf)?))
153    }
154
155    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
156        Poll::Ready(Ok(()))
157    }
158
159    fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
160        self.session.shutdown()?;
161        Poll::Ready(Ok(()))
162    }
163}