Skip to main content

victauri_core/
middleware.rs

1//! Shared axum middleware for Victauri's localhost HTTP servers.
2//!
3//! Gated behind the `middleware` feature flag.  Provides thin middleware
4//! wrappers around the pure-logic security primitives in [`super::security`].
5
6use std::sync::Arc;
7
8use axum::extract::Request;
9use axum::http::StatusCode;
10use axum::middleware::Next;
11use axum::response::Response;
12
13use crate::security::{self, RateLimiter, constant_time_eq, is_allowed_origin, is_localhost_host};
14
15const BEARER_PREFIX_LEN: usize = "Bearer ".len();
16
17/// Shared authentication state holding the optional Bearer token for the MCP
18/// server.
19#[derive(Clone)]
20pub struct AuthState {
21    /// The expected Bearer token, or `None` if authentication is disabled.
22    pub token: Option<String>,
23}
24
25/// Axum middleware that validates the `Authorization: Bearer <token>` header
26/// against [`AuthState`].
27///
28/// Case-insensitive prefix matching per RFC 7235.  Constant-time token
29/// comparison via [`constant_time_eq`].
30///
31/// # Errors
32///
33/// Returns [`StatusCode::UNAUTHORIZED`] if the token is missing or invalid.
34pub async fn require_auth(
35    axum::extract::State(auth): axum::extract::State<Arc<AuthState>>,
36    request: Request,
37    next: Next,
38) -> Result<Response, StatusCode> {
39    let Some(expected) = &auth.token else {
40        return Ok(next.run(request).await);
41    };
42
43    let provided = request
44        .headers()
45        .get("authorization")
46        .and_then(|v| v.to_str().ok())
47        .and_then(|v| {
48            let lower = v.to_lowercase();
49            if lower.starts_with("bearer ") {
50                Some(v[BEARER_PREFIX_LEN..].to_string())
51            } else {
52                None
53            }
54        });
55
56    match provided {
57        Some(ref token) if constant_time_eq(token.as_bytes(), expected.as_bytes()) => {
58            Ok(next.run(request).await)
59        }
60        _ => {
61            tracing::warn!("Victauri: rejected request — invalid or missing auth token");
62            Err(StatusCode::UNAUTHORIZED)
63        }
64    }
65}
66
67/// Create a rate limiter with the default capacity of
68/// [`DEFAULT_RATE_LIMIT`](security::DEFAULT_RATE_LIMIT) requests per second.
69#[must_use]
70pub fn default_rate_limiter() -> Arc<RateLimiter> {
71    Arc::new(RateLimiter::new(security::DEFAULT_RATE_LIMIT))
72}
73
74/// Axum middleware that rejects requests with 429 when the token bucket is
75/// exhausted.
76///
77/// # Errors
78///
79/// Returns [`StatusCode::TOO_MANY_REQUESTS`] with `Retry-After: 1` header when
80/// the rate limit is exceeded.
81pub async fn rate_limit(
82    axum::extract::State(limiter): axum::extract::State<Arc<RateLimiter>>,
83    request: Request,
84    next: Next,
85) -> Result<
86    Response,
87    (
88        StatusCode,
89        [(axum::http::HeaderName, axum::http::HeaderValue); 1],
90    ),
91> {
92    if limiter.try_acquire() {
93        Ok(next.run(request).await)
94    } else {
95        Err((
96            StatusCode::TOO_MANY_REQUESTS,
97            [(
98                axum::http::header::RETRY_AFTER,
99                axum::http::HeaderValue::from_static("1"),
100            )],
101        ))
102    }
103}
104
105/// Axum middleware that blocks DNS rebinding attacks.
106///
107/// Rejects any request where the `Host` header is not a localhost address.
108///
109/// # Errors
110///
111/// Returns [`StatusCode::FORBIDDEN`] if the `Host` header is not `localhost`,
112/// `127.0.0.1`, or `::1`.
113pub async fn dns_rebinding_guard(request: Request, next: Next) -> Result<Response, StatusCode> {
114    let host = request
115        .headers()
116        .get("host")
117        .and_then(|v| v.to_str().ok())
118        .unwrap_or("");
119    if !is_localhost_host(host) {
120        tracing::warn!("DNS rebinding attempt blocked: Host={host}");
121        return Err(StatusCode::FORBIDDEN);
122    }
123    Ok(next.run(request).await)
124}
125
126/// Axum middleware that blocks cross-origin requests from browsers.
127///
128/// # Errors
129///
130/// Returns [`StatusCode::FORBIDDEN`] if the `Origin` header is present and does
131/// not match a localhost or `tauri://` origin.
132pub async fn origin_guard(request: Request, next: Next) -> Result<Response, StatusCode> {
133    if let Some(origin) = request
134        .headers()
135        .get("origin")
136        .and_then(|v| v.to_str().ok())
137        && !is_allowed_origin(origin)
138    {
139        tracing::warn!("Cross-origin request blocked: Origin={origin}");
140        return Err(StatusCode::FORBIDDEN);
141    }
142    Ok(next.run(request).await)
143}
144
145/// Axum middleware that sets security-hardening response headers on every
146/// response.
147pub async fn security_headers(request: Request, next: Next) -> Response {
148    let mut response = next.run(request).await;
149    let headers = response.headers_mut();
150    headers.insert(
151        axum::http::header::X_CONTENT_TYPE_OPTIONS,
152        axum::http::HeaderValue::from_static("nosniff"),
153    );
154    headers.insert(
155        axum::http::header::CACHE_CONTROL,
156        axum::http::HeaderValue::from_static("no-store"),
157    );
158    headers.insert(
159        axum::http::header::HeaderName::from_static("x-frame-options"),
160        axum::http::HeaderValue::from_static("DENY"),
161    );
162    headers.insert(
163        axum::http::header::ACCESS_CONTROL_ALLOW_ORIGIN,
164        axum::http::HeaderValue::from_static("null"),
165    );
166    headers.insert(
167        axum::http::header::HeaderName::from_static("content-security-policy"),
168        axum::http::HeaderValue::from_static("default-src 'none'"),
169    );
170    response
171}