turbomcp_auth/
manager.rs

1//! Authentication Manager
2//!
3//! Central authentication manager for coordinating multiple authentication providers.
4
5use std::collections::HashMap;
6use std::sync::Arc;
7use std::time::{Duration, SystemTime};
8
9use tokio::sync::RwLock;
10
11use super::config::AuthConfig;
12use super::types::{AuthContext, AuthCredentials, AuthProvider};
13use turbomcp_protocol::{Error as McpError, Result as McpResult};
14
15/// Authentication manager for coordinating multiple authentication providers
16#[derive(Debug)]
17pub struct AuthManager {
18    /// Authentication configuration
19    config: AuthConfig,
20    /// Registered authentication providers
21    providers: Arc<RwLock<HashMap<String, Arc<dyn AuthProvider>>>>,
22    /// Active sessions
23    sessions: Arc<RwLock<HashMap<String, AuthContext>>>,
24    /// Session cleanup task handle
25    _cleanup_handle: Option<tokio::task::JoinHandle<()>>,
26}
27
28impl AuthManager {
29    /// Create a new authentication manager
30    #[must_use]
31    pub fn new(config: AuthConfig) -> Self {
32        let manager = Self {
33            config,
34            providers: Arc::new(RwLock::new(HashMap::new())),
35            sessions: Arc::new(RwLock::new(HashMap::new())),
36            _cleanup_handle: None,
37        };
38
39        // Start session cleanup task
40        let sessions_clone = manager.sessions.clone();
41        let cleanup_handle = tokio::spawn(async move {
42            let mut interval = tokio::time::interval(Duration::from_secs(300)); // 5 minutes
43            loop {
44                interval.tick().await;
45                let now = SystemTime::now();
46                let mut sessions = sessions_clone.write().await;
47                sessions
48                    .retain(|_, context| context.expires_at.is_none_or(|expires| expires > now));
49            }
50        });
51
52        Self {
53            _cleanup_handle: Some(cleanup_handle),
54            ..manager
55        }
56    }
57
58    /// Add an authentication provider
59    pub async fn add_provider(&self, provider: Arc<dyn AuthProvider>) {
60        let name = provider.name().to_string();
61        self.providers.write().await.insert(name, provider);
62    }
63
64    /// Remove an authentication provider
65    pub async fn remove_provider(&self, name: &str) -> bool {
66        self.providers.write().await.remove(name).is_some()
67    }
68
69    /// List available providers
70    pub async fn list_providers(&self) -> Vec<String> {
71        self.providers.read().await.keys().cloned().collect()
72    }
73
74    /// Authenticate user with credentials
75    pub async fn authenticate(
76        &self,
77        provider_name: &str,
78        credentials: AuthCredentials,
79    ) -> McpResult<AuthContext> {
80        if !self.config.enabled {
81            return Err(McpError::internal("Authentication is disabled".to_string()));
82        }
83
84        let providers = self.providers.read().await;
85        let provider = providers
86            .get(provider_name)
87            .ok_or_else(|| McpError::internal(format!("Provider '{provider_name}' not found")))?;
88
89        let mut auth_context = provider.authenticate(credentials).await?;
90
91        // Apply default roles if configured
92        if auth_context.roles.is_empty() {
93            auth_context.roles = self.config.authorization.default_roles.clone();
94        }
95
96        // Store session
97        let session_id = auth_context.session_id.clone();
98        self.sessions
99            .write()
100            .await
101            .insert(session_id, auth_context.clone());
102
103        Ok(auth_context)
104    }
105
106    /// Validate token and get authentication context
107    pub async fn validate_token(
108        &self,
109        token: &str,
110        provider_name: Option<&str>,
111    ) -> McpResult<AuthContext> {
112        if !self.config.enabled {
113            return Err(McpError::internal("Authentication is disabled".to_string()));
114        }
115
116        let providers = self.providers.read().await;
117
118        if let Some(provider_name) = provider_name {
119            let provider = providers.get(provider_name).ok_or_else(|| {
120                McpError::internal(format!("Provider '{provider_name}' not found"))
121            })?;
122            provider.validate_token(token).await
123        } else {
124            // Try all providers
125            for provider in providers.values() {
126                if let Ok(context) = provider.validate_token(token).await {
127                    return Ok(context);
128                }
129            }
130            Err(McpError::internal("Token validation failed".to_string()))
131        }
132    }
133
134    /// Get session by ID
135    pub async fn get_session(&self, session_id: &str) -> Option<AuthContext> {
136        self.sessions.read().await.get(session_id).cloned()
137    }
138
139    /// Revoke session
140    pub async fn revoke_session(&self, session_id: &str) -> McpResult<()> {
141        let context = self
142            .sessions
143            .write()
144            .await
145            .remove(session_id)
146            .ok_or_else(|| McpError::internal("Session not found".to_string()))?;
147
148        // Try to revoke token with provider
149        let providers = self.providers.read().await;
150        if let Some(provider) = providers.get(&context.provider)
151            && let Some(token) = &context.token
152        {
153            let _ = provider.revoke_token(&token.access_token).await;
154        }
155
156        Ok(())
157    }
158
159    /// Check if user has permission
160    #[must_use]
161    pub fn check_permission(&self, context: &AuthContext, permission: &str) -> bool {
162        context.permissions.contains(&permission.to_string())
163            || context.roles.iter().any(|role| {
164                self.config
165                    .authorization
166                    .inheritance_rules
167                    .get(role)
168                    .is_some_and(|perms| perms.contains(&permission.to_string()))
169            })
170    }
171
172    /// Check if user has role
173    #[must_use]
174    pub fn check_role(&self, context: &AuthContext, role: &str) -> bool {
175        context.roles.contains(&role.to_string())
176    }
177}
178
179// Note: PKCE functionality is handled by the oauth2 crate's built-in
180// PkceCodeChallenge::new_random_sha256() method for maximum security
181
182/// Global authentication manager
183static GLOBAL_AUTH_MANAGER: once_cell::sync::Lazy<tokio::sync::RwLock<Option<Arc<AuthManager>>>> =
184    once_cell::sync::Lazy::new(|| tokio::sync::RwLock::new(None));
185
186/// Set the global authentication manager
187pub async fn set_global_auth_manager(manager: Arc<AuthManager>) {
188    *GLOBAL_AUTH_MANAGER.write().await = Some(manager);
189}
190
191/// Get the global authentication manager
192pub async fn global_auth_manager() -> Option<Arc<AuthManager>> {
193    GLOBAL_AUTH_MANAGER.read().await.clone()
194}
195
196/// Convenience function to check authentication
197pub async fn check_auth(token: &str) -> McpResult<AuthContext> {
198    if let Some(manager) = global_auth_manager().await {
199        manager.validate_token(token, None).await
200    } else {
201        Err(McpError::internal(
202            "Authentication manager not initialized".to_string(),
203        ))
204    }
205}
206
207#[cfg(test)]
208mod tests {
209    use super::*;
210    use crate::{
211        config::{
212            AuthorizationConfig, OAuth2Config, OAuth2FlowType, SecurityLevel, SessionConfig,
213            SessionStorageType,
214        },
215        providers::ApiKeyProvider,
216        types::UserInfo,
217    };
218    use std::collections::HashMap;
219
220    #[test]
221    fn test_oauth2_config() {
222        let config = OAuth2Config {
223            client_id: "test_client".to_string(),
224            client_secret: "test_secret".to_string(),
225            auth_url: "https://auth.example.com/oauth/authorize".to_string(),
226            token_url: "https://auth.example.com/oauth/token".to_string(),
227            redirect_uri: "http://localhost:8080/callback".to_string(),
228            scopes: vec!["read".to_string(), "write".to_string()],
229            flow_type: OAuth2FlowType::AuthorizationCode,
230            additional_params: HashMap::new(),
231            security_level: SecurityLevel::Standard,
232            mcp_resource_uri: None,
233            auto_resource_indicators: false,
234            #[cfg(feature = "dpop")]
235            dpop_config: None,
236        };
237
238        assert_eq!(config.client_id, "test_client");
239        assert_eq!(config.flow_type, OAuth2FlowType::AuthorizationCode);
240    }
241
242    #[test]
243    fn test_oauth2_pkce_integration() {
244        // Test that oauth2 crate PKCE functionality works as expected
245        let (challenge1, _verifier1) = oauth2::PkceCodeChallenge::new_random_sha256();
246        let (challenge2, _verifier2) = oauth2::PkceCodeChallenge::new_random_sha256();
247
248        // Each PKCE challenge should be unique
249        assert_ne!(challenge1.as_str(), challenge2.as_str());
250        assert!(!challenge1.as_str().is_empty());
251        assert!(!challenge2.as_str().is_empty());
252    }
253
254    #[tokio::test]
255    async fn test_api_key_provider() {
256        let provider = ApiKeyProvider::new("test_api".to_string());
257
258        let user_info = UserInfo {
259            id: "user123".to_string(),
260            username: "testuser".to_string(),
261            email: Some("test@example.com".to_string()),
262            display_name: Some("Test User".to_string()),
263            avatar_url: None,
264            metadata: HashMap::new(),
265        };
266
267        provider
268            .add_api_key("test_key_123".to_string(), user_info.clone())
269            .await;
270
271        let credentials = AuthCredentials::ApiKey {
272            key: "test_key_123".to_string(),
273        };
274
275        let auth_result = provider.authenticate(credentials).await;
276        assert!(auth_result.is_ok());
277
278        let context = auth_result.unwrap();
279        assert_eq!(context.user.username, "testuser");
280        assert_eq!(context.provider, "test_api");
281    }
282
283    #[tokio::test]
284    async fn test_auth_manager() {
285        let config = AuthConfig {
286            enabled: true,
287            providers: vec![],
288            session: SessionConfig {
289                timeout_seconds: 3600,
290                secure_cookies: true,
291                cookie_domain: None,
292                storage: SessionStorageType::Memory,
293                max_sessions_per_user: Some(5),
294            },
295            authorization: AuthorizationConfig {
296                rbac_enabled: true,
297                default_roles: vec!["user".to_string()],
298                inheritance_rules: HashMap::new(),
299                resource_permissions: HashMap::new(),
300            },
301        };
302
303        let manager = AuthManager::new(config);
304        let api_provider = Arc::new(ApiKeyProvider::new("api".to_string()));
305        manager.add_provider(api_provider.clone()).await;
306
307        let providers = manager.list_providers().await;
308        assert!(providers.contains(&"api".to_string()));
309    }
310}