Skip to main content

ubl_auth/
lib.rs

1
2#![forbid(unsafe_code)]
3
4use base64::{engine::general_purpose::URL_SAFE_NO_PAD as B64URL, Engine as _};
5use ed25519_dalek::{VerifyingKey, Signature, Verifier};
6use once_cell::sync::Lazy;
7use parking_lot::Mutex;
8use serde::{Deserialize, Serialize};
9use serde_json::Value as Json;
10use std::{collections::HashMap, time::{SystemTime, UNIX_EPOCH}};
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct Claims {
14    pub sub: String,
15    #[serde(default)]
16    pub iss: Option<String>,
17    #[serde(default)]
18    pub aud: Option<Aud>,
19    #[serde(default)]
20    pub exp: Option<i64>,
21    #[serde(default)]
22    pub nbf: Option<i64>,
23    #[serde(default)]
24    pub iat: Option<i64>,
25    #[serde(default)]
26    pub jti: Option<String>,
27    #[serde(default)]
28    pub scope: Option<String>,
29    #[serde(flatten)]
30    pub extra: HashMap<String, Json>,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
34#[serde(untagged)]
35pub enum Aud {
36    One(String),
37    Many(Vec<String>),
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct VerifyOptions {
42    pub leeway_secs: i64,
43    pub issuer: Option<String>,
44    pub audience: Option<String>,
45    pub now: Option<i64>,
46}
47impl Default for VerifyOptions {
48    fn default() -> Self {
49        Self { leeway_secs: 300, issuer: None, audience: None, now: None }
50    }
51}
52impl VerifyOptions {
53    pub fn with_issuer(mut self, iss: &str) -> Self { self.issuer = Some(iss.to_string()); self }
54    pub fn with_audience(mut self, aud: &str) -> Self { self.audience = Some(aud.to_string()); self }
55    pub fn with_leeway(mut self, secs: i64) -> Self { self.leeway_secs = secs; self }
56    pub fn with_now(mut self, now: i64) -> Self { self.now = Some(now); self }
57}
58
59#[derive(Debug, thiserror::Error)]
60pub enum VerifyError {
61    #[error("bad token format")]
62    BadFormat,
63    #[error("base64 decode failed")]
64    Base64,
65    #[error("json parse failed")]
66    Json,
67    #[error("alg not allowed (expected EdDSA)")]
68    Alg,
69    #[error("missing kid in JWT header")]
70    Kid,
71    #[error("jwks http error: {0}")]
72    JwksHttp(String),
73    #[error("jwks parse error")]
74    JwksJson,
75    #[error("no matching key for kid")]
76    NoKey,
77    #[error("invalid signature")]
78    Signature,
79    #[error("claim 'exp' expired")]
80    Expired,
81    #[error("claim 'nbf' in future")]
82    NotYetValid,
83    #[error("issuer mismatch")]
84    Issuer,
85    #[error("audience mismatch")]
86    Audience,
87    #[error("missing sub")]
88    MissingSub,
89}
90
91#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct Jwk { pub kty:String, #[serde(default)] pub crv:Option<String>, #[serde(default)] pub x:Option<String>, #[serde(default)] pub kid:Option<String> }
93#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct Jwks { pub keys: Vec<Jwk> }
95
96#[derive(Debug, Clone)]
97pub struct JwksCacheEntry { pub jwks: Jwks, pub fetched_at: i64 }
98#[derive(Debug)]
99pub struct JwksCache { ttl_secs: i64, inner: Mutex<HashMap<String, JwksCacheEntry>> }
100
101static GLOBAL_JWKS: Lazy<JwksCache> = Lazy::new(|| JwksCache::new(300));
102
103impl JwksCache {
104    pub fn new(ttl_secs: i64) -> Self { Self { ttl_secs, inner: Mutex::new(HashMap::new()) } }
105    pub fn put(&self, uri: &str, jwks: Jwks) {
106        let mut m = self.inner.lock();
107        m.insert(uri.to_string(), JwksCacheEntry{ jwks, fetched_at: now_ts() });
108    }
109    pub fn get_fresh(&self, uri: &str) -> Option<Jwks> {
110        let m = self.inner.lock();
111        if let Some(entry) = m.get(uri) {
112            if now_ts() - entry.fetched_at <= self.ttl_secs {
113                return Some(entry.jwks.clone());
114            }
115        }
116        None
117    }
118}
119
120pub fn verify_ed25519_jwt_with_jwks(token: &str, jwks_uri: &str, opts: &VerifyOptions) -> Result<Claims, VerifyError> {
121    verify_ed25519_jwt_with_cache(token, jwks_uri, &GLOBAL_JWKS, opts)
122}
123
124pub fn verify_ed25519_jwt_with_cache(token: &str, jwks_uri: &str, cache: &JwksCache, opts: &VerifyOptions) -> Result<Claims, VerifyError> {
125    let (header, payload, sig, signing_input) = split_and_decode(token)?;
126
127    let alg = header.get("alg").and_then(|v| v.as_str()).ok_or(VerifyError::Alg)?;
128    if alg != "EdDSA" { return Err(VerifyError::Alg); }
129    let kid = header.get("kid").and_then(|v| v.as_str()).ok_or(VerifyError::Kid)?;
130
131    let jwks = if let Some(j) = cache.get_fresh(jwks_uri) { j } else {
132        let fetched = fetch_jwks(jwks_uri)?;
133        cache.put(jwks_uri, fetched.clone());
134        fetched
135    };
136    let vk = key_by_kid(&jwks, kid).ok_or(VerifyError::NoKey)?;
137
138    vk.verify_strict(signing_input.as_bytes(), &sig).map_err(|_| VerifyError::Signature)?;
139
140    let claims: Claims = serde_json::from_value(payload).map_err(|_| VerifyError::Json)?;
141    check_claims(&claims, opts)?;
142    Ok(claims)
143}
144
145fn split_and_decode(token: &str) -> Result<(Json, Json, Signature, String), VerifyError> {
146    let parts: Vec<&str> = token.split('.').collect();
147    if parts.len() != 3 { return Err(VerifyError::BadFormat); }
148    let header_json = String::from_utf8(B64URL.decode(parts[0].as_bytes()).map_err(|_| VerifyError::Base64)?).map_err(|_| VerifyError::Base64)?;
149    let payload_json = String::from_utf8(B64URL.decode(parts[1].as_bytes()).map_err(|_| VerifyError::Base64)?).map_err(|_| VerifyError::Base64)?;
150    let sig_bytes = B64URL.decode(parts[2].as_bytes()).map_err(|_| VerifyError::Base64)?;
151    let sig = Signature::from_bytes(sig_bytes[..].try_into().map_err(|_| VerifyError::Signature)?);
152    let header: Json = serde_json::from_str(&header_json).map_err(|_| VerifyError::Json)?;
153    let payload: Json = serde_json::from_str(&payload_json).map_err(|_| VerifyError::Json)?;
154    Ok((header, payload, sig, format!("{}.{}", parts[0], parts[1])))
155}
156
157fn fetch_jwks(uri: &str) -> Result<Jwks, VerifyError> {
158    let resp = ureq::get(uri).call().map_err(|e| VerifyError::JwksHttp(e.to_string()))?;
159    let body = resp.into_string().map_err(|e| VerifyError::JwksHttp(e.to_string()))?;
160    serde_json::from_str(&body).map_err(|_| VerifyError::JwksJson)
161}
162
163fn key_by_kid(jwks: &Jwks, kid: &str) -> Option<VerifyingKey> {
164    for k in &jwks.keys {
165        if k.kty != "OKP" { continue; }
166        if k.crv.as_deref() != Some("Ed25519") { continue; }
167        let k_kid = k.kid.as_deref().unwrap_or_default();
168        if k_kid == kid || k_kid.is_empty() {
169            if let Some(x) = &k.x {
170                if let Ok(bytes) = B64URL.decode(x.as_bytes()) {
171                    if let Ok(vk) = VerifyingKey::from_bytes(bytes[..].try_into().ok()?) {
172                        return Some(vk);
173                    }
174                }
175            }
176        }
177    }
178    None
179}
180
181pub fn now_ts() -> i64 {
182    let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_secs() as i64;
183    now
184}
185
186fn check_claims(c: &Claims, opts: &VerifyOptions) -> Result<(), VerifyError> {
187    let now = opts.now.unwrap_or_else(now_ts);
188    if c.sub.is_empty() { return Err(VerifyError::MissingSub); }
189    if let Some(exp) = c.exp {
190        if now > exp + opts.leeway_secs { return Err(VerifyError::Expired); }
191    }
192    if let Some(nbf) = c.nbf {
193        if now + opts.leeway_secs < nbf { return Err(VerifyError::NotYetValid); }
194    }
195    if let Some(iat) = c.iat {
196        if iat > now + opts.leeway_secs { return Err(VerifyError::NotYetValid); }
197    }
198    if let Some(ref iss) = opts.issuer {
199        if c.iss.as_deref() != Some(iss) { return Err(VerifyError::Issuer); }
200    }
201    if let Some(ref aud) = opts.audience {
202        match &c.aud {
203            None => return Err(VerifyError::Audience),
204            Some(Aud::One(s)) if s != aud => return Err(VerifyError::Audience),
205            Some(Aud::Many(v)) if !v.iter().any(|x| x == aud) => return Err(VerifyError::Audience),
206            _ => {}
207        }
208    }
209    Ok(())
210}
211
212#[cfg(test)]
213mod tests {
214    use super::*;
215    use rand::{SeedableRng, rngs::StdRng};
216    use ed25519_dalek::{SigningKey, Signer};
217    use serde_json::json;
218    use base64::{engine::general_purpose::URL_SAFE_NO_PAD as B64URL, Engine as _};
219
220    #[test]
221    fn roundtrip_sign_and_verify_with_cache() {
222        let mut rng = StdRng::seed_from_u64(42);
223        let sk = SigningKey::generate(&mut rng);
224        let vk = sk.verifying_key();
225        let x = B64URL.encode(vk.to_bytes());
226
227        let cache = JwksCache::new(3600);
228        cache.put("mem://jwks", Jwks{ keys: vec![ Jwk{ kty:"OKP".into(), crv:Some("Ed25519".into()), x:Some(x), kid:Some("test".into()) } ]});
229
230        let header = json!({"alg":"EdDSA","kid":"test","typ":"JWT"});
231        let now = now_ts();
232        let payload = json!({
233            "sub":"did:key:zTest",
234            "iss":"https://id.ubl.agency",
235            "aud":"demo",
236            "iat": now,
237            "nbf": now - 5,
238            "exp": now + 3600
239        });
240        let hdr = B64URL.encode(serde_json::to_string(&header).unwrap());
241        let pld = B64URL.encode(serde_json::to_string(&payload).unwrap());
242        let msg = format!("{}.{}", hdr, pld);
243        let sig = sk.sign(msg.as_bytes());
244        let jwt = format!("{}.{}", msg, B64URL.encode(sig.to_bytes()));
245
246        let opts = VerifyOptions::default().with_issuer("https://id.ubl.agency").with_audience("demo");
247        let claims = verify_ed25519_jwt_with_cache(&jwt, "mem://jwks", &cache, &opts).expect("verify");
248        assert_eq!(claims.sub, "did:key:zTest");
249    }
250}