Skip to main content

stormchaser_api/auth/
mod.rs

1pub mod jwks;
2/// OPA integration for authorization
3pub mod opa;
4
5use crate::AppState;
6use axum::{
7    extract::FromRequestParts,
8    http::{request::Parts, StatusCode},
9};
10use jsonwebtoken::{decode, decode_header, DecodingKey, Validation};
11use serde::{Deserialize, Serialize};
12pub use stormchaser_model::auth::Claims;
13
14/// Fallback JWT secret for local development
15pub const JWT_SECRET: &[u8] = b"stormchaser-secret-dev-only"; // Fallback for local dev
16
17/// Extractor for authenticated user claims
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct AuthClaims(pub Claims);
20
21#[axum::async_trait]
22impl FromRequestParts<AppState> for AuthClaims {
23    type Rejection = StatusCode;
24
25    async fn from_request_parts(
26        parts: &mut Parts,
27        state: &AppState,
28    ) -> Result<Self, Self::Rejection> {
29        let auth_header = parts
30            .headers
31            .get(axum::http::header::AUTHORIZATION)
32            .and_then(|h| h.to_str().ok())
33            .ok_or(StatusCode::UNAUTHORIZED)?;
34
35        if !auth_header.starts_with("Bearer ") {
36            return Err(StatusCode::UNAUTHORIZED);
37        }
38
39        let token = &auth_header["Bearer ".len()..];
40
41        // 1. Try OIDC/Dex validation if configured
42        if let Some(oidc_config) = &state.oidc_config {
43            if let Ok(header) = decode_header(token) {
44                if let Some(kid) = header.kid {
45                    let mut jwk_opt = state.jwks.read().await.get(&kid).cloned();
46
47                    if jwk_opt.is_none() {
48                        tracing::warn!("kid {} not found in JWKS cache, attempting refresh", kid);
49                        let new_jwks = crate::auth::jwks::fetch_jwks(&oidc_config.jwks_url).await;
50                        let mut jwks_write = state.jwks.write().await;
51                        *jwks_write = new_jwks;
52                        jwk_opt = jwks_write.get(&kid).cloned();
53                    }
54
55                    if let Some(jwk) = jwk_opt {
56                        let mut validation = Validation::new(header.alg);
57                        validation.set_audience(std::slice::from_ref(&oidc_config.client_id));
58                        validation.set_issuer(&[
59                            oidc_config.issuer.as_str(),
60                            oidc_config.external_issuer.as_str(),
61                        ]);
62
63                        if let Ok(decoding_key) = DecodingKey::from_jwk(&jwk) {
64                            if let Ok(token_data) =
65                                decode::<Claims>(token, &decoding_key, &validation)
66                            {
67                                return Ok(AuthClaims(token_data.claims));
68                            }
69                        }
70                    }
71                }
72            }
73        }
74
75        // 2. Fallback to local secret for legacy/dev tokens
76        let mut validation = Validation::default();
77        validation.validate_exp = true;
78        // Skip audience/issuer check for local tokens as they don't have them set usually in the current model
79        validation.required_spec_claims.remove("aud");
80
81        let token_data =
82            decode::<Claims>(token, &DecodingKey::from_secret(JWT_SECRET), &validation)
83                .inspect_err(|e| tracing::error!("JWT decode failed: {:?}", e))
84                .map_err(|_| StatusCode::UNAUTHORIZED)?;
85
86        Ok(AuthClaims(token_data.claims))
87    }
88}