ultrafast_mcp_auth/
validation.rs

1use crate::{error::AuthError, types::TokenClaims};
2use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode};
3
4/// Token validator for JWT access tokens
5#[derive(Clone)]
6pub struct TokenValidator {
7    validation: Validation,
8    decoding_key: Option<DecodingKey>,
9    secret: String,
10}
11
12impl TokenValidator {
13    /// Create a new token validator with a secret
14    pub fn new(secret: String) -> Self {
15        let mut validation = Validation::new(Algorithm::HS256);
16        validation.validate_exp = true;
17        validation.validate_nbf = true;
18
19        Self {
20            validation,
21            decoding_key: Some(DecodingKey::from_secret(secret.as_ref())),
22            secret,
23        }
24    }
25
26    /// Get the secret (for testing)
27    pub fn get_secret(&self) -> &str {
28        &self.secret
29    }
30
31    /// Set the decoding key for JWT verification
32    pub fn with_decoding_key(mut self, key: DecodingKey) -> Self {
33        self.decoding_key = Some(key);
34        self
35    }
36
37    /// Set required audience
38    pub fn with_audience<T: ToString>(mut self, audience: T) -> Self {
39        self.validation.set_audience(&[audience.to_string()]);
40        self
41    }
42
43    /// Set required issuer
44    pub fn with_issuer<T: ToString>(mut self, issuer: T) -> Self {
45        self.validation.set_issuer(&[issuer.to_string()]);
46        self
47    }
48
49    /// Validate a JWT access token
50    pub async fn validate_token(&self, token: &str) -> Result<TokenClaims, AuthError> {
51        let decoding_key =
52            self.decoding_key
53                .as_ref()
54                .ok_or_else(|| AuthError::TokenValidationError {
55                    reason: "No decoding key configured".to_string(),
56                })?;
57
58        let token_data = decode::<TokenClaims>(token, decoding_key, &self.validation)
59            .map_err(|e| AuthError::InvalidToken(e.to_string()))?;
60
61        Ok(token_data.claims)
62    }
63
64    /// Validate token audience specifically (RFC 8707)
65    pub fn validate_audience(
66        &self,
67        claims: &TokenClaims,
68        expected_audience: &str,
69    ) -> Result<(), AuthError> {
70        if !claims.aud.contains(&expected_audience.to_string()) {
71            return Err(AuthError::InvalidAudience {
72                expected: expected_audience.to_string(),
73                actual: claims.aud.join(", "),
74            });
75        }
76        Ok(())
77    }
78
79    /// Validate required scopes
80    pub fn validate_scopes(
81        &self,
82        claims: &TokenClaims,
83        required_scopes: &[String],
84    ) -> Result<(), AuthError> {
85        let token_scopes = claims
86            .scope
87            .as_ref()
88            .map(|s| s.split_whitespace().collect::<Vec<_>>())
89            .unwrap_or_default();
90
91        for required_scope in required_scopes {
92            if !token_scopes.contains(&required_scope.as_str()) {
93                return Err(AuthError::MissingScope {
94                    scope: required_scope.clone(),
95                });
96            }
97        }
98
99        Ok(())
100    }
101}
102
103impl Default for TokenValidator {
104    fn default() -> Self {
105        Self::new("".to_string())
106    }
107}
108
109/// Extract bearer token from Authorization header
110pub fn extract_bearer_token(auth_header: &str) -> Result<&str, AuthError> {
111    if !auth_header.starts_with("Bearer ") {
112        return Err(AuthError::InvalidToken("Not a Bearer token".to_string()));
113    }
114    let token = auth_header.strip_prefix("Bearer ").expect("Bearer prefix should be present");
115    let token = token.trim();
116    if token.is_empty() {
117        return Err(AuthError::InvalidToken("Empty token".to_string()));
118    }
119
120    // Validate token length (prevent extremely long tokens)
121    const MAX_TOKEN_LENGTH: usize = 4096;
122    if token.len() > MAX_TOKEN_LENGTH {
123        return Err(AuthError::InvalidToken(format!(
124            "Token too long: {} characters (max: {})",
125            token.len(),
126            MAX_TOKEN_LENGTH
127        )));
128    }
129
130    Ok(token)
131}
132
133/// Extract and validate JWT bearer token from Authorization header
134pub fn extract_jwt_bearer_token(auth_header: &str) -> Result<&str, AuthError> {
135    let token = extract_bearer_token(auth_header)?;
136
137    // Basic JWT format validation (should contain exactly 3 parts separated by dots)
138    let parts: Vec<&str> = token.split('.').collect();
139    if parts.len() != 3 {
140        return Err(AuthError::InvalidToken(
141            "Invalid JWT format: expected 3 parts".to_string(),
142        ));
143    }
144
145    Ok(token)
146}
147
148#[cfg(test)]
149mod tests {
150    use super::*;
151
152    #[test]
153    fn test_extract_bearer_token() {
154        // Test simple tokens (for backward compatibility)
155        assert_eq!(extract_bearer_token("Bearer abc123"), Ok("abc123"));
156        assert_eq!(extract_bearer_token("Bearer  abc123  "), Ok("abc123"));
157
158        // Test JWT tokens
159        let jwt_token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c";
160        assert_eq!(
161            extract_bearer_token(&format!("Bearer {jwt_token}")),
162            Ok(jwt_token)
163        );
164
165        // Test error cases
166        assert_eq!(
167            extract_bearer_token("Basic abc123"),
168            Err(AuthError::InvalidToken("Not a Bearer token".to_string()))
169        );
170
171        assert_eq!(
172            extract_bearer_token(""),
173            Err(AuthError::InvalidToken("Not a Bearer token".to_string()))
174        );
175
176        // Test token length validation
177        let long_token = "a".repeat(4097);
178        assert_eq!(
179            extract_bearer_token(&format!("Bearer {long_token}")),
180            Err(AuthError::InvalidToken(
181                "Token too long: 4097 characters (max: 4096)".to_string()
182            ))
183        );
184    }
185
186    #[test]
187    fn test_extract_jwt_bearer_token() {
188        // Test valid JWT tokens
189        let jwt_token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c";
190        assert_eq!(
191            extract_jwt_bearer_token(&format!("Bearer {jwt_token}")),
192            Ok(jwt_token)
193        );
194
195        // Test invalid JWT format (simple token)
196        assert_eq!(
197            extract_jwt_bearer_token("Bearer abc123"),
198            Err(AuthError::InvalidToken(
199                "Invalid JWT format: expected 3 parts".to_string()
200            ))
201        );
202
203        // Test invalid JWT format (wrong number of parts)
204        assert_eq!(
205            extract_jwt_bearer_token("Bearer part1.part2"),
206            Err(AuthError::InvalidToken(
207                "Invalid JWT format: expected 3 parts".to_string()
208            ))
209        );
210
211        assert_eq!(
212            extract_jwt_bearer_token("Bearer part1.part2.part3.part4"),
213            Err(AuthError::InvalidToken(
214                "Invalid JWT format: expected 3 parts".to_string()
215            ))
216        );
217    }
218}