1use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6
7use jsonwebtoken::DecodingKey;
8use tokio::sync::RwLock;
9
10use crate::error::IdentityError;
11
12#[derive(Clone)]
14struct CachedKey {
15 alg: String,
17 decoding_key: DecodingKey,
19}
20
21struct CacheState {
23 keys: HashMap<String, CachedKey>,
24 fetched_at: Option<Instant>,
25}
26
27pub struct JwksKeyStore {
33 url: String,
34 ttl: Duration,
35 cache: Arc<RwLock<CacheState>>,
36}
37
38impl JwksKeyStore {
39 #[must_use]
41 pub fn new(url: &str, ttl: Duration) -> Self {
42 Self {
43 url: url.to_owned(),
44 ttl,
45 cache: Arc::new(RwLock::new(CacheState {
46 keys: HashMap::new(),
47 fetched_at: None,
48 })),
49 }
50 }
51
52 pub async fn fetch(&self) -> Result<(), IdentityError> {
57 let jwks = fetch_jwks_http(&self.url).await?;
58 let keys = parse_jwks(&jwks)?;
59
60 let mut cache = self.cache.write().await;
61 cache.keys = keys;
62 cache.fetched_at = Some(Instant::now());
63 Ok(())
64 }
65
66 pub async fn get_key(&self, kid: &str) -> Option<DecodingKey> {
69 {
71 let cache = self.cache.read().await;
72 if let Some(fetched_at) = cache.fetched_at {
73 if fetched_at.elapsed() < self.ttl {
74 return cache.keys.get(kid).map(|k| k.decoding_key.clone());
75 }
76 }
77 }
78
79 if let Err(e) = self.fetch().await {
81 tracing::warn!("JWKS refresh failed: {e}, using stale cache if available");
82 let cache = self.cache.read().await;
84 return cache.keys.get(kid).map(|k| k.decoding_key.clone());
85 }
86
87 let cache = self.cache.read().await;
88 cache.keys.get(kid).map(|k| k.decoding_key.clone())
89 }
90
91 pub async fn get_algorithm(&self, kid: &str) -> Option<String> {
93 let cache = self.cache.read().await;
94 cache.keys.get(kid).map(|k| k.alg.clone())
95 }
96
97 pub async fn is_cache_valid(&self) -> bool {
99 let cache = self.cache.read().await;
100 if let Some(fetched_at) = cache.fetched_at {
101 fetched_at.elapsed() < self.ttl && !cache.keys.is_empty()
102 } else {
103 false
104 }
105 }
106}
107
108#[derive(serde::Deserialize)]
110struct JwksDocument {
111 keys: Vec<JwkKey>,
112}
113
114#[derive(serde::Deserialize)]
115struct JwkKey {
116 #[serde(default)]
117 kid: Option<String>,
118 kty: String,
119 #[serde(default)]
120 alg: Option<String>,
121 #[serde(default)]
122 n: Option<String>,
123 #[serde(default)]
124 e: Option<String>,
125 #[serde(default)]
126 x: Option<String>,
127 #[serde(default)]
128 y: Option<String>,
129 #[serde(default)]
130 #[allow(dead_code)]
131 crv: Option<String>,
132}
133
134async fn fetch_jwks_http(url: &str) -> Result<String, IdentityError> {
136 let url_parsed = url::Url::parse(url).map_err(|_| IdentityError::ProviderUnavailable)?;
140
141 let host = url_parsed
142 .host_str()
143 .ok_or(IdentityError::ProviderUnavailable)?;
144 let port = url_parsed.port().unwrap_or(match url_parsed.scheme() {
145 "https" => 443,
146 _ => 80,
147 });
148 let path = url_parsed.path();
149
150 let addr = format!("{host}:{port}");
151 let stream = tokio::net::TcpStream::connect(&addr)
152 .await
153 .map_err(|_| IdentityError::ProviderUnavailable)?;
154
155 use tokio::io::{AsyncReadExt, AsyncWriteExt};
156 let mut stream = stream;
157 let request = format!("GET {path} HTTP/1.1\r\nHost: {host}\r\nConnection: close\r\n\r\n");
158 stream
159 .write_all(request.as_bytes())
160 .await
161 .map_err(|_| IdentityError::ProviderUnavailable)?;
162
163 let mut response = Vec::new();
164 stream
165 .read_to_end(&mut response)
166 .await
167 .map_err(|_| IdentityError::ProviderUnavailable)?;
168
169 let response_str = String::from_utf8_lossy(&response);
170 let body_start = response_str
172 .find("\r\n\r\n")
173 .map(|i| i + 4)
174 .ok_or(IdentityError::ProviderUnavailable)?;
175 Ok(response_str[body_start..].to_string())
176}
177
178fn parse_jwks(json: &str) -> Result<HashMap<String, CachedKey>, IdentityError> {
180 let doc: JwksDocument =
181 serde_json::from_str(json).map_err(|_| IdentityError::TokenMalformed)?;
182
183 let mut keys = HashMap::new();
184 for jwk in &doc.keys {
185 let kid = match &jwk.kid {
186 Some(k) => k.clone(),
187 None => continue, };
189 let alg = jwk.alg.clone().unwrap_or_default();
190
191 let decoding_key = match jwk.kty.as_str() {
192 "RSA" => {
193 let n = jwk.n.as_deref().ok_or(IdentityError::TokenMalformed)?;
194 let e = jwk.e.as_deref().ok_or(IdentityError::TokenMalformed)?;
195 DecodingKey::from_rsa_components(n, e).map_err(|_| IdentityError::TokenMalformed)?
196 }
197 "EC" => {
198 let x = jwk.x.as_deref().ok_or(IdentityError::TokenMalformed)?;
199 let y = jwk.y.as_deref().ok_or(IdentityError::TokenMalformed)?;
200 DecodingKey::from_ec_components(x, y).map_err(|_| IdentityError::TokenMalformed)?
201 }
202 _ => continue, };
204
205 keys.insert(kid, CachedKey { alg, decoding_key });
206 }
207
208 Ok(keys)
209}