1use anyhow::{Context, Result, anyhow, bail};
4use base64::Engine;
5use base64::engine::general_purpose::URL_SAFE_NO_PAD;
6use reqwest::{Client, Url};
7use ring::rand::{SecureRandom, SystemRandom};
8use serde::{Deserialize, Serialize};
9use std::collections::BTreeMap;
10
11use crate::credentials::{AuthCredentialsStoreMode, CredentialStorage};
12use crate::pkce::{PkceChallenge, generate_pkce_challenge};
13
14const DEFAULT_CALLBACK_PORT: u16 = 8768;
15const DEFAULT_FLOW_TIMEOUT_SECS: u64 = 300;
16const REFRESH_SKEW_SECS: u64 = 60;
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
20#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
21#[serde(default)]
22pub struct McpOAuthConfig {
23 pub authorization_url: String,
25 pub token_url: String,
27 pub client_id: String,
29 #[serde(default)]
31 pub scopes: Vec<String>,
32 #[serde(default)]
34 pub audience: Option<String>,
35 pub callback_port: u16,
37 pub flow_timeout_secs: u64,
39 #[serde(default)]
41 pub credentials_store_mode: AuthCredentialsStoreMode,
42 #[serde(default)]
44 pub extra_auth_params: BTreeMap<String, String>,
45 #[serde(default)]
47 pub extra_token_params: BTreeMap<String, String>,
48}
49
50impl Default for McpOAuthConfig {
51 fn default() -> Self {
52 Self {
53 authorization_url: String::new(),
54 token_url: String::new(),
55 client_id: String::new(),
56 scopes: Vec::new(),
57 audience: None,
58 callback_port: DEFAULT_CALLBACK_PORT,
59 flow_timeout_secs: DEFAULT_FLOW_TIMEOUT_SECS,
60 credentials_store_mode: AuthCredentialsStoreMode::default(),
61 extra_auth_params: BTreeMap::new(),
62 extra_token_params: BTreeMap::new(),
63 }
64 }
65}
66
67impl McpOAuthConfig {
68 pub fn validate(&self, provider_name: &str) -> Result<()> {
69 if self.authorization_url.trim().is_empty() {
70 bail!(
71 "MCP provider '{}' is missing oauth.authorization_url",
72 provider_name
73 );
74 }
75 if self.token_url.trim().is_empty() {
76 bail!(
77 "MCP provider '{}' is missing oauth.token_url",
78 provider_name
79 );
80 }
81 if self.client_id.trim().is_empty() {
82 bail!(
83 "MCP provider '{}' is missing oauth.client_id",
84 provider_name
85 );
86 }
87 Ok(())
88 }
89
90 fn callback_url(&self) -> String {
91 format!("http://localhost:{}/auth/callback", self.callback_port)
92 }
93}
94
95#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct McpOAuthToken {
98 pub access_token: String,
99 pub refresh_token: Option<String>,
100 pub token_type: Option<String>,
101 pub scope: Option<String>,
102 pub obtained_at: u64,
103 pub expires_at: Option<u64>,
104}
105
106impl McpOAuthToken {
107 pub fn is_refresh_due(&self) -> bool {
108 self.expires_at
109 .is_some_and(|expires_at| now_secs().saturating_add(REFRESH_SKEW_SECS) >= expires_at)
110 }
111}
112
113#[derive(Debug, Clone)]
115pub enum McpOAuthStatus {
116 Authenticated {
117 age_seconds: u64,
118 expires_in: Option<u64>,
119 },
120 NotAuthenticated,
121}
122
123#[derive(Debug, Clone)]
125pub struct McpOAuthPreparedLogin {
126 pub auth_url: String,
127 pub callback_port: u16,
128 pub timeout_secs: u64,
129 pkce: PkceChallenge,
130 state: String,
131}
132
133impl McpOAuthPreparedLogin {
134 #[must_use]
135 pub fn expected_state(&self) -> &str {
136 &self.state
137 }
138}
139
140#[derive(Debug, Clone, PartialEq, Eq)]
142pub struct McpOAuthLoginCompletion {
143 pub name: String,
144 pub success: bool,
145 pub error: Option<String>,
146}
147
148#[derive(Debug, Clone, Default)]
150pub struct McpOAuthService;
151
152impl McpOAuthService {
153 #[must_use]
154 pub fn new() -> Self {
155 Self
156 }
157
158 pub fn prepare_login(
159 &self,
160 provider_name: &str,
161 config: &McpOAuthConfig,
162 ) -> Result<McpOAuthPreparedLogin> {
163 config.validate(provider_name)?;
164 let pkce = generate_pkce_challenge()?;
165 let state = generate_state()?;
166 let auth_url = build_auth_url(config, &pkce, &state)?;
167 Ok(McpOAuthPreparedLogin {
168 auth_url,
169 callback_port: config.callback_port,
170 timeout_secs: config.flow_timeout_secs,
171 pkce,
172 state,
173 })
174 }
175
176 pub async fn complete_login(
177 &self,
178 provider_name: &str,
179 config: &McpOAuthConfig,
180 prepared: &McpOAuthPreparedLogin,
181 code: &str,
182 ) -> Result<McpOAuthLoginCompletion> {
183 config.validate(provider_name)?;
184 let token = exchange_code_for_token(config, code, &prepared.pkce).await?;
185 save_token(provider_name, &token, config.credentials_store_mode)?;
186 Ok(McpOAuthLoginCompletion {
187 name: provider_name.to_string(),
188 success: true,
189 error: None,
190 })
191 }
192
193 pub fn status(
194 &self,
195 provider_name: &str,
196 storage_mode: AuthCredentialsStoreMode,
197 ) -> Result<McpOAuthStatus> {
198 let Some(token) = load_token(provider_name, storage_mode)? else {
199 return Ok(McpOAuthStatus::NotAuthenticated);
200 };
201 let now = now_secs();
202 Ok(McpOAuthStatus::Authenticated {
203 age_seconds: now.saturating_sub(token.obtained_at),
204 expires_in: token
205 .expires_at
206 .map(|expires_at| expires_at.saturating_sub(now)),
207 })
208 }
209
210 pub fn load_token(
211 &self,
212 provider_name: &str,
213 storage_mode: AuthCredentialsStoreMode,
214 ) -> Result<Option<McpOAuthToken>> {
215 load_token(provider_name, storage_mode)
216 }
217
218 pub async fn resolve_access_token(
219 &self,
220 provider_name: &str,
221 config: &McpOAuthConfig,
222 ) -> Result<Option<String>> {
223 let Some(mut token) = load_token(provider_name, config.credentials_store_mode)? else {
224 return Ok(None);
225 };
226
227 if token.is_refresh_due() {
228 if token.refresh_token.is_some() {
229 token = refresh_token(config, &token).await?;
230 save_token(provider_name, &token, config.credentials_store_mode)?;
231 } else {
232 bail!(
233 "Stored MCP OAuth token for '{}' expired and cannot be refreshed. Run `vtcode mcp login {}` again.",
234 provider_name,
235 provider_name
236 );
237 }
238 }
239
240 Ok(Some(token.access_token))
241 }
242
243 pub fn logout(
244 &self,
245 provider_name: &str,
246 storage_mode: AuthCredentialsStoreMode,
247 ) -> Result<McpOAuthLoginCompletion> {
248 clear_token(provider_name, storage_mode)?;
249 Ok(McpOAuthLoginCompletion {
250 name: provider_name.to_string(),
251 success: true,
252 error: None,
253 })
254 }
255}
256
257fn build_auth_url(
258 config: &McpOAuthConfig,
259 challenge: &PkceChallenge,
260 state: &str,
261) -> Result<String> {
262 let mut url =
263 Url::parse(&config.authorization_url).context("invalid oauth.authorization_url")?;
264 {
265 let mut query = url.query_pairs_mut();
266 query.append_pair("response_type", "code");
267 query.append_pair("client_id", &config.client_id);
268 query.append_pair("redirect_uri", &config.callback_url());
269 query.append_pair("code_challenge", &challenge.code_challenge);
270 query.append_pair("code_challenge_method", &challenge.code_challenge_method);
271 query.append_pair("state", state);
272 if !config.scopes.is_empty() {
273 query.append_pair("scope", &config.scopes.join(" "));
274 }
275 if let Some(audience) = config.audience.as_deref()
276 && !audience.trim().is_empty()
277 {
278 query.append_pair("audience", audience);
279 }
280 for (key, value) in &config.extra_auth_params {
281 if !key.trim().is_empty() {
282 query.append_pair(key, value);
283 }
284 }
285 }
286 Ok(url.to_string())
287}
288
289async fn exchange_code_for_token(
290 config: &McpOAuthConfig,
291 code: &str,
292 challenge: &PkceChallenge,
293) -> Result<McpOAuthToken> {
294 let mut form = vec![
295 ("grant_type".to_string(), "authorization_code".to_string()),
296 ("client_id".to_string(), config.client_id.clone()),
297 ("code".to_string(), code.to_string()),
298 ("redirect_uri".to_string(), config.callback_url()),
299 (
300 "code_verifier".to_string(),
301 challenge.code_verifier.to_string(),
302 ),
303 ];
304 if let Some(audience) = config.audience.as_deref()
305 && !audience.trim().is_empty()
306 {
307 form.push(("audience".to_string(), audience.to_string()));
308 }
309 form.extend(
310 config
311 .extra_token_params
312 .iter()
313 .map(|(key, value)| (key.clone(), value.clone())),
314 );
315 send_token_request(&config.token_url, &form).await
316}
317
318async fn refresh_token(config: &McpOAuthConfig, current: &McpOAuthToken) -> Result<McpOAuthToken> {
319 let refresh_token = current
320 .refresh_token
321 .as_deref()
322 .filter(|value| !value.trim().is_empty())
323 .ok_or_else(|| anyhow!("Stored MCP OAuth token does not include a refresh token"))?;
324 let mut form = vec![
325 ("grant_type".to_string(), "refresh_token".to_string()),
326 ("client_id".to_string(), config.client_id.clone()),
327 ("refresh_token".to_string(), refresh_token.to_string()),
328 ];
329 if let Some(audience) = config.audience.as_deref()
330 && !audience.trim().is_empty()
331 {
332 form.push(("audience".to_string(), audience.to_string()));
333 }
334 form.extend(
335 config
336 .extra_token_params
337 .iter()
338 .map(|(key, value)| (key.clone(), value.clone())),
339 );
340
341 let refreshed = send_token_request(&config.token_url, &form).await?;
342 Ok(McpOAuthToken {
343 refresh_token: refreshed
344 .refresh_token
345 .or_else(|| current.refresh_token.clone()),
346 ..refreshed
347 })
348}
349
350async fn send_token_request(token_url: &str, form: &[(String, String)]) -> Result<McpOAuthToken> {
351 let response = Client::new()
352 .post(token_url)
353 .header("Content-Type", "application/x-www-form-urlencoded")
354 .form(form)
355 .send()
356 .await
357 .with_context(|| format!("failed to send MCP OAuth request to {token_url}"))?;
358 let status = response.status();
359 let body = response
360 .text()
361 .await
362 .context("failed to read MCP OAuth response body")?;
363
364 if !status.is_success() {
365 bail!("MCP OAuth request failed (HTTP {}): {}", status, body);
366 }
367
368 let payload: TokenResponse =
369 serde_json::from_str(&body).context("failed to parse MCP OAuth token response")?;
370 let now = now_secs();
371 Ok(McpOAuthToken {
372 access_token: payload.access_token,
373 refresh_token: payload.refresh_token,
374 token_type: payload.token_type,
375 scope: payload.scope,
376 obtained_at: now,
377 expires_at: payload.expires_in.map(|secs| now.saturating_add(secs)),
378 })
379}
380
381#[derive(Debug, Deserialize)]
382struct TokenResponse {
383 access_token: String,
384 #[serde(default)]
385 refresh_token: Option<String>,
386 #[serde(default)]
387 token_type: Option<String>,
388 #[serde(default)]
389 scope: Option<String>,
390 #[serde(default)]
391 expires_in: Option<u64>,
392}
393
394fn generate_state() -> Result<String> {
395 let mut state_bytes = [0_u8; 32];
396 SystemRandom::new()
397 .fill(&mut state_bytes)
398 .map_err(|_| anyhow!("failed to generate MCP OAuth state"))?;
399 Ok(URL_SAFE_NO_PAD.encode(state_bytes))
400}
401
402fn save_token(
403 provider_name: &str,
404 token: &McpOAuthToken,
405 storage_mode: AuthCredentialsStoreMode,
406) -> Result<()> {
407 let serialized = serde_json::to_string(token).context("failed to serialize MCP OAuth token")?;
408 token_storage(provider_name).store_with_mode(&serialized, storage_mode)
409}
410
411fn load_token(
412 provider_name: &str,
413 storage_mode: AuthCredentialsStoreMode,
414) -> Result<Option<McpOAuthToken>> {
415 let Some(serialized) = token_storage(provider_name).load_with_mode(storage_mode)? else {
416 return Ok(None);
417 };
418 serde_json::from_str(&serialized)
419 .context("failed to parse stored MCP OAuth token")
420 .map(Some)
421}
422
423fn clear_token(provider_name: &str, storage_mode: AuthCredentialsStoreMode) -> Result<()> {
424 token_storage(provider_name).clear_with_mode(storage_mode)
425}
426
427fn token_storage(provider_name: &str) -> CredentialStorage {
428 let normalized_provider = provider_name
429 .chars()
430 .map(|ch| {
431 if ch.is_ascii_alphanumeric() || ch == '-' || ch == '_' {
432 ch
433 } else {
434 '_'
435 }
436 })
437 .collect::<String>();
438 CredentialStorage::new("vtcode", format!("mcp_oauth_{normalized_provider}"))
439}
440
441fn now_secs() -> u64 {
442 std::time::SystemTime::now()
443 .duration_since(std::time::UNIX_EPOCH)
444 .map(|duration| duration.as_secs())
445 .unwrap_or(0)
446}
447
448#[cfg(test)]
449mod tests {
450 use super::*;
451 use assert_fs::TempDir;
452 use serial_test::serial;
453 use std::path::PathBuf;
454
455 struct TestAuthDirGuard {
456 previous: Option<PathBuf>,
457 temp_dir: Option<TempDir>,
458 }
459
460 impl TestAuthDirGuard {
461 fn new() -> Self {
462 let temp_dir = TempDir::new().expect("temp dir");
463 let previous = crate::storage_paths::auth_storage_dir_override_for_tests()
464 .expect("read previous auth dir override");
465 crate::storage_paths::set_auth_storage_dir_override_for_tests(Some(
466 temp_dir.path().to_path_buf(),
467 ))
468 .expect("set auth dir override");
469 Self {
470 previous,
471 temp_dir: Some(temp_dir),
472 }
473 }
474 }
475
476 impl Drop for TestAuthDirGuard {
477 fn drop(&mut self) {
478 crate::storage_paths::set_auth_storage_dir_override_for_tests(self.previous.clone())
479 .expect("restore auth dir override");
480 if let Some(temp_dir) = self.temp_dir.take() {
481 let _ = temp_dir.close();
482 }
483 }
484 }
485
486 fn sample_config() -> McpOAuthConfig {
487 McpOAuthConfig {
488 authorization_url: "https://example.com/oauth/authorize".to_string(),
489 token_url: "https://example.com/oauth/token".to_string(),
490 client_id: "client-123".to_string(),
491 scopes: vec!["mcp:read".to_string(), "mcp:write".to_string()],
492 audience: Some("mcp-api".to_string()),
493 callback_port: 8123,
494 flow_timeout_secs: 120,
495 credentials_store_mode: AuthCredentialsStoreMode::File,
496 extra_auth_params: BTreeMap::from([("prompt".to_string(), "consent".to_string())]),
497 extra_token_params: BTreeMap::new(),
498 }
499 }
500
501 #[test]
502 fn prepare_login_builds_expected_auth_url() {
503 let service = McpOAuthService::new();
504 let prepared = service
505 .prepare_login("demo", &sample_config())
506 .expect("prepare login");
507
508 assert!(prepared.auth_url.contains("response_type=code"));
509 assert!(prepared.auth_url.contains("client_id=client-123"));
510 assert!(prepared.auth_url.contains("scope=mcp%3Aread+mcp%3Awrite"));
511 assert!(prepared.auth_url.contains("audience=mcp-api"));
512 assert!(prepared.auth_url.contains("prompt=consent"));
513 assert!(prepared.auth_url.contains("code_challenge="));
514 assert!(prepared.auth_url.contains("state="));
515 assert_eq!(prepared.callback_port, 8123);
516 assert_eq!(prepared.timeout_secs, 120);
517 }
518
519 #[test]
520 #[serial]
521 fn status_reflects_stored_token() {
522 let _guard = TestAuthDirGuard::new();
523 let service = McpOAuthService::new();
524 let storage_mode = AuthCredentialsStoreMode::File;
525 assert!(matches!(
526 service.status("demo", storage_mode).expect("status"),
527 McpOAuthStatus::NotAuthenticated
528 ));
529
530 save_token(
531 "demo",
532 &McpOAuthToken {
533 access_token: "access".to_string(),
534 refresh_token: Some("refresh".to_string()),
535 token_type: Some("Bearer".to_string()),
536 scope: Some("mcp:read".to_string()),
537 obtained_at: now_secs(),
538 expires_at: Some(now_secs() + 3600),
539 },
540 storage_mode,
541 )
542 .expect("save token");
543
544 let status = service.status("demo", storage_mode).expect("status");
545 assert!(matches!(
546 status,
547 McpOAuthStatus::Authenticated {
548 expires_in: Some(_),
549 ..
550 }
551 ));
552 }
553
554 #[test]
555 #[serial]
556 fn logout_clears_stored_token() {
557 let _guard = TestAuthDirGuard::new();
558 let service = McpOAuthService::new();
559 let storage_mode = AuthCredentialsStoreMode::File;
560 save_token(
561 "demo",
562 &McpOAuthToken {
563 access_token: "access".to_string(),
564 refresh_token: None,
565 token_type: Some("Bearer".to_string()),
566 scope: None,
567 obtained_at: now_secs(),
568 expires_at: None,
569 },
570 storage_mode,
571 )
572 .expect("save token");
573
574 service.logout("demo", storage_mode).expect("logout");
575 assert!(load_token("demo", storage_mode).expect("load").is_none());
576 }
577}