1use std::sync::Arc;
2
3use axum::{
4 Router,
5 extract::{
6 Query, State,
7 ws::{WebSocketUpgrade, rejection::WebSocketUpgradeRejection},
8 },
9 http::{HeaderMap, StatusCode, header},
10 response::{IntoResponse, Response},
11 routing::get,
12};
13use serde::Deserialize;
14
15use super::{RealtimeError, SocketAppState};
16
17#[derive(Debug, Clone)]
18pub struct RealtimeRouteOptions {
19 pub path: &'static str,
20 pub allow_query_token: bool,
21 pub strict_header_precedence: bool,
22}
23
24impl Default for RealtimeRouteOptions {
25 fn default() -> Self {
26 Self {
27 path: "/realtime/socket",
28 allow_query_token: true,
29 strict_header_precedence: true,
30 }
31 }
32}
33
34struct SocketRouteState {
35 socket_server_handle: Arc<SocketAppState>,
36 options: RealtimeRouteOptions,
37}
38
39impl Clone for SocketRouteState {
40 fn clone(&self) -> Self {
41 Self {
42 socket_server_handle: Arc::clone(&self.socket_server_handle),
43 options: self.options.clone(),
44 }
45 }
46}
47
48#[derive(Debug, Deserialize, Default)]
49struct SocketQuery {
50 token: Option<String>,
51}
52
53#[derive(Debug)]
54enum RealtimeHttpError {
55 MissingToken,
56 InvalidToken,
57 UpgradeRequired,
58 RealtimeDisabled,
59 VerifyFailed(RealtimeError),
60}
61
62impl RealtimeHttpError {
63 fn status(&self) -> StatusCode {
64 match self {
65 Self::MissingToken | Self::InvalidToken => StatusCode::UNAUTHORIZED,
66 Self::UpgradeRequired => StatusCode::BAD_REQUEST,
67 Self::RealtimeDisabled => StatusCode::NOT_FOUND,
68 Self::VerifyFailed(err) => match err {
69 RealtimeError::BadRequest(_) => StatusCode::BAD_REQUEST,
70 RealtimeError::Unauthorized(_) => StatusCode::UNAUTHORIZED,
71 RealtimeError::Forbidden(_) => StatusCode::FORBIDDEN,
72 RealtimeError::NotFound(_) => StatusCode::NOT_FOUND,
73 RealtimeError::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
74 },
75 }
76 }
77
78 fn message(&self) -> String {
79 match self {
80 Self::MissingToken => {
81 "Missing access token (use Authorization Bearer or token query param)".to_string()
82 }
83 Self::InvalidToken => "Missing/invalid Authorization header".to_string(),
84 Self::UpgradeRequired => "WebSocket upgrade required".to_string(),
85 Self::RealtimeDisabled => "Realtime is disabled".to_string(),
86 Self::VerifyFailed(err) => err.message().to_string(),
87 }
88 }
89}
90
91impl IntoResponse for RealtimeHttpError {
92 fn into_response(self) -> Response {
93 (self.status(), self.message()).into_response()
94 }
95}
96
97pub fn router(socket_server_handle: Arc<SocketAppState>) -> Router {
98 router_with_options(socket_server_handle, RealtimeRouteOptions::default())
99}
100
101pub fn router_with_options(
102 socket_server_handle: Arc<SocketAppState>,
103 options: RealtimeRouteOptions,
104) -> Router {
105 let path = options.path;
106 Router::new()
107 .route(path, get(socket_handler))
108 .with_state(SocketRouteState {
109 socket_server_handle,
110 options,
111 })
112}
113
114async fn socket_handler(
115 State(handler_state): State<SocketRouteState>,
116 upgrade: Result<WebSocketUpgrade, WebSocketUpgradeRejection>,
117 headers: HeaderMap,
118 Query(query): Query<SocketQuery>,
119) -> Response {
120 let realtime = handler_state.socket_server_handle.handle.clone();
121
122 if !realtime.is_enabled() {
123 return RealtimeHttpError::RealtimeDisabled.into_response();
124 }
125
126 let upgrade = match upgrade {
127 Ok(upgrade) => upgrade,
128 Err(_) => return RealtimeHttpError::UpgradeRequired.into_response(),
129 };
130
131 let token = match extract_access_token(&headers, &query, &handler_state.options) {
132 Ok(token) => token,
133 Err(err) => return err.into_response(),
134 };
135
136 let auth = match handler_state
137 .socket_server_handle
138 .verifier
139 .verify_token(&token)
140 .await
141 {
142 Ok(auth) => auth,
143 Err(err) => return RealtimeHttpError::VerifyFailed(err).into_response(),
144 };
145
146 upgrade
147 .max_message_size(realtime.max_message_bytes())
148 .max_frame_size(realtime.max_message_bytes())
149 .on_upgrade(move |socket| async move {
150 realtime.serve_socket(socket, auth).await;
151 })
152 .into_response()
153}
154
155fn extract_access_token(
156 headers: &HeaderMap,
157 query: &SocketQuery,
158 options: &RealtimeRouteOptions,
159) -> Result<String, RealtimeHttpError> {
160 let auth_header = headers
161 .get(header::AUTHORIZATION)
162 .and_then(|value| value.to_str().ok());
163
164 if let Some(auth_header) = auth_header {
165 let header_token = auth_header
166 .strip_prefix("Bearer ")
167 .map(str::trim)
168 .filter(|value| !value.is_empty());
169
170 if let Some(token) = header_token {
171 return Ok(token.to_string());
172 }
173
174 if options.strict_header_precedence {
175 return Err(RealtimeHttpError::InvalidToken);
176 }
177 }
178
179 if options.allow_query_token
180 && let Some(token) = query
181 .token
182 .as_deref()
183 .map(str::trim)
184 .filter(|value| !value.is_empty())
185 {
186 return Ok(token.to_string());
187 }
188
189 Err(RealtimeHttpError::MissingToken)
190}
191
192#[cfg(test)]
193mod tests {
194 use axum::http::header;
195
196 use super::*;
197
198 #[test]
199 fn extract_access_token_prefers_authorization_header() {
200 let mut headers = HeaderMap::new();
201 headers.insert(
202 header::AUTHORIZATION,
203 "Bearer header-token".parse().expect("valid header"),
204 );
205 let query = SocketQuery {
206 token: Some("query-token".to_string()),
207 };
208
209 let token = extract_access_token(&headers, &query, &RealtimeRouteOptions::default())
210 .expect("token should parse");
211 assert_eq!(token, "header-token");
212 }
213
214 #[test]
215 fn extract_access_token_falls_back_to_query_token() {
216 let headers = HeaderMap::new();
217 let query = SocketQuery {
218 token: Some("query-token".to_string()),
219 };
220
221 let token = extract_access_token(&headers, &query, &RealtimeRouteOptions::default())
222 .expect("token should parse");
223 assert_eq!(token, "query-token");
224 }
225
226 #[test]
227 fn extract_access_token_rejects_invalid_header_when_strict() {
228 let mut headers = HeaderMap::new();
229 headers.insert(
230 header::AUTHORIZATION,
231 "Token abc".parse().expect("valid header"),
232 );
233 let query = SocketQuery {
234 token: Some("query-token".to_string()),
235 };
236
237 let err = extract_access_token(&headers, &query, &RealtimeRouteOptions::default())
238 .expect_err("invalid header should fail");
239 assert!(matches!(err, RealtimeHttpError::InvalidToken));
240 }
241}