turbomcp_auth/manager.rs
1//! Authentication Manager
2//!
3//! Central authentication manager for coordinating multiple authentication providers.
4//!
5//! # MCP Compliance
6//!
7//! Per MCP specification (2025-06-18), authentication is **stateless**.
8//! Each request must include valid credentials (Bearer token in Authorization header).
9//! This manager does NOT maintain server-side session state for authentication decisions.
10//!
11//! ## Stateless Authentication Flow
12//!
13//! ```rust,no_run
14//! # use turbomcp_auth::{AuthManager, AuthCredentials, config::{AuthConfig, AuthorizationConfig}};
15//! # use std::collections::HashMap;
16//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
17//! # let config = AuthConfig {
18//! # enabled: true,
19//! # providers: vec![],
20//! # authorization: AuthorizationConfig {
21//! # rbac_enabled: false,
22//! # default_roles: vec![],
23//! # inheritance_rules: HashMap::new(),
24//! # resource_permissions: HashMap::new(),
25//! # },
26//! # };
27//! # let manager = AuthManager::new(config);
28//! # let credentials = AuthCredentials::ApiKey { key: "test".to_string() };
29//! // 1. Authenticate user and get auth context
30//! let auth_context = manager.authenticate("oauth2", credentials).await?;
31//!
32//! // 2. Extract token from auth context
33//! let token = auth_context.token.as_ref().unwrap().access_token.clone();
34//!
35//! // 3. On subsequent requests, validate token EVERY TIME
36//! let validated_context = manager.validate_token(&token, Some("oauth2")).await?;
37//! // ✅ Token validated via provider - truly stateless
38//! # Ok(())
39//! # }
40//! ```
41
42use std::collections::HashMap;
43use std::sync::Arc;
44
45use tokio::sync::RwLock;
46
47use super::config::AuthConfig;
48use super::context::AuthContext as UnifiedAuthContext; // Unified AuthContext for external API
49use super::types::{AuthCredentials, AuthProvider};
50use turbomcp_protocol::{Error as McpError, Result as McpResult};
51
52/// Authentication manager for coordinating multiple authentication providers
53///
54/// # MCP Specification Compliance
55///
56/// This manager implements **stateless** authentication per MCP spec (RFC 9728).
57/// No server-side session state is maintained. All authentication decisions are made
58/// by validating credentials on EVERY request.
59#[derive(Debug)]
60pub struct AuthManager {
61 /// Authentication configuration
62 config: AuthConfig,
63 /// Registered authentication providers
64 providers: Arc<RwLock<HashMap<String, Arc<dyn AuthProvider>>>>,
65}
66
67impl AuthManager {
68 /// Create a new authentication manager
69 ///
70 /// # MCP Specification Compliance
71 ///
72 /// Creates a stateless authentication manager per MCP spec.
73 /// No server-side session state is maintained.
74 #[must_use]
75 pub fn new(config: AuthConfig) -> Self {
76 Self {
77 config,
78 providers: Arc::new(RwLock::new(HashMap::new())),
79 }
80 }
81
82 /// Add an authentication provider
83 pub async fn add_provider(&self, provider: Arc<dyn AuthProvider>) {
84 let name = provider.name().to_string();
85 self.providers.write().await.insert(name, provider);
86 }
87
88 /// Remove an authentication provider
89 pub async fn remove_provider(&self, name: &str) -> bool {
90 self.providers.write().await.remove(name).is_some()
91 }
92
93 /// List available providers
94 pub async fn list_providers(&self) -> Vec<String> {
95 self.providers.read().await.keys().cloned().collect()
96 }
97
98 /// Authenticate user with credentials
99 ///
100 /// # MCP Specification Compliance
101 ///
102 /// Authenticates the user and returns an `AuthContext`.
103 /// **NO server-side session state is created** - per MCP stateless requirement.
104 ///
105 /// The returned `AuthContext` contains a token (if applicable) that the client
106 /// must include in subsequent requests via the `Authorization` header.
107 ///
108 /// # Example
109 ///
110 /// ```rust,no_run
111 /// # use turbomcp_auth::{AuthManager, AuthCredentials, config::{AuthConfig, AuthorizationConfig}};
112 /// # use std::collections::HashMap;
113 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
114 /// # let config = AuthConfig {
115 /// # enabled: true,
116 /// # providers: vec![],
117 /// # authorization: AuthorizationConfig {
118 /// # rbac_enabled: false,
119 /// # default_roles: vec![],
120 /// # inheritance_rules: HashMap::new(),
121 /// # resource_permissions: HashMap::new(),
122 /// # },
123 /// # };
124 /// # let manager = AuthManager::new(config);
125 /// let credentials = AuthCredentials::ApiKey {
126 /// key: "secret_key".to_string(),
127 /// };
128 ///
129 /// let auth_context = manager.authenticate("api", credentials).await?;
130 ///
131 /// // Extract token for subsequent requests
132 /// if let Some(token_info) = &auth_context.token {
133 /// let access_token = &token_info.access_token;
134 /// // Client must send: Authorization: Bearer {access_token}
135 /// }
136 /// # Ok(())
137 /// # }
138 /// ```
139 pub async fn authenticate(
140 &self,
141 provider_name: &str,
142 credentials: AuthCredentials,
143 ) -> McpResult<UnifiedAuthContext> {
144 if !self.config.enabled {
145 return Err(McpError::internal("Authentication is disabled".to_string()));
146 }
147
148 let providers = self.providers.read().await;
149 let provider = providers
150 .get(provider_name)
151 .ok_or_else(|| McpError::internal(format!("Provider '{provider_name}' not found")))?;
152
153 let mut auth_context = provider.authenticate(credentials).await?;
154
155 // Apply default roles if configured
156 if auth_context.roles.is_empty() {
157 auth_context.roles = self.config.authorization.default_roles.clone();
158 }
159
160 // MCP Spec: Stateless authentication - NO session storage
161 // Client must include token in Authorization header on every request
162 Ok(auth_context)
163 }
164
165 /// Validate token and get authentication context
166 ///
167 /// # MCP Specification Compliance
168 ///
169 /// Validates the token on EVERY request per MCP stateless requirement.
170 /// This method MUST be called for each incoming request to ensure the token
171 /// is still valid (not expired, not revoked, etc.).
172 ///
173 /// # Arguments
174 ///
175 /// * `token` - The access token to validate (from Authorization header)
176 /// * `provider_name` - Optional provider name (if known). If None, tries all providers.
177 ///
178 /// # Example
179 ///
180 /// ```rust,no_run
181 /// # use turbomcp_auth::AuthManager;
182 /// # async fn handle_request(manager: &AuthManager, auth_header: &str) -> Result<(), Box<dyn std::error::Error>> {
183 /// // Extract token from Authorization header
184 /// let token = auth_header.strip_prefix("Bearer ").unwrap();
185 ///
186 /// // Validate token on EVERY request (stateless)
187 /// let auth_context = manager.validate_token(token, None).await?;
188 ///
189 /// // Use auth_context for authorization decisions
190 /// println!("Authenticated user: {}", auth_context.user.username);
191 /// # Ok(())
192 /// # }
193 /// ```
194 pub async fn validate_token(
195 &self,
196 token: &str,
197 provider_name: Option<&str>,
198 ) -> McpResult<UnifiedAuthContext> {
199 if !self.config.enabled {
200 return Err(McpError::internal("Authentication is disabled".to_string()));
201 }
202
203 let providers = self.providers.read().await;
204
205 if let Some(provider_name) = provider_name {
206 let provider = providers.get(provider_name).ok_or_else(|| {
207 McpError::internal(format!("Provider '{provider_name}' not found"))
208 })?;
209 provider.validate_token(token).await
210 } else {
211 // Try all providers
212 for provider in providers.values() {
213 if let Ok(auth_context) = provider.validate_token(token).await {
214 return Ok(auth_context);
215 }
216 }
217 Err(McpError::internal("Token validation failed".to_string()))
218 }
219 }
220
221 /// Check if user has permission
222 #[must_use]
223 pub fn check_permission(&self, context: &UnifiedAuthContext, permission: &str) -> bool {
224 context.permissions.contains(&permission.to_string())
225 || context.roles.iter().any(|role| {
226 self.config
227 .authorization
228 .inheritance_rules
229 .get(role)
230 .is_some_and(|perms| perms.contains(&permission.to_string()))
231 })
232 }
233
234 /// Check if user has role
235 #[must_use]
236 pub fn check_role(&self, context: &UnifiedAuthContext, role: &str) -> bool {
237 context.roles.contains(&role.to_string())
238 }
239}
240
241// Note: PKCE functionality is handled by the oauth2 crate's built-in
242// PkceCodeChallenge::new_random_sha256() method for maximum security
243
244/// Global authentication manager
245static GLOBAL_AUTH_MANAGER: once_cell::sync::Lazy<tokio::sync::RwLock<Option<Arc<AuthManager>>>> =
246 once_cell::sync::Lazy::new(|| tokio::sync::RwLock::new(None));
247
248/// Set the global authentication manager
249pub async fn set_global_auth_manager(manager: Arc<AuthManager>) {
250 *GLOBAL_AUTH_MANAGER.write().await = Some(manager);
251}
252
253/// Get the global authentication manager
254pub async fn global_auth_manager() -> Option<Arc<AuthManager>> {
255 GLOBAL_AUTH_MANAGER.read().await.clone()
256}
257
258/// Convenience function to check authentication
259pub async fn check_auth(token: &str) -> McpResult<UnifiedAuthContext> {
260 if let Some(manager) = global_auth_manager().await {
261 manager.validate_token(token, None).await
262 } else {
263 Err(McpError::internal(
264 "Authentication manager not initialized".to_string(),
265 ))
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272 use crate::{
273 config::{AuthorizationConfig, OAuth2Config, OAuth2FlowType, SecurityLevel},
274 providers::ApiKeyProvider,
275 types::UserInfo,
276 };
277 use std::collections::HashMap;
278
279 #[test]
280 fn test_oauth2_config() {
281 let config = OAuth2Config {
282 client_id: "test_client".to_string(),
283 client_secret: "test_secret".to_string().into(),
284 auth_url: "https://auth.example.com/oauth/authorize".to_string(),
285 token_url: "https://auth.example.com/oauth/token".to_string(),
286 revocation_url: None,
287 redirect_uri: "http://localhost:8080/callback".to_string(),
288 scopes: vec!["read".to_string(), "write".to_string()],
289 flow_type: OAuth2FlowType::AuthorizationCode,
290 additional_params: HashMap::new(),
291 security_level: SecurityLevel::Standard,
292 mcp_resource_uri: None,
293 auto_resource_indicators: false,
294 #[cfg(feature = "dpop")]
295 dpop_config: None,
296 };
297
298 assert_eq!(config.client_id, "test_client");
299 assert_eq!(config.flow_type, OAuth2FlowType::AuthorizationCode);
300 }
301
302 #[test]
303 fn test_oauth2_pkce_integration() {
304 // Test that oauth2 crate PKCE functionality works as expected
305 let (challenge1, _verifier1) = oauth2::PkceCodeChallenge::new_random_sha256();
306 let (challenge2, _verifier2) = oauth2::PkceCodeChallenge::new_random_sha256();
307
308 // Each PKCE challenge should be unique
309 assert_ne!(challenge1.as_str(), challenge2.as_str());
310 assert!(!challenge1.as_str().is_empty());
311 assert!(!challenge2.as_str().is_empty());
312 }
313
314 #[tokio::test]
315 async fn test_api_key_provider() {
316 let provider = ApiKeyProvider::new("test_api".to_string());
317
318 let user_info = UserInfo {
319 id: "user123".to_string(),
320 username: "testuser".to_string(),
321 email: Some("test@example.com".to_string()),
322 display_name: Some("Test User".to_string()),
323 avatar_url: None,
324 metadata: HashMap::new(),
325 };
326
327 provider
328 .add_api_key("test_key_123".to_string(), user_info.clone())
329 .await;
330
331 let credentials = AuthCredentials::ApiKey {
332 key: "test_key_123".to_string(),
333 };
334
335 let auth_result = provider.authenticate(credentials).await;
336 assert!(auth_result.is_ok());
337
338 let context = auth_result.unwrap();
339 assert_eq!(context.user.username, "testuser");
340 assert_eq!(context.provider, "test_api");
341 }
342
343 #[tokio::test]
344 async fn test_auth_manager() {
345 let config = AuthConfig {
346 enabled: true,
347 providers: vec![],
348 authorization: AuthorizationConfig {
349 rbac_enabled: true,
350 default_roles: vec!["user".to_string()],
351 inheritance_rules: HashMap::new(),
352 resource_permissions: HashMap::new(),
353 },
354 };
355
356 let manager = AuthManager::new(config);
357 let api_provider = Arc::new(ApiKeyProvider::new("api".to_string()));
358 manager.add_provider(api_provider.clone()).await;
359
360 let providers = manager.list_providers().await;
361 assert!(providers.contains(&"api".to_string()));
362 }
363}