Skip to main content

uselesskey_axum/
lib.rs

1#![forbid(unsafe_code)]
2
3//! `axum` auth-test helpers built on deterministic `uselesskey` fixtures.
4//!
5//! This crate is intentionally test-focused and scoped to common drop-in needs:
6//! - JWKS and OIDC discovery routers
7//! - Bearer token verification middleware for tests
8//! - Typed deterministic auth context extraction/injection
9
10use axum::extract::{FromRequestParts, State};
11use axum::http::{Request, StatusCode, header::AUTHORIZATION, request::Parts};
12use axum::middleware::Next;
13use axum::response::{IntoResponse, Response};
14use axum::routing::get;
15use axum::{Json, Router};
16use jsonwebtoken::{
17    Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, decode_header, encode,
18};
19use serde::{Deserialize, Serialize};
20use serde_json::{Value, json};
21use std::sync::Arc;
22use uselesskey_core::{Factory, Seed};
23use uselesskey_rsa::{RsaFactoryExt, RsaKeyPair, RsaSpec};
24
25const DEFAULT_JWKS_PATH: &str = "/.well-known/jwks.json";
26const DEFAULT_OIDC_PATH: &str = "/.well-known/openid-configuration";
27
28/// Expected JWT shape for test verification.
29#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
30pub struct AuthExpectations {
31    /// Expected `iss` claim.
32    pub issuer: String,
33    /// Expected `aud` claim.
34    pub audience: String,
35    /// Expected key id from JWT header.
36    pub kid: String,
37}
38
39impl AuthExpectations {
40    /// Build expected issuer/audience/kid values.
41    pub fn new(
42        issuer: impl Into<String>,
43        audience: impl Into<String>,
44        kid: impl Into<String>,
45    ) -> Self {
46        Self {
47            issuer: issuer.into(),
48            audience: audience.into(),
49            kid: kid.into(),
50        }
51    }
52
53    /// Replace expected issuer.
54    pub fn with_issuer(mut self, issuer: impl Into<String>) -> Self {
55        self.issuer = issuer.into();
56        self
57    }
58
59    /// Replace expected audience.
60    pub fn with_audience(mut self, audience: impl Into<String>) -> Self {
61        self.audience = audience.into();
62        self
63    }
64
65    /// Replace expected kid.
66    pub fn with_kid(mut self, kid: impl Into<String>) -> Self {
67        self.kid = kid.into();
68        self
69    }
70}
71
72/// Deterministic JWT rotation phase.
73#[derive(Clone, Copy, Debug, PartialEq, Eq)]
74pub enum RotationPhase {
75    /// Primary signing key.
76    Primary,
77    /// Next key in deterministic rotation sequence.
78    Next,
79}
80
81impl RotationPhase {
82    fn suffix(self) -> &'static str {
83        match self {
84            Self::Primary => "primary",
85            Self::Next => "next",
86        }
87    }
88}
89
90/// Deterministic signer + JWKS test fixture for one rotation phase.
91#[derive(Clone)]
92pub struct DeterministicJwksPhase {
93    keypair: RsaKeyPair,
94    expectations: AuthExpectations,
95}
96
97impl DeterministicJwksPhase {
98    /// Build deterministic material for a rotation phase.
99    pub fn new(
100        seed: Seed,
101        label: impl AsRef<str>,
102        phase: RotationPhase,
103        issuer: impl Into<String>,
104        audience: impl Into<String>,
105    ) -> Self {
106        let fx = Factory::deterministic(seed);
107        let keypair = fx.rsa(
108            format!("{}:{}", label.as_ref(), phase.suffix()),
109            RsaSpec::rs256(),
110        );
111        let kid = keypair.kid();
112        Self {
113            keypair,
114            expectations: AuthExpectations::new(issuer, audience, kid),
115        }
116    }
117
118    /// Public JWKS payload for this phase.
119    pub fn jwks_json(&self) -> Value {
120        self.keypair.public_jwks_json()
121    }
122
123    /// Expected issuer/audience/kid values.
124    pub fn expectations(&self) -> &AuthExpectations {
125        &self.expectations
126    }
127
128    /// Create RS256 bearer token for test claims.
129    pub fn issue_token(&self, mut claims: Value, ttl_seconds: u64) -> String {
130        let now = current_unix_seconds();
131        if claims.get("iss").is_none() {
132            claims["iss"] = Value::String(self.expectations.issuer.clone());
133        }
134        if claims.get("aud").is_none() {
135            claims["aud"] = Value::String(self.expectations.audience.clone());
136        }
137        if claims.get("iat").is_none() {
138            claims["iat"] = Value::Number((now as u64).into());
139        }
140        if claims.get("exp").is_none() {
141            claims["exp"] = Value::Number((now as u64 + ttl_seconds).into());
142        }
143
144        let mut header = Header::new(Algorithm::RS256);
145        header.kid = Some(self.expectations.kid.clone());
146
147        encode(
148            &header,
149            &claims,
150            &EncodingKey::from_rsa_pem(self.keypair.private_key_pkcs8_pem().as_bytes())
151                .expect("deterministic fixture key should produce valid RSA encoding key"),
152        )
153        .expect("deterministic fixture key should produce valid JWT")
154    }
155
156    fn decoding_key(&self) -> DecodingKey {
157        DecodingKey::from_rsa_pem(self.keypair.public_key_spki_pem().as_bytes())
158            .expect("deterministic fixture key should produce valid RSA decoding key")
159    }
160}
161
162/// Typed auth context inserted by helpers and extracted in handlers.
163#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
164pub struct TestAuthContext {
165    pub sub: String,
166    pub iss: String,
167    pub aud: String,
168    pub kid: String,
169    pub exp: u64,
170}
171
172impl<S> FromRequestParts<S> for TestAuthContext
173where
174    S: Send + Sync,
175{
176    type Rejection = (StatusCode, &'static str);
177
178    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
179        parts
180            .extensions
181            .get::<Self>()
182            .cloned()
183            .ok_or((StatusCode::UNAUTHORIZED, "missing auth context"))
184    }
185}
186
187/// Middleware verification state.
188#[derive(Clone)]
189pub struct MockJwtVerifierState {
190    signer: DeterministicJwksPhase,
191}
192
193impl MockJwtVerifierState {
194    /// Build middleware state from a deterministic phase.
195    pub fn new(signer: DeterministicJwksPhase) -> Self {
196        Self { signer }
197    }
198
199    /// Produce JWKS JSON served by [`jwks_router`].
200    pub fn jwks_json(&self) -> Value {
201        self.signer.jwks_json()
202    }
203
204    /// Produce OIDC discovery JSON served by [`oidc_router`].
205    pub fn oidc_json(&self, base_url: impl AsRef<str>) -> Value {
206        let base = base_url.as_ref().trim_end_matches('/');
207        json!({
208            "issuer": self.signer.expectations().issuer,
209            "jwks_uri": format!("{base}{DEFAULT_JWKS_PATH}"),
210            "id_token_signing_alg_values_supported": ["RS256"],
211            "token_endpoint_auth_methods_supported": ["none"],
212            "response_types_supported": ["token"],
213            "subject_types_supported": ["public"],
214        })
215    }
216
217    /// Generate a valid bearer token for this state.
218    pub fn issue_token(&self, claims: Value, ttl_seconds: u64) -> String {
219        self.signer.issue_token(claims, ttl_seconds)
220    }
221
222    /// Clone expected claims checks.
223    pub fn expectations(&self) -> AuthExpectations {
224        self.signer.expectations().clone()
225    }
226}
227
228/// Build a JWKS router mounted at `/.well-known/jwks.json`.
229pub fn jwks_router(state: MockJwtVerifierState) -> Router {
230    Router::new()
231        .route(DEFAULT_JWKS_PATH, get(jwks_handler))
232        .with_state(state)
233}
234
235/// Build an OIDC discovery router mounted at `/.well-known/openid-configuration`.
236pub fn oidc_router(state: MockJwtVerifierState, base_url: impl Into<String>) -> Router {
237    let state = OidcState {
238        verifier: state,
239        base_url: base_url.into(),
240    };
241    Router::new()
242        .route(DEFAULT_OIDC_PATH, get(oidc_handler))
243        .with_state(state)
244}
245
246/// Attach a middleware layer that verifies bearer tokens and inserts [`TestAuthContext`].
247pub fn mock_jwt_verifier_layer(router: Router, state: MockJwtVerifierState) -> Router {
248    let state = Arc::new(state);
249    router.layer(axum::middleware::from_fn(move |request, next| {
250        let state = Arc::clone(&state);
251        async move { verify_bearer_token(state.as_ref().clone(), request, next).await }
252    }))
253}
254
255/// Attach a middleware layer that injects a deterministic auth context without JWT parsing.
256pub fn inject_auth_context_layer(router: Router, context: TestAuthContext) -> Router {
257    let context = Arc::new(context);
258    router.layer(axum::middleware::from_fn(move |request, next| {
259        let context = Arc::clone(&context);
260        async move { inject_auth_context(context.as_ref().clone(), request, next).await }
261    }))
262}
263
264#[derive(Clone)]
265struct OidcState {
266    verifier: MockJwtVerifierState,
267    base_url: String,
268}
269
270async fn jwks_handler(State(state): State<MockJwtVerifierState>) -> Json<Value> {
271    Json(state.jwks_json())
272}
273
274async fn oidc_handler(State(state): State<OidcState>) -> Json<Value> {
275    Json(state.verifier.oidc_json(&state.base_url))
276}
277
278async fn inject_auth_context(
279    context: TestAuthContext,
280    mut request: Request<axum::body::Body>,
281    next: Next,
282) -> Response {
283    request.extensions_mut().insert(context);
284    next.run(request).await
285}
286
287async fn verify_bearer_token(
288    state: MockJwtVerifierState,
289    mut request: Request<axum::body::Body>,
290    next: Next,
291) -> Response {
292    let bearer = match extract_bearer(request.headers()) {
293        Ok(token) => token,
294        Err((code, msg)) => return (code, msg).into_response(),
295    };
296
297    let header = match decode_header(bearer) {
298        Ok(header) => header,
299        Err(_) => return (StatusCode::UNAUTHORIZED, "invalid jwt header").into_response(),
300    };
301
302    let expected = state.expectations();
303    if header.kid.as_deref() != Some(expected.kid.as_str()) {
304        return (StatusCode::UNAUTHORIZED, "unexpected kid").into_response();
305    }
306
307    let mut validation = Validation::new(Algorithm::RS256);
308    validation.set_issuer(std::slice::from_ref(&expected.issuer));
309    validation.set_audience(std::slice::from_ref(&expected.audience));
310    validation.leeway = 0;
311
312    let token = match decode::<Value>(bearer, &state.signer.decoding_key(), &validation) {
313        Ok(token) => token,
314        Err(_) => return (StatusCode::UNAUTHORIZED, "token verification failed").into_response(),
315    };
316
317    let sub = token
318        .claims
319        .get("sub")
320        .and_then(Value::as_str)
321        .unwrap_or("unknown")
322        .to_owned();
323    let iss = token
324        .claims
325        .get("iss")
326        .and_then(Value::as_str)
327        .unwrap_or_default()
328        .to_owned();
329    let aud = token
330        .claims
331        .get("aud")
332        .and_then(Value::as_str)
333        .unwrap_or_default()
334        .to_owned();
335    let exp = token
336        .claims
337        .get("exp")
338        .and_then(Value::as_u64)
339        .unwrap_or_default();
340
341    request.extensions_mut().insert(TestAuthContext {
342        sub,
343        iss,
344        aud,
345        kid: expected.kid,
346        exp,
347    });
348
349    next.run(request).await
350}
351
352fn extract_bearer(headers: &axum::http::HeaderMap) -> Result<&str, (StatusCode, &'static str)> {
353    let header = headers
354        .get(AUTHORIZATION)
355        .and_then(|value| value.to_str().ok())
356        .ok_or((StatusCode::UNAUTHORIZED, "missing authorization header"))?;
357    let token = header
358        .strip_prefix("Bearer ")
359        .ok_or((StatusCode::UNAUTHORIZED, "invalid authorization scheme"))?;
360    if token.is_empty() {
361        return Err((StatusCode::UNAUTHORIZED, "empty bearer token"));
362    }
363    Ok(token)
364}
365
366fn current_unix_seconds() -> usize {
367    std::time::SystemTime::now()
368        .duration_since(std::time::UNIX_EPOCH)
369        .expect("current time should be >= unix epoch")
370        .as_secs() as usize
371}
372
373#[cfg(test)]
374mod tests {
375    use super::*;
376    use axum::body::Body;
377    use axum::http::Request;
378    use axum::response::IntoResponse;
379    use axum::routing::get;
380    use tower::ServiceExt;
381
382    fn phase(phase: RotationPhase) -> DeterministicJwksPhase {
383        let seed = Seed::from_env_value("uselesskey-axum-tests").expect("seed parse");
384        DeterministicJwksPhase::new(
385            seed,
386            "auth-suite",
387            phase,
388            "https://issuer.example.test",
389            "api://example-aud",
390        )
391    }
392
393    #[tokio::test]
394    async fn jwks_and_oidc_routes_respond() {
395        let state = MockJwtVerifierState::new(phase(RotationPhase::Primary));
396        let app = jwks_router(state.clone()).merge(oidc_router(state, "http://localhost:3000"));
397
398        let jwks_res = app
399            .clone()
400            .oneshot(
401                Request::builder()
402                    .uri(DEFAULT_JWKS_PATH)
403                    .body(Body::empty())
404                    .unwrap(),
405            )
406            .await
407            .unwrap();
408        assert_eq!(jwks_res.status(), StatusCode::OK);
409
410        let oidc_res = app
411            .oneshot(
412                Request::builder()
413                    .uri(DEFAULT_OIDC_PATH)
414                    .body(Body::empty())
415                    .unwrap(),
416            )
417            .await
418            .unwrap();
419        assert_eq!(oidc_res.status(), StatusCode::OK);
420    }
421
422    #[tokio::test]
423    async fn rotation_phase_produces_distinct_kids() {
424        let primary = phase(RotationPhase::Primary);
425        let next = phase(RotationPhase::Next);
426        assert_ne!(primary.expectations().kid, next.expectations().kid);
427    }
428
429    #[tokio::test]
430    async fn verifier_rejects_wrong_audience() {
431        let state = MockJwtVerifierState::new(phase(RotationPhase::Primary));
432        let token = state.issue_token(json!({"sub":"alice", "aud":"api://wrong-aud"}), 300);
433
434        let app = mock_jwt_verifier_layer(
435            Router::new().route(
436                "/me",
437                get(|auth: TestAuthContext| async move {
438                    Json(json!({"sub": auth.sub})).into_response()
439                }),
440            ),
441            state,
442        );
443
444        let response = app
445            .oneshot(
446                Request::builder()
447                    .uri("/me")
448                    .header(AUTHORIZATION, format!("Bearer {token}"))
449                    .body(Body::empty())
450                    .unwrap(),
451            )
452            .await
453            .unwrap();
454
455        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
456    }
457
458    #[tokio::test]
459    async fn verifier_rejects_expired_token() {
460        let state = MockJwtVerifierState::new(phase(RotationPhase::Primary));
461        let now = current_unix_seconds() as u64;
462        let token = state.issue_token(
463            json!({"sub":"alice", "exp": now.saturating_sub(5), "iat": now.saturating_sub(10)}),
464            300,
465        );
466
467        let app = mock_jwt_verifier_layer(
468            Router::new().route("/me", get(|| async { StatusCode::OK })),
469            state,
470        );
471
472        let response = app
473            .oneshot(
474                Request::builder()
475                    .uri("/me")
476                    .header(AUTHORIZATION, format!("Bearer {token}"))
477                    .body(Body::empty())
478                    .unwrap(),
479            )
480            .await
481            .unwrap();
482
483        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
484    }
485
486    #[tokio::test]
487    async fn deterministic_auth_context_injection_works() {
488        let app = inject_auth_context_layer(
489            Router::new().route(
490                "/me",
491                get(|auth: TestAuthContext| async move {
492                    Json(json!({"sub": auth.sub, "kid": auth.kid})).into_response()
493                }),
494            ),
495            TestAuthContext {
496                sub: "test-user".into(),
497                iss: "iss".into(),
498                aud: "aud".into(),
499                kid: "kid-1".into(),
500                exp: 42,
501            },
502        );
503
504        let response = app
505            .oneshot(Request::builder().uri("/me").body(Body::empty()).unwrap())
506            .await
507            .unwrap();
508
509        assert_eq!(response.status(), StatusCode::OK);
510    }
511}