turbomcp_auth/providers/
oauth2.rs1use std::sync::Arc;
6use std::time::SystemTime;
7
8use async_trait::async_trait;
9use tokio::sync::RwLock;
10use uuid::Uuid;
11
12use super::super::config::AuthProviderType;
13use super::super::context::AuthContext;
14use super::super::oauth2::OAuth2Client;
15use super::super::types::{AuthCredentials, AuthProvider, TokenInfo, UserInfo};
16use turbomcp_protocol::{Error as McpError, Result as McpResult};
17
18#[derive(Debug)]
20pub struct OAuth2Provider {
21 name: String,
23 client: Arc<OAuth2Client>,
25 #[allow(dead_code)]
27 resource_uri: String,
28 http_client: reqwest::Client,
30 token_cache: Arc<RwLock<std::collections::HashMap<String, CachedToken>>>,
32}
33
34#[derive(Debug, Clone)]
36struct CachedToken {
37 token: TokenInfo,
39 cached_at: SystemTime,
41}
42
43impl OAuth2Provider {
44 pub fn new(name: String, client: Arc<OAuth2Client>, resource_uri: String) -> Self {
57 Self {
58 name,
59 client,
60 resource_uri,
61 http_client: reqwest::Client::new(),
62 token_cache: Arc::new(RwLock::new(std::collections::HashMap::new())),
63 }
64 }
65
66 async fn fetch_user_info(&self, access_token: &str) -> McpResult<UserInfo> {
68 let provider_config = self.client.provider_config();
69 let userinfo_endpoint = provider_config.userinfo_endpoint.as_ref().ok_or_else(|| {
70 McpError::internal("Provider does not support userinfo endpoint".to_string())
71 })?;
72
73 let response = self
74 .http_client
75 .get(userinfo_endpoint)
76 .bearer_auth(access_token)
77 .send()
78 .await
79 .map_err(|e| McpError::internal(format!("Userinfo request failed: {e}")))?;
80
81 if !response.status().is_success() {
82 return Err(McpError::internal(format!(
83 "Userinfo endpoint returned status {}",
84 response.status()
85 )));
86 }
87
88 let user_data: serde_json::Value = response
89 .json()
90 .await
91 .map_err(|e| McpError::internal(format!("Failed to parse userinfo response: {e}")))?;
92
93 let user_id = user_data
95 .get("sub")
96 .or_else(|| user_data.get("id"))
97 .or_else(|| user_data.get("user_id"))
98 .and_then(|v| v.as_str())
99 .unwrap_or(&Uuid::new_v4().to_string())
100 .to_string();
101
102 let username = user_data
103 .get("name")
104 .or_else(|| user_data.get("login"))
105 .or_else(|| user_data.get("preferred_username"))
106 .and_then(|v| v.as_str())
107 .unwrap_or(&user_id)
108 .to_string();
109
110 let email = user_data
111 .get("email")
112 .and_then(|v| v.as_str())
113 .map(|s| s.to_string());
114
115 let display_name = user_data
116 .get("name")
117 .and_then(|v| v.as_str())
118 .map(|s| s.to_string());
119
120 let avatar_url = user_data
121 .get("picture")
122 .or_else(|| user_data.get("avatar_url"))
123 .and_then(|v| v.as_str())
124 .map(|s| s.to_string());
125
126 Ok(UserInfo {
127 id: user_id,
128 username,
129 email,
130 display_name,
131 avatar_url,
132 metadata: std::collections::HashMap::new(),
133 })
134 }
135}
136
137#[async_trait]
138impl AuthProvider for OAuth2Provider {
139 fn name(&self) -> &str {
140 &self.name
141 }
142
143 fn provider_type(&self) -> AuthProviderType {
144 AuthProviderType::OAuth2
145 }
146
147 async fn authenticate(&self, credentials: AuthCredentials) -> McpResult<AuthContext> {
148 match credentials {
149 AuthCredentials::OAuth2Code { code: _, state: _ } => {
150 Err(McpError::internal(
164 "OAuth2 authentication requires exchange_code_for_token() method. \
165 Use OAuth2Client.authorization_code_flow() and \
166 OAuth2Client.exchange_code_for_token() directly."
167 .to_string(),
168 ))
169 }
170 _ => Err(McpError::validation(
171 "OAuth2 provider only accepts OAuth2Code credentials".to_string(),
172 )),
173 }
174 }
175
176 async fn validate_token(&self, token: &str) -> McpResult<AuthContext> {
177 {
179 let cache = self.token_cache.read().await;
180 if let Some(cached) = cache.get(token) {
181 let elapsed = cached
182 .cached_at
183 .elapsed()
184 .unwrap_or(std::time::Duration::from_secs(0));
185 if elapsed < std::time::Duration::from_secs(300) {
187 let user_info = self.fetch_user_info(token).await?;
188 let request_id = Uuid::new_v4().to_string();
189 let mut builder = AuthContext::builder()
190 .subject(user_info.id.clone())
191 .user(user_info)
192 .roles(vec!["oauth_user".to_string()])
193 .permissions(vec!["api_access".to_string()])
194 .request_id(request_id)
195 .token(cached.token.clone())
196 .provider(self.name.clone())
197 .authenticated_at(SystemTime::now());
198
199 if let Some(secs) = cached.token.expires_in {
200 builder = builder
201 .expires_at(SystemTime::now() + std::time::Duration::from_secs(secs));
202 }
203
204 return builder
205 .build()
206 .map_err(|e| McpError::internal(e.to_string()));
207 }
208 }
209 }
210
211 let user_info = self.fetch_user_info(token).await?;
213 let request_id = Uuid::new_v4().to_string();
214
215 AuthContext::builder()
216 .subject(user_info.id.clone())
217 .user(user_info)
218 .roles(vec!["oauth_user".to_string()])
219 .permissions(vec!["api_access".to_string()])
220 .request_id(request_id)
221 .provider(self.name.clone())
222 .authenticated_at(SystemTime::now())
223 .build()
224 .map_err(|e| McpError::internal(e.to_string()))
225 }
226
227 async fn refresh_token(&self, refresh_token: &str) -> McpResult<TokenInfo> {
228 self.client.refresh_access_token(refresh_token).await
231 }
232
233 async fn revoke_token(&self, token: &str) -> McpResult<()> {
234 let cached_token = self.token_cache.write().await.remove(token);
236
237 if let Some(cached) = cached_token {
239 self.client.revoke_token(&cached.token).await?;
240 } else {
241 let token_info = TokenInfo {
243 access_token: token.to_string(),
244 token_type: "Bearer".to_string(),
245 refresh_token: None,
246 expires_in: None,
247 scope: None,
248 };
249 self.client.revoke_token(&token_info).await?;
250 }
251
252 Ok(())
253 }
254
255 async fn get_user_info(&self, token: &str) -> McpResult<UserInfo> {
256 self.fetch_user_info(token).await
257 }
258}
259
260#[cfg(test)]
261mod tests {
262 use super::*;
263 use crate::config::{OAuth2Config, ProviderType};
264
265 #[test]
266 fn test_oauth2_provider_creation() {
267 let config = OAuth2Config {
268 client_id: "test-client".to_string(),
269 client_secret: "test-secret".to_string().into(),
270 auth_url: "https://provider.example.com/oauth/authorize".to_string(),
271 token_url: "https://provider.example.com/oauth/token".to_string(),
272 revocation_url: Some("https://provider.example.com/oauth/revoke".to_string()),
273 redirect_uri: "http://localhost:8080/callback".to_string(),
274 scopes: vec!["openid".to_string(), "profile".to_string()],
275 flow_type: crate::config::OAuth2FlowType::AuthorizationCode,
276 additional_params: std::collections::HashMap::new(),
277 security_level: Default::default(),
278 #[cfg(feature = "dpop")]
279 dpop_config: None,
280 mcp_resource_uri: None,
281 auto_resource_indicators: true,
282 };
283
284 let oauth_client = OAuth2Client::new(&config, ProviderType::Generic)
285 .expect("Failed to create OAuth2Client");
286 let provider = OAuth2Provider::new(
287 "test".to_string(),
288 Arc::new(oauth_client),
289 "https://mcp.example.com".to_string(), );
291
292 assert_eq!(provider.name(), "test");
293 assert_eq!(provider.provider_type(), AuthProviderType::OAuth2);
294 }
295}