1use std::collections::HashMap;
11
12use crate::jws::{decode, decode_header, Algorithm, DecodingKey, Validation};
13use serde::{Deserialize, Serialize};
14use serde_json::Value;
15
16use crate::bridges::{Bridge, BridgeError, BridgeKind};
17use crate::generated::{
18 ActorIdentity, ActorIdentity_IdentityVersion, ActorType, AuthorityRoot, AuthorityRoot_Kind,
19 PublicKey, PublicKey_Purpose, TrustLevel,
20};
21
22#[derive(Clone, Debug, Serialize, Deserialize)]
23pub struct OAuthBridgeConfig {
24 pub bridge_id: String,
25 pub trust_domain: String,
26 pub jwks: Jwks,
27 pub allowed_algorithms: Vec<String>,
28 pub issuer: String,
29 pub audience: Vec<String>,
30 #[serde(default = "default_clock_tolerance")]
31 pub clock_tolerance_seconds: u64,
32}
33
34fn default_clock_tolerance() -> u64 {
35 60
36}
37
38#[derive(Clone, Debug, Serialize, Deserialize)]
39pub struct Jwks {
40 pub keys: Vec<Jwk>,
41}
42
43#[derive(Clone, Debug, Serialize, Deserialize)]
46pub struct Jwk {
47 pub kty: String,
48 #[serde(default)]
49 pub alg: Option<String>,
50 #[serde(default)]
51 pub kid: Option<String>,
52 #[serde(default)]
53 pub crv: Option<String>,
54 #[serde(default)]
55 pub x: Option<String>,
56 #[serde(default)]
57 pub y: Option<String>,
58 #[serde(default)]
59 pub n: Option<String>,
60 #[serde(default)]
61 pub e: Option<String>,
62}
63
64#[derive(Clone, Debug, Serialize, Deserialize)]
65pub struct OAuthClaims {
66 pub iss: Option<String>,
67 pub sub: Option<String>,
68 pub aud: Option<Value>,
69 pub exp: Option<u64>,
70 pub iat: Option<u64>,
71 pub scope: Option<Value>,
72 #[serde(rename = "tf_actor_type", default)]
73 pub tf_actor_type: Option<String>,
74 #[serde(flatten)]
75 pub extra: HashMap<String, Value>,
76}
77
78#[derive(Clone, Debug)]
79pub struct OAuthVerificationResult {
80 pub identity: ActorIdentity,
81 pub capabilities: Vec<String>,
82 pub claims: OAuthClaims,
83}
84
85pub struct OAuthBridge {
86 cfg: OAuthBridgeConfig,
87}
88
89impl OAuthBridge {
90 pub fn new(cfg: OAuthBridgeConfig) -> Self {
91 OAuthBridge { cfg }
92 }
93
94 pub fn verify_token(&self, token: &str) -> Result<OAuthVerificationResult, BridgeError> {
95 if token.is_empty() {
96 return Err(BridgeError::InvalidInput("empty token".into()));
97 }
98 let header = decode_header(token)
99 .map_err(|e| BridgeError::Rejected(format!("malformed JWT: {}", e)))?;
100 let alg = header
101 .algorithm()
102 .map_err(|e| BridgeError::Rejected(e.to_string()))?;
103 let alg_name = alg.name().to_string();
104 if !self
105 .cfg
106 .allowed_algorithms
107 .iter()
108 .any(|a| a.eq_ignore_ascii_case(&alg_name))
109 {
110 return Err(BridgeError::Rejected(format!(
111 "algorithm {} not in allow-list",
112 alg_name
113 )));
114 }
115
116 let kid = header
117 .kid
118 .clone()
119 .ok_or_else(|| BridgeError::Rejected("JWT header missing kid".into()))?;
120 let jwk = self
121 .cfg
122 .jwks
123 .keys
124 .iter()
125 .find(|k| k.kid.as_deref() == Some(&kid))
126 .ok_or_else(|| BridgeError::Rejected(format!("no JWK with kid {}", kid)))?;
127 let key = decoding_key_for(jwk)?;
128
129 let mut validation = Validation::new(alg);
130 validation.set_issuer(&[self.cfg.issuer.as_str()]);
131 validation.set_audience(&self.cfg.audience);
132 validation.leeway = self.cfg.clock_tolerance_seconds;
133 validation.algorithms = vec![alg];
134
135 let data = decode::<OAuthClaims>(token, &key, &validation)
136 .map_err(|e| BridgeError::Rejected(format!("JWT verify failed: {}", e)))?;
137 let claims = data.claims;
138 let subject = claims
139 .sub
140 .clone()
141 .ok_or_else(|| BridgeError::Rejected("JWT missing sub claim".into()))?;
142 let actor_type_str = claims.tf_actor_type.as_deref().unwrap_or("human");
143 let actor_type = match actor_type_str {
144 "human" => ActorType::Human,
145 "agent" => ActorType::Agent,
146 "device" => ActorType::Device,
147 "service" => ActorType::Service,
148 "site" => ActorType::Site,
149 "organization" => ActorType::Organization,
150 other => {
151 return Err(BridgeError::Rejected(format!(
152 "unsupported tf_actor_type: {}",
153 other
154 )))
155 }
156 };
157 let encoded_subject = encode_subject(&subject);
158 let actor_id = format!(
159 "tf:actor:{}:{}/{}",
160 actor_type_str, self.cfg.trust_domain, encoded_subject
161 );
162
163 let identity = ActorIdentity {
164 identity_version: ActorIdentity_IdentityVersion::V1,
165 actor_id,
166 actor_type,
167 instance_id: None,
168 public_keys: vec![project_jwk_to_public_key(jwk)?],
169 trust_levels: vec![TrustLevel::T3],
170 authority_roots: vec![AuthorityRoot {
171 kind: AuthorityRoot_Kind::Organization,
172 id: self.cfg.issuer.clone(),
173 }],
174 attestations: None,
175 valid_from: claims
176 .iat
177 .map(timestamp)
178 .unwrap_or_else(|| timestamp(now_unix())),
179 valid_until: claims.exp.map(timestamp),
180 revocation_ref: None,
181 signature: None,
182 };
183
184 let capabilities = scopes_from_claims(&claims);
185
186 Ok(OAuthVerificationResult {
187 identity,
188 capabilities,
189 claims,
190 })
191 }
192}
193
194impl Bridge for OAuthBridge {
195 fn bridge_id(&self) -> &str {
196 &self.cfg.bridge_id
197 }
198 fn kind(&self) -> BridgeKind {
199 BridgeKind::Oauth
200 }
201 fn trust_domain(&self) -> &str {
202 &self.cfg.trust_domain
203 }
204}
205
206fn decoding_key_for(jwk: &Jwk) -> Result<DecodingKey, BridgeError> {
207 match jwk.kty.as_str() {
208 "EC" => {
209 let x = jwk
210 .x
211 .as_ref()
212 .ok_or_else(|| BridgeError::InvalidInput("EC JWK missing x".into()))?;
213 let y = jwk
214 .y
215 .as_ref()
216 .ok_or_else(|| BridgeError::InvalidInput("EC JWK missing y".into()))?;
217 DecodingKey::from_ec_components(x, y)
218 .map_err(|e| BridgeError::InvalidInput(format!("bad EC components: {}", e)))
219 }
220 "RSA" => {
221 let n = jwk
222 .n
223 .as_ref()
224 .ok_or_else(|| BridgeError::InvalidInput("RSA JWK missing n".into()))?;
225 let e = jwk
226 .e
227 .as_ref()
228 .ok_or_else(|| BridgeError::InvalidInput("RSA JWK missing e".into()))?;
229 DecodingKey::from_rsa_components(n, e)
230 .map_err(|e| BridgeError::InvalidInput(format!("bad RSA components: {}", e)))
231 }
232 "OKP" => {
233 let x = jwk
234 .x
235 .as_ref()
236 .ok_or_else(|| BridgeError::InvalidInput("OKP JWK missing x".into()))?;
237 DecodingKey::from_ed_components(x)
238 .map_err(|e| BridgeError::InvalidInput(format!("bad OKP components: {}", e)))
239 }
240 other => Err(BridgeError::InvalidInput(format!(
241 "unsupported kty {}",
242 other
243 ))),
244 }
245}
246
247fn encode_subject(s: &str) -> String {
248 let mut out = String::with_capacity(s.len());
251 for b in s.bytes() {
252 match b {
253 b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
254 out.push(b as char);
255 }
256 _ => out.push_str(&format!("%{:02X}", b)),
257 }
258 }
259 out
260}
261
262fn scopes_from_claims(claims: &OAuthClaims) -> Vec<String> {
263 match &claims.scope {
264 Some(Value::String(s)) => s.split_whitespace().map(str::to_string).collect(),
265 Some(Value::Array(arr)) => arr
266 .iter()
267 .filter_map(|v| v.as_str().map(str::to_string))
268 .collect(),
269 _ => Vec::new(),
270 }
271}
272
273fn timestamp(t: u64) -> String {
274 let datetime = std::time::UNIX_EPOCH + std::time::Duration::from_secs(t);
276 let secs = datetime
277 .duration_since(std::time::UNIX_EPOCH)
278 .expect("post-epoch")
279 .as_secs() as i64;
280 let (year, month, day, hour, minute, second) = secs_to_ymdhms(secs);
282 format!(
283 "{:04}-{:02}-{:02}T{:02}:{:02}:{:02}Z",
284 year, month, day, hour, minute, second
285 )
286}
287
288fn now_unix() -> u64 {
289 std::time::SystemTime::now()
290 .duration_since(std::time::UNIX_EPOCH)
291 .unwrap_or_default()
292 .as_secs()
293}
294
295fn secs_to_ymdhms(secs: i64) -> (i32, u32, u32, u32, u32, u32) {
296 let days = secs.div_euclid(86_400);
298 let time = secs.rem_euclid(86_400);
299 let hour = (time / 3600) as u32;
300 let minute = ((time % 3600) / 60) as u32;
301 let second = (time % 60) as u32;
302
303 let z = days + 719_468;
304 let era = if z >= 0 { z } else { z - 146_096 } / 146_097;
305 let doe = (z - era * 146_097) as u64; let yoe = (doe - doe / 1460 + doe / 36524 - doe / 146096) / 365;
307 let y = yoe as i64 + era * 400;
308 let doy = doe - (365 * yoe + yoe / 4 - yoe / 100);
309 let mp = (5 * doy + 2) / 153;
310 let d = (doy - (153 * mp + 2) / 5 + 1) as u32;
311 let m = if mp < 10 {
312 (mp + 3) as u32
313 } else {
314 (mp - 9) as u32
315 };
316 let year = if m <= 2 { y + 1 } else { y };
317 (year as i32, m, d, hour, minute, second)
318}
319
320pub fn parse_algorithm(name: &str) -> Result<Algorithm, BridgeError> {
321 Algorithm::parse(name).map_err(|e| BridgeError::InvalidInput(e.to_string()))
322}
323
324pub fn project_jwk_to_public_key(jwk: &Jwk) -> Result<PublicKey, BridgeError> {
328 use crate::encoding::{STANDARD, URL_SAFE_NO_PAD};
329 let key_id = jwk
330 .kid
331 .clone()
332 .unwrap_or_else(|| "oauth-bridge-bearer".to_string());
333 match jwk.kty.as_str() {
334 "OKP" => {
335 let x = jwk
337 .x
338 .as_ref()
339 .ok_or_else(|| BridgeError::InvalidInput("OKP JWK missing x".into()))?;
340 let bytes = URL_SAFE_NO_PAD
341 .decode(x)
342 .map_err(|e| BridgeError::InvalidInput(format!("base64url x: {}", e)))?;
343 Ok(PublicKey {
344 key_id,
345 algorithm: "ed25519".into(),
346 public_key: STANDARD.encode(bytes),
347 purpose: PublicKey_Purpose::Signing,
348 valid_from: None,
349 valid_until: None,
350 })
351 }
352 "EC" => {
353 let x = jwk
354 .x
355 .as_ref()
356 .ok_or_else(|| BridgeError::InvalidInput("EC JWK missing x".into()))?;
357 let y = jwk
358 .y
359 .as_ref()
360 .ok_or_else(|| BridgeError::InvalidInput("EC JWK missing y".into()))?;
361 let xb = URL_SAFE_NO_PAD
362 .decode(x)
363 .map_err(|e| BridgeError::InvalidInput(format!("base64url x: {}", e)))?;
364 let yb = URL_SAFE_NO_PAD
365 .decode(y)
366 .map_err(|e| BridgeError::InvalidInput(format!("base64url y: {}", e)))?;
367 let mut sec1 = Vec::with_capacity(1 + xb.len() + yb.len());
368 sec1.push(0x04);
369 sec1.extend_from_slice(&xb);
370 sec1.extend_from_slice(&yb);
371 let crv = jwk.crv.as_deref().unwrap_or("");
372 let alg = match crv {
373 "P-256" => "p256",
374 "P-384" => "p384",
375 "P-521" => "p521",
376 _ => "ec",
377 };
378 Ok(PublicKey {
379 key_id,
380 algorithm: alg.into(),
381 public_key: STANDARD.encode(sec1),
382 purpose: PublicKey_Purpose::Signing,
383 valid_from: None,
384 valid_until: None,
385 })
386 }
387 "RSA" => {
388 let n = jwk
389 .n
390 .as_ref()
391 .ok_or_else(|| BridgeError::InvalidInput("RSA JWK missing n".into()))?;
392 let e = jwk
393 .e
394 .as_ref()
395 .ok_or_else(|| BridgeError::InvalidInput("RSA JWK missing e".into()))?;
396 let nb = URL_SAFE_NO_PAD
397 .decode(n)
398 .map_err(|err| BridgeError::InvalidInput(format!("base64url n: {}", err)))?;
399 let eb = URL_SAFE_NO_PAD
400 .decode(e)
401 .map_err(|err| BridgeError::InvalidInput(format!("base64url e: {}", err)))?;
402 let der = encode_rsa_spki(&nb, &eb);
403 Ok(PublicKey {
404 key_id,
405 algorithm: "rsa".into(),
406 public_key: STANDARD.encode(der),
407 purpose: PublicKey_Purpose::Signing,
408 valid_from: None,
409 valid_until: None,
410 })
411 }
412 other => Err(BridgeError::Unsupported(format!(
413 "unsupported JWK kty: {}",
414 other
415 ))),
416 }
417}
418
419fn encode_rsa_spki(n: &[u8], e: &[u8]) -> Vec<u8> {
420 let rsa_public_key = der_sequence(&[der_integer(n), der_integer(e)]);
421 let oid_rsa_encryption: [u8; 11] = [
422 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x01,
423 ];
424 let null_params: [u8; 2] = [0x05, 0x00];
425 let alg_id = der_sequence(&[oid_rsa_encryption.to_vec(), null_params.to_vec()]);
426 let mut bit_string_body = Vec::with_capacity(1 + rsa_public_key.len());
427 bit_string_body.push(0x00);
428 bit_string_body.extend_from_slice(&rsa_public_key);
429 let mut bit_string = Vec::with_capacity(2 + bit_string_body.len());
430 bit_string.push(0x03);
431 bit_string.extend_from_slice(&der_len(bit_string_body.len()));
432 bit_string.extend_from_slice(&bit_string_body);
433 der_sequence(&[alg_id, bit_string])
434}
435
436fn der_sequence(parts: &[Vec<u8>]) -> Vec<u8> {
437 let body: Vec<u8> = parts.iter().flat_map(|p| p.clone()).collect();
438 let mut out = Vec::with_capacity(2 + body.len());
439 out.push(0x30);
440 out.extend_from_slice(&der_len(body.len()));
441 out.extend_from_slice(&body);
442 out
443}
444
445fn der_integer(bytes: &[u8]) -> Vec<u8> {
446 let mut start = 0usize;
447 while start < bytes.len() - 1 && bytes[start] == 0 {
448 start += 1;
449 }
450 let payload = &bytes[start..];
451 let needs_pad = payload[0] & 0x80 != 0;
452 let len = payload.len() + if needs_pad { 1 } else { 0 };
453 let mut out = Vec::with_capacity(2 + len);
454 out.push(0x02);
455 out.extend_from_slice(&der_len(len));
456 if needs_pad {
457 out.push(0x00);
458 }
459 out.extend_from_slice(payload);
460 out
461}
462
463fn der_len(n: usize) -> Vec<u8> {
464 if n < 0x80 {
465 return vec![n as u8];
466 }
467 let mut bytes = Vec::new();
468 let mut v = n;
469 while v > 0 {
470 bytes.insert(0, (v & 0xff) as u8);
471 v >>= 8;
472 }
473 let mut out = Vec::with_capacity(1 + bytes.len());
474 out.push(0x80 | bytes.len() as u8);
475 out.extend_from_slice(&bytes);
476 out
477}