1use std::collections::HashMap;
12use std::sync::{Arc, RwLock};
13use std::time::{Duration, SystemTime, UNIX_EPOCH};
14
15use async_trait::async_trait;
16use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, decode_header};
17use serde::Deserialize;
18use tonic::metadata::MetadataMap;
19
20use super::{
21 AuthCtx, AuthError, PrincipalKind, RawToken, ServiceTokenMinter, TokenExtractor, TokenVerifier,
22};
23
24#[derive(Clone, Copy, Debug, Default)]
28pub struct BearerHeaderExtractor;
29
30impl TokenExtractor for BearerHeaderExtractor {
31 fn extract(&self, metadata: &MetadataMap) -> Result<RawToken, AuthError> {
32 let header = metadata
33 .get("authorization")
34 .ok_or(AuthError::MissingToken)?;
35 let value = header.to_str().map_err(|_| AuthError::MissingToken)?;
36 let token = value
37 .strip_prefix("Bearer ")
38 .or_else(|| value.strip_prefix("bearer "))
39 .ok_or(AuthError::MissingToken)?
40 .trim();
41 if token.is_empty() {
42 return Err(AuthError::MissingToken);
43 }
44 Ok(RawToken {
45 value: token.to_string(),
46 kind: "bearer-jwt",
47 })
48 }
49}
50
51pub struct JwtValidator {
64 config: JwtConfig,
65 keys: Arc<RwLock<JwksCache>>,
66 http: reqwest::Client,
67}
68
69#[derive(Clone)]
70struct JwtConfig {
71 issuer: String,
72 audience: String,
73 jwks_url: Option<String>,
74 jwks_ttl: Duration,
75 insecure_dev: bool,
76 static_key: Option<DecodingKey>,
81 static_alg: Algorithm,
83}
84
85#[derive(Default)]
86struct JwksCache {
87 keys: HashMap<String, DecodingKey>,
89 fetched_at: Option<SystemTime>,
90}
91
92impl JwtValidator {
93 pub fn from_env() -> Result<Self, AuthError> {
95 let insecure_dev = std::env::var("TONIN_AUTH_INSECURE_DEV").ok().as_deref() == Some("1");
96 let issuer = std::env::var("TONIN_AUTH_ISSUER").ok();
97 let audience = std::env::var("TONIN_AUTH_AUDIENCE").ok();
98 let jwks_url = std::env::var("TONIN_AUTH_JWKS_URL").ok();
99 let ttl_secs = std::env::var("TONIN_AUTH_JWKS_TTL_SECS")
100 .ok()
101 .and_then(|s| s.parse::<u64>().ok())
102 .unwrap_or(600);
103
104 if insecure_dev {
105 tracing::warn!(
106 "TONIN_AUTH_INSECURE_DEV=1 — JWT signatures NOT verified. Local dev only."
107 );
108 return Ok(Self::insecure_dev_inner(
109 issuer.unwrap_or_default(),
110 audience.unwrap_or_default(),
111 ));
112 }
113
114 let issuer = issuer.ok_or_else(|| {
115 AuthError::Config(
116 "TONIN_AUTH_ISSUER unset (set TONIN_AUTH_INSECURE_DEV=1 for dev)".into(),
117 )
118 })?;
119 let audience =
120 audience.ok_or_else(|| AuthError::Config("TONIN_AUTH_AUDIENCE unset".into()))?;
121 let jwks_url =
122 jwks_url.ok_or_else(|| AuthError::Config("TONIN_AUTH_JWKS_URL unset".into()))?;
123
124 Ok(Self {
125 config: JwtConfig {
126 issuer,
127 audience,
128 jwks_url: Some(jwks_url),
129 jwks_ttl: Duration::from_secs(ttl_secs),
130 insecure_dev: false,
131 static_key: None,
132 static_alg: Algorithm::RS256,
133 },
134 keys: Arc::new(RwLock::new(JwksCache::default())),
135 http: reqwest::Client::builder()
136 .timeout(Duration::from_secs(5))
137 .build()
138 .map_err(|e| AuthError::Config(format!("http client init: {e}")))?,
139 })
140 }
141
142 pub fn insecure_dev() -> Self {
146 Self::insecure_dev_inner(String::new(), String::new())
147 }
148
149 fn insecure_dev_inner(issuer: String, audience: String) -> Self {
150 Self {
151 config: JwtConfig {
152 issuer,
153 audience,
154 jwks_url: None,
155 jwks_ttl: Duration::from_secs(0),
156 insecure_dev: true,
157 static_key: None,
158 static_alg: Algorithm::RS256,
159 },
160 keys: Arc::new(RwLock::new(JwksCache::default())),
161 http: reqwest::Client::new(),
162 }
163 }
164
165 #[cfg(test)]
168 pub(crate) fn with_static_key(
169 issuer: String,
170 audience: String,
171 key: DecodingKey,
172 alg: Algorithm,
173 ) -> Self {
174 Self {
175 config: JwtConfig {
176 issuer,
177 audience,
178 jwks_url: None,
179 jwks_ttl: Duration::from_secs(0),
180 insecure_dev: false,
181 static_key: Some(key),
182 static_alg: alg,
183 },
184 keys: Arc::new(RwLock::new(JwksCache::default())),
185 http: reqwest::Client::new(),
186 }
187 }
188
189 async fn resolve_key(&self, kid: Option<&str>) -> Result<DecodingKey, AuthError> {
190 if let Some(k) = &self.config.static_key {
192 return Ok(k.clone());
193 }
194 let jwks_url = self
195 .config
196 .jwks_url
197 .as_deref()
198 .ok_or_else(|| AuthError::Config("no JWKS URL configured".into()))?;
199
200 if let Some(kid) = kid {
202 let cache = self.keys.read().expect("jwks cache poisoned");
203 if let Some(k) = cache.keys.get(kid)
204 && let Some(fetched) = cache.fetched_at
205 && SystemTime::now()
206 .duration_since(fetched)
207 .unwrap_or_default()
208 < self.config.jwks_ttl
209 {
210 return Ok(k.clone());
211 }
212 }
213
214 self.refresh_jwks(jwks_url).await?;
216
217 let cache = self.keys.read().expect("jwks cache poisoned");
218 match kid {
219 Some(k) => cache
220 .keys
221 .get(k)
222 .cloned()
223 .ok_or_else(|| AuthError::Verification(format!("no JWKS key for kid={k}"))),
224 None => cache
225 .keys
226 .values()
227 .next()
228 .cloned()
229 .ok_or_else(|| AuthError::Verification("JWKS empty".into())),
230 }
231 }
232
233 async fn refresh_jwks(&self, url: &str) -> Result<(), AuthError> {
234 let resp = self
235 .http
236 .get(url)
237 .send()
238 .await
239 .map_err(|e| AuthError::Transport(e.to_string()))?;
240 if !resp.status().is_success() {
241 return Err(AuthError::Transport(format!(
242 "JWKS fetch failed: HTTP {}",
243 resp.status()
244 )));
245 }
246 let jwks: Jwks = resp
247 .json()
248 .await
249 .map_err(|e| AuthError::Verification(format!("JWKS parse: {e}")))?;
250
251 let mut new_keys = HashMap::new();
252 for k in jwks.keys {
253 if let (Some(kid), Some(n), Some(e)) = (k.kid, k.n, k.e)
254 && let Ok(dk) = DecodingKey::from_rsa_components(&n, &e)
255 {
256 new_keys.insert(kid, dk);
257 }
258 }
259
260 let mut cache = self.keys.write().expect("jwks cache poisoned");
261 cache.keys = new_keys;
262 cache.fetched_at = Some(SystemTime::now());
263 Ok(())
264 }
265}
266
267#[derive(Deserialize)]
268struct Jwks {
269 keys: Vec<Jwk>,
270}
271
272#[derive(Deserialize)]
273struct Jwk {
274 kid: Option<String>,
275 n: Option<String>,
276 e: Option<String>,
277}
278
279#[derive(Deserialize, Debug)]
281struct Claims {
282 sub: String,
283 iss: String,
284 aud: AudClaim,
285 exp: i64,
286 #[serde(default)]
287 scope: Option<String>,
288 #[serde(default)]
289 scopes: Option<Vec<String>>,
290 #[serde(default)]
292 kind: Option<String>,
293 #[serde(flatten)]
294 extra: HashMap<String, serde_json::Value>,
295}
296
297#[derive(Deserialize, Debug)]
299#[serde(untagged)]
300enum AudClaim {
301 Single(String),
302 Multi(Vec<String>),
303}
304
305impl AudClaim {
306 fn first(&self) -> String {
307 match self {
308 AudClaim::Single(s) => s.clone(),
309 AudClaim::Multi(v) => v.first().cloned().unwrap_or_default(),
310 }
311 }
312}
313
314#[async_trait]
315impl TokenVerifier for JwtValidator {
316 async fn verify(&self, token: &RawToken) -> Result<AuthCtx, AuthError> {
317 if self.config.insecure_dev {
318 return verify_insecure(&token.value, &self.config);
319 }
320
321 let header = decode_header(&token.value).map_err(|e| match e.kind() {
322 jsonwebtoken::errors::ErrorKind::InvalidToken => {
323 AuthError::Verification("malformed".into())
324 }
325 _ => AuthError::Verification(e.to_string()),
326 })?;
327
328 let key = self.resolve_key(header.kid.as_deref()).await?;
329 let alg = if self.config.static_key.is_some() {
330 self.config.static_alg
331 } else {
332 header.alg
333 };
334
335 let mut validation = Validation::new(alg);
336 validation.set_audience(&[&self.config.audience]);
337 validation.set_issuer(&[&self.config.issuer]);
338 validation.validate_exp = true;
339
340 let data =
341 decode::<Claims>(&token.value, &key, &validation).map_err(|e| match e.kind() {
342 jsonwebtoken::errors::ErrorKind::InvalidSignature => AuthError::Signature,
343 jsonwebtoken::errors::ErrorKind::ExpiredSignature => AuthError::Expired,
344 jsonwebtoken::errors::ErrorKind::InvalidAudience => AuthError::Audience {
345 expected: self.config.audience.clone(),
346 got: "(rejected by validator)".into(),
347 },
348 jsonwebtoken::errors::ErrorKind::InvalidIssuer => AuthError::Issuer {
349 expected: self.config.issuer.clone(),
350 got: "(rejected by validator)".into(),
351 },
352 _ => AuthError::Verification(e.to_string()),
353 })?;
354
355 Ok(claims_to_authctx(data.claims, &token.value))
356 }
357}
358
359fn verify_insecure(jwt: &str, cfg: &JwtConfig) -> Result<AuthCtx, AuthError> {
360 let parts: Vec<&str> = jwt.split('.').collect();
362 if parts.len() != 3 {
363 return Err(AuthError::Verification("not a JWT".into()));
364 }
365 let payload = base64_url_decode(parts[1])
366 .map_err(|e| AuthError::Verification(format!("payload base64: {e}")))?;
367 let claims: Claims = serde_json::from_slice(&payload)
368 .map_err(|e| AuthError::Verification(format!("payload json: {e}")))?;
369 let ctx = claims_to_authctx(claims, jwt);
370 if !cfg.issuer.is_empty() && ctx.issuer != cfg.issuer {
372 return Err(AuthError::Issuer {
373 expected: cfg.issuer.clone(),
374 got: ctx.issuer,
375 });
376 }
377 if !cfg.audience.is_empty() && ctx.audience != cfg.audience {
378 return Err(AuthError::Audience {
379 expected: cfg.audience.clone(),
380 got: ctx.audience,
381 });
382 }
383 tracing::warn!(subject = %ctx.subject, "INSECURE_DEV: accepted unsigned JWT");
384 Ok(ctx)
385}
386
387fn base64_url_decode(s: &str) -> Result<Vec<u8>, String> {
388 use base64::Engine;
389 base64::engine::general_purpose::URL_SAFE_NO_PAD
390 .decode(s)
391 .map_err(|e| e.to_string())
392}
393
394fn claims_to_authctx(c: Claims, raw: &str) -> AuthCtx {
395 let mut scopes = Vec::new();
396 if let Some(s) = c.scope {
397 scopes.extend(s.split_whitespace().map(String::from));
398 }
399 if let Some(v) = c.scopes {
400 scopes.extend(v);
401 }
402 let kind = match c.kind.as_deref() {
403 Some("service") => PrincipalKind::Service,
404 Some("agent") => PrincipalKind::Agent,
405 _ => PrincipalKind::User,
406 };
407 AuthCtx {
408 subject: c.sub,
409 issuer: c.iss,
410 audience: c.aud.first(),
411 scopes,
412 kind,
413 raw_token: raw.to_string(),
414 expires_at: c.exp.max(0) as f64,
415 extra: c.extra,
416 }
417}
418
419pub struct HttpServiceTokenMinter {
430 url: String,
431 audience: String,
432 scopes: Vec<String>,
433 http: reqwest::Client,
434 cached: tokio::sync::RwLock<Option<AuthCtx>>,
435}
436
437impl HttpServiceTokenMinter {
438 pub fn from_env() -> Result<Self, AuthError> {
439 let url = std::env::var("TONIN_AUTH_SERVICE_TOKEN_URL")
440 .map_err(|_| AuthError::Config("TONIN_AUTH_SERVICE_TOKEN_URL unset".into()))?;
441 let audience = std::env::var("TONIN_AUTH_SERVICE_AUDIENCE").unwrap_or_default();
442 let scopes = std::env::var("TONIN_AUTH_SERVICE_TOKEN_SCOPES")
443 .ok()
444 .map(|s| s.split(',').map(|s| s.trim().to_string()).collect())
445 .unwrap_or_default();
446 Ok(Self {
447 url,
448 audience,
449 scopes,
450 http: reqwest::Client::builder()
451 .timeout(Duration::from_secs(5))
452 .build()
453 .map_err(|e| AuthError::Config(format!("http client: {e}")))?,
454 cached: tokio::sync::RwLock::new(None),
455 })
456 }
457}
458
459#[derive(serde::Serialize)]
460struct MintRequest<'a> {
461 audience: &'a str,
462 scopes: &'a [String],
463}
464
465#[derive(Deserialize)]
466struct MintResponse {
467 token: String,
468 #[serde(default)]
470 expires_in: Option<u64>,
471}
472
473#[async_trait]
474impl ServiceTokenMinter for HttpServiceTokenMinter {
475 async fn mint(&self) -> Result<AuthCtx, AuthError> {
476 {
478 let cached = self.cached.read().await;
479 if let Some(ctx) = cached.as_ref() {
480 let now = SystemTime::now()
481 .duration_since(UNIX_EPOCH)
482 .map(|d| d.as_secs_f64())
483 .unwrap_or(0.0);
484 if ctx.expires_at - now > 60.0 {
485 return Ok(ctx.clone());
486 }
487 }
488 }
489
490 let body = MintRequest {
491 audience: &self.audience,
492 scopes: &self.scopes,
493 };
494 let resp = self
495 .http
496 .post(&self.url)
497 .json(&body)
498 .send()
499 .await
500 .map_err(|e| AuthError::Transport(e.to_string()))?;
501 if !resp.status().is_success() {
502 return Err(AuthError::Transport(format!(
503 "service-token mint failed: HTTP {}",
504 resp.status()
505 )));
506 }
507 let body: MintResponse = resp
508 .json()
509 .await
510 .map_err(|e| AuthError::Verification(format!("mint response: {e}")))?;
511
512 let now_secs = SystemTime::now()
513 .duration_since(UNIX_EPOCH)
514 .map(|d| d.as_secs_f64())
515 .unwrap_or(0.0);
516 let expires_at = now_secs + body.expires_in.unwrap_or(3600) as f64;
517
518 let ctx = AuthCtx {
519 subject: "service".into(),
520 issuer: "micro-auth-svc".into(),
521 audience: self.audience.clone(),
522 scopes: self.scopes.clone(),
523 kind: PrincipalKind::Service,
524 raw_token: body.token,
525 expires_at,
526 extra: HashMap::new(),
527 };
528 *self.cached.write().await = Some(ctx.clone());
529 Ok(ctx)
530 }
531}
532
533#[cfg(test)]
534mod tests {
535 use super::*;
536 use jsonwebtoken::{EncodingKey, Header, encode};
537
538 fn signing_keypair() -> (EncodingKey, DecodingKey) {
539 let secret = b"a-test-secret-at-least-32-bytes-long-please";
546 (
547 EncodingKey::from_secret(secret),
548 DecodingKey::from_secret(secret),
549 )
550 }
551
552 fn build_jwt(
553 signing: &EncodingKey,
554 sub: &str,
555 iss: &str,
556 aud: &str,
557 scopes: &[&str],
558 ttl_secs: i64,
559 ) -> String {
560 #[derive(serde::Serialize)]
561 struct Cl<'a> {
562 sub: &'a str,
563 iss: &'a str,
564 aud: &'a str,
565 exp: i64,
566 scope: String,
567 }
568 let exp = chrono_now() + ttl_secs;
569 let cl = Cl {
570 sub,
571 iss,
572 aud,
573 exp,
574 scope: scopes.join(" "),
575 };
576 encode(&Header::new(Algorithm::HS256), &cl, signing).unwrap()
577 }
578
579 fn chrono_now() -> i64 {
580 SystemTime::now()
581 .duration_since(UNIX_EPOCH)
582 .unwrap()
583 .as_secs() as i64
584 }
585
586 #[tokio::test]
587 async fn jwt_validator_accepts_valid_token() {
588 let (signing, verifying) = signing_keypair();
589 let v = JwtValidator::with_static_key(
590 "https://auth.example.com".into(),
591 "billing-service".into(),
592 verifying,
593 Algorithm::HS256,
594 );
595 let jwt = build_jwt(
596 &signing,
597 "alice",
598 "https://auth.example.com",
599 "billing-service",
600 &["read:billing", "write:billing"],
601 300,
602 );
603 let token = RawToken {
604 value: jwt,
605 kind: "bearer-jwt",
606 };
607 let ctx = v.verify(&token).await.unwrap();
608 assert_eq!(ctx.subject, "alice");
609 assert_eq!(ctx.audience, "billing-service");
610 assert!(ctx.scopes.contains(&"read:billing".to_string()));
611 }
612
613 #[tokio::test]
614 async fn jwt_validator_rejects_expired_token() {
615 let (signing, verifying) = signing_keypair();
616 let v = JwtValidator::with_static_key(
617 "https://auth.example.com".into(),
618 "billing-service".into(),
619 verifying,
620 Algorithm::HS256,
621 );
622 let jwt = build_jwt(
625 &signing,
626 "alice",
627 "https://auth.example.com",
628 "billing-service",
629 &[],
630 -3600,
631 );
632 let token = RawToken {
633 value: jwt,
634 kind: "bearer-jwt",
635 };
636 let err = v.verify(&token).await.unwrap_err();
637 assert!(matches!(err, AuthError::Expired), "got {err:?}");
638 }
639
640 #[tokio::test]
641 async fn jwt_validator_rejects_wrong_audience() {
642 let (signing, verifying) = signing_keypair();
643 let v = JwtValidator::with_static_key(
644 "https://auth.example.com".into(),
645 "billing-service".into(),
646 verifying,
647 Algorithm::HS256,
648 );
649 let jwt = build_jwt(
650 &signing,
651 "alice",
652 "https://auth.example.com",
653 "WRONG",
654 &[],
655 300,
656 );
657 let token = RawToken {
658 value: jwt,
659 kind: "bearer-jwt",
660 };
661 let err = v.verify(&token).await.unwrap_err();
662 assert!(matches!(err, AuthError::Audience { .. }), "got {err:?}");
663 }
664
665 #[tokio::test]
666 async fn jwt_validator_rejects_bad_signature() {
667 let (_signing, verifying) = signing_keypair();
668 let (other_signing, _) = signing_keypair();
669 let v = JwtValidator::with_static_key(
671 "https://auth.example.com".into(),
672 "billing-service".into(),
673 DecodingKey::from_secret(b"different-secret-also-32-bytes-or-more!"),
674 Algorithm::HS256,
675 );
676 let jwt = build_jwt(
677 &other_signing,
678 "alice",
679 "https://auth.example.com",
680 "billing-service",
681 &[],
682 300,
683 );
684 let token = RawToken {
685 value: jwt,
686 kind: "bearer-jwt",
687 };
688 let err = v.verify(&token).await.unwrap_err();
689 assert!(matches!(err, AuthError::Signature), "got {err:?}");
690 let _ = verifying; }
692
693 #[test]
694 fn bearer_extractor_parses_authorization() {
695 let mut md = MetadataMap::new();
696 md.insert("authorization", "Bearer test-token".parse().unwrap());
697 let t = BearerHeaderExtractor.extract(&md).unwrap();
698 assert_eq!(t.value, "test-token");
699 assert_eq!(t.kind, "bearer-jwt");
700 }
701
702 #[test]
703 fn bearer_extractor_missing_header() {
704 let md = MetadataMap::new();
705 let err = BearerHeaderExtractor.extract(&md).unwrap_err();
706 assert!(matches!(err, AuthError::MissingToken));
707 }
708}