Skip to main content

tf_types/
bridge_oauth.rs

1//! OAuth/GNAP bridge — verify a JWT bearer token using the in-house
2//! `crate::jws` module,
3//! against a static or remote JWKS, and project the verified claims into a
4//! TrustForge actor identity + capabilities.
5//!
6//! Supports ES256 / ES384 / RS256 / RS384 / RS512 / EdDSA. Algorithm
7//! confusion attacks (alg:none, HS256-with-RSA-key) are guarded by the
8//! mandatory allow-list passed at bridge construction time.
9
10use std::collections::HashMap;
11
12use crate::jws::{decode, decode_header, Algorithm, DecodingKey, Validation};
13use serde::{Deserialize, Serialize};
14use serde_json::Value;
15
16use crate::bridges::{Bridge, BridgeError, BridgeKind};
17use crate::generated::{
18    ActorIdentity, ActorIdentity_IdentityVersion, ActorType, AuthorityRoot, AuthorityRoot_Kind,
19    PublicKey, PublicKey_Purpose, TrustLevel,
20};
21
22#[derive(Clone, Debug, Serialize, Deserialize)]
23pub struct OAuthBridgeConfig {
24    pub bridge_id: String,
25    pub trust_domain: String,
26    pub jwks: Jwks,
27    pub allowed_algorithms: Vec<String>,
28    pub issuer: String,
29    pub audience: Vec<String>,
30    #[serde(default = "default_clock_tolerance")]
31    pub clock_tolerance_seconds: u64,
32}
33
34fn default_clock_tolerance() -> u64 {
35    60
36}
37
38#[derive(Clone, Debug, Serialize, Deserialize)]
39pub struct Jwks {
40    pub keys: Vec<Jwk>,
41}
42
43/// Minimal JWK shape the bridge accepts. ES256/ES384 use x/y; RS* use n/e;
44/// EdDSA uses crv=Ed25519 + x.
45#[derive(Clone, Debug, Serialize, Deserialize)]
46pub struct Jwk {
47    pub kty: String,
48    #[serde(default)]
49    pub alg: Option<String>,
50    #[serde(default)]
51    pub kid: Option<String>,
52    #[serde(default)]
53    pub crv: Option<String>,
54    #[serde(default)]
55    pub x: Option<String>,
56    #[serde(default)]
57    pub y: Option<String>,
58    #[serde(default)]
59    pub n: Option<String>,
60    #[serde(default)]
61    pub e: Option<String>,
62}
63
64#[derive(Clone, Debug, Serialize, Deserialize)]
65pub struct OAuthClaims {
66    pub iss: Option<String>,
67    pub sub: Option<String>,
68    pub aud: Option<Value>,
69    pub exp: Option<u64>,
70    pub iat: Option<u64>,
71    pub scope: Option<Value>,
72    #[serde(rename = "tf_actor_type", default)]
73    pub tf_actor_type: Option<String>,
74    #[serde(flatten)]
75    pub extra: HashMap<String, Value>,
76}
77
78#[derive(Clone, Debug)]
79pub struct OAuthVerificationResult {
80    pub identity: ActorIdentity,
81    pub capabilities: Vec<String>,
82    pub claims: OAuthClaims,
83}
84
85pub struct OAuthBridge {
86    cfg: OAuthBridgeConfig,
87}
88
89impl OAuthBridge {
90    pub fn new(cfg: OAuthBridgeConfig) -> Self {
91        OAuthBridge { cfg }
92    }
93
94    pub fn verify_token(&self, token: &str) -> Result<OAuthVerificationResult, BridgeError> {
95        if token.is_empty() {
96            return Err(BridgeError::InvalidInput("empty token".into()));
97        }
98        let header = decode_header(token)
99            .map_err(|e| BridgeError::Rejected(format!("malformed JWT: {}", e)))?;
100        let alg = header
101            .algorithm()
102            .map_err(|e| BridgeError::Rejected(e.to_string()))?;
103        let alg_name = alg.name().to_string();
104        if !self
105            .cfg
106            .allowed_algorithms
107            .iter()
108            .any(|a| a.eq_ignore_ascii_case(&alg_name))
109        {
110            return Err(BridgeError::Rejected(format!(
111                "algorithm {} not in allow-list",
112                alg_name
113            )));
114        }
115
116        let kid = header
117            .kid
118            .clone()
119            .ok_or_else(|| BridgeError::Rejected("JWT header missing kid".into()))?;
120        let jwk = self
121            .cfg
122            .jwks
123            .keys
124            .iter()
125            .find(|k| k.kid.as_deref() == Some(&kid))
126            .ok_or_else(|| BridgeError::Rejected(format!("no JWK with kid {}", kid)))?;
127        let key = decoding_key_for(jwk)?;
128
129        let mut validation = Validation::new(alg);
130        validation.set_issuer(&[self.cfg.issuer.as_str()]);
131        validation.set_audience(&self.cfg.audience);
132        validation.leeway = self.cfg.clock_tolerance_seconds;
133        validation.algorithms = vec![alg];
134
135        let data = decode::<OAuthClaims>(token, &key, &validation)
136            .map_err(|e| BridgeError::Rejected(format!("JWT verify failed: {}", e)))?;
137        let claims = data.claims;
138        let subject = claims
139            .sub
140            .clone()
141            .ok_or_else(|| BridgeError::Rejected("JWT missing sub claim".into()))?;
142        let actor_type_str = claims.tf_actor_type.as_deref().unwrap_or("human");
143        let actor_type = match actor_type_str {
144            "human" => ActorType::Human,
145            "agent" => ActorType::Agent,
146            "device" => ActorType::Device,
147            "service" => ActorType::Service,
148            "site" => ActorType::Site,
149            "organization" => ActorType::Organization,
150            other => {
151                return Err(BridgeError::Rejected(format!(
152                    "unsupported tf_actor_type: {}",
153                    other
154                )))
155            }
156        };
157        let encoded_subject = encode_subject(&subject);
158        let actor_id = format!(
159            "tf:actor:{}:{}/{}",
160            actor_type_str, self.cfg.trust_domain, encoded_subject
161        );
162
163        let identity = ActorIdentity {
164            identity_version: ActorIdentity_IdentityVersion::V1,
165            actor_id,
166            actor_type,
167            instance_id: None,
168            public_keys: vec![project_jwk_to_public_key(jwk)?],
169            trust_levels: vec![TrustLevel::T3],
170            authority_roots: vec![AuthorityRoot {
171                kind: AuthorityRoot_Kind::Organization,
172                id: self.cfg.issuer.clone(),
173            }],
174            attestations: None,
175            valid_from: claims
176                .iat
177                .map(timestamp)
178                .unwrap_or_else(|| timestamp(now_unix())),
179            valid_until: claims.exp.map(timestamp),
180            revocation_ref: None,
181            signature: None,
182        };
183
184        let capabilities = scopes_from_claims(&claims);
185
186        Ok(OAuthVerificationResult {
187            identity,
188            capabilities,
189            claims,
190        })
191    }
192}
193
194impl Bridge for OAuthBridge {
195    fn bridge_id(&self) -> &str {
196        &self.cfg.bridge_id
197    }
198    fn kind(&self) -> BridgeKind {
199        BridgeKind::Oauth
200    }
201    fn trust_domain(&self) -> &str {
202        &self.cfg.trust_domain
203    }
204}
205
206fn decoding_key_for(jwk: &Jwk) -> Result<DecodingKey, BridgeError> {
207    match jwk.kty.as_str() {
208        "EC" => {
209            let x = jwk
210                .x
211                .as_ref()
212                .ok_or_else(|| BridgeError::InvalidInput("EC JWK missing x".into()))?;
213            let y = jwk
214                .y
215                .as_ref()
216                .ok_or_else(|| BridgeError::InvalidInput("EC JWK missing y".into()))?;
217            DecodingKey::from_ec_components(x, y)
218                .map_err(|e| BridgeError::InvalidInput(format!("bad EC components: {}", e)))
219        }
220        "RSA" => {
221            let n = jwk
222                .n
223                .as_ref()
224                .ok_or_else(|| BridgeError::InvalidInput("RSA JWK missing n".into()))?;
225            let e = jwk
226                .e
227                .as_ref()
228                .ok_or_else(|| BridgeError::InvalidInput("RSA JWK missing e".into()))?;
229            DecodingKey::from_rsa_components(n, e)
230                .map_err(|e| BridgeError::InvalidInput(format!("bad RSA components: {}", e)))
231        }
232        "OKP" => {
233            let x = jwk
234                .x
235                .as_ref()
236                .ok_or_else(|| BridgeError::InvalidInput("OKP JWK missing x".into()))?;
237            DecodingKey::from_ed_components(x)
238                .map_err(|e| BridgeError::InvalidInput(format!("bad OKP components: {}", e)))
239        }
240        other => Err(BridgeError::InvalidInput(format!(
241            "unsupported kty {}",
242            other
243        ))),
244    }
245}
246
247fn encode_subject(s: &str) -> String {
248    // Percent-encode anything outside the unreserved RFC 3986 set so the
249    // subject can be embedded in an actor URI path segment.
250    let mut out = String::with_capacity(s.len());
251    for b in s.bytes() {
252        match b {
253            b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
254                out.push(b as char);
255            }
256            _ => out.push_str(&format!("%{:02X}", b)),
257        }
258    }
259    out
260}
261
262fn scopes_from_claims(claims: &OAuthClaims) -> Vec<String> {
263    match &claims.scope {
264        Some(Value::String(s)) => s.split_whitespace().map(str::to_string).collect(),
265        Some(Value::Array(arr)) => arr
266            .iter()
267            .filter_map(|v| v.as_str().map(str::to_string))
268            .collect(),
269        _ => Vec::new(),
270    }
271}
272
273fn timestamp(t: u64) -> String {
274    // Format as RFC 3339 UTC.
275    let datetime = std::time::UNIX_EPOCH + std::time::Duration::from_secs(t);
276    let secs = datetime
277        .duration_since(std::time::UNIX_EPOCH)
278        .expect("post-epoch")
279        .as_secs() as i64;
280    // Build YYYY-MM-DDTHH:MM:SSZ from secs without bringing chrono in.
281    let (year, month, day, hour, minute, second) = secs_to_ymdhms(secs);
282    format!(
283        "{:04}-{:02}-{:02}T{:02}:{:02}:{:02}Z",
284        year, month, day, hour, minute, second
285    )
286}
287
288fn now_unix() -> u64 {
289    std::time::SystemTime::now()
290        .duration_since(std::time::UNIX_EPOCH)
291        .unwrap_or_default()
292        .as_secs()
293}
294
295fn secs_to_ymdhms(secs: i64) -> (i32, u32, u32, u32, u32, u32) {
296    // Civil-from-days algorithm by Howard Hinnant.
297    let days = secs.div_euclid(86_400);
298    let time = secs.rem_euclid(86_400);
299    let hour = (time / 3600) as u32;
300    let minute = ((time % 3600) / 60) as u32;
301    let second = (time % 60) as u32;
302
303    let z = days + 719_468;
304    let era = if z >= 0 { z } else { z - 146_096 } / 146_097;
305    let doe = (z - era * 146_097) as u64; // [0, 146096]
306    let yoe = (doe - doe / 1460 + doe / 36524 - doe / 146096) / 365;
307    let y = yoe as i64 + era * 400;
308    let doy = doe - (365 * yoe + yoe / 4 - yoe / 100);
309    let mp = (5 * doy + 2) / 153;
310    let d = (doy - (153 * mp + 2) / 5 + 1) as u32;
311    let m = if mp < 10 {
312        (mp + 3) as u32
313    } else {
314        (mp - 9) as u32
315    };
316    let year = if m <= 2 { y + 1 } else { y };
317    (year as i32, m, d, hour, minute, second)
318}
319
320pub fn parse_algorithm(name: &str) -> Result<Algorithm, BridgeError> {
321    Algorithm::parse(name).map_err(|e| BridgeError::InvalidInput(e.to_string()))
322}
323
324/// Project a JWK into the TrustForge `PublicKey` shape (raw bytes,
325/// base64-encoded, with the algorithm name normalised to TrustForge's
326/// vocabulary). Mirrors the TS `projectJwkToPublicKey`.
327pub fn project_jwk_to_public_key(jwk: &Jwk) -> Result<PublicKey, BridgeError> {
328    use crate::encoding::{STANDARD, URL_SAFE_NO_PAD};
329    let key_id = jwk
330        .kid
331        .clone()
332        .unwrap_or_else(|| "oauth-bridge-bearer".to_string());
333    match jwk.kty.as_str() {
334        "OKP" => {
335            // Ed25519 — raw 32-byte x.
336            let x = jwk
337                .x
338                .as_ref()
339                .ok_or_else(|| BridgeError::InvalidInput("OKP JWK missing x".into()))?;
340            let bytes = URL_SAFE_NO_PAD
341                .decode(x)
342                .map_err(|e| BridgeError::InvalidInput(format!("base64url x: {}", e)))?;
343            Ok(PublicKey {
344                key_id,
345                algorithm: "ed25519".into(),
346                public_key: STANDARD.encode(bytes),
347                purpose: PublicKey_Purpose::Signing,
348                valid_from: None,
349                valid_until: None,
350            })
351        }
352        "EC" => {
353            let x = jwk
354                .x
355                .as_ref()
356                .ok_or_else(|| BridgeError::InvalidInput("EC JWK missing x".into()))?;
357            let y = jwk
358                .y
359                .as_ref()
360                .ok_or_else(|| BridgeError::InvalidInput("EC JWK missing y".into()))?;
361            let xb = URL_SAFE_NO_PAD
362                .decode(x)
363                .map_err(|e| BridgeError::InvalidInput(format!("base64url x: {}", e)))?;
364            let yb = URL_SAFE_NO_PAD
365                .decode(y)
366                .map_err(|e| BridgeError::InvalidInput(format!("base64url y: {}", e)))?;
367            let mut sec1 = Vec::with_capacity(1 + xb.len() + yb.len());
368            sec1.push(0x04);
369            sec1.extend_from_slice(&xb);
370            sec1.extend_from_slice(&yb);
371            let crv = jwk.crv.as_deref().unwrap_or("");
372            let alg = match crv {
373                "P-256" => "p256",
374                "P-384" => "p384",
375                "P-521" => "p521",
376                _ => "ec",
377            };
378            Ok(PublicKey {
379                key_id,
380                algorithm: alg.into(),
381                public_key: STANDARD.encode(sec1),
382                purpose: PublicKey_Purpose::Signing,
383                valid_from: None,
384                valid_until: None,
385            })
386        }
387        "RSA" => {
388            let n = jwk
389                .n
390                .as_ref()
391                .ok_or_else(|| BridgeError::InvalidInput("RSA JWK missing n".into()))?;
392            let e = jwk
393                .e
394                .as_ref()
395                .ok_or_else(|| BridgeError::InvalidInput("RSA JWK missing e".into()))?;
396            let nb = URL_SAFE_NO_PAD
397                .decode(n)
398                .map_err(|err| BridgeError::InvalidInput(format!("base64url n: {}", err)))?;
399            let eb = URL_SAFE_NO_PAD
400                .decode(e)
401                .map_err(|err| BridgeError::InvalidInput(format!("base64url e: {}", err)))?;
402            let der = encode_rsa_spki(&nb, &eb);
403            Ok(PublicKey {
404                key_id,
405                algorithm: "rsa".into(),
406                public_key: STANDARD.encode(der),
407                purpose: PublicKey_Purpose::Signing,
408                valid_from: None,
409                valid_until: None,
410            })
411        }
412        other => Err(BridgeError::Unsupported(format!(
413            "unsupported JWK kty: {}",
414            other
415        ))),
416    }
417}
418
419fn encode_rsa_spki(n: &[u8], e: &[u8]) -> Vec<u8> {
420    let rsa_public_key = der_sequence(&[der_integer(n), der_integer(e)]);
421    let oid_rsa_encryption: [u8; 11] = [
422        0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x01,
423    ];
424    let null_params: [u8; 2] = [0x05, 0x00];
425    let alg_id = der_sequence(&[oid_rsa_encryption.to_vec(), null_params.to_vec()]);
426    let mut bit_string_body = Vec::with_capacity(1 + rsa_public_key.len());
427    bit_string_body.push(0x00);
428    bit_string_body.extend_from_slice(&rsa_public_key);
429    let mut bit_string = Vec::with_capacity(2 + bit_string_body.len());
430    bit_string.push(0x03);
431    bit_string.extend_from_slice(&der_len(bit_string_body.len()));
432    bit_string.extend_from_slice(&bit_string_body);
433    der_sequence(&[alg_id, bit_string])
434}
435
436fn der_sequence(parts: &[Vec<u8>]) -> Vec<u8> {
437    let body: Vec<u8> = parts.iter().flat_map(|p| p.clone()).collect();
438    let mut out = Vec::with_capacity(2 + body.len());
439    out.push(0x30);
440    out.extend_from_slice(&der_len(body.len()));
441    out.extend_from_slice(&body);
442    out
443}
444
445fn der_integer(bytes: &[u8]) -> Vec<u8> {
446    let mut start = 0usize;
447    while start < bytes.len() - 1 && bytes[start] == 0 {
448        start += 1;
449    }
450    let payload = &bytes[start..];
451    let needs_pad = payload[0] & 0x80 != 0;
452    let len = payload.len() + if needs_pad { 1 } else { 0 };
453    let mut out = Vec::with_capacity(2 + len);
454    out.push(0x02);
455    out.extend_from_slice(&der_len(len));
456    if needs_pad {
457        out.push(0x00);
458    }
459    out.extend_from_slice(payload);
460    out
461}
462
463fn der_len(n: usize) -> Vec<u8> {
464    if n < 0x80 {
465        return vec![n as u8];
466    }
467    let mut bytes = Vec::new();
468    let mut v = n;
469    while v > 0 {
470        bytes.insert(0, (v & 0xff) as u8);
471        v >>= 8;
472    }
473    let mut out = Vec::with_capacity(1 + bytes.len());
474    out.push(0x80 | bytes.len() as u8);
475    out.extend_from_slice(&bytes);
476    out
477}