turbomcp_auth/providers/
oauth2.rs

1//! OAuth 2.1 Authentication Provider
2//!
3//! Implements the AuthProvider trait for OAuth 2.1 authorization flows.
4
5use std::sync::Arc;
6use std::time::SystemTime;
7
8use async_trait::async_trait;
9use tokio::sync::RwLock;
10use uuid::Uuid;
11
12use super::super::config::AuthProviderType;
13use super::super::context::AuthContext;
14use super::super::oauth2::OAuth2Client;
15use super::super::types::{AuthCredentials, AuthProvider, TokenInfo, UserInfo};
16use turbomcp_protocol::{Error as McpError, Result as McpResult};
17
18/// OAuth 2.1 authentication provider
19#[derive(Debug)]
20pub struct OAuth2Provider {
21    /// Provider name
22    name: String,
23    /// OAuth2 client for handling flows
24    client: Arc<OAuth2Client>,
25    /// MCP server canonical URI (RFC 8707) - required for token binding
26    #[allow(dead_code)]
27    resource_uri: String,
28    /// HTTP client for userinfo endpoint
29    http_client: reqwest::Client,
30    /// Token cache to avoid redundant requests
31    token_cache: Arc<RwLock<std::collections::HashMap<String, CachedToken>>>,
32}
33
34/// Cached token with metadata
35#[derive(Debug, Clone)]
36struct CachedToken {
37    /// The token info
38    token: TokenInfo,
39    /// When it was cached
40    cached_at: SystemTime,
41}
42
43impl OAuth2Provider {
44    /// Create a new OAuth2 provider with MCP server resource URI
45    ///
46    /// # Arguments
47    ///
48    /// * `name` - Provider name for identification
49    /// * `client` - OAuth2 client configured for the provider
50    /// * `resource_uri` - **MCP server canonical URI** (RFC 8707) - e.g., "https://mcp.example.com"
51    ///
52    /// # MCP Requirement
53    ///
54    /// The resource URI binds all tokens to the specific MCP server, preventing
55    /// token misuse across service boundaries per RFC 8707.
56    pub fn new(name: String, client: Arc<OAuth2Client>, resource_uri: String) -> Self {
57        Self {
58            name,
59            client,
60            resource_uri,
61            http_client: reqwest::Client::new(),
62            token_cache: Arc::new(RwLock::new(std::collections::HashMap::new())),
63        }
64    }
65
66    /// Get user info from the OAuth provider's userinfo endpoint
67    async fn fetch_user_info(&self, access_token: &str) -> McpResult<UserInfo> {
68        let provider_config = self.client.provider_config();
69        let userinfo_endpoint = provider_config.userinfo_endpoint.as_ref().ok_or_else(|| {
70            McpError::internal("Provider does not support userinfo endpoint".to_string())
71        })?;
72
73        let response = self
74            .http_client
75            .get(userinfo_endpoint)
76            .bearer_auth(access_token)
77            .send()
78            .await
79            .map_err(|e| McpError::internal(format!("Userinfo request failed: {e}")))?;
80
81        if !response.status().is_success() {
82            return Err(McpError::internal(format!(
83                "Userinfo endpoint returned status {}",
84                response.status()
85            )));
86        }
87
88        let user_data: serde_json::Value = response
89            .json()
90            .await
91            .map_err(|e| McpError::internal(format!("Failed to parse userinfo response: {e}")))?;
92
93        // Extract user information from response (varies by provider)
94        let user_id = user_data
95            .get("sub")
96            .or_else(|| user_data.get("id"))
97            .or_else(|| user_data.get("user_id"))
98            .and_then(|v| v.as_str())
99            .unwrap_or(&Uuid::new_v4().to_string())
100            .to_string();
101
102        let username = user_data
103            .get("name")
104            .or_else(|| user_data.get("login"))
105            .or_else(|| user_data.get("preferred_username"))
106            .and_then(|v| v.as_str())
107            .unwrap_or(&user_id)
108            .to_string();
109
110        let email = user_data
111            .get("email")
112            .and_then(|v| v.as_str())
113            .map(|s| s.to_string());
114
115        let display_name = user_data
116            .get("name")
117            .and_then(|v| v.as_str())
118            .map(|s| s.to_string());
119
120        let avatar_url = user_data
121            .get("picture")
122            .or_else(|| user_data.get("avatar_url"))
123            .and_then(|v| v.as_str())
124            .map(|s| s.to_string());
125
126        Ok(UserInfo {
127            id: user_id,
128            username,
129            email,
130            display_name,
131            avatar_url,
132            metadata: std::collections::HashMap::new(),
133        })
134    }
135}
136
137#[async_trait]
138impl AuthProvider for OAuth2Provider {
139    fn name(&self) -> &str {
140        &self.name
141    }
142
143    fn provider_type(&self) -> AuthProviderType {
144        AuthProviderType::OAuth2
145    }
146
147    async fn authenticate(&self, credentials: AuthCredentials) -> McpResult<AuthContext> {
148        match credentials {
149            AuthCredentials::OAuth2Code { code: _, state: _ } => {
150                // In a real implementation, we'd validate state parameter
151                // For now, we need the PKCE code verifier which should be stored
152                // This is a simplified implementation - in practice, code_verifier
153                // would come from session storage based on state parameter
154
155                // Exchange code for token using empty verifier (in real implementation,
156                // this would come from stored session state)
157                // For now, return an error - the flow should be:
158                // 1. Client calls authorization_code_flow() -> gets code_verifier
159                // 2. User redirects with code
160                // 3. Client calls exchange_code_for_token() with code_verifier
161                // 4. Provider stores token and creates AuthContext
162
163                Err(McpError::internal(
164                    "OAuth2 authentication requires exchange_code_for_token() method. \
165                     Use OAuth2Client.authorization_code_flow() and \
166                     OAuth2Client.exchange_code_for_token() directly."
167                        .to_string(),
168                ))
169            }
170            _ => Err(McpError::validation(
171                "OAuth2 provider only accepts OAuth2Code credentials".to_string(),
172            )),
173        }
174    }
175
176    async fn validate_token(&self, token: &str) -> McpResult<AuthContext> {
177        // Check cache first
178        {
179            let cache = self.token_cache.read().await;
180            if let Some(cached) = cache.get(token) {
181                let elapsed = cached
182                    .cached_at
183                    .elapsed()
184                    .unwrap_or(std::time::Duration::from_secs(0));
185                // Cache for 5 minutes
186                if elapsed < std::time::Duration::from_secs(300) {
187                    let user_info = self.fetch_user_info(token).await?;
188                    let request_id = Uuid::new_v4().to_string();
189                    let mut builder = AuthContext::builder()
190                        .subject(user_info.id.clone())
191                        .user(user_info)
192                        .roles(vec!["oauth_user".to_string()])
193                        .permissions(vec!["api_access".to_string()])
194                        .request_id(request_id)
195                        .token(cached.token.clone())
196                        .provider(self.name.clone())
197                        .authenticated_at(SystemTime::now());
198
199                    if let Some(secs) = cached.token.expires_in {
200                        builder = builder
201                            .expires_at(SystemTime::now() + std::time::Duration::from_secs(secs));
202                    }
203
204                    return builder
205                        .build()
206                        .map_err(|e| McpError::internal(e.to_string()));
207                }
208            }
209        }
210
211        // Token not in cache or cache expired - fetch user info to validate
212        let user_info = self.fetch_user_info(token).await?;
213        let request_id = Uuid::new_v4().to_string();
214
215        AuthContext::builder()
216            .subject(user_info.id.clone())
217            .user(user_info)
218            .roles(vec!["oauth_user".to_string()])
219            .permissions(vec!["api_access".to_string()])
220            .request_id(request_id)
221            .provider(self.name.clone())
222            .authenticated_at(SystemTime::now())
223            .build()
224            .map_err(|e| McpError::internal(e.to_string()))
225    }
226
227    async fn refresh_token(&self, refresh_token: &str) -> McpResult<TokenInfo> {
228        // Refresh token using the OAuth2 client
229        // Note: RFC 8707 resource parameter is handled in OAuth2Client::refresh_access_token
230        self.client.refresh_access_token(refresh_token).await
231    }
232
233    async fn revoke_token(&self, token: &str) -> McpResult<()> {
234        // Remove from cache first
235        let cached_token = self.token_cache.write().await.remove(token);
236
237        // If we have the full token info, revoke it at the provider (RFC 7009)
238        if let Some(cached) = cached_token {
239            self.client.revoke_token(&cached.token).await?;
240        } else {
241            // If not in cache, create a minimal TokenInfo for revocation
242            let token_info = TokenInfo {
243                access_token: token.to_string(),
244                token_type: "Bearer".to_string(),
245                refresh_token: None,
246                expires_in: None,
247                scope: None,
248            };
249            self.client.revoke_token(&token_info).await?;
250        }
251
252        Ok(())
253    }
254
255    async fn get_user_info(&self, token: &str) -> McpResult<UserInfo> {
256        self.fetch_user_info(token).await
257    }
258}
259
260#[cfg(test)]
261mod tests {
262    use super::*;
263    use crate::config::{OAuth2Config, ProviderType};
264
265    #[test]
266    fn test_oauth2_provider_creation() {
267        let config = OAuth2Config {
268            client_id: "test-client".to_string(),
269            client_secret: "test-secret".to_string().into(),
270            auth_url: "https://provider.example.com/oauth/authorize".to_string(),
271            token_url: "https://provider.example.com/oauth/token".to_string(),
272            revocation_url: Some("https://provider.example.com/oauth/revoke".to_string()),
273            redirect_uri: "http://localhost:8080/callback".to_string(),
274            scopes: vec!["openid".to_string(), "profile".to_string()],
275            flow_type: crate::config::OAuth2FlowType::AuthorizationCode,
276            additional_params: std::collections::HashMap::new(),
277            security_level: Default::default(),
278            #[cfg(feature = "dpop")]
279            dpop_config: None,
280            mcp_resource_uri: None,
281            auto_resource_indicators: true,
282        };
283
284        let oauth_client = OAuth2Client::new(&config, ProviderType::Generic)
285            .expect("Failed to create OAuth2Client");
286        let provider = OAuth2Provider::new(
287            "test".to_string(),
288            Arc::new(oauth_client),
289            "https://mcp.example.com".to_string(), // MCP server resource URI
290        );
291
292        assert_eq!(provider.name(), "test");
293        assert_eq!(provider.provider_type(), AuthProviderType::OAuth2);
294    }
295}