spiffe_rs/bundle/jwtbundle/
mod.rs

1use crate::internal::jwk::JwkDocument;
2use crate::internal::jwtutil;
3use crate::spiffeid::TrustDomain;
4use base64::Engine;
5use serde::Serialize;
6use std::collections::HashMap;
7use std::fs;
8use std::io::Read;
9use std::sync::RwLock;
10
11#[derive(Debug, Clone)]
12pub struct Error(String);
13
14impl std::fmt::Display for Error {
15    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
16        self.0.fmt(f)
17    }
18}
19
20impl std::error::Error for Error {}
21
22impl Error {
23    pub fn new(message: impl Into<String>) -> Error {
24        Error(message.into())
25    }
26}
27
28pub type Result<T> = std::result::Result<T, Error>;
29
30pub use crate::internal::jwk::JwtKey;
31
32fn wrap_error(message: impl std::fmt::Display) -> Error {
33    Error(format!("jwtbundle: {}", message))
34}
35
36fn strip_prefix(message: &str) -> &str {
37    message.strip_prefix("jwtbundle: ").unwrap_or(message)
38}
39
40/// A JWT bundle contains the JWT authorities (public keys) for a trust domain.
41#[derive(Debug)]
42pub struct Bundle {
43    trust_domain: TrustDomain,
44    jwt_authorities: RwLock<HashMap<String, JwtKey>>,
45}
46
47impl Bundle {
48    /// Creates a new empty `Bundle` for the given trust domain.
49    pub fn new(trust_domain: TrustDomain) -> Bundle {
50        Bundle {
51            trust_domain,
52            jwt_authorities: RwLock::new(HashMap::new()),
53        }
54    }
55
56    /// Creates a new `Bundle` for the given trust domain and authorities.
57    pub fn from_jwt_authorities(
58        trust_domain: TrustDomain,
59        jwt_authorities: &HashMap<String, JwtKey>,
60    ) -> Bundle {
61        Bundle {
62            trust_domain,
63            jwt_authorities: RwLock::new(jwtutil::copy_jwt_authorities(jwt_authorities)),
64        }
65    }
66
67    /// Loads a JWT bundle from a JSON file (JWKS).
68    pub fn load(trust_domain: TrustDomain, path: &str) -> Result<Bundle> {
69        let bytes =
70            fs::read(path).map_err(|err| wrap_error(format!("unable to read JWT bundle: {}", err)))?;
71        Bundle::parse(trust_domain, &bytes)
72    }
73
74    /// Reads a JWT bundle from a reader.
75    pub fn read(trust_domain: TrustDomain, reader: &mut dyn Read) -> Result<Bundle> {
76        let mut bytes = Vec::new();
77        reader
78            .read_to_end(&mut bytes)
79            .map_err(|err| wrap_error(format!("unable to read: {}", err)))?;
80        Bundle::parse(trust_domain, &bytes)
81    }
82
83    /// Parses a JWT bundle from JSON bytes (JWKS).
84    pub fn parse(trust_domain: TrustDomain, bytes: &[u8]) -> Result<Bundle> {
85        let jwks: JwkDocument =
86            serde_json::from_slice(bytes).map_err(|err| wrap_error(format!("unable to parse JWKS: {}", err)))?;
87        let bundle = Bundle::new(trust_domain);
88        let keys = jwks.keys.unwrap_or_default();
89        for (idx, key) in keys.iter().enumerate() {
90            let key_id = key.key_id().unwrap_or_default();
91            let jwt_key = key
92                .to_jwt_key()
93                .map_err(|err| wrap_error(format!("error adding authority {} of JWKS: {}", idx, err)))?;
94            if let Err(err) = bundle.add_jwt_authority(key_id, jwt_key) {
95                return Err(wrap_error(format!(
96                    "error adding authority {} of JWKS: {}",
97                    idx,
98                    strip_prefix(&err.to_string())
99                )));
100            }
101        }
102        Ok(bundle)
103    }
104
105    /// Returns the trust domain of the bundle.
106    pub fn trust_domain(&self) -> TrustDomain {
107        self.trust_domain.clone()
108    }
109
110    /// Returns the JWT authorities in the bundle.
111    pub fn jwt_authorities(&self) -> HashMap<String, JwtKey> {
112        self.jwt_authorities
113            .read()
114            .map(|guard| jwtutil::copy_jwt_authorities(&guard))
115            .unwrap_or_default()
116    }
117
118    /// Finds a JWT authority by its key ID.
119    pub fn find_jwt_authority(&self, key_id: &str) -> Option<JwtKey> {
120        self.jwt_authorities
121            .read()
122            .ok()
123            .and_then(|guard| guard.get(key_id).cloned())
124    }
125
126    /// Returns `true` if the bundle has an authority with the given key ID.
127    pub fn has_jwt_authority(&self, key_id: &str) -> bool {
128        self.jwt_authorities
129            .read()
130            .map(|guard| guard.contains_key(key_id))
131            .unwrap_or(false)
132    }
133
134    /// Adds a JWT authority to the bundle.
135    pub fn add_jwt_authority(&self, key_id: &str, jwt_authority: JwtKey) -> Result<()> {
136        if key_id.is_empty() {
137            return Err(wrap_error("keyID cannot be empty"));
138        }
139        if let Ok(mut guard) = self.jwt_authorities.write() {
140            guard.insert(key_id.to_string(), jwt_authority);
141        }
142        Ok(())
143    }
144
145    /// Removes a JWT authority from the bundle.
146    pub fn remove_jwt_authority(&self, key_id: &str) {
147        if let Ok(mut guard) = self.jwt_authorities.write() {
148            guard.remove(key_id);
149        }
150    }
151
152    /// Sets the JWT authorities in the bundle.
153    pub fn set_jwt_authorities(&self, jwt_authorities: &HashMap<String, JwtKey>) {
154        if let Ok(mut guard) = self.jwt_authorities.write() {
155            *guard = jwtutil::copy_jwt_authorities(jwt_authorities);
156        }
157    }
158
159    /// Returns `true` if the bundle is empty.
160    pub fn empty(&self) -> bool {
161        self.jwt_authorities
162            .read()
163            .map(|guard| guard.is_empty())
164            .unwrap_or(true)
165    }
166
167    /// Marshals the bundle to JSON bytes (JWKS).
168    pub fn marshal(&self) -> Result<Vec<u8>> {
169        let mut keys = Vec::new();
170        let authorities = self.jwt_authorities();
171        for (key_id, jwt_key) in authorities {
172            keys.push(JwksKey::from_jwt_key(&key_id, &jwt_key));
173        }
174        let jwks = Jwks { keys };
175        serde_json::to_vec(&jwks).map_err(|err| wrap_error(err))
176    }
177
178    /// Clones the bundle.
179    pub fn clone_bundle(&self) -> Bundle {
180        Bundle::from_jwt_authorities(self.trust_domain(), &self.jwt_authorities())
181    }
182
183    /// Returns `true` if this bundle is equal to another bundle.
184    pub fn equal(&self, other: &Bundle) -> bool {
185        self.trust_domain == other.trust_domain
186            && jwtutil::jwt_authorities_equal(&self.jwt_authorities(), &other.jwt_authorities())
187    }
188
189    /// Returns the bundle for the given trust domain if it matches.
190    pub fn get_jwt_bundle_for_trust_domain(&self, trust_domain: TrustDomain) -> Result<Bundle> {
191        if self.trust_domain != trust_domain {
192            return Err(wrap_error(format!(
193                "no JWT bundle for trust domain \"{}\"",
194                trust_domain
195            )));
196        }
197        Ok(self.clone_bundle())
198    }
199}
200
201/// A source of JWT bundles.
202pub trait Source {
203    /// Returns the JWT bundle for the given trust domain.
204    fn get_jwt_bundle_for_trust_domain(&self, trust_domain: TrustDomain) -> Result<Bundle>;
205}
206
207/// A set of JWT bundles for multiple trust domains.
208#[derive(Debug)]
209pub struct Set {
210    bundles: RwLock<HashMap<TrustDomain, Bundle>>,
211}
212
213impl Set {
214    /// Creates a new `Set` from the given bundles.
215    pub fn new(bundles: &[Bundle]) -> Set {
216        let mut map = HashMap::new();
217        for bundle in bundles {
218            map.insert(bundle.trust_domain(), bundle.clone_bundle());
219        }
220        Set {
221            bundles: RwLock::new(map),
222        }
223    }
224
225    /// Adds a bundle to the set.
226    pub fn add(&self, bundle: &Bundle) {
227        if let Ok(mut guard) = self.bundles.write() {
228            guard.insert(bundle.trust_domain(), bundle.clone_bundle());
229        }
230    }
231
232    /// Removes the bundle for the given trust domain from the set.
233    pub fn remove(&self, trust_domain: TrustDomain) {
234        if let Ok(mut guard) = self.bundles.write() {
235            guard.remove(&trust_domain);
236        }
237    }
238
239    /// Returns `true` if the set has a bundle for the given trust domain.
240    pub fn has(&self, trust_domain: TrustDomain) -> bool {
241        self.bundles
242            .read()
243            .map(|guard| guard.contains_key(&trust_domain))
244            .unwrap_or(false)
245    }
246
247    /// Returns the bundle for the given trust domain from the set.
248    pub fn get(&self, trust_domain: TrustDomain) -> Option<Bundle> {
249        self.bundles
250            .read()
251            .ok()
252            .and_then(|guard| guard.get(&trust_domain).map(|b| b.clone_bundle()))
253    }
254
255    /// Returns all bundles in the set.
256    pub fn bundles(&self) -> Vec<Bundle> {
257        let mut bundles = self
258            .bundles
259            .read()
260            .map(|guard| guard.values().map(|b| b.clone_bundle()).collect::<Vec<_>>())
261            .unwrap_or_default();
262        bundles.sort_by(|a, b| a.trust_domain().compare(&b.trust_domain()));
263        bundles
264    }
265
266    /// Returns the number of bundles in the set.
267    pub fn len(&self) -> usize {
268        self.bundles.read().map(|guard| guard.len()).unwrap_or(0)
269    }
270
271    /// Returns the JWT bundle for the given trust domain.
272    pub fn get_jwt_bundle_for_trust_domain(&self, trust_domain: TrustDomain) -> Result<Bundle> {
273        let guard = self
274            .bundles
275            .read()
276            .map_err(|_| wrap_error("bundle store poisoned"))?;
277        let bundle = guard.get(&trust_domain).ok_or_else(|| {
278            wrap_error(format!(
279                "no JWT bundle for trust domain \"{}\"",
280                trust_domain
281            ))
282        })?;
283        Ok(bundle.clone_bundle())
284    }
285}
286
287impl Source for Set {
288    fn get_jwt_bundle_for_trust_domain(&self, trust_domain: TrustDomain) -> Result<Bundle> {
289        self.get_jwt_bundle_for_trust_domain(trust_domain)
290    }
291}
292
293impl Source for Bundle {
294    fn get_jwt_bundle_for_trust_domain(&self, trust_domain: TrustDomain) -> Result<Bundle> {
295        self.get_jwt_bundle_for_trust_domain(trust_domain)
296    }
297}
298
299#[derive(Serialize)]
300struct Jwks {
301    keys: Vec<JwksKey>,
302}
303
304#[derive(Serialize)]
305struct JwksKey {
306    kty: String,
307    kid: String,
308    #[serde(skip_serializing_if = "Option::is_none")]
309    crv: Option<String>,
310    #[serde(skip_serializing_if = "Option::is_none")]
311    x: Option<String>,
312    #[serde(skip_serializing_if = "Option::is_none")]
313    y: Option<String>,
314    #[serde(skip_serializing_if = "Option::is_none")]
315    n: Option<String>,
316    #[serde(skip_serializing_if = "Option::is_none")]
317    e: Option<String>,
318}
319
320impl JwksKey {
321    fn from_jwt_key(key_id: &str, key: &JwtKey) -> JwksKey {
322        match key {
323            JwtKey::Ec { crv, x, y } => JwksKey {
324                kty: "EC".to_string(),
325                kid: key_id.to_string(),
326                crv: Some(crv.clone()),
327                x: Some(base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(x)),
328                y: Some(base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(y)),
329                n: None,
330                e: None,
331            },
332            JwtKey::Rsa { n, e } => JwksKey {
333                kty: "RSA".to_string(),
334                kid: key_id.to_string(),
335                crv: None,
336                x: None,
337                y: None,
338                n: Some(base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(n)),
339                e: Some(base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(e)),
340            },
341        }
342    }
343}