structured_proxy/auth/
jwks.rs1use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9
10use jsonwebtoken::jwk::{AlgorithmParameters, EllipticCurve, Jwk, JwkSet, KeyAlgorithm};
11use jsonwebtoken::{Algorithm, DecodingKey};
12use tokio::sync::{Mutex, RwLock};
13
14#[derive(Clone)]
16pub struct VerifyingKey {
17 pub key: Arc<DecodingKey>,
18 pub algorithm: Algorithm,
19}
20
21pub struct JwksCache {
23 uri: String,
24 client: reqwest::Client,
25 keys: RwLock<HashMap<String, VerifyingKey>>,
26 last_refresh: Mutex<Option<Instant>>,
27}
28
29const MIN_REFRESH_INTERVAL: Duration = Duration::from_secs(60);
32
33const JWKS_HTTP_TIMEOUT: Duration = Duration::from_secs(5);
35
36impl JwksCache {
37 pub fn new(uri: String) -> Self {
39 let client = reqwest::Client::builder()
40 .timeout(JWKS_HTTP_TIMEOUT)
41 .build()
42 .unwrap_or_default();
43 Self {
44 uri,
45 client,
46 keys: RwLock::new(HashMap::new()),
47 last_refresh: Mutex::new(None),
48 }
49 }
50
51 pub async fn key_for(&self, kid: &str) -> Option<VerifyingKey> {
54 if let Some(k) = self.keys.read().await.get(kid).cloned() {
55 return Some(k);
56 }
57 if self.refresh().await.is_err() {
58 return None;
59 }
60 self.keys.read().await.get(kid).cloned()
61 }
62
63 async fn refresh(&self) -> Result<(), String> {
66 {
69 let mut last = self.last_refresh.lock().await;
70 if let Some(t) = *last {
71 let empty = self.keys.read().await.is_empty();
72 if !empty && t.elapsed() < MIN_REFRESH_INTERVAL {
73 return Err("refresh throttled".to_string());
74 }
75 }
76 *last = Some(Instant::now());
77 }
78
79 let set: JwkSet = self
80 .client
81 .get(&self.uri)
82 .send()
83 .await
84 .map_err(|e| format!("JWKS fetch failed: {e}"))?
85 .json()
86 .await
87 .map_err(|e| format!("JWKS decode failed: {e}"))?;
88
89 let new_keys = parse_jwks(&set);
90 *self.keys.write().await = new_keys;
91 Ok(())
92 }
93}
94
95fn parse_jwks(set: &JwkSet) -> HashMap<String, VerifyingKey> {
98 let mut map = HashMap::new();
99 for jwk in &set.keys {
100 let Some(kid) = jwk.common.key_id.clone() else {
101 continue;
102 };
103 let Some(algorithm) = algorithm_for(jwk) else {
104 continue;
105 };
106 if let Ok(key) = DecodingKey::from_jwk(jwk) {
107 map.insert(
108 kid,
109 VerifyingKey {
110 key: Arc::new(key),
111 algorithm,
112 },
113 );
114 }
115 }
116 map
117}
118
119fn algorithm_for(jwk: &Jwk) -> Option<Algorithm> {
125 if let Some(alg) = jwk.common.key_algorithm.and_then(key_algorithm_to_alg) {
126 return Some(alg);
127 }
128 match &jwk.algorithm {
129 AlgorithmParameters::RSA(_) => Some(Algorithm::RS256),
130 AlgorithmParameters::EllipticCurve(ec) => match ec.curve {
131 EllipticCurve::P256 => Some(Algorithm::ES256),
132 EllipticCurve::P384 => Some(Algorithm::ES384),
133 _ => None,
135 },
136 AlgorithmParameters::OctetKeyPair(_) => Some(Algorithm::EdDSA),
137 AlgorithmParameters::OctetKey(_) => None,
138 }
139}
140
141fn key_algorithm_to_alg(ka: KeyAlgorithm) -> Option<Algorithm> {
144 Some(match ka {
145 KeyAlgorithm::ES256 => Algorithm::ES256,
146 KeyAlgorithm::ES384 => Algorithm::ES384,
147 KeyAlgorithm::RS256 => Algorithm::RS256,
148 KeyAlgorithm::RS384 => Algorithm::RS384,
149 KeyAlgorithm::RS512 => Algorithm::RS512,
150 KeyAlgorithm::PS256 => Algorithm::PS256,
151 KeyAlgorithm::PS384 => Algorithm::PS384,
152 KeyAlgorithm::PS512 => Algorithm::PS512,
153 KeyAlgorithm::EdDSA => Algorithm::EdDSA,
154 _ => return None,
155 })
156}
157
158#[cfg(test)]
159mod tests {
160 use super::*;
161
162 #[test]
163 fn parse_jwks_keeps_asymmetric_keys_and_maps_algorithms() {
164 let set: JwkSet = serde_json::from_value(serde_json::json!({
167 "keys": [{
168 "kty": "RSA",
169 "kid": "rsa-1",
170 "use": "sig",
171 "n": "0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368Qen-JS7-zw04o6sJ9qjp6lFm5_T4nzcCqRfMOgRA_g_S0d7e9k7B0v0vqHr0e1V_o-z0ow5dWpql8-zKj4hQp8sg_Pn8O0R5ZQS4t8hUE-3-r3ftt1YzQ",
172 "e": "AQAB"
173 }]
174 })).unwrap();
175 let keys = parse_jwks(&set);
176 assert!(keys.contains_key("rsa-1"));
177 assert_eq!(keys["rsa-1"].algorithm, Algorithm::RS256);
178 }
179
180 #[test]
181 fn algorithm_prefers_explicit_jwk_alg() {
182 let jwk: Jwk = serde_json::from_value(serde_json::json!({
184 "kty": "EC", "crv": "P-384", "alg": "ES384", "kid": "k",
185 "x": "AAAA", "y": "AAAA"
186 }))
187 .unwrap();
188 assert_eq!(algorithm_for(&jwk), Some(Algorithm::ES384));
189 }
190
191 #[test]
192 fn algorithm_falls_back_to_curve_not_es256() {
193 let jwk: Jwk = serde_json::from_value(serde_json::json!({
195 "kty": "EC", "crv": "P-384", "kid": "k", "x": "AAAA", "y": "AAAA"
196 }))
197 .unwrap();
198 assert_eq!(algorithm_for(&jwk), Some(Algorithm::ES384));
199 }
200
201 #[test]
202 fn parse_jwks_skips_symmetric_and_keyless() {
203 let set: JwkSet = serde_json::from_value(serde_json::json!({
204 "keys": [
205 { "kty": "oct", "kid": "hmac", "k": "c2VjcmV0" },
206 { "kty": "RSA", "n": "0vx7ag", "e": "AQAB" }
207 ]
208 }))
209 .unwrap();
210 assert!(parse_jwks(&set).is_empty());
212 }
213}