spacetimedb_client_api/util/
websocket.rs1use axum::extract::FromRequestParts;
4use axum::response::{IntoResponse, Response};
5use axum_extra::TypedHeader;
6use headers::{Connection, HeaderMapExt, SecWebsocketAccept, SecWebsocketKey, SecWebsocketVersion, Upgrade};
7use http::{HeaderName, HeaderValue, Method, StatusCode};
8use hyper::upgrade::{OnUpgrade, Upgraded};
9use hyper_util::rt::TokioIo;
10
11use super::flat_csv::FlatCsv;
12
13pub use tokio_tungstenite::tungstenite;
14pub use tungstenite::{
15 error::Error as WsError,
16 protocol::{frame::coding::CloseCode, CloseFrame, Message, WebSocketConfig},
17};
18
19pub type WebSocketStream = tokio_tungstenite::WebSocketStream<TokioIo<Upgraded>>;
20
21pub struct RequestSecWebsocketProtocol(FlatCsv);
22
23impl headers::Header for RequestSecWebsocketProtocol {
24 fn name() -> &'static HeaderName {
25 &http::header::SEC_WEBSOCKET_PROTOCOL
26 }
27 fn decode<'i, I: Iterator<Item = &'i HeaderValue>>(values: &mut I) -> Result<Self, headers::Error> {
28 Ok(Self(values.collect()))
29 }
30 fn encode<E: Extend<HeaderValue>>(&self, values: &mut E) {
31 values.extend([self.0.value.clone()])
32 }
33}
34
35impl RequestSecWebsocketProtocol {
36 pub fn iter(&self) -> impl Iterator<Item = &str> {
37 self.0.iter()
38 }
39
40 pub fn select<S, P>(&self, protocols: impl IntoIterator<Item = (S, P)>) -> Option<(ResponseSecWebsocketProtocol, P)>
41 where
42 S: for<'a> PartialEq<&'a str> + TryInto<HeaderValue>,
43 {
44 protocols
45 .into_iter()
46 .find(|(protoname, _)| self.iter().any(|x| *protoname == x))
47 .map(|(protoname, proto)| {
48 let proto_header = protoname.try_into().unwrap_or_else(|_| unreachable!());
49 (ResponseSecWebsocketProtocol(proto_header), proto)
50 })
51 }
52}
53
54pub struct ResponseSecWebsocketProtocol(pub HeaderValue);
55
56impl headers::Header for ResponseSecWebsocketProtocol {
57 fn name() -> &'static HeaderName {
58 &http::header::SEC_WEBSOCKET_PROTOCOL
59 }
60 fn decode<'i, I: Iterator<Item = &'i HeaderValue>>(values: &mut I) -> Result<Self, headers::Error> {
61 values.next().cloned().map(Self).ok_or_else(headers::Error::invalid)
62 }
63 fn encode<E: Extend<HeaderValue>>(&self, values: &mut E) {
64 values.extend([self.0.clone()])
65 }
66}
67
68pub struct WebSocketUpgrade {
69 key: SecWebsocketKey,
70 requested_protocol: Option<RequestSecWebsocketProtocol>,
71 upgrade: OnUpgrade,
72}
73
74pub enum WebSocketUpgradeRejection {
75 MethodNotGet,
76 BadUpgrade,
77 BadVersion,
78 KeyMissing,
79}
80
81#[async_trait::async_trait]
82impl<S> FromRequestParts<S> for WebSocketUpgrade {
83 type Rejection = WebSocketUpgradeRejection;
84 async fn from_request_parts(parts: &mut http::request::Parts, _state: &S) -> Result<Self, Self::Rejection> {
85 use WebSocketUpgradeRejection::*;
86
87 if parts.method != Method::GET {
88 return Err(MethodNotGet);
89 }
90
91 let upgrade = parts
92 .extensions
93 .remove::<OnUpgrade>()
94 .filter(|_| {
95 parts
96 .headers
97 .typed_get::<Connection>()
98 .is_some_and(|conn| conn.contains("upgrade"))
99 && parts.headers.typed_get::<Upgrade>() == Some(Upgrade::websocket())
100 })
101 .ok_or(BadUpgrade)?;
102
103 if parts.headers.typed_get::<SecWebsocketVersion>() != Some(SecWebsocketVersion::V13) {
104 return Err(BadVersion);
105 }
106
107 let key = parts.headers.typed_get::<SecWebsocketKey>().ok_or(KeyMissing)?;
108
109 let requested_protocol = parts.headers.typed_get::<RequestSecWebsocketProtocol>();
110
111 Ok(WebSocketUpgrade {
112 key,
113 requested_protocol,
114 upgrade,
115 })
116 }
117}
118
119impl IntoResponse for WebSocketUpgradeRejection {
120 fn into_response(self) -> Response {
121 match self {
122 Self::MethodNotGet => (StatusCode::METHOD_NOT_ALLOWED, "Request method must be `GET`").into_response(),
123 Self::BadUpgrade => (
124 StatusCode::UPGRADE_REQUIRED,
125 TypedHeader(Connection::upgrade()),
126 TypedHeader(Upgrade::websocket()),
127 "This service requires use of the websocket protocol",
128 )
129 .into_response(),
130 Self::BadVersion => (
131 StatusCode::BAD_REQUEST,
132 "`Sec-WebSocket-Version` header did not include '13'",
133 )
134 .into_response(),
135 Self::KeyMissing => (StatusCode::BAD_REQUEST, "`Sec-WebSocket-Key` header missing").into_response(),
136 }
137 }
138}
139
140impl WebSocketUpgrade {
141 #[inline]
142 pub fn protocol(&self) -> Option<&RequestSecWebsocketProtocol> {
143 self.requested_protocol.as_ref()
144 }
145
146 pub fn select_protocol<S, P>(
148 self,
149 protocols: impl IntoIterator<Item = (S, P)>,
150 ) -> (WebSocketResponse, PendingWebSocket, Option<P>)
151 where
152 S: for<'a> PartialEq<&'a str> + TryInto<HeaderValue>,
153 {
154 let (proto_header, proto) = self
155 .requested_protocol
156 .as_ref()
157 .and_then(|proto| proto.select(protocols))
158 .unzip();
159 let (resp, ws) = self.into_response(proto_header);
160 (resp, ws, proto)
161 }
162
163 #[inline]
165 pub fn ignore_protocol(self) -> (WebSocketResponse, PendingWebSocket) {
166 self.into_response(None)
167 }
168
169 #[inline]
171 pub fn into_response(
172 self,
173 protocol: Option<ResponseSecWebsocketProtocol>,
174 ) -> (WebSocketResponse, PendingWebSocket) {
175 let resp = WebSocketResponse {
176 accept: self.key.into(),
177 protocol,
178 };
179 (resp, PendingWebSocket(self.upgrade))
180 }
181}
182
183pub struct PendingWebSocket(OnUpgrade);
184
185impl PendingWebSocket {
186 #[inline]
187 pub async fn upgrade(self, config: WebSocketConfig) -> hyper::Result<WebSocketStream> {
188 let stream = TokioIo::new(self.0.await?);
189 Ok(WebSocketStream::from_raw_socket(stream, tungstenite::protocol::Role::Server, Some(config)).await)
190 }
191
192 #[inline]
193 pub fn into_inner(self) -> OnUpgrade {
194 self.0
195 }
196}
197
198pub struct WebSocketResponse {
201 accept: SecWebsocketAccept,
202 protocol: Option<ResponseSecWebsocketProtocol>,
203}
204
205impl IntoResponse for WebSocketResponse {
206 #[inline]
207 fn into_response(self) -> Response {
208 (
209 StatusCode::SWITCHING_PROTOCOLS,
210 TypedHeader(Connection::upgrade()),
211 TypedHeader(Upgrade::websocket()),
212 TypedHeader(self.accept),
213 self.protocol.map(TypedHeader),
214 (),
215 )
216 .into_response()
217 }
218}