rust_cfzt_validator/
cache.rs1use crate::keys;
2use jsonwebtoken::{
3 jwk::{self, JwkSet},
4 DecodingKey,
5};
6use std::{collections::{HashMap, HashSet}, mem::replace, sync::RwLock};
7
8fn assert_key(key_id: &str, keymap: &keys::AccessKeyMap) {
9 if !keymap.contains_key(key_id) {
10 panic!("kid '{key_id}' is not in key set");
11 }
12}
13
14fn build_kid_set(keymap: &keys::AccessKeyMap) -> HashSet<String> {
15 let mut kids: HashSet<String> = HashSet::new();
16
17 for kid in keymap.keys() {
18 kids.insert(kid.to_string());
19 }
20
21 return kids;
22}
23
24fn build_jwk_set(keymap: &keys::AccessKeyMap) -> jwk::JwkSet {
25 let mut jwks: Vec<jwk::Jwk> = Vec::new();
26
27 for key in keymap.values() {
28 jwks.push(key.get_jwk());
29 }
30
31 JwkSet { keys: jwks }
32}
33
34struct KeySet {
35 kid_set: HashSet<String>,
36 key_set: jwk::JwkSet,
37}
38
39impl KeySet {
40 pub fn new(keymap: keys::AccessKeyMap) -> Self {
41 Self {
42 kid_set: build_kid_set(&keymap),
43 key_set: build_jwk_set(&keymap),
44 }
45 }
46
47 pub fn contains(&self, key_id: &str) -> bool {
48 self.kid_set.contains(key_id)
49 }
50
51 pub fn find(&self, key_id: &str) -> Option<&jwk::Jwk> {
52 self.key_set.find(key_id)
53 }
54
55 pub fn get_key_ids(&self) -> HashSet<String> {
56 self.kid_set.clone()
57 }
58}
59
60
61pub struct Cache {
66 latest_key_id: RwLock<String>,
67 key_set: RwLock<KeySet>,
68 decoding_keys: RwLock<HashMap<String, DecodingKey>>,
69}
70
71impl Cache {
72 fn contains_key(&self, key_id: &str) -> bool {
73 self.key_set.read().unwrap().contains(&key_id)
74 }
75
76 fn is_decoding_key_cached(&self, key_id: &str) -> bool {
77 self.decoding_keys.read().unwrap().contains_key(key_id)
78 }
79
80 fn get_key(&self, key_id: &str) -> Option<jwk::Jwk> {
81 Some(self.key_set.read().unwrap().find(key_id)?.clone())
82 }
83
84 fn flush_stale_decoding_keys(&self) {
85 let mut decoding_keys = self.decoding_keys.write().unwrap();
87
88 let cached_key_ids: Vec<String> = decoding_keys
90 .keys()
91 .map(|key_id| key_id.to_owned())
92 .collect();
93
94 for key_id in cached_key_ids {
96 if !self.contains_key(&key_id) {
97 decoding_keys.remove(&key_id);
98 }
99 }
100 }
101
102 fn build_decoding_key(&self, key_id: &str) {
103 if !self.is_decoding_key_cached(key_id) {
104 let mut decoding_keys = self.decoding_keys.write().unwrap();
105 let jwk = self.get_key(key_id).unwrap();
106 let decoding_key = DecodingKey::from_jwk(&jwk).unwrap();
107 decoding_keys.insert(key_id.to_string(), decoding_key);
108 }
109 }
110
111 pub fn new(latest_key_id: &str, keymap: keys::AccessKeyMap) -> Self {
114 assert_key(latest_key_id, &keymap);
115
116 let this = Cache {
117 latest_key_id: RwLock::new(latest_key_id.to_string()),
118 key_set: RwLock::new(KeySet::new(keymap)),
119 decoding_keys: RwLock::new(HashMap::new()),
120 };
121
122 this.build_decoding_key(latest_key_id);
124
125 this
126 }
127
128 pub fn is_rotation_needed(&self, candidate_key_ids: HashSet<String>) -> bool {
130 let current = self.get_key_ids();
131 let diff: Vec<&String> = current.difference(&candidate_key_ids).collect();
132 diff.len() > 0
133 }
134
135 pub fn get_latest_key_id(&self) -> String {
137 self.latest_key_id.read().unwrap().clone()
138 }
139
140 pub fn rotate_keys(&self, latest_key_id: &str, latest_keymap: keys::AccessKeyMap) {
142 assert_key(latest_key_id, &latest_keymap);
143
144 let _ = replace(&mut *self.latest_key_id.write().unwrap(), latest_key_id.to_string());
145 let _ = replace(&mut *self.key_set.write().unwrap(), KeySet::new(latest_keymap));
146
147 self.flush_stale_decoding_keys();
148 self.build_decoding_key(latest_key_id);
149 }
150
151 pub fn get_key_ids(&self) -> HashSet<String> {
153 self.key_set.read().unwrap().get_key_ids()
154 }
155
156 pub fn get_decoding_key(&self, key_id: &str) -> Option<DecodingKey> {
158 if self.contains_key(&key_id) {
159 self.build_decoding_key(key_id);
160 return Some(self.decoding_keys.read().unwrap().get(key_id)?.to_owned());
161 }
162 None
163 }
164}
165
166#[cfg(test)]
167mod tests {
168 use crate::{api, keys};
169
170 use super::*;
171 use jsonwebtoken;
172 use serde::{Deserialize, Serialize};
173 use serde_json;
174
175 #[derive(Debug, Serialize, Deserialize)]
176 struct Claims {
177 foo: String,
178 bin: String,
179 }
180
181 const SAMPLE_NEW_PAYLOAD: &str = include_str!("../test_data/mock_signing_key_1.json");
182 const SAMPLE_ROTATION_PAYLOAD: &str = include_str!("../test_data/mock_signing_key_2.json");
183
184 const KEY_ID_NEW: &str = "o3KvfajHFSE6XLTo0oP98efQvVmfpS0CkPKlNSTzNjA";
185 const KEY_ID_ROTATE: &str = "X33sNdmTvRC0O6irH8lKcncS9klV37WVzKlV7v2zY_s";
186
187 const TOKEN_NEW: &str = "eyJhbGciOiJSUzI1NiIsImtpZCI6Im8zS3ZmYWpIRlNFNlhMVG8wb1A5OGVmUXZWbWZwUzBDa1BLbE5TVHpOakEifQ.eyJiaW4iOiJiYXoiLCJmb28iOiJiYXIifQ.jRRcOsa4Wayx5dbYC-Rk5qF5SKUq9OnYqRlilK8tuugXFrYkxGXpmX-2_TzRGH8--lnS-OWXVacnbTwKVyS1w3uswAph40ySIGUnOg9oKkL2Gu5aIq8AejmseqQkwWGep9a5dcklAiBMgiwTw2B2rQTay2ZCKKjY0TJm8Lh0Msngsb1aXlMWcLWZxUtEh5bVr7y3m23CT4NuL0hGMxFW9okzuRHW8pyWAgXln8ii2U8-ypVyJ0YLYjpvXPRGg12rPp3NgWh6uGe_HuRqVuHSSWVTUT-bwP4vcTndvq9943gc_O_VRd-OTnN2CRen8KXWdJLwW63mKvxUa4M9RFW-Iw";
188 const TOKEN_ROTATE: &str = "eyJhbGciOiJSUzI1NiIsImtpZCI6IlgzM3NOZG1UdlJDME82aXJIOGxLY25jUzlrbFYzN1dWektsVjd2MnpZX3MifQ.eyJiaW4iOiJiYXoiLCJmb28iOiJiYXIifQ.GCfxwZLaDpECHKRYbAg28ZE745ktgCnOlWnlPdT6JNnW3NQIDEHK1hTIjKU8I8yi88JAW77BWiJl7bUW-b_Ykmi3bltDuI4RfGdArQXgWsX5kNCyChMyT63JEh70USmZ7QsBuE3loMHM-gcmP_DD6iKvbCk2vY9TaxIsYfxJtSxZ8i9mYCR93W0qtY9uuSV6Tls6fYHj5shexrbbVmIDMYynxrsbhgbsm6q915k1OnTyxa8fc5Az3-c2zJc3yvOFcwo6z1c9SaRScmeV_U24PqBfWKCknJafv-atv4zkn-ClSZtxdW_JE3mRumib3a7F7gSfany2EhXsp7fOTNgeBg";
189
190 fn load_mock_data(text: &str) -> (String, keys::AccessKeyMap) {
191 let payload: serde_json::Value = serde_json::from_str(text).unwrap();
192 let key_id = api::extract_latest_key_id(&payload).unwrap();
193 let keymap = api::extract_current_keys(&payload).unwrap();
194
195 (key_id, keymap)
196 }
197
198 fn get_cache() -> Cache {
199 let (latest_key_id, keymap) = load_mock_data(SAMPLE_NEW_PAYLOAD);
200 Cache::new(&latest_key_id, keymap)
201 }
202
203 fn test_cache(cache: Cache, key_id: &str, token: &str) {
204 assert_eq!(cache.get_latest_key_id(), key_id);
205 assert!(cache.get_key_ids().contains(key_id));
206
207 let header = jsonwebtoken::decode_header(token).unwrap();
208 let header_kid = header.kid.unwrap();
209 let mut validation = jsonwebtoken::Validation::new(jsonwebtoken::Algorithm::RS256);
210 validation.required_spec_claims = HashSet::new();
211
212 assert_eq!(header_kid, key_id);
213
214 let decoding_key = cache.get_decoding_key(&header_kid).unwrap();
215 let result =
216 jsonwebtoken::decode::<Claims>(token, &decoding_key, &validation);
217
218 assert!(result.is_ok());
219
220 let token = result.unwrap();
221
222 assert_eq!(token.claims.foo.as_str(), "bar");
223 assert_eq!(token.claims.bin.as_str(), "baz");
224 assert_eq!(token.header.kid.unwrap(), key_id);
225 }
226
227 #[test]
228 fn test_fresh_cache() {
229 let cache = get_cache();
230 test_cache(cache, KEY_ID_NEW, TOKEN_NEW);
231 }
232
233 #[test]
234 fn test_cache_rotation() {
235 let cache = get_cache();
236 let key_ids = cache.get_key_ids();
237 assert!(!cache.is_rotation_needed(key_ids));
238 let (latest_key_id, latest_keymap) = load_mock_data(SAMPLE_ROTATION_PAYLOAD);
239 let latest_key_ids: HashSet<String> = latest_keymap.keys().cloned().collect();
240 assert!(cache.is_rotation_needed(latest_key_ids));
241 cache.rotate_keys(&latest_key_id, latest_keymap);
242 assert!(!cache.get_key_ids().contains(TOKEN_NEW));
243 test_cache(cache, KEY_ID_ROTATE, TOKEN_ROTATE);
244 }
245}