turbomcp_auth/oauth2/
client.rs

1//! OAuth 2.1 Client Implementation
2//!
3//! This module provides an OAuth 2.1 client wrapper that supports:
4//! - Authorization Code flow (with PKCE)
5//! - Client Credentials flow (server-to-server)
6//! - Device Authorization flow (CLI/IoT)
7//!
8//! The client handles provider-specific configurations and quirks for
9//! Google, Microsoft, GitHub, GitLab, and generic OAuth providers.
10
11use std::collections::HashMap;
12
13use oauth2::{
14    AuthUrl, ClientId, ClientSecret, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl,
15    RefreshToken, RevocationUrl, Scope, TokenResponse, TokenUrl,
16    basic::{BasicClient, BasicTokenType},
17    revocation::StandardRevocableToken,
18};
19use secrecy::ExposeSecret;
20
21use turbomcp_protocol::{Error as McpError, Result as McpResult};
22
23use super::super::config::{OAuth2Config, ProviderConfig, ProviderType, RefreshBehavior};
24use super::super::types::TokenInfo;
25
26/// OAuth 2.1 client wrapper supporting all modern flows
27#[derive(Debug, Clone)]
28pub struct OAuth2Client {
29    /// Authorization code flow client (most common)
30    pub(crate) auth_code_client: BasicClient,
31    /// Client credentials client (server-to-server)
32    pub(crate) client_credentials_client: Option<BasicClient>,
33    /// Device code client (for CLI/IoT applications)
34    pub(crate) device_code_client: Option<BasicClient>,
35    /// Provider-specific configuration
36    pub provider_config: ProviderConfig,
37}
38
39impl OAuth2Client {
40    /// Create an OAuth 2.1 client supporting all flows
41    pub fn new(config: &OAuth2Config, provider_type: ProviderType) -> McpResult<Self> {
42        // Validate URLs
43        let auth_url = AuthUrl::new(config.auth_url.clone())
44            .map_err(|_| McpError::validation("Invalid authorization URL".to_string()))?;
45
46        let token_url = TokenUrl::new(config.token_url.clone())
47            .map_err(|_| McpError::validation("Invalid token URL".to_string()))?;
48
49        // Redirect URI validation with security checks
50        let redirect_url = Self::validate_redirect_uri(&config.redirect_uri)?;
51
52        // Create authorization code flow client (primary)
53        let client_secret = if config.client_secret.expose_secret().is_empty() {
54            None
55        } else {
56            Some(ClientSecret::new(
57                config.client_secret.expose_secret().clone(),
58            ))
59        };
60
61        let mut auth_code_client = BasicClient::new(
62            ClientId::new(config.client_id.clone()),
63            client_secret.clone(),
64            auth_url.clone(),
65            Some(token_url.clone()),
66        )
67        .set_redirect_uri(redirect_url);
68
69        // Set revocation endpoint if provided (RFC 7009)
70        if let Some(ref revocation_url_str) = config.revocation_url {
71            let revocation_url = RevocationUrl::new(revocation_url_str.clone())
72                .map_err(|_| McpError::validation("Invalid revocation URL".to_string()))?;
73            auth_code_client = auth_code_client.set_revocation_uri(revocation_url);
74        }
75
76        // Create client credentials client if we have a secret (server-to-server)
77        let client_credentials_client = if client_secret.is_some() {
78            Some(BasicClient::new(
79                ClientId::new(config.client_id.clone()),
80                client_secret.clone(),
81                auth_url.clone(),
82                Some(token_url.clone()),
83            ))
84        } else {
85            None
86        };
87
88        // Device code client (for CLI/IoT apps) - uses same configuration
89        let device_code_client = Some(BasicClient::new(
90            ClientId::new(config.client_id.clone()),
91            client_secret,
92            auth_url,
93            Some(token_url),
94        ));
95
96        // Provider-specific configuration
97        let provider_config = Self::build_provider_config(provider_type);
98
99        Ok(Self {
100            auth_code_client,
101            client_credentials_client,
102            device_code_client,
103            provider_config,
104        })
105    }
106
107    /// Build provider-specific configuration
108    fn build_provider_config(provider_type: ProviderType) -> ProviderConfig {
109        match provider_type {
110            ProviderType::Google => ProviderConfig {
111                provider_type,
112                default_scopes: vec![
113                    "openid".to_string(),
114                    "email".to_string(),
115                    "profile".to_string(),
116                ],
117                refresh_behavior: RefreshBehavior::Proactive,
118                userinfo_endpoint: Some(
119                    "https://www.googleapis.com/oauth2/v2/userinfo".to_string(),
120                ),
121                additional_params: HashMap::new(),
122            },
123            ProviderType::Microsoft => ProviderConfig {
124                provider_type,
125                default_scopes: vec![
126                    "openid".to_string(),
127                    "profile".to_string(),
128                    "email".to_string(),
129                    "User.Read".to_string(),
130                ],
131                refresh_behavior: RefreshBehavior::Proactive,
132                userinfo_endpoint: Some("https://graph.microsoft.com/v1.0/me".to_string()),
133                additional_params: HashMap::new(),
134            },
135            ProviderType::GitHub => ProviderConfig {
136                provider_type,
137                default_scopes: vec!["user:email".to_string(), "read:user".to_string()],
138                refresh_behavior: RefreshBehavior::Reactive,
139                userinfo_endpoint: Some("https://api.github.com/user".to_string()),
140                additional_params: HashMap::new(),
141            },
142            ProviderType::GitLab => ProviderConfig {
143                provider_type,
144                default_scopes: vec!["read_user".to_string(), "openid".to_string()],
145                refresh_behavior: RefreshBehavior::Proactive,
146                userinfo_endpoint: Some("https://gitlab.com/api/v4/user".to_string()),
147                additional_params: HashMap::new(),
148            },
149            ProviderType::Apple => ProviderConfig {
150                provider_type,
151                default_scopes: vec![
152                    "openid".to_string(),
153                    "email".to_string(),
154                    "name".to_string(),
155                ],
156                refresh_behavior: RefreshBehavior::Proactive,
157                userinfo_endpoint: Some("https://appleid.apple.com/auth/v1/user".to_string()),
158                additional_params: {
159                    let mut params = HashMap::new();
160                    // Apple requires response_mode=form_post for web apps
161                    params.insert("response_mode".to_string(), "form_post".to_string());
162                    params
163                },
164            },
165            ProviderType::Okta => ProviderConfig {
166                provider_type,
167                default_scopes: vec![
168                    "openid".to_string(),
169                    "email".to_string(),
170                    "profile".to_string(),
171                ],
172                refresh_behavior: RefreshBehavior::Proactive,
173                userinfo_endpoint: Some("/oauth2/v1/userinfo".to_string()), // Relative to Okta domain
174                additional_params: HashMap::new(),
175            },
176            ProviderType::Auth0 => ProviderConfig {
177                provider_type,
178                default_scopes: vec![
179                    "openid".to_string(),
180                    "email".to_string(),
181                    "profile".to_string(),
182                ],
183                refresh_behavior: RefreshBehavior::Proactive,
184                userinfo_endpoint: Some("/userinfo".to_string()), // Relative to Auth0 domain
185                additional_params: HashMap::new(),
186            },
187            ProviderType::Keycloak => ProviderConfig {
188                provider_type,
189                default_scopes: vec![
190                    "openid".to_string(),
191                    "email".to_string(),
192                    "profile".to_string(),
193                ],
194                refresh_behavior: RefreshBehavior::Proactive,
195                userinfo_endpoint: Some(
196                    "/realms/{realm}/protocol/openid-connect/userinfo".to_string(),
197                ),
198                additional_params: HashMap::new(),
199            },
200            ProviderType::Generic | ProviderType::Custom(_) => ProviderConfig {
201                provider_type,
202                default_scopes: vec!["openid".to_string(), "profile".to_string()],
203                refresh_behavior: RefreshBehavior::Proactive,
204                userinfo_endpoint: None,
205                additional_params: HashMap::new(),
206            },
207        }
208    }
209
210    /// Redirect URI validation with security checks
211    ///
212    /// Security considerations:
213    /// - Prevents open redirect attacks
214    /// - Validates URL format and structure
215    /// - Environment-aware validation (localhost for development)
216    fn validate_redirect_uri(uri: &str) -> McpResult<RedirectUrl> {
217        use url::Url;
218
219        // Parse and validate URL structure
220        let parsed = Url::parse(uri)
221            .map_err(|e| McpError::validation(format!("Invalid redirect URI format: {e}")))?;
222
223        // Security: Validate scheme
224        match parsed.scheme() {
225            "http" => {
226                // Only allow http for localhost/127.0.0.1/0.0.0.0 in development
227                if let Some(host) = parsed.host_str() {
228                    // Allow localhost, 127.0.0.1, 0.0.0.0 (bind all interfaces)
229                    let is_localhost = host == "localhost"
230                        || host.starts_with("localhost:")
231                        || host == "127.0.0.1"
232                        || host.starts_with("127.0.0.1:")
233                        || host == "0.0.0.0"
234                        || host.starts_with("0.0.0.0:");
235
236                    if !is_localhost {
237                        return Err(McpError::validation(
238                            "HTTP redirect URIs only allowed for localhost in development"
239                                .to_string(),
240                        ));
241                    }
242                } else {
243                    return Err(McpError::validation(
244                        "Redirect URI must have a valid host".to_string(),
245                    ));
246                }
247            }
248            "https" => {
249                // HTTPS is always allowed
250            }
251            "com.example.app" | "msauth" => {
252                // Allow custom schemes for mobile apps (common patterns)
253            }
254            scheme if scheme.starts_with("app.") || scheme.ends_with(".app") => {
255                // Allow app-specific custom schemes
256            }
257            _ => {
258                return Err(McpError::validation(format!(
259                    "Unsupported redirect URI scheme: {}. Use https, http (localhost only), or app-specific schemes",
260                    parsed.scheme()
261                )));
262            }
263        }
264
265        // Security: Prevent fragment in redirect URI (per OAuth 2.0 spec)
266        if parsed.fragment().is_some() {
267            return Err(McpError::validation(
268                "Redirect URI must not contain URL fragment".to_string(),
269            ));
270        }
271
272        // Security: Check for path traversal in PATH component only
273        // Note: url::Url::parse() already normalizes paths and removes .. segments
274        // We check the final path to ensure no traversal remains after normalization
275        if let Some(path) = parsed.path_segments() {
276            for segment in path {
277                if segment == ".." {
278                    return Err(McpError::validation(
279                        "Redirect URI path must not contain traversal sequences".to_string(),
280                    ));
281                }
282            }
283        }
284
285        // Use oauth2 crate's RedirectUrl for validation
286        // This provides URL validation per OAuth 2.1 specifications
287        // For production security, implement exact whitelist matching of allowed URIs
288        RedirectUrl::new(uri.to_string())
289            .map_err(|_| McpError::validation("Failed to create redirect URL".to_string()))
290    }
291
292    /// Get access to the authorization code client
293    #[must_use]
294    pub fn auth_code_client(&self) -> &BasicClient {
295        &self.auth_code_client
296    }
297
298    /// Get access to the client credentials client (if available)
299    #[must_use]
300    pub fn client_credentials_client(&self) -> Option<&BasicClient> {
301        self.client_credentials_client.as_ref()
302    }
303
304    /// Get access to the device code client (if available)
305    #[must_use]
306    pub fn device_code_client(&self) -> Option<&BasicClient> {
307        self.device_code_client.as_ref()
308    }
309
310    /// Get the provider configuration
311    #[must_use]
312    pub fn provider_config(&self) -> &ProviderConfig {
313        &self.provider_config
314    }
315
316    /// Start authorization code flow with PKCE
317    ///
318    /// This initiates the OAuth 2.1 authorization code flow with PKCE (RFC 7636)
319    /// for enhanced security, especially for public clients.
320    ///
321    /// # PKCE Code Verifier Storage (CRITICAL SECURITY REQUIREMENT)
322    ///
323    /// The returned code_verifier MUST be securely stored and associated with the
324    /// state parameter until the authorization code is exchanged for tokens.
325    ///
326    /// **Storage Options (from most to least secure):**
327    ///
328    /// 1. **Server-side encrypted session** (RECOMMENDED for web apps)
329    ///    - Store in server session with HttpOnly, Secure, SameSite=Lax cookies
330    ///    - Associate with state parameter for CSRF protection
331    ///    - Automatic cleanup after exchange or timeout
332    ///
333    /// 2. **Redis/Database with TTL** (RECOMMENDED for distributed systems)
334    ///    - Key: state parameter, Value: encrypted code_verifier
335    ///    - Set TTL to match authorization timeout (typically 10 minutes)
336    ///    - Use server-side encryption at rest
337    ///
338    /// 3. **In-memory for SPAs** (ACCEPTABLE for public clients only)
339    ///    - Store in JavaScript closure or React state (NOT localStorage/sessionStorage)
340    ///    - Clear immediately after token exchange
341    ///    - Risk: XSS can steal verifier
342    ///
343    /// **NEVER:**
344    /// - Store in localStorage or sessionStorage (XSS risk)
345    /// - Send to client in URL or query parameters
346    /// - Log or expose in error messages
347    ///
348    /// # Arguments
349    /// * `scopes` - Requested OAuth scopes
350    /// * `state` - CSRF protection state parameter (use cryptographically random value)
351    ///
352    /// # Returns
353    /// Tuple of (authorization_url, PKCE code_verifier for secure storage)
354    ///
355    /// # Example
356    /// ```ignore
357    /// // Server-side web app (RECOMMENDED)
358    /// let state = generate_csrf_token();  // Cryptographically random
359    /// let (auth_url, code_verifier) = client.authorization_code_flow(scopes, state.clone());
360    ///
361    /// // Store securely server-side
362    /// session.insert("oauth_state", state);
363    /// session.insert("pkce_verifier", code_verifier);  // Encrypted session
364    ///
365    /// // Redirect user
366    /// redirect_to(auth_url);
367    /// ```
368    pub fn authorization_code_flow(&self, scopes: Vec<String>, state: String) -> (String, String) {
369        // Generate PKCE challenge
370        let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
371
372        // Build authorization URL with PKCE
373        let (auth_url, _state) = self
374            .auth_code_client
375            .authorize_url(|| oauth2::CsrfToken::new(state))
376            .add_scopes(scopes.into_iter().map(Scope::new))
377            .set_pkce_challenge(pkce_challenge)
378            .url();
379
380        (auth_url.to_string(), pkce_verifier.secret().to_string())
381    }
382
383    /// Exchange authorization code for access token
384    ///
385    /// This exchanges the authorization code received from the OAuth provider
386    /// for an access token using PKCE (RFC 7636).
387    ///
388    /// # Arguments
389    /// * `code` - Authorization code from OAuth provider
390    /// * `code_verifier` - PKCE code verifier (from authorization_code_flow)
391    ///
392    /// # Returns
393    /// TokenInfo containing access token and refresh token (if available)
394    pub async fn exchange_code_for_token(
395        &self,
396        code: String,
397        code_verifier: String,
398    ) -> McpResult<TokenInfo> {
399        let http_client = reqwest::Client::new();
400        let token_response = self
401            .auth_code_client
402            .exchange_code(oauth2::AuthorizationCode::new(code))
403            .set_pkce_verifier(PkceCodeVerifier::new(code_verifier))
404            .request_async(|request| async { execute_oauth_request(&http_client, request).await })
405            .await
406            .map_err(|e| McpError::internal(format!("Token exchange failed: {e}")))?;
407
408        Ok(self.token_response_to_token_info(token_response))
409    }
410
411    /// Refresh an access token with automatic refresh token rotation
412    ///
413    /// This uses a refresh token to obtain a new access token without
414    /// requiring user interaction. OAuth 2.1 and RFC 9700 recommend refresh
415    /// token rotation where the server issues a new refresh token with each
416    /// refresh request.
417    ///
418    /// # Refresh Token Rotation (OAuth 2.1 / RFC 9700 Best Practice)
419    ///
420    /// When the server supports rotation:
421    /// - A new refresh token is returned in the response
422    /// - The old refresh token should be discarded immediately
423    /// - Store and use the new refresh token for future requests
424    /// - This prevents token theft detection
425    ///
426    /// **Important:** Always check if `token_info.refresh_token` is present in
427    /// the response. If present, you MUST replace your stored refresh token
428    /// with the new one. If absent, continue using the current refresh token.
429    ///
430    /// # Arguments
431    /// * `refresh_token` - The current refresh token
432    ///
433    /// # Returns
434    /// New TokenInfo with:
435    /// - Fresh access token (always present)
436    /// - New refresh token (if server supports rotation)
437    ///
438    /// # Example
439    /// ```ignore
440    /// let mut stored_refresh_token = "current_refresh_token";
441    /// let new_tokens = client.refresh_access_token(stored_refresh_token).await?;
442    ///
443    /// // Check for refresh token rotation
444    /// if let Some(new_refresh_token) = &new_tokens.refresh_token {
445    ///     // Server rotated the token - update storage
446    ///     stored_refresh_token = new_refresh_token;
447    ///     println!("Refresh token rotated (security best practice)");
448    /// }
449    /// // Use new access token
450    /// let access_token = new_tokens.access_token;
451    /// ```
452    pub async fn refresh_access_token(&self, refresh_token: &str) -> McpResult<TokenInfo> {
453        let http_client = reqwest::Client::new();
454        let token_response = self
455            .auth_code_client
456            .exchange_refresh_token(&RefreshToken::new(refresh_token.to_string()))
457            .request_async(|request| async { execute_oauth_request(&http_client, request).await })
458            .await
459            .map_err(|e| McpError::internal(format!("Token refresh failed: {e}")))?;
460
461        Ok(self.token_response_to_token_info(token_response))
462    }
463
464    /// Client credentials flow for server-to-server authentication
465    ///
466    /// This implements the OAuth 2.1 Client Credentials flow for
467    /// service-to-service communication without user involvement.
468    ///
469    /// # Arguments
470    /// * `scopes` - Requested OAuth scopes
471    ///
472    /// # Returns
473    /// TokenInfo with access token (typically without refresh token)
474    pub async fn client_credentials_flow(&self, scopes: Vec<String>) -> McpResult<TokenInfo> {
475        let client = self.client_credentials_client.as_ref().ok_or_else(|| {
476            McpError::internal("Client credentials flow requires client secret".to_string())
477        })?;
478
479        let http_client = reqwest::Client::new();
480        let token_response = client
481            .exchange_client_credentials()
482            .add_scopes(scopes.into_iter().map(Scope::new))
483            .request_async(|request| async { execute_oauth_request(&http_client, request).await })
484            .await
485            .map_err(|e| McpError::internal(format!("Client credentials flow failed: {e}")))?;
486
487        Ok(self.token_response_to_token_info(token_response))
488    }
489
490    /// Convert oauth2 token response to TokenInfo
491    fn token_response_to_token_info(
492        &self,
493        response: oauth2::StandardTokenResponse<oauth2::EmptyExtraTokenFields, BasicTokenType>,
494    ) -> TokenInfo {
495        let expires_in = response.expires_in().map(|duration| duration.as_secs());
496
497        TokenInfo {
498            access_token: response.access_token().secret().clone(),
499            token_type: format!("{:?}", response.token_type()),
500            refresh_token: response.refresh_token().map(|t| t.secret().clone()),
501            expires_in,
502            scope: response.scopes().map(|scopes| {
503                scopes
504                    .iter()
505                    .map(|s| s.as_str())
506                    .collect::<Vec<_>>()
507                    .join(" ")
508            }),
509        }
510    }
511
512    /// Revoke a token using RFC 7009 Token Revocation
513    ///
514    /// Per RFC 7009 Section 2, prefer revoking refresh tokens (which MUST be supported
515    /// by the server if issued) over access tokens (which MAY be supported).
516    ///
517    /// # Arguments
518    /// * `token_info` - Token information containing access and/or refresh token
519    ///
520    /// # Returns
521    /// Ok if revocation succeeded or token was already invalid (per RFC 7009)
522    ///
523    /// # Errors
524    /// Returns error if:
525    /// - No revocation endpoint was configured
526    /// - Network/HTTP error occurred
527    /// - Server returned an error response
528    pub async fn revoke_token(&self, token_info: &TokenInfo) -> McpResult<()> {
529        let http_client = reqwest::Client::new();
530
531        // Per RFC 7009 Section 2: Prefer refresh token, fallback to access token
532        let token_to_revoke: StandardRevocableToken =
533            if let Some(ref refresh_token) = token_info.refresh_token {
534                RefreshToken::new(refresh_token.clone()).into()
535            } else {
536                oauth2::AccessToken::new(token_info.access_token.clone()).into()
537            };
538
539        self.auth_code_client
540            .revoke_token(token_to_revoke)
541            .map_err(|e| McpError::internal(format!("Token revocation not configured: {e}")))?
542            .request_async(|request| async { execute_oauth_request(&http_client, request).await })
543            .await
544            .map_err(|e| McpError::internal(format!("Token revocation failed: {e}")))?;
545
546        Ok(())
547    }
548
549    /// Validate that an access token is still valid
550    ///
551    /// This checks if a token has expired based on expiration time.
552    /// Note: This is a client-side check only; servers may have revoked the token.
553    pub fn is_token_expired(&self, token: &TokenInfo) -> bool {
554        if let Some(expires_in) = token.expires_in {
555            // Assume token was valid "now" - in production, store issued_at timestamp
556            expires_in == 0
557        } else {
558            false
559        }
560    }
561}
562
563/// Execute OAuth request using reqwest HTTP client
564/// Converts between oauth2 and reqwest types
565async fn execute_oauth_request(
566    client: &reqwest::Client,
567    request: oauth2::HttpRequest,
568) -> Result<oauth2::HttpResponse, oauth2::reqwest::Error<reqwest::Error>> {
569    let method_str = format!("{}", request.method);
570    let url = request.url.clone();
571
572    // Build the request
573    let mut req_builder = match method_str.to_uppercase().as_str() {
574        "GET" => client.get(url),
575        "POST" => client.post(url),
576        m => {
577            return Err(oauth2::reqwest::Error::Other(format!(
578                "Unsupported HTTP method: {}",
579                m
580            )));
581        }
582    };
583
584    // Add body (always present, even if empty)
585    if !request.body.is_empty() {
586        req_builder = req_builder.body(request.body);
587    }
588
589    // Add headers - convert from oauth2 HeaderName/HeaderValue to reqwest types
590    for (name, value) in &request.headers {
591        let name_str = format!("{:?}", name); // Use debug format for HeaderName
592        // HeaderValue as_bytes should work
593        let value_bytes = value.as_bytes();
594
595        if let (Ok(header_name), Ok(header_value)) = (
596            reqwest::header::HeaderName::from_bytes(name_str.as_bytes()),
597            reqwest::header::HeaderValue::from_bytes(value_bytes),
598        ) {
599            req_builder = req_builder.header(header_name, header_value);
600        }
601    }
602
603    // Send request
604    let response = req_builder
605        .send()
606        .await
607        .map_err(|e| oauth2::reqwest::Error::Other(e.to_string()))?;
608
609    let status = response.status();
610    let body = response
611        .bytes()
612        .await
613        .map_err(|e| oauth2::reqwest::Error::Other(e.to_string()))?
614        .to_vec();
615
616    // Convert reqwest status code to oauth2 status code
617    let oauth_status =
618        oauth2::http::StatusCode::from_u16(status.as_u16()).unwrap_or(oauth2::http::StatusCode::OK);
619
620    Ok(oauth2::HttpResponse {
621        status_code: oauth_status,
622        body,
623        headers: Default::default(),
624    })
625}