turbomcp_auth/oauth2/client.rs
1//! OAuth 2.1 Client Implementation
2//!
3//! This module provides an OAuth 2.1 client wrapper that supports:
4//! - Authorization Code flow (with PKCE)
5//! - Client Credentials flow (server-to-server)
6//! - Device Authorization flow (CLI/IoT)
7//!
8//! The client handles provider-specific configurations and quirks for
9//! Google, Microsoft, GitHub, GitLab, and generic OAuth providers.
10
11use std::collections::HashMap;
12
13use oauth2::{
14 AuthUrl, ClientId, ClientSecret, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl,
15 RefreshToken, RevocationUrl, Scope, TokenResponse, TokenUrl,
16 basic::{BasicClient, BasicTokenType},
17 revocation::StandardRevocableToken,
18};
19use secrecy::ExposeSecret;
20
21use turbomcp_protocol::{Error as McpError, Result as McpResult};
22
23use super::super::config::{OAuth2Config, ProviderConfig, ProviderType, RefreshBehavior};
24use super::super::types::TokenInfo;
25
26/// OAuth 2.1 client wrapper supporting all modern flows
27#[derive(Debug, Clone)]
28pub struct OAuth2Client {
29 /// Authorization code flow client (most common)
30 pub(crate) auth_code_client: BasicClient,
31 /// Client credentials client (server-to-server)
32 pub(crate) client_credentials_client: Option<BasicClient>,
33 /// Device code client (for CLI/IoT applications)
34 pub(crate) device_code_client: Option<BasicClient>,
35 /// Provider-specific configuration
36 pub provider_config: ProviderConfig,
37}
38
39impl OAuth2Client {
40 /// Create an OAuth 2.1 client supporting all flows
41 pub fn new(config: &OAuth2Config, provider_type: ProviderType) -> McpResult<Self> {
42 // Validate URLs
43 let auth_url = AuthUrl::new(config.auth_url.clone())
44 .map_err(|_| McpError::validation("Invalid authorization URL".to_string()))?;
45
46 let token_url = TokenUrl::new(config.token_url.clone())
47 .map_err(|_| McpError::validation("Invalid token URL".to_string()))?;
48
49 // Redirect URI validation with security checks
50 let redirect_url = Self::validate_redirect_uri(&config.redirect_uri)?;
51
52 // Create authorization code flow client (primary)
53 let client_secret = if config.client_secret.expose_secret().is_empty() {
54 None
55 } else {
56 Some(ClientSecret::new(
57 config.client_secret.expose_secret().clone(),
58 ))
59 };
60
61 let mut auth_code_client = BasicClient::new(
62 ClientId::new(config.client_id.clone()),
63 client_secret.clone(),
64 auth_url.clone(),
65 Some(token_url.clone()),
66 )
67 .set_redirect_uri(redirect_url);
68
69 // Set revocation endpoint if provided (RFC 7009)
70 if let Some(ref revocation_url_str) = config.revocation_url {
71 let revocation_url = RevocationUrl::new(revocation_url_str.clone())
72 .map_err(|_| McpError::validation("Invalid revocation URL".to_string()))?;
73 auth_code_client = auth_code_client.set_revocation_uri(revocation_url);
74 }
75
76 // Create client credentials client if we have a secret (server-to-server)
77 let client_credentials_client = if client_secret.is_some() {
78 Some(BasicClient::new(
79 ClientId::new(config.client_id.clone()),
80 client_secret.clone(),
81 auth_url.clone(),
82 Some(token_url.clone()),
83 ))
84 } else {
85 None
86 };
87
88 // Device code client (for CLI/IoT apps) - uses same configuration
89 let device_code_client = Some(BasicClient::new(
90 ClientId::new(config.client_id.clone()),
91 client_secret,
92 auth_url,
93 Some(token_url),
94 ));
95
96 // Provider-specific configuration
97 let provider_config = Self::build_provider_config(provider_type);
98
99 Ok(Self {
100 auth_code_client,
101 client_credentials_client,
102 device_code_client,
103 provider_config,
104 })
105 }
106
107 /// Build provider-specific configuration
108 fn build_provider_config(provider_type: ProviderType) -> ProviderConfig {
109 match provider_type {
110 ProviderType::Google => ProviderConfig {
111 provider_type,
112 default_scopes: vec![
113 "openid".to_string(),
114 "email".to_string(),
115 "profile".to_string(),
116 ],
117 refresh_behavior: RefreshBehavior::Proactive,
118 userinfo_endpoint: Some(
119 "https://www.googleapis.com/oauth2/v2/userinfo".to_string(),
120 ),
121 additional_params: HashMap::new(),
122 },
123 ProviderType::Microsoft => ProviderConfig {
124 provider_type,
125 default_scopes: vec![
126 "openid".to_string(),
127 "profile".to_string(),
128 "email".to_string(),
129 "User.Read".to_string(),
130 ],
131 refresh_behavior: RefreshBehavior::Proactive,
132 userinfo_endpoint: Some("https://graph.microsoft.com/v1.0/me".to_string()),
133 additional_params: HashMap::new(),
134 },
135 ProviderType::GitHub => ProviderConfig {
136 provider_type,
137 default_scopes: vec!["user:email".to_string(), "read:user".to_string()],
138 refresh_behavior: RefreshBehavior::Reactive,
139 userinfo_endpoint: Some("https://api.github.com/user".to_string()),
140 additional_params: HashMap::new(),
141 },
142 ProviderType::GitLab => ProviderConfig {
143 provider_type,
144 default_scopes: vec!["read_user".to_string(), "openid".to_string()],
145 refresh_behavior: RefreshBehavior::Proactive,
146 userinfo_endpoint: Some("https://gitlab.com/api/v4/user".to_string()),
147 additional_params: HashMap::new(),
148 },
149 ProviderType::Generic | ProviderType::Custom(_) => ProviderConfig {
150 provider_type,
151 default_scopes: vec!["openid".to_string(), "profile".to_string()],
152 refresh_behavior: RefreshBehavior::Proactive,
153 userinfo_endpoint: None,
154 additional_params: HashMap::new(),
155 },
156 }
157 }
158
159 /// Redirect URI validation with security checks
160 ///
161 /// Security considerations:
162 /// - Prevents open redirect attacks
163 /// - Validates URL format and structure
164 /// - Environment-aware validation (localhost for development)
165 fn validate_redirect_uri(uri: &str) -> McpResult<RedirectUrl> {
166 use url::Url;
167
168 // Parse and validate URL structure
169 let parsed = Url::parse(uri)
170 .map_err(|e| McpError::validation(format!("Invalid redirect URI format: {e}")))?;
171
172 // Security: Validate scheme
173 match parsed.scheme() {
174 "http" => {
175 // Only allow http for localhost/127.0.0.1/0.0.0.0 in development
176 if let Some(host) = parsed.host_str() {
177 // Allow localhost, 127.0.0.1, 0.0.0.0 (bind all interfaces)
178 let is_localhost = host == "localhost"
179 || host.starts_with("localhost:")
180 || host == "127.0.0.1"
181 || host.starts_with("127.0.0.1:")
182 || host == "0.0.0.0"
183 || host.starts_with("0.0.0.0:");
184
185 if !is_localhost {
186 return Err(McpError::validation(
187 "HTTP redirect URIs only allowed for localhost in development"
188 .to_string(),
189 ));
190 }
191 } else {
192 return Err(McpError::validation(
193 "Redirect URI must have a valid host".to_string(),
194 ));
195 }
196 }
197 "https" => {
198 // HTTPS is always allowed
199 }
200 "com.example.app" | "msauth" => {
201 // Allow custom schemes for mobile apps (common patterns)
202 }
203 scheme if scheme.starts_with("app.") || scheme.ends_with(".app") => {
204 // Allow app-specific custom schemes
205 }
206 _ => {
207 return Err(McpError::validation(format!(
208 "Unsupported redirect URI scheme: {}. Use https, http (localhost only), or app-specific schemes",
209 parsed.scheme()
210 )));
211 }
212 }
213
214 // Security: Prevent fragment in redirect URI (per OAuth 2.0 spec)
215 if parsed.fragment().is_some() {
216 return Err(McpError::validation(
217 "Redirect URI must not contain URL fragment".to_string(),
218 ));
219 }
220
221 // Security: Check for path traversal in PATH component only
222 // Note: url::Url::parse() already normalizes paths and removes .. segments
223 // We check the final path to ensure no traversal remains after normalization
224 if let Some(path) = parsed.path_segments() {
225 for segment in path {
226 if segment == ".." {
227 return Err(McpError::validation(
228 "Redirect URI path must not contain traversal sequences".to_string(),
229 ));
230 }
231 }
232 }
233
234 // Use oauth2 crate's RedirectUrl for validation
235 // This provides URL validation per OAuth 2.1 specifications
236 // For production security, implement exact whitelist matching of allowed URIs
237 RedirectUrl::new(uri.to_string())
238 .map_err(|_| McpError::validation("Failed to create redirect URL".to_string()))
239 }
240
241 /// Get access to the authorization code client
242 #[must_use]
243 pub fn auth_code_client(&self) -> &BasicClient {
244 &self.auth_code_client
245 }
246
247 /// Get access to the client credentials client (if available)
248 #[must_use]
249 pub fn client_credentials_client(&self) -> Option<&BasicClient> {
250 self.client_credentials_client.as_ref()
251 }
252
253 /// Get access to the device code client (if available)
254 #[must_use]
255 pub fn device_code_client(&self) -> Option<&BasicClient> {
256 self.device_code_client.as_ref()
257 }
258
259 /// Get the provider configuration
260 #[must_use]
261 pub fn provider_config(&self) -> &ProviderConfig {
262 &self.provider_config
263 }
264
265 /// Start authorization code flow with PKCE
266 ///
267 /// This initiates the OAuth 2.1 authorization code flow with PKCE (RFC 7636)
268 /// for enhanced security, especially for public clients.
269 ///
270 /// # PKCE Code Verifier Storage (CRITICAL SECURITY REQUIREMENT)
271 ///
272 /// The returned code_verifier MUST be securely stored and associated with the
273 /// state parameter until the authorization code is exchanged for tokens.
274 ///
275 /// **Storage Options (from most to least secure):**
276 ///
277 /// 1. **Server-side encrypted session** (RECOMMENDED for web apps)
278 /// - Store in server session with HttpOnly, Secure, SameSite=Lax cookies
279 /// - Associate with state parameter for CSRF protection
280 /// - Automatic cleanup after exchange or timeout
281 ///
282 /// 2. **Redis/Database with TTL** (RECOMMENDED for distributed systems)
283 /// - Key: state parameter, Value: encrypted code_verifier
284 /// - Set TTL to match authorization timeout (typically 10 minutes)
285 /// - Use server-side encryption at rest
286 ///
287 /// 3. **In-memory for SPAs** (ACCEPTABLE for public clients only)
288 /// - Store in JavaScript closure or React state (NOT localStorage/sessionStorage)
289 /// - Clear immediately after token exchange
290 /// - Risk: XSS can steal verifier
291 ///
292 /// **NEVER:**
293 /// - Store in localStorage or sessionStorage (XSS risk)
294 /// - Send to client in URL or query parameters
295 /// - Log or expose in error messages
296 ///
297 /// # Arguments
298 /// * `scopes` - Requested OAuth scopes
299 /// * `state` - CSRF protection state parameter (use cryptographically random value)
300 ///
301 /// # Returns
302 /// Tuple of (authorization_url, PKCE code_verifier for secure storage)
303 ///
304 /// # Example
305 /// ```ignore
306 /// // Server-side web app (RECOMMENDED)
307 /// let state = generate_csrf_token(); // Cryptographically random
308 /// let (auth_url, code_verifier) = client.authorization_code_flow(scopes, state.clone());
309 ///
310 /// // Store securely server-side
311 /// session.insert("oauth_state", state);
312 /// session.insert("pkce_verifier", code_verifier); // Encrypted session
313 ///
314 /// // Redirect user
315 /// redirect_to(auth_url);
316 /// ```
317 pub fn authorization_code_flow(&self, scopes: Vec<String>, state: String) -> (String, String) {
318 // Generate PKCE challenge
319 let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
320
321 // Build authorization URL with PKCE
322 let (auth_url, _state) = self
323 .auth_code_client
324 .authorize_url(|| oauth2::CsrfToken::new(state))
325 .add_scopes(scopes.into_iter().map(Scope::new))
326 .set_pkce_challenge(pkce_challenge)
327 .url();
328
329 (auth_url.to_string(), pkce_verifier.secret().to_string())
330 }
331
332 /// Exchange authorization code for access token
333 ///
334 /// This exchanges the authorization code received from the OAuth provider
335 /// for an access token using PKCE (RFC 7636).
336 ///
337 /// # Arguments
338 /// * `code` - Authorization code from OAuth provider
339 /// * `code_verifier` - PKCE code verifier (from authorization_code_flow)
340 ///
341 /// # Returns
342 /// TokenInfo containing access token and refresh token (if available)
343 pub async fn exchange_code_for_token(
344 &self,
345 code: String,
346 code_verifier: String,
347 ) -> McpResult<TokenInfo> {
348 let http_client = reqwest::Client::new();
349 let token_response = self
350 .auth_code_client
351 .exchange_code(oauth2::AuthorizationCode::new(code))
352 .set_pkce_verifier(PkceCodeVerifier::new(code_verifier))
353 .request_async(|request| async { execute_oauth_request(&http_client, request).await })
354 .await
355 .map_err(|e| McpError::internal(format!("Token exchange failed: {e}")))?;
356
357 Ok(self.token_response_to_token_info(token_response))
358 }
359
360 /// Refresh an access token with automatic refresh token rotation
361 ///
362 /// This uses a refresh token to obtain a new access token without
363 /// requiring user interaction. OAuth 2.1 and RFC 9700 recommend refresh
364 /// token rotation where the server issues a new refresh token with each
365 /// refresh request.
366 ///
367 /// # Refresh Token Rotation (OAuth 2.1 / RFC 9700 Best Practice)
368 ///
369 /// When the server supports rotation:
370 /// - A new refresh token is returned in the response
371 /// - The old refresh token should be discarded immediately
372 /// - Store and use the new refresh token for future requests
373 /// - This prevents token theft detection
374 ///
375 /// **Important:** Always check if `token_info.refresh_token` is present in
376 /// the response. If present, you MUST replace your stored refresh token
377 /// with the new one. If absent, continue using the current refresh token.
378 ///
379 /// # Arguments
380 /// * `refresh_token` - The current refresh token
381 ///
382 /// # Returns
383 /// New TokenInfo with:
384 /// - Fresh access token (always present)
385 /// - New refresh token (if server supports rotation)
386 ///
387 /// # Example
388 /// ```ignore
389 /// let mut stored_refresh_token = "current_refresh_token";
390 /// let new_tokens = client.refresh_access_token(stored_refresh_token).await?;
391 ///
392 /// // Check for refresh token rotation
393 /// if let Some(new_refresh_token) = &new_tokens.refresh_token {
394 /// // Server rotated the token - update storage
395 /// stored_refresh_token = new_refresh_token;
396 /// println!("Refresh token rotated (security best practice)");
397 /// }
398 /// // Use new access token
399 /// let access_token = new_tokens.access_token;
400 /// ```
401 pub async fn refresh_access_token(&self, refresh_token: &str) -> McpResult<TokenInfo> {
402 let http_client = reqwest::Client::new();
403 let token_response = self
404 .auth_code_client
405 .exchange_refresh_token(&RefreshToken::new(refresh_token.to_string()))
406 .request_async(|request| async { execute_oauth_request(&http_client, request).await })
407 .await
408 .map_err(|e| McpError::internal(format!("Token refresh failed: {e}")))?;
409
410 Ok(self.token_response_to_token_info(token_response))
411 }
412
413 /// Client credentials flow for server-to-server authentication
414 ///
415 /// This implements the OAuth 2.1 Client Credentials flow for
416 /// service-to-service communication without user involvement.
417 ///
418 /// # Arguments
419 /// * `scopes` - Requested OAuth scopes
420 ///
421 /// # Returns
422 /// TokenInfo with access token (typically without refresh token)
423 pub async fn client_credentials_flow(&self, scopes: Vec<String>) -> McpResult<TokenInfo> {
424 let client = self.client_credentials_client.as_ref().ok_or_else(|| {
425 McpError::internal("Client credentials flow requires client secret".to_string())
426 })?;
427
428 let http_client = reqwest::Client::new();
429 let token_response = client
430 .exchange_client_credentials()
431 .add_scopes(scopes.into_iter().map(Scope::new))
432 .request_async(|request| async { execute_oauth_request(&http_client, request).await })
433 .await
434 .map_err(|e| McpError::internal(format!("Client credentials flow failed: {e}")))?;
435
436 Ok(self.token_response_to_token_info(token_response))
437 }
438
439 /// Convert oauth2 token response to TokenInfo
440 fn token_response_to_token_info(
441 &self,
442 response: oauth2::StandardTokenResponse<oauth2::EmptyExtraTokenFields, BasicTokenType>,
443 ) -> TokenInfo {
444 let expires_in = response.expires_in().map(|duration| duration.as_secs());
445
446 TokenInfo {
447 access_token: response.access_token().secret().clone(),
448 token_type: format!("{:?}", response.token_type()),
449 refresh_token: response.refresh_token().map(|t| t.secret().clone()),
450 expires_in,
451 scope: response.scopes().map(|scopes| {
452 scopes
453 .iter()
454 .map(|s| s.as_str())
455 .collect::<Vec<_>>()
456 .join(" ")
457 }),
458 }
459 }
460
461 /// Revoke a token using RFC 7009 Token Revocation
462 ///
463 /// Per RFC 7009 Section 2, prefer revoking refresh tokens (which MUST be supported
464 /// by the server if issued) over access tokens (which MAY be supported).
465 ///
466 /// # Arguments
467 /// * `token_info` - Token information containing access and/or refresh token
468 ///
469 /// # Returns
470 /// Ok if revocation succeeded or token was already invalid (per RFC 7009)
471 ///
472 /// # Errors
473 /// Returns error if:
474 /// - No revocation endpoint was configured
475 /// - Network/HTTP error occurred
476 /// - Server returned an error response
477 pub async fn revoke_token(&self, token_info: &TokenInfo) -> McpResult<()> {
478 let http_client = reqwest::Client::new();
479
480 // Per RFC 7009 Section 2: Prefer refresh token, fallback to access token
481 let token_to_revoke: StandardRevocableToken =
482 if let Some(ref refresh_token) = token_info.refresh_token {
483 RefreshToken::new(refresh_token.clone()).into()
484 } else {
485 oauth2::AccessToken::new(token_info.access_token.clone()).into()
486 };
487
488 self.auth_code_client
489 .revoke_token(token_to_revoke)
490 .map_err(|e| McpError::internal(format!("Token revocation not configured: {e}")))?
491 .request_async(|request| async { execute_oauth_request(&http_client, request).await })
492 .await
493 .map_err(|e| McpError::internal(format!("Token revocation failed: {e}")))?;
494
495 Ok(())
496 }
497
498 /// Validate that an access token is still valid
499 ///
500 /// This checks if a token has expired based on expiration time.
501 /// Note: This is a client-side check only; servers may have revoked the token.
502 pub fn is_token_expired(&self, token: &TokenInfo) -> bool {
503 if let Some(expires_in) = token.expires_in {
504 // Assume token was valid "now" - in production, store issued_at timestamp
505 expires_in == 0
506 } else {
507 false
508 }
509 }
510}
511
512/// Execute OAuth request using reqwest HTTP client
513/// Converts between oauth2 and reqwest types
514async fn execute_oauth_request(
515 client: &reqwest::Client,
516 request: oauth2::HttpRequest,
517) -> Result<oauth2::HttpResponse, oauth2::reqwest::Error<reqwest::Error>> {
518 let method_str = format!("{}", request.method);
519 let url = request.url.clone();
520
521 // Build the request
522 let mut req_builder = match method_str.to_uppercase().as_str() {
523 "GET" => client.get(url),
524 "POST" => client.post(url),
525 m => {
526 return Err(oauth2::reqwest::Error::Other(format!(
527 "Unsupported HTTP method: {}",
528 m
529 )));
530 }
531 };
532
533 // Add body (always present, even if empty)
534 if !request.body.is_empty() {
535 req_builder = req_builder.body(request.body);
536 }
537
538 // Add headers - convert from oauth2 HeaderName/HeaderValue to reqwest types
539 for (name, value) in &request.headers {
540 let name_str = format!("{:?}", name); // Use debug format for HeaderName
541 // HeaderValue as_bytes should work
542 let value_bytes = value.as_bytes();
543
544 if let (Ok(header_name), Ok(header_value)) = (
545 reqwest::header::HeaderName::from_bytes(name_str.as_bytes()),
546 reqwest::header::HeaderValue::from_bytes(value_bytes),
547 ) {
548 req_builder = req_builder.header(header_name, header_value);
549 }
550 }
551
552 // Send request
553 let response = req_builder
554 .send()
555 .await
556 .map_err(|e| oauth2::reqwest::Error::Other(e.to_string()))?;
557
558 let status = response.status();
559 let body = response
560 .bytes()
561 .await
562 .map_err(|e| oauth2::reqwest::Error::Other(e.to_string()))?
563 .to_vec();
564
565 // Convert reqwest status code to oauth2 status code
566 let oauth_status =
567 oauth2::http::StatusCode::from_u16(status.as_u16()).unwrap_or(oauth2::http::StatusCode::OK);
568
569 Ok(oauth2::HttpResponse {
570 status_code: oauth_status,
571 body,
572 headers: Default::default(),
573 })
574}