spacetimedb_client_api/util/
websocket.rs

1//! A more flexible version of axum::extract::ws. This could probably get pulled out into its own crate at some point.
2
3use axum::extract::FromRequestParts;
4use axum::response::{IntoResponse, Response};
5use axum_extra::TypedHeader;
6use headers::{Connection, HeaderMapExt, SecWebsocketAccept, SecWebsocketKey, SecWebsocketVersion, Upgrade};
7use http::{HeaderName, HeaderValue, Method, StatusCode};
8use hyper::upgrade::{OnUpgrade, Upgraded};
9use hyper_util::rt::TokioIo;
10
11use super::flat_csv::FlatCsv;
12
13pub use tokio_tungstenite::tungstenite;
14pub use tungstenite::{
15    error::Error as WsError,
16    protocol::{frame::coding::CloseCode, CloseFrame, Message, WebSocketConfig},
17};
18
19pub type WebSocketStream = tokio_tungstenite::WebSocketStream<TokioIo<Upgraded>>;
20
21pub struct RequestSecWebsocketProtocol(FlatCsv);
22
23impl headers::Header for RequestSecWebsocketProtocol {
24    fn name() -> &'static HeaderName {
25        &http::header::SEC_WEBSOCKET_PROTOCOL
26    }
27    fn decode<'i, I: Iterator<Item = &'i HeaderValue>>(values: &mut I) -> Result<Self, headers::Error> {
28        Ok(Self(values.collect()))
29    }
30    fn encode<E: Extend<HeaderValue>>(&self, values: &mut E) {
31        values.extend([self.0.value.clone()])
32    }
33}
34
35impl RequestSecWebsocketProtocol {
36    pub fn iter(&self) -> impl Iterator<Item = &str> {
37        self.0.iter()
38    }
39
40    pub fn select<S, P>(&self, protocols: impl IntoIterator<Item = (S, P)>) -> Option<(ResponseSecWebsocketProtocol, P)>
41    where
42        S: for<'a> PartialEq<&'a str> + TryInto<HeaderValue>,
43    {
44        protocols
45            .into_iter()
46            .find(|(protoname, _)| self.iter().any(|x| *protoname == x))
47            .map(|(protoname, proto)| {
48                let proto_header = protoname.try_into().unwrap_or_else(|_| unreachable!());
49                (ResponseSecWebsocketProtocol(proto_header), proto)
50            })
51    }
52}
53
54pub struct ResponseSecWebsocketProtocol(pub HeaderValue);
55
56impl headers::Header for ResponseSecWebsocketProtocol {
57    fn name() -> &'static HeaderName {
58        &http::header::SEC_WEBSOCKET_PROTOCOL
59    }
60    fn decode<'i, I: Iterator<Item = &'i HeaderValue>>(values: &mut I) -> Result<Self, headers::Error> {
61        values.next().cloned().map(Self).ok_or_else(headers::Error::invalid)
62    }
63    fn encode<E: Extend<HeaderValue>>(&self, values: &mut E) {
64        values.extend([self.0.clone()])
65    }
66}
67
68pub struct WebSocketUpgrade {
69    key: SecWebsocketKey,
70    requested_protocol: Option<RequestSecWebsocketProtocol>,
71    upgrade: OnUpgrade,
72}
73
74pub enum WebSocketUpgradeRejection {
75    MethodNotGet,
76    BadUpgrade,
77    BadVersion,
78    KeyMissing,
79}
80
81#[async_trait::async_trait]
82impl<S> FromRequestParts<S> for WebSocketUpgrade {
83    type Rejection = WebSocketUpgradeRejection;
84    async fn from_request_parts(parts: &mut http::request::Parts, _state: &S) -> Result<Self, Self::Rejection> {
85        use WebSocketUpgradeRejection::*;
86
87        if parts.method != Method::GET {
88            return Err(MethodNotGet);
89        }
90
91        let upgrade = parts
92            .extensions
93            .remove::<OnUpgrade>()
94            .filter(|_| {
95                parts
96                    .headers
97                    .typed_get::<Connection>()
98                    .is_some_and(|conn| conn.contains("upgrade"))
99                    && parts.headers.typed_get::<Upgrade>() == Some(Upgrade::websocket())
100            })
101            .ok_or(BadUpgrade)?;
102
103        if parts.headers.typed_get::<SecWebsocketVersion>() != Some(SecWebsocketVersion::V13) {
104            return Err(BadVersion);
105        }
106
107        let key = parts.headers.typed_get::<SecWebsocketKey>().ok_or(KeyMissing)?;
108
109        let requested_protocol = parts.headers.typed_get::<RequestSecWebsocketProtocol>();
110
111        Ok(WebSocketUpgrade {
112            key,
113            requested_protocol,
114            upgrade,
115        })
116    }
117}
118
119impl IntoResponse for WebSocketUpgradeRejection {
120    fn into_response(self) -> Response {
121        match self {
122            Self::MethodNotGet => (StatusCode::METHOD_NOT_ALLOWED, "Request method must be `GET`").into_response(),
123            Self::BadUpgrade => (
124                StatusCode::UPGRADE_REQUIRED,
125                TypedHeader(Connection::upgrade()),
126                TypedHeader(Upgrade::websocket()),
127                "This service requires use of the websocket protocol",
128            )
129                .into_response(),
130            Self::BadVersion => (
131                StatusCode::BAD_REQUEST,
132                "`Sec-WebSocket-Version` header did not include '13'",
133            )
134                .into_response(),
135            Self::KeyMissing => (StatusCode::BAD_REQUEST, "`Sec-WebSocket-Key` header missing").into_response(),
136        }
137    }
138}
139
140impl WebSocketUpgrade {
141    #[inline]
142    pub fn protocol(&self) -> Option<&RequestSecWebsocketProtocol> {
143        self.requested_protocol.as_ref()
144    }
145
146    /// Select a subprotocol from the ones provided, and prepare a response for the client.
147    pub fn select_protocol<S, P>(
148        self,
149        protocols: impl IntoIterator<Item = (S, P)>,
150    ) -> (WebSocketResponse, PendingWebSocket, Option<P>)
151    where
152        S: for<'a> PartialEq<&'a str> + TryInto<HeaderValue>,
153    {
154        let (proto_header, proto) = self
155            .requested_protocol
156            .as_ref()
157            .and_then(|proto| proto.select(protocols))
158            .unzip();
159        let (resp, ws) = self.into_response(proto_header);
160        (resp, ws, proto)
161    }
162
163    /// Prepare a response with no subprotocol selected.
164    #[inline]
165    pub fn ignore_protocol(self) -> (WebSocketResponse, PendingWebSocket) {
166        self.into_response(None)
167    }
168
169    /// Prepare a response with the given subprotocol.
170    #[inline]
171    pub fn into_response(
172        self,
173        protocol: Option<ResponseSecWebsocketProtocol>,
174    ) -> (WebSocketResponse, PendingWebSocket) {
175        let resp = WebSocketResponse {
176            accept: self.key.into(),
177            protocol,
178        };
179        (resp, PendingWebSocket(self.upgrade))
180    }
181}
182
183pub struct PendingWebSocket(OnUpgrade);
184
185impl PendingWebSocket {
186    #[inline]
187    pub async fn upgrade(self, config: WebSocketConfig) -> hyper::Result<WebSocketStream> {
188        let stream = TokioIo::new(self.0.await?);
189        Ok(WebSocketStream::from_raw_socket(stream, tungstenite::protocol::Role::Server, Some(config)).await)
190    }
191
192    #[inline]
193    pub fn into_inner(self) -> OnUpgrade {
194        self.0
195    }
196}
197
198/// An type representing an http response for a successful websocket upgrade. Note that this response
199/// must be returned to the client for [`PendingWebSocket::upgrade`] to succeed.
200pub struct WebSocketResponse {
201    accept: SecWebsocketAccept,
202    protocol: Option<ResponseSecWebsocketProtocol>,
203}
204
205impl IntoResponse for WebSocketResponse {
206    #[inline]
207    fn into_response(self) -> Response {
208        (
209            StatusCode::SWITCHING_PROTOCOLS,
210            TypedHeader(Connection::upgrade()),
211            TypedHeader(Upgrade::websocket()),
212            TypedHeader(self.accept),
213            self.protocol.map(TypedHeader),
214            (),
215        )
216            .into_response()
217    }
218}