Skip to main content

realtime/server/
axum.rs

1use std::sync::Arc;
2
3use axum::{
4    Router,
5    extract::{
6        Query, State,
7        ws::{WebSocketUpgrade, rejection::WebSocketUpgradeRejection},
8    },
9    http::{HeaderMap, StatusCode, header},
10    response::{IntoResponse, Response},
11    routing::get,
12};
13use serde::Deserialize;
14
15use super::{RealtimeError, SocketAppState};
16
17#[derive(Debug, Clone)]
18pub struct RealtimeRouteOptions {
19    pub path: &'static str,
20    pub allow_query_token: bool,
21    pub strict_header_precedence: bool,
22}
23
24impl Default for RealtimeRouteOptions {
25    fn default() -> Self {
26        Self {
27            path: "/realtime/socket",
28            allow_query_token: true,
29            strict_header_precedence: true,
30        }
31    }
32}
33
34struct SocketRouteState {
35    socket_server_handle: Arc<SocketAppState>,
36    options: RealtimeRouteOptions,
37}
38
39impl Clone for SocketRouteState {
40    fn clone(&self) -> Self {
41        Self {
42            socket_server_handle: Arc::clone(&self.socket_server_handle),
43            options: self.options.clone(),
44        }
45    }
46}
47
48#[derive(Debug, Deserialize, Default)]
49struct SocketQuery {
50    token: Option<String>,
51}
52
53#[derive(Debug)]
54enum RealtimeHttpError {
55    MissingToken,
56    InvalidToken,
57    UpgradeRequired,
58    RealtimeDisabled,
59    VerifyFailed(RealtimeError),
60}
61
62impl RealtimeHttpError {
63    fn status(&self) -> StatusCode {
64        match self {
65            Self::MissingToken | Self::InvalidToken => StatusCode::UNAUTHORIZED,
66            Self::UpgradeRequired => StatusCode::BAD_REQUEST,
67            Self::RealtimeDisabled => StatusCode::NOT_FOUND,
68            Self::VerifyFailed(err) => match err {
69                RealtimeError::BadRequest(_) => StatusCode::BAD_REQUEST,
70                RealtimeError::Unauthorized(_) => StatusCode::UNAUTHORIZED,
71                RealtimeError::Forbidden(_) => StatusCode::FORBIDDEN,
72                RealtimeError::NotFound(_) => StatusCode::NOT_FOUND,
73                RealtimeError::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
74            },
75        }
76    }
77
78    fn message(&self) -> String {
79        match self {
80            Self::MissingToken => {
81                "Missing access token (use Authorization Bearer or token query param)".to_string()
82            }
83            Self::InvalidToken => "Missing/invalid Authorization header".to_string(),
84            Self::UpgradeRequired => "WebSocket upgrade required".to_string(),
85            Self::RealtimeDisabled => "Realtime is disabled".to_string(),
86            Self::VerifyFailed(err) => err.message().to_string(),
87        }
88    }
89}
90
91impl IntoResponse for RealtimeHttpError {
92    fn into_response(self) -> Response {
93        (self.status(), self.message()).into_response()
94    }
95}
96
97pub fn router(socket_server_handle: Arc<SocketAppState>) -> Router {
98    router_with_options(socket_server_handle, RealtimeRouteOptions::default())
99}
100
101pub fn router_with_options(
102    socket_server_handle: Arc<SocketAppState>,
103    options: RealtimeRouteOptions,
104) -> Router {
105    let path = options.path;
106    Router::new()
107        .route(path, get(socket_handler))
108        .with_state(SocketRouteState {
109            socket_server_handle,
110            options,
111        })
112}
113
114async fn socket_handler(
115    State(handler_state): State<SocketRouteState>,
116    upgrade: Result<WebSocketUpgrade, WebSocketUpgradeRejection>,
117    headers: HeaderMap,
118    Query(query): Query<SocketQuery>,
119) -> Response {
120    let realtime = handler_state.socket_server_handle.handle.clone();
121
122    if !realtime.is_enabled() {
123        return RealtimeHttpError::RealtimeDisabled.into_response();
124    }
125
126    let upgrade = match upgrade {
127        Ok(upgrade) => upgrade,
128        Err(_) => return RealtimeHttpError::UpgradeRequired.into_response(),
129    };
130
131    let token = match extract_access_token(&headers, &query, &handler_state.options) {
132        Ok(token) => token,
133        Err(err) => return err.into_response(),
134    };
135
136    let auth = match handler_state
137        .socket_server_handle
138        .verifier
139        .verify_token(&token)
140        .await
141    {
142        Ok(auth) => auth,
143        Err(err) => return RealtimeHttpError::VerifyFailed(err).into_response(),
144    };
145
146    upgrade
147        .max_message_size(realtime.max_message_bytes())
148        .max_frame_size(realtime.max_message_bytes())
149        .on_upgrade(move |socket| async move {
150            realtime.serve_socket(socket, auth).await;
151        })
152        .into_response()
153}
154
155fn extract_access_token(
156    headers: &HeaderMap,
157    query: &SocketQuery,
158    options: &RealtimeRouteOptions,
159) -> Result<String, RealtimeHttpError> {
160    let auth_header = headers
161        .get(header::AUTHORIZATION)
162        .and_then(|value| value.to_str().ok());
163
164    if let Some(auth_header) = auth_header {
165        let header_token = auth_header
166            .strip_prefix("Bearer ")
167            .map(str::trim)
168            .filter(|value| !value.is_empty());
169
170        if let Some(token) = header_token {
171            return Ok(token.to_string());
172        }
173
174        if options.strict_header_precedence {
175            return Err(RealtimeHttpError::InvalidToken);
176        }
177    }
178
179    if options.allow_query_token
180        && let Some(token) = query
181            .token
182            .as_deref()
183            .map(str::trim)
184            .filter(|value| !value.is_empty())
185    {
186        return Ok(token.to_string());
187    }
188
189    Err(RealtimeHttpError::MissingToken)
190}
191
192#[cfg(test)]
193mod tests {
194    use axum::http::header;
195
196    use super::*;
197
198    #[test]
199    fn extract_access_token_prefers_authorization_header() {
200        let mut headers = HeaderMap::new();
201        headers.insert(
202            header::AUTHORIZATION,
203            "Bearer header-token".parse().expect("valid header"),
204        );
205        let query = SocketQuery {
206            token: Some("query-token".to_string()),
207        };
208
209        let token = extract_access_token(&headers, &query, &RealtimeRouteOptions::default())
210            .expect("token should parse");
211        assert_eq!(token, "header-token");
212    }
213
214    #[test]
215    fn extract_access_token_falls_back_to_query_token() {
216        let headers = HeaderMap::new();
217        let query = SocketQuery {
218            token: Some("query-token".to_string()),
219        };
220
221        let token = extract_access_token(&headers, &query, &RealtimeRouteOptions::default())
222            .expect("token should parse");
223        assert_eq!(token, "query-token");
224    }
225
226    #[test]
227    fn extract_access_token_rejects_invalid_header_when_strict() {
228        let mut headers = HeaderMap::new();
229        headers.insert(
230            header::AUTHORIZATION,
231            "Token abc".parse().expect("valid header"),
232        );
233        let query = SocketQuery {
234            token: Some("query-token".to_string()),
235        };
236
237        let err = extract_access_token(&headers, &query, &RealtimeRouteOptions::default())
238            .expect_err("invalid header should fail");
239        assert!(matches!(err, RealtimeHttpError::InvalidToken));
240    }
241}