zvault_server/
middleware.rs1use std::sync::Arc;
8
9use axum::extract::{Request, State};
10use axum::http::StatusCode;
11use axum::middleware::Next;
12use axum::response::{IntoResponse, Response};
13
14use crate::state::AppState;
15
16#[derive(Debug, Clone)]
18pub struct AuthContext {
19 pub token_hash: String,
21 pub policies: Vec<String>,
23 pub display_name: String,
25}
26
27pub async fn auth_middleware(
31 State(state): State<Arc<AppState>>,
32 mut req: Request,
33 next: Next,
34) -> Response {
35 let path = req.uri().path().to_owned();
36
37 if path == "/v1/sys/health"
39 || path == "/v1/sys/seal-status"
40 || path == "/v1/sys/init"
41 || path == "/v1/sys/unseal"
42 || path.starts_with("/app/")
43 || path == "/app"
44 || path == "/"
45 {
46 return next.run(req).await;
47 }
48
49 let token = req
50 .headers()
51 .get("X-Vault-Token")
52 .and_then(|v| v.to_str().ok())
53 .map(String::from);
54
55 let Some(token) = token else {
56 return (
57 StatusCode::UNAUTHORIZED,
58 axum::Json(serde_json::json!({"error": "unauthorized", "message": "missing X-Vault-Token header"})),
59 ).into_response();
60 };
61
62 match state.token_store.lookup(&token).await {
63 Ok(entry) => {
64 let ctx = AuthContext {
65 token_hash: entry.token_hash.clone(),
66 policies: entry.policies.clone(),
67 display_name: entry.display_name.clone(),
68 };
69 req.extensions_mut().insert(ctx);
70 next.run(req).await
71 }
72 Err(_) => (
73 StatusCode::UNAUTHORIZED,
74 axum::Json(serde_json::json!({"error": "unauthorized", "message": "invalid or expired token"})),
75 ).into_response(),
76 }
77}