1use crate::async_request::AsyncResponse;
2use crate::handle::WinHttpHandle;
3use crossfire::mpsc;
4use parking_lot::Mutex;
5use std::ffi::c_void;
6use std::future::Future;
7use std::mem::ManuallyDrop;
8use std::pin::Pin;
9use std::sync::Arc;
10use std::task::{Context, Poll, Waker};
11use windows::Win32::Foundation::WIN32_ERROR;
12use windows::Win32::Networking::WinHttp::*;
13use windows::core::{Error, Result};
14
15const ERROR_IO_PENDING_VALUE: u32 = 997;
18
19#[derive(Debug)]
20enum WsSendEvent {
21 WriteComplete,
22 Error(u32),
23}
24
25#[derive(Debug)]
26enum WsRecvEvent {
27 ReadComplete { bytes_read: u32, buffer_type: i32 },
28 CloseComplete,
29 Error(u32),
30}
31
32type SendSender = crossfire::MTx<mpsc::List<WsSendEvent>>;
33type SendReceiver = crossfire::Rx<mpsc::List<WsSendEvent>>;
34type RecvSender = crossfire::MTx<mpsc::List<WsRecvEvent>>;
35type RecvReceiver = crossfire::Rx<mpsc::List<WsRecvEvent>>;
36
37struct WebSocketContext {
38 send_waker: Mutex<Option<Waker>>,
39 recv_waker: Mutex<Option<Waker>>,
40 send_sender: SendSender,
41 recv_sender: RecvSender,
42}
43
44impl WebSocketContext {
45 fn new(send_sender: SendSender, recv_sender: RecvSender) -> Pin<Arc<Self>> {
46 Arc::pin(Self {
47 send_waker: Mutex::new(None),
48 recv_waker: Mutex::new(None),
49 send_sender,
50 recv_sender,
51 })
52 }
53
54 fn wake_send(&self) {
55 if let Some(waker) = self.send_waker.lock().take() {
56 waker.wake();
57 }
58 }
59
60 fn wake_recv(&self) {
61 if let Some(waker) = self.recv_waker.lock().take() {
62 waker.wake();
63 }
64 }
65
66 fn set_send_waker(&self, waker: &Waker) {
67 let mut guard = self.send_waker.lock();
68 match guard.as_ref() {
69 Some(existing) if existing.will_wake(waker) => {}
70 _ => *guard = Some(waker.clone()),
71 }
72 }
73
74 fn set_recv_waker(&self, waker: &Waker) {
75 let mut guard = self.recv_waker.lock();
76 match guard.as_ref() {
77 Some(existing) if existing.will_wake(waker) => {}
78 _ => *guard = Some(waker.clone()),
79 }
80 }
81}
82
83unsafe extern "system" fn async_websocket_callback(
89 _hinternet: *mut c_void,
90 context: usize,
91 status: u32,
92 status_info: *mut c_void,
93 _status_info_length: u32,
94) {
95 let Some(ctx) = (unsafe { (context as *const WebSocketContext).as_ref() }) else {
96 return;
97 };
98
99 match status {
100 WINHTTP_CALLBACK_STATUS_WRITE_COMPLETE => {
101 let _ = ctx.send_sender.send(WsSendEvent::WriteComplete);
102 ctx.wake_send();
103 }
104 WINHTTP_CALLBACK_STATUS_READ_COMPLETE => {
105 if !status_info.is_null() {
106 let ws_status = unsafe { &*(status_info as *const WINHTTP_WEB_SOCKET_STATUS) };
107 let _ = ctx.recv_sender.send(WsRecvEvent::ReadComplete {
108 bytes_read: ws_status.dwBytesTransferred,
109 buffer_type: ws_status.eBufferType.0,
110 });
111 }
112 ctx.wake_recv();
113 }
114 WINHTTP_CALLBACK_STATUS_CLOSE_COMPLETE => {
115 let _ = ctx.recv_sender.send(WsRecvEvent::CloseComplete);
116 ctx.wake_recv();
117 }
118 WINHTTP_CALLBACK_STATUS_SHUTDOWN_COMPLETE => {
119 let _ = ctx.send_sender.send(WsSendEvent::WriteComplete);
120 ctx.wake_send();
121 }
122 WINHTTP_CALLBACK_STATUS_REQUEST_ERROR => {
123 if !status_info.is_null() {
124 let ws_err = unsafe { &*(status_info as *const WINHTTP_WEB_SOCKET_ASYNC_RESULT) };
125 let error_code = ws_err.AsyncResult.dwError;
126 if ws_err.Operation.0 == 0 {
128 let _ = ctx.send_sender.send(WsSendEvent::Error(error_code));
129 ctx.wake_send();
130 } else {
131 let _ = ctx.recv_sender.send(WsRecvEvent::Error(error_code));
132 ctx.wake_recv();
133 }
134 }
135 }
136 _ => {}
137 }
138}
139
140#[derive(Debug, Clone, PartialEq, Eq)]
142pub enum WebSocketMessage {
143 Text(String),
145 Binary(Vec<u8>),
147 Close,
149}
150
151pub struct AsyncWebSocket {
158 handle: WinHttpHandle,
159 context: Pin<Arc<WebSocketContext>>,
160 send_receiver: SendReceiver,
161 recv_receiver: RecvReceiver,
162}
163
164impl AsyncWebSocket {
165 pub fn from_response(response: AsyncResponse<'_>) -> Result<Self> {
171 let request = response.into_request();
172
173 let request = ManuallyDrop::new(request);
176 let request_raw = request.handle.as_raw();
177
178 let ws_raw = unsafe { WinHttpWebSocketCompleteUpgrade(request_raw, None) };
179 if ws_raw.is_null() {
180 return Err(Error::from_thread());
181 }
182
183 let handle = unsafe { WinHttpHandle::from_raw(ws_raw) }
184 .expect("WinHttpWebSocketCompleteUpgrade returned non-null");
185
186 let (send_tx, send_rx) = mpsc::unbounded_blocking();
188 let (recv_tx, recv_rx) = mpsc::unbounded_blocking();
189 let context = WebSocketContext::new(send_tx, recv_tx);
190
191 unsafe {
193 WinHttpSetStatusCallback(
194 handle.as_raw(),
195 Some(async_websocket_callback),
196 WINHTTP_CALLBACK_FLAG_ALL_NOTIFICATIONS,
197 0,
198 )
199 };
200
201 let ctx_ptr: usize = &*context as *const WebSocketContext as usize;
203 unsafe {
204 WinHttpSetOption(
205 Some(handle.as_raw()),
206 WINHTTP_OPTION_CONTEXT_VALUE,
207 Some(&ctx_ptr.to_ne_bytes()),
208 )?;
209 }
210
211 unsafe {
213 let _ = WinHttpCloseHandle(request_raw);
214 }
215
216 Ok(Self {
217 handle,
218 context,
219 send_receiver: send_rx,
220 recv_receiver: recv_rx,
221 })
222 }
223
224 pub fn send(
226 &self,
227 data: &[u8],
228 buffer_type: WINHTTP_WEB_SOCKET_BUFFER_TYPE,
229 ) -> WsSendFuture<'_> {
230 WsSendFuture {
231 ws: self,
232 data: data.to_vec(),
233 buffer_type,
234 initiated: false,
235 }
236 }
237
238 pub fn send_text(&self, text: &str) -> WsSendFuture<'_> {
240 self.send(text.as_bytes(), WINHTTP_WEB_SOCKET_UTF8_MESSAGE_BUFFER_TYPE)
241 }
242
243 pub fn send_binary(&self, data: &[u8]) -> WsSendFuture<'_> {
245 self.send(data, WINHTTP_WEB_SOCKET_BINARY_MESSAGE_BUFFER_TYPE)
246 }
247
248 pub fn receive(&self) -> WsReceiveFuture<'_> {
253 WsReceiveFuture {
254 ws: self,
255 buffer: vec![0u8; 8192],
256 fragments: Vec::new(),
257 is_text: None,
258 initiated: false,
259 }
260 }
261
262 pub fn close(&self, status: u16, reason: &str) -> WsCloseFuture<'_> {
264 WsCloseFuture {
265 ws: self,
266 status,
267 reason: reason.as_bytes().to_vec(),
268 initiated: false,
269 }
270 }
271
272 pub fn into_stream(self) -> WebSocketStream {
277 WebSocketStream {
278 ws: self,
279 buffer: vec![0u8; 8192],
280 fragments: Vec::new(),
281 is_text: None,
282 initiated: false,
283 closed: false,
284 }
285 }
286}
287
288pub struct WsSendFuture<'ws> {
293 ws: &'ws AsyncWebSocket,
294 data: Vec<u8>,
295 buffer_type: WINHTTP_WEB_SOCKET_BUFFER_TYPE,
296 initiated: bool,
297}
298
299impl Future for WsSendFuture<'_> {
300 type Output = Result<()>;
301
302 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
303 let this = &mut *self;
304 this.ws.context.set_send_waker(cx.waker());
305
306 if !this.initiated {
307 this.initiated = true;
308 let status = unsafe {
309 WinHttpWebSocketSend(this.ws.handle.as_raw(), this.buffer_type, Some(&this.data))
310 };
311 if status == 0 {
312 return Poll::Ready(Ok(()));
314 }
315 if status != ERROR_IO_PENDING_VALUE {
316 return Poll::Ready(Err(Error::from_thread()));
317 }
318 }
320
321 match this.ws.send_receiver.try_recv() {
322 Ok(WsSendEvent::WriteComplete) => Poll::Ready(Ok(())),
323 Ok(WsSendEvent::Error(code)) => Poll::Ready(Err(Error::from(WIN32_ERROR(code)))),
324 Err(crossfire::TryRecvError::Empty) => Poll::Pending,
325 Err(crossfire::TryRecvError::Disconnected) => Poll::Ready(Err(Error::empty())),
326 }
327 }
328}
329
330pub struct WsReceiveFuture<'ws> {
336 ws: &'ws AsyncWebSocket,
337 buffer: Vec<u8>,
338 fragments: Vec<u8>,
339 is_text: Option<bool>,
340 initiated: bool,
341}
342
343impl Future for WsReceiveFuture<'_> {
344 type Output = Result<WebSocketMessage>;
345
346 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
347 let this = &mut *self;
348 this.ws.context.set_recv_waker(cx.waker());
349
350 if !this.initiated {
351 this.initiated = true;
352 if let Err(e) = initiate_receive(&this.ws.handle, &mut this.buffer) {
353 return Poll::Ready(Err(e));
354 }
355 }
356
357 loop {
358 match this.ws.recv_receiver.try_recv() {
359 Ok(WsRecvEvent::ReadComplete {
360 bytes_read,
361 buffer_type,
362 }) => {
363 let data = &this.buffer[..bytes_read as usize];
364 match buffer_type {
365 1 => {
367 if this.is_text.is_none() {
368 this.is_text = Some(false);
369 }
370 this.fragments.extend_from_slice(data);
371 if let Err(e) = initiate_receive(&this.ws.handle, &mut this.buffer) {
373 return Poll::Ready(Err(e));
374 }
375 this.ws.context.set_recv_waker(cx.waker());
376 continue;
377 }
378 3 => {
380 if this.is_text.is_none() {
381 this.is_text = Some(true);
382 }
383 this.fragments.extend_from_slice(data);
384 if let Err(e) = initiate_receive(&this.ws.handle, &mut this.buffer) {
385 return Poll::Ready(Err(e));
386 }
387 this.ws.context.set_recv_waker(cx.waker());
388 continue;
389 }
390 0 => {
392 let mut full = std::mem::take(&mut this.fragments);
393 full.extend_from_slice(data);
394 return Poll::Ready(Ok(WebSocketMessage::Binary(full)));
395 }
396 2 => {
398 let mut full = std::mem::take(&mut this.fragments);
399 full.extend_from_slice(data);
400 let text = String::from_utf8_lossy(&full).into_owned();
401 return Poll::Ready(Ok(WebSocketMessage::Text(text)));
402 }
403 4 => {
405 return Poll::Ready(Ok(WebSocketMessage::Close));
406 }
407 _ => {
408 let mut full = std::mem::take(&mut this.fragments);
410 full.extend_from_slice(data);
411 return Poll::Ready(Ok(WebSocketMessage::Binary(full)));
412 }
413 }
414 }
415 Ok(WsRecvEvent::CloseComplete) => {
416 return Poll::Ready(Ok(WebSocketMessage::Close));
417 }
418 Ok(WsRecvEvent::Error(code)) => {
419 return Poll::Ready(Err(Error::from(WIN32_ERROR(code))));
420 }
421 Err(crossfire::TryRecvError::Empty) => return Poll::Pending,
422 Err(crossfire::TryRecvError::Disconnected) => {
423 return Poll::Ready(Err(Error::empty()));
424 }
425 }
426 }
427 }
428}
429
430fn initiate_receive(handle: &WinHttpHandle, buffer: &mut [u8]) -> Result<()> {
432 let mut bytes_read = 0u32;
433 let mut buffer_type = WINHTTP_WEB_SOCKET_BUFFER_TYPE::default();
434
435 let status = unsafe {
436 WinHttpWebSocketReceive(
437 handle.as_raw(),
438 buffer.as_mut_ptr() as *mut _,
439 buffer.len() as u32,
440 &mut bytes_read,
441 &mut buffer_type,
442 )
443 };
444
445 if status != 0 && status != ERROR_IO_PENDING_VALUE {
446 return Err(Error::from_thread());
447 }
448 Ok(())
449}
450
451pub struct WsCloseFuture<'ws> {
455 ws: &'ws AsyncWebSocket,
456 status: u16,
457 reason: Vec<u8>,
458 initiated: bool,
459}
460
461impl Future for WsCloseFuture<'_> {
462 type Output = Result<()>;
463
464 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
465 let this = &mut *self;
466 this.ws.context.set_recv_waker(cx.waker());
467
468 if !this.initiated {
469 this.initiated = true;
470 let reason_ptr = if this.reason.is_empty() {
471 None
472 } else {
473 Some(this.reason.as_ptr() as *const c_void)
474 };
475 let status = unsafe {
476 WinHttpWebSocketClose(
477 this.ws.handle.as_raw(),
478 this.status,
479 reason_ptr,
480 this.reason.len() as u32,
481 )
482 };
483 if status == 0 {
484 return Poll::Ready(Ok(()));
485 }
486 if status != ERROR_IO_PENDING_VALUE {
487 return Poll::Ready(Err(Error::from_thread()));
488 }
489 }
490
491 loop {
492 match this.ws.recv_receiver.try_recv() {
493 Ok(WsRecvEvent::CloseComplete) => return Poll::Ready(Ok(())),
494 Ok(WsRecvEvent::Error(code)) => {
495 return Poll::Ready(Err(Error::from(WIN32_ERROR(code))));
496 }
497 Ok(_) => continue,
498 Err(crossfire::TryRecvError::Empty) => return Poll::Pending,
499 Err(crossfire::TryRecvError::Disconnected) => {
500 return Poll::Ready(Err(Error::empty()));
501 }
502 }
503 }
504 }
505}
506
507pub struct WebSocketStream {
514 ws: AsyncWebSocket,
515 buffer: Vec<u8>,
516 fragments: Vec<u8>,
517 is_text: Option<bool>,
518 initiated: bool,
519 closed: bool,
520}
521
522impl futures_core::Stream for WebSocketStream {
523 type Item = Result<WebSocketMessage>;
524
525 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
526 let this = &mut *self;
527
528 if this.closed {
529 return Poll::Ready(None);
530 }
531
532 this.ws.context.set_recv_waker(cx.waker());
533
534 if !this.initiated {
535 this.initiated = true;
536 if let Err(e) = initiate_receive(&this.ws.handle, &mut this.buffer) {
537 this.closed = true;
538 return Poll::Ready(Some(Err(e)));
539 }
540 }
541
542 loop {
543 match this.ws.recv_receiver.try_recv() {
544 Ok(WsRecvEvent::ReadComplete {
545 bytes_read,
546 buffer_type,
547 }) => {
548 let data = &this.buffer[..bytes_read as usize];
549 match buffer_type {
550 1 => {
552 if this.is_text.is_none() {
553 this.is_text = Some(false);
554 }
555 this.fragments.extend_from_slice(data);
556 if let Err(e) = initiate_receive(&this.ws.handle, &mut this.buffer) {
557 this.closed = true;
558 return Poll::Ready(Some(Err(e)));
559 }
560 this.ws.context.set_recv_waker(cx.waker());
561 continue;
562 }
563 3 => {
565 if this.is_text.is_none() {
566 this.is_text = Some(true);
567 }
568 this.fragments.extend_from_slice(data);
569 if let Err(e) = initiate_receive(&this.ws.handle, &mut this.buffer) {
570 this.closed = true;
571 return Poll::Ready(Some(Err(e)));
572 }
573 this.ws.context.set_recv_waker(cx.waker());
574 continue;
575 }
576 0 => {
578 let mut full = std::mem::take(&mut this.fragments);
579 full.extend_from_slice(data);
580 this.is_text = None;
581 this.initiated = false;
583 return Poll::Ready(Some(Ok(WebSocketMessage::Binary(full))));
584 }
585 2 => {
587 let mut full = std::mem::take(&mut this.fragments);
588 full.extend_from_slice(data);
589 let text = String::from_utf8_lossy(&full).into_owned();
590 this.is_text = None;
591 this.initiated = false;
592 return Poll::Ready(Some(Ok(WebSocketMessage::Text(text))));
593 }
594 4 => {
596 this.closed = true;
597 return Poll::Ready(None);
598 }
599 _ => {
600 let mut full = std::mem::take(&mut this.fragments);
601 full.extend_from_slice(data);
602 this.is_text = None;
603 this.initiated = false;
604 return Poll::Ready(Some(Ok(WebSocketMessage::Binary(full))));
605 }
606 }
607 }
608 Ok(WsRecvEvent::CloseComplete) => {
609 this.closed = true;
610 return Poll::Ready(None);
611 }
612 Ok(WsRecvEvent::Error(code)) => {
613 this.closed = true;
614 return Poll::Ready(Some(Err(Error::from(WIN32_ERROR(code)))));
615 }
616 Err(crossfire::TryRecvError::Empty) => return Poll::Pending,
617 Err(crossfire::TryRecvError::Disconnected) => {
618 this.closed = true;
619 return Poll::Ready(None);
620 }
621 }
622 }
623 }
624}
625
626#[cfg(test)]
629mod tests {
630 use super::*;
631
632 #[test]
633 fn websocket_message_text_equality() {
634 let msg = WebSocketMessage::Text("hello".to_string());
635 assert_eq!(msg, WebSocketMessage::Text("hello".to_string()));
636 assert_ne!(msg, WebSocketMessage::Text("world".to_string()));
637 }
638
639 #[test]
640 fn websocket_message_binary_equality() {
641 let msg = WebSocketMessage::Binary(vec![1, 2, 3]);
642 assert_eq!(msg, WebSocketMessage::Binary(vec![1, 2, 3]));
643 assert_ne!(msg, WebSocketMessage::Binary(vec![4, 5, 6]));
644 }
645
646 #[test]
647 fn websocket_message_close() {
648 let msg = WebSocketMessage::Close;
649 assert_eq!(msg, WebSocketMessage::Close);
650 assert_ne!(msg, WebSocketMessage::Text(String::new()));
651 }
652
653 #[test]
654 fn websocket_message_debug_format() {
655 let text = WebSocketMessage::Text("hi".to_string());
656 let debug = format!("{text:?}");
657 assert!(debug.contains("Text"));
658 assert!(debug.contains("hi"));
659
660 let binary = WebSocketMessage::Binary(vec![0xDE, 0xAD]);
661 let debug = format!("{binary:?}");
662 assert!(debug.contains("Binary"));
663
664 let close = WebSocketMessage::Close;
665 let debug = format!("{close:?}");
666 assert!(debug.contains("Close"));
667 }
668
669 #[test]
670 fn websocket_message_clone() {
671 let original = WebSocketMessage::Text("test".to_string());
672 let cloned = original.clone();
673 assert_eq!(original, cloned);
674
675 let original = WebSocketMessage::Binary(vec![1, 2, 3]);
676 let cloned = original.clone();
677 assert_eq!(original, cloned);
678 }
679
680 #[test]
681 fn websocket_message_variants_are_distinct() {
682 let text = WebSocketMessage::Text(String::new());
683 let binary = WebSocketMessage::Binary(Vec::new());
684 let close = WebSocketMessage::Close;
685
686 assert_ne!(text, binary);
687 assert_ne!(text, close);
688 assert_ne!(binary, close);
689 }
690}