ultrafast_mcp_auth/
validation.rs1use crate::{error::AuthError, types::TokenClaims};
2use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode};
3
4#[derive(Clone)]
6pub struct TokenValidator {
7 validation: Validation,
8 decoding_key: Option<DecodingKey>,
9 secret: String,
10}
11
12impl TokenValidator {
13 pub fn new(secret: String) -> Self {
15 let mut validation = Validation::new(Algorithm::HS256);
16 validation.validate_exp = true;
17 validation.validate_nbf = true;
18
19 Self {
20 validation,
21 decoding_key: Some(DecodingKey::from_secret(secret.as_ref())),
22 secret,
23 }
24 }
25
26 pub fn get_secret(&self) -> &str {
28 &self.secret
29 }
30
31 pub fn with_decoding_key(mut self, key: DecodingKey) -> Self {
33 self.decoding_key = Some(key);
34 self
35 }
36
37 pub fn with_audience<T: ToString>(mut self, audience: T) -> Self {
39 self.validation.set_audience(&[audience.to_string()]);
40 self
41 }
42
43 pub fn with_issuer<T: ToString>(mut self, issuer: T) -> Self {
45 self.validation.set_issuer(&[issuer.to_string()]);
46 self
47 }
48
49 pub async fn validate_token(&self, token: &str) -> Result<TokenClaims, AuthError> {
51 let decoding_key =
52 self.decoding_key
53 .as_ref()
54 .ok_or_else(|| AuthError::TokenValidationError {
55 reason: "No decoding key configured".to_string(),
56 })?;
57
58 let token_data = decode::<TokenClaims>(token, decoding_key, &self.validation)
59 .map_err(|e| AuthError::InvalidToken(e.to_string()))?;
60
61 Ok(token_data.claims)
62 }
63
64 pub fn validate_audience(
66 &self,
67 claims: &TokenClaims,
68 expected_audience: &str,
69 ) -> Result<(), AuthError> {
70 if !claims.aud.contains(&expected_audience.to_string()) {
71 return Err(AuthError::InvalidAudience {
72 expected: expected_audience.to_string(),
73 actual: claims.aud.join(", "),
74 });
75 }
76 Ok(())
77 }
78
79 pub fn validate_scopes(
81 &self,
82 claims: &TokenClaims,
83 required_scopes: &[String],
84 ) -> Result<(), AuthError> {
85 let token_scopes = claims
86 .scope
87 .as_ref()
88 .map(|s| s.split_whitespace().collect::<Vec<_>>())
89 .unwrap_or_default();
90
91 for required_scope in required_scopes {
92 if !token_scopes.contains(&required_scope.as_str()) {
93 return Err(AuthError::MissingScope {
94 scope: required_scope.clone(),
95 });
96 }
97 }
98
99 Ok(())
100 }
101}
102
103impl Default for TokenValidator {
104 fn default() -> Self {
105 Self::new("".to_string())
106 }
107}
108
109pub fn extract_bearer_token(auth_header: &str) -> Result<&str, AuthError> {
111 if !auth_header.starts_with("Bearer ") {
112 return Err(AuthError::InvalidToken("Not a Bearer token".to_string()));
113 }
114 let token = auth_header.strip_prefix("Bearer ").expect("Bearer prefix should be present");
115 let token = token.trim();
116 if token.is_empty() {
117 return Err(AuthError::InvalidToken("Empty token".to_string()));
118 }
119
120 const MAX_TOKEN_LENGTH: usize = 4096;
122 if token.len() > MAX_TOKEN_LENGTH {
123 return Err(AuthError::InvalidToken(format!(
124 "Token too long: {} characters (max: {})",
125 token.len(),
126 MAX_TOKEN_LENGTH
127 )));
128 }
129
130 Ok(token)
131}
132
133pub fn extract_jwt_bearer_token(auth_header: &str) -> Result<&str, AuthError> {
135 let token = extract_bearer_token(auth_header)?;
136
137 let parts: Vec<&str> = token.split('.').collect();
139 if parts.len() != 3 {
140 return Err(AuthError::InvalidToken(
141 "Invalid JWT format: expected 3 parts".to_string(),
142 ));
143 }
144
145 Ok(token)
146}
147
148#[cfg(test)]
149mod tests {
150 use super::*;
151
152 #[test]
153 fn test_extract_bearer_token() {
154 assert_eq!(extract_bearer_token("Bearer abc123"), Ok("abc123"));
156 assert_eq!(extract_bearer_token("Bearer abc123 "), Ok("abc123"));
157
158 let jwt_token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c";
160 assert_eq!(
161 extract_bearer_token(&format!("Bearer {jwt_token}")),
162 Ok(jwt_token)
163 );
164
165 assert_eq!(
167 extract_bearer_token("Basic abc123"),
168 Err(AuthError::InvalidToken("Not a Bearer token".to_string()))
169 );
170
171 assert_eq!(
172 extract_bearer_token(""),
173 Err(AuthError::InvalidToken("Not a Bearer token".to_string()))
174 );
175
176 let long_token = "a".repeat(4097);
178 assert_eq!(
179 extract_bearer_token(&format!("Bearer {long_token}")),
180 Err(AuthError::InvalidToken(
181 "Token too long: 4097 characters (max: 4096)".to_string()
182 ))
183 );
184 }
185
186 #[test]
187 fn test_extract_jwt_bearer_token() {
188 let jwt_token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c";
190 assert_eq!(
191 extract_jwt_bearer_token(&format!("Bearer {jwt_token}")),
192 Ok(jwt_token)
193 );
194
195 assert_eq!(
197 extract_jwt_bearer_token("Bearer abc123"),
198 Err(AuthError::InvalidToken(
199 "Invalid JWT format: expected 3 parts".to_string()
200 ))
201 );
202
203 assert_eq!(
205 extract_jwt_bearer_token("Bearer part1.part2"),
206 Err(AuthError::InvalidToken(
207 "Invalid JWT format: expected 3 parts".to_string()
208 ))
209 );
210
211 assert_eq!(
212 extract_jwt_bearer_token("Bearer part1.part2.part3.part4"),
213 Err(AuthError::InvalidToken(
214 "Invalid JWT format: expected 3 parts".to_string()
215 ))
216 );
217 }
218}