1use serde::de::DeserializeOwned;
20use serde::{Deserialize, Serialize};
21use serde_json::Value;
22
23use crate::encoding::URL_SAFE_NO_PAD;
24
25#[derive(Debug, Clone, PartialEq, Eq)]
26pub enum JwsError {
27 Malformed(String),
28 UnsupportedAlgorithm(String),
29 AlgorithmNotAllowed(String),
30 BadKey(String),
31 BadSignature,
32 InvalidClaim(String),
33}
34
35impl std::fmt::Display for JwsError {
36 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37 match self {
38 JwsError::Malformed(m) => write!(f, "malformed JWT: {m}"),
39 JwsError::UnsupportedAlgorithm(a) => write!(f, "unsupported algorithm: {a}"),
40 JwsError::AlgorithmNotAllowed(a) => write!(f, "algorithm {a} not allowed"),
41 JwsError::BadKey(m) => write!(f, "bad key: {m}"),
42 JwsError::BadSignature => write!(f, "signature verification failed"),
43 JwsError::InvalidClaim(m) => write!(f, "invalid claim: {m}"),
44 }
45 }
46}
47
48impl std::error::Error for JwsError {}
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq)]
51pub enum Algorithm {
52 ES256,
53 ES384,
54 RS256,
55 RS384,
56 RS512,
57 EdDSA,
58}
59
60impl Algorithm {
61 pub fn parse(name: &str) -> Result<Self, JwsError> {
62 match name.to_ascii_uppercase().as_str() {
63 "ES256" => Ok(Algorithm::ES256),
64 "ES384" => Ok(Algorithm::ES384),
65 "RS256" => Ok(Algorithm::RS256),
66 "RS384" => Ok(Algorithm::RS384),
67 "RS512" => Ok(Algorithm::RS512),
68 "EDDSA" => Ok(Algorithm::EdDSA),
69 other => Err(JwsError::UnsupportedAlgorithm(other.to_string())),
70 }
71 }
72
73 pub fn name(&self) -> &'static str {
74 match self {
75 Algorithm::ES256 => "ES256",
76 Algorithm::ES384 => "ES384",
77 Algorithm::RS256 => "RS256",
78 Algorithm::RS384 => "RS384",
79 Algorithm::RS512 => "RS512",
80 Algorithm::EdDSA => "EdDSA",
81 }
82 }
83}
84
85impl std::fmt::Display for Algorithm {
86 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87 f.write_str(self.name())
88 }
89}
90
91#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct Header {
93 pub alg: String,
94 #[serde(skip_serializing_if = "Option::is_none", default)]
95 pub kid: Option<String>,
96 #[serde(skip_serializing_if = "Option::is_none", default)]
97 pub typ: Option<String>,
98 #[serde(skip_serializing_if = "Option::is_none", default)]
100 pub jwk: Option<Value>,
101}
102
103impl Header {
104 pub fn new(alg: Algorithm) -> Self {
105 Header {
106 alg: alg.name().to_string(),
107 kid: None,
108 typ: Some("JWT".to_string()),
109 jwk: None,
110 }
111 }
112
113 pub fn algorithm(&self) -> Result<Algorithm, JwsError> {
114 Algorithm::parse(&self.alg)
115 }
116}
117
118pub fn decode_header(token: &str) -> Result<Header, JwsError> {
121 let first = token
122 .split('.')
123 .next()
124 .ok_or_else(|| JwsError::Malformed("empty token".into()))?;
125 let bytes = URL_SAFE_NO_PAD
126 .decode(first)
127 .map_err(|e| JwsError::Malformed(format!("header base64url: {e}")))?;
128 serde_json::from_slice(&bytes).map_err(|e| JwsError::Malformed(format!("header JSON: {e}")))
129}
130
131pub enum DecodingKey {
136 Ed25519(ed25519_dalek::VerifyingKey),
137 P256(p256::ecdsa::VerifyingKey),
138 P384(p384::ecdsa::VerifyingKey),
139 Rsa(rsa::RsaPublicKey),
140}
141
142impl DecodingKey {
143 pub fn from_ec_components(x: &str, y: &str) -> Result<Self, JwsError> {
146 let xb = b64u(x, "x")?;
147 let yb = b64u(y, "y")?;
148 if xb.len() != yb.len() {
149 return Err(JwsError::BadKey("EC x/y length mismatch".into()));
150 }
151 let mut sec1 = Vec::with_capacity(1 + xb.len() + yb.len());
152 sec1.push(0x04);
153 sec1.extend_from_slice(&xb);
154 sec1.extend_from_slice(&yb);
155 match xb.len() {
156 32 => p256::ecdsa::VerifyingKey::from_sec1_bytes(&sec1)
157 .map(DecodingKey::P256)
158 .map_err(|e| JwsError::BadKey(format!("P-256 point: {e}"))),
159 48 => p384::ecdsa::VerifyingKey::from_sec1_bytes(&sec1)
160 .map(DecodingKey::P384)
161 .map_err(|e| JwsError::BadKey(format!("P-384 point: {e}"))),
162 n => Err(JwsError::BadKey(format!("unsupported EC width {n}"))),
163 }
164 }
165
166 pub fn from_rsa_components(n: &str, e: &str) -> Result<Self, JwsError> {
168 let nb = b64u(n, "n")?;
169 let eb = b64u(e, "e")?;
170 let key = rsa::RsaPublicKey::new(
171 rsa::BigUint::from_bytes_be(&nb),
172 rsa::BigUint::from_bytes_be(&eb),
173 )
174 .map_err(|e| JwsError::BadKey(format!("RSA components: {e}")))?;
175 Ok(DecodingKey::Rsa(key))
176 }
177
178 pub fn from_ed_components(x: &str) -> Result<Self, JwsError> {
180 let xb = b64u(x, "x")?;
181 let arr: [u8; 32] = xb
182 .as_slice()
183 .try_into()
184 .map_err(|_| JwsError::BadKey("Ed25519 x must be 32 bytes".into()))?;
185 ed25519_dalek::VerifyingKey::from_bytes(&arr)
186 .map(DecodingKey::Ed25519)
187 .map_err(|e| JwsError::BadKey(format!("Ed25519 point: {e}")))
188 }
189
190 fn verify(&self, alg: Algorithm, message: &[u8], signature: &[u8]) -> Result<(), JwsError> {
191 match (self, alg) {
192 (DecodingKey::Ed25519(key), Algorithm::EdDSA) => {
193 use ed25519_dalek::Verifier;
194 let sig = ed25519_dalek::Signature::from_slice(signature)
195 .map_err(|_| JwsError::BadSignature)?;
196 key.verify(message, &sig).map_err(|_| JwsError::BadSignature)
197 }
198 (DecodingKey::P256(key), Algorithm::ES256) => {
199 use p256::ecdsa::signature::Verifier;
200 let sig = p256::ecdsa::Signature::from_slice(signature)
201 .map_err(|_| JwsError::BadSignature)?;
202 key.verify(message, &sig).map_err(|_| JwsError::BadSignature)
203 }
204 (DecodingKey::P384(key), Algorithm::ES384) => {
205 use p384::ecdsa::signature::Verifier;
206 let sig = p384::ecdsa::Signature::from_slice(signature)
207 .map_err(|_| JwsError::BadSignature)?;
208 key.verify(message, &sig).map_err(|_| JwsError::BadSignature)
209 }
210 (DecodingKey::Rsa(key), Algorithm::RS256) => {
211 verify_rsa::<sha2::Sha256>(key, message, signature)
212 }
213 (DecodingKey::Rsa(key), Algorithm::RS384) => {
214 verify_rsa::<sha2::Sha384>(key, message, signature)
215 }
216 (DecodingKey::Rsa(key), Algorithm::RS512) => {
217 verify_rsa::<sha2::Sha512>(key, message, signature)
218 }
219 _ => Err(JwsError::AlgorithmNotAllowed(format!(
221 "{} incompatible with the provided key type",
222 alg
223 ))),
224 }
225 }
226}
227
228fn verify_rsa<D>(key: &rsa::RsaPublicKey, message: &[u8], signature: &[u8]) -> Result<(), JwsError>
229where
230 D: rsa::sha2::Digest + rsa::pkcs8::AssociatedOid,
231{
232 use rsa::signature::Verifier;
233 let verifying = rsa::pkcs1v15::VerifyingKey::<D>::new(key.clone());
234 let sig = rsa::pkcs1v15::Signature::try_from(signature).map_err(|_| JwsError::BadSignature)?;
235 verifying
236 .verify(message, &sig)
237 .map_err(|_| JwsError::BadSignature)
238}
239
240fn b64u(s: &str, what: &str) -> Result<Vec<u8>, JwsError> {
241 URL_SAFE_NO_PAD
242 .decode(s)
243 .map_err(|e| JwsError::BadKey(format!("base64url {what}: {e}")))
244}
245
246#[derive(Debug, Clone)]
251pub struct Validation {
252 pub algorithms: Vec<Algorithm>,
253 pub leeway: u64,
255 pub validate_exp: bool,
256 pub validate_nbf: bool,
257 issuer: Option<Vec<String>>,
258 audience: Option<Vec<String>>,
259}
260
261impl Validation {
262 pub fn new(alg: Algorithm) -> Self {
263 Validation {
264 algorithms: vec![alg],
265 leeway: 0,
266 validate_exp: true,
267 validate_nbf: false,
268 issuer: None,
269 audience: None,
270 }
271 }
272
273 pub fn set_issuer<T: ToString>(&mut self, issuers: &[T]) {
274 self.issuer = Some(issuers.iter().map(|i| i.to_string()).collect());
275 }
276
277 pub fn set_audience<T: ToString>(&mut self, audiences: &[T]) {
278 self.audience = Some(audiences.iter().map(|a| a.to_string()).collect());
279 }
280}
281
282#[derive(Debug)]
283pub struct TokenData<T> {
284 pub header: Header,
285 pub claims: T,
286}
287
288pub fn decode<T: DeserializeOwned>(
291 token: &str,
292 key: &DecodingKey,
293 validation: &Validation,
294) -> Result<TokenData<T>, JwsError> {
295 let mut parts = token.split('.');
296 let (h, p, s) = match (parts.next(), parts.next(), parts.next(), parts.next()) {
297 (Some(h), Some(p), Some(s), None) => (h, p, s),
298 _ => return Err(JwsError::Malformed("expected three dot-separated segments".into())),
299 };
300 let header: Header = {
301 let bytes = URL_SAFE_NO_PAD
302 .decode(h)
303 .map_err(|e| JwsError::Malformed(format!("header base64url: {e}")))?;
304 serde_json::from_slice(&bytes).map_err(|e| JwsError::Malformed(format!("header JSON: {e}")))?
305 };
306 let alg = header.algorithm()?;
307 if !validation.algorithms.contains(&alg) {
308 return Err(JwsError::AlgorithmNotAllowed(alg.name().to_string()));
309 }
310 let signature = URL_SAFE_NO_PAD
311 .decode(s)
312 .map_err(|e| JwsError::Malformed(format!("signature base64url: {e}")))?;
313 let message_len = h.len() + 1 + p.len();
314 let message = &token.as_bytes()[..message_len];
315 key.verify(alg, message, &signature)?;
316
317 let payload_bytes = URL_SAFE_NO_PAD
318 .decode(p)
319 .map_err(|e| JwsError::Malformed(format!("payload base64url: {e}")))?;
320 let claims_value: Value = serde_json::from_slice(&payload_bytes)
321 .map_err(|e| JwsError::Malformed(format!("payload JSON: {e}")))?;
322 validate_registered_claims(&claims_value, validation)?;
323 let claims = serde_json::from_value(claims_value)
324 .map_err(|e| JwsError::Malformed(format!("claims shape: {e}")))?;
325 Ok(TokenData { header, claims })
326}
327
328fn validate_registered_claims(claims: &Value, v: &Validation) -> Result<(), JwsError> {
329 let now = now_unix();
330 if v.validate_exp {
331 let exp = claims
332 .get("exp")
333 .and_then(Value::as_u64)
334 .ok_or_else(|| JwsError::InvalidClaim("exp missing".into()))?;
335 if exp.saturating_add(v.leeway) < now {
336 return Err(JwsError::InvalidClaim("token expired".into()));
337 }
338 }
339 if v.validate_nbf {
340 if let Some(nbf) = claims.get("nbf").and_then(Value::as_u64) {
341 if nbf.saturating_sub(v.leeway) > now {
342 return Err(JwsError::InvalidClaim("token not yet valid".into()));
343 }
344 }
345 }
346 if let Some(issuers) = &v.issuer {
347 let iss = claims
348 .get("iss")
349 .and_then(Value::as_str)
350 .ok_or_else(|| JwsError::InvalidClaim("iss missing".into()))?;
351 if !issuers.iter().any(|i| i == iss) {
352 return Err(JwsError::InvalidClaim(format!("issuer {iss} not accepted")));
353 }
354 }
355 if let Some(audiences) = &v.audience {
356 let ok = match claims.get("aud") {
357 Some(Value::String(a)) => audiences.iter().any(|x| x == a),
358 Some(Value::Array(arr)) => arr
359 .iter()
360 .filter_map(Value::as_str)
361 .any(|a| audiences.iter().any(|x| x == a)),
362 _ => false,
363 };
364 if !ok {
365 return Err(JwsError::InvalidClaim("audience not accepted".into()));
366 }
367 }
368 Ok(())
369}
370
371fn now_unix() -> u64 {
372 std::time::SystemTime::now()
373 .duration_since(std::time::UNIX_EPOCH)
374 .unwrap_or_default()
375 .as_secs()
376}
377
378pub enum EncodingKey {
383 Ed25519(Box<ed25519_dalek::SigningKey>),
384 P256(Box<p256::ecdsa::SigningKey>),
385}
386
387impl EncodingKey {
388 pub fn from_ed_pem(pem: &[u8]) -> Result<Self, JwsError> {
389 use ed25519_dalek::pkcs8::DecodePrivateKey;
390 let text =
391 std::str::from_utf8(pem).map_err(|_| JwsError::BadKey("PEM not UTF-8".into()))?;
392 ed25519_dalek::SigningKey::from_pkcs8_pem(text)
393 .map(|k| EncodingKey::Ed25519(Box::new(k)))
394 .map_err(|e| JwsError::BadKey(format!("Ed25519 PKCS#8: {e}")))
395 }
396
397 pub fn from_ec_pem(pem: &[u8]) -> Result<Self, JwsError> {
398 use p256::pkcs8::DecodePrivateKey;
399 let text =
400 std::str::from_utf8(pem).map_err(|_| JwsError::BadKey("PEM not UTF-8".into()))?;
401 p256::SecretKey::from_pkcs8_pem(text)
402 .map(|k| EncodingKey::P256(Box::new(p256::ecdsa::SigningKey::from(k))))
403 .map_err(|e| JwsError::BadKey(format!("EC PKCS#8: {e}")))
404 }
405
406 fn sign(&self, alg: Algorithm, message: &[u8]) -> Result<Vec<u8>, JwsError> {
407 match (self, alg) {
408 (EncodingKey::Ed25519(key), Algorithm::EdDSA) => {
409 use ed25519_dalek::Signer;
410 Ok(key.sign(message).to_bytes().to_vec())
411 }
412 (EncodingKey::P256(key), Algorithm::ES256) => {
413 use p256::ecdsa::signature::Signer;
414 let sig: p256::ecdsa::Signature = key.sign(message);
415 Ok(sig.to_bytes().to_vec())
416 }
417 _ => Err(JwsError::AlgorithmNotAllowed(format!(
418 "{} incompatible with the provided signing key",
419 alg
420 ))),
421 }
422 }
423}
424
425pub fn encode<T: Serialize>(
427 header: &Header,
428 claims: &T,
429 key: &EncodingKey,
430) -> Result<String, JwsError> {
431 let alg = header.algorithm()?;
432 let header_json =
433 serde_json::to_vec(header).map_err(|e| JwsError::Malformed(e.to_string()))?;
434 let payload_json =
435 serde_json::to_vec(claims).map_err(|e| JwsError::Malformed(e.to_string()))?;
436 let mut token = String::new();
437 token.push_str(&URL_SAFE_NO_PAD.encode(header_json));
438 token.push('.');
439 token.push_str(&URL_SAFE_NO_PAD.encode(payload_json));
440 let signature = key.sign(alg, token.as_bytes())?;
441 token.push('.');
442 token.push_str(&URL_SAFE_NO_PAD.encode(signature));
443 Ok(token)
444}
445
446#[cfg(test)]
447mod tests {
448 use super::*;
449 use serde_json::json;
450
451 fn ed_pair() -> (EncodingKey, DecodingKey) {
452 let signing = ed25519_dalek::SigningKey::generate(&mut rand::rngs::OsRng);
453 let x = URL_SAFE_NO_PAD.encode(signing.verifying_key().as_bytes());
454 (
455 EncodingKey::Ed25519(Box::new(signing)),
456 DecodingKey::from_ed_components(&x).unwrap(),
457 )
458 }
459
460 fn claims(exp_offset: i64) -> Value {
461 json!({
462 "iss": "https://idp.example.com",
463 "sub": "alice",
464 "aud": "tf://example.com",
465 "exp": (now_unix() as i64 + exp_offset) as u64,
466 })
467 }
468
469 fn validation() -> Validation {
470 let mut v = Validation::new(Algorithm::EdDSA);
471 v.set_issuer(&["https://idp.example.com"]);
472 v.set_audience(&["tf://example.com"]);
473 v
474 }
475
476 #[test]
477 fn round_trip_eddsa() {
478 let (enc, dec) = ed_pair();
479 let token = encode(&Header::new(Algorithm::EdDSA), &claims(300), &enc).unwrap();
480 let data: TokenData<Value> = decode(&token, &dec, &validation()).unwrap();
481 assert_eq!(data.claims["sub"], "alice");
482 }
483
484 #[test]
485 fn round_trip_es256() {
486 use p256::elliptic_curve::sec1::ToEncodedPoint;
487 let secret = p256::SecretKey::random(&mut rand::rngs::OsRng);
488 let point = secret.public_key().to_encoded_point(false);
489 let dec = DecodingKey::from_ec_components(
490 &URL_SAFE_NO_PAD.encode(point.x().unwrap()),
491 &URL_SAFE_NO_PAD.encode(point.y().unwrap()),
492 )
493 .unwrap();
494 let enc = EncodingKey::P256(Box::new(p256::ecdsa::SigningKey::from(secret)));
495 let mut v = Validation::new(Algorithm::ES256);
496 v.set_issuer(&["https://idp.example.com"]);
497 v.set_audience(&["tf://example.com"]);
498 let token = encode(&Header::new(Algorithm::ES256), &claims(300), &enc).unwrap();
499 let data: TokenData<Value> = decode(&token, &dec, &v).unwrap();
500 assert_eq!(data.claims["sub"], "alice");
501 }
502
503 #[test]
504 fn tampered_signature_rejected() {
505 let (enc, dec) = ed_pair();
506 let token = encode(&Header::new(Algorithm::EdDSA), &claims(300), &enc).unwrap();
507 let mut bad = token.clone();
508 bad.pop();
509 bad.push(if token.ends_with('A') { 'B' } else { 'A' });
510 let err = decode::<Value>(&bad, &dec, &validation()).unwrap_err();
511 assert!(matches!(err, JwsError::BadSignature | JwsError::Malformed(_)));
512 }
513
514 #[test]
515 fn alg_none_unrepresentable_and_rejected() {
516 let header = URL_SAFE_NO_PAD.encode(br#"{"alg":"none"}"#);
519 let payload = URL_SAFE_NO_PAD.encode(br#"{"sub":"alice"}"#);
520 let token = format!("{header}.{payload}.");
521 let (_, dec) = ed_pair();
522 let err = decode::<Value>(&token, &dec, &validation()).unwrap_err();
523 assert!(matches!(err, JwsError::UnsupportedAlgorithm(_)));
524 }
525
526 #[test]
527 fn wrong_alg_for_key_rejected() {
528 let (enc, dec) = ed_pair();
529 let token = encode(&Header::new(Algorithm::EdDSA), &claims(300), &enc).unwrap();
530 let mut v = validation();
532 v.algorithms = vec![Algorithm::ES256];
533 let err = decode::<Value>(&token, &dec, &v).unwrap_err();
534 assert!(matches!(err, JwsError::AlgorithmNotAllowed(_)));
535 }
536
537 #[test]
538 fn expired_token_rejected_with_leeway() {
539 let (enc, dec) = ed_pair();
540 let token = encode(&Header::new(Algorithm::EdDSA), &claims(-120), &enc).unwrap();
541 let err = decode::<Value>(&token, &dec, &validation()).unwrap_err();
542 assert!(matches!(err, JwsError::InvalidClaim(_)));
543 let mut v = validation();
545 v.leeway = 3600;
546 assert!(decode::<Value>(&token, &dec, &v).is_ok());
547 }
548
549 #[test]
550 fn issuer_and_audience_enforced() {
551 let (enc, dec) = ed_pair();
552 let token = encode(&Header::new(Algorithm::EdDSA), &claims(300), &enc).unwrap();
553 let mut v = validation();
554 v.set_issuer(&["https://other.example.com"]);
555 assert!(decode::<Value>(&token, &dec, &v).is_err());
556 let mut v = validation();
557 v.set_audience(&["tf://other.example.com"]);
558 assert!(decode::<Value>(&token, &dec, &v).is_err());
559 }
560}