turbomcp_auth/oauth2/client.rs
1//! OAuth 2.1 Client Implementation
2//!
3//! This module provides an OAuth 2.1 client wrapper that supports:
4//! - Authorization Code flow (with PKCE)
5//! - Client Credentials flow (server-to-server)
6//! - Device Authorization flow (CLI/IoT)
7//!
8//! The client handles provider-specific configurations and quirks for
9//! Google, Microsoft, GitHub, GitLab, and generic OAuth providers.
10
11use std::collections::HashMap;
12
13use oauth2::{
14 AuthUrl, ClientId, ClientSecret, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl,
15 RefreshToken, RevocationUrl, Scope, TokenResponse, TokenUrl,
16 basic::{BasicClient, BasicTokenType},
17 revocation::StandardRevocableToken,
18};
19use secrecy::ExposeSecret;
20
21use turbomcp_protocol::{Error as McpError, Result as McpResult};
22
23use super::super::config::{OAuth2Config, ProviderConfig, ProviderType, RefreshBehavior};
24use super::super::types::TokenInfo;
25
26/// OAuth 2.1 client wrapper supporting all modern flows
27#[derive(Debug, Clone)]
28pub struct OAuth2Client {
29 /// Authorization code flow client (most common)
30 pub(crate) auth_code_client: BasicClient,
31 /// Client credentials client (server-to-server)
32 pub(crate) client_credentials_client: Option<BasicClient>,
33 /// Device code client (for CLI/IoT applications)
34 pub(crate) device_code_client: Option<BasicClient>,
35 /// Provider-specific configuration
36 pub provider_config: ProviderConfig,
37}
38
39impl OAuth2Client {
40 /// Create an OAuth 2.1 client supporting all flows
41 pub fn new(config: &OAuth2Config, provider_type: ProviderType) -> McpResult<Self> {
42 // Validate URLs
43 let auth_url = AuthUrl::new(config.auth_url.clone())
44 .map_err(|_| McpError::validation("Invalid authorization URL".to_string()))?;
45
46 let token_url = TokenUrl::new(config.token_url.clone())
47 .map_err(|_| McpError::validation("Invalid token URL".to_string()))?;
48
49 // Redirect URI validation with security checks
50 let redirect_url = Self::validate_redirect_uri(&config.redirect_uri)?;
51
52 // Create authorization code flow client (primary)
53 let client_secret = if config.client_secret.expose_secret().is_empty() {
54 None
55 } else {
56 Some(ClientSecret::new(
57 config.client_secret.expose_secret().clone(),
58 ))
59 };
60
61 let mut auth_code_client = BasicClient::new(
62 ClientId::new(config.client_id.clone()),
63 client_secret.clone(),
64 auth_url.clone(),
65 Some(token_url.clone()),
66 )
67 .set_redirect_uri(redirect_url);
68
69 // Set revocation endpoint if provided (RFC 7009)
70 if let Some(ref revocation_url_str) = config.revocation_url {
71 let revocation_url = RevocationUrl::new(revocation_url_str.clone())
72 .map_err(|_| McpError::validation("Invalid revocation URL".to_string()))?;
73 auth_code_client = auth_code_client.set_revocation_uri(revocation_url);
74 }
75
76 // Create client credentials client if we have a secret (server-to-server)
77 let client_credentials_client = if client_secret.is_some() {
78 Some(BasicClient::new(
79 ClientId::new(config.client_id.clone()),
80 client_secret.clone(),
81 auth_url.clone(),
82 Some(token_url.clone()),
83 ))
84 } else {
85 None
86 };
87
88 // Device code client (for CLI/IoT apps) - uses same configuration
89 let device_code_client = Some(BasicClient::new(
90 ClientId::new(config.client_id.clone()),
91 client_secret,
92 auth_url,
93 Some(token_url),
94 ));
95
96 // Provider-specific configuration
97 let provider_config = Self::build_provider_config(provider_type);
98
99 Ok(Self {
100 auth_code_client,
101 client_credentials_client,
102 device_code_client,
103 provider_config,
104 })
105 }
106
107 /// Build provider-specific configuration
108 fn build_provider_config(provider_type: ProviderType) -> ProviderConfig {
109 match provider_type {
110 ProviderType::Google => ProviderConfig {
111 provider_type,
112 default_scopes: vec![
113 "openid".to_string(),
114 "email".to_string(),
115 "profile".to_string(),
116 ],
117 refresh_behavior: RefreshBehavior::Proactive,
118 userinfo_endpoint: Some(
119 "https://www.googleapis.com/oauth2/v2/userinfo".to_string(),
120 ),
121 additional_params: HashMap::new(),
122 },
123 ProviderType::Microsoft => ProviderConfig {
124 provider_type,
125 default_scopes: vec![
126 "openid".to_string(),
127 "profile".to_string(),
128 "email".to_string(),
129 "User.Read".to_string(),
130 ],
131 refresh_behavior: RefreshBehavior::Proactive,
132 userinfo_endpoint: Some("https://graph.microsoft.com/v1.0/me".to_string()),
133 additional_params: HashMap::new(),
134 },
135 ProviderType::GitHub => ProviderConfig {
136 provider_type,
137 default_scopes: vec!["user:email".to_string(), "read:user".to_string()],
138 refresh_behavior: RefreshBehavior::Reactive,
139 userinfo_endpoint: Some("https://api.github.com/user".to_string()),
140 additional_params: HashMap::new(),
141 },
142 ProviderType::GitLab => ProviderConfig {
143 provider_type,
144 default_scopes: vec!["read_user".to_string(), "openid".to_string()],
145 refresh_behavior: RefreshBehavior::Proactive,
146 userinfo_endpoint: Some("https://gitlab.com/api/v4/user".to_string()),
147 additional_params: HashMap::new(),
148 },
149 ProviderType::Apple => ProviderConfig {
150 provider_type,
151 default_scopes: vec![
152 "openid".to_string(),
153 "email".to_string(),
154 "name".to_string(),
155 ],
156 refresh_behavior: RefreshBehavior::Proactive,
157 userinfo_endpoint: Some("https://appleid.apple.com/auth/v1/user".to_string()),
158 additional_params: {
159 let mut params = HashMap::new();
160 // Apple requires response_mode=form_post for web apps
161 params.insert("response_mode".to_string(), "form_post".to_string());
162 params
163 },
164 },
165 ProviderType::Okta => ProviderConfig {
166 provider_type,
167 default_scopes: vec![
168 "openid".to_string(),
169 "email".to_string(),
170 "profile".to_string(),
171 ],
172 refresh_behavior: RefreshBehavior::Proactive,
173 userinfo_endpoint: Some("/oauth2/v1/userinfo".to_string()), // Relative to Okta domain
174 additional_params: HashMap::new(),
175 },
176 ProviderType::Auth0 => ProviderConfig {
177 provider_type,
178 default_scopes: vec![
179 "openid".to_string(),
180 "email".to_string(),
181 "profile".to_string(),
182 ],
183 refresh_behavior: RefreshBehavior::Proactive,
184 userinfo_endpoint: Some("/userinfo".to_string()), // Relative to Auth0 domain
185 additional_params: HashMap::new(),
186 },
187 ProviderType::Keycloak => ProviderConfig {
188 provider_type,
189 default_scopes: vec![
190 "openid".to_string(),
191 "email".to_string(),
192 "profile".to_string(),
193 ],
194 refresh_behavior: RefreshBehavior::Proactive,
195 userinfo_endpoint: Some(
196 "/realms/{realm}/protocol/openid-connect/userinfo".to_string(),
197 ),
198 additional_params: HashMap::new(),
199 },
200 ProviderType::Generic | ProviderType::Custom(_) => ProviderConfig {
201 provider_type,
202 default_scopes: vec!["openid".to_string(), "profile".to_string()],
203 refresh_behavior: RefreshBehavior::Proactive,
204 userinfo_endpoint: None,
205 additional_params: HashMap::new(),
206 },
207 }
208 }
209
210 /// Redirect URI validation with security checks
211 ///
212 /// Security considerations:
213 /// - Prevents open redirect attacks
214 /// - Validates URL format and structure
215 /// - Environment-aware validation (localhost for development)
216 fn validate_redirect_uri(uri: &str) -> McpResult<RedirectUrl> {
217 use url::Url;
218
219 // Parse and validate URL structure
220 let parsed = Url::parse(uri)
221 .map_err(|e| McpError::validation(format!("Invalid redirect URI format: {e}")))?;
222
223 // Security: Validate scheme
224 match parsed.scheme() {
225 "http" => {
226 // Only allow http for localhost/127.0.0.1/0.0.0.0 in development
227 if let Some(host) = parsed.host_str() {
228 // Allow localhost, 127.0.0.1, 0.0.0.0 (bind all interfaces)
229 let is_localhost = host == "localhost"
230 || host.starts_with("localhost:")
231 || host == "127.0.0.1"
232 || host.starts_with("127.0.0.1:")
233 || host == "0.0.0.0"
234 || host.starts_with("0.0.0.0:");
235
236 if !is_localhost {
237 return Err(McpError::validation(
238 "HTTP redirect URIs only allowed for localhost in development"
239 .to_string(),
240 ));
241 }
242 } else {
243 return Err(McpError::validation(
244 "Redirect URI must have a valid host".to_string(),
245 ));
246 }
247 }
248 "https" => {
249 // HTTPS is always allowed
250 }
251 "com.example.app" | "msauth" => {
252 // Allow custom schemes for mobile apps (common patterns)
253 }
254 scheme if scheme.starts_with("app.") || scheme.ends_with(".app") => {
255 // Allow app-specific custom schemes
256 }
257 _ => {
258 return Err(McpError::validation(format!(
259 "Unsupported redirect URI scheme: {}. Use https, http (localhost only), or app-specific schemes",
260 parsed.scheme()
261 )));
262 }
263 }
264
265 // Security: Prevent fragment in redirect URI (per OAuth 2.0 spec)
266 if parsed.fragment().is_some() {
267 return Err(McpError::validation(
268 "Redirect URI must not contain URL fragment".to_string(),
269 ));
270 }
271
272 // Security: Check for path traversal in PATH component only
273 // Note: url::Url::parse() already normalizes paths and removes .. segments
274 // We check the final path to ensure no traversal remains after normalization
275 if let Some(path) = parsed.path_segments() {
276 for segment in path {
277 if segment == ".." {
278 return Err(McpError::validation(
279 "Redirect URI path must not contain traversal sequences".to_string(),
280 ));
281 }
282 }
283 }
284
285 // Use oauth2 crate's RedirectUrl for validation
286 // This provides URL validation per OAuth 2.1 specifications
287 // For production security, implement exact whitelist matching of allowed URIs
288 RedirectUrl::new(uri.to_string())
289 .map_err(|_| McpError::validation("Failed to create redirect URL".to_string()))
290 }
291
292 /// Get access to the authorization code client
293 #[must_use]
294 pub fn auth_code_client(&self) -> &BasicClient {
295 &self.auth_code_client
296 }
297
298 /// Get access to the client credentials client (if available)
299 #[must_use]
300 pub fn client_credentials_client(&self) -> Option<&BasicClient> {
301 self.client_credentials_client.as_ref()
302 }
303
304 /// Get access to the device code client (if available)
305 #[must_use]
306 pub fn device_code_client(&self) -> Option<&BasicClient> {
307 self.device_code_client.as_ref()
308 }
309
310 /// Get the provider configuration
311 #[must_use]
312 pub fn provider_config(&self) -> &ProviderConfig {
313 &self.provider_config
314 }
315
316 /// Start authorization code flow with PKCE
317 ///
318 /// This initiates the OAuth 2.1 authorization code flow with PKCE (RFC 7636)
319 /// for enhanced security, especially for public clients.
320 ///
321 /// # PKCE Code Verifier Storage (CRITICAL SECURITY REQUIREMENT)
322 ///
323 /// The returned code_verifier MUST be securely stored and associated with the
324 /// state parameter until the authorization code is exchanged for tokens.
325 ///
326 /// **Storage Options (from most to least secure):**
327 ///
328 /// 1. **Server-side encrypted session** (RECOMMENDED for web apps)
329 /// - Store in server session with HttpOnly, Secure, SameSite=Lax cookies
330 /// - Associate with state parameter for CSRF protection
331 /// - Automatic cleanup after exchange or timeout
332 ///
333 /// 2. **Redis/Database with TTL** (RECOMMENDED for distributed systems)
334 /// - Key: state parameter, Value: encrypted code_verifier
335 /// - Set TTL to match authorization timeout (typically 10 minutes)
336 /// - Use server-side encryption at rest
337 ///
338 /// 3. **In-memory for SPAs** (ACCEPTABLE for public clients only)
339 /// - Store in JavaScript closure or React state (NOT localStorage/sessionStorage)
340 /// - Clear immediately after token exchange
341 /// - Risk: XSS can steal verifier
342 ///
343 /// **NEVER:**
344 /// - Store in localStorage or sessionStorage (XSS risk)
345 /// - Send to client in URL or query parameters
346 /// - Log or expose in error messages
347 ///
348 /// # Arguments
349 /// * `scopes` - Requested OAuth scopes
350 /// * `state` - CSRF protection state parameter (use cryptographically random value)
351 ///
352 /// # Returns
353 /// Tuple of (authorization_url, PKCE code_verifier for secure storage)
354 ///
355 /// # Example
356 /// ```ignore
357 /// // Server-side web app (RECOMMENDED)
358 /// let state = generate_csrf_token(); // Cryptographically random
359 /// let (auth_url, code_verifier) = client.authorization_code_flow(scopes, state.clone());
360 ///
361 /// // Store securely server-side
362 /// session.insert("oauth_state", state);
363 /// session.insert("pkce_verifier", code_verifier); // Encrypted session
364 ///
365 /// // Redirect user
366 /// redirect_to(auth_url);
367 /// ```
368 pub fn authorization_code_flow(&self, scopes: Vec<String>, state: String) -> (String, String) {
369 // Generate PKCE challenge
370 let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
371
372 // Build authorization URL with PKCE
373 let (auth_url, _state) = self
374 .auth_code_client
375 .authorize_url(|| oauth2::CsrfToken::new(state))
376 .add_scopes(scopes.into_iter().map(Scope::new))
377 .set_pkce_challenge(pkce_challenge)
378 .url();
379
380 (auth_url.to_string(), pkce_verifier.secret().to_string())
381 }
382
383 /// Exchange authorization code for access token
384 ///
385 /// This exchanges the authorization code received from the OAuth provider
386 /// for an access token using PKCE (RFC 7636).
387 ///
388 /// # Arguments
389 /// * `code` - Authorization code from OAuth provider
390 /// * `code_verifier` - PKCE code verifier (from authorization_code_flow)
391 ///
392 /// # Returns
393 /// TokenInfo containing access token and refresh token (if available)
394 pub async fn exchange_code_for_token(
395 &self,
396 code: String,
397 code_verifier: String,
398 ) -> McpResult<TokenInfo> {
399 let http_client = reqwest::Client::new();
400 let token_response = self
401 .auth_code_client
402 .exchange_code(oauth2::AuthorizationCode::new(code))
403 .set_pkce_verifier(PkceCodeVerifier::new(code_verifier))
404 .request_async(|request| async { execute_oauth_request(&http_client, request).await })
405 .await
406 .map_err(|e| McpError::internal(format!("Token exchange failed: {e}")))?;
407
408 Ok(self.token_response_to_token_info(token_response))
409 }
410
411 /// Refresh an access token with automatic refresh token rotation
412 ///
413 /// This uses a refresh token to obtain a new access token without
414 /// requiring user interaction. OAuth 2.1 and RFC 9700 recommend refresh
415 /// token rotation where the server issues a new refresh token with each
416 /// refresh request.
417 ///
418 /// # Refresh Token Rotation (OAuth 2.1 / RFC 9700 Best Practice)
419 ///
420 /// When the server supports rotation:
421 /// - A new refresh token is returned in the response
422 /// - The old refresh token should be discarded immediately
423 /// - Store and use the new refresh token for future requests
424 /// - This prevents token theft detection
425 ///
426 /// **Important:** Always check if `token_info.refresh_token` is present in
427 /// the response. If present, you MUST replace your stored refresh token
428 /// with the new one. If absent, continue using the current refresh token.
429 ///
430 /// # Arguments
431 /// * `refresh_token` - The current refresh token
432 ///
433 /// # Returns
434 /// New TokenInfo with:
435 /// - Fresh access token (always present)
436 /// - New refresh token (if server supports rotation)
437 ///
438 /// # Example
439 /// ```ignore
440 /// let mut stored_refresh_token = "current_refresh_token";
441 /// let new_tokens = client.refresh_access_token(stored_refresh_token).await?;
442 ///
443 /// // Check for refresh token rotation
444 /// if let Some(new_refresh_token) = &new_tokens.refresh_token {
445 /// // Server rotated the token - update storage
446 /// stored_refresh_token = new_refresh_token;
447 /// println!("Refresh token rotated (security best practice)");
448 /// }
449 /// // Use new access token
450 /// let access_token = new_tokens.access_token;
451 /// ```
452 pub async fn refresh_access_token(&self, refresh_token: &str) -> McpResult<TokenInfo> {
453 let http_client = reqwest::Client::new();
454 let token_response = self
455 .auth_code_client
456 .exchange_refresh_token(&RefreshToken::new(refresh_token.to_string()))
457 .request_async(|request| async { execute_oauth_request(&http_client, request).await })
458 .await
459 .map_err(|e| McpError::internal(format!("Token refresh failed: {e}")))?;
460
461 Ok(self.token_response_to_token_info(token_response))
462 }
463
464 /// Client credentials flow for server-to-server authentication
465 ///
466 /// This implements the OAuth 2.1 Client Credentials flow for
467 /// service-to-service communication without user involvement.
468 ///
469 /// # Arguments
470 /// * `scopes` - Requested OAuth scopes
471 ///
472 /// # Returns
473 /// TokenInfo with access token (typically without refresh token)
474 pub async fn client_credentials_flow(&self, scopes: Vec<String>) -> McpResult<TokenInfo> {
475 let client = self.client_credentials_client.as_ref().ok_or_else(|| {
476 McpError::internal("Client credentials flow requires client secret".to_string())
477 })?;
478
479 let http_client = reqwest::Client::new();
480 let token_response = client
481 .exchange_client_credentials()
482 .add_scopes(scopes.into_iter().map(Scope::new))
483 .request_async(|request| async { execute_oauth_request(&http_client, request).await })
484 .await
485 .map_err(|e| McpError::internal(format!("Client credentials flow failed: {e}")))?;
486
487 Ok(self.token_response_to_token_info(token_response))
488 }
489
490 /// Convert oauth2 token response to TokenInfo
491 fn token_response_to_token_info(
492 &self,
493 response: oauth2::StandardTokenResponse<oauth2::EmptyExtraTokenFields, BasicTokenType>,
494 ) -> TokenInfo {
495 let expires_in = response.expires_in().map(|duration| duration.as_secs());
496
497 TokenInfo {
498 access_token: response.access_token().secret().clone(),
499 token_type: format!("{:?}", response.token_type()),
500 refresh_token: response.refresh_token().map(|t| t.secret().clone()),
501 expires_in,
502 scope: response.scopes().map(|scopes| {
503 scopes
504 .iter()
505 .map(|s| s.as_str())
506 .collect::<Vec<_>>()
507 .join(" ")
508 }),
509 }
510 }
511
512 /// Revoke a token using RFC 7009 Token Revocation
513 ///
514 /// Per RFC 7009 Section 2, prefer revoking refresh tokens (which MUST be supported
515 /// by the server if issued) over access tokens (which MAY be supported).
516 ///
517 /// # Arguments
518 /// * `token_info` - Token information containing access and/or refresh token
519 ///
520 /// # Returns
521 /// Ok if revocation succeeded or token was already invalid (per RFC 7009)
522 ///
523 /// # Errors
524 /// Returns error if:
525 /// - No revocation endpoint was configured
526 /// - Network/HTTP error occurred
527 /// - Server returned an error response
528 pub async fn revoke_token(&self, token_info: &TokenInfo) -> McpResult<()> {
529 let http_client = reqwest::Client::new();
530
531 // Per RFC 7009 Section 2: Prefer refresh token, fallback to access token
532 let token_to_revoke: StandardRevocableToken =
533 if let Some(ref refresh_token) = token_info.refresh_token {
534 RefreshToken::new(refresh_token.clone()).into()
535 } else {
536 oauth2::AccessToken::new(token_info.access_token.clone()).into()
537 };
538
539 self.auth_code_client
540 .revoke_token(token_to_revoke)
541 .map_err(|e| McpError::internal(format!("Token revocation not configured: {e}")))?
542 .request_async(|request| async { execute_oauth_request(&http_client, request).await })
543 .await
544 .map_err(|e| McpError::internal(format!("Token revocation failed: {e}")))?;
545
546 Ok(())
547 }
548
549 /// Validate that an access token is still valid
550 ///
551 /// This checks if a token has expired based on expiration time.
552 /// Note: This is a client-side check only; servers may have revoked the token.
553 pub fn is_token_expired(&self, token: &TokenInfo) -> bool {
554 if let Some(expires_in) = token.expires_in {
555 // Assume token was valid "now" - in production, store issued_at timestamp
556 expires_in == 0
557 } else {
558 false
559 }
560 }
561}
562
563/// Execute OAuth request using reqwest HTTP client
564/// Converts between oauth2 and reqwest types
565async fn execute_oauth_request(
566 client: &reqwest::Client,
567 request: oauth2::HttpRequest,
568) -> Result<oauth2::HttpResponse, oauth2::reqwest::Error<reqwest::Error>> {
569 let method_str = format!("{}", request.method);
570 let url = request.url.clone();
571
572 // Build the request
573 let mut req_builder = match method_str.to_uppercase().as_str() {
574 "GET" => client.get(url),
575 "POST" => client.post(url),
576 m => {
577 return Err(oauth2::reqwest::Error::Other(format!(
578 "Unsupported HTTP method: {}",
579 m
580 )));
581 }
582 };
583
584 // Add body (always present, even if empty)
585 if !request.body.is_empty() {
586 req_builder = req_builder.body(request.body);
587 }
588
589 // Add headers - convert from oauth2 HeaderName/HeaderValue to reqwest types
590 for (name, value) in &request.headers {
591 let name_str = format!("{:?}", name); // Use debug format for HeaderName
592 // HeaderValue as_bytes should work
593 let value_bytes = value.as_bytes();
594
595 if let (Ok(header_name), Ok(header_value)) = (
596 reqwest::header::HeaderName::from_bytes(name_str.as_bytes()),
597 reqwest::header::HeaderValue::from_bytes(value_bytes),
598 ) {
599 req_builder = req_builder.header(header_name, header_value);
600 }
601 }
602
603 // Send request
604 let response = req_builder
605 .send()
606 .await
607 .map_err(|e| oauth2::reqwest::Error::Other(e.to_string()))?;
608
609 let status = response.status();
610 let body = response
611 .bytes()
612 .await
613 .map_err(|e| oauth2::reqwest::Error::Other(e.to_string()))?
614 .to_vec();
615
616 // Convert reqwest status code to oauth2 status code
617 let oauth_status =
618 oauth2::http::StatusCode::from_u16(status.as_u16()).unwrap_or(oauth2::http::StatusCode::OK);
619
620 Ok(oauth2::HttpResponse {
621 status_code: oauth_status,
622 body,
623 headers: Default::default(),
624 })
625}