Skip to main content

tonin_core/auth/
mod.rs

1//! Authentication and authorization layer.
2//!
3//! Three traits, one concrete type:
4//!
5//! - [`TokenExtractor`] — where the token comes from on an incoming request
6//! - [`TokenVerifier`] — how the token gets validated → [`AuthCtx`]
7//! - [`ServiceTokenMinter`] — how a service identity token gets minted
8//!   (for background jobs / queue consumers with no user request to
9//!   propagate from)
10//!
11//! [`AuthCtx`] is a stable struct — not generic — so the propagation
12//! layer, generated client crates, and the [`CURRENT_AUTH`] task-local
13//! all refer to it without generic-bounds threading. **The type itself
14//! lives in `tonin-client::auth`** so client-only consumers don't
15//! have to depend on the server framework; this module re-exports it.
16//!
17//! ## Common case (one-liner)
18//!
19//! ```no_run
20//! # use tonin_core::Service;
21//! # use tonin_core::auth::default::JwtValidator;
22//! # async fn run() -> tonin_core::Result<()> {
23//! let svc = Service::new("my-svc")
24//!     .with_auth(JwtValidator::from_env()?);
25//! # Ok(()) }
26//! ```
27//!
28//! ## Custom verifier
29//!
30//! ```ignore
31//! use tonin_core::auth::{TokenVerifier, AuthCtx, AuthError, RawToken};
32//! use async_trait::async_trait;
33//!
34//! struct OktaVerifier { /* ... */ }
35//!
36//! #[async_trait]
37//! impl TokenVerifier for OktaVerifier {
38//!     async fn verify(&self, token: &RawToken) -> Result<AuthCtx, AuthError> {
39//!         // ...
40//!     }
41//! }
42//!
43//! let svc = Service::new("my-svc").with_auth(OktaVerifier::from_env());
44//! ```
45
46pub mod default;
47mod layer;
48
49use std::sync::Arc;
50
51use async_trait::async_trait;
52
53pub use layer::AuthLayer;
54
55// Re-export the shared client/server types from tonin-client. These
56// are the contract between the inbound auth layer (here) and the
57// outbound propagation in generated client SDKs (which depend on
58// tonin-client but NOT on this crate).
59pub use tonin_client::auth::{AuthCtx, AuthError, PrincipalKind, RawToken};
60
61// ---------- traits ----------
62
63/// Pulls a `RawToken` out of an incoming request's metadata.
64///
65/// Default: [`default::BearerHeaderExtractor`] reads `Authorization: Bearer <token>`.
66///
67/// The trait takes `&MetadataMap` (not `&Request<T>`) so it stays
68/// dyn-safe; the [`AuthLayer`] builds a metadata view from the raw
69/// `http::Request` before calling.
70pub trait TokenExtractor: Send + Sync + 'static {
71    fn extract(&self, metadata: &tonic::metadata::MetadataMap) -> Result<RawToken, AuthError>;
72}
73
74/// Verifies a [`RawToken`] and returns the resulting [`AuthCtx`].
75///
76/// Default: [`default::JwtValidator`] (signature + exp + iss + aud via JWKS).
77#[async_trait]
78pub trait TokenVerifier: Send + Sync + 'static {
79    async fn verify(&self, token: &RawToken) -> Result<AuthCtx, AuthError>;
80}
81
82/// Mints an [`AuthCtx`] representing this service (no user). Used by
83/// background jobs and queue consumers.
84///
85/// Default: [`default::HttpServiceTokenMinter`] POSTs to a configured
86/// auth-service endpoint.
87#[async_trait]
88pub trait ServiceTokenMinter: Send + Sync + 'static {
89    async fn mint(&self) -> Result<AuthCtx, AuthError>;
90}
91
92// ---------- composition ----------
93
94/// Try multiple verifiers in order; first success wins. If all fail,
95/// returns the last error.
96pub struct ChainVerifier {
97    inner: Vec<Arc<dyn TokenVerifier>>,
98}
99
100impl ChainVerifier {
101    pub fn new() -> Self {
102        Self { inner: Vec::new() }
103    }
104    #[allow(clippy::should_implement_trait)] // builder method; std::ops::Add is the wrong shape (consumes self, two args)
105    pub fn add<V: TokenVerifier>(mut self, v: V) -> Self {
106        self.inner.push(Arc::new(v));
107        self
108    }
109}
110
111impl Default for ChainVerifier {
112    fn default() -> Self {
113        Self::new()
114    }
115}
116
117#[async_trait]
118impl TokenVerifier for ChainVerifier {
119    async fn verify(&self, token: &RawToken) -> Result<AuthCtx, AuthError> {
120        let mut last_err = AuthError::MissingToken;
121        for v in &self.inner {
122            match v.verify(token).await {
123                Ok(ctx) => return Ok(ctx),
124                Err(e) => last_err = e,
125            }
126        }
127        Err(last_err)
128    }
129}
130
131/// Verifier that always returns anonymous. Used by [`crate::Service::without_auth()`].
132pub(crate) struct AnonymousVerifier;
133
134#[async_trait]
135impl TokenVerifier for AnonymousVerifier {
136    async fn verify(&self, _: &RawToken) -> Result<AuthCtx, AuthError> {
137        Ok(AuthCtx::anonymous())
138    }
139}
140
141// ---------- task-local ----------
142
143tokio::task_local! {
144    /// The current request's [`AuthCtx`]. Set by the auth layer for the
145    /// duration of each handler invocation. Generated client code reads
146    /// this on outbound calls to propagate the caller's identity.
147    ///
148    /// **Spawn pitfall:** `CURRENT_AUTH` is task-local. If you
149    /// `tokio::spawn` a future that calls a downstream service, capture
150    /// `AuthCtx` first and pass it explicitly:
151    ///
152    /// ```ignore
153    /// let auth = AuthCtx::from(&req);
154    /// tokio::spawn(async move {
155    ///     billing.do_thing_as(&auth, ...).await;
156    /// });
157    /// ```
158    pub static CURRENT_AUTH: AuthCtx;
159}
160
161/// Convenience: read the current task's `AuthCtx`, returning anonymous
162/// if no auth layer is active.
163pub fn current() -> AuthCtx {
164    CURRENT_AUTH
165        .try_with(|a| a.clone())
166        .unwrap_or_else(|_| AuthCtx::anonymous())
167}
168
169// ---------- service-token helper ----------
170
171/// Mint a service-identity token using the configured
172/// [`ServiceTokenMinter`]. For background jobs / queue consumers.
173///
174/// Defaults to [`default::HttpServiceTokenMinter::from_env`] if no
175/// custom minter has been registered.
176pub async fn service_token() -> Result<AuthCtx, AuthError> {
177    static MINTER: tokio::sync::OnceCell<Arc<dyn ServiceTokenMinter>> =
178        tokio::sync::OnceCell::const_new();
179    let minter = MINTER
180        .get_or_try_init(|| async {
181            let m = default::HttpServiceTokenMinter::from_env()?;
182            Ok::<Arc<dyn ServiceTokenMinter>, AuthError>(Arc::new(m))
183        })
184        .await?;
185    minter.mint().await
186}
187
188#[cfg(test)]
189mod tests {
190    use super::*;
191
192    #[tokio::test]
193    async fn chain_verifier_first_success_wins() {
194        struct AlwaysOk(AuthCtx);
195        #[async_trait]
196        impl TokenVerifier for AlwaysOk {
197            async fn verify(&self, _: &RawToken) -> Result<AuthCtx, AuthError> {
198                Ok(self.0.clone())
199            }
200        }
201        struct AlwaysErr;
202        #[async_trait]
203        impl TokenVerifier for AlwaysErr {
204            async fn verify(&self, _: &RawToken) -> Result<AuthCtx, AuthError> {
205                Err(AuthError::Signature)
206            }
207        }
208        let mut ok = AuthCtx::anonymous();
209        ok.subject = "alice".into();
210
211        let chain = ChainVerifier::new().add(AlwaysErr).add(AlwaysOk(ok));
212        let token = RawToken {
213            value: "x".into(),
214            kind: "bearer-jwt",
215        };
216        let out = chain.verify(&token).await.unwrap();
217        assert_eq!(out.subject, "alice");
218    }
219
220    #[tokio::test]
221    async fn chain_verifier_returns_last_err_when_all_fail() {
222        struct ErrA;
223        struct ErrB;
224        #[async_trait]
225        impl TokenVerifier for ErrA {
226            async fn verify(&self, _: &RawToken) -> Result<AuthCtx, AuthError> {
227                Err(AuthError::Signature)
228            }
229        }
230        #[async_trait]
231        impl TokenVerifier for ErrB {
232            async fn verify(&self, _: &RawToken) -> Result<AuthCtx, AuthError> {
233                Err(AuthError::Expired)
234            }
235        }
236        let chain = ChainVerifier::new().add(ErrA).add(ErrB);
237        let token = RawToken {
238            value: "x".into(),
239            kind: "bearer-jwt",
240        };
241        let err = chain.verify(&token).await.unwrap_err();
242        // Last err wins.
243        matches!(err, AuthError::Expired);
244    }
245}