spiffe_rs/bundle/jwtbundle/
mod.rs1use crate::internal::jwk::JwkDocument;
2use crate::internal::jwtutil;
3use crate::spiffeid::TrustDomain;
4use base64::Engine;
5use serde::Serialize;
6use std::collections::HashMap;
7use std::fs;
8use std::io::Read;
9use std::sync::RwLock;
10
11#[derive(Debug, Clone)]
12pub struct Error(String);
13
14impl std::fmt::Display for Error {
15 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
16 self.0.fmt(f)
17 }
18}
19
20impl std::error::Error for Error {}
21
22impl Error {
23 pub fn new(message: impl Into<String>) -> Error {
24 Error(message.into())
25 }
26}
27
28pub type Result<T> = std::result::Result<T, Error>;
29
30pub use crate::internal::jwk::JwtKey;
31
32fn wrap_error(message: impl std::fmt::Display) -> Error {
33 Error(format!("jwtbundle: {}", message))
34}
35
36fn strip_prefix(message: &str) -> &str {
37 message.strip_prefix("jwtbundle: ").unwrap_or(message)
38}
39
40#[derive(Debug)]
42pub struct Bundle {
43 trust_domain: TrustDomain,
44 jwt_authorities: RwLock<HashMap<String, JwtKey>>,
45}
46
47impl Bundle {
48 pub fn new(trust_domain: TrustDomain) -> Bundle {
50 Bundle {
51 trust_domain,
52 jwt_authorities: RwLock::new(HashMap::new()),
53 }
54 }
55
56 pub fn from_jwt_authorities(
58 trust_domain: TrustDomain,
59 jwt_authorities: &HashMap<String, JwtKey>,
60 ) -> Bundle {
61 Bundle {
62 trust_domain,
63 jwt_authorities: RwLock::new(jwtutil::copy_jwt_authorities(jwt_authorities)),
64 }
65 }
66
67 pub fn load(trust_domain: TrustDomain, path: &str) -> Result<Bundle> {
69 let bytes =
70 fs::read(path).map_err(|err| wrap_error(format!("unable to read JWT bundle: {}", err)))?;
71 Bundle::parse(trust_domain, &bytes)
72 }
73
74 pub fn read(trust_domain: TrustDomain, reader: &mut dyn Read) -> Result<Bundle> {
76 let mut bytes = Vec::new();
77 reader
78 .read_to_end(&mut bytes)
79 .map_err(|err| wrap_error(format!("unable to read: {}", err)))?;
80 Bundle::parse(trust_domain, &bytes)
81 }
82
83 pub fn parse(trust_domain: TrustDomain, bytes: &[u8]) -> Result<Bundle> {
85 let jwks: JwkDocument =
86 serde_json::from_slice(bytes).map_err(|err| wrap_error(format!("unable to parse JWKS: {}", err)))?;
87 let bundle = Bundle::new(trust_domain);
88 let keys = jwks.keys.unwrap_or_default();
89 for (idx, key) in keys.iter().enumerate() {
90 let key_id = key.key_id().unwrap_or_default();
91 let jwt_key = key
92 .to_jwt_key()
93 .map_err(|err| wrap_error(format!("error adding authority {} of JWKS: {}", idx, err)))?;
94 if let Err(err) = bundle.add_jwt_authority(key_id, jwt_key) {
95 return Err(wrap_error(format!(
96 "error adding authority {} of JWKS: {}",
97 idx,
98 strip_prefix(&err.to_string())
99 )));
100 }
101 }
102 Ok(bundle)
103 }
104
105 pub fn trust_domain(&self) -> TrustDomain {
107 self.trust_domain.clone()
108 }
109
110 pub fn jwt_authorities(&self) -> HashMap<String, JwtKey> {
112 self.jwt_authorities
113 .read()
114 .map(|guard| jwtutil::copy_jwt_authorities(&guard))
115 .unwrap_or_default()
116 }
117
118 pub fn find_jwt_authority(&self, key_id: &str) -> Option<JwtKey> {
120 self.jwt_authorities
121 .read()
122 .ok()
123 .and_then(|guard| guard.get(key_id).cloned())
124 }
125
126 pub fn has_jwt_authority(&self, key_id: &str) -> bool {
128 self.jwt_authorities
129 .read()
130 .map(|guard| guard.contains_key(key_id))
131 .unwrap_or(false)
132 }
133
134 pub fn add_jwt_authority(&self, key_id: &str, jwt_authority: JwtKey) -> Result<()> {
136 if key_id.is_empty() {
137 return Err(wrap_error("keyID cannot be empty"));
138 }
139 if let Ok(mut guard) = self.jwt_authorities.write() {
140 guard.insert(key_id.to_string(), jwt_authority);
141 }
142 Ok(())
143 }
144
145 pub fn remove_jwt_authority(&self, key_id: &str) {
147 if let Ok(mut guard) = self.jwt_authorities.write() {
148 guard.remove(key_id);
149 }
150 }
151
152 pub fn set_jwt_authorities(&self, jwt_authorities: &HashMap<String, JwtKey>) {
154 if let Ok(mut guard) = self.jwt_authorities.write() {
155 *guard = jwtutil::copy_jwt_authorities(jwt_authorities);
156 }
157 }
158
159 pub fn empty(&self) -> bool {
161 self.jwt_authorities
162 .read()
163 .map(|guard| guard.is_empty())
164 .unwrap_or(true)
165 }
166
167 pub fn marshal(&self) -> Result<Vec<u8>> {
169 let mut keys = Vec::new();
170 let authorities = self.jwt_authorities();
171 for (key_id, jwt_key) in authorities {
172 keys.push(JwksKey::from_jwt_key(&key_id, &jwt_key));
173 }
174 let jwks = Jwks { keys };
175 serde_json::to_vec(&jwks).map_err(|err| wrap_error(err))
176 }
177
178 pub fn clone_bundle(&self) -> Bundle {
180 Bundle::from_jwt_authorities(self.trust_domain(), &self.jwt_authorities())
181 }
182
183 pub fn equal(&self, other: &Bundle) -> bool {
185 self.trust_domain == other.trust_domain
186 && jwtutil::jwt_authorities_equal(&self.jwt_authorities(), &other.jwt_authorities())
187 }
188
189 pub fn get_jwt_bundle_for_trust_domain(&self, trust_domain: TrustDomain) -> Result<Bundle> {
191 if self.trust_domain != trust_domain {
192 return Err(wrap_error(format!(
193 "no JWT bundle for trust domain \"{}\"",
194 trust_domain
195 )));
196 }
197 Ok(self.clone_bundle())
198 }
199}
200
201pub trait Source {
203 fn get_jwt_bundle_for_trust_domain(&self, trust_domain: TrustDomain) -> Result<Bundle>;
205}
206
207#[derive(Debug)]
209pub struct Set {
210 bundles: RwLock<HashMap<TrustDomain, Bundle>>,
211}
212
213impl Set {
214 pub fn new(bundles: &[Bundle]) -> Set {
216 let mut map = HashMap::new();
217 for bundle in bundles {
218 map.insert(bundle.trust_domain(), bundle.clone_bundle());
219 }
220 Set {
221 bundles: RwLock::new(map),
222 }
223 }
224
225 pub fn add(&self, bundle: &Bundle) {
227 if let Ok(mut guard) = self.bundles.write() {
228 guard.insert(bundle.trust_domain(), bundle.clone_bundle());
229 }
230 }
231
232 pub fn remove(&self, trust_domain: TrustDomain) {
234 if let Ok(mut guard) = self.bundles.write() {
235 guard.remove(&trust_domain);
236 }
237 }
238
239 pub fn has(&self, trust_domain: TrustDomain) -> bool {
241 self.bundles
242 .read()
243 .map(|guard| guard.contains_key(&trust_domain))
244 .unwrap_or(false)
245 }
246
247 pub fn get(&self, trust_domain: TrustDomain) -> Option<Bundle> {
249 self.bundles
250 .read()
251 .ok()
252 .and_then(|guard| guard.get(&trust_domain).map(|b| b.clone_bundle()))
253 }
254
255 pub fn bundles(&self) -> Vec<Bundle> {
257 let mut bundles = self
258 .bundles
259 .read()
260 .map(|guard| guard.values().map(|b| b.clone_bundle()).collect::<Vec<_>>())
261 .unwrap_or_default();
262 bundles.sort_by(|a, b| a.trust_domain().compare(&b.trust_domain()));
263 bundles
264 }
265
266 pub fn len(&self) -> usize {
268 self.bundles.read().map(|guard| guard.len()).unwrap_or(0)
269 }
270
271 pub fn get_jwt_bundle_for_trust_domain(&self, trust_domain: TrustDomain) -> Result<Bundle> {
273 let guard = self
274 .bundles
275 .read()
276 .map_err(|_| wrap_error("bundle store poisoned"))?;
277 let bundle = guard.get(&trust_domain).ok_or_else(|| {
278 wrap_error(format!(
279 "no JWT bundle for trust domain \"{}\"",
280 trust_domain
281 ))
282 })?;
283 Ok(bundle.clone_bundle())
284 }
285}
286
287impl Source for Set {
288 fn get_jwt_bundle_for_trust_domain(&self, trust_domain: TrustDomain) -> Result<Bundle> {
289 self.get_jwt_bundle_for_trust_domain(trust_domain)
290 }
291}
292
293impl Source for Bundle {
294 fn get_jwt_bundle_for_trust_domain(&self, trust_domain: TrustDomain) -> Result<Bundle> {
295 self.get_jwt_bundle_for_trust_domain(trust_domain)
296 }
297}
298
299#[derive(Serialize)]
300struct Jwks {
301 keys: Vec<JwksKey>,
302}
303
304#[derive(Serialize)]
305struct JwksKey {
306 kty: String,
307 kid: String,
308 #[serde(skip_serializing_if = "Option::is_none")]
309 crv: Option<String>,
310 #[serde(skip_serializing_if = "Option::is_none")]
311 x: Option<String>,
312 #[serde(skip_serializing_if = "Option::is_none")]
313 y: Option<String>,
314 #[serde(skip_serializing_if = "Option::is_none")]
315 n: Option<String>,
316 #[serde(skip_serializing_if = "Option::is_none")]
317 e: Option<String>,
318}
319
320impl JwksKey {
321 fn from_jwt_key(key_id: &str, key: &JwtKey) -> JwksKey {
322 match key {
323 JwtKey::Ec { crv, x, y } => JwksKey {
324 kty: "EC".to_string(),
325 kid: key_id.to_string(),
326 crv: Some(crv.clone()),
327 x: Some(base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(x)),
328 y: Some(base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(y)),
329 n: None,
330 e: None,
331 },
332 JwtKey::Rsa { n, e } => JwksKey {
333 kty: "RSA".to_string(),
334 kid: key_id.to_string(),
335 crv: None,
336 x: None,
337 y: None,
338 n: Some(base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(n)),
339 e: Some(base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(e)),
340 },
341 }
342 }
343}