tuitbot_server/auth/
middleware.rs1use std::sync::Arc;
12
13use axum::extract::{Request, State};
14use axum::http::{HeaderMap, Method, StatusCode};
15use axum::middleware::Next;
16use axum::response::{IntoResponse, Response};
17use serde_json::json;
18use tuitbot_core::auth::session;
19
20use crate::state::AppState;
21
22fn extract_session_cookie(headers: &HeaderMap) -> Option<String> {
24 headers
25 .get("cookie")
26 .and_then(|v| v.to_str().ok())
27 .and_then(|cookies| {
28 cookies.split(';').find_map(|c| {
29 let c = c.trim();
30 c.strip_prefix("tuitbot_session=").map(|v| v.to_string())
31 })
32 })
33}
34
35const AUTH_EXEMPT_PATHS: &[&str] = &[
37 "/health",
38 "/api/health",
39 "/settings/status",
40 "/api/settings/status",
41 "/settings/init",
42 "/api/settings/init",
43 "/settings/test-llm",
44 "/api/settings/test-llm",
45 "/ws",
46 "/api/ws",
47 "/auth/login",
48 "/api/auth/login",
49 "/auth/status",
50 "/api/auth/status",
51 "/connectors/google-drive/callback",
52 "/api/connectors/google-drive/callback",
53 "/media/file",
56 "/api/media/file",
57];
58
59pub async fn auth_middleware(
61 State(state): State<Arc<AppState>>,
62 headers: HeaderMap,
63 request: Request,
64 next: Next,
65) -> Response {
66 let path = request.uri().path();
67
68 if AUTH_EXEMPT_PATHS.contains(&path) {
70 return next.run(request).await;
71 }
72
73 let bearer_ok = headers
75 .get("authorization")
76 .and_then(|v| v.to_str().ok())
77 .and_then(|v| v.strip_prefix("Bearer "))
78 .is_some_and(|token| token == state.api_token);
79
80 if bearer_ok {
81 return next.run(request).await;
82 }
83
84 if let Some(session_token) = extract_session_cookie(&headers) {
86 match session::validate_session(&state.db, &session_token).await {
87 Ok(Some(sess)) => {
88 let method = request.method().clone();
90 if method == Method::POST
91 || method == Method::PATCH
92 || method == Method::DELETE
93 || method == Method::PUT
94 {
95 let csrf_ok = headers
96 .get("x-csrf-token")
97 .and_then(|v| v.to_str().ok())
98 .is_some_and(|t| t == sess.csrf_token);
99
100 if !csrf_ok {
101 return (
102 StatusCode::FORBIDDEN,
103 axum::Json(json!({"error": "missing or invalid CSRF token"})),
104 )
105 .into_response();
106 }
107 }
108 return next.run(request).await;
109 }
110 Ok(None) => { }
111 Err(e) => {
112 tracing::error!(error = %e, "Session validation failed");
113 }
114 }
115 }
116
117 (
119 StatusCode::UNAUTHORIZED,
120 axum::Json(json!({"error": "unauthorized"})),
121 )
122 .into_response()
123}