Skip to main content

secure_identity/
jwks.rs

1//! JWKS (JSON Web Key Set) key store with TTL-based caching.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6
7use jsonwebtoken::DecodingKey;
8use tokio::sync::RwLock;
9
10use crate::error::IdentityError;
11
12/// A cached JWKS key entry.
13#[derive(Clone)]
14struct CachedKey {
15    /// The algorithm, e.g. "RS256", "ES256".
16    alg: String,
17    /// The decoding key for signature verification.
18    decoding_key: DecodingKey,
19}
20
21/// Internal cache state.
22struct CacheState {
23    keys: HashMap<String, CachedKey>,
24    fetched_at: Option<Instant>,
25}
26
27/// A JWKS key store that fetches and caches public keys from a JWKS endpoint.
28///
29/// Keys are cached with a configurable TTL. When the cache expires, the next
30/// lookup triggers a refresh. If the endpoint is unavailable but the cache is warm,
31/// stale cached keys are used with a warning logged.
32pub struct JwksKeyStore {
33    url: String,
34    ttl: Duration,
35    cache: Arc<RwLock<CacheState>>,
36}
37
38impl JwksKeyStore {
39    /// Creates a new [`JwksKeyStore`] with the given endpoint URL and cache TTL.
40    #[must_use]
41    pub fn new(url: &str, ttl: Duration) -> Self {
42        Self {
43            url: url.to_owned(),
44            ttl,
45            cache: Arc::new(RwLock::new(CacheState {
46                keys: HashMap::new(),
47                fetched_at: None,
48            })),
49        }
50    }
51
52    /// Fetches the JWKS from the configured endpoint and updates the cache.
53    ///
54    /// # Errors
55    /// Returns `IdentityError::ProviderUnavailable` if the endpoint cannot be reached.
56    pub async fn fetch(&self) -> Result<(), IdentityError> {
57        let jwks = fetch_jwks_http(&self.url).await?;
58        let keys = parse_jwks(&jwks)?;
59
60        let mut cache = self.cache.write().await;
61        cache.keys = keys;
62        cache.fetched_at = Some(Instant::now());
63        Ok(())
64    }
65
66    /// Returns the [`DecodingKey`] for the given `kid`, fetching from the endpoint if
67    /// the cache is expired or empty.
68    pub async fn get_key(&self, kid: &str) -> Option<DecodingKey> {
69        // Check cache first
70        {
71            let cache = self.cache.read().await;
72            if let Some(fetched_at) = cache.fetched_at {
73                if fetched_at.elapsed() < self.ttl {
74                    return cache.keys.get(kid).map(|k| k.decoding_key.clone());
75                }
76            }
77        }
78
79        // Cache expired or empty — try to refresh
80        if let Err(e) = self.fetch().await {
81            tracing::warn!("JWKS refresh failed: {e}, using stale cache if available");
82            // Fall back to stale cache
83            let cache = self.cache.read().await;
84            return cache.keys.get(kid).map(|k| k.decoding_key.clone());
85        }
86
87        let cache = self.cache.read().await;
88        cache.keys.get(kid).map(|k| k.decoding_key.clone())
89    }
90
91    /// Returns the algorithm string for the given `kid`, if cached.
92    pub async fn get_algorithm(&self, kid: &str) -> Option<String> {
93        let cache = self.cache.read().await;
94        cache.keys.get(kid).map(|k| k.alg.clone())
95    }
96
97    /// Returns true if the cache has keys and is within TTL.
98    pub async fn is_cache_valid(&self) -> bool {
99        let cache = self.cache.read().await;
100        if let Some(fetched_at) = cache.fetched_at {
101            fetched_at.elapsed() < self.ttl && !cache.keys.is_empty()
102        } else {
103            false
104        }
105    }
106}
107
108/// JWKS JSON structures for parsing.
109#[derive(serde::Deserialize)]
110struct JwksDocument {
111    keys: Vec<JwkKey>,
112}
113
114#[derive(serde::Deserialize)]
115struct JwkKey {
116    #[serde(default)]
117    kid: Option<String>,
118    kty: String,
119    #[serde(default)]
120    alg: Option<String>,
121    #[serde(default)]
122    n: Option<String>,
123    #[serde(default)]
124    e: Option<String>,
125    #[serde(default)]
126    x: Option<String>,
127    #[serde(default)]
128    y: Option<String>,
129    #[serde(default)]
130    #[allow(dead_code)]
131    crv: Option<String>,
132}
133
134/// Fetches JWKS JSON from a URL using a simple HTTP GET.
135async fn fetch_jwks_http(url: &str) -> Result<String, IdentityError> {
136    // Use a simple TCP-based HTTP client to avoid requiring reqwest at compile time.
137    // For production use with the `jwks` feature, this would use reqwest.
138    // This implementation handles http:// URLs for testing and is intentionally simple.
139    let url_parsed = url::Url::parse(url).map_err(|_| IdentityError::ProviderUnavailable)?;
140
141    let host = url_parsed
142        .host_str()
143        .ok_or(IdentityError::ProviderUnavailable)?;
144    let port = url_parsed.port().unwrap_or(match url_parsed.scheme() {
145        "https" => 443,
146        _ => 80,
147    });
148    let path = url_parsed.path();
149
150    let addr = format!("{host}:{port}");
151    let stream = tokio::net::TcpStream::connect(&addr)
152        .await
153        .map_err(|_| IdentityError::ProviderUnavailable)?;
154
155    use tokio::io::{AsyncReadExt, AsyncWriteExt};
156    let mut stream = stream;
157    let request = format!("GET {path} HTTP/1.1\r\nHost: {host}\r\nConnection: close\r\n\r\n");
158    stream
159        .write_all(request.as_bytes())
160        .await
161        .map_err(|_| IdentityError::ProviderUnavailable)?;
162
163    let mut response = Vec::new();
164    stream
165        .read_to_end(&mut response)
166        .await
167        .map_err(|_| IdentityError::ProviderUnavailable)?;
168
169    let response_str = String::from_utf8_lossy(&response);
170    // Find body after \r\n\r\n
171    let body_start = response_str
172        .find("\r\n\r\n")
173        .map(|i| i + 4)
174        .ok_or(IdentityError::ProviderUnavailable)?;
175    Ok(response_str[body_start..].to_string())
176}
177
178/// Parses a JWKS JSON document into a map of kid → CachedKey.
179fn parse_jwks(json: &str) -> Result<HashMap<String, CachedKey>, IdentityError> {
180    let doc: JwksDocument =
181        serde_json::from_str(json).map_err(|_| IdentityError::TokenMalformed)?;
182
183    let mut keys = HashMap::new();
184    for jwk in &doc.keys {
185        let kid = match &jwk.kid {
186            Some(k) => k.clone(),
187            None => continue, // skip keys without kid
188        };
189        let alg = jwk.alg.clone().unwrap_or_default();
190
191        let decoding_key = match jwk.kty.as_str() {
192            "RSA" => {
193                let n = jwk.n.as_deref().ok_or(IdentityError::TokenMalformed)?;
194                let e = jwk.e.as_deref().ok_or(IdentityError::TokenMalformed)?;
195                DecodingKey::from_rsa_components(n, e).map_err(|_| IdentityError::TokenMalformed)?
196            }
197            "EC" => {
198                let x = jwk.x.as_deref().ok_or(IdentityError::TokenMalformed)?;
199                let y = jwk.y.as_deref().ok_or(IdentityError::TokenMalformed)?;
200                DecodingKey::from_ec_components(x, y).map_err(|_| IdentityError::TokenMalformed)?
201            }
202            _ => continue, // skip unknown key types
203        };
204
205        keys.insert(kid, CachedKey { alg, decoding_key });
206    }
207
208    Ok(keys)
209}