turbomcp_auth/oauth2/
client.rs

1//! OAuth 2.1 Client Implementation
2//!
3//! This module provides an OAuth 2.1 client wrapper that supports:
4//! - Authorization Code flow (with PKCE)
5//! - Client Credentials flow (server-to-server)
6//! - Device Authorization flow (CLI/IoT)
7//!
8//! The client handles provider-specific configurations and quirks for
9//! Google, Microsoft, GitHub, GitLab, and generic OAuth providers.
10
11use std::collections::HashMap;
12
13use oauth2::{
14    AuthUrl, ClientId, ClientSecret, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl,
15    RefreshToken, RevocationUrl, Scope, TokenResponse, TokenUrl,
16    basic::{BasicClient, BasicTokenType},
17    revocation::StandardRevocableToken,
18};
19use secrecy::ExposeSecret;
20
21use turbomcp_protocol::{Error as McpError, Result as McpResult};
22
23use super::super::config::{OAuth2Config, ProviderConfig, ProviderType, RefreshBehavior};
24use super::super::types::TokenInfo;
25
26/// OAuth 2.1 client wrapper supporting all modern flows
27#[derive(Debug, Clone)]
28pub struct OAuth2Client {
29    /// Authorization code flow client (most common)
30    pub(crate) auth_code_client: BasicClient,
31    /// Client credentials client (server-to-server)
32    pub(crate) client_credentials_client: Option<BasicClient>,
33    /// Device code client (for CLI/IoT applications)
34    pub(crate) device_code_client: Option<BasicClient>,
35    /// Provider-specific configuration
36    pub provider_config: ProviderConfig,
37}
38
39impl OAuth2Client {
40    /// Create an OAuth 2.1 client supporting all flows
41    pub fn new(config: &OAuth2Config, provider_type: ProviderType) -> McpResult<Self> {
42        // Validate URLs
43        let auth_url = AuthUrl::new(config.auth_url.clone())
44            .map_err(|_| McpError::validation("Invalid authorization URL".to_string()))?;
45
46        let token_url = TokenUrl::new(config.token_url.clone())
47            .map_err(|_| McpError::validation("Invalid token URL".to_string()))?;
48
49        // Redirect URI validation with security checks
50        let redirect_url = Self::validate_redirect_uri(&config.redirect_uri)?;
51
52        // Create authorization code flow client (primary)
53        let client_secret = if config.client_secret.expose_secret().is_empty() {
54            None
55        } else {
56            Some(ClientSecret::new(
57                config.client_secret.expose_secret().clone(),
58            ))
59        };
60
61        let mut auth_code_client = BasicClient::new(
62            ClientId::new(config.client_id.clone()),
63            client_secret.clone(),
64            auth_url.clone(),
65            Some(token_url.clone()),
66        )
67        .set_redirect_uri(redirect_url);
68
69        // Set revocation endpoint if provided (RFC 7009)
70        if let Some(ref revocation_url_str) = config.revocation_url {
71            let revocation_url = RevocationUrl::new(revocation_url_str.clone())
72                .map_err(|_| McpError::validation("Invalid revocation URL".to_string()))?;
73            auth_code_client = auth_code_client.set_revocation_uri(revocation_url);
74        }
75
76        // Create client credentials client if we have a secret (server-to-server)
77        let client_credentials_client = if client_secret.is_some() {
78            Some(BasicClient::new(
79                ClientId::new(config.client_id.clone()),
80                client_secret.clone(),
81                auth_url.clone(),
82                Some(token_url.clone()),
83            ))
84        } else {
85            None
86        };
87
88        // Device code client (for CLI/IoT apps) - uses same configuration
89        let device_code_client = Some(BasicClient::new(
90            ClientId::new(config.client_id.clone()),
91            client_secret,
92            auth_url,
93            Some(token_url),
94        ));
95
96        // Provider-specific configuration
97        let provider_config = Self::build_provider_config(provider_type);
98
99        Ok(Self {
100            auth_code_client,
101            client_credentials_client,
102            device_code_client,
103            provider_config,
104        })
105    }
106
107    /// Build provider-specific configuration
108    fn build_provider_config(provider_type: ProviderType) -> ProviderConfig {
109        match provider_type {
110            ProviderType::Google => ProviderConfig {
111                provider_type,
112                default_scopes: vec![
113                    "openid".to_string(),
114                    "email".to_string(),
115                    "profile".to_string(),
116                ],
117                refresh_behavior: RefreshBehavior::Proactive,
118                userinfo_endpoint: Some(
119                    "https://www.googleapis.com/oauth2/v2/userinfo".to_string(),
120                ),
121                additional_params: HashMap::new(),
122            },
123            ProviderType::Microsoft => ProviderConfig {
124                provider_type,
125                default_scopes: vec![
126                    "openid".to_string(),
127                    "profile".to_string(),
128                    "email".to_string(),
129                    "User.Read".to_string(),
130                ],
131                refresh_behavior: RefreshBehavior::Proactive,
132                userinfo_endpoint: Some("https://graph.microsoft.com/v1.0/me".to_string()),
133                additional_params: HashMap::new(),
134            },
135            ProviderType::GitHub => ProviderConfig {
136                provider_type,
137                default_scopes: vec!["user:email".to_string(), "read:user".to_string()],
138                refresh_behavior: RefreshBehavior::Reactive,
139                userinfo_endpoint: Some("https://api.github.com/user".to_string()),
140                additional_params: HashMap::new(),
141            },
142            ProviderType::GitLab => ProviderConfig {
143                provider_type,
144                default_scopes: vec!["read_user".to_string(), "openid".to_string()],
145                refresh_behavior: RefreshBehavior::Proactive,
146                userinfo_endpoint: Some("https://gitlab.com/api/v4/user".to_string()),
147                additional_params: HashMap::new(),
148            },
149            ProviderType::Generic | ProviderType::Custom(_) => ProviderConfig {
150                provider_type,
151                default_scopes: vec!["openid".to_string(), "profile".to_string()],
152                refresh_behavior: RefreshBehavior::Proactive,
153                userinfo_endpoint: None,
154                additional_params: HashMap::new(),
155            },
156        }
157    }
158
159    /// Redirect URI validation with security checks
160    ///
161    /// Security considerations:
162    /// - Prevents open redirect attacks
163    /// - Validates URL format and structure
164    /// - Environment-aware validation (localhost for development)
165    fn validate_redirect_uri(uri: &str) -> McpResult<RedirectUrl> {
166        use url::Url;
167
168        // Parse and validate URL structure
169        let parsed = Url::parse(uri)
170            .map_err(|e| McpError::validation(format!("Invalid redirect URI format: {e}")))?;
171
172        // Security: Validate scheme
173        match parsed.scheme() {
174            "http" => {
175                // Only allow http for localhost/127.0.0.1/0.0.0.0 in development
176                if let Some(host) = parsed.host_str() {
177                    // Allow localhost, 127.0.0.1, 0.0.0.0 (bind all interfaces)
178                    let is_localhost = host == "localhost"
179                        || host.starts_with("localhost:")
180                        || host == "127.0.0.1"
181                        || host.starts_with("127.0.0.1:")
182                        || host == "0.0.0.0"
183                        || host.starts_with("0.0.0.0:");
184
185                    if !is_localhost {
186                        return Err(McpError::validation(
187                            "HTTP redirect URIs only allowed for localhost in development"
188                                .to_string(),
189                        ));
190                    }
191                } else {
192                    return Err(McpError::validation(
193                        "Redirect URI must have a valid host".to_string(),
194                    ));
195                }
196            }
197            "https" => {
198                // HTTPS is always allowed
199            }
200            "com.example.app" | "msauth" => {
201                // Allow custom schemes for mobile apps (common patterns)
202            }
203            scheme if scheme.starts_with("app.") || scheme.ends_with(".app") => {
204                // Allow app-specific custom schemes
205            }
206            _ => {
207                return Err(McpError::validation(format!(
208                    "Unsupported redirect URI scheme: {}. Use https, http (localhost only), or app-specific schemes",
209                    parsed.scheme()
210                )));
211            }
212        }
213
214        // Security: Prevent fragment in redirect URI (per OAuth 2.0 spec)
215        if parsed.fragment().is_some() {
216            return Err(McpError::validation(
217                "Redirect URI must not contain URL fragment".to_string(),
218            ));
219        }
220
221        // Security: Check for path traversal in PATH component only
222        // Note: url::Url::parse() already normalizes paths and removes .. segments
223        // We check the final path to ensure no traversal remains after normalization
224        if let Some(path) = parsed.path_segments() {
225            for segment in path {
226                if segment == ".." {
227                    return Err(McpError::validation(
228                        "Redirect URI path must not contain traversal sequences".to_string(),
229                    ));
230                }
231            }
232        }
233
234        // Use oauth2 crate's RedirectUrl for validation
235        // This provides URL validation per OAuth 2.1 specifications
236        // For production security, implement exact whitelist matching of allowed URIs
237        RedirectUrl::new(uri.to_string())
238            .map_err(|_| McpError::validation("Failed to create redirect URL".to_string()))
239    }
240
241    /// Get access to the authorization code client
242    #[must_use]
243    pub fn auth_code_client(&self) -> &BasicClient {
244        &self.auth_code_client
245    }
246
247    /// Get access to the client credentials client (if available)
248    #[must_use]
249    pub fn client_credentials_client(&self) -> Option<&BasicClient> {
250        self.client_credentials_client.as_ref()
251    }
252
253    /// Get access to the device code client (if available)
254    #[must_use]
255    pub fn device_code_client(&self) -> Option<&BasicClient> {
256        self.device_code_client.as_ref()
257    }
258
259    /// Get the provider configuration
260    #[must_use]
261    pub fn provider_config(&self) -> &ProviderConfig {
262        &self.provider_config
263    }
264
265    /// Start authorization code flow with PKCE
266    ///
267    /// This initiates the OAuth 2.1 authorization code flow with PKCE (RFC 7636)
268    /// for enhanced security, especially for public clients.
269    ///
270    /// # PKCE Code Verifier Storage (CRITICAL SECURITY REQUIREMENT)
271    ///
272    /// The returned code_verifier MUST be securely stored and associated with the
273    /// state parameter until the authorization code is exchanged for tokens.
274    ///
275    /// **Storage Options (from most to least secure):**
276    ///
277    /// 1. **Server-side encrypted session** (RECOMMENDED for web apps)
278    ///    - Store in server session with HttpOnly, Secure, SameSite=Lax cookies
279    ///    - Associate with state parameter for CSRF protection
280    ///    - Automatic cleanup after exchange or timeout
281    ///
282    /// 2. **Redis/Database with TTL** (RECOMMENDED for distributed systems)
283    ///    - Key: state parameter, Value: encrypted code_verifier
284    ///    - Set TTL to match authorization timeout (typically 10 minutes)
285    ///    - Use server-side encryption at rest
286    ///
287    /// 3. **In-memory for SPAs** (ACCEPTABLE for public clients only)
288    ///    - Store in JavaScript closure or React state (NOT localStorage/sessionStorage)
289    ///    - Clear immediately after token exchange
290    ///    - Risk: XSS can steal verifier
291    ///
292    /// **NEVER:**
293    /// - Store in localStorage or sessionStorage (XSS risk)
294    /// - Send to client in URL or query parameters
295    /// - Log or expose in error messages
296    ///
297    /// # Arguments
298    /// * `scopes` - Requested OAuth scopes
299    /// * `state` - CSRF protection state parameter (use cryptographically random value)
300    ///
301    /// # Returns
302    /// Tuple of (authorization_url, PKCE code_verifier for secure storage)
303    ///
304    /// # Example
305    /// ```ignore
306    /// // Server-side web app (RECOMMENDED)
307    /// let state = generate_csrf_token();  // Cryptographically random
308    /// let (auth_url, code_verifier) = client.authorization_code_flow(scopes, state.clone());
309    ///
310    /// // Store securely server-side
311    /// session.insert("oauth_state", state);
312    /// session.insert("pkce_verifier", code_verifier);  // Encrypted session
313    ///
314    /// // Redirect user
315    /// redirect_to(auth_url);
316    /// ```
317    pub fn authorization_code_flow(&self, scopes: Vec<String>, state: String) -> (String, String) {
318        // Generate PKCE challenge
319        let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
320
321        // Build authorization URL with PKCE
322        let (auth_url, _state) = self
323            .auth_code_client
324            .authorize_url(|| oauth2::CsrfToken::new(state))
325            .add_scopes(scopes.into_iter().map(Scope::new))
326            .set_pkce_challenge(pkce_challenge)
327            .url();
328
329        (auth_url.to_string(), pkce_verifier.secret().to_string())
330    }
331
332    /// Exchange authorization code for access token
333    ///
334    /// This exchanges the authorization code received from the OAuth provider
335    /// for an access token using PKCE (RFC 7636).
336    ///
337    /// # Arguments
338    /// * `code` - Authorization code from OAuth provider
339    /// * `code_verifier` - PKCE code verifier (from authorization_code_flow)
340    ///
341    /// # Returns
342    /// TokenInfo containing access token and refresh token (if available)
343    pub async fn exchange_code_for_token(
344        &self,
345        code: String,
346        code_verifier: String,
347    ) -> McpResult<TokenInfo> {
348        let http_client = reqwest::Client::new();
349        let token_response = self
350            .auth_code_client
351            .exchange_code(oauth2::AuthorizationCode::new(code))
352            .set_pkce_verifier(PkceCodeVerifier::new(code_verifier))
353            .request_async(|request| async { execute_oauth_request(&http_client, request).await })
354            .await
355            .map_err(|e| McpError::internal(format!("Token exchange failed: {e}")))?;
356
357        Ok(self.token_response_to_token_info(token_response))
358    }
359
360    /// Refresh an access token with automatic refresh token rotation
361    ///
362    /// This uses a refresh token to obtain a new access token without
363    /// requiring user interaction. OAuth 2.1 and RFC 9700 recommend refresh
364    /// token rotation where the server issues a new refresh token with each
365    /// refresh request.
366    ///
367    /// # Refresh Token Rotation (OAuth 2.1 / RFC 9700 Best Practice)
368    ///
369    /// When the server supports rotation:
370    /// - A new refresh token is returned in the response
371    /// - The old refresh token should be discarded immediately
372    /// - Store and use the new refresh token for future requests
373    /// - This prevents token theft detection
374    ///
375    /// **Important:** Always check if `token_info.refresh_token` is present in
376    /// the response. If present, you MUST replace your stored refresh token
377    /// with the new one. If absent, continue using the current refresh token.
378    ///
379    /// # Arguments
380    /// * `refresh_token` - The current refresh token
381    ///
382    /// # Returns
383    /// New TokenInfo with:
384    /// - Fresh access token (always present)
385    /// - New refresh token (if server supports rotation)
386    ///
387    /// # Example
388    /// ```ignore
389    /// let mut stored_refresh_token = "current_refresh_token";
390    /// let new_tokens = client.refresh_access_token(stored_refresh_token).await?;
391    ///
392    /// // Check for refresh token rotation
393    /// if let Some(new_refresh_token) = &new_tokens.refresh_token {
394    ///     // Server rotated the token - update storage
395    ///     stored_refresh_token = new_refresh_token;
396    ///     println!("Refresh token rotated (security best practice)");
397    /// }
398    /// // Use new access token
399    /// let access_token = new_tokens.access_token;
400    /// ```
401    pub async fn refresh_access_token(&self, refresh_token: &str) -> McpResult<TokenInfo> {
402        let http_client = reqwest::Client::new();
403        let token_response = self
404            .auth_code_client
405            .exchange_refresh_token(&RefreshToken::new(refresh_token.to_string()))
406            .request_async(|request| async { execute_oauth_request(&http_client, request).await })
407            .await
408            .map_err(|e| McpError::internal(format!("Token refresh failed: {e}")))?;
409
410        Ok(self.token_response_to_token_info(token_response))
411    }
412
413    /// Client credentials flow for server-to-server authentication
414    ///
415    /// This implements the OAuth 2.1 Client Credentials flow for
416    /// service-to-service communication without user involvement.
417    ///
418    /// # Arguments
419    /// * `scopes` - Requested OAuth scopes
420    ///
421    /// # Returns
422    /// TokenInfo with access token (typically without refresh token)
423    pub async fn client_credentials_flow(&self, scopes: Vec<String>) -> McpResult<TokenInfo> {
424        let client = self.client_credentials_client.as_ref().ok_or_else(|| {
425            McpError::internal("Client credentials flow requires client secret".to_string())
426        })?;
427
428        let http_client = reqwest::Client::new();
429        let token_response = client
430            .exchange_client_credentials()
431            .add_scopes(scopes.into_iter().map(Scope::new))
432            .request_async(|request| async { execute_oauth_request(&http_client, request).await })
433            .await
434            .map_err(|e| McpError::internal(format!("Client credentials flow failed: {e}")))?;
435
436        Ok(self.token_response_to_token_info(token_response))
437    }
438
439    /// Convert oauth2 token response to TokenInfo
440    fn token_response_to_token_info(
441        &self,
442        response: oauth2::StandardTokenResponse<oauth2::EmptyExtraTokenFields, BasicTokenType>,
443    ) -> TokenInfo {
444        let expires_in = response.expires_in().map(|duration| duration.as_secs());
445
446        TokenInfo {
447            access_token: response.access_token().secret().clone(),
448            token_type: format!("{:?}", response.token_type()),
449            refresh_token: response.refresh_token().map(|t| t.secret().clone()),
450            expires_in,
451            scope: response.scopes().map(|scopes| {
452                scopes
453                    .iter()
454                    .map(|s| s.as_str())
455                    .collect::<Vec<_>>()
456                    .join(" ")
457            }),
458        }
459    }
460
461    /// Revoke a token using RFC 7009 Token Revocation
462    ///
463    /// Per RFC 7009 Section 2, prefer revoking refresh tokens (which MUST be supported
464    /// by the server if issued) over access tokens (which MAY be supported).
465    ///
466    /// # Arguments
467    /// * `token_info` - Token information containing access and/or refresh token
468    ///
469    /// # Returns
470    /// Ok if revocation succeeded or token was already invalid (per RFC 7009)
471    ///
472    /// # Errors
473    /// Returns error if:
474    /// - No revocation endpoint was configured
475    /// - Network/HTTP error occurred
476    /// - Server returned an error response
477    pub async fn revoke_token(&self, token_info: &TokenInfo) -> McpResult<()> {
478        let http_client = reqwest::Client::new();
479
480        // Per RFC 7009 Section 2: Prefer refresh token, fallback to access token
481        let token_to_revoke: StandardRevocableToken =
482            if let Some(ref refresh_token) = token_info.refresh_token {
483                RefreshToken::new(refresh_token.clone()).into()
484            } else {
485                oauth2::AccessToken::new(token_info.access_token.clone()).into()
486            };
487
488        self.auth_code_client
489            .revoke_token(token_to_revoke)
490            .map_err(|e| McpError::internal(format!("Token revocation not configured: {e}")))?
491            .request_async(|request| async { execute_oauth_request(&http_client, request).await })
492            .await
493            .map_err(|e| McpError::internal(format!("Token revocation failed: {e}")))?;
494
495        Ok(())
496    }
497
498    /// Validate that an access token is still valid
499    ///
500    /// This checks if a token has expired based on expiration time.
501    /// Note: This is a client-side check only; servers may have revoked the token.
502    pub fn is_token_expired(&self, token: &TokenInfo) -> bool {
503        if let Some(expires_in) = token.expires_in {
504            // Assume token was valid "now" - in production, store issued_at timestamp
505            expires_in == 0
506        } else {
507            false
508        }
509    }
510}
511
512/// Execute OAuth request using reqwest HTTP client
513/// Converts between oauth2 and reqwest types
514async fn execute_oauth_request(
515    client: &reqwest::Client,
516    request: oauth2::HttpRequest,
517) -> Result<oauth2::HttpResponse, oauth2::reqwest::Error<reqwest::Error>> {
518    let method_str = format!("{}", request.method);
519    let url = request.url.clone();
520
521    // Build the request
522    let mut req_builder = match method_str.to_uppercase().as_str() {
523        "GET" => client.get(url),
524        "POST" => client.post(url),
525        m => {
526            return Err(oauth2::reqwest::Error::Other(format!(
527                "Unsupported HTTP method: {}",
528                m
529            )));
530        }
531    };
532
533    // Add body (always present, even if empty)
534    if !request.body.is_empty() {
535        req_builder = req_builder.body(request.body);
536    }
537
538    // Add headers - convert from oauth2 HeaderName/HeaderValue to reqwest types
539    for (name, value) in &request.headers {
540        let name_str = format!("{:?}", name); // Use debug format for HeaderName
541        // HeaderValue as_bytes should work
542        let value_bytes = value.as_bytes();
543
544        if let (Ok(header_name), Ok(header_value)) = (
545            reqwest::header::HeaderName::from_bytes(name_str.as_bytes()),
546            reqwest::header::HeaderValue::from_bytes(value_bytes),
547        ) {
548            req_builder = req_builder.header(header_name, header_value);
549        }
550    }
551
552    // Send request
553    let response = req_builder
554        .send()
555        .await
556        .map_err(|e| oauth2::reqwest::Error::Other(e.to_string()))?;
557
558    let status = response.status();
559    let body = response
560        .bytes()
561        .await
562        .map_err(|e| oauth2::reqwest::Error::Other(e.to_string()))?
563        .to_vec();
564
565    // Convert reqwest status code to oauth2 status code
566    let oauth_status =
567        oauth2::http::StatusCode::from_u16(status.as_u16()).unwrap_or(oauth2::http::StatusCode::OK);
568
569    Ok(oauth2::HttpResponse {
570        status_code: oauth_status,
571        body,
572        headers: Default::default(),
573    })
574}