wintun_bindings/
async_session.rs1use crate::{handle::UnsafeHandle, session::Session};
2use futures::{AsyncRead, AsyncWrite};
3use std::future::Future;
4use std::pin::Pin;
5use std::sync::{Arc, Mutex};
6use std::task::{Context, Poll};
7use windows_sys::Win32::{
8 Foundation::{FALSE, HANDLE, WAIT_ABANDONED_0, WAIT_EVENT, WAIT_OBJECT_0},
9 System::Threading::{INFINITE, WaitForMultipleObjects},
10};
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13enum WaitingStopReason {
14 Shutdown,
15 Ready,
16}
17
18#[derive(Debug, Clone)]
19enum ReadState {
20 Waiting(Option<Arc<Mutex<blocking::Task<WaitingStopReason>>>>),
21 Idle,
22 Closed,
23}
24
25#[derive(Clone)]
26pub struct AsyncSession {
27 session: Arc<Session>,
28 read_state: ReadState,
29}
30
31impl std::ops::Deref for AsyncSession {
32 type Target = Session;
33
34 fn deref(&self) -> &Self::Target {
35 &self.session
36 }
37}
38
39impl From<Arc<Session>> for AsyncSession {
40 fn from(session: Arc<Session>) -> Self {
41 Self {
42 session,
43 read_state: ReadState::Idle,
44 }
45 }
46}
47
48impl AsyncSession {
49 fn wait_for_read(read_event: UnsafeHandle<HANDLE>, shutdown_event: UnsafeHandle<HANDLE>) -> WaitingStopReason {
50 const WAIT_OBJECT_1: WAIT_EVENT = WAIT_OBJECT_0 + 1;
51 const WAIT_ABANDONED_1: WAIT_EVENT = WAIT_ABANDONED_0 + 1;
52 let handles = [shutdown_event.0, read_event.0];
53 match unsafe { WaitForMultipleObjects(handles.len() as u32, &handles as _, FALSE, INFINITE) } {
54 WAIT_OBJECT_0 | WAIT_ABANDONED_0 => WaitingStopReason::Shutdown,
55 WAIT_OBJECT_1 => WaitingStopReason::Ready,
56 WAIT_ABANDONED_1 => panic!("Read event deleted unexpectedly"),
57 e => panic!("WaitForMultipleObjects returned unexpected value {:?}", e),
58 }
59 }
60
61 pub async fn recv(&self, buf: &mut [u8]) -> std::io::Result<usize> {
62 loop {
63 match self.session.try_receive() {
64 Ok(Some(packet)) => {
65 let size = packet.bytes.len();
66 if buf.len() < size {
67 return Err(std::io::Error::other("Buffer too small"));
68 }
69 buf[..size].copy_from_slice(&packet.bytes[..size]);
70 return Ok(size);
71 }
72 Ok(None) => {
73 let read_event = self.session.get_read_wait_event()?;
74 let shutdown_event = self.session.shutdown_event.get_handle();
75 match blocking::unblock(move || Self::wait_for_read(read_event, shutdown_event)).await {
76 WaitingStopReason::Shutdown => {
77 return Err(crate::Error::ShuttingDown.into());
78 }
79 WaitingStopReason::Ready => continue,
80 }
81 }
82 Err(err) => return Err(std::io::Error::other(err)),
83 }
84 }
85 }
86
87 pub async fn send(&self, buf: &[u8]) -> std::io::Result<usize> {
88 self.internal_send(buf)
89 }
90
91 fn internal_send(&self, buf: &[u8]) -> std::io::Result<usize> {
92 let packet = self.session.allocate_send_packet(buf.len() as _)?;
93 packet.bytes.copy_from_slice(buf);
94 self.session.send_packet(packet);
95 Ok(buf.len())
96 }
97}
98
99impl AsyncRead for AsyncSession {
100 fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll<std::io::Result<usize>> {
101 loop {
102 match &mut self.read_state {
103 ReadState::Idle => match self.session.try_receive() {
104 Ok(Some(packet)) => {
105 let size = packet.bytes.len();
106 if buf.len() < size {
107 return Poll::Ready(Err(std::io::Error::other("Buffer too small")));
108 }
109 buf[..size].copy_from_slice(&packet.bytes[..size]);
110 return Poll::Ready(Ok(size));
111 }
112 Ok(None) => {
113 let read_event = self.session.get_read_wait_event()?;
114 let shutdown_event = self.session.shutdown_event.get_handle();
115 let task = Arc::new(Mutex::new(blocking::unblock(move || {
116 Self::wait_for_read(read_event, shutdown_event)
117 })));
118 self.read_state = ReadState::Waiting(Some(task));
119 }
120 Err(err) => return Poll::Ready(Err(std::io::Error::other(err))),
121 },
122 ReadState::Waiting(task) => {
123 let task = match task.take() {
124 Some(task) => task,
125 None => return Poll::Pending,
126 };
127 let task_clone = task.clone();
128 let mut task_guard = match task_clone.lock() {
129 Ok(guard) => guard,
130 Err(e) => {
131 self.read_state = ReadState::Waiting(Some(task));
132 return Poll::Ready(Err(std::io::Error::other(format!("Lock task failed: {}", e))));
133 }
134 };
135 self.read_state = match Pin::new(&mut *task_guard).poll(cx) {
136 Poll::Ready(WaitingStopReason::Shutdown) => ReadState::Closed,
137 Poll::Ready(WaitingStopReason::Ready) => ReadState::Idle,
138 Poll::Pending => ReadState::Waiting(Some(task)),
139 };
140 if let ReadState::Waiting(_) = self.read_state {
141 return Poll::Pending;
142 }
143 }
144 ReadState::Closed => return Poll::Ready(Ok(0)),
145 }
146 }
147 }
148}
149
150impl AsyncWrite for AsyncSession {
151 fn poll_write(self: Pin<&mut Self>, _cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
152 Poll::Ready(Ok(self.internal_send(buf)?))
153 }
154
155 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
156 Poll::Ready(Ok(()))
157 }
158
159 fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
160 self.session.shutdown()?;
161 Poll::Ready(Ok(()))
162 }
163}