Skip to main content

winhttp/
async_websocket.rs

1use crate::async_request::AsyncResponse;
2use crate::handle::WinHttpHandle;
3use crossfire::mpsc;
4use parking_lot::Mutex;
5use std::ffi::c_void;
6use std::future::Future;
7use std::mem::ManuallyDrop;
8use std::pin::Pin;
9use std::sync::Arc;
10use std::task::{Context, Poll, Waker};
11use windows::Win32::Foundation::WIN32_ERROR;
12use windows::Win32::Networking::WinHttp::*;
13use windows::core::{Error, Result};
14
15/// Windows error code for an asynchronous operation that has been initiated
16/// but not yet completed.
17const ERROR_IO_PENDING_VALUE: u32 = 997;
18
19#[derive(Debug)]
20enum WsSendEvent {
21    WriteComplete,
22    Error(u32),
23}
24
25#[derive(Debug)]
26enum WsRecvEvent {
27    ReadComplete { bytes_read: u32, buffer_type: i32 },
28    CloseComplete,
29    Error(u32),
30}
31
32type SendSender = crossfire::MTx<mpsc::List<WsSendEvent>>;
33type SendReceiver = crossfire::Rx<mpsc::List<WsSendEvent>>;
34type RecvSender = crossfire::MTx<mpsc::List<WsRecvEvent>>;
35type RecvReceiver = crossfire::Rx<mpsc::List<WsRecvEvent>>;
36
37struct WebSocketContext {
38    send_waker: Mutex<Option<Waker>>,
39    recv_waker: Mutex<Option<Waker>>,
40    send_sender: SendSender,
41    recv_sender: RecvSender,
42}
43
44impl WebSocketContext {
45    fn new(send_sender: SendSender, recv_sender: RecvSender) -> Pin<Arc<Self>> {
46        Arc::pin(Self {
47            send_waker: Mutex::new(None),
48            recv_waker: Mutex::new(None),
49            send_sender,
50            recv_sender,
51        })
52    }
53
54    fn wake_send(&self) {
55        if let Some(waker) = self.send_waker.lock().take() {
56            waker.wake();
57        }
58    }
59
60    fn wake_recv(&self) {
61        if let Some(waker) = self.recv_waker.lock().take() {
62            waker.wake();
63        }
64    }
65
66    fn set_send_waker(&self, waker: &Waker) {
67        let mut guard = self.send_waker.lock();
68        match guard.as_ref() {
69            Some(existing) if existing.will_wake(waker) => {}
70            _ => *guard = Some(waker.clone()),
71        }
72    }
73
74    fn set_recv_waker(&self, waker: &Waker) {
75        let mut guard = self.recv_waker.lock();
76        match guard.as_ref() {
77            Some(existing) if existing.will_wake(waker) => {}
78            _ => *guard = Some(waker.clone()),
79        }
80    }
81}
82
83/// # Safety
84///
85/// This function is called by WinHTTP on its own thread pool. The `context`
86/// parameter is a raw pointer to a pinned `WebSocketContext` whose lifetime is
87/// guaranteed by the owning `AsyncWebSocket`.
88unsafe extern "system" fn async_websocket_callback(
89    _hinternet: *mut c_void,
90    context: usize,
91    status: u32,
92    status_info: *mut c_void,
93    _status_info_length: u32,
94) {
95    let Some(ctx) = (unsafe { (context as *const WebSocketContext).as_ref() }) else {
96        return;
97    };
98
99    match status {
100        WINHTTP_CALLBACK_STATUS_WRITE_COMPLETE => {
101            let _ = ctx.send_sender.send(WsSendEvent::WriteComplete);
102            ctx.wake_send();
103        }
104        WINHTTP_CALLBACK_STATUS_READ_COMPLETE => {
105            if !status_info.is_null() {
106                let ws_status = unsafe { &*(status_info as *const WINHTTP_WEB_SOCKET_STATUS) };
107                let _ = ctx.recv_sender.send(WsRecvEvent::ReadComplete {
108                    bytes_read: ws_status.dwBytesTransferred,
109                    buffer_type: ws_status.eBufferType.0,
110                });
111            }
112            ctx.wake_recv();
113        }
114        WINHTTP_CALLBACK_STATUS_CLOSE_COMPLETE => {
115            let _ = ctx.recv_sender.send(WsRecvEvent::CloseComplete);
116            ctx.wake_recv();
117        }
118        WINHTTP_CALLBACK_STATUS_SHUTDOWN_COMPLETE => {
119            let _ = ctx.send_sender.send(WsSendEvent::WriteComplete);
120            ctx.wake_send();
121        }
122        WINHTTP_CALLBACK_STATUS_REQUEST_ERROR => {
123            if !status_info.is_null() {
124                let ws_err = unsafe { &*(status_info as *const WINHTTP_WEB_SOCKET_ASYNC_RESULT) };
125                let error_code = ws_err.AsyncResult.dwError;
126                // Operation.0: 0 = Send, else = Receive/Close/Shutdown
127                if ws_err.Operation.0 == 0 {
128                    let _ = ctx.send_sender.send(WsSendEvent::Error(error_code));
129                    ctx.wake_send();
130                } else {
131                    let _ = ctx.recv_sender.send(WsRecvEvent::Error(error_code));
132                    ctx.wake_recv();
133                }
134            }
135        }
136        _ => {}
137    }
138}
139
140/// A complete WebSocket message received from the server.
141#[derive(Debug, Clone, PartialEq, Eq)]
142pub enum WebSocketMessage {
143    /// A complete UTF-8 text message.
144    Text(String),
145    /// A complete binary message.
146    Binary(Vec<u8>),
147    /// A close frame was received. The connection is being shut down.
148    Close,
149}
150
151/// An async WebSocket connection built on top of WinHTTP's async machinery.
152///
153/// Two independent crossfire channels (send and receive) allow concurrent
154/// `send` + `receive` operations without contention.
155///
156/// Created from an [`AsyncResponse`] whose HTTP 101 upgrade has completed.
157pub struct AsyncWebSocket {
158    handle: WinHttpHandle,
159    context: Pin<Arc<WebSocketContext>>,
160    send_receiver: SendReceiver,
161    recv_receiver: RecvReceiver,
162}
163
164impl AsyncWebSocket {
165    /// Upgrade an HTTP 101 response into an async WebSocket connection.
166    ///
167    /// Internally calls `WinHttpWebSocketCompleteUpgrade`, closes the
168    /// original request handle, then installs a new callback and context
169    /// on the WebSocket handle.
170    pub fn from_response(response: AsyncResponse<'_>) -> Result<Self> {
171        let request = response.into_request();
172
173        // Prevent `Request`'s Drop from closing the handle — we need it for
174        // the upgrade call and will close it ourselves afterward.
175        let request = ManuallyDrop::new(request);
176        let request_raw = request.handle.as_raw();
177
178        let ws_raw = unsafe { WinHttpWebSocketCompleteUpgrade(request_raw, None) };
179        if ws_raw.is_null() {
180            return Err(Error::from_thread());
181        }
182
183        let handle = unsafe { WinHttpHandle::from_raw(ws_raw) }
184            .expect("WinHttpWebSocketCompleteUpgrade returned non-null");
185
186        // Create two independent channels.
187        let (send_tx, send_rx) = mpsc::unbounded_blocking();
188        let (recv_tx, recv_rx) = mpsc::unbounded_blocking();
189        let context = WebSocketContext::new(send_tx, recv_tx);
190
191        // Install our callback on the WebSocket handle.
192        unsafe {
193            WinHttpSetStatusCallback(
194                handle.as_raw(),
195                Some(async_websocket_callback),
196                WINHTTP_CALLBACK_FLAG_ALL_NOTIFICATIONS,
197                0,
198            )
199        };
200
201        // Set the context pointer so the callback can reach our channels.
202        let ctx_ptr: usize = &*context as *const WebSocketContext as usize;
203        unsafe {
204            WinHttpSetOption(
205                Some(handle.as_raw()),
206                WINHTTP_OPTION_CONTEXT_VALUE,
207                Some(&ctx_ptr.to_ne_bytes()),
208            )?;
209        }
210
211        // Now safe to close the original request handle.
212        unsafe {
213            let _ = WinHttpCloseHandle(request_raw);
214        }
215
216        Ok(Self {
217            handle,
218            context,
219            send_receiver: send_rx,
220            recv_receiver: recv_rx,
221        })
222    }
223
224    /// Send raw data with the given buffer type.
225    pub fn send(
226        &self,
227        data: &[u8],
228        buffer_type: WINHTTP_WEB_SOCKET_BUFFER_TYPE,
229    ) -> WsSendFuture<'_> {
230        WsSendFuture {
231            ws: self,
232            data: data.to_vec(),
233            buffer_type,
234            initiated: false,
235        }
236    }
237
238    /// Send a UTF-8 text message.
239    pub fn send_text(&self, text: &str) -> WsSendFuture<'_> {
240        self.send(text.as_bytes(), WINHTTP_WEB_SOCKET_UTF8_MESSAGE_BUFFER_TYPE)
241    }
242
243    /// Send a binary message.
244    pub fn send_binary(&self, data: &[u8]) -> WsSendFuture<'_> {
245        self.send(data, WINHTTP_WEB_SOCKET_BINARY_MESSAGE_BUFFER_TYPE)
246    }
247
248    /// Receive a single complete [`WebSocketMessage`].
249    ///
250    /// Fragments are reassembled automatically — the future resolves only
251    /// when a complete text, binary, or close message has been collected.
252    pub fn receive(&self) -> WsReceiveFuture<'_> {
253        WsReceiveFuture {
254            ws: self,
255            buffer: vec![0u8; 8192],
256            fragments: Vec::new(),
257            is_text: None,
258            initiated: false,
259        }
260    }
261
262    /// Initiate a graceful close of the WebSocket connection.
263    pub fn close(&self, status: u16, reason: &str) -> WsCloseFuture<'_> {
264        WsCloseFuture {
265            ws: self,
266            status,
267            reason: reason.as_bytes().to_vec(),
268            initiated: false,
269        }
270    }
271
272    // Stream adapter
273
274    /// Consume the `AsyncWebSocket` and return a [`WebSocketStream`] that
275    /// implements [`futures_core::Stream`].
276    pub fn into_stream(self) -> WebSocketStream {
277        WebSocketStream {
278            ws: self,
279            buffer: vec![0u8; 8192],
280            fragments: Vec::new(),
281            is_text: None,
282            initiated: false,
283            closed: false,
284        }
285    }
286}
287
288// WsSendFuture
289
290/// Future returned by [`AsyncWebSocket::send`], [`send_text`](AsyncWebSocket::send_text),
291/// and [`send_binary`](AsyncWebSocket::send_binary).
292pub struct WsSendFuture<'ws> {
293    ws: &'ws AsyncWebSocket,
294    data: Vec<u8>,
295    buffer_type: WINHTTP_WEB_SOCKET_BUFFER_TYPE,
296    initiated: bool,
297}
298
299impl Future for WsSendFuture<'_> {
300    type Output = Result<()>;
301
302    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
303        let this = &mut *self;
304        this.ws.context.set_send_waker(cx.waker());
305
306        if !this.initiated {
307            this.initiated = true;
308            let status = unsafe {
309                WinHttpWebSocketSend(this.ws.handle.as_raw(), this.buffer_type, Some(&this.data))
310            };
311            if status == 0 {
312                // Synchronous completion (unusual on async sessions).
313                return Poll::Ready(Ok(()));
314            }
315            if status != ERROR_IO_PENDING_VALUE {
316                return Poll::Ready(Err(Error::from_thread()));
317            }
318            // ERROR_IO_PENDING — fall through to drain the channel.
319        }
320
321        match this.ws.send_receiver.try_recv() {
322            Ok(WsSendEvent::WriteComplete) => Poll::Ready(Ok(())),
323            Ok(WsSendEvent::Error(code)) => Poll::Ready(Err(Error::from(WIN32_ERROR(code)))),
324            Err(crossfire::TryRecvError::Empty) => Poll::Pending,
325            Err(crossfire::TryRecvError::Disconnected) => Poll::Ready(Err(Error::empty())),
326        }
327    }
328}
329
330// WsReceiveFuture
331
332/// Future returned by [`AsyncWebSocket::receive`].
333///
334/// Automatically reassembles fragments into a single complete message.
335pub struct WsReceiveFuture<'ws> {
336    ws: &'ws AsyncWebSocket,
337    buffer: Vec<u8>,
338    fragments: Vec<u8>,
339    is_text: Option<bool>,
340    initiated: bool,
341}
342
343impl Future for WsReceiveFuture<'_> {
344    type Output = Result<WebSocketMessage>;
345
346    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
347        let this = &mut *self;
348        this.ws.context.set_recv_waker(cx.waker());
349
350        if !this.initiated {
351            this.initiated = true;
352            if let Err(e) = initiate_receive(&this.ws.handle, &mut this.buffer) {
353                return Poll::Ready(Err(e));
354            }
355        }
356
357        loop {
358            match this.ws.recv_receiver.try_recv() {
359                Ok(WsRecvEvent::ReadComplete {
360                    bytes_read,
361                    buffer_type,
362                }) => {
363                    let data = &this.buffer[..bytes_read as usize];
364                    match buffer_type {
365                        // BinaryFragment
366                        1 => {
367                            if this.is_text.is_none() {
368                                this.is_text = Some(false);
369                            }
370                            this.fragments.extend_from_slice(data);
371                            // Initiate another receive for the next fragment.
372                            if let Err(e) = initiate_receive(&this.ws.handle, &mut this.buffer) {
373                                return Poll::Ready(Err(e));
374                            }
375                            this.ws.context.set_recv_waker(cx.waker());
376                            continue;
377                        }
378                        // Utf8Fragment
379                        3 => {
380                            if this.is_text.is_none() {
381                                this.is_text = Some(true);
382                            }
383                            this.fragments.extend_from_slice(data);
384                            if let Err(e) = initiate_receive(&this.ws.handle, &mut this.buffer) {
385                                return Poll::Ready(Err(e));
386                            }
387                            this.ws.context.set_recv_waker(cx.waker());
388                            continue;
389                        }
390                        // BinaryMessage (complete or final fragment)
391                        0 => {
392                            let mut full = std::mem::take(&mut this.fragments);
393                            full.extend_from_slice(data);
394                            return Poll::Ready(Ok(WebSocketMessage::Binary(full)));
395                        }
396                        // Utf8Message (complete or final fragment)
397                        2 => {
398                            let mut full = std::mem::take(&mut this.fragments);
399                            full.extend_from_slice(data);
400                            let text = String::from_utf8_lossy(&full).into_owned();
401                            return Poll::Ready(Ok(WebSocketMessage::Text(text)));
402                        }
403                        // Close
404                        4 => {
405                            return Poll::Ready(Ok(WebSocketMessage::Close));
406                        }
407                        _ => {
408                            // Unknown buffer type — treat as binary.
409                            let mut full = std::mem::take(&mut this.fragments);
410                            full.extend_from_slice(data);
411                            return Poll::Ready(Ok(WebSocketMessage::Binary(full)));
412                        }
413                    }
414                }
415                Ok(WsRecvEvent::CloseComplete) => {
416                    return Poll::Ready(Ok(WebSocketMessage::Close));
417                }
418                Ok(WsRecvEvent::Error(code)) => {
419                    return Poll::Ready(Err(Error::from(WIN32_ERROR(code))));
420                }
421                Err(crossfire::TryRecvError::Empty) => return Poll::Pending,
422                Err(crossfire::TryRecvError::Disconnected) => {
423                    return Poll::Ready(Err(Error::empty()));
424                }
425            }
426        }
427    }
428}
429
430/// Kick off a single `WinHttpWebSocketReceive` call.
431fn initiate_receive(handle: &WinHttpHandle, buffer: &mut [u8]) -> Result<()> {
432    let mut bytes_read = 0u32;
433    let mut buffer_type = WINHTTP_WEB_SOCKET_BUFFER_TYPE::default();
434
435    let status = unsafe {
436        WinHttpWebSocketReceive(
437            handle.as_raw(),
438            buffer.as_mut_ptr() as *mut _,
439            buffer.len() as u32,
440            &mut bytes_read,
441            &mut buffer_type,
442        )
443    };
444
445    if status != 0 && status != ERROR_IO_PENDING_VALUE {
446        return Err(Error::from_thread());
447    }
448    Ok(())
449}
450
451// WsCloseFuture
452
453/// Future returned by [`AsyncWebSocket::close`].
454pub struct WsCloseFuture<'ws> {
455    ws: &'ws AsyncWebSocket,
456    status: u16,
457    reason: Vec<u8>,
458    initiated: bool,
459}
460
461impl Future for WsCloseFuture<'_> {
462    type Output = Result<()>;
463
464    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
465        let this = &mut *self;
466        this.ws.context.set_recv_waker(cx.waker());
467
468        if !this.initiated {
469            this.initiated = true;
470            let reason_ptr = if this.reason.is_empty() {
471                None
472            } else {
473                Some(this.reason.as_ptr() as *const c_void)
474            };
475            let status = unsafe {
476                WinHttpWebSocketClose(
477                    this.ws.handle.as_raw(),
478                    this.status,
479                    reason_ptr,
480                    this.reason.len() as u32,
481                )
482            };
483            if status == 0 {
484                return Poll::Ready(Ok(()));
485            }
486            if status != ERROR_IO_PENDING_VALUE {
487                return Poll::Ready(Err(Error::from_thread()));
488            }
489        }
490
491        loop {
492            match this.ws.recv_receiver.try_recv() {
493                Ok(WsRecvEvent::CloseComplete) => return Poll::Ready(Ok(())),
494                Ok(WsRecvEvent::Error(code)) => {
495                    return Poll::Ready(Err(Error::from(WIN32_ERROR(code))));
496                }
497                Ok(_) => continue,
498                Err(crossfire::TryRecvError::Empty) => return Poll::Pending,
499                Err(crossfire::TryRecvError::Disconnected) => {
500                    return Poll::Ready(Err(Error::empty()));
501                }
502            }
503        }
504    }
505}
506
507// WebSocketStream — futures_core::Stream adapter
508
509/// A [`futures_core::Stream`] of [`WebSocketMessage`]s.
510///
511/// Created by [`AsyncWebSocket::into_stream`]. Yields `Some(Ok(msg))` for
512/// each received message and `None` when the connection is closed.
513pub struct WebSocketStream {
514    ws: AsyncWebSocket,
515    buffer: Vec<u8>,
516    fragments: Vec<u8>,
517    is_text: Option<bool>,
518    initiated: bool,
519    closed: bool,
520}
521
522impl futures_core::Stream for WebSocketStream {
523    type Item = Result<WebSocketMessage>;
524
525    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
526        let this = &mut *self;
527
528        if this.closed {
529            return Poll::Ready(None);
530        }
531
532        this.ws.context.set_recv_waker(cx.waker());
533
534        if !this.initiated {
535            this.initiated = true;
536            if let Err(e) = initiate_receive(&this.ws.handle, &mut this.buffer) {
537                this.closed = true;
538                return Poll::Ready(Some(Err(e)));
539            }
540        }
541
542        loop {
543            match this.ws.recv_receiver.try_recv() {
544                Ok(WsRecvEvent::ReadComplete {
545                    bytes_read,
546                    buffer_type,
547                }) => {
548                    let data = &this.buffer[..bytes_read as usize];
549                    match buffer_type {
550                        // BinaryFragment
551                        1 => {
552                            if this.is_text.is_none() {
553                                this.is_text = Some(false);
554                            }
555                            this.fragments.extend_from_slice(data);
556                            if let Err(e) = initiate_receive(&this.ws.handle, &mut this.buffer) {
557                                this.closed = true;
558                                return Poll::Ready(Some(Err(e)));
559                            }
560                            this.ws.context.set_recv_waker(cx.waker());
561                            continue;
562                        }
563                        // Utf8Fragment
564                        3 => {
565                            if this.is_text.is_none() {
566                                this.is_text = Some(true);
567                            }
568                            this.fragments.extend_from_slice(data);
569                            if let Err(e) = initiate_receive(&this.ws.handle, &mut this.buffer) {
570                                this.closed = true;
571                                return Poll::Ready(Some(Err(e)));
572                            }
573                            this.ws.context.set_recv_waker(cx.waker());
574                            continue;
575                        }
576                        // BinaryMessage (complete or final fragment)
577                        0 => {
578                            let mut full = std::mem::take(&mut this.fragments);
579                            full.extend_from_slice(data);
580                            this.is_text = None;
581                            // Initiate next receive for the stream.
582                            this.initiated = false;
583                            return Poll::Ready(Some(Ok(WebSocketMessage::Binary(full))));
584                        }
585                        // Utf8Message (complete or final fragment)
586                        2 => {
587                            let mut full = std::mem::take(&mut this.fragments);
588                            full.extend_from_slice(data);
589                            let text = String::from_utf8_lossy(&full).into_owned();
590                            this.is_text = None;
591                            this.initiated = false;
592                            return Poll::Ready(Some(Ok(WebSocketMessage::Text(text))));
593                        }
594                        // Close
595                        4 => {
596                            this.closed = true;
597                            return Poll::Ready(None);
598                        }
599                        _ => {
600                            let mut full = std::mem::take(&mut this.fragments);
601                            full.extend_from_slice(data);
602                            this.is_text = None;
603                            this.initiated = false;
604                            return Poll::Ready(Some(Ok(WebSocketMessage::Binary(full))));
605                        }
606                    }
607                }
608                Ok(WsRecvEvent::CloseComplete) => {
609                    this.closed = true;
610                    return Poll::Ready(None);
611                }
612                Ok(WsRecvEvent::Error(code)) => {
613                    this.closed = true;
614                    return Poll::Ready(Some(Err(Error::from(WIN32_ERROR(code)))));
615                }
616                Err(crossfire::TryRecvError::Empty) => return Poll::Pending,
617                Err(crossfire::TryRecvError::Disconnected) => {
618                    this.closed = true;
619                    return Poll::Ready(None);
620                }
621            }
622        }
623    }
624}
625
626// Unit tests
627
628#[cfg(test)]
629mod tests {
630    use super::*;
631
632    #[test]
633    fn websocket_message_text_equality() {
634        let msg = WebSocketMessage::Text("hello".to_string());
635        assert_eq!(msg, WebSocketMessage::Text("hello".to_string()));
636        assert_ne!(msg, WebSocketMessage::Text("world".to_string()));
637    }
638
639    #[test]
640    fn websocket_message_binary_equality() {
641        let msg = WebSocketMessage::Binary(vec![1, 2, 3]);
642        assert_eq!(msg, WebSocketMessage::Binary(vec![1, 2, 3]));
643        assert_ne!(msg, WebSocketMessage::Binary(vec![4, 5, 6]));
644    }
645
646    #[test]
647    fn websocket_message_close() {
648        let msg = WebSocketMessage::Close;
649        assert_eq!(msg, WebSocketMessage::Close);
650        assert_ne!(msg, WebSocketMessage::Text(String::new()));
651    }
652
653    #[test]
654    fn websocket_message_debug_format() {
655        let text = WebSocketMessage::Text("hi".to_string());
656        let debug = format!("{text:?}");
657        assert!(debug.contains("Text"));
658        assert!(debug.contains("hi"));
659
660        let binary = WebSocketMessage::Binary(vec![0xDE, 0xAD]);
661        let debug = format!("{binary:?}");
662        assert!(debug.contains("Binary"));
663
664        let close = WebSocketMessage::Close;
665        let debug = format!("{close:?}");
666        assert!(debug.contains("Close"));
667    }
668
669    #[test]
670    fn websocket_message_clone() {
671        let original = WebSocketMessage::Text("test".to_string());
672        let cloned = original.clone();
673        assert_eq!(original, cloned);
674
675        let original = WebSocketMessage::Binary(vec![1, 2, 3]);
676        let cloned = original.clone();
677        assert_eq!(original, cloned);
678    }
679
680    #[test]
681    fn websocket_message_variants_are_distinct() {
682        let text = WebSocketMessage::Text(String::new());
683        let binary = WebSocketMessage::Binary(Vec::new());
684        let close = WebSocketMessage::Close;
685
686        assert_ne!(text, binary);
687        assert_ne!(text, close);
688        assert_ne!(binary, close);
689    }
690}