Skip to main content

tf_types/
jws.rs

1//! In-house JWS/JWT compact-serialization verify + sign (RFC 7515/7519)
2//! — TrustForge owns its envelope layer; see `docs/dependency-audit.md`.
3//! Mirror of `tools/tf-types-ts/src/core/jws.ts`.
4//!
5//! **No custom cryptography**: every signature operation delegates to a
6//! reviewed primitive crate — `ed25519-dalek` (EdDSA), `p256`/`p384`
7//! (ES256/ES384), `rsa` (RS256/RS384/RS512). This module only owns the
8//! *envelope*: compact-form parsing, base64url handling, the algorithm
9//! allow-list, and registered-claim validation.
10//!
11//! Security posture (deliberate, do not relax):
12//! - `alg` is never trusted from the token alone — verification requires
13//!   the caller's explicit allow-list, and `none` is unrepresentable.
14//! - Key type and algorithm must agree (an RSA key never verifies an
15//!   ES256 token, killing key-confusion downgrades).
16//! - `exp` is validated by default; `iss`/`aud` are validated whenever
17//!   the caller configures them, and configured-but-missing claims fail.
18
19use serde::de::DeserializeOwned;
20use serde::{Deserialize, Serialize};
21use serde_json::Value;
22
23use crate::encoding::URL_SAFE_NO_PAD;
24
25#[derive(Debug, Clone, PartialEq, Eq)]
26pub enum JwsError {
27    Malformed(String),
28    UnsupportedAlgorithm(String),
29    AlgorithmNotAllowed(String),
30    BadKey(String),
31    BadSignature,
32    InvalidClaim(String),
33}
34
35impl std::fmt::Display for JwsError {
36    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37        match self {
38            JwsError::Malformed(m) => write!(f, "malformed JWT: {m}"),
39            JwsError::UnsupportedAlgorithm(a) => write!(f, "unsupported algorithm: {a}"),
40            JwsError::AlgorithmNotAllowed(a) => write!(f, "algorithm {a} not allowed"),
41            JwsError::BadKey(m) => write!(f, "bad key: {m}"),
42            JwsError::BadSignature => write!(f, "signature verification failed"),
43            JwsError::InvalidClaim(m) => write!(f, "invalid claim: {m}"),
44        }
45    }
46}
47
48impl std::error::Error for JwsError {}
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq)]
51pub enum Algorithm {
52    ES256,
53    ES384,
54    RS256,
55    RS384,
56    RS512,
57    EdDSA,
58}
59
60impl Algorithm {
61    pub fn parse(name: &str) -> Result<Self, JwsError> {
62        match name.to_ascii_uppercase().as_str() {
63            "ES256" => Ok(Algorithm::ES256),
64            "ES384" => Ok(Algorithm::ES384),
65            "RS256" => Ok(Algorithm::RS256),
66            "RS384" => Ok(Algorithm::RS384),
67            "RS512" => Ok(Algorithm::RS512),
68            "EDDSA" => Ok(Algorithm::EdDSA),
69            other => Err(JwsError::UnsupportedAlgorithm(other.to_string())),
70        }
71    }
72
73    pub fn name(&self) -> &'static str {
74        match self {
75            Algorithm::ES256 => "ES256",
76            Algorithm::ES384 => "ES384",
77            Algorithm::RS256 => "RS256",
78            Algorithm::RS384 => "RS384",
79            Algorithm::RS512 => "RS512",
80            Algorithm::EdDSA => "EdDSA",
81        }
82    }
83}
84
85impl std::fmt::Display for Algorithm {
86    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87        f.write_str(self.name())
88    }
89}
90
91#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct Header {
93    pub alg: String,
94    #[serde(skip_serializing_if = "Option::is_none", default)]
95    pub kid: Option<String>,
96    #[serde(skip_serializing_if = "Option::is_none", default)]
97    pub typ: Option<String>,
98    /// DPoP-style embedded public key.
99    #[serde(skip_serializing_if = "Option::is_none", default)]
100    pub jwk: Option<Value>,
101}
102
103impl Header {
104    pub fn new(alg: Algorithm) -> Self {
105        Header {
106            alg: alg.name().to_string(),
107            kid: None,
108            typ: Some("JWT".to_string()),
109            jwk: None,
110        }
111    }
112
113    pub fn algorithm(&self) -> Result<Algorithm, JwsError> {
114        Algorithm::parse(&self.alg)
115    }
116}
117
118/// Parse the (unverified!) header segment. Never make a trust decision
119/// from this alone.
120pub fn decode_header(token: &str) -> Result<Header, JwsError> {
121    let first = token
122        .split('.')
123        .next()
124        .ok_or_else(|| JwsError::Malformed("empty token".into()))?;
125    let bytes = URL_SAFE_NO_PAD
126        .decode(first)
127        .map_err(|e| JwsError::Malformed(format!("header base64url: {e}")))?;
128    serde_json::from_slice(&bytes).map_err(|e| JwsError::Malformed(format!("header JSON: {e}")))
129}
130
131/* ------------------------------------------------------------------ */
132/*  Keys                                                               */
133/* ------------------------------------------------------------------ */
134
135pub enum DecodingKey {
136    Ed25519(ed25519_dalek::VerifyingKey),
137    P256(p256::ecdsa::VerifyingKey),
138    P384(p384::ecdsa::VerifyingKey),
139    Rsa(rsa::RsaPublicKey),
140}
141
142impl DecodingKey {
143    /// From JWK EC members (base64url x/y). Curve is inferred from the
144    /// coordinate width.
145    pub fn from_ec_components(x: &str, y: &str) -> Result<Self, JwsError> {
146        let xb = b64u(x, "x")?;
147        let yb = b64u(y, "y")?;
148        if xb.len() != yb.len() {
149            return Err(JwsError::BadKey("EC x/y length mismatch".into()));
150        }
151        let mut sec1 = Vec::with_capacity(1 + xb.len() + yb.len());
152        sec1.push(0x04);
153        sec1.extend_from_slice(&xb);
154        sec1.extend_from_slice(&yb);
155        match xb.len() {
156            32 => p256::ecdsa::VerifyingKey::from_sec1_bytes(&sec1)
157                .map(DecodingKey::P256)
158                .map_err(|e| JwsError::BadKey(format!("P-256 point: {e}"))),
159            48 => p384::ecdsa::VerifyingKey::from_sec1_bytes(&sec1)
160                .map(DecodingKey::P384)
161                .map_err(|e| JwsError::BadKey(format!("P-384 point: {e}"))),
162            n => Err(JwsError::BadKey(format!("unsupported EC width {n}"))),
163        }
164    }
165
166    /// From JWK RSA members (base64url n/e).
167    pub fn from_rsa_components(n: &str, e: &str) -> Result<Self, JwsError> {
168        let nb = b64u(n, "n")?;
169        let eb = b64u(e, "e")?;
170        let key = rsa::RsaPublicKey::new(
171            rsa::BigUint::from_bytes_be(&nb),
172            rsa::BigUint::from_bytes_be(&eb),
173        )
174        .map_err(|e| JwsError::BadKey(format!("RSA components: {e}")))?;
175        Ok(DecodingKey::Rsa(key))
176    }
177
178    /// From JWK OKP member (base64url x, Ed25519).
179    pub fn from_ed_components(x: &str) -> Result<Self, JwsError> {
180        let xb = b64u(x, "x")?;
181        let arr: [u8; 32] = xb
182            .as_slice()
183            .try_into()
184            .map_err(|_| JwsError::BadKey("Ed25519 x must be 32 bytes".into()))?;
185        ed25519_dalek::VerifyingKey::from_bytes(&arr)
186            .map(DecodingKey::Ed25519)
187            .map_err(|e| JwsError::BadKey(format!("Ed25519 point: {e}")))
188    }
189
190    fn verify(&self, alg: Algorithm, message: &[u8], signature: &[u8]) -> Result<(), JwsError> {
191        match (self, alg) {
192            (DecodingKey::Ed25519(key), Algorithm::EdDSA) => {
193                use ed25519_dalek::Verifier;
194                let sig = ed25519_dalek::Signature::from_slice(signature)
195                    .map_err(|_| JwsError::BadSignature)?;
196                key.verify(message, &sig).map_err(|_| JwsError::BadSignature)
197            }
198            (DecodingKey::P256(key), Algorithm::ES256) => {
199                use p256::ecdsa::signature::Verifier;
200                let sig = p256::ecdsa::Signature::from_slice(signature)
201                    .map_err(|_| JwsError::BadSignature)?;
202                key.verify(message, &sig).map_err(|_| JwsError::BadSignature)
203            }
204            (DecodingKey::P384(key), Algorithm::ES384) => {
205                use p384::ecdsa::signature::Verifier;
206                let sig = p384::ecdsa::Signature::from_slice(signature)
207                    .map_err(|_| JwsError::BadSignature)?;
208                key.verify(message, &sig).map_err(|_| JwsError::BadSignature)
209            }
210            (DecodingKey::Rsa(key), Algorithm::RS256) => {
211                verify_rsa::<sha2::Sha256>(key, message, signature)
212            }
213            (DecodingKey::Rsa(key), Algorithm::RS384) => {
214                verify_rsa::<sha2::Sha384>(key, message, signature)
215            }
216            (DecodingKey::Rsa(key), Algorithm::RS512) => {
217                verify_rsa::<sha2::Sha512>(key, message, signature)
218            }
219            // Key type and algorithm must agree — no cross-verification.
220            _ => Err(JwsError::AlgorithmNotAllowed(format!(
221                "{} incompatible with the provided key type",
222                alg
223            ))),
224        }
225    }
226}
227
228fn verify_rsa<D>(key: &rsa::RsaPublicKey, message: &[u8], signature: &[u8]) -> Result<(), JwsError>
229where
230    D: rsa::sha2::Digest + rsa::pkcs8::AssociatedOid,
231{
232    use rsa::signature::Verifier;
233    let verifying = rsa::pkcs1v15::VerifyingKey::<D>::new(key.clone());
234    let sig = rsa::pkcs1v15::Signature::try_from(signature).map_err(|_| JwsError::BadSignature)?;
235    verifying
236        .verify(message, &sig)
237        .map_err(|_| JwsError::BadSignature)
238}
239
240fn b64u(s: &str, what: &str) -> Result<Vec<u8>, JwsError> {
241    URL_SAFE_NO_PAD
242        .decode(s)
243        .map_err(|e| JwsError::BadKey(format!("base64url {what}: {e}")))
244}
245
246/* ------------------------------------------------------------------ */
247/*  Validation                                                         */
248/* ------------------------------------------------------------------ */
249
250#[derive(Debug, Clone)]
251pub struct Validation {
252    pub algorithms: Vec<Algorithm>,
253    /// Clock tolerance in seconds, applied to `exp` and `nbf`.
254    pub leeway: u64,
255    pub validate_exp: bool,
256    pub validate_nbf: bool,
257    issuer: Option<Vec<String>>,
258    audience: Option<Vec<String>>,
259}
260
261impl Validation {
262    pub fn new(alg: Algorithm) -> Self {
263        Validation {
264            algorithms: vec![alg],
265            leeway: 0,
266            validate_exp: true,
267            validate_nbf: false,
268            issuer: None,
269            audience: None,
270        }
271    }
272
273    pub fn set_issuer<T: ToString>(&mut self, issuers: &[T]) {
274        self.issuer = Some(issuers.iter().map(|i| i.to_string()).collect());
275    }
276
277    pub fn set_audience<T: ToString>(&mut self, audiences: &[T]) {
278        self.audience = Some(audiences.iter().map(|a| a.to_string()).collect());
279    }
280}
281
282#[derive(Debug)]
283pub struct TokenData<T> {
284    pub header: Header,
285    pub claims: T,
286}
287
288/// Verify a compact JWS and deserialize its payload, enforcing the
289/// registered claims configured on `validation`.
290pub fn decode<T: DeserializeOwned>(
291    token: &str,
292    key: &DecodingKey,
293    validation: &Validation,
294) -> Result<TokenData<T>, JwsError> {
295    let mut parts = token.split('.');
296    let (h, p, s) = match (parts.next(), parts.next(), parts.next(), parts.next()) {
297        (Some(h), Some(p), Some(s), None) => (h, p, s),
298        _ => return Err(JwsError::Malformed("expected three dot-separated segments".into())),
299    };
300    let header: Header = {
301        let bytes = URL_SAFE_NO_PAD
302            .decode(h)
303            .map_err(|e| JwsError::Malformed(format!("header base64url: {e}")))?;
304        serde_json::from_slice(&bytes).map_err(|e| JwsError::Malformed(format!("header JSON: {e}")))?
305    };
306    let alg = header.algorithm()?;
307    if !validation.algorithms.contains(&alg) {
308        return Err(JwsError::AlgorithmNotAllowed(alg.name().to_string()));
309    }
310    let signature = URL_SAFE_NO_PAD
311        .decode(s)
312        .map_err(|e| JwsError::Malformed(format!("signature base64url: {e}")))?;
313    let message_len = h.len() + 1 + p.len();
314    let message = &token.as_bytes()[..message_len];
315    key.verify(alg, message, &signature)?;
316
317    let payload_bytes = URL_SAFE_NO_PAD
318        .decode(p)
319        .map_err(|e| JwsError::Malformed(format!("payload base64url: {e}")))?;
320    let claims_value: Value = serde_json::from_slice(&payload_bytes)
321        .map_err(|e| JwsError::Malformed(format!("payload JSON: {e}")))?;
322    validate_registered_claims(&claims_value, validation)?;
323    let claims = serde_json::from_value(claims_value)
324        .map_err(|e| JwsError::Malformed(format!("claims shape: {e}")))?;
325    Ok(TokenData { header, claims })
326}
327
328fn validate_registered_claims(claims: &Value, v: &Validation) -> Result<(), JwsError> {
329    let now = now_unix();
330    if v.validate_exp {
331        let exp = claims
332            .get("exp")
333            .and_then(Value::as_u64)
334            .ok_or_else(|| JwsError::InvalidClaim("exp missing".into()))?;
335        if exp.saturating_add(v.leeway) < now {
336            return Err(JwsError::InvalidClaim("token expired".into()));
337        }
338    }
339    if v.validate_nbf {
340        if let Some(nbf) = claims.get("nbf").and_then(Value::as_u64) {
341            if nbf.saturating_sub(v.leeway) > now {
342                return Err(JwsError::InvalidClaim("token not yet valid".into()));
343            }
344        }
345    }
346    if let Some(issuers) = &v.issuer {
347        let iss = claims
348            .get("iss")
349            .and_then(Value::as_str)
350            .ok_or_else(|| JwsError::InvalidClaim("iss missing".into()))?;
351        if !issuers.iter().any(|i| i == iss) {
352            return Err(JwsError::InvalidClaim(format!("issuer {iss} not accepted")));
353        }
354    }
355    if let Some(audiences) = &v.audience {
356        let ok = match claims.get("aud") {
357            Some(Value::String(a)) => audiences.iter().any(|x| x == a),
358            Some(Value::Array(arr)) => arr
359                .iter()
360                .filter_map(Value::as_str)
361                .any(|a| audiences.iter().any(|x| x == a)),
362            _ => false,
363        };
364        if !ok {
365            return Err(JwsError::InvalidClaim("audience not accepted".into()));
366        }
367    }
368    Ok(())
369}
370
371fn now_unix() -> u64 {
372    std::time::SystemTime::now()
373        .duration_since(std::time::UNIX_EPOCH)
374        .unwrap_or_default()
375        .as_secs()
376}
377
378/* ------------------------------------------------------------------ */
379/*  Signing (tests and vector generation)                              */
380/* ------------------------------------------------------------------ */
381
382pub enum EncodingKey {
383    Ed25519(Box<ed25519_dalek::SigningKey>),
384    P256(Box<p256::ecdsa::SigningKey>),
385}
386
387impl EncodingKey {
388    pub fn from_ed_pem(pem: &[u8]) -> Result<Self, JwsError> {
389        use ed25519_dalek::pkcs8::DecodePrivateKey;
390        let text =
391            std::str::from_utf8(pem).map_err(|_| JwsError::BadKey("PEM not UTF-8".into()))?;
392        ed25519_dalek::SigningKey::from_pkcs8_pem(text)
393            .map(|k| EncodingKey::Ed25519(Box::new(k)))
394            .map_err(|e| JwsError::BadKey(format!("Ed25519 PKCS#8: {e}")))
395    }
396
397    pub fn from_ec_pem(pem: &[u8]) -> Result<Self, JwsError> {
398        use p256::pkcs8::DecodePrivateKey;
399        let text =
400            std::str::from_utf8(pem).map_err(|_| JwsError::BadKey("PEM not UTF-8".into()))?;
401        p256::SecretKey::from_pkcs8_pem(text)
402            .map(|k| EncodingKey::P256(Box::new(p256::ecdsa::SigningKey::from(k))))
403            .map_err(|e| JwsError::BadKey(format!("EC PKCS#8: {e}")))
404    }
405
406    fn sign(&self, alg: Algorithm, message: &[u8]) -> Result<Vec<u8>, JwsError> {
407        match (self, alg) {
408            (EncodingKey::Ed25519(key), Algorithm::EdDSA) => {
409                use ed25519_dalek::Signer;
410                Ok(key.sign(message).to_bytes().to_vec())
411            }
412            (EncodingKey::P256(key), Algorithm::ES256) => {
413                use p256::ecdsa::signature::Signer;
414                let sig: p256::ecdsa::Signature = key.sign(message);
415                Ok(sig.to_bytes().to_vec())
416            }
417            _ => Err(JwsError::AlgorithmNotAllowed(format!(
418                "{} incompatible with the provided signing key",
419                alg
420            ))),
421        }
422    }
423}
424
425/// Mint a compact JWS.
426pub fn encode<T: Serialize>(
427    header: &Header,
428    claims: &T,
429    key: &EncodingKey,
430) -> Result<String, JwsError> {
431    let alg = header.algorithm()?;
432    let header_json =
433        serde_json::to_vec(header).map_err(|e| JwsError::Malformed(e.to_string()))?;
434    let payload_json =
435        serde_json::to_vec(claims).map_err(|e| JwsError::Malformed(e.to_string()))?;
436    let mut token = String::new();
437    token.push_str(&URL_SAFE_NO_PAD.encode(header_json));
438    token.push('.');
439    token.push_str(&URL_SAFE_NO_PAD.encode(payload_json));
440    let signature = key.sign(alg, token.as_bytes())?;
441    token.push('.');
442    token.push_str(&URL_SAFE_NO_PAD.encode(signature));
443    Ok(token)
444}
445
446#[cfg(test)]
447mod tests {
448    use super::*;
449    use serde_json::json;
450
451    fn ed_pair() -> (EncodingKey, DecodingKey) {
452        let signing = ed25519_dalek::SigningKey::generate(&mut rand::rngs::OsRng);
453        let x = URL_SAFE_NO_PAD.encode(signing.verifying_key().as_bytes());
454        (
455            EncodingKey::Ed25519(Box::new(signing)),
456            DecodingKey::from_ed_components(&x).unwrap(),
457        )
458    }
459
460    fn claims(exp_offset: i64) -> Value {
461        json!({
462            "iss": "https://idp.example.com",
463            "sub": "alice",
464            "aud": "tf://example.com",
465            "exp": (now_unix() as i64 + exp_offset) as u64,
466        })
467    }
468
469    fn validation() -> Validation {
470        let mut v = Validation::new(Algorithm::EdDSA);
471        v.set_issuer(&["https://idp.example.com"]);
472        v.set_audience(&["tf://example.com"]);
473        v
474    }
475
476    #[test]
477    fn round_trip_eddsa() {
478        let (enc, dec) = ed_pair();
479        let token = encode(&Header::new(Algorithm::EdDSA), &claims(300), &enc).unwrap();
480        let data: TokenData<Value> = decode(&token, &dec, &validation()).unwrap();
481        assert_eq!(data.claims["sub"], "alice");
482    }
483
484    #[test]
485    fn round_trip_es256() {
486        use p256::elliptic_curve::sec1::ToEncodedPoint;
487        let secret = p256::SecretKey::random(&mut rand::rngs::OsRng);
488        let point = secret.public_key().to_encoded_point(false);
489        let dec = DecodingKey::from_ec_components(
490            &URL_SAFE_NO_PAD.encode(point.x().unwrap()),
491            &URL_SAFE_NO_PAD.encode(point.y().unwrap()),
492        )
493        .unwrap();
494        let enc = EncodingKey::P256(Box::new(p256::ecdsa::SigningKey::from(secret)));
495        let mut v = Validation::new(Algorithm::ES256);
496        v.set_issuer(&["https://idp.example.com"]);
497        v.set_audience(&["tf://example.com"]);
498        let token = encode(&Header::new(Algorithm::ES256), &claims(300), &enc).unwrap();
499        let data: TokenData<Value> = decode(&token, &dec, &v).unwrap();
500        assert_eq!(data.claims["sub"], "alice");
501    }
502
503    #[test]
504    fn tampered_signature_rejected() {
505        let (enc, dec) = ed_pair();
506        let token = encode(&Header::new(Algorithm::EdDSA), &claims(300), &enc).unwrap();
507        let mut bad = token.clone();
508        bad.pop();
509        bad.push(if token.ends_with('A') { 'B' } else { 'A' });
510        let err = decode::<Value>(&bad, &dec, &validation()).unwrap_err();
511        assert!(matches!(err, JwsError::BadSignature | JwsError::Malformed(_)));
512    }
513
514    #[test]
515    fn alg_none_unrepresentable_and_rejected() {
516        // A hand-built alg:none token must fail: parse (unknown alg) and
517        // allow-list both reject it.
518        let header = URL_SAFE_NO_PAD.encode(br#"{"alg":"none"}"#);
519        let payload = URL_SAFE_NO_PAD.encode(br#"{"sub":"alice"}"#);
520        let token = format!("{header}.{payload}.");
521        let (_, dec) = ed_pair();
522        let err = decode::<Value>(&token, &dec, &validation()).unwrap_err();
523        assert!(matches!(err, JwsError::UnsupportedAlgorithm(_)));
524    }
525
526    #[test]
527    fn wrong_alg_for_key_rejected() {
528        let (enc, dec) = ed_pair();
529        let token = encode(&Header::new(Algorithm::EdDSA), &claims(300), &enc).unwrap();
530        // Validation allows ES256 only.
531        let mut v = validation();
532        v.algorithms = vec![Algorithm::ES256];
533        let err = decode::<Value>(&token, &dec, &v).unwrap_err();
534        assert!(matches!(err, JwsError::AlgorithmNotAllowed(_)));
535    }
536
537    #[test]
538    fn expired_token_rejected_with_leeway() {
539        let (enc, dec) = ed_pair();
540        let token = encode(&Header::new(Algorithm::EdDSA), &claims(-120), &enc).unwrap();
541        let err = decode::<Value>(&token, &dec, &validation()).unwrap_err();
542        assert!(matches!(err, JwsError::InvalidClaim(_)));
543        // Generous leeway lets it pass.
544        let mut v = validation();
545        v.leeway = 3600;
546        assert!(decode::<Value>(&token, &dec, &v).is_ok());
547    }
548
549    #[test]
550    fn issuer_and_audience_enforced() {
551        let (enc, dec) = ed_pair();
552        let token = encode(&Header::new(Algorithm::EdDSA), &claims(300), &enc).unwrap();
553        let mut v = validation();
554        v.set_issuer(&["https://other.example.com"]);
555        assert!(decode::<Value>(&token, &dec, &v).is_err());
556        let mut v = validation();
557        v.set_audience(&["tf://other.example.com"]);
558        assert!(decode::<Value>(&token, &dec, &v).is_err());
559    }
560}