Skip to main content

structured_proxy/auth/
mod.rs

1//! JWT authentication and route-level authorization.
2//!
3//! Validates `Authorization: Bearer` JWTs against a configured key source (an
4//! Ed25519 PEM file or a JWKS endpoint), enforces per-route policies
5//! (`require_auth` / `required_roles`), and forwards selected claims to the
6//! upstream as request headers. Active only when `auth.mode == "jwt"`.
7
8pub mod authz;
9pub mod forward;
10pub mod jwks;
11pub mod policy;
12
13use std::collections::HashMap;
14use std::collections::HashSet;
15use std::sync::Arc;
16
17use axum::extract::State;
18use axum::http::header::{HeaderName, HeaderValue};
19use axum::http::{HeaderMap, StatusCode};
20use axum::middleware::Next;
21use axum::response::{IntoResponse, Response};
22use axum::Json;
23use jsonwebtoken::{decode, decode_header, Algorithm, DecodingKey, Validation};
24use serde_json::Value;
25
26use crate::config::AuthConfig;
27use jwks::JwksCache;
28use policy::Policies;
29
30/// Where verifying keys come from.
31enum KeySource {
32    /// A single Ed25519 public key (EdDSA).
33    Pem(Arc<DecodingKey>),
34    /// Keys discovered from a JWKS endpoint, selected by `kid`.
35    Jwks(JwksCache),
36}
37
38/// Compiled auth configuration: keys, expected claims, and route policies.
39pub struct Auth {
40    keys: KeySource,
41    issuer: Option<String>,
42    audience: Option<String>,
43    claims_headers: HashMap<String, String>,
44    roles_claim: String,
45    policies: Policies,
46}
47
48impl Auth {
49    /// Build auth from config, or `None` when `auth.mode` is not `"jwt"`.
50    ///
51    /// # Errors
52    /// Returns an error string when the JWT config is missing a key source, the
53    /// PEM file cannot be read, or a policy glob fails to compile.
54    pub fn build(config: &AuthConfig) -> Result<Option<Arc<Self>>, String> {
55        if config.mode != "jwt" {
56            return Ok(None);
57        }
58        let jwt = config
59            .jwt
60            .as_ref()
61            .ok_or("auth.mode is \"jwt\" but auth.jwt is not set")?;
62
63        let keys = if let Some(uri) = &jwt.jwks_uri {
64            KeySource::Jwks(JwksCache::new(uri.clone()))
65        } else if let Some(pem_path) = &jwt.public_key_pem_file {
66            let pem = std::fs::read(pem_path)
67                .map_err(|e| format!("failed to read auth.jwt.public_key_pem_file: {e}"))?;
68            let key = DecodingKey::from_ed_pem(&pem)
69                .map_err(|e| format!("invalid Ed25519 public key PEM: {e}"))?;
70            KeySource::Pem(Arc::new(key))
71        } else {
72            return Err("auth.jwt requires either jwks_uri or public_key_pem_file".to_string());
73        };
74
75        let policies = match &config.forward_auth {
76            Some(fa) => Policies::compile(&fa.policies)?,
77            None => Policies::default(),
78        };
79
80        Ok(Some(Arc::new(Self {
81            keys,
82            issuer: jwt.issuer.clone(),
83            audience: jwt.audience.clone(),
84            claims_headers: jwt.claims_headers.clone(),
85            roles_claim: jwt.roles_claim.clone(),
86            policies,
87        })))
88    }
89
90    /// Verify a token and return its claims, or `None` if invalid.
91    async fn verify(&self, token: &str) -> Option<Value> {
92        let header = decode_header(token).ok()?;
93        let (key, algorithm) = match &self.keys {
94            KeySource::Pem(k) => (k.clone(), Algorithm::EdDSA),
95            KeySource::Jwks(cache) => {
96                let kid = header.kid.as_deref()?;
97                let vk = cache.key_for(kid).await?;
98                (vk.key, vk.algorithm)
99            }
100        };
101        // Reject algorithm confusion: the token must use the key's algorithm.
102        if header.alg != algorithm {
103            return None;
104        }
105
106        let mut validation = Validation::new(algorithm);
107        if let Some(iss) = &self.issuer {
108            validation.set_issuer(&[iss]);
109        }
110        match &self.audience {
111            Some(aud) => validation.set_audience(&[aud]),
112            None => validation.validate_aud = false,
113        }
114
115        decode::<Value>(token, &key, &validation)
116            .ok()
117            .map(|data| data.claims)
118    }
119}
120
121/// The outcome of an auth check for a request.
122pub(crate) enum AuthDecision {
123    /// Allowed; forward these (verified) claim headers to the upstream.
124    Allow(HeaderMap),
125    /// Rejected: no/invalid credentials (HTTP 401).
126    Unauthenticated(&'static str),
127    /// Rejected: authenticated but lacking a required role (HTTP 403).
128    Forbidden(&'static str),
129}
130
131impl Auth {
132    /// Evaluate auth for a request: validate the bearer token, apply the route
133    /// policy, and render the claim headers to forward. This is the single
134    /// source of truth shared by the middleware and the forward-auth endpoint.
135    pub(crate) async fn decide(
136        &self,
137        headers: &HeaderMap,
138        path: &str,
139        method: &str,
140    ) -> AuthDecision {
141        // A token that is present but invalid is always a 401, regardless of policy.
142        let claims = match bearer_token(headers) {
143            Some(token) => match self.verify(&token).await {
144                Some(c) => Some(c),
145                None => return AuthDecision::Unauthenticated("invalid or expired token"),
146            },
147            None => None,
148        };
149
150        if let Some(policy) = self.policies.match_rule(path, method) {
151            if policy.require_auth && claims.is_none() {
152                return AuthDecision::Unauthenticated("authentication required");
153            }
154            if !policy.required_roles.is_empty() {
155                // An unauthenticated caller is told to authenticate (401), not
156                // that they lack a role (403).
157                let Some(claims) = claims.as_ref() else {
158                    return AuthDecision::Unauthenticated("authentication required");
159                };
160                let roles = extract_roles(claims, &self.roles_claim);
161                if !policy.required_roles.iter().all(|r| roles.contains(r)) {
162                    return AuthDecision::Forbidden("insufficient role");
163                }
164            }
165        }
166
167        let mut claim_headers = HeaderMap::new();
168        if let Some(claims) = &claims {
169            inject_claim_headers(&mut claim_headers, claims, &self.claims_headers);
170        }
171        AuthDecision::Allow(claim_headers)
172    }
173}
174
175/// Axum middleware enforcing JWT auth and route policies.
176pub async fn middleware(
177    State(auth): State<Arc<Auth>>,
178    mut request: axum::extract::Request,
179    next: Next,
180) -> Response {
181    let path = request.uri().path().to_string();
182    let method = request.method().as_str().to_ascii_uppercase();
183
184    // Strip any client-supplied values for proxy-controlled claim headers, so a
185    // client can never forge them onto the upstream (only verified claims set
186    // them below).
187    strip_claim_headers(request.headers_mut(), &auth.claims_headers);
188
189    match auth.decide(request.headers(), &path, &method).await {
190        AuthDecision::Unauthenticated(msg) => unauthorized(msg),
191        AuthDecision::Forbidden(msg) => forbidden(msg),
192        AuthDecision::Allow(claim_headers) => {
193            let dst = request.headers_mut();
194            for (name, value) in &claim_headers {
195                dst.insert(name.clone(), value.clone());
196            }
197            next.run(request).await
198        }
199    }
200}
201
202/// Extract the bearer token from the `Authorization` header.
203fn bearer_token(headers: &HeaderMap) -> Option<String> {
204    let value = headers.get("authorization")?.to_str().ok()?;
205    let token = value
206        .strip_prefix("Bearer ")
207        .or_else(|| value.strip_prefix("bearer "))?;
208    let token = token.trim();
209    (!token.is_empty()).then(|| token.to_string())
210}
211
212/// Resolve a (possibly dotted) claim path to a JSON value.
213fn claim_at<'a>(claims: &'a Value, path: &str) -> Option<&'a Value> {
214    let mut cur = claims;
215    for seg in path.split('.') {
216        cur = cur.get(seg)?;
217    }
218    Some(cur)
219}
220
221/// Collect the caller's roles from the configured claim (an array of strings).
222fn extract_roles(claims: &Value, roles_claim: &str) -> HashSet<String> {
223    claim_at(claims, roles_claim)
224        .and_then(Value::as_array)
225        .map(|arr| {
226            arr.iter()
227                .filter_map(|v| v.as_str().map(str::to_string))
228                .collect()
229        })
230        .unwrap_or_default()
231}
232
233/// Remove any incoming values for the proxy-controlled claim headers, so a
234/// client cannot forge them onto the upstream.
235fn strip_claim_headers(headers: &mut HeaderMap, mapping: &HashMap<String, String>) {
236    for header in mapping.values() {
237        if let Ok(name) = HeaderName::try_from(header.as_str()) {
238            while headers.remove(&name).is_some() {}
239        }
240    }
241}
242
243/// Inject configured claims as request headers forwarded to the upstream.
244fn inject_claim_headers(
245    headers: &mut HeaderMap,
246    claims: &Value,
247    mapping: &HashMap<String, String>,
248) {
249    for (claim, header) in mapping {
250        let Some(value) = claim_at(claims, claim) else {
251            continue;
252        };
253        let rendered = match value {
254            Value::String(s) => s.clone(),
255            Value::Number(n) => n.to_string(),
256            Value::Bool(b) => b.to_string(),
257            // Skip arrays/objects/null: not meaningful as a single header value.
258            _ => continue,
259        };
260        if let (Ok(name), Ok(val)) = (
261            HeaderName::try_from(header.as_str()),
262            HeaderValue::try_from(rendered),
263        ) {
264            headers.insert(name, val);
265        }
266    }
267}
268
269fn unauthorized(message: &str) -> Response {
270    (
271        StatusCode::UNAUTHORIZED,
272        Json(serde_json::json!({ "error": "UNAUTHENTICATED", "message": message })),
273    )
274        .into_response()
275}
276
277fn forbidden(message: &str) -> Response {
278    (
279        StatusCode::FORBIDDEN,
280        Json(serde_json::json!({ "error": "PERMISSION_DENIED", "message": message })),
281    )
282        .into_response()
283}
284
285#[cfg(test)]
286mod tests {
287    use super::*;
288
289    #[test]
290    fn bearer_token_parsing() {
291        let mut h = HeaderMap::new();
292        h.insert("authorization", "Bearer abc.def.ghi".parse().unwrap());
293        assert_eq!(bearer_token(&h).as_deref(), Some("abc.def.ghi"));
294
295        let mut h2 = HeaderMap::new();
296        h2.insert("authorization", "Basic xyz".parse().unwrap());
297        assert_eq!(bearer_token(&h2), None);
298        assert_eq!(bearer_token(&HeaderMap::new()), None);
299    }
300
301    #[test]
302    fn extract_roles_reads_array_and_dotted_path() {
303        let claims = serde_json::json!({
304            "roles": ["admin", "billing"],
305            "realm_access": { "roles": ["nested"] }
306        });
307        assert!(extract_roles(&claims, "roles").contains("admin"));
308        assert!(extract_roles(&claims, "realm_access.roles").contains("nested"));
309        assert!(extract_roles(&claims, "missing").is_empty());
310    }
311
312    #[test]
313    fn inject_claim_headers_renders_scalars() {
314        let claims = serde_json::json!({ "sub": "u-1", "n": 7, "obj": {"x": 1} });
315        let mapping = HashMap::from([
316            ("sub".to_string(), "x-user-id".to_string()),
317            ("n".to_string(), "x-n".to_string()),
318            ("obj".to_string(), "x-obj".to_string()),
319        ]);
320        let mut headers = HeaderMap::new();
321        inject_claim_headers(&mut headers, &claims, &mapping);
322        assert_eq!(headers["x-user-id"], "u-1");
323        assert_eq!(headers["x-n"], "7");
324        // Object claim is skipped (not a scalar).
325        assert!(!headers.contains_key("x-obj"));
326    }
327
328    // --- end-to-end JWT validation + policy enforcement ---
329
330    use crate::config::{AuthConfig, ForwardAuthConfig, JwtConfig, RoutePolicyConfig};
331    use axum::http::Request as HttpRequest;
332    use jsonwebtoken::{encode, EncodingKey, Header};
333    use std::sync::atomic::{AtomicU32, Ordering};
334    use tower::ServiceExt;
335
336    // Ed25519 test keypair (generated for tests only; not a secret).
337    const TEST_PRIV_PEM: &str = "-----BEGIN PRIVATE KEY-----\n\
338        MC4CAQAwBQYDK2VwBCIEIEVVO7H+T5tERRn/dzukOc8i9iYEKKtPh//qcrES+dCt\n\
339        -----END PRIVATE KEY-----\n";
340    const TEST_PUB_PEM: &str = "-----BEGIN PUBLIC KEY-----\n\
341        MCowBQYDK2VwAyEARCMxEnaM2/dblLuPNgBZpTvSUXO5ir+XQ1nyzJm4CFw=\n\
342        -----END PUBLIC KEY-----\n";
343
344    fn temp_pub_pem() -> std::path::PathBuf {
345        static N: AtomicU32 = AtomicU32::new(0);
346        let path = std::env::temp_dir().join(format!(
347            "sp_auth_{}_{}.pem",
348            std::process::id(),
349            N.fetch_add(1, Ordering::Relaxed)
350        ));
351        std::fs::write(&path, TEST_PUB_PEM).unwrap();
352        path
353    }
354
355    fn sign(claims: serde_json::Value) -> String {
356        let key = EncodingKey::from_ed_pem(TEST_PRIV_PEM.as_bytes()).unwrap();
357        encode(&Header::new(Algorithm::EdDSA), &claims, &key).unwrap()
358    }
359
360    fn future_exp() -> i64 {
361        let now = std::time::SystemTime::now()
362            .duration_since(std::time::UNIX_EPOCH)
363            .unwrap()
364            .as_secs() as i64;
365        now + 3600
366    }
367
368    fn auth_with_policy(roles: &[&str]) -> Arc<Auth> {
369        let cfg = AuthConfig {
370            mode: "jwt".into(),
371            jwt: Some(JwtConfig {
372                jwks_uri: None,
373                issuer: Some("test-iss".into()),
374                audience: Some("test-aud".into()),
375                public_key_pem_file: Some(temp_pub_pem()),
376                claims_headers: HashMap::from([("sub".to_string(), "x-user".to_string())]),
377                roles_claim: "roles".into(),
378            }),
379            forward_auth: Some(ForwardAuthConfig {
380                enabled: true,
381                path: "/auth/verify".into(),
382                policies: vec![RoutePolicyConfig {
383                    path: "/secure".into(),
384                    methods: vec!["*".into()],
385                    require_auth: true,
386                    required_roles: roles.iter().map(|s| s.to_string()).collect(),
387                }],
388                login_url: None,
389                applications_path: None,
390            }),
391            authz: None,
392        };
393        Auth::build(&cfg).unwrap().unwrap()
394    }
395
396    fn app(auth: Arc<Auth>) -> axum::Router {
397        // Both routes echo the x-user header the upstream would receive.
398        let echo = |headers: HeaderMap| async move {
399            headers
400                .get("x-user")
401                .and_then(|v| v.to_str().ok())
402                .unwrap_or("")
403                .to_string()
404        };
405        axum::Router::new()
406            .route("/secure", axum::routing::get(echo))
407            .route("/open", axum::routing::get(echo))
408            .layer(axum::middleware::from_fn_with_state(auth, middleware))
409    }
410
411    async fn body_string(resp: axum::response::Response) -> String {
412        let bytes = axum::body::to_bytes(resp.into_body(), 1024).await.unwrap();
413        String::from_utf8(bytes.to_vec()).unwrap()
414    }
415
416    #[tokio::test]
417    async fn strips_client_supplied_claim_headers() {
418        // A client forges x-user on an unprotected route with no token; the
419        // proxy must not forward the forged value to the upstream.
420        let app = app(auth_with_policy(&[]));
421        let resp = app
422            .oneshot(
423                HttpRequest::get("/open")
424                    .header("x-user", "forged-admin")
425                    .body(axum::body::Body::empty())
426                    .unwrap(),
427            )
428            .await
429            .unwrap();
430        assert_eq!(resp.status(), 200);
431        assert_eq!(body_string(resp).await, "");
432    }
433
434    #[tokio::test]
435    async fn unauthenticated_role_check_is_401_not_403() {
436        // Policy requires a role but not auth; a request with no token is
437        // unauthenticated, so it must get 401, not 403.
438        let cfg = AuthConfig {
439            mode: "jwt".into(),
440            jwt: Some(JwtConfig {
441                jwks_uri: None,
442                issuer: None,
443                audience: None,
444                public_key_pem_file: Some(temp_pub_pem()),
445                claims_headers: HashMap::new(),
446                roles_claim: "roles".into(),
447            }),
448            forward_auth: Some(ForwardAuthConfig {
449                enabled: true,
450                path: "/auth/verify".into(),
451                policies: vec![RoutePolicyConfig {
452                    path: "/secure".into(),
453                    methods: vec!["*".into()],
454                    require_auth: false,
455                    required_roles: vec!["admin".into()],
456                }],
457                login_url: None,
458                applications_path: None,
459            }),
460            authz: None,
461        };
462        let auth = Auth::build(&cfg).unwrap().unwrap();
463        let resp = app(auth)
464            .oneshot(
465                HttpRequest::get("/secure")
466                    .body(axum::body::Body::empty())
467                    .unwrap(),
468            )
469            .await
470            .unwrap();
471        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
472    }
473
474    #[tokio::test]
475    async fn rejects_missing_token_on_protected_route() {
476        let app = app(auth_with_policy(&[]));
477        let resp = app
478            .oneshot(
479                HttpRequest::get("/secure")
480                    .body(axum::body::Body::empty())
481                    .unwrap(),
482            )
483            .await
484            .unwrap();
485        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
486    }
487
488    #[tokio::test]
489    async fn accepts_valid_token_and_injects_claim_header() {
490        let app = app(auth_with_policy(&["admin"]));
491        let token = sign(serde_json::json!({
492            "iss": "test-iss", "aud": "test-aud", "exp": future_exp(),
493            "sub": "user-42", "roles": ["admin"]
494        }));
495        let resp = app
496            .oneshot(
497                HttpRequest::get("/secure")
498                    .header("authorization", format!("Bearer {token}"))
499                    .body(axum::body::Body::empty())
500                    .unwrap(),
501            )
502            .await
503            .unwrap();
504        assert_eq!(resp.status(), 200);
505        let body = axum::body::to_bytes(resp.into_body(), 1024).await.unwrap();
506        // The sub claim was forwarded to the handler as x-user.
507        assert_eq!(&body[..], b"user-42");
508    }
509
510    #[tokio::test]
511    async fn forbids_when_required_role_missing() {
512        let app = app(auth_with_policy(&["admin"]));
513        let token = sign(serde_json::json!({
514            "iss": "test-iss", "aud": "test-aud", "exp": future_exp(),
515            "sub": "user-42", "roles": ["viewer"]
516        }));
517        let resp = app
518            .oneshot(
519                HttpRequest::get("/secure")
520                    .header("authorization", format!("Bearer {token}"))
521                    .body(axum::body::Body::empty())
522                    .unwrap(),
523            )
524            .await
525            .unwrap();
526        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
527    }
528
529    #[tokio::test]
530    async fn rejects_expired_and_wrong_issuer() {
531        let app = app(auth_with_policy(&[]));
532        let expired = sign(serde_json::json!({
533            "iss": "test-iss", "aud": "test-aud", "exp": 1, "sub": "u", "roles": ["admin"]
534        }));
535        let resp = app
536            .clone()
537            .oneshot(
538                HttpRequest::get("/secure")
539                    .header("authorization", format!("Bearer {expired}"))
540                    .body(axum::body::Body::empty())
541                    .unwrap(),
542            )
543            .await
544            .unwrap();
545        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
546
547        let wrong_iss = sign(serde_json::json!({
548            "iss": "evil", "aud": "test-aud", "exp": future_exp(), "sub": "u", "roles": ["admin"]
549        }));
550        let resp = app
551            .oneshot(
552                HttpRequest::get("/secure")
553                    .header("authorization", format!("Bearer {wrong_iss}"))
554                    .body(axum::body::Body::empty())
555                    .unwrap(),
556            )
557            .await
558            .unwrap();
559        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
560    }
561}