ubl_axum_kit/
lib.rs

1#![forbid(unsafe_code)]
2use axum::{http::{request::Parts, HeaderMap}, response::IntoResponse};
3use base64::{engine::general_purpose::URL_SAFE_NO_PAD as B64URL, Engine as _};
4use ed25519_dalek::{VerifyingKey, Signature, Verifier};
5use once_cell::sync::Lazy;
6use parking_lot::Mutex;
7use serde::{Deserialize, Serialize};
8use serde_json::Value as Json;
9use std::{collections::HashMap, time::{SystemTime, UNIX_EPOCH}};
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct SubjectCtx {
13    pub did: String,
14    pub iss: String,
15    pub jti: Option<String>,
16    pub scope: Vec<String>,
17    pub aud: Option<String>,
18}
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct TenantCtx {
21    pub tenant: String,
22}
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct WalletPop {
25    pub wallet_did: String,
26    pub has_ath: bool,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct ApiError { pub error: String, pub code: String, pub trace_id: String }
31impl IntoResponse for ApiError {
32    fn into_response(self) -> axum::response::Response {
33        let body = serde_json::to_string(&self).unwrap_or_else(|_| "{\"error\":\"internal\"}".to_string());
34        (axum::http::StatusCode::BAD_REQUEST, [("content-type","application/json")], body).into_response()
35    }
36}
37
38/// Success response wrapper
39#[derive(Debug, Clone)]
40pub struct ApiOk(pub serde_json::Value);
41impl IntoResponse for ApiOk {
42    fn into_response(self) -> axum::response::Response {
43        let body = serde_json::to_string(&self.0).unwrap_or_else(|_| "{}".to_string());
44        (axum::http::StatusCode::OK, [("content-type","application/json")], body).into_response()
45    }
46}
47
48/// Not found response
49#[derive(Debug, Clone)]
50pub struct ApiNotFound { pub error: String, pub trace_id: String }
51impl IntoResponse for ApiNotFound {
52    fn into_response(self) -> axum::response::Response {
53        let body = serde_json::to_string(&serde_json::json!({"error": self.error, "trace_id": self.trace_id})).unwrap_or_else(|_| "{\"error\":\"not found\"}".to_string());
54        (axum::http::StatusCode::NOT_FOUND, [("content-type","application/json")], body).into_response()
55    }
56}
57
58#[derive(Debug, Deserialize, Clone)]
59struct Jwk { kty:String, crv:Option<String>, x:Option<String>, kid:Option<String> }
60#[derive(Debug, Deserialize, Clone)]
61struct Jwks { keys: Vec<Jwk> }
62
63static GLOBAL_JWKS: Lazy<Mutex<HashMap<String, (Jwks, i64)>>> = Lazy::new(|| Mutex::new(HashMap::new()));
64
65fn fetch_jwks(uri: &str) -> Result<Jwks, String> {
66    let g = GLOBAL_JWKS.lock();
67    let now = now_ts();
68    if let Some((jwks, ts)) = g.get(uri).cloned() {
69        if now - ts <= 300 { return Ok(jwks); }
70    }
71    drop(g);
72    let resp = ureq::get(uri).call().map_err(|e| format!("jwks http: {e}"))?;
73    let body = resp.into_string().map_err(|e| format!("jwks body: {e}"))?;
74    let jwks: Jwks = serde_json::from_str(&body).map_err(|e| format!("jwks json: {e}"))?;
75    let mut g2 = GLOBAL_JWKS.lock();
76    g2.insert(uri.to_string(), (jwks.clone(), now));
77    Ok(jwks)
78}
79fn key_by_kid(jwks: &Jwks, kid: &str) -> Option<VerifyingKey> {
80    for k in &jwks.keys {
81        if k.kty != "OKP" { continue; }
82        if k.crv.as_deref() != Some("Ed25519") { continue; }
83        let k_kid = k.kid.as_deref().unwrap_or_default();
84        if k_kid == kid || k_kid.is_empty() {
85            if let Some(x) = &k.x {
86                if let Ok(bytes) = B64URL.decode(x.as_bytes()) {
87                    if let Ok(vk) = VerifyingKey::from_bytes(bytes[..].try_into().ok()?) {
88                        return Some(vk);
89                    }
90                }
91            }
92        }
93    }
94    None
95}
96
97pub fn now_ts() -> i64 { SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_secs() as i64 }
98pub fn new_trace_id() -> String { ulid::Ulid::new().to_string() }
99
100pub async fn actor_from_headers(headers: &HeaderMap, jwks_uri: &str, require_aud: Option<&str>) -> Result<SubjectCtx, String> {
101    let authz = headers.get("authorization").and_then(|v| v.to_str().ok()).ok_or("missing authorization")?;
102    if !authz.starts_with("Bearer ") { return Err("missing bearer".into()); }
103    let token = &authz[7..];
104    let parts: Vec<&str> = token.split('.').collect();
105    if parts.len() != 3 { return Err("bad jwt".into()); }
106
107    let hdr_json = String::from_utf8(B64URL.decode(parts[0].as_bytes()).map_err(|_| "b64")?).map_err(|_| "utf8")?;
108    let pld_json = String::from_utf8(B64URL.decode(parts[1].as_bytes()).map_err(|_| "b64")?).map_err(|_| "utf8")?;
109    let hdr: Json = serde_json::from_str(&hdr_json).map_err(|_| "json")?;
110    let pld: Json = serde_json::from_str(&pld_json).map_err(|_| "json")?;
111
112    if hdr.get("alg").and_then(|v| v.as_str()) != Some("EdDSA") { return Err("alg".into()); }
113    let kid = hdr.get("kid").and_then(|v| v.as_str()).ok_or("kid")?;
114    let jwks = fetch_jwks(jwks_uri)?;
115    let vk = key_by_kid(&jwks, kid).ok_or("no key")?;
116
117    let msg = format!("{}.{}", parts[0], parts[1]);
118    let sig_bytes = B64URL.decode(parts[2].as_bytes()).map_err(|_| "b64")?;
119    let sig = Signature::from_bytes(sig_bytes[..].try_into().map_err(|_| "sig")?);
120    vk.verify_strict(msg.as_bytes(), &sig).map_err(|_| "sig verify")?;
121
122    let did = pld.get("sub").and_then(|v| v.as_str()).ok_or("sub")?.to_string();
123    let iss = pld.get("iss").and_then(|v| v.as_str()).unwrap_or("").to_string();
124    let jti = pld.get("jti").and_then(|v| v.as_str()).map(|s| s.to_string());
125    let scope = pld.get("scope").and_then(|v| v.as_str()).unwrap_or("").split_whitespace().map(|s| s.to_string()).collect::<Vec<_>>();
126    let aud = pld.get("aud").and_then(|v| v.as_str()).map(|s| s.to_string());
127    if let Some(expect) = require_aud {
128        if aud.as_deref() != Some(expect) { return Err("aud mismatch".into()); }
129    }
130    Ok(SubjectCtx{ did, iss, jti, scope, aud })
131}
132
133pub async fn tenant_from_parts(parts: &Parts) -> Result<TenantCtx, String> {
134    // try header first
135    if let Some(t) = parts.headers.get("x-ubl-tenant").and_then(|v| v.to_str().ok()) {
136        return Ok(TenantCtx { tenant: t.to_string() });
137    }
138    // else parse from path: expect /t/{tenant}/...
139    let path = parts.uri.path();
140    let segs: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect();
141    if segs.len() >= 2 && segs[0] == "t" {
142        return Ok(TenantCtx { tenant: segs[1].to_string() });
143    }
144    Err("missing tenant".into())
145}
146
147pub async fn verify_pop(headers: &HeaderMap, method: &str, path: &str, bearer_opt: Option<&str>) -> Result<WalletPop, String> {
148    // X-UBL-POW: base64url(JSON{wallet_did, ts, method, path, sig, ath?})
149    let h = headers.get("x-ubl-pow").and_then(|v| v.to_str().ok()).ok_or("missing X-UBL-POW")?;
150    let raw = B64URL.decode(h.as_bytes()).map_err(|_| "b64")?;
151    let obj: Json = serde_json::from_slice(&raw).map_err(|_| "json")?;
152    let wallet_did = obj.get("wallet_did").and_then(|v| v.as_str()).ok_or("wallet_did")?;
153    let ts = obj.get("ts").and_then(|v| v.as_i64()).ok_or("ts")?;
154    let m = obj.get("method").and_then(|v| v.as_str()).ok_or("method")?;
155    let p = obj.get("path").and_then(|v| v.as_str()).ok_or("path")?;
156    if m != method || p != path { return Err("request binding mismatch".into()); }
157    if (now_ts() - ts).abs() > 120 { return Err("ts skew".into()); }
158    // optional token binding
159    let has_ath = if let (Some(bearer), Some(ath)) = (bearer_opt, obj.get("ath").and_then(|v| v.as_str())) {
160        let tok = bearer.strip_prefix("Bearer ").unwrap_or(bearer);
161        let digest = blake3::hash(tok.as_bytes());
162        let ath_calc = B64URL.encode(digest.as_bytes());
163        if ath != ath_calc { return Err("ath mismatch".into()); }
164        true
165    } else { false };
166    Ok(WalletPop{ wallet_did: wallet_did.to_string(), has_ath })
167}
168
169/// dumb rate limiter: per (tenant, id), 60 req/min
170#[derive(Default)]
171pub struct RateLimiter { inner: Mutex<HashMap<(String, String), Vec<i64>>> }
172impl RateLimiter {
173    pub fn allow(&self, tenant: &str, id: &str, limit: usize) -> bool {
174        let now = now_ts();
175        let mut m = self.inner.lock();
176        let v = m.entry((tenant.to_string(), id.to_string())).or_default();
177        v.retain(|&t| now - t <= 60);
178        if v.len() >= limit { return false; }
179        v.push(now);
180        true
181    }
182}