Skip to main content

vtcode_auth/
mcp_oauth.rs

1//! OAuth support for HTTP MCP providers.
2
3use anyhow::{Context, Result, anyhow, bail};
4use base64::Engine;
5use base64::engine::general_purpose::URL_SAFE_NO_PAD;
6use reqwest::{Client, Url};
7use ring::rand::{SecureRandom, SystemRandom};
8use serde::{Deserialize, Serialize};
9use std::collections::BTreeMap;
10
11use crate::credentials::{AuthCredentialsStoreMode, CredentialStorage};
12use crate::pkce::{PkceChallenge, generate_pkce_challenge};
13
14const DEFAULT_CALLBACK_PORT: u16 = 8768;
15const DEFAULT_FLOW_TIMEOUT_SECS: u64 = 300;
16const REFRESH_SKEW_SECS: u64 = 60;
17
18/// Configuration for OAuth-enabled MCP HTTP providers.
19#[derive(Debug, Clone, Serialize, Deserialize)]
20#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
21#[serde(default)]
22pub struct McpOAuthConfig {
23    /// OAuth authorization endpoint.
24    pub authorization_url: String,
25    /// OAuth token endpoint.
26    pub token_url: String,
27    /// OAuth client identifier.
28    pub client_id: String,
29    /// Requested scopes.
30    #[serde(default)]
31    pub scopes: Vec<String>,
32    /// Optional audience/resource hint sent with the auth and token requests.
33    #[serde(default)]
34    pub audience: Option<String>,
35    /// Local callback server port.
36    pub callback_port: u16,
37    /// Browser-flow timeout in seconds.
38    pub flow_timeout_secs: u64,
39    /// Credential storage backend for this provider's token.
40    #[serde(default)]
41    pub credentials_store_mode: AuthCredentialsStoreMode,
42    /// Extra query parameters appended to the authorization URL.
43    #[serde(default)]
44    pub extra_auth_params: BTreeMap<String, String>,
45    /// Extra form fields appended to token exchanges and refreshes.
46    #[serde(default)]
47    pub extra_token_params: BTreeMap<String, String>,
48}
49
50impl Default for McpOAuthConfig {
51    fn default() -> Self {
52        Self {
53            authorization_url: String::new(),
54            token_url: String::new(),
55            client_id: String::new(),
56            scopes: Vec::new(),
57            audience: None,
58            callback_port: DEFAULT_CALLBACK_PORT,
59            flow_timeout_secs: DEFAULT_FLOW_TIMEOUT_SECS,
60            credentials_store_mode: AuthCredentialsStoreMode::default(),
61            extra_auth_params: BTreeMap::new(),
62            extra_token_params: BTreeMap::new(),
63        }
64    }
65}
66
67impl McpOAuthConfig {
68    pub fn validate(&self, provider_name: &str) -> Result<()> {
69        if self.authorization_url.trim().is_empty() {
70            bail!(
71                "MCP provider '{}' is missing oauth.authorization_url",
72                provider_name
73            );
74        }
75        if self.token_url.trim().is_empty() {
76            bail!(
77                "MCP provider '{}' is missing oauth.token_url",
78                provider_name
79            );
80        }
81        if self.client_id.trim().is_empty() {
82            bail!(
83                "MCP provider '{}' is missing oauth.client_id",
84                provider_name
85            );
86        }
87        Ok(())
88    }
89
90    fn callback_url(&self) -> String {
91        format!("http://localhost:{}/auth/callback", self.callback_port)
92    }
93}
94
95/// Stored OAuth token for an MCP HTTP provider.
96#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct McpOAuthToken {
98    pub access_token: String,
99    pub refresh_token: Option<String>,
100    pub token_type: Option<String>,
101    pub scope: Option<String>,
102    pub obtained_at: u64,
103    pub expires_at: Option<u64>,
104}
105
106impl McpOAuthToken {
107    pub fn is_refresh_due(&self) -> bool {
108        self.expires_at
109            .is_some_and(|expires_at| now_secs().saturating_add(REFRESH_SKEW_SECS) >= expires_at)
110    }
111}
112
113/// Status for an MCP provider's stored OAuth token.
114#[derive(Debug, Clone)]
115pub enum McpOAuthStatus {
116    Authenticated {
117        age_seconds: u64,
118        expires_in: Option<u64>,
119    },
120    NotAuthenticated,
121}
122
123/// Prepared browser-login flow for an MCP OAuth provider.
124#[derive(Debug, Clone)]
125pub struct McpOAuthPreparedLogin {
126    pub auth_url: String,
127    pub callback_port: u16,
128    pub timeout_secs: u64,
129    pkce: PkceChallenge,
130    state: String,
131}
132
133impl McpOAuthPreparedLogin {
134    #[must_use]
135    pub fn expected_state(&self) -> &str {
136        &self.state
137    }
138}
139
140/// Completion payload kept intentionally close to Codex app-server.
141#[derive(Debug, Clone, PartialEq, Eq)]
142pub struct McpOAuthLoginCompletion {
143    pub name: String,
144    pub success: bool,
145    pub error: Option<String>,
146}
147
148/// Service for loading, refreshing, and persisting MCP OAuth tokens.
149#[derive(Debug, Clone, Default)]
150pub struct McpOAuthService;
151
152impl McpOAuthService {
153    #[must_use]
154    pub fn new() -> Self {
155        Self
156    }
157
158    pub fn prepare_login(
159        &self,
160        provider_name: &str,
161        config: &McpOAuthConfig,
162    ) -> Result<McpOAuthPreparedLogin> {
163        config.validate(provider_name)?;
164        let pkce = generate_pkce_challenge()?;
165        let state = generate_state()?;
166        let auth_url = build_auth_url(config, &pkce, &state)?;
167        Ok(McpOAuthPreparedLogin {
168            auth_url,
169            callback_port: config.callback_port,
170            timeout_secs: config.flow_timeout_secs,
171            pkce,
172            state,
173        })
174    }
175
176    pub async fn complete_login(
177        &self,
178        provider_name: &str,
179        config: &McpOAuthConfig,
180        prepared: &McpOAuthPreparedLogin,
181        code: &str,
182    ) -> Result<McpOAuthLoginCompletion> {
183        config.validate(provider_name)?;
184        let token = exchange_code_for_token(config, code, &prepared.pkce).await?;
185        save_token(provider_name, &token, config.credentials_store_mode)?;
186        Ok(McpOAuthLoginCompletion {
187            name: provider_name.to_string(),
188            success: true,
189            error: None,
190        })
191    }
192
193    pub fn status(
194        &self,
195        provider_name: &str,
196        storage_mode: AuthCredentialsStoreMode,
197    ) -> Result<McpOAuthStatus> {
198        let Some(token) = load_token(provider_name, storage_mode)? else {
199            return Ok(McpOAuthStatus::NotAuthenticated);
200        };
201        let now = now_secs();
202        Ok(McpOAuthStatus::Authenticated {
203            age_seconds: now.saturating_sub(token.obtained_at),
204            expires_in: token
205                .expires_at
206                .map(|expires_at| expires_at.saturating_sub(now)),
207        })
208    }
209
210    pub fn load_token(
211        &self,
212        provider_name: &str,
213        storage_mode: AuthCredentialsStoreMode,
214    ) -> Result<Option<McpOAuthToken>> {
215        load_token(provider_name, storage_mode)
216    }
217
218    pub async fn resolve_access_token(
219        &self,
220        provider_name: &str,
221        config: &McpOAuthConfig,
222    ) -> Result<Option<String>> {
223        let Some(mut token) = load_token(provider_name, config.credentials_store_mode)? else {
224            return Ok(None);
225        };
226
227        if token.is_refresh_due() {
228            if token.refresh_token.is_some() {
229                token = refresh_token(config, &token).await?;
230                save_token(provider_name, &token, config.credentials_store_mode)?;
231            } else {
232                bail!(
233                    "Stored MCP OAuth token for '{}' expired and cannot be refreshed. Run `vtcode mcp login {}` again.",
234                    provider_name,
235                    provider_name
236                );
237            }
238        }
239
240        Ok(Some(token.access_token))
241    }
242
243    pub fn logout(
244        &self,
245        provider_name: &str,
246        storage_mode: AuthCredentialsStoreMode,
247    ) -> Result<McpOAuthLoginCompletion> {
248        clear_token(provider_name, storage_mode)?;
249        Ok(McpOAuthLoginCompletion {
250            name: provider_name.to_string(),
251            success: true,
252            error: None,
253        })
254    }
255}
256
257fn build_auth_url(
258    config: &McpOAuthConfig,
259    challenge: &PkceChallenge,
260    state: &str,
261) -> Result<String> {
262    let mut url =
263        Url::parse(&config.authorization_url).context("invalid oauth.authorization_url")?;
264    {
265        let mut query = url.query_pairs_mut();
266        query.append_pair("response_type", "code");
267        query.append_pair("client_id", &config.client_id);
268        query.append_pair("redirect_uri", &config.callback_url());
269        query.append_pair("code_challenge", &challenge.code_challenge);
270        query.append_pair("code_challenge_method", &challenge.code_challenge_method);
271        query.append_pair("state", state);
272        if !config.scopes.is_empty() {
273            query.append_pair("scope", &config.scopes.join(" "));
274        }
275        if let Some(audience) = config.audience.as_deref()
276            && !audience.trim().is_empty()
277        {
278            query.append_pair("audience", audience);
279        }
280        for (key, value) in &config.extra_auth_params {
281            if !key.trim().is_empty() {
282                query.append_pair(key, value);
283            }
284        }
285    }
286    Ok(url.to_string())
287}
288
289async fn exchange_code_for_token(
290    config: &McpOAuthConfig,
291    code: &str,
292    challenge: &PkceChallenge,
293) -> Result<McpOAuthToken> {
294    let mut form = vec![
295        ("grant_type".to_string(), "authorization_code".to_string()),
296        ("client_id".to_string(), config.client_id.clone()),
297        ("code".to_string(), code.to_string()),
298        ("redirect_uri".to_string(), config.callback_url()),
299        (
300            "code_verifier".to_string(),
301            challenge.code_verifier.to_string(),
302        ),
303    ];
304    if let Some(audience) = config.audience.as_deref()
305        && !audience.trim().is_empty()
306    {
307        form.push(("audience".to_string(), audience.to_string()));
308    }
309    form.extend(
310        config
311            .extra_token_params
312            .iter()
313            .map(|(key, value)| (key.clone(), value.clone())),
314    );
315    send_token_request(&config.token_url, &form).await
316}
317
318async fn refresh_token(config: &McpOAuthConfig, current: &McpOAuthToken) -> Result<McpOAuthToken> {
319    let refresh_token = current
320        .refresh_token
321        .as_deref()
322        .filter(|value| !value.trim().is_empty())
323        .ok_or_else(|| anyhow!("Stored MCP OAuth token does not include a refresh token"))?;
324    let mut form = vec![
325        ("grant_type".to_string(), "refresh_token".to_string()),
326        ("client_id".to_string(), config.client_id.clone()),
327        ("refresh_token".to_string(), refresh_token.to_string()),
328    ];
329    if let Some(audience) = config.audience.as_deref()
330        && !audience.trim().is_empty()
331    {
332        form.push(("audience".to_string(), audience.to_string()));
333    }
334    form.extend(
335        config
336            .extra_token_params
337            .iter()
338            .map(|(key, value)| (key.clone(), value.clone())),
339    );
340
341    let refreshed = send_token_request(&config.token_url, &form).await?;
342    Ok(McpOAuthToken {
343        refresh_token: refreshed
344            .refresh_token
345            .or_else(|| current.refresh_token.clone()),
346        ..refreshed
347    })
348}
349
350async fn send_token_request(token_url: &str, form: &[(String, String)]) -> Result<McpOAuthToken> {
351    let response = Client::new()
352        .post(token_url)
353        .header("Content-Type", "application/x-www-form-urlencoded")
354        .form(form)
355        .send()
356        .await
357        .with_context(|| format!("failed to send MCP OAuth request to {token_url}"))?;
358    let status = response.status();
359    let body = response
360        .text()
361        .await
362        .context("failed to read MCP OAuth response body")?;
363
364    if !status.is_success() {
365        bail!("MCP OAuth request failed (HTTP {}): {}", status, body);
366    }
367
368    let payload: TokenResponse =
369        serde_json::from_str(&body).context("failed to parse MCP OAuth token response")?;
370    let now = now_secs();
371    Ok(McpOAuthToken {
372        access_token: payload.access_token,
373        refresh_token: payload.refresh_token,
374        token_type: payload.token_type,
375        scope: payload.scope,
376        obtained_at: now,
377        expires_at: payload.expires_in.map(|secs| now.saturating_add(secs)),
378    })
379}
380
381#[derive(Debug, Deserialize)]
382struct TokenResponse {
383    access_token: String,
384    #[serde(default)]
385    refresh_token: Option<String>,
386    #[serde(default)]
387    token_type: Option<String>,
388    #[serde(default)]
389    scope: Option<String>,
390    #[serde(default)]
391    expires_in: Option<u64>,
392}
393
394fn generate_state() -> Result<String> {
395    let mut state_bytes = [0_u8; 32];
396    SystemRandom::new()
397        .fill(&mut state_bytes)
398        .map_err(|_| anyhow!("failed to generate MCP OAuth state"))?;
399    Ok(URL_SAFE_NO_PAD.encode(state_bytes))
400}
401
402fn save_token(
403    provider_name: &str,
404    token: &McpOAuthToken,
405    storage_mode: AuthCredentialsStoreMode,
406) -> Result<()> {
407    let serialized = serde_json::to_string(token).context("failed to serialize MCP OAuth token")?;
408    token_storage(provider_name).store_with_mode(&serialized, storage_mode)
409}
410
411fn load_token(
412    provider_name: &str,
413    storage_mode: AuthCredentialsStoreMode,
414) -> Result<Option<McpOAuthToken>> {
415    let Some(serialized) = token_storage(provider_name).load_with_mode(storage_mode)? else {
416        return Ok(None);
417    };
418    serde_json::from_str(&serialized)
419        .context("failed to parse stored MCP OAuth token")
420        .map(Some)
421}
422
423fn clear_token(provider_name: &str, storage_mode: AuthCredentialsStoreMode) -> Result<()> {
424    token_storage(provider_name).clear_with_mode(storage_mode)
425}
426
427fn token_storage(provider_name: &str) -> CredentialStorage {
428    let normalized_provider = provider_name
429        .chars()
430        .map(|ch| {
431            if ch.is_ascii_alphanumeric() || ch == '-' || ch == '_' {
432                ch
433            } else {
434                '_'
435            }
436        })
437        .collect::<String>();
438    CredentialStorage::new("vtcode", format!("mcp_oauth_{normalized_provider}"))
439}
440
441fn now_secs() -> u64 {
442    std::time::SystemTime::now()
443        .duration_since(std::time::UNIX_EPOCH)
444        .map(|duration| duration.as_secs())
445        .unwrap_or(0)
446}
447
448#[cfg(test)]
449mod tests {
450    use super::*;
451    use assert_fs::TempDir;
452    use serial_test::serial;
453    use std::path::PathBuf;
454
455    struct TestAuthDirGuard {
456        previous: Option<PathBuf>,
457        temp_dir: Option<TempDir>,
458    }
459
460    impl TestAuthDirGuard {
461        fn new() -> Self {
462            let temp_dir = TempDir::new().expect("temp dir");
463            let previous = crate::storage_paths::auth_storage_dir_override_for_tests()
464                .expect("read previous auth dir override");
465            crate::storage_paths::set_auth_storage_dir_override_for_tests(Some(
466                temp_dir.path().to_path_buf(),
467            ))
468            .expect("set auth dir override");
469            Self {
470                previous,
471                temp_dir: Some(temp_dir),
472            }
473        }
474    }
475
476    impl Drop for TestAuthDirGuard {
477        fn drop(&mut self) {
478            crate::storage_paths::set_auth_storage_dir_override_for_tests(self.previous.clone())
479                .expect("restore auth dir override");
480            if let Some(temp_dir) = self.temp_dir.take() {
481                let _ = temp_dir.close();
482            }
483        }
484    }
485
486    fn sample_config() -> McpOAuthConfig {
487        McpOAuthConfig {
488            authorization_url: "https://example.com/oauth/authorize".to_string(),
489            token_url: "https://example.com/oauth/token".to_string(),
490            client_id: "client-123".to_string(),
491            scopes: vec!["mcp:read".to_string(), "mcp:write".to_string()],
492            audience: Some("mcp-api".to_string()),
493            callback_port: 8123,
494            flow_timeout_secs: 120,
495            credentials_store_mode: AuthCredentialsStoreMode::File,
496            extra_auth_params: BTreeMap::from([("prompt".to_string(), "consent".to_string())]),
497            extra_token_params: BTreeMap::new(),
498        }
499    }
500
501    #[test]
502    fn prepare_login_builds_expected_auth_url() {
503        let service = McpOAuthService::new();
504        let prepared = service
505            .prepare_login("demo", &sample_config())
506            .expect("prepare login");
507
508        assert!(prepared.auth_url.contains("response_type=code"));
509        assert!(prepared.auth_url.contains("client_id=client-123"));
510        assert!(prepared.auth_url.contains("scope=mcp%3Aread+mcp%3Awrite"));
511        assert!(prepared.auth_url.contains("audience=mcp-api"));
512        assert!(prepared.auth_url.contains("prompt=consent"));
513        assert!(prepared.auth_url.contains("code_challenge="));
514        assert!(prepared.auth_url.contains("state="));
515        assert_eq!(prepared.callback_port, 8123);
516        assert_eq!(prepared.timeout_secs, 120);
517    }
518
519    #[test]
520    #[serial]
521    fn status_reflects_stored_token() {
522        let _guard = TestAuthDirGuard::new();
523        let service = McpOAuthService::new();
524        let storage_mode = AuthCredentialsStoreMode::File;
525        assert!(matches!(
526            service.status("demo", storage_mode).expect("status"),
527            McpOAuthStatus::NotAuthenticated
528        ));
529
530        save_token(
531            "demo",
532            &McpOAuthToken {
533                access_token: "access".to_string(),
534                refresh_token: Some("refresh".to_string()),
535                token_type: Some("Bearer".to_string()),
536                scope: Some("mcp:read".to_string()),
537                obtained_at: now_secs(),
538                expires_at: Some(now_secs() + 3600),
539            },
540            storage_mode,
541        )
542        .expect("save token");
543
544        let status = service.status("demo", storage_mode).expect("status");
545        assert!(matches!(
546            status,
547            McpOAuthStatus::Authenticated {
548                expires_in: Some(_),
549                ..
550            }
551        ));
552    }
553
554    #[test]
555    #[serial]
556    fn logout_clears_stored_token() {
557        let _guard = TestAuthDirGuard::new();
558        let service = McpOAuthService::new();
559        let storage_mode = AuthCredentialsStoreMode::File;
560        save_token(
561            "demo",
562            &McpOAuthToken {
563                access_token: "access".to_string(),
564                refresh_token: None,
565                token_type: Some("Bearer".to_string()),
566                scope: None,
567                obtained_at: now_secs(),
568                expires_at: None,
569            },
570            storage_mode,
571        )
572        .expect("save token");
573
574        service.logout("demo", storage_mode).expect("logout");
575        assert!(load_token("demo", storage_mode).expect("load").is_none());
576    }
577}