rust_cfzt_validator/
cache.rs

1use crate::keys;
2use jsonwebtoken::{
3    jwk::{self, JwkSet},
4    DecodingKey,
5};
6use std::{collections::{HashMap, HashSet}, mem::replace, sync::RwLock};
7
8fn assert_key(key_id: &str, keymap: &keys::AccessKeyMap) {
9    if !keymap.contains_key(key_id) {
10        panic!("kid '{key_id}' is not in key set");
11    }
12}
13
14fn build_kid_set(keymap: &keys::AccessKeyMap) -> HashSet<String> {
15    let mut kids: HashSet<String> = HashSet::new();
16
17    for kid in keymap.keys() {
18        kids.insert(kid.to_string());
19    }
20
21    return kids;
22}
23
24fn build_jwk_set(keymap: &keys::AccessKeyMap) -> jwk::JwkSet {
25    let mut jwks: Vec<jwk::Jwk> = Vec::new();
26
27    for key in keymap.values() {
28        jwks.push(key.get_jwk());
29    }
30
31    JwkSet { keys: jwks }
32}
33
34struct KeySet {
35    kid_set: HashSet<String>,
36    key_set: jwk::JwkSet,
37}
38
39impl KeySet {
40    pub fn new(keymap: keys::AccessKeyMap) -> Self {
41        Self {
42            kid_set: build_kid_set(&keymap),
43            key_set: build_jwk_set(&keymap),
44        }
45    }
46
47    pub fn contains(&self, key_id: &str) -> bool {
48        self.kid_set.contains(key_id)
49    }
50
51    pub fn find(&self, key_id: &str) -> Option<&jwk::Jwk> {
52        self.key_set.find(key_id)
53    }
54
55    pub fn get_key_ids(&self) -> HashSet<String> {
56        self.kid_set.clone()
57    }
58}
59
60
61/// Maintains the autoritative list of currently trusted JWKs for a single team
62/// and caches the DecodingKey structs derived from them.
63/// Needs to be periodically seeded with latest keys by some external trigger
64/// invoking the rotate_keys() method.
65pub struct Cache {
66    latest_key_id: RwLock<String>,
67    key_set: RwLock<KeySet>,
68    decoding_keys: RwLock<HashMap<String, DecodingKey>>,
69}
70
71impl Cache {
72    fn contains_key(&self, key_id: &str) -> bool {
73        self.key_set.read().unwrap().contains(&key_id)
74    }
75
76    fn is_decoding_key_cached(&self, key_id: &str) -> bool {
77        self.decoding_keys.read().unwrap().contains_key(key_id)
78    }
79
80    fn get_key(&self, key_id: &str) -> Option<jwk::Jwk> {
81        Some(self.key_set.read().unwrap().find(key_id)?.clone())
82    }
83
84    fn flush_stale_decoding_keys(&self) {
85        // Acquire write lock
86        let mut decoding_keys = self.decoding_keys.write().unwrap();
87
88        // Take a snapshot of current keys
89        let cached_key_ids: Vec<String> = decoding_keys
90            .keys()
91            .map(|key_id| key_id.to_owned())
92            .collect();
93
94        // Identify stale entries in decoding key cache and purge them
95        for key_id in cached_key_ids {
96            if !self.contains_key(&key_id) {
97                decoding_keys.remove(&key_id);
98            }
99        }
100    }
101
102    fn build_decoding_key(&self, key_id: &str) {
103        if !self.is_decoding_key_cached(key_id) {
104            let mut decoding_keys = self.decoding_keys.write().unwrap();
105            let jwk = self.get_key(key_id).unwrap();
106            let decoding_key = DecodingKey::from_jwk(&jwk).unwrap();
107            decoding_keys.insert(key_id.to_string(), decoding_key);
108        }
109    }
110
111    /// Constructs a new Cache from a key ID denoting the latest JWK
112    /// and a HashMap of key IDs to AccessKey structs.
113    pub fn new(latest_key_id: &str, keymap: keys::AccessKeyMap) -> Self {
114        assert_key(latest_key_id, &keymap);
115
116        let this = Cache {
117            latest_key_id: RwLock::new(latest_key_id.to_string()),
118            key_set: RwLock::new(KeySet::new(keymap)),
119            decoding_keys: RwLock::new(HashMap::new()),
120        };
121
122        // Prewarm the cache with the latest key
123        this.build_decoding_key(latest_key_id);
124
125        this
126    }
127
128    /// Given a specific map of new keys, check if an update is required.
129    pub fn is_rotation_needed(&self, candidate_key_ids: HashSet<String>) -> bool {
130        let current = self.get_key_ids();
131        let diff: Vec<&String> = current.difference(&candidate_key_ids).collect();
132        diff.len() > 0
133    }
134
135    /// Retrieve the latest key id
136    pub fn get_latest_key_id(&self) -> String {
137        self.latest_key_id.read().unwrap().clone()
138    }
139
140    /// Updates the Cache with a new latest key ID and map of AccessKey structs.
141    pub fn rotate_keys(&self, latest_key_id: &str, latest_keymap: keys::AccessKeyMap) {
142        assert_key(latest_key_id, &latest_keymap);
143
144        let _ = replace(&mut *self.latest_key_id.write().unwrap(), latest_key_id.to_string());
145        let _ = replace(&mut *self.key_set.write().unwrap(), KeySet::new(latest_keymap));
146
147        self.flush_stale_decoding_keys();
148        self.build_decoding_key(latest_key_id);
149    }
150
151    /// Get the current list of trusted key IDs.
152    pub fn get_key_ids(&self) -> HashSet<String> {
153        self.key_set.read().unwrap().get_key_ids()
154    }
155
156    /// Attempt to retrieve a specific key as a DecodingKey struct.
157    pub fn get_decoding_key(&self, key_id: &str) -> Option<DecodingKey> {
158        if self.contains_key(&key_id) {
159                self.build_decoding_key(key_id);
160                return Some(self.decoding_keys.read().unwrap().get(key_id)?.to_owned());
161        }
162        None
163    }
164}
165
166#[cfg(test)]
167mod tests {
168    use crate::{api, keys};
169
170    use super::*;
171    use jsonwebtoken;
172    use serde::{Deserialize, Serialize};
173    use serde_json;
174
175    #[derive(Debug, Serialize, Deserialize)]
176    struct Claims {
177        foo: String,
178        bin: String,
179    }
180
181    const SAMPLE_NEW_PAYLOAD: &str = include_str!("../test_data/mock_signing_key_1.json");
182    const SAMPLE_ROTATION_PAYLOAD: &str = include_str!("../test_data/mock_signing_key_2.json");
183
184    const KEY_ID_NEW: &str = "o3KvfajHFSE6XLTo0oP98efQvVmfpS0CkPKlNSTzNjA";
185    const KEY_ID_ROTATE: &str = "X33sNdmTvRC0O6irH8lKcncS9klV37WVzKlV7v2zY_s";
186
187    const TOKEN_NEW: &str = "eyJhbGciOiJSUzI1NiIsImtpZCI6Im8zS3ZmYWpIRlNFNlhMVG8wb1A5OGVmUXZWbWZwUzBDa1BLbE5TVHpOakEifQ.eyJiaW4iOiJiYXoiLCJmb28iOiJiYXIifQ.jRRcOsa4Wayx5dbYC-Rk5qF5SKUq9OnYqRlilK8tuugXFrYkxGXpmX-2_TzRGH8--lnS-OWXVacnbTwKVyS1w3uswAph40ySIGUnOg9oKkL2Gu5aIq8AejmseqQkwWGep9a5dcklAiBMgiwTw2B2rQTay2ZCKKjY0TJm8Lh0Msngsb1aXlMWcLWZxUtEh5bVr7y3m23CT4NuL0hGMxFW9okzuRHW8pyWAgXln8ii2U8-ypVyJ0YLYjpvXPRGg12rPp3NgWh6uGe_HuRqVuHSSWVTUT-bwP4vcTndvq9943gc_O_VRd-OTnN2CRen8KXWdJLwW63mKvxUa4M9RFW-Iw";
188    const TOKEN_ROTATE: &str = "eyJhbGciOiJSUzI1NiIsImtpZCI6IlgzM3NOZG1UdlJDME82aXJIOGxLY25jUzlrbFYzN1dWektsVjd2MnpZX3MifQ.eyJiaW4iOiJiYXoiLCJmb28iOiJiYXIifQ.GCfxwZLaDpECHKRYbAg28ZE745ktgCnOlWnlPdT6JNnW3NQIDEHK1hTIjKU8I8yi88JAW77BWiJl7bUW-b_Ykmi3bltDuI4RfGdArQXgWsX5kNCyChMyT63JEh70USmZ7QsBuE3loMHM-gcmP_DD6iKvbCk2vY9TaxIsYfxJtSxZ8i9mYCR93W0qtY9uuSV6Tls6fYHj5shexrbbVmIDMYynxrsbhgbsm6q915k1OnTyxa8fc5Az3-c2zJc3yvOFcwo6z1c9SaRScmeV_U24PqBfWKCknJafv-atv4zkn-ClSZtxdW_JE3mRumib3a7F7gSfany2EhXsp7fOTNgeBg";
189
190    fn load_mock_data(text: &str) -> (String, keys::AccessKeyMap) {
191        let payload: serde_json::Value = serde_json::from_str(text).unwrap();
192        let key_id = api::extract_latest_key_id(&payload).unwrap();
193        let keymap = api::extract_current_keys(&payload).unwrap();
194
195        (key_id, keymap)
196    }
197
198    fn get_cache() -> Cache {
199        let (latest_key_id, keymap) = load_mock_data(SAMPLE_NEW_PAYLOAD);
200        Cache::new(&latest_key_id, keymap)
201    }
202
203    fn test_cache(cache: Cache, key_id: &str, token: &str) {
204        assert_eq!(cache.get_latest_key_id(), key_id);
205        assert!(cache.get_key_ids().contains(key_id));
206
207        let header = jsonwebtoken::decode_header(token).unwrap();
208        let header_kid = header.kid.unwrap();
209        let mut validation = jsonwebtoken::Validation::new(jsonwebtoken::Algorithm::RS256);
210        validation.required_spec_claims = HashSet::new();
211
212        assert_eq!(header_kid, key_id);
213
214        let decoding_key = cache.get_decoding_key(&header_kid).unwrap();
215        let result =
216            jsonwebtoken::decode::<Claims>(token, &decoding_key, &validation);
217
218        assert!(result.is_ok());
219
220        let token = result.unwrap();
221
222        assert_eq!(token.claims.foo.as_str(), "bar");
223        assert_eq!(token.claims.bin.as_str(), "baz");
224        assert_eq!(token.header.kid.unwrap(), key_id);
225    }
226
227    #[test]
228    fn test_fresh_cache() {
229        let cache = get_cache();
230        test_cache(cache, KEY_ID_NEW, TOKEN_NEW);
231    }
232
233    #[test]
234    fn test_cache_rotation() {
235        let cache = get_cache();
236        let key_ids = cache.get_key_ids();
237        assert!(!cache.is_rotation_needed(key_ids));
238        let (latest_key_id, latest_keymap) = load_mock_data(SAMPLE_ROTATION_PAYLOAD);
239        let latest_key_ids: HashSet<String> = latest_keymap.keys().cloned().collect();
240        assert!(cache.is_rotation_needed(latest_key_ids));
241        cache.rotate_keys(&latest_key_id, latest_keymap);
242        assert!(!cache.get_key_ids().contains(TOKEN_NEW));
243        test_cache(cache, KEY_ID_ROTATE, TOKEN_ROTATE);
244    }
245}