victauri_core/
middleware.rs1use std::sync::Arc;
7
8use axum::extract::Request;
9use axum::http::StatusCode;
10use axum::middleware::Next;
11use axum::response::Response;
12
13use crate::security::{self, RateLimiter, constant_time_eq, is_allowed_origin, is_localhost_host};
14
15const BEARER_PREFIX_LEN: usize = "Bearer ".len();
16
17#[derive(Clone)]
20pub struct AuthState {
21 pub token: Option<String>,
23}
24
25pub async fn require_auth(
35 axum::extract::State(auth): axum::extract::State<Arc<AuthState>>,
36 request: Request,
37 next: Next,
38) -> Result<Response, StatusCode> {
39 let Some(expected) = &auth.token else {
40 return Ok(next.run(request).await);
41 };
42
43 let provided = request
44 .headers()
45 .get("authorization")
46 .and_then(|v| v.to_str().ok())
47 .and_then(|v| {
48 let lower = v.to_lowercase();
49 if lower.starts_with("bearer ") {
50 Some(v[BEARER_PREFIX_LEN..].to_string())
51 } else {
52 None
53 }
54 });
55
56 match provided {
57 Some(ref token) if constant_time_eq(token.as_bytes(), expected.as_bytes()) => {
58 Ok(next.run(request).await)
59 }
60 _ => {
61 tracing::warn!("Victauri: rejected request — invalid or missing auth token");
62 Err(StatusCode::UNAUTHORIZED)
63 }
64 }
65}
66
67#[must_use]
70pub fn default_rate_limiter() -> Arc<RateLimiter> {
71 Arc::new(RateLimiter::new(security::DEFAULT_RATE_LIMIT))
72}
73
74pub async fn rate_limit(
82 axum::extract::State(limiter): axum::extract::State<Arc<RateLimiter>>,
83 request: Request,
84 next: Next,
85) -> Result<
86 Response,
87 (
88 StatusCode,
89 [(axum::http::HeaderName, axum::http::HeaderValue); 1],
90 ),
91> {
92 if limiter.try_acquire() {
93 Ok(next.run(request).await)
94 } else {
95 Err((
96 StatusCode::TOO_MANY_REQUESTS,
97 [(
98 axum::http::header::RETRY_AFTER,
99 axum::http::HeaderValue::from_static("1"),
100 )],
101 ))
102 }
103}
104
105pub async fn dns_rebinding_guard(request: Request, next: Next) -> Result<Response, StatusCode> {
114 let host = request
115 .headers()
116 .get("host")
117 .and_then(|v| v.to_str().ok())
118 .unwrap_or("");
119 if !is_localhost_host(host) {
120 tracing::warn!("DNS rebinding attempt blocked: Host={host}");
121 return Err(StatusCode::FORBIDDEN);
122 }
123 Ok(next.run(request).await)
124}
125
126pub async fn origin_guard(request: Request, next: Next) -> Result<Response, StatusCode> {
133 if let Some(origin) = request
134 .headers()
135 .get("origin")
136 .and_then(|v| v.to_str().ok())
137 && !is_allowed_origin(origin)
138 {
139 tracing::warn!("Cross-origin request blocked: Origin={origin}");
140 return Err(StatusCode::FORBIDDEN);
141 }
142 Ok(next.run(request).await)
143}
144
145pub async fn security_headers(request: Request, next: Next) -> Response {
148 let mut response = next.run(request).await;
149 let headers = response.headers_mut();
150 headers.insert(
151 axum::http::header::X_CONTENT_TYPE_OPTIONS,
152 axum::http::HeaderValue::from_static("nosniff"),
153 );
154 headers.insert(
155 axum::http::header::CACHE_CONTROL,
156 axum::http::HeaderValue::from_static("no-store"),
157 );
158 headers.insert(
159 axum::http::header::HeaderName::from_static("x-frame-options"),
160 axum::http::HeaderValue::from_static("DENY"),
161 );
162 headers.insert(
163 axum::http::header::ACCESS_CONTROL_ALLOW_ORIGIN,
164 axum::http::HeaderValue::from_static("null"),
165 );
166 headers.insert(
167 axum::http::header::HeaderName::from_static("content-security-policy"),
168 axum::http::HeaderValue::from_static("default-src 'none'"),
169 );
170 response
171}