Skip to main content

spikard_http/
auth.rs

1//! Authentication middleware for JWT and API keys.
2//!
3//! This module provides tower middleware for authenticating requests using:
4//! - JWT tokens (via the Authorization header)
5//! - API keys (via custom headers)
6
7use axum::{
8    body::Body,
9    extract::Request,
10    http::{HeaderMap, StatusCode, Uri},
11    middleware::Next,
12    response::{IntoResponse, Response},
13};
14use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode};
15use serde::{Deserialize, Serialize};
16use std::collections::HashSet;
17
18use crate::{ApiKeyConfig, JwtConfig, ProblemDetails};
19
20/// Standard type URI for authentication errors (401)
21const TYPE_AUTH_ERROR: &str = "https://spikard.dev/errors/unauthorized";
22
23/// Standard type URI for configuration errors (500)
24const TYPE_CONFIG_ERROR: &str = "https://spikard.dev/errors/configuration-error";
25
26/// Internal header key used to expose validated JWT claims to handlers.
27pub const INTERNAL_JWT_CLAIMS_HEADER: &str = "x-spikard-jwt-claims";
28
29/// JWT claims structure - can be extended based on needs
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct Claims {
32    pub sub: String,
33    pub exp: usize,
34    #[serde(skip_serializing_if = "Option::is_none")]
35    pub iat: Option<usize>,
36    #[serde(skip_serializing_if = "Option::is_none")]
37    pub nbf: Option<usize>,
38    #[serde(skip_serializing_if = "Option::is_none")]
39    pub aud: Option<Vec<String>>,
40    #[serde(skip_serializing_if = "Option::is_none")]
41    pub iss: Option<String>,
42}
43
44/// JWT authentication middleware
45///
46/// Validates JWT tokens from the Authorization header (Bearer scheme).
47/// On success, the validated claims are available to downstream handlers.
48/// On failure, returns 401 Unauthorized with RFC 9457 Problem Details.
49///
50/// Coverage: Tested via integration tests (`auth_integration.rs`)
51///
52/// # Errors
53/// Returns an error response when the Authorization header is missing, malformed,
54/// the token is invalid, or configuration is incorrect.
55#[cfg(not(tarpaulin_include))]
56pub async fn jwt_auth_middleware(
57    config: JwtConfig,
58    headers: HeaderMap,
59    request: Request<Body>,
60    next: Next,
61) -> Result<Response, Response> {
62    let auth_header = headers
63        .get("authorization")
64        .and_then(|v| v.to_str().ok())
65        .ok_or_else(|| {
66            let problem = ProblemDetails::new(
67                TYPE_AUTH_ERROR,
68                "Missing or invalid Authorization header",
69                StatusCode::UNAUTHORIZED,
70            )
71            .with_detail("Expected 'Authorization: Bearer <token>'");
72            (StatusCode::UNAUTHORIZED, axum::Json(problem)).into_response()
73        })?;
74
75    let token = auth_header.strip_prefix("Bearer ").ok_or_else(|| {
76        let problem = ProblemDetails::new(
77            TYPE_AUTH_ERROR,
78            "Invalid Authorization header format",
79            StatusCode::UNAUTHORIZED,
80        )
81        .with_detail("Authorization header must use Bearer scheme: 'Bearer <token>'");
82        (StatusCode::UNAUTHORIZED, axum::Json(problem)).into_response()
83    })?;
84
85    let parts: Vec<&str> = token.split('.').collect();
86    if parts.len() != 3 {
87        let problem = ProblemDetails::new(TYPE_AUTH_ERROR, "Malformed JWT token", StatusCode::UNAUTHORIZED)
88            .with_detail(format!(
89                "Malformed JWT token: expected 3 parts separated by dots, found {}",
90                parts.len()
91            ));
92        return Err((StatusCode::UNAUTHORIZED, axum::Json(problem)).into_response());
93    }
94
95    let algorithm = parse_algorithm(&config.algorithm).map_err(|_| {
96        let problem = ProblemDetails::new(
97            TYPE_CONFIG_ERROR,
98            "Invalid JWT configuration",
99            StatusCode::INTERNAL_SERVER_ERROR,
100        )
101        .with_detail(format!("Unsupported algorithm: {}", config.algorithm));
102        (StatusCode::INTERNAL_SERVER_ERROR, axum::Json(problem)).into_response()
103    })?;
104
105    let mut validation = Validation::new(algorithm);
106    if let Some(ref aud) = config.audience {
107        validation.set_audience(aud);
108    }
109    if let Some(ref iss) = config.issuer {
110        validation.set_issuer(std::slice::from_ref(iss));
111    }
112    validation.leeway = config.leeway;
113    validation.validate_nbf = true;
114
115    let decoding_key = DecodingKey::from_secret(config.secret.as_bytes());
116    let token_data = decode::<Claims>(token, &decoding_key, &validation).map_err(|e| {
117        let detail = match e.kind() {
118            jsonwebtoken::errors::ErrorKind::ExpiredSignature => "Token has expired".to_string(),
119            jsonwebtoken::errors::ErrorKind::InvalidToken => "Token is invalid".to_string(),
120            jsonwebtoken::errors::ErrorKind::InvalidSignature | jsonwebtoken::errors::ErrorKind::Base64(_) => {
121                "Token signature is invalid".to_string()
122            }
123            jsonwebtoken::errors::ErrorKind::InvalidAudience => "Token audience is invalid".to_string(),
124            jsonwebtoken::errors::ErrorKind::InvalidIssuer => config.issuer.as_ref().map_or_else(
125                || "Token issuer is invalid".to_string(),
126                |expected_iss| format!("Token issuer is invalid, expected '{expected_iss}'"),
127            ),
128            jsonwebtoken::errors::ErrorKind::ImmatureSignature => {
129                "JWT not valid yet, not before claim is in the future".to_string()
130            }
131            _ => format!("Token validation failed: {e}"),
132        };
133
134        let problem =
135            ProblemDetails::new(TYPE_AUTH_ERROR, "JWT validation failed", StatusCode::UNAUTHORIZED).with_detail(detail);
136        (StatusCode::UNAUTHORIZED, axum::Json(problem)).into_response()
137    })?;
138
139    let mut request = request;
140    request.extensions_mut().insert(token_data.claims);
141
142    Ok(next.run(request).await)
143}
144
145/// Parse JWT algorithm string to jsonwebtoken Algorithm enum
146fn parse_algorithm(alg: &str) -> Result<Algorithm, String> {
147    match alg {
148        "HS256" => Ok(Algorithm::HS256),
149        "HS384" => Ok(Algorithm::HS384),
150        "HS512" => Ok(Algorithm::HS512),
151        "RS256" => Ok(Algorithm::RS256),
152        "RS384" => Ok(Algorithm::RS384),
153        "RS512" => Ok(Algorithm::RS512),
154        "ES256" => Ok(Algorithm::ES256),
155        "ES384" => Ok(Algorithm::ES384),
156        "PS256" => Ok(Algorithm::PS256),
157        "PS384" => Ok(Algorithm::PS384),
158        "PS512" => Ok(Algorithm::PS512),
159        _ => Err(format!("Unsupported algorithm: {alg}")),
160    }
161}
162
163/// API Key authentication middleware
164///
165/// Validates API keys from a custom header (default: X-API-Key) or query parameter.
166/// Checks header first, then query parameter as fallback.
167/// On success, the request proceeds to the next handler.
168/// On failure, returns 401 Unauthorized with RFC 9457 Problem Details.
169///
170/// Coverage: Tested via integration tests (`auth_integration.rs`)
171///
172/// # Errors
173/// Returns an error response when the API key is missing or invalid.
174#[cfg(not(tarpaulin_include))]
175pub async fn api_key_auth_middleware(
176    config: ApiKeyConfig,
177    headers: HeaderMap,
178    request: Request<Body>,
179    next: Next,
180) -> Result<Response, Response> {
181    let valid_keys: HashSet<String> = config.keys.into_iter().collect();
182
183    let uri = request.uri().clone();
184
185    let api_key_from_header = headers.get(&config.header_name).and_then(|v| v.to_str().ok());
186
187    let api_key = api_key_from_header.map_or_else(|| extract_api_key_from_query(&uri), Some);
188
189    let api_key = api_key.ok_or_else(|| {
190        let problem =
191            ProblemDetails::new(TYPE_AUTH_ERROR, "Missing API key", StatusCode::UNAUTHORIZED).with_detail(format!(
192                "Expected '{}' header or 'api_key' query parameter with valid API key",
193                config.header_name
194            ));
195        (StatusCode::UNAUTHORIZED, axum::Json(problem)).into_response()
196    })?;
197
198    if !valid_keys.contains(api_key) {
199        let problem = ProblemDetails::new(TYPE_AUTH_ERROR, "Invalid API key", StatusCode::UNAUTHORIZED)
200            .with_detail("The provided API key is not valid");
201        return Err((StatusCode::UNAUTHORIZED, axum::Json(problem)).into_response());
202    }
203
204    Ok(next.run(request).await)
205}
206
207/// Extract API key from query parameters
208///
209/// Checks for common API key parameter names: api_key, apiKey, key
210fn extract_api_key_from_query(uri: &Uri) -> Option<&str> {
211    let query = uri.query()?;
212
213    for param in query.split('&') {
214        if let Some((key, value)) = param.split_once('=')
215            && (key == "api_key" || key == "apiKey" || key == "key")
216        {
217            return Some(value);
218        }
219    }
220
221    None
222}
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227
228    #[test]
229    fn test_parse_algorithm() {
230        assert!(matches!(parse_algorithm("HS256"), Ok(Algorithm::HS256)));
231        assert!(matches!(parse_algorithm("HS384"), Ok(Algorithm::HS384)));
232        assert!(matches!(parse_algorithm("HS512"), Ok(Algorithm::HS512)));
233        assert!(matches!(parse_algorithm("RS256"), Ok(Algorithm::RS256)));
234        assert!(matches!(parse_algorithm("RS384"), Ok(Algorithm::RS384)));
235        assert!(matches!(parse_algorithm("RS512"), Ok(Algorithm::RS512)));
236        assert!(matches!(parse_algorithm("ES256"), Ok(Algorithm::ES256)));
237        assert!(matches!(parse_algorithm("ES384"), Ok(Algorithm::ES384)));
238        assert!(matches!(parse_algorithm("PS256"), Ok(Algorithm::PS256)));
239        assert!(matches!(parse_algorithm("PS384"), Ok(Algorithm::PS384)));
240        assert!(matches!(parse_algorithm("PS512"), Ok(Algorithm::PS512)));
241        assert!(parse_algorithm("INVALID").is_err());
242    }
243
244    #[test]
245    fn test_claims_serialization() {
246        let claims = Claims {
247            sub: "user123".to_string(),
248            exp: 1234567890,
249            iat: Some(1234567800),
250            nbf: None,
251            aud: Some(vec!["https://api.example.com".to_string()]),
252            iss: Some("https://auth.example.com".to_string()),
253        };
254
255        let json = serde_json::to_string(&claims).unwrap();
256        assert!(json.contains("user123"));
257        assert!(json.contains("1234567890"));
258    }
259
260    #[test]
261    fn test_extract_api_key_from_query_api_key() {
262        let uri: axum::http::Uri = "/api/endpoint?api_key=secret123".parse().unwrap();
263        let result = extract_api_key_from_query(&uri);
264        assert_eq!(result, Some("secret123"));
265    }
266
267    #[test]
268    fn test_extract_api_key_from_query_api_key_camel_case() {
269        let uri: axum::http::Uri = "/api/endpoint?apiKey=mykey456".parse().unwrap();
270        let result = extract_api_key_from_query(&uri);
271        assert_eq!(result, Some("mykey456"));
272    }
273
274    #[test]
275    fn test_extract_api_key_from_query_key() {
276        let uri: axum::http::Uri = "/api/endpoint?key=testkey789".parse().unwrap();
277        let result = extract_api_key_from_query(&uri);
278        assert_eq!(result, Some("testkey789"));
279    }
280
281    #[test]
282    fn test_extract_api_key_from_query_no_key() {
283        let uri: axum::http::Uri = "/api/endpoint?foo=bar&baz=qux".parse().unwrap();
284        let result = extract_api_key_from_query(&uri);
285        assert_eq!(result, None);
286    }
287
288    #[test]
289    fn test_extract_api_key_from_query_empty_string() {
290        let uri: axum::http::Uri = "/api/endpoint".parse().unwrap();
291        let result = extract_api_key_from_query(&uri);
292        assert_eq!(result, None);
293    }
294
295    #[test]
296    fn test_extract_api_key_from_query_multiple_params() {
297        let uri: axum::http::Uri = "/api/endpoint?foo=bar&api_key=found&baz=qux".parse().unwrap();
298        let result = extract_api_key_from_query(&uri);
299        assert_eq!(result, Some("found"));
300    }
301}