Skip to main content

tonin_core/auth/
default.rs

1//! Default implementations of the auth traits. Users import what they
2//! want; swapping them out means dropping the corresponding deps from
3//! `Cargo.toml`.
4//!
5//! - [`BearerHeaderExtractor`] — reads `Authorization: Bearer <token>`
6//! - [`JwtValidator`] — full JWT validation (sig + exp + iss + aud) with
7//!   JWKS caching
8//! - [`HttpServiceTokenMinter`] — POSTs to an auth-service endpoint to
9//!   obtain a service-identity token
10
11use std::collections::HashMap;
12use std::sync::{Arc, RwLock};
13use std::time::{Duration, SystemTime, UNIX_EPOCH};
14
15use async_trait::async_trait;
16use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, decode_header};
17use serde::Deserialize;
18use tonic::metadata::MetadataMap;
19
20use super::{
21    AuthCtx, AuthError, PrincipalKind, RawToken, ServiceTokenMinter, TokenExtractor, TokenVerifier,
22};
23
24// ---------- BearerHeaderExtractor ----------
25
26/// Reads `Authorization: Bearer <token>` from request metadata.
27#[derive(Clone, Copy, Debug, Default)]
28pub struct BearerHeaderExtractor;
29
30impl TokenExtractor for BearerHeaderExtractor {
31    fn extract(&self, metadata: &MetadataMap) -> Result<RawToken, AuthError> {
32        let header = metadata
33            .get("authorization")
34            .ok_or(AuthError::MissingToken)?;
35        let value = header.to_str().map_err(|_| AuthError::MissingToken)?;
36        let token = value
37            .strip_prefix("Bearer ")
38            .or_else(|| value.strip_prefix("bearer "))
39            .ok_or(AuthError::MissingToken)?
40            .trim();
41        if token.is_empty() {
42            return Err(AuthError::MissingToken);
43        }
44        Ok(RawToken {
45            value: token.to_string(),
46            kind: "bearer-jwt",
47        })
48    }
49}
50
51// ---------- JwtValidator ----------
52
53/// JWT validator with JWKS caching.
54///
55/// Configuration via env (consumed by [`Self::from_env`]):
56/// - `TONIN_AUTH_ISSUER` — required `iss` claim
57/// - `TONIN_AUTH_AUDIENCE` — required `aud` claim
58/// - `TONIN_AUTH_JWKS_URL` — public-key endpoint; refreshed on cache miss
59///   (and at most once per `TONIN_AUTH_JWKS_TTL_SECS`, default 600)
60/// - `TONIN_AUTH_INSECURE_DEV` — when `=1`, skip signature/expiry/aud/iss
61///   and accept any well-formed JWT. **Logs a loud warning at startup
62///   and on every call.** Local dev only.
63pub struct JwtValidator {
64    config: JwtConfig,
65    keys: Arc<RwLock<JwksCache>>,
66    http: reqwest::Client,
67}
68
69#[derive(Clone)]
70struct JwtConfig {
71    issuer: String,
72    audience: String,
73    jwks_url: Option<String>,
74    jwks_ttl: Duration,
75    insecure_dev: bool,
76    /// Static decoding key for unit tests: when set, skip JWKS fetch and
77    /// use this key directly. Never set via env; only by test helpers.
78    /// (`DecodingKey` itself is not `Debug`, hence no `Debug` derive
79    /// on the struct.)
80    static_key: Option<DecodingKey>,
81    /// Algorithm for static-key tests. JWKS-served keys carry their own alg.
82    static_alg: Algorithm,
83}
84
85#[derive(Default)]
86struct JwksCache {
87    /// key-id → DecodingKey
88    keys: HashMap<String, DecodingKey>,
89    fetched_at: Option<SystemTime>,
90}
91
92impl JwtValidator {
93    /// Build from environment variables. See type docs for the recognized vars.
94    pub fn from_env() -> Result<Self, AuthError> {
95        let insecure_dev = std::env::var("TONIN_AUTH_INSECURE_DEV").ok().as_deref() == Some("1");
96        let issuer = std::env::var("TONIN_AUTH_ISSUER").ok();
97        let audience = std::env::var("TONIN_AUTH_AUDIENCE").ok();
98        let jwks_url = std::env::var("TONIN_AUTH_JWKS_URL").ok();
99        let ttl_secs = std::env::var("TONIN_AUTH_JWKS_TTL_SECS")
100            .ok()
101            .and_then(|s| s.parse::<u64>().ok())
102            .unwrap_or(600);
103
104        if insecure_dev {
105            tracing::warn!(
106                "TONIN_AUTH_INSECURE_DEV=1 — JWT signatures NOT verified. Local dev only."
107            );
108            return Ok(Self::insecure_dev_inner(
109                issuer.unwrap_or_default(),
110                audience.unwrap_or_default(),
111            ));
112        }
113
114        let issuer = issuer.ok_or_else(|| {
115            AuthError::Config(
116                "TONIN_AUTH_ISSUER unset (set TONIN_AUTH_INSECURE_DEV=1 for dev)".into(),
117            )
118        })?;
119        let audience =
120            audience.ok_or_else(|| AuthError::Config("TONIN_AUTH_AUDIENCE unset".into()))?;
121        let jwks_url =
122            jwks_url.ok_or_else(|| AuthError::Config("TONIN_AUTH_JWKS_URL unset".into()))?;
123
124        Ok(Self {
125            config: JwtConfig {
126                issuer,
127                audience,
128                jwks_url: Some(jwks_url),
129                jwks_ttl: Duration::from_secs(ttl_secs),
130                insecure_dev: false,
131                static_key: None,
132                static_alg: Algorithm::RS256,
133            },
134            keys: Arc::new(RwLock::new(JwksCache::default())),
135            http: reqwest::Client::builder()
136                .timeout(Duration::from_secs(5))
137                .build()
138                .map_err(|e| AuthError::Config(format!("http client init: {e}")))?,
139        })
140    }
141
142    /// Parse-only fallback: well-formed JWT accepted, no signature check.
143    /// **Local dev only.** Triggered by `TONIN_AUTH_INSECURE_DEV=1` or by
144    /// calling this constructor directly.
145    pub fn insecure_dev() -> Self {
146        Self::insecure_dev_inner(String::new(), String::new())
147    }
148
149    fn insecure_dev_inner(issuer: String, audience: String) -> Self {
150        Self {
151            config: JwtConfig {
152                issuer,
153                audience,
154                jwks_url: None,
155                jwks_ttl: Duration::from_secs(0),
156                insecure_dev: true,
157                static_key: None,
158                static_alg: Algorithm::RS256,
159            },
160            keys: Arc::new(RwLock::new(JwksCache::default())),
161            http: reqwest::Client::new(),
162        }
163    }
164
165    /// Test helper: build a validator that uses a fixed `DecodingKey` and
166    /// skips JWKS fetch. Audience + issuer still enforced.
167    #[cfg(test)]
168    pub(crate) fn with_static_key(
169        issuer: String,
170        audience: String,
171        key: DecodingKey,
172        alg: Algorithm,
173    ) -> Self {
174        Self {
175            config: JwtConfig {
176                issuer,
177                audience,
178                jwks_url: None,
179                jwks_ttl: Duration::from_secs(0),
180                insecure_dev: false,
181                static_key: Some(key),
182                static_alg: alg,
183            },
184            keys: Arc::new(RwLock::new(JwksCache::default())),
185            http: reqwest::Client::new(),
186        }
187    }
188
189    async fn resolve_key(&self, kid: Option<&str>) -> Result<DecodingKey, AuthError> {
190        // Static key takes precedence (tests).
191        if let Some(k) = &self.config.static_key {
192            return Ok(k.clone());
193        }
194        let jwks_url = self
195            .config
196            .jwks_url
197            .as_deref()
198            .ok_or_else(|| AuthError::Config("no JWKS URL configured".into()))?;
199
200        // Cache hit?
201        if let Some(kid) = kid {
202            let cache = self.keys.read().expect("jwks cache poisoned");
203            if let Some(k) = cache.keys.get(kid)
204                && let Some(fetched) = cache.fetched_at
205                && SystemTime::now()
206                    .duration_since(fetched)
207                    .unwrap_or_default()
208                    < self.config.jwks_ttl
209            {
210                return Ok(k.clone());
211            }
212        }
213
214        // Miss → refresh.
215        self.refresh_jwks(jwks_url).await?;
216
217        let cache = self.keys.read().expect("jwks cache poisoned");
218        match kid {
219            Some(k) => cache
220                .keys
221                .get(k)
222                .cloned()
223                .ok_or_else(|| AuthError::Verification(format!("no JWKS key for kid={k}"))),
224            None => cache
225                .keys
226                .values()
227                .next()
228                .cloned()
229                .ok_or_else(|| AuthError::Verification("JWKS empty".into())),
230        }
231    }
232
233    async fn refresh_jwks(&self, url: &str) -> Result<(), AuthError> {
234        let resp = self
235            .http
236            .get(url)
237            .send()
238            .await
239            .map_err(|e| AuthError::Transport(e.to_string()))?;
240        if !resp.status().is_success() {
241            return Err(AuthError::Transport(format!(
242                "JWKS fetch failed: HTTP {}",
243                resp.status()
244            )));
245        }
246        let jwks: Jwks = resp
247            .json()
248            .await
249            .map_err(|e| AuthError::Verification(format!("JWKS parse: {e}")))?;
250
251        let mut new_keys = HashMap::new();
252        for k in jwks.keys {
253            if let (Some(kid), Some(n), Some(e)) = (k.kid, k.n, k.e)
254                && let Ok(dk) = DecodingKey::from_rsa_components(&n, &e)
255            {
256                new_keys.insert(kid, dk);
257            }
258        }
259
260        let mut cache = self.keys.write().expect("jwks cache poisoned");
261        cache.keys = new_keys;
262        cache.fetched_at = Some(SystemTime::now());
263        Ok(())
264    }
265}
266
267#[derive(Deserialize)]
268struct Jwks {
269    keys: Vec<Jwk>,
270}
271
272#[derive(Deserialize)]
273struct Jwk {
274    kid: Option<String>,
275    n: Option<String>,
276    e: Option<String>,
277}
278
279/// Standard JWT claims we extract into `AuthCtx`.
280#[derive(Deserialize, Debug)]
281struct Claims {
282    sub: String,
283    iss: String,
284    aud: AudClaim,
285    exp: i64,
286    #[serde(default)]
287    scope: Option<String>,
288    #[serde(default)]
289    scopes: Option<Vec<String>>,
290    /// `kind` is non-standard; agnitiv-style claim distinguishing user vs service.
291    #[serde(default)]
292    kind: Option<String>,
293    #[serde(flatten)]
294    extra: HashMap<String, serde_json::Value>,
295}
296
297/// `aud` may be a string or array. Handle both.
298#[derive(Deserialize, Debug)]
299#[serde(untagged)]
300enum AudClaim {
301    Single(String),
302    Multi(Vec<String>),
303}
304
305impl AudClaim {
306    fn first(&self) -> String {
307        match self {
308            AudClaim::Single(s) => s.clone(),
309            AudClaim::Multi(v) => v.first().cloned().unwrap_or_default(),
310        }
311    }
312}
313
314#[async_trait]
315impl TokenVerifier for JwtValidator {
316    async fn verify(&self, token: &RawToken) -> Result<AuthCtx, AuthError> {
317        if self.config.insecure_dev {
318            return verify_insecure(&token.value, &self.config);
319        }
320
321        let header = decode_header(&token.value).map_err(|e| match e.kind() {
322            jsonwebtoken::errors::ErrorKind::InvalidToken => {
323                AuthError::Verification("malformed".into())
324            }
325            _ => AuthError::Verification(e.to_string()),
326        })?;
327
328        let key = self.resolve_key(header.kid.as_deref()).await?;
329        let alg = if self.config.static_key.is_some() {
330            self.config.static_alg
331        } else {
332            header.alg
333        };
334
335        let mut validation = Validation::new(alg);
336        validation.set_audience(&[&self.config.audience]);
337        validation.set_issuer(&[&self.config.issuer]);
338        validation.validate_exp = true;
339
340        let data =
341            decode::<Claims>(&token.value, &key, &validation).map_err(|e| match e.kind() {
342                jsonwebtoken::errors::ErrorKind::InvalidSignature => AuthError::Signature,
343                jsonwebtoken::errors::ErrorKind::ExpiredSignature => AuthError::Expired,
344                jsonwebtoken::errors::ErrorKind::InvalidAudience => AuthError::Audience {
345                    expected: self.config.audience.clone(),
346                    got: "(rejected by validator)".into(),
347                },
348                jsonwebtoken::errors::ErrorKind::InvalidIssuer => AuthError::Issuer {
349                    expected: self.config.issuer.clone(),
350                    got: "(rejected by validator)".into(),
351                },
352                _ => AuthError::Verification(e.to_string()),
353            })?;
354
355        Ok(claims_to_authctx(data.claims, &token.value))
356    }
357}
358
359fn verify_insecure(jwt: &str, cfg: &JwtConfig) -> Result<AuthCtx, AuthError> {
360    // Parse-only: split, base64-decode the payload, deserialize.
361    let parts: Vec<&str> = jwt.split('.').collect();
362    if parts.len() != 3 {
363        return Err(AuthError::Verification("not a JWT".into()));
364    }
365    let payload = base64_url_decode(parts[1])
366        .map_err(|e| AuthError::Verification(format!("payload base64: {e}")))?;
367    let claims: Claims = serde_json::from_slice(&payload)
368        .map_err(|e| AuthError::Verification(format!("payload json: {e}")))?;
369    let ctx = claims_to_authctx(claims, jwt);
370    // Best-effort issuer/audience enforcement even in dev, if configured.
371    if !cfg.issuer.is_empty() && ctx.issuer != cfg.issuer {
372        return Err(AuthError::Issuer {
373            expected: cfg.issuer.clone(),
374            got: ctx.issuer,
375        });
376    }
377    if !cfg.audience.is_empty() && ctx.audience != cfg.audience {
378        return Err(AuthError::Audience {
379            expected: cfg.audience.clone(),
380            got: ctx.audience,
381        });
382    }
383    tracing::warn!(subject = %ctx.subject, "INSECURE_DEV: accepted unsigned JWT");
384    Ok(ctx)
385}
386
387fn base64_url_decode(s: &str) -> Result<Vec<u8>, String> {
388    use base64::Engine;
389    base64::engine::general_purpose::URL_SAFE_NO_PAD
390        .decode(s)
391        .map_err(|e| e.to_string())
392}
393
394fn claims_to_authctx(c: Claims, raw: &str) -> AuthCtx {
395    let mut scopes = Vec::new();
396    if let Some(s) = c.scope {
397        scopes.extend(s.split_whitespace().map(String::from));
398    }
399    if let Some(v) = c.scopes {
400        scopes.extend(v);
401    }
402    let kind = match c.kind.as_deref() {
403        Some("service") => PrincipalKind::Service,
404        Some("agent") => PrincipalKind::Agent,
405        _ => PrincipalKind::User,
406    };
407    AuthCtx {
408        subject: c.sub,
409        issuer: c.iss,
410        audience: c.aud.first(),
411        scopes,
412        kind,
413        raw_token: raw.to_string(),
414        expires_at: c.exp.max(0) as f64,
415        extra: c.extra,
416    }
417}
418
419// ---------- HttpServiceTokenMinter ----------
420
421/// Mints service-identity tokens by POSTing to an auth-service endpoint.
422///
423/// Configuration:
424/// - `TONIN_AUTH_SERVICE_TOKEN_URL` — endpoint to POST to (required)
425/// - `TONIN_AUTH_SERVICE_AUDIENCE` — `aud` to request (defaults to service's name)
426/// - `TONIN_AUTH_SERVICE_TOKEN_SCOPES` — comma-separated scopes
427///
428/// The minter caches the token in memory and refreshes 60s before expiry.
429pub struct HttpServiceTokenMinter {
430    url: String,
431    audience: String,
432    scopes: Vec<String>,
433    http: reqwest::Client,
434    cached: tokio::sync::RwLock<Option<AuthCtx>>,
435}
436
437impl HttpServiceTokenMinter {
438    pub fn from_env() -> Result<Self, AuthError> {
439        let url = std::env::var("TONIN_AUTH_SERVICE_TOKEN_URL")
440            .map_err(|_| AuthError::Config("TONIN_AUTH_SERVICE_TOKEN_URL unset".into()))?;
441        let audience = std::env::var("TONIN_AUTH_SERVICE_AUDIENCE").unwrap_or_default();
442        let scopes = std::env::var("TONIN_AUTH_SERVICE_TOKEN_SCOPES")
443            .ok()
444            .map(|s| s.split(',').map(|s| s.trim().to_string()).collect())
445            .unwrap_or_default();
446        Ok(Self {
447            url,
448            audience,
449            scopes,
450            http: reqwest::Client::builder()
451                .timeout(Duration::from_secs(5))
452                .build()
453                .map_err(|e| AuthError::Config(format!("http client: {e}")))?,
454            cached: tokio::sync::RwLock::new(None),
455        })
456    }
457}
458
459#[derive(serde::Serialize)]
460struct MintRequest<'a> {
461    audience: &'a str,
462    scopes: &'a [String],
463}
464
465#[derive(Deserialize)]
466struct MintResponse {
467    token: String,
468    /// Optional expiry hint in seconds; not strictly required.
469    #[serde(default)]
470    expires_in: Option<u64>,
471}
472
473#[async_trait]
474impl ServiceTokenMinter for HttpServiceTokenMinter {
475    async fn mint(&self) -> Result<AuthCtx, AuthError> {
476        // Return cached if still valid (>60s remaining).
477        {
478            let cached = self.cached.read().await;
479            if let Some(ctx) = cached.as_ref() {
480                let now = SystemTime::now()
481                    .duration_since(UNIX_EPOCH)
482                    .map(|d| d.as_secs_f64())
483                    .unwrap_or(0.0);
484                if ctx.expires_at - now > 60.0 {
485                    return Ok(ctx.clone());
486                }
487            }
488        }
489
490        let body = MintRequest {
491            audience: &self.audience,
492            scopes: &self.scopes,
493        };
494        let resp = self
495            .http
496            .post(&self.url)
497            .json(&body)
498            .send()
499            .await
500            .map_err(|e| AuthError::Transport(e.to_string()))?;
501        if !resp.status().is_success() {
502            return Err(AuthError::Transport(format!(
503                "service-token mint failed: HTTP {}",
504                resp.status()
505            )));
506        }
507        let body: MintResponse = resp
508            .json()
509            .await
510            .map_err(|e| AuthError::Verification(format!("mint response: {e}")))?;
511
512        let now_secs = SystemTime::now()
513            .duration_since(UNIX_EPOCH)
514            .map(|d| d.as_secs_f64())
515            .unwrap_or(0.0);
516        let expires_at = now_secs + body.expires_in.unwrap_or(3600) as f64;
517
518        let ctx = AuthCtx {
519            subject: "service".into(),
520            issuer: "micro-auth-svc".into(),
521            audience: self.audience.clone(),
522            scopes: self.scopes.clone(),
523            kind: PrincipalKind::Service,
524            raw_token: body.token,
525            expires_at,
526            extra: HashMap::new(),
527        };
528        *self.cached.write().await = Some(ctx.clone());
529        Ok(ctx)
530    }
531}
532
533#[cfg(test)]
534mod tests {
535    use super::*;
536    use jsonwebtoken::{EncodingKey, Header, encode};
537
538    fn signing_keypair() -> (EncodingKey, DecodingKey) {
539        // Reuse the test key from jsonwebtoken's own examples — generated
540        // freshly here to avoid pulling in a private-key file.
541        // RSA-2048 keys: encoding from PEM, decoding from PEM.
542        // For test simplicity, use HS256 (symmetric) — same code path,
543        // just a different algorithm. JWKS path is exercised separately
544        // via integration tests.
545        let secret = b"a-test-secret-at-least-32-bytes-long-please";
546        (
547            EncodingKey::from_secret(secret),
548            DecodingKey::from_secret(secret),
549        )
550    }
551
552    fn build_jwt(
553        signing: &EncodingKey,
554        sub: &str,
555        iss: &str,
556        aud: &str,
557        scopes: &[&str],
558        ttl_secs: i64,
559    ) -> String {
560        #[derive(serde::Serialize)]
561        struct Cl<'a> {
562            sub: &'a str,
563            iss: &'a str,
564            aud: &'a str,
565            exp: i64,
566            scope: String,
567        }
568        let exp = chrono_now() + ttl_secs;
569        let cl = Cl {
570            sub,
571            iss,
572            aud,
573            exp,
574            scope: scopes.join(" "),
575        };
576        encode(&Header::new(Algorithm::HS256), &cl, signing).unwrap()
577    }
578
579    fn chrono_now() -> i64 {
580        SystemTime::now()
581            .duration_since(UNIX_EPOCH)
582            .unwrap()
583            .as_secs() as i64
584    }
585
586    #[tokio::test]
587    async fn jwt_validator_accepts_valid_token() {
588        let (signing, verifying) = signing_keypair();
589        let v = JwtValidator::with_static_key(
590            "https://auth.example.com".into(),
591            "billing-service".into(),
592            verifying,
593            Algorithm::HS256,
594        );
595        let jwt = build_jwt(
596            &signing,
597            "alice",
598            "https://auth.example.com",
599            "billing-service",
600            &["read:billing", "write:billing"],
601            300,
602        );
603        let token = RawToken {
604            value: jwt,
605            kind: "bearer-jwt",
606        };
607        let ctx = v.verify(&token).await.unwrap();
608        assert_eq!(ctx.subject, "alice");
609        assert_eq!(ctx.audience, "billing-service");
610        assert!(ctx.scopes.contains(&"read:billing".to_string()));
611    }
612
613    #[tokio::test]
614    async fn jwt_validator_rejects_expired_token() {
615        let (signing, verifying) = signing_keypair();
616        let v = JwtValidator::with_static_key(
617            "https://auth.example.com".into(),
618            "billing-service".into(),
619            verifying,
620            Algorithm::HS256,
621        );
622        // jsonwebtoken's Validation has a default leeway of 60s. Use a
623        // much-larger negative ttl so we land safely outside it.
624        let jwt = build_jwt(
625            &signing,
626            "alice",
627            "https://auth.example.com",
628            "billing-service",
629            &[],
630            -3600,
631        );
632        let token = RawToken {
633            value: jwt,
634            kind: "bearer-jwt",
635        };
636        let err = v.verify(&token).await.unwrap_err();
637        assert!(matches!(err, AuthError::Expired), "got {err:?}");
638    }
639
640    #[tokio::test]
641    async fn jwt_validator_rejects_wrong_audience() {
642        let (signing, verifying) = signing_keypair();
643        let v = JwtValidator::with_static_key(
644            "https://auth.example.com".into(),
645            "billing-service".into(),
646            verifying,
647            Algorithm::HS256,
648        );
649        let jwt = build_jwt(
650            &signing,
651            "alice",
652            "https://auth.example.com",
653            "WRONG",
654            &[],
655            300,
656        );
657        let token = RawToken {
658            value: jwt,
659            kind: "bearer-jwt",
660        };
661        let err = v.verify(&token).await.unwrap_err();
662        assert!(matches!(err, AuthError::Audience { .. }), "got {err:?}");
663    }
664
665    #[tokio::test]
666    async fn jwt_validator_rejects_bad_signature() {
667        let (_signing, verifying) = signing_keypair();
668        let (other_signing, _) = signing_keypair();
669        // Tamper: change the verifying key to something unrelated.
670        let v = JwtValidator::with_static_key(
671            "https://auth.example.com".into(),
672            "billing-service".into(),
673            DecodingKey::from_secret(b"different-secret-also-32-bytes-or-more!"),
674            Algorithm::HS256,
675        );
676        let jwt = build_jwt(
677            &other_signing,
678            "alice",
679            "https://auth.example.com",
680            "billing-service",
681            &[],
682            300,
683        );
684        let token = RawToken {
685            value: jwt,
686            kind: "bearer-jwt",
687        };
688        let err = v.verify(&token).await.unwrap_err();
689        assert!(matches!(err, AuthError::Signature), "got {err:?}");
690        let _ = verifying; // unused
691    }
692
693    #[test]
694    fn bearer_extractor_parses_authorization() {
695        let mut md = MetadataMap::new();
696        md.insert("authorization", "Bearer test-token".parse().unwrap());
697        let t = BearerHeaderExtractor.extract(&md).unwrap();
698        assert_eq!(t.value, "test-token");
699        assert_eq!(t.kind, "bearer-jwt");
700    }
701
702    #[test]
703    fn bearer_extractor_missing_header() {
704        let md = MetadataMap::new();
705        let err = BearerHeaderExtractor.extract(&md).unwrap_err();
706        assert!(matches!(err, AuthError::MissingToken));
707    }
708}