rustapi_ws/extractor.rs
1//! WebSocket extractor
2
3use crate::upgrade::{validate_upgrade_request, WebSocketUpgrade};
4use hyper::upgrade::OnUpgrade;
5use rustapi_core::{ApiError, FromRequest, Request, Result};
6use rustapi_openapi::{Operation, OperationModifier};
7
8/// WebSocket extractor for upgrading HTTP connections to WebSocket
9///
10/// Use this extractor in your handler to initiate a WebSocket upgrade.
11/// The extractor validates the upgrade request and returns a `WebSocket`
12/// that can be used to set up the connection handler.
13///
14/// # Example
15///
16/// ```rust,ignore
17/// use rustapi_ws::{WebSocket, Message};
18///
19/// async fn ws_handler(ws: WebSocket) -> impl IntoResponse {
20/// ws.on_upgrade(|socket| async move {
21/// let (mut sender, mut receiver) = socket.split();
22///
23/// while let Some(Ok(msg)) = receiver.next().await {
24/// match msg {
25/// Message::Text(text) => {
26/// // Echo back
27/// let _ = sender.send(Message::text(format!("Echo: {}", text))).await;
28/// }
29/// Message::Close(_) => break,
30/// _ => {}
31/// }
32/// }
33/// })
34/// }
35/// ```
36pub struct WebSocket {
37 sec_key: String,
38 protocols: Vec<String>,
39 extensions: Option<String>,
40 on_upgrade: Option<OnUpgrade>,
41}
42
43impl WebSocket {
44 /// Create a WebSocket upgrade response with a handler
45 ///
46 /// The provided callback will be called with the established WebSocket
47 /// stream once the upgrade is complete.
48 pub fn on_upgrade<F, Fut>(mut self, callback: F) -> WebSocketUpgrade
49 where
50 F: FnOnce(crate::WebSocketStream) -> Fut + Send + 'static,
51 Fut: std::future::Future<Output = ()> + Send + 'static,
52 {
53 let upgrade = WebSocketUpgrade::new(self.sec_key, self.extensions, self.on_upgrade.take());
54
55 // If protocols were requested, select the first one
56 let upgrade = if let Some(protocol) = self.protocols.first() {
57 upgrade.protocol(protocol)
58 } else {
59 upgrade
60 };
61
62 upgrade.on_upgrade(callback)
63 }
64
65 /// Get the requested protocols
66 pub fn protocols(&self) -> &[String] {
67 &self.protocols
68 }
69
70 /// Check if a specific protocol was requested
71 pub fn has_protocol(&self, protocol: &str) -> bool {
72 self.protocols.iter().any(|p| p == protocol)
73 }
74}
75
76impl FromRequest for WebSocket {
77 async fn from_request(req: &mut Request) -> Result<Self> {
78 let headers = req.headers();
79 let method = req.method();
80
81 // Validate the upgrade request
82 // Note: we clone sec_key to avoid keeping borrow of headers
83 let sec_key = validate_upgrade_request(method, headers)
84 .map_err(ApiError::from)?
85 .to_string();
86
87 // Parse requested protocols
88 let protocols = headers
89 .get("Sec-WebSocket-Protocol")
90 .and_then(|v| v.to_str().ok())
91 .map(|s| s.split(',').map(|p| p.trim().to_string()).collect())
92 .unwrap_or_default();
93
94 // Get extensions
95 let extensions = headers
96 .get("Sec-WebSocket-Extensions")
97 .and_then(|v| v.to_str().ok())
98 .map(|s| s.to_string());
99
100 // Capture OnUpgrade future
101 let on_upgrade = req.extensions_mut().remove::<OnUpgrade>();
102
103 // IMPORTANT: Consume the request body to ensure hyper allows the upgrade.
104 if let Some(stream) = req.take_stream() {
105 use http_body_util::BodyExt;
106 let _ = stream.collect().await;
107 }
108
109 Ok(Self {
110 sec_key,
111 protocols,
112 extensions,
113 on_upgrade,
114 })
115 }
116}
117
118impl OperationModifier for WebSocket {
119 fn update_operation(_op: &mut Operation) {
120 // WebSocket endpoints don't have regular request body parameters
121 // The upgrade is indicated by the response
122 }
123}