ultrafast_mcp_auth/
middleware.rs

1use crate::{AuthMethod, TokenValidator, error::AuthError, types::TokenClaims};
2use base64::Engine;
3use std::collections::HashMap;
4use std::sync::Arc;
5use tokio::sync::RwLock;
6use tracing::{debug, warn};
7
8/// Authentication context for requests
9#[derive(Debug, Clone)]
10pub struct AuthContext {
11    pub user_id: Option<String>,
12    pub scopes: Vec<String>,
13    pub claims: Option<TokenClaims>,
14    pub auth_method: AuthMethod,
15    pub is_authenticated: bool,
16}
17
18impl AuthContext {
19    pub fn new() -> Self {
20        Self {
21            user_id: None,
22            scopes: Vec::new(),
23            claims: None,
24            auth_method: AuthMethod::None,
25            is_authenticated: false,
26        }
27    }
28
29    pub fn with_user_id(mut self, user_id: String) -> Self {
30        self.user_id = Some(user_id);
31        self
32    }
33
34    pub fn with_scopes(mut self, scopes: Vec<String>) -> Self {
35        self.scopes = scopes;
36        self
37    }
38
39    pub fn with_claims(mut self, claims: TokenClaims) -> Self {
40        self.claims = Some(claims);
41        self
42    }
43
44    pub fn with_auth_method(mut self, auth_method: AuthMethod) -> Self {
45        self.auth_method = auth_method;
46        self
47    }
48
49    pub fn authenticated(mut self) -> Self {
50        self.is_authenticated = true;
51        self
52    }
53
54    pub fn has_scope(&self, scope: &str) -> bool {
55        self.scopes.contains(&scope.to_string())
56    }
57
58    pub fn has_any_scope(&self, scopes: &[String]) -> bool {
59        scopes.iter().any(|scope| self.has_scope(scope))
60    }
61
62    pub fn has_all_scopes(&self, scopes: &[String]) -> bool {
63        scopes.iter().all(|scope| self.has_scope(scope))
64    }
65}
66
67impl Default for AuthContext {
68    fn default() -> Self {
69        Self::new()
70    }
71}
72
73/// Server-side authentication middleware
74pub struct ServerAuthMiddleware {
75    token_validator: Arc<TokenValidator>,
76    required_scopes: Vec<String>,
77    auth_enabled: bool,
78    session_store: Arc<RwLock<HashMap<String, AuthContext>>>,
79}
80
81impl ServerAuthMiddleware {
82    pub fn new(token_validator: TokenValidator) -> Self {
83        Self {
84            token_validator: Arc::new(token_validator),
85            required_scopes: Vec::new(),
86            auth_enabled: true,
87            session_store: Arc::new(RwLock::new(HashMap::new())),
88        }
89    }
90
91    pub fn with_required_scopes(mut self, scopes: Vec<String>) -> Self {
92        self.required_scopes = scopes;
93        self
94    }
95
96    pub fn with_auth_enabled(mut self, enabled: bool) -> Self {
97        self.auth_enabled = enabled;
98        self
99    }
100
101    /// Validate authentication headers and return auth context
102    pub async fn validate_request(
103        &self,
104        headers: &HashMap<String, String>,
105    ) -> Result<AuthContext, AuthError> {
106        if !self.auth_enabled {
107            return Ok(AuthContext::new().authenticated());
108        }
109
110        // Check for Authorization header
111        if let Some(auth_header) = headers.get("Authorization") {
112            return self.validate_auth_header(auth_header).await;
113        }
114
115        // Check for API key headers
116        for (key, value) in headers {
117            if key.to_lowercase().contains("api-key") || key.to_lowercase().contains("x-api-key") {
118                return self.validate_api_key(key, value).await;
119            }
120        }
121
122        // No authentication found
123        if self.required_scopes.is_empty() {
124            // If no scopes required, allow unauthenticated access
125            Ok(AuthContext::new())
126        } else {
127            Err(AuthError::InvalidCredentials)
128        }
129    }
130
131    /// Validate Authorization header
132    async fn validate_auth_header(&self, auth_header: &str) -> Result<AuthContext, AuthError> {
133        if auth_header.starts_with("Bearer ") {
134            self.validate_bearer_token(auth_header).await
135        } else if auth_header.starts_with("Basic ") {
136            self.validate_basic_auth(auth_header).await
137        } else {
138            Err(AuthError::InvalidToken(
139                "Unsupported authorization scheme".to_string(),
140            ))
141        }
142    }
143
144    /// Validate Bearer token
145    async fn validate_bearer_token(&self, auth_header: &str) -> Result<AuthContext, AuthError> {
146        let token = crate::validation::extract_bearer_token(auth_header)?;
147
148        // Validate JWT token
149        let claims = self.token_validator.validate_token(token).await?;
150
151        // Check required scopes
152        if !self.required_scopes.is_empty() {
153            self.token_validator
154                .validate_scopes(&claims, &self.required_scopes)?;
155        }
156
157        let scopes = claims
158            .scope
159            .as_ref()
160            .map(|s| s.split_whitespace().map(|s| s.to_string()).collect())
161            .unwrap_or_default();
162
163        let auth_context = AuthContext::new()
164            .with_user_id(claims.sub.clone())
165            .with_scopes(scopes)
166            .with_claims(claims)
167            .with_auth_method(AuthMethod::bearer(token.to_string()))
168            .authenticated();
169
170        debug!(
171            "Bearer token validated for user: {}",
172            auth_context
173                .user_id
174                .as_ref()
175                .unwrap_or(&"unknown".to_string())
176        );
177        Ok(auth_context)
178    }
179
180    /// Validate Basic authentication
181    async fn validate_basic_auth(&self, auth_header: &str) -> Result<AuthContext, AuthError> {
182        // Extract and decode basic auth credentials
183        let encoded = auth_header
184            .strip_prefix("Basic ")
185            .ok_or_else(|| AuthError::InvalidToken("Invalid Basic auth format".to_string()))?;
186
187        let decoded = base64::engine::general_purpose::STANDARD
188            .decode(encoded)
189            .map_err(|_| AuthError::InvalidToken("Invalid Basic auth encoding".to_string()))?;
190
191        let credentials = String::from_utf8(decoded)
192            .map_err(|_| AuthError::InvalidToken("Invalid Basic auth credentials".to_string()))?;
193
194        let parts: Vec<&str> = credentials.splitn(2, ':').collect();
195        if parts.len() != 2 {
196            return Err(AuthError::InvalidToken(
197                "Invalid Basic auth format".to_string(),
198            ));
199        }
200
201        let username = parts[0];
202        let password = parts[1];
203
204        // In a real implementation, you would validate against a user database
205        // For now, we'll use a simple validation
206        if username.is_empty() || password.is_empty() {
207            return Err(AuthError::InvalidCredentials);
208        }
209
210        let auth_context = AuthContext::new()
211            .with_user_id(username.to_string())
212            .with_auth_method(AuthMethod::basic(
213                username.to_string(),
214                password.to_string(),
215            ))
216            .authenticated();
217
218        debug!("Basic auth validated for user: {}", username);
219        Ok(auth_context)
220    }
221
222    /// Validate API key
223    async fn validate_api_key(
224        &self,
225        _header_name: &str,
226        api_key: &str,
227    ) -> Result<AuthContext, AuthError> {
228        if api_key.is_empty() {
229            return Err(AuthError::InvalidCredentials);
230        }
231
232        // In a real implementation, you would validate the API key against a database
233        // For now, we'll accept any non-empty API key
234        let auth_context = AuthContext::new()
235            .with_user_id(format!("api_user_{}", &api_key[..8.min(api_key.len())]))
236            .with_auth_method(AuthMethod::api_key(api_key.to_string()))
237            .authenticated();
238
239        debug!(
240            "API key validated for user: {}",
241            auth_context
242                .user_id
243                .as_ref()
244                .unwrap_or(&"unknown".to_string())
245        );
246        Ok(auth_context)
247    }
248
249    /// Store session authentication context
250    pub async fn store_session(&self, session_id: String, auth_context: AuthContext) {
251        let mut sessions = self.session_store.write().await;
252        sessions.insert(session_id, auth_context);
253    }
254
255    /// Get session authentication context
256    pub async fn get_session(&self, session_id: &str) -> Option<AuthContext> {
257        let sessions = self.session_store.read().await;
258        sessions.get(session_id).cloned()
259    }
260
261    /// Remove session authentication context
262    pub async fn remove_session(&self, session_id: &str) {
263        let mut sessions = self.session_store.write().await;
264        sessions.remove(session_id);
265    }
266
267    /// Check if user has required scopes
268    pub fn check_scopes(
269        &self,
270        auth_context: &AuthContext,
271        required_scopes: &[String],
272    ) -> Result<(), AuthError> {
273        if required_scopes.is_empty() {
274            return Ok(());
275        }
276
277        if !auth_context.has_all_scopes(required_scopes) {
278            let missing_scopes: Vec<String> = required_scopes
279                .iter()
280                .filter(|scope| !auth_context.has_scope(scope))
281                .cloned()
282                .collect();
283
284            return Err(AuthError::MissingScope {
285                scope: missing_scopes.join(", "),
286            });
287        }
288
289        Ok(())
290    }
291}
292
293/// Client-side authentication middleware
294pub struct ClientAuthMiddleware {
295    auth_method: AuthMethod,
296    auto_refresh: bool,
297}
298
299impl ClientAuthMiddleware {
300    pub fn new(auth_method: AuthMethod) -> Self {
301        Self {
302            auth_method,
303            auto_refresh: true,
304        }
305    }
306
307    pub fn with_auto_refresh(mut self, enabled: bool) -> Self {
308        self.auto_refresh = enabled;
309        self
310    }
311
312    /// Get authentication headers for outgoing requests
313    pub async fn get_headers(&mut self) -> Result<HashMap<String, String>, AuthError> {
314        // Refresh token if needed and auto-refresh is enabled
315        if self.auto_refresh && self.auth_method.requires_refresh() {
316            if let Err(e) = self.auth_method.refresh().await {
317                warn!("Failed to refresh authentication: {:?}", e);
318            }
319        }
320
321        self.auth_method.get_headers().await
322    }
323
324    /// Update authentication method
325    pub fn with_auth_method(mut self, auth_method: AuthMethod) -> Self {
326        self.auth_method = auth_method;
327        self
328    }
329
330    /// Get current authentication method
331    pub fn get_auth_method(&self) -> &AuthMethod {
332        &self.auth_method
333    }
334}
335
336// Remove the conflicting From implementation - it's not needed since AuthError is already the same type
337
338#[cfg(test)]
339mod tests {
340    use super::*;
341
342    #[tokio::test]
343    async fn test_bearer_auth_validation() {
344        let validator = TokenValidator::new("test_secret".to_string());
345        let middleware = ServerAuthMiddleware::new(validator)
346            .with_required_scopes(vec!["read".to_string(), "write".to_string()]);
347
348        let mut headers = HashMap::new();
349        headers.insert(
350            "Authorization".to_string(),
351            "Bearer invalid_token".to_string(),
352        );
353
354        let result = middleware.validate_request(&headers).await;
355        assert!(result.is_err());
356    }
357
358    #[tokio::test]
359    async fn test_basic_auth_validation() {
360        let validator = TokenValidator::new("test_secret".to_string());
361        let middleware = ServerAuthMiddleware::new(validator);
362
363        let mut headers = HashMap::new();
364        headers.insert(
365            "Authorization".to_string(),
366            "Basic dXNlcjpwYXNz".to_string(),
367        ); // user:pass
368
369        let result = middleware.validate_request(&headers).await;
370        assert!(result.is_ok());
371    }
372
373    #[tokio::test]
374    async fn test_api_key_validation() {
375        let validator = TokenValidator::new("test_secret".to_string());
376        let middleware = ServerAuthMiddleware::new(validator);
377
378        let mut headers = HashMap::new();
379        headers.insert("X-API-Key".to_string(), "test_api_key".to_string());
380
381        let result = middleware.validate_request(&headers).await;
382        assert!(result.is_ok());
383    }
384
385    #[tokio::test]
386    async fn test_auth_context_scopes() {
387        let context = AuthContext::new().with_scopes(vec!["read".to_string(), "write".to_string()]);
388
389        assert!(context.has_scope("read"));
390        assert!(context.has_scope("write"));
391        assert!(!context.has_scope("delete"));
392
393        assert!(context.has_any_scope(&["read".to_string(), "delete".to_string()]));
394        assert!(context.has_all_scopes(&["read".to_string(), "write".to_string()]));
395        assert!(!context.has_all_scopes(&["read".to_string(), "delete".to_string()]));
396    }
397}