1use core::fmt;
2use log::warn;
3use std::collections::BTreeMap;
4use std::time::{SystemTime, UNIX_EPOCH};
5
6use anyhow::{anyhow, Error, Result};
7use openssl::bn::BigNum;
8use openssl::nid::Nid;
9use openssl::{
10 ec::{EcGroup, EcKey},
11 pkey::Public,
12};
13
14use base64::decode_config as b64_dec;
15use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation};
16use serde::{de::DeserializeOwned, Deserialize, Serialize};
17use serde_json::Value as JsonValue;
18
19use crate::SpiffeID;
20
21const SEGMENTS_COUNT: usize = 3;
22
23#[derive(Serialize, Deserialize, PartialEq, Debug)]
24pub struct JwtKey {
25 #[serde(rename = "kty")]
26 pub key_type: String,
27 #[serde(rename = "kid")]
28 pub key_id: String,
29 #[serde(rename = "crv")]
30 pub curve: String,
31 pub x: String,
32 pub y: String,
33}
34
35#[derive(PartialEq, Debug)]
36pub struct JwtBundle {
37 pub inner: BTreeMap<String, JwtKey>,
38}
39
40impl fmt::Display for JwtBundle {
41 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42 writeln!(
43 f,
44 "{}",
45 serde_json::to_string(&self.inner).unwrap_or_default()
46 )
47 }
48}
49
50impl Serialize for JwtBundle {
51 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
52 serializer.collect_seq(self.inner.values())
53 }
54}
55
56impl<'de> Deserialize<'de> for JwtBundle {
57 fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
58 let raw = Vec::<JwtKey>::deserialize(deserializer)?;
59 Ok(JwtBundle {
60 inner: raw.into_iter().map(|x| (x.key_id.clone(), x)).collect(),
61 })
62 }
63}
64
65impl JwtKey {
66 pub fn as_openssl_public_key(&self) -> Result<EcKey<Public>> {
67 let nid = match &*self.curve {
68 "P-256" => Nid::X9_62_PRIME256V1,
69 "P-384" => Nid::SECP384R1,
70 _ => {
71 return Err(anyhow!(
72 "invalid curve in jwt key '{}': {}",
73 self.key_id,
74 self.curve
75 ))
76 }
77 };
78 let group = EcGroup::from_curve_name(nid)?;
79
80 let x = base64::decode_config(&self.x, base64::URL_SAFE)?;
81 let x = BigNum::from_slice(&x[..])?;
82 let y = base64::decode_config(&self.y, base64::URL_SAFE)?;
83 let y = BigNum::from_slice(&y[..])?;
84 Ok(EcKey::from_public_key_affine_coordinates(&group, &x, &y)?)
85 }
86}
87
88#[derive(Deserialize, Serialize)]
89struct JwtHeader {
90 #[serde(rename = "alg")]
91 algorithm: String,
92 #[serde(rename = "kid")]
93 key_id: String,
94 typ: String,
95}
96
97#[derive(Deserialize, Serialize, Debug)]
98struct JwtPayload {
99 sub: SpiffeID,
100}
101
102#[derive(Serialize, Deserialize, Debug)]
103struct Claims {
104 aud: Vec<String>,
105 exp: usize,
106 #[serde(skip_serializing_if = "Option::is_none")]
107 iat: Option<usize>,
108 #[serde(skip_serializing_if = "Option::is_none")]
109 iss: Option<String>,
110 sub: String,
111}
112
113impl JwtBundle {
114 pub fn verify_token<T: DeserializeOwned>(&self, encoded_token: &str) -> Result<T> {
115 let raw_segments: Vec<&str> = encoded_token.split('.').collect();
117 if raw_segments.len() != SEGMENTS_COUNT {
118 return Err(anyhow!("jwt token has incorrect amounts of segments"));
119 }
120 let header_segment = raw_segments[0];
121 let payload_segment = raw_segments[1];
122 let b64_to_json = |seg| -> Result<JsonValue, Error> {
123 serde_json::from_slice(b64_dec(seg, base64::URL_SAFE_NO_PAD)?.as_slice())
124 .map_err(Error::from)
125 };
126 let payload_json = b64_to_json(payload_segment)?;
127
128 let header = header_segment;
130 let header = base64::decode_config(header, base64::URL_SAFE_NO_PAD)?;
131 let header: JwtHeader = serde_json::from_slice(&header[..])?;
132 if header.typ != "JWT" {
133 return Err(anyhow!("header 'typ' not 'JWT': {}", header.typ));
134 }
135 let key = self
136 .inner
137 .get(&header.key_id)
138 .ok_or_else(|| anyhow!("key id '{}' not found in bundle", header.key_id))?;
139
140 let ec_public_key = key.as_openssl_public_key()?;
141 let public_key_pem =
142 openssl::pkey::PKey::from_ec_key(ec_public_key)?.public_key_to_pem()?;
143 let public_key_u8 = public_key_pem.as_slice();
144
145 let validation = Validation {
146 algorithms: vec![Algorithm::ES256, Algorithm::ES384],
147 validate_exp: false, ..Validation::default()
149 };
150
151 let token_data = match decode::<Claims>(
152 encoded_token,
153 &DecodingKey::from_ec_pem(public_key_u8)?,
154 &validation,
155 ) {
156 Ok(c) => c,
157 Err(err) => return Err(anyhow!("{:?} happened during decoding Jwt token", err)),
158 };
159
160 let start = SystemTime::now();
162 let now = start.duration_since(UNIX_EPOCH)?;
163
164 let now = now.as_secs() as usize;
165
166 if token_data.claims.exp < now as usize {
168 if token_data.claims.exp < (now - 86400) as usize {
169 return Err(anyhow!("Token has expired for over 24 hours"));
171 } else {
172 warn!("Token is about to expire in 24 hours")
173 }
174 }
175
176 Ok(serde_json::from_value(payload_json)?)
177 }
178
179 pub fn verify_spiffe_id(&self, encoded_token: &str) -> Result<SpiffeID> {
180 let payload: JwtPayload = self.verify_token(encoded_token)?;
181 Ok(payload.sub)
182 }
183}
184
185#[cfg(test)]
186mod tests {
187 use super::*;
188 use jsonwebtoken::{encode, EncodingKey, Header};
189 use serde_test::{assert_tokens, Token};
190
191 impl Default for JwtKey {
192 fn default() -> Self {
193 JwtKey {
194 key_type: String::from("JWT"),
195 key_id: String::from("dummy_keyid"),
196 curve: String::from("P-256"),
197 x: String::from("ovsRfW7L2V8zyGyJkOLA_JlczbgssQ7JrVQ2pzS74QY"),
198 y: String::from("kO_n1Pz9qbK8gNzfXA4Hfo1K11-Dyl1JilDFYltNyhw"),
199 }
200 }
201 }
202
203 impl Default for JwtBundle {
204 fn default() -> Self {
205 let jwt_key = JwtKey {
206 ..JwtKey::default()
207 };
208 let mut bundle_inner = BTreeMap::new();
209 bundle_inner.insert(String::from("dummy_keyid"), jwt_key);
210
211 JwtBundle {
212 inner: bundle_inner,
213 }
214 }
215 }
216
217 impl Default for Claims {
218 fn default() -> Self {
219 Claims {
220 aud: vec![String::from("dummy_audience")],
221 exp: 1753717118, iat: None,
223 iss: None,
224 sub: String::from("spiffe://dummy.org/ns:dummy/id:dummy"),
225 }
226 }
227 }
228
229 struct Setup {
230 bundle_p256: JwtBundle,
231 bundle_p384: JwtBundle,
232 bundle_invalid_curve: JwtBundle,
233 bundle_reserved_slot: JwtBundle,
234 token_p256: String,
235 token_p384: String,
236 token_reserved_slot: String,
237 token_wrong_sig: &'static str,
238 token_invalid_segment_length: &'static str,
239 token_invalid_header_type: String,
240 token_invalid_key_id: String,
241 token_expired: String,
242 token_about_to_expire: String,
243 token_with_issuer: String,
244 token_with_iat: String,
245 }
246
247 impl Setup {
248 fn new() -> Self {
249 Self {
250 bundle_p256: {
251 JwtBundle{
252 ..JwtBundle::default()
253 }
254 },
255 bundle_p384: {
256 let jwt_key = JwtKey{
257 x: String::from("_Ukg1KZI3nxFNp94Dt6Zh4sDFMBtsCOpFpHNBw0K_R4OSW2veXsCta-mIUfbKGr-"),
258 y: String::from("4fQDA18hHXcB3Z8Ld-h0GG7ZGDyZjhsez1AlJ7Swvd8ruXiC3cVpVt27UPIv0f70"),
259 curve: String::from("P-384"),
260 .. JwtKey::default()};
261 let mut bundle_inner = BTreeMap::new();
262 bundle_inner.insert(String::from("dummy_keyid"), jwt_key);
263 JwtBundle{
264 inner: bundle_inner,
265 }
266 },
267 bundle_invalid_curve: {
268 let jwt_key = JwtKey{
269 x: String::from("_Ukg1KZI3nxFNp94Dt6Zh4sDFMBtsCOpFpHNBw0K_R4OSW2veXsCta-mIUfbKGr-"),
270 y: String::from("4fQDA18hHXcB3Z8Ld-h0GG7ZGDyZjhsez1AlJ7Swvd8ruXiC3cVpVt27UPIv0f70"),
271 curve: String::from("P-521"),
272 .. JwtKey::default()};
273 let mut bundle_inner = BTreeMap::new();
274 bundle_inner.insert(String::from("dummy_keyid"), jwt_key);
275 JwtBundle{
276 inner: bundle_inner,
277 }
278 },
279 bundle_reserved_slot: {
282 let jwt_key = JwtKey {
283 key_type: String::from("JWT"),
284 key_id: String::from("dummy_keyid"),
286 curve: String::from("P-256"),
287 x: String::from("ovsRfW7L2V8zyGyJkOLA_JlczbgssQ7JrVQ2pzS74QY"),
290 y: String::from("kO_n1Pz9qbK8gNzfXA4Hfo1K11-Dyl1JilDFYltNyhw"),
291 };
292 let mut bundle_inner = BTreeMap::new();
293 bundle_inner.insert(String::from("dummy_keyid"), jwt_key);
295 JwtBundle{
296 inner: bundle_inner,
297 }
298 },
299 token_p256: {
300 generate_token_on_algorithm(Algorithm::ES256)
301 },
302 token_p384: {
303 generate_token_on_algorithm(Algorithm::ES384)
304 },
305 token_reserved_slot: {
306 String::from("eyJ0eXAiOiJKV1QiLCJhbGciOiJFUzI1NiIsImtpZCI6ImR1bW15X2tleWlkIn0.eyJhdWQiOlsiZHVtbXlfYXVkaWVuY2UiXSwiZXhwIjoxNzUzNzE3MTE4LCJzdWIiOiJzcGlmZmU6Ly9kdW1teS5vcmcvbnM6ZHVtbXkvaWQ6ZHVtbXkifQ.Js21h6fppSJ4MUjrbk3pvIrJ_ybLD0UvZcA9lldXnaxVEbiB3NBfcOPAaKHaQJN_BYQZiERRyFrUMRprx-MvTA")
309 },
310 token_wrong_sig: "eyJ0eXAiOiJKV1QiLCJhbGciOiJFUzI1NiIsImtpZCI6ImR1bW15X2tleWlkIn0.eyJhdWQiOlsiZHVtbXlfYXVkaWVuY2UiXSwiZXhwIjoxNzUzNzE3MTE4LCJpYXQiOjE2MjcwMTUyMjIsImlzcyI6InVzZXIiLCJzdWIiOiJzcGlmZmU6Ly9kdW1teS5vcmcvbnM6ZHVtbXkvaWQ6ZHVtbXkifQ.q7RMpz74PigIib2x34bSU6mp72Bw26tTS9Zl3nV_Gwzpt7-RsQFktbKefZC9JV0uJptCKJNLeyXBdNs3NgV7GA",
311 token_invalid_segment_length: "eyJ0eXAiOiJKV1QiLCJhbGciOiJFUzI1NiIsImtpZCI6ImR1bW15X2tleWlkIn0",
312 token_invalid_header_type: {
313 generate_token_wrong_header()
314 },
315 token_invalid_key_id: {
316 generate_token_invalid_key_id()
317 },
318 token_expired: {
319 let start = SystemTime::now();
320 let now = start.duration_since(UNIX_EPOCH).unwrap();
321 let expired_time = now.as_secs() as usize - 86460; generate_token_on_expire(expired_time)
323 },
324 token_about_to_expire: {
325 let start = SystemTime::now();
326 let now = start.duration_since(UNIX_EPOCH).unwrap();
327 let expired_time = now.as_secs() as usize - 46460;
328 generate_token_on_expire(expired_time)
329 },
330 token_with_issuer: {
331 generate_token_with_issuer()
332 },
333 token_with_iat: {
334 generate_token_with_iat()
335 }
336 }
337 }
338 }
339
340 fn generate_token_on_algorithm(algorithm: Algorithm) -> String {
341 let priv_key_pem = if algorithm == Algorithm::ES256 {
342 include_bytes!("../tests/data/priv_key_256v1.pem").to_vec()
343 } else {
344 include_bytes!("../tests/data/priv_key_384r1.pem").to_vec()
345 };
346
347 let my_claims = Claims {
348 ..Claims::default()
349 };
350 let header = Header {
351 alg: algorithm,
352 kid: Some("dummy_keyid".to_owned()),
353 ..Header::default()
354 };
355
356 let key = openssl::pkey::PKey::private_key_from_pem(priv_key_pem.as_slice()).unwrap();
357 let pem = key.private_key_to_pem_pkcs8().unwrap();
358 encode(
359 &header,
360 &my_claims,
361 &EncodingKey::from_ec_pem(pem.as_slice()).unwrap(),
362 )
363 .unwrap()
364 }
365
366 fn generate_token_wrong_header() -> String {
367 let priv_key_pem = include_bytes!("../tests/data/priv_key_256v1.pem");
368
369 let my_claims = Claims {
370 ..Claims::default()
371 };
372 let header = Header {
373 alg: Algorithm::ES256,
374 kid: Some("dummy_keyid".to_owned()),
375 typ: Some("error".to_owned()),
376 ..Header::default()
377 };
378
379 let key = openssl::pkey::PKey::private_key_from_pem(priv_key_pem).unwrap();
380 let pem = key.private_key_to_pem_pkcs8().unwrap();
381 encode(
382 &header,
383 &my_claims,
384 &EncodingKey::from_ec_pem(pem.as_slice()).unwrap(),
385 )
386 .unwrap()
387 }
388
389 fn generate_token_invalid_key_id() -> String {
390 let priv_key_pem = include_bytes!("../tests/data/priv_key_256v1.pem");
391
392 let my_claims = Claims {
393 ..Claims::default()
394 };
395 let header = Header {
396 alg: Algorithm::ES256,
397 kid: Some("error".to_owned()),
398 ..Header::default()
399 };
400
401 let key = openssl::pkey::PKey::private_key_from_pem(priv_key_pem).unwrap();
402 let pem = key.private_key_to_pem_pkcs8().unwrap();
403 encode(
404 &header,
405 &my_claims,
406 &EncodingKey::from_ec_pem(pem.as_slice()).unwrap(),
407 )
408 .unwrap()
409 }
410
411 fn generate_token_on_expire(expire_time: usize) -> String {
412 let priv_key_pem = include_bytes!("../tests/data/priv_key_256v1.pem");
413
414 let my_claims = Claims {
415 exp: expire_time,
416 ..Claims::default()
417 };
418 let header = Header {
419 alg: Algorithm::ES256,
420 kid: Some("dummy_keyid".to_owned()),
421 ..Header::default()
422 };
423
424 let key = openssl::pkey::PKey::private_key_from_pem(priv_key_pem).unwrap();
425 let pem = key.private_key_to_pem_pkcs8().unwrap();
426 encode(
427 &header,
428 &my_claims,
429 &EncodingKey::from_ec_pem(pem.as_slice()).unwrap(),
430 )
431 .unwrap()
432 }
433
434 fn generate_token_with_issuer() -> String {
435 let priv_key_pem = include_bytes!("../tests/data/priv_key_256v1.pem");
436
437 let my_claims = Claims {
438 iss: Some(String::from("user")),
439 ..Claims::default()
440 };
441 let header = Header {
442 alg: Algorithm::ES256,
443 kid: Some("dummy_keyid".to_owned()),
444 ..Header::default()
445 };
446
447 let key = openssl::pkey::PKey::private_key_from_pem(priv_key_pem).unwrap();
448 let pem = key.private_key_to_pem_pkcs8().unwrap();
449 encode(
450 &header,
451 &my_claims,
452 &EncodingKey::from_ec_pem(pem.as_slice()).unwrap(),
453 )
454 .unwrap()
455 }
456
457 fn generate_token_with_iat() -> String {
458 let priv_key_pem = include_bytes!("../tests/data/priv_key_256v1.pem");
459
460 let my_claims = Claims {
461 iat: Some(1627015222), ..Claims::default()
463 };
464 let header = Header {
465 alg: Algorithm::ES256,
466 kid: Some("dummy_keyid".to_owned()),
467 ..Header::default()
468 };
469
470 let key = openssl::pkey::PKey::private_key_from_pem(priv_key_pem).unwrap();
471 let pem = key.private_key_to_pem_pkcs8().unwrap();
472 encode(
473 &header,
474 &my_claims,
475 &EncodingKey::from_ec_pem(pem.as_slice()).unwrap(),
476 )
477 .unwrap()
478 }
479
480 #[test]
481 fn test_verify_token_p256() {
482 let setup = Setup::new();
483 assert!(
484 setup
485 .bundle_p256
486 .verify_token::<JwtPayload>(&setup.token_p256)
487 .is_ok(),
488 "token verification failed"
489 );
490 }
491
492 #[test]
493 fn test_verify_token_p384() {
494 let setup = Setup::new();
495 assert!(
496 setup
497 .bundle_p384
498 .verify_token::<JwtPayload>(&setup.token_p384)
499 .is_ok(),
500 "token verification failed"
501 );
502 }
503
504 #[test]
505 fn test_verify_token_reserved_slot() {
506 let setup = Setup::new();
507 assert_eq!(
508 format!(
509 "{:#}",
510 setup
511 .bundle_reserved_slot
512 .verify_token::<JwtPayload>(&setup.token_reserved_slot)
513 .unwrap_err()
514 ),
515 "Error(InvalidSignature) happened during decoding Jwt token"
516 );
517 }
518
519 #[test]
520 fn test_verify_token_wrong_sig() {
521 let setup = Setup::new();
522 assert_eq!(
523 format!(
524 "{:#}",
525 setup
526 .bundle_p256
527 .verify_token::<JwtPayload>(setup.token_wrong_sig)
528 .unwrap_err()
529 ),
530 "Error(InvalidSignature) happened during decoding Jwt token"
531 );
532 }
533
534 #[test]
535 fn test_verify_token_bundle_invalid_curve() {
536 let setup = Setup::new();
537 assert_eq!(
538 format!(
539 "{:#}",
540 setup
541 .bundle_invalid_curve
542 .verify_token::<JwtPayload>(&setup.token_p256)
543 .unwrap_err()
544 ),
545 "invalid curve in jwt key 'dummy_keyid': P-521"
546 );
547 }
548
549 #[test]
550 fn test_verify_token_invalid_header_type() {
551 let setup = Setup::new();
552 assert_eq!(
553 format!(
554 "{:#}",
555 setup
556 .bundle_p256
557 .verify_token::<JwtPayload>(&setup.token_invalid_header_type)
558 .unwrap_err()
559 ),
560 "header 'typ' not 'JWT': error"
561 );
562 }
563
564 #[test]
565 fn test_verify_token_invalid_key_id() {
566 let setup = Setup::new();
567 assert_eq!(
568 format!(
569 "{:#}",
570 setup
571 .bundle_p256
572 .verify_token::<JwtPayload>(&setup.token_invalid_key_id)
573 .unwrap_err()
574 ),
575 "key id 'error' not found in bundle"
576 );
577 }
578
579 #[test]
580 fn test_verify_token_invalid_segment_length() {
581 let setup = Setup::new();
582 assert_eq!(
583 format!(
584 "{:#}",
585 setup
586 .bundle_p256
587 .verify_token::<JwtPayload>(setup.token_invalid_segment_length)
588 .unwrap_err()
589 ),
590 "jwt token has incorrect amounts of segments"
591 );
592 }
593
594 #[test]
595 fn test_verify_token_expired() {
596 let setup = Setup::new();
597 assert_eq!(
598 format!(
599 "{:#}",
600 setup
601 .bundle_p256
602 .verify_token::<JwtPayload>(&setup.token_expired)
603 .unwrap_err()
604 ),
605 "Token has expired for over 24 hours"
606 );
607 }
608
609 #[test]
610 fn test_verify_token_about_to_expire() {
611 let setup = Setup::new();
612 assert!(
613 setup
614 .bundle_p256
615 .verify_token::<JwtPayload>(&setup.token_about_to_expire)
616 .is_ok(),
617 "Token about to expire verification failed"
618 );
619 }
620
621 #[test]
622 fn test_verify_token_with_issuer() {
623 let setup = Setup::new();
624 assert!(
625 setup
626 .bundle_p256
627 .verify_token::<JwtPayload>(&setup.token_with_issuer)
628 .is_ok(),
629 "Token with issuer verification failed"
630 );
631 }
632
633 #[test]
634 fn test_verify_token_with_iat() {
635 let setup = Setup::new();
636 assert!(
637 setup
638 .bundle_p256
639 .verify_token::<JwtPayload>(&setup.token_with_iat)
640 .is_ok(),
641 "Token with iat verification failed"
642 );
643 }
644
645 #[test]
646 fn test_verify_spiffe_id_p256() {
647 let setup = Setup::new();
648 assert!(
649 setup
650 .bundle_p256
651 .verify_spiffe_id(&setup.token_p256)
652 .is_ok(),
653 "Spiffe ID verification failed"
654 );
655 }
656
657 #[test]
658 fn test_verify_spiffe_id_p384() {
659 let setup = Setup::new();
660 assert!(
661 setup
662 .bundle_p384
663 .verify_spiffe_id(&setup.token_p384)
664 .is_ok(),
665 "Spiffe ID verification failed"
666 );
667 }
668
669 #[test]
670 fn test_ser_de() {
671 let setup = Setup::new();
672 assert_tokens(
673 &setup.bundle_p256,
674 &[
675 Token::Seq { len: Some(1) },
676 Token::Struct {
677 name: "JwtKey",
678 len: 5,
679 },
680 Token::String("kty"),
681 Token::String("JWT"),
682 Token::String("kid"),
683 Token::String("dummy_keyid"),
684 Token::String("crv"),
685 Token::String("P-256"),
686 Token::String("x"),
687 Token::String("ovsRfW7L2V8zyGyJkOLA_JlczbgssQ7JrVQ2pzS74QY"),
688 Token::String("y"),
689 Token::String("kO_n1Pz9qbK8gNzfXA4Hfo1K11-Dyl1JilDFYltNyhw"),
690 Token::StructEnd,
691 Token::SeqEnd,
692 ],
693 );
694 }
695
696 #[test]
697 fn test_jwt_bundle_display() {
698 let setup = Setup::new();
699 assert_eq!(format!("{}",setup.bundle_p256),String::from("{\"dummy_keyid\":{\"kty\":\"JWT\",\"kid\":\"dummy_keyid\",\"crv\":\"P-256\",\"x\":\"ovsRfW7L2V8zyGyJkOLA_JlczbgssQ7JrVQ2pzS74QY\",\"y\":\"kO_n1Pz9qbK8gNzfXA4Hfo1K11-Dyl1JilDFYltNyhw\"}}\n"));
700 }
701}