turbomcp_auth/
manager.rs

1//! Authentication Manager
2//!
3//! Central authentication manager for coordinating multiple authentication providers.
4//!
5//! # MCP Compliance
6//!
7//! Per MCP specification (2025-06-18), authentication is **stateless**.
8//! Each request must include valid credentials (Bearer token in Authorization header).
9//! This manager does NOT maintain server-side session state for authentication decisions.
10//!
11//! ## Stateless Authentication Flow
12//!
13//! ```rust,no_run
14//! # use turbomcp_auth::{AuthManager, AuthCredentials, config::AuthConfig};
15//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
16//! # let config = AuthConfig {
17//! #     enabled: true,
18//! #     providers: vec![],
19//! #     authorization: Default::default(),
20//! # };
21//! # let manager = AuthManager::new(config);
22//! # let credentials = AuthCredentials::ApiKey { key: "test".to_string() };
23//! // 1. Authenticate user and get auth context
24//! let auth_context = manager.authenticate("oauth2", credentials).await?;
25//!
26//! // 2. Extract token from auth context
27//! let token = auth_context.token.as_ref().unwrap().access_token.clone();
28//!
29//! // 3. On subsequent requests, validate token EVERY TIME
30//! let validated_context = manager.validate_token(&token, Some("oauth2")).await?;
31//! // ✅ Token validated via provider - truly stateless
32//! # Ok(())
33//! # }
34//! ```
35
36use std::collections::HashMap;
37use std::sync::Arc;
38
39use tokio::sync::RwLock;
40
41use super::config::AuthConfig;
42use super::context::AuthContext as UnifiedAuthContext; // Unified AuthContext for external API
43use super::types::{AuthCredentials, AuthProvider};
44use turbomcp_protocol::{Error as McpError, Result as McpResult};
45
46/// Authentication manager for coordinating multiple authentication providers
47///
48/// # MCP Specification Compliance
49///
50/// This manager implements **stateless** authentication per MCP spec (RFC 9728).
51/// No server-side session state is maintained. All authentication decisions are made
52/// by validating credentials on EVERY request.
53#[derive(Debug)]
54pub struct AuthManager {
55    /// Authentication configuration
56    config: AuthConfig,
57    /// Registered authentication providers
58    providers: Arc<RwLock<HashMap<String, Arc<dyn AuthProvider>>>>,
59}
60
61impl AuthManager {
62    /// Create a new authentication manager
63    ///
64    /// # MCP Specification Compliance
65    ///
66    /// Creates a stateless authentication manager per MCP spec.
67    /// No server-side session state is maintained.
68    #[must_use]
69    pub fn new(config: AuthConfig) -> Self {
70        Self {
71            config,
72            providers: Arc::new(RwLock::new(HashMap::new())),
73        }
74    }
75
76    /// Add an authentication provider
77    pub async fn add_provider(&self, provider: Arc<dyn AuthProvider>) {
78        let name = provider.name().to_string();
79        self.providers.write().await.insert(name, provider);
80    }
81
82    /// Remove an authentication provider
83    pub async fn remove_provider(&self, name: &str) -> bool {
84        self.providers.write().await.remove(name).is_some()
85    }
86
87    /// List available providers
88    pub async fn list_providers(&self) -> Vec<String> {
89        self.providers.read().await.keys().cloned().collect()
90    }
91
92    /// Authenticate user with credentials
93    ///
94    /// # MCP Specification Compliance
95    ///
96    /// Authenticates the user and returns an `AuthContext`.
97    /// **NO server-side session state is created** - per MCP stateless requirement.
98    ///
99    /// The returned `AuthContext` contains a token (if applicable) that the client
100    /// must include in subsequent requests via the `Authorization` header.
101    ///
102    /// # Example
103    ///
104    /// ```rust,no_run
105    /// # use turbomcp_auth::{AuthManager, AuthCredentials, config::AuthConfig};
106    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
107    /// # let config = AuthConfig {
108    /// #     enabled: true,
109    /// #     providers: vec![],
110    /// #     authorization: Default::default(),
111    /// # };
112    /// # let manager = AuthManager::new(config);
113    /// let credentials = AuthCredentials::ApiKey {
114    ///     key: "secret_key".to_string(),
115    /// };
116    ///
117    /// let auth_context = manager.authenticate("api", credentials).await?;
118    ///
119    /// // Extract token for subsequent requests
120    /// if let Some(token_info) = &auth_context.token {
121    ///     let access_token = &token_info.access_token;
122    ///     // Client must send: Authorization: Bearer {access_token}
123    /// }
124    /// # Ok(())
125    /// # }
126    /// ```
127    pub async fn authenticate(
128        &self,
129        provider_name: &str,
130        credentials: AuthCredentials,
131    ) -> McpResult<UnifiedAuthContext> {
132        if !self.config.enabled {
133            return Err(McpError::internal("Authentication is disabled".to_string()));
134        }
135
136        let providers = self.providers.read().await;
137        let provider = providers
138            .get(provider_name)
139            .ok_or_else(|| McpError::internal(format!("Provider '{provider_name}' not found")))?;
140
141        let mut auth_context = provider.authenticate(credentials).await?;
142
143        // Apply default roles if configured
144        if auth_context.roles.is_empty() {
145            auth_context.roles = self.config.authorization.default_roles.clone();
146        }
147
148        // MCP Spec: Stateless authentication - NO session storage
149        // Client must include token in Authorization header on every request
150        Ok(auth_context)
151    }
152
153    /// Validate token and get authentication context
154    ///
155    /// # MCP Specification Compliance
156    ///
157    /// Validates the token on EVERY request per MCP stateless requirement.
158    /// This method MUST be called for each incoming request to ensure the token
159    /// is still valid (not expired, not revoked, etc.).
160    ///
161    /// # Arguments
162    ///
163    /// * `token` - The access token to validate (from Authorization header)
164    /// * `provider_name` - Optional provider name (if known). If None, tries all providers.
165    ///
166    /// # Example
167    ///
168    /// ```rust,no_run
169    /// # use turbomcp_auth::AuthManager;
170    /// # async fn handle_request(manager: &AuthManager, auth_header: &str) -> Result<(), Box<dyn std::error::Error>> {
171    /// // Extract token from Authorization header
172    /// let token = auth_header.strip_prefix("Bearer ").unwrap();
173    ///
174    /// // Validate token on EVERY request (stateless)
175    /// let auth_context = manager.validate_token(token, None).await?;
176    ///
177    /// // Use auth_context for authorization decisions
178    /// println!("Authenticated user: {}", auth_context.user.username);
179    /// # Ok(())
180    /// # }
181    /// ```
182    pub async fn validate_token(
183        &self,
184        token: &str,
185        provider_name: Option<&str>,
186    ) -> McpResult<UnifiedAuthContext> {
187        if !self.config.enabled {
188            return Err(McpError::internal("Authentication is disabled".to_string()));
189        }
190
191        let providers = self.providers.read().await;
192
193        if let Some(provider_name) = provider_name {
194            let provider = providers.get(provider_name).ok_or_else(|| {
195                McpError::internal(format!("Provider '{provider_name}' not found"))
196            })?;
197            provider.validate_token(token).await
198        } else {
199            // Try all providers
200            for provider in providers.values() {
201                if let Ok(auth_context) = provider.validate_token(token).await {
202                    return Ok(auth_context);
203                }
204            }
205            Err(McpError::internal("Token validation failed".to_string()))
206        }
207    }
208
209    /// Check if user has permission
210    #[must_use]
211    pub fn check_permission(&self, context: &UnifiedAuthContext, permission: &str) -> bool {
212        context.permissions.contains(&permission.to_string())
213            || context.roles.iter().any(|role| {
214                self.config
215                    .authorization
216                    .inheritance_rules
217                    .get(role)
218                    .is_some_and(|perms| perms.contains(&permission.to_string()))
219            })
220    }
221
222    /// Check if user has role
223    #[must_use]
224    pub fn check_role(&self, context: &UnifiedAuthContext, role: &str) -> bool {
225        context.roles.contains(&role.to_string())
226    }
227}
228
229// Note: PKCE functionality is handled by the oauth2 crate's built-in
230// PkceCodeChallenge::new_random_sha256() method for maximum security
231
232/// Global authentication manager
233static GLOBAL_AUTH_MANAGER: once_cell::sync::Lazy<tokio::sync::RwLock<Option<Arc<AuthManager>>>> =
234    once_cell::sync::Lazy::new(|| tokio::sync::RwLock::new(None));
235
236/// Set the global authentication manager
237pub async fn set_global_auth_manager(manager: Arc<AuthManager>) {
238    *GLOBAL_AUTH_MANAGER.write().await = Some(manager);
239}
240
241/// Get the global authentication manager
242pub async fn global_auth_manager() -> Option<Arc<AuthManager>> {
243    GLOBAL_AUTH_MANAGER.read().await.clone()
244}
245
246/// Convenience function to check authentication
247pub async fn check_auth(token: &str) -> McpResult<UnifiedAuthContext> {
248    if let Some(manager) = global_auth_manager().await {
249        manager.validate_token(token, None).await
250    } else {
251        Err(McpError::internal(
252            "Authentication manager not initialized".to_string(),
253        ))
254    }
255}
256
257#[cfg(test)]
258mod tests {
259    use super::*;
260    use crate::{
261        config::{AuthorizationConfig, OAuth2Config, OAuth2FlowType, SecurityLevel},
262        providers::ApiKeyProvider,
263        types::UserInfo,
264    };
265    use std::collections::HashMap;
266
267    #[test]
268    fn test_oauth2_config() {
269        let config = OAuth2Config {
270            client_id: "test_client".to_string(),
271            client_secret: "test_secret".to_string().into(),
272            auth_url: "https://auth.example.com/oauth/authorize".to_string(),
273            token_url: "https://auth.example.com/oauth/token".to_string(),
274            revocation_url: None,
275            redirect_uri: "http://localhost:8080/callback".to_string(),
276            scopes: vec!["read".to_string(), "write".to_string()],
277            flow_type: OAuth2FlowType::AuthorizationCode,
278            additional_params: HashMap::new(),
279            security_level: SecurityLevel::Standard,
280            mcp_resource_uri: None,
281            auto_resource_indicators: false,
282            #[cfg(feature = "dpop")]
283            dpop_config: None,
284        };
285
286        assert_eq!(config.client_id, "test_client");
287        assert_eq!(config.flow_type, OAuth2FlowType::AuthorizationCode);
288    }
289
290    #[test]
291    fn test_oauth2_pkce_integration() {
292        // Test that oauth2 crate PKCE functionality works as expected
293        let (challenge1, _verifier1) = oauth2::PkceCodeChallenge::new_random_sha256();
294        let (challenge2, _verifier2) = oauth2::PkceCodeChallenge::new_random_sha256();
295
296        // Each PKCE challenge should be unique
297        assert_ne!(challenge1.as_str(), challenge2.as_str());
298        assert!(!challenge1.as_str().is_empty());
299        assert!(!challenge2.as_str().is_empty());
300    }
301
302    #[tokio::test]
303    async fn test_api_key_provider() {
304        let provider = ApiKeyProvider::new("test_api".to_string());
305
306        let user_info = UserInfo {
307            id: "user123".to_string(),
308            username: "testuser".to_string(),
309            email: Some("test@example.com".to_string()),
310            display_name: Some("Test User".to_string()),
311            avatar_url: None,
312            metadata: HashMap::new(),
313        };
314
315        provider
316            .add_api_key("test_key_123".to_string(), user_info.clone())
317            .await;
318
319        let credentials = AuthCredentials::ApiKey {
320            key: "test_key_123".to_string(),
321        };
322
323        let auth_result = provider.authenticate(credentials).await;
324        assert!(auth_result.is_ok());
325
326        let context = auth_result.unwrap();
327        assert_eq!(context.user.username, "testuser");
328        assert_eq!(context.provider, "test_api");
329    }
330
331    #[tokio::test]
332    async fn test_auth_manager() {
333        let config = AuthConfig {
334            enabled: true,
335            providers: vec![],
336            authorization: AuthorizationConfig {
337                rbac_enabled: true,
338                default_roles: vec!["user".to_string()],
339                inheritance_rules: HashMap::new(),
340                resource_permissions: HashMap::new(),
341            },
342        };
343
344        let manager = AuthManager::new(config);
345        let api_provider = Arc::new(ApiKeyProvider::new("api".to_string()));
346        manager.add_provider(api_provider.clone()).await;
347
348        let providers = manager.list_providers().await;
349        assert!(providers.contains(&"api".to_string()));
350    }
351}