1use std::sync::Arc;
9
10use axum::extract::Request;
11use axum::http::header::{HeaderValue, LOCATION};
12use axum::http::{HeaderMap, StatusCode};
13use axum::response::{IntoResponse, Response};
14use axum::routing::any;
15use axum::Router;
16
17use super::{forbidden, unauthorized, Auth, AuthDecision};
18use crate::config::AuthConfig;
19
20pub struct ForwardAuth {
22 auth: Arc<Auth>,
23 path: String,
24 login_url: Option<String>,
25}
26
27impl ForwardAuth {
28 pub fn build(config: &AuthConfig, auth: Arc<Auth>) -> Option<Arc<Self>> {
33 let fa = config.forward_auth.as_ref()?;
34 if !fa.enabled {
35 return None;
36 }
37 Some(Arc::new(Self {
38 auth,
39 path: fa.path.clone(),
40 login_url: fa.login_url.clone(),
41 }))
42 }
43
44 pub fn routes<S>(self: &Arc<Self>) -> Router<S>
46 where
47 S: Clone + Send + Sync + 'static,
48 {
49 let fa = self.clone();
50 Router::new().route(
53 &self.path,
54 any(move |req: Request| {
55 let fa = fa.clone();
56 async move { fa.verify(req).await }
57 }),
58 )
59 }
60
61 async fn verify(&self, request: Request) -> Response {
62 let headers = request.headers();
63 let method = original_method(headers)
64 .unwrap_or_else(|| request.method().as_str().to_ascii_uppercase());
65 let path = original_path(headers).unwrap_or_else(|| request.uri().path().to_string());
66
67 match self.auth.decide(headers, &path, &method).await {
68 AuthDecision::Allow(claim_headers) => (StatusCode::OK, claim_headers).into_response(),
71 AuthDecision::Unauthenticated(msg) => self.deny(msg),
72 AuthDecision::Forbidden(msg) => forbidden(msg),
73 }
74 }
75
76 fn deny(&self, msg: &'static str) -> Response {
79 let mut response = unauthorized(msg);
80 if let Some(url) = &self.login_url {
81 if let Ok(value) = HeaderValue::try_from(url.as_str()) {
82 response.headers_mut().insert(LOCATION, value);
83 }
84 }
85 response
86 }
87}
88
89fn original_method(headers: &HeaderMap) -> Option<String> {
91 forwarded(headers, &["x-forwarded-method", "x-original-method"]).map(|m| m.to_ascii_uppercase())
92}
93
94fn original_path(headers: &HeaderMap) -> Option<String> {
96 let raw = forwarded(headers, &["x-forwarded-uri", "x-original-uri"])?;
97 let path = raw.split_once('?').map_or(raw.as_str(), |(p, _)| p);
98 Some(path.to_string())
99}
100
101fn forwarded(headers: &HeaderMap, names: &[&str]) -> Option<String> {
103 names
104 .iter()
105 .filter_map(|n| headers.get(*n).and_then(|v| v.to_str().ok()))
106 .find(|v| !v.is_empty())
107 .map(str::to_string)
108}
109
110#[cfg(test)]
111mod tests {
112 use super::*;
113 use crate::config::{AuthConfig, ForwardAuthConfig, JwtConfig, RoutePolicyConfig};
114 use axum::body::Body;
115 use axum::http::Request as HttpRequest;
116 use ed25519_dalek::{Signer, SigningKey};
117 use std::collections::HashMap;
118 use tower::ServiceExt;
119
120 fn keypair() -> (SigningKey, String) {
122 let sk = SigningKey::from_bytes(&[7u8; 32]);
123 let spki_prefix: [u8; 12] = [
124 0x30, 0x2a, 0x30, 0x05, 0x06, 0x03, 0x2b, 0x65, 0x70, 0x03, 0x21, 0x00,
125 ];
126 let mut der = spki_prefix.to_vec();
127 der.extend_from_slice(sk.verifying_key().as_bytes());
128 use base64::Engine;
129 let b64 = base64::engine::general_purpose::STANDARD.encode(&der);
130 let pem = format!("-----BEGIN PUBLIC KEY-----\n{b64}\n-----END PUBLIC KEY-----\n");
131 (sk, pem)
132 }
133
134 fn sign(sk: &SigningKey, claims: &serde_json::Value) -> String {
135 use base64::engine::general_purpose::URL_SAFE_NO_PAD;
136 use base64::Engine;
137 let header = URL_SAFE_NO_PAD.encode(br#"{"alg":"EdDSA","typ":"JWT"}"#);
138 let payload = URL_SAFE_NO_PAD.encode(serde_json::to_vec(claims).unwrap());
139 let signing_input = format!("{header}.{payload}");
140 let sig = sk.sign(signing_input.as_bytes());
141 format!("{signing_input}.{}", URL_SAFE_NO_PAD.encode(sig.to_bytes()))
142 }
143
144 fn write_pem(pem: &str) -> std::path::PathBuf {
145 use std::sync::atomic::{AtomicU32, Ordering};
146 static N: AtomicU32 = AtomicU32::new(0);
147 let p = std::env::temp_dir().join(format!(
148 "sp_fa_{}_{}.pem",
149 std::process::id(),
150 N.fetch_add(1, Ordering::Relaxed)
151 ));
152 std::fs::write(&p, pem).unwrap();
153 p
154 }
155
156 fn forward_auth(pem_path: std::path::PathBuf, login_url: Option<String>) -> Arc<ForwardAuth> {
157 let mut claims_headers = HashMap::new();
158 claims_headers.insert("sub".to_string(), "x-forwarded-user".to_string());
159 let config = AuthConfig {
160 mode: "jwt".into(),
161 jwt: Some(JwtConfig {
162 issuer: None,
163 audience: None,
164 jwks_uri: None,
165 public_key_pem_file: Some(pem_path),
166 claims_headers,
167 roles_claim: "roles".into(),
168 }),
169 forward_auth: Some(ForwardAuthConfig {
170 enabled: true,
171 path: "/auth/verify".into(),
172 policies: vec![RoutePolicyConfig {
173 path: "/v1/admin/**".into(),
174 methods: vec!["*".into()],
175 require_auth: true,
176 required_roles: vec!["admin".into()],
177 }],
178 login_url,
179 applications_path: None,
180 }),
181 authz: None,
182 };
183 let auth = Auth::build(&config).unwrap().unwrap();
184 ForwardAuth::build(&config, auth).unwrap()
185 }
186
187 async fn call(fa: &Arc<ForwardAuth>, req: HttpRequest<Body>) -> Response {
188 let app: Router = fa.routes();
189 app.oneshot(req).await.unwrap()
190 }
191
192 fn verify_request(method: &str, uri: &str, token: Option<&str>) -> HttpRequest<Body> {
193 let mut b = HttpRequest::get("/auth/verify")
194 .header("x-forwarded-method", method)
195 .header("x-forwarded-uri", uri);
196 if let Some(t) = token {
197 b = b.header("authorization", format!("Bearer {t}"));
198 }
199 b.body(Body::empty()).unwrap()
200 }
201
202 #[tokio::test]
203 async fn allows_and_echoes_claim_header() {
204 let (sk, pem) = keypair();
205 let fa = forward_auth(write_pem(&pem), None);
206 let token = sign(
207 &sk,
208 &serde_json::json!({ "sub": "alice", "roles": ["admin"], "exp": 9999999999u64 }),
209 );
210 let resp = call(&fa, verify_request("GET", "/v1/admin/things", Some(&token))).await;
211 assert_eq!(resp.status(), StatusCode::OK);
212 assert_eq!(resp.headers()["x-forwarded-user"], "alice");
214 }
215
216 #[tokio::test]
217 async fn denies_without_token_and_sets_login_location() {
218 let (_sk, pem) = keypair();
219 let fa = forward_auth(write_pem(&pem), Some("https://login.example.com".into()));
220 let resp = call(&fa, verify_request("GET", "/v1/admin/things", None)).await;
221 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
222 assert_eq!(resp.headers()[LOCATION], "https://login.example.com");
223 }
224
225 #[tokio::test]
226 async fn forbids_when_role_missing() {
227 let (sk, pem) = keypair();
228 let fa = forward_auth(write_pem(&pem), None);
229 let token = sign(
230 &sk,
231 &serde_json::json!({ "sub": "bob", "roles": ["user"], "exp": 9999999999u64 }),
232 );
233 let resp = call(&fa, verify_request("GET", "/v1/admin/things", Some(&token))).await;
234 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
235 }
236
237 #[tokio::test]
238 async fn denies_invalid_token() {
239 let (_sk, pem) = keypair();
242 let fa = forward_auth(write_pem(&pem), None);
243 let wrong_key = SigningKey::from_bytes(&[9u8; 32]);
244 let token = sign(
245 &wrong_key,
246 &serde_json::json!({ "sub": "mallory", "roles": ["admin"], "exp": 9999999999u64 }),
247 );
248 let resp = call(&fa, verify_request("GET", "/v1/admin/things", Some(&token))).await;
249 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
250 }
251
252 #[tokio::test]
253 async fn allows_unprotected_original_path() {
254 let (_sk, pem) = keypair();
257 let fa = forward_auth(write_pem(&pem), None);
258 let resp = call(&fa, verify_request("GET", "/v1/public/info", None)).await;
259 assert_eq!(resp.status(), StatusCode::OK);
260 }
261}