turbomcp_auth/
manager.rs

1//! Authentication Manager
2//!
3//! Central authentication manager for coordinating multiple authentication providers.
4//!
5//! # MCP Compliance
6//!
7//! Per MCP specification (2025-06-18), authentication is **stateless**.
8//! Each request must include valid credentials (Bearer token in Authorization header).
9//! This manager does NOT maintain server-side session state for authentication decisions.
10//!
11//! ## Stateless Authentication Flow
12//!
13//! ```rust,no_run
14//! # use turbomcp_auth::{AuthManager, AuthCredentials, config::{AuthConfig, AuthorizationConfig}};
15//! # use std::collections::HashMap;
16//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
17//! # let config = AuthConfig {
18//! #     enabled: true,
19//! #     providers: vec![],
20//! #     authorization: AuthorizationConfig {
21//! #         rbac_enabled: false,
22//! #         default_roles: vec![],
23//! #         inheritance_rules: HashMap::new(),
24//! #         resource_permissions: HashMap::new(),
25//! #     },
26//! # };
27//! # let manager = AuthManager::new(config);
28//! # let credentials = AuthCredentials::ApiKey { key: "test".to_string() };
29//! // 1. Authenticate user and get auth context
30//! let auth_context = manager.authenticate("oauth2", credentials).await?;
31//!
32//! // 2. Extract token from auth context
33//! let token = auth_context.token.as_ref().unwrap().access_token.clone();
34//!
35//! // 3. On subsequent requests, validate token EVERY TIME
36//! let validated_context = manager.validate_token(&token, Some("oauth2")).await?;
37//! // ✅ Token validated via provider - truly stateless
38//! # Ok(())
39//! # }
40//! ```
41
42use std::collections::HashMap;
43use std::sync::Arc;
44
45use tokio::sync::RwLock;
46
47use super::config::AuthConfig;
48use super::context::AuthContext as UnifiedAuthContext; // Unified AuthContext for external API
49use super::types::{AuthCredentials, AuthProvider};
50use turbomcp_protocol::{Error as McpError, Result as McpResult};
51
52/// Authentication manager for coordinating multiple authentication providers
53///
54/// # MCP Specification Compliance
55///
56/// This manager implements **stateless** authentication per MCP spec (RFC 9728).
57/// No server-side session state is maintained. All authentication decisions are made
58/// by validating credentials on EVERY request.
59#[derive(Debug)]
60pub struct AuthManager {
61    /// Authentication configuration
62    config: AuthConfig,
63    /// Registered authentication providers
64    providers: Arc<RwLock<HashMap<String, Arc<dyn AuthProvider>>>>,
65}
66
67impl AuthManager {
68    /// Create a new authentication manager
69    ///
70    /// # MCP Specification Compliance
71    ///
72    /// Creates a stateless authentication manager per MCP spec.
73    /// No server-side session state is maintained.
74    #[must_use]
75    pub fn new(config: AuthConfig) -> Self {
76        Self {
77            config,
78            providers: Arc::new(RwLock::new(HashMap::new())),
79        }
80    }
81
82    /// Add an authentication provider
83    pub async fn add_provider(&self, provider: Arc<dyn AuthProvider>) {
84        let name = provider.name().to_string();
85        self.providers.write().await.insert(name, provider);
86    }
87
88    /// Remove an authentication provider
89    pub async fn remove_provider(&self, name: &str) -> bool {
90        self.providers.write().await.remove(name).is_some()
91    }
92
93    /// List available providers
94    pub async fn list_providers(&self) -> Vec<String> {
95        self.providers.read().await.keys().cloned().collect()
96    }
97
98    /// Authenticate user with credentials
99    ///
100    /// # MCP Specification Compliance
101    ///
102    /// Authenticates the user and returns an `AuthContext`.
103    /// **NO server-side session state is created** - per MCP stateless requirement.
104    ///
105    /// The returned `AuthContext` contains a token (if applicable) that the client
106    /// must include in subsequent requests via the `Authorization` header.
107    ///
108    /// # Example
109    ///
110    /// ```rust,no_run
111    /// # use turbomcp_auth::{AuthManager, AuthCredentials, config::{AuthConfig, AuthorizationConfig}};
112    /// # use std::collections::HashMap;
113    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
114    /// # let config = AuthConfig {
115    /// #     enabled: true,
116    /// #     providers: vec![],
117    /// #     authorization: AuthorizationConfig {
118    /// #         rbac_enabled: false,
119    /// #         default_roles: vec![],
120    /// #         inheritance_rules: HashMap::new(),
121    /// #         resource_permissions: HashMap::new(),
122    /// #     },
123    /// # };
124    /// # let manager = AuthManager::new(config);
125    /// let credentials = AuthCredentials::ApiKey {
126    ///     key: "secret_key".to_string(),
127    /// };
128    ///
129    /// let auth_context = manager.authenticate("api", credentials).await?;
130    ///
131    /// // Extract token for subsequent requests
132    /// if let Some(token_info) = &auth_context.token {
133    ///     let access_token = &token_info.access_token;
134    ///     // Client must send: Authorization: Bearer {access_token}
135    /// }
136    /// # Ok(())
137    /// # }
138    /// ```
139    pub async fn authenticate(
140        &self,
141        provider_name: &str,
142        credentials: AuthCredentials,
143    ) -> McpResult<UnifiedAuthContext> {
144        if !self.config.enabled {
145            return Err(McpError::internal("Authentication is disabled".to_string()));
146        }
147
148        let providers = self.providers.read().await;
149        let provider = providers
150            .get(provider_name)
151            .ok_or_else(|| McpError::internal(format!("Provider '{provider_name}' not found")))?;
152
153        let mut auth_context = provider.authenticate(credentials).await?;
154
155        // Apply default roles if configured
156        if auth_context.roles.is_empty() {
157            auth_context.roles = self.config.authorization.default_roles.clone();
158        }
159
160        // MCP Spec: Stateless authentication - NO session storage
161        // Client must include token in Authorization header on every request
162        Ok(auth_context)
163    }
164
165    /// Validate token and get authentication context
166    ///
167    /// # MCP Specification Compliance
168    ///
169    /// Validates the token on EVERY request per MCP stateless requirement.
170    /// This method MUST be called for each incoming request to ensure the token
171    /// is still valid (not expired, not revoked, etc.).
172    ///
173    /// # Arguments
174    ///
175    /// * `token` - The access token to validate (from Authorization header)
176    /// * `provider_name` - Optional provider name (if known). If None, tries all providers.
177    ///
178    /// # Example
179    ///
180    /// ```rust,no_run
181    /// # use turbomcp_auth::AuthManager;
182    /// # async fn handle_request(manager: &AuthManager, auth_header: &str) -> Result<(), Box<dyn std::error::Error>> {
183    /// // Extract token from Authorization header
184    /// let token = auth_header.strip_prefix("Bearer ").unwrap();
185    ///
186    /// // Validate token on EVERY request (stateless)
187    /// let auth_context = manager.validate_token(token, None).await?;
188    ///
189    /// // Use auth_context for authorization decisions
190    /// println!("Authenticated user: {}", auth_context.user.username);
191    /// # Ok(())
192    /// # }
193    /// ```
194    pub async fn validate_token(
195        &self,
196        token: &str,
197        provider_name: Option<&str>,
198    ) -> McpResult<UnifiedAuthContext> {
199        if !self.config.enabled {
200            return Err(McpError::internal("Authentication is disabled".to_string()));
201        }
202
203        let providers = self.providers.read().await;
204
205        if let Some(provider_name) = provider_name {
206            let provider = providers.get(provider_name).ok_or_else(|| {
207                McpError::internal(format!("Provider '{provider_name}' not found"))
208            })?;
209            provider.validate_token(token).await
210        } else {
211            // Try all providers
212            for provider in providers.values() {
213                if let Ok(auth_context) = provider.validate_token(token).await {
214                    return Ok(auth_context);
215                }
216            }
217            Err(McpError::internal("Token validation failed".to_string()))
218        }
219    }
220
221    /// Check if user has permission
222    #[must_use]
223    pub fn check_permission(&self, context: &UnifiedAuthContext, permission: &str) -> bool {
224        context.permissions.contains(&permission.to_string())
225            || context.roles.iter().any(|role| {
226                self.config
227                    .authorization
228                    .inheritance_rules
229                    .get(role)
230                    .is_some_and(|perms| perms.contains(&permission.to_string()))
231            })
232    }
233
234    /// Check if user has role
235    #[must_use]
236    pub fn check_role(&self, context: &UnifiedAuthContext, role: &str) -> bool {
237        context.roles.contains(&role.to_string())
238    }
239}
240
241// Note: PKCE functionality is handled by the oauth2 crate's built-in
242// PkceCodeChallenge::new_random_sha256() method for maximum security
243
244/// Global authentication manager
245static GLOBAL_AUTH_MANAGER: once_cell::sync::Lazy<tokio::sync::RwLock<Option<Arc<AuthManager>>>> =
246    once_cell::sync::Lazy::new(|| tokio::sync::RwLock::new(None));
247
248/// Set the global authentication manager
249pub async fn set_global_auth_manager(manager: Arc<AuthManager>) {
250    *GLOBAL_AUTH_MANAGER.write().await = Some(manager);
251}
252
253/// Get the global authentication manager
254pub async fn global_auth_manager() -> Option<Arc<AuthManager>> {
255    GLOBAL_AUTH_MANAGER.read().await.clone()
256}
257
258/// Convenience function to check authentication
259pub async fn check_auth(token: &str) -> McpResult<UnifiedAuthContext> {
260    if let Some(manager) = global_auth_manager().await {
261        manager.validate_token(token, None).await
262    } else {
263        Err(McpError::internal(
264            "Authentication manager not initialized".to_string(),
265        ))
266    }
267}
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272    use crate::{
273        config::{AuthorizationConfig, OAuth2Config, OAuth2FlowType, SecurityLevel},
274        providers::ApiKeyProvider,
275        types::UserInfo,
276    };
277    use std::collections::HashMap;
278
279    #[test]
280    fn test_oauth2_config() {
281        let config = OAuth2Config {
282            client_id: "test_client".to_string(),
283            client_secret: "test_secret".to_string().into(),
284            auth_url: "https://auth.example.com/oauth/authorize".to_string(),
285            token_url: "https://auth.example.com/oauth/token".to_string(),
286            revocation_url: None,
287            redirect_uri: "http://localhost:8080/callback".to_string(),
288            scopes: vec!["read".to_string(), "write".to_string()],
289            flow_type: OAuth2FlowType::AuthorizationCode,
290            additional_params: HashMap::new(),
291            security_level: SecurityLevel::Standard,
292            mcp_resource_uri: None,
293            auto_resource_indicators: false,
294            #[cfg(feature = "dpop")]
295            dpop_config: None,
296        };
297
298        assert_eq!(config.client_id, "test_client");
299        assert_eq!(config.flow_type, OAuth2FlowType::AuthorizationCode);
300    }
301
302    #[test]
303    fn test_oauth2_pkce_integration() {
304        // Test that oauth2 crate PKCE functionality works as expected
305        let (challenge1, _verifier1) = oauth2::PkceCodeChallenge::new_random_sha256();
306        let (challenge2, _verifier2) = oauth2::PkceCodeChallenge::new_random_sha256();
307
308        // Each PKCE challenge should be unique
309        assert_ne!(challenge1.as_str(), challenge2.as_str());
310        assert!(!challenge1.as_str().is_empty());
311        assert!(!challenge2.as_str().is_empty());
312    }
313
314    #[tokio::test]
315    async fn test_api_key_provider() {
316        let provider = ApiKeyProvider::new("test_api".to_string());
317
318        let user_info = UserInfo {
319            id: "user123".to_string(),
320            username: "testuser".to_string(),
321            email: Some("test@example.com".to_string()),
322            display_name: Some("Test User".to_string()),
323            avatar_url: None,
324            metadata: HashMap::new(),
325        };
326
327        provider
328            .add_api_key("test_key_123".to_string(), user_info.clone())
329            .await;
330
331        let credentials = AuthCredentials::ApiKey {
332            key: "test_key_123".to_string(),
333        };
334
335        let auth_result = provider.authenticate(credentials).await;
336        assert!(auth_result.is_ok());
337
338        let context = auth_result.unwrap();
339        assert_eq!(context.user.username, "testuser");
340        assert_eq!(context.provider, "test_api");
341    }
342
343    #[tokio::test]
344    async fn test_auth_manager() {
345        let config = AuthConfig {
346            enabled: true,
347            providers: vec![],
348            authorization: AuthorizationConfig {
349                rbac_enabled: true,
350                default_roles: vec!["user".to_string()],
351                inheritance_rules: HashMap::new(),
352                resource_permissions: HashMap::new(),
353            },
354        };
355
356        let manager = AuthManager::new(config);
357        let api_provider = Arc::new(ApiKeyProvider::new("api".to_string()));
358        manager.add_provider(api_provider.clone()).await;
359
360        let providers = manager.list_providers().await;
361        assert!(providers.contains(&"api".to_string()));
362    }
363}