1use anyhow::{Context, Result, anyhow};
32use ring::aead::{self, Aad, LessSafeKey, NONCE_LEN, Nonce, UnboundKey};
33use ring::rand::{SecureRandom, SystemRandom};
34use serde::{Deserialize, Serialize};
35use std::fs;
36use std::path::PathBuf;
37
38pub use super::credentials::AuthCredentialsStoreMode;
39use super::pkce::PkceChallenge;
40
41const OPENROUTER_AUTH_URL: &str = "https://openrouter.ai/auth";
43const OPENROUTER_KEYS_URL: &str = "https://openrouter.ai/api/v1/auth/keys";
44
45pub const DEFAULT_CALLBACK_PORT: u16 = 8484;
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
50#[serde(default)]
51pub struct OpenRouterOAuthConfig {
52 pub use_oauth: bool,
54 pub callback_port: u16,
56 pub auto_refresh: bool,
58}
59
60impl Default for OpenRouterOAuthConfig {
61 fn default() -> Self {
62 Self {
63 use_oauth: false,
64 callback_port: DEFAULT_CALLBACK_PORT,
65 auto_refresh: true,
66 }
67 }
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct OpenRouterToken {
73 pub api_key: String,
75 pub obtained_at: u64,
77 pub expires_at: Option<u64>,
79 pub label: Option<String>,
81}
82
83impl OpenRouterToken {
84 pub fn is_expired(&self) -> bool {
86 if let Some(expires_at) = self.expires_at {
87 let now = std::time::SystemTime::now()
88 .duration_since(std::time::UNIX_EPOCH)
89 .map(|d| d.as_secs())
90 .unwrap_or(0);
91 now >= expires_at
92 } else {
93 false
94 }
95 }
96}
97
98#[derive(Debug, Serialize, Deserialize)]
100struct EncryptedToken {
101 nonce: String,
103 ciphertext: String,
105 version: u8,
107}
108
109pub fn get_auth_url(challenge: &PkceChallenge, callback_port: u16) -> String {
118 let callback_url = format!("http://localhost:{}/callback", callback_port);
119 format!(
120 "{}?callback_url={}&code_challenge={}&code_challenge_method={}",
121 OPENROUTER_AUTH_URL,
122 urlencoding::encode(&callback_url),
123 urlencoding::encode(&challenge.code_challenge),
124 challenge.code_challenge_method
125 )
126}
127
128pub async fn exchange_code_for_token(code: &str, challenge: &PkceChallenge) -> Result<String> {
140 let client = reqwest::Client::new();
141
142 let payload = serde_json::json!({
143 "code": code,
144 "code_verifier": challenge.code_verifier,
145 "code_challenge_method": challenge.code_challenge_method
146 });
147
148 let response = client
149 .post(OPENROUTER_KEYS_URL)
150 .header("Content-Type", "application/json")
151 .json(&payload)
152 .send()
153 .await
154 .context("Failed to send token exchange request")?;
155
156 let status = response.status();
157 let body = response
158 .text()
159 .await
160 .context("Failed to read response body")?;
161
162 if !status.is_success() {
163 if status.as_u16() == 400 {
165 return Err(anyhow!(
166 "Invalid code_challenge_method. Ensure you're using the same method (S256) in both steps."
167 ));
168 } else if status.as_u16() == 403 {
169 return Err(anyhow!(
170 "Invalid code or code_verifier. The authorization code may have expired."
171 ));
172 } else if status.as_u16() == 405 {
173 return Err(anyhow!(
174 "Method not allowed. Ensure you're using POST over HTTPS."
175 ));
176 }
177 return Err(anyhow!("Token exchange failed (HTTP {}): {}", status, body));
178 }
179
180 let response_json: serde_json::Value =
182 serde_json::from_str(&body).context("Failed to parse token response")?;
183
184 let api_key = response_json
185 .get("key")
186 .and_then(|v| v.as_str())
187 .ok_or_else(|| anyhow!("Response missing 'key' field"))?
188 .to_string();
189
190 Ok(api_key)
191}
192
193fn get_token_path() -> Result<PathBuf> {
195 let vtcode_dir = dirs::home_dir()
196 .ok_or_else(|| anyhow!("Could not determine home directory"))?
197 .join(".vtcode")
198 .join("auth");
199
200 fs::create_dir_all(&vtcode_dir).context("Failed to create auth directory")?;
201
202 Ok(vtcode_dir.join("openrouter.json"))
203}
204
205fn derive_encryption_key() -> Result<LessSafeKey> {
207 use ring::digest::{SHA256, digest};
208
209 let mut key_material = Vec::new();
211
212 if let Ok(hostname) = hostname::get() {
214 key_material.extend_from_slice(hostname.as_encoded_bytes());
215 }
216
217 #[cfg(unix)]
219 {
220 key_material.extend_from_slice(&nix::unistd::getuid().as_raw().to_le_bytes());
221 }
222 #[cfg(not(unix))]
223 {
224 if let Ok(user) = std::env::var("USER").or_else(|_| std::env::var("USERNAME")) {
225 key_material.extend_from_slice(user.as_bytes());
226 }
227 }
228
229 key_material.extend_from_slice(b"vtcode-openrouter-oauth-v1");
231
232 let hash = digest(&SHA256, &key_material);
234 let key_bytes: &[u8; 32] = hash.as_ref()[..32].try_into().context("Hash too short")?;
235
236 let unbound_key = UnboundKey::new(&aead::AES_256_GCM, key_bytes)
237 .map_err(|_| anyhow!("Invalid key length"))?;
238
239 Ok(LessSafeKey::new(unbound_key))
240}
241
242fn encrypt_token(token: &OpenRouterToken) -> Result<EncryptedToken> {
244 let key = derive_encryption_key()?;
245 let rng = SystemRandom::new();
246
247 let mut nonce_bytes = [0u8; NONCE_LEN];
249 rng.fill(&mut nonce_bytes)
250 .map_err(|_| anyhow!("Failed to generate nonce"))?;
251
252 let plaintext = serde_json::to_vec(token).context("Failed to serialize token")?;
254
255 let mut ciphertext = plaintext;
257 let nonce = Nonce::assume_unique_for_key(nonce_bytes);
258 key.seal_in_place_append_tag(nonce, Aad::empty(), &mut ciphertext)
259 .map_err(|_| anyhow!("Encryption failed"))?;
260
261 use base64::{Engine, engine::general_purpose::STANDARD};
262
263 Ok(EncryptedToken {
264 nonce: STANDARD.encode(nonce_bytes),
265 ciphertext: STANDARD.encode(&ciphertext),
266 version: 1,
267 })
268}
269
270fn decrypt_token(encrypted: &EncryptedToken) -> Result<OpenRouterToken> {
272 if encrypted.version != 1 {
273 return Err(anyhow!(
274 "Unsupported token format version: {}",
275 encrypted.version
276 ));
277 }
278
279 use base64::{Engine, engine::general_purpose::STANDARD};
280
281 let key = derive_encryption_key()?;
282
283 let nonce_bytes: [u8; NONCE_LEN] = STANDARD
284 .decode(&encrypted.nonce)
285 .context("Invalid nonce encoding")?
286 .try_into()
287 .map_err(|_| anyhow!("Invalid nonce length"))?;
288
289 let mut ciphertext = STANDARD
290 .decode(&encrypted.ciphertext)
291 .context("Invalid ciphertext encoding")?;
292
293 let nonce = Nonce::assume_unique_for_key(nonce_bytes);
294 let plaintext = key
295 .open_in_place(nonce, Aad::empty(), &mut ciphertext)
296 .map_err(|_| {
297 anyhow!("Decryption failed - token may be corrupted or from different machine")
298 })?;
299
300 serde_json::from_slice(plaintext).context("Failed to deserialize token")
301}
302
303pub fn save_oauth_token_with_mode(
309 token: &OpenRouterToken,
310 mode: AuthCredentialsStoreMode,
311) -> Result<()> {
312 let effective_mode = mode.effective_mode();
313
314 match effective_mode {
315 AuthCredentialsStoreMode::Keyring => save_oauth_token_keyring(token),
316 AuthCredentialsStoreMode::File => save_oauth_token_file(token),
317 _ => unreachable!(),
318 }
319}
320
321fn save_oauth_token_keyring(token: &OpenRouterToken) -> Result<()> {
323 let entry =
324 keyring::Entry::new("vtcode", "openrouter_oauth").context("Failed to access OS keyring")?;
325
326 let token_json =
328 serde_json::to_string(token).context("Failed to serialize token for keyring")?;
329
330 entry
331 .set_password(&token_json)
332 .context("Failed to store token in OS keyring")?;
333
334 tracing::info!("OAuth token saved to OS keyring");
335 Ok(())
336}
337
338fn save_oauth_token_file(token: &OpenRouterToken) -> Result<()> {
340 let path = get_token_path()?;
341 let encrypted = encrypt_token(token)?;
342 let json =
343 serde_json::to_string_pretty(&encrypted).context("Failed to serialize encrypted token")?;
344
345 fs::write(&path, json).context("Failed to write token file")?;
346
347 #[cfg(unix)]
349 {
350 use std::os::unix::fs::PermissionsExt;
351 let perms = std::fs::Permissions::from_mode(0o600);
352 fs::set_permissions(&path, perms).context("Failed to set token file permissions")?;
353 }
354
355 tracing::info!("OAuth token saved to {}", path.display());
356 Ok(())
357}
358
359pub fn save_oauth_token(token: &OpenRouterToken) -> Result<()> {
364 save_oauth_token_with_mode(token, AuthCredentialsStoreMode::default())
365}
366
367pub fn load_oauth_token_with_mode(
371 mode: AuthCredentialsStoreMode,
372) -> Result<Option<OpenRouterToken>> {
373 let effective_mode = mode.effective_mode();
374
375 match effective_mode {
376 AuthCredentialsStoreMode::Keyring => load_oauth_token_keyring(),
377 AuthCredentialsStoreMode::File => load_oauth_token_file(),
378 _ => unreachable!(),
379 }
380}
381
382fn load_oauth_token_keyring() -> Result<Option<OpenRouterToken>> {
384 let entry = match keyring::Entry::new("vtcode", "openrouter_oauth") {
385 Ok(e) => e,
386 Err(_) => return Ok(None),
387 };
388
389 let token_json = match entry.get_password() {
390 Ok(json) => json,
391 Err(keyring::Error::NoEntry) => return Ok(None),
392 Err(e) => return Err(anyhow!("Failed to read from keyring: {}", e)),
393 };
394
395 let token: OpenRouterToken =
396 serde_json::from_str(&token_json).context("Failed to parse token from keyring")?;
397
398 if token.is_expired() {
400 tracing::warn!("OAuth token has expired, removing...");
401 clear_oauth_token_keyring()?;
402 return Ok(None);
403 }
404
405 Ok(Some(token))
406}
407
408fn load_oauth_token_file() -> Result<Option<OpenRouterToken>> {
410 let path = get_token_path()?;
411
412 if !path.exists() {
413 return Ok(None);
414 }
415
416 let json = fs::read_to_string(&path).context("Failed to read token file")?;
417 let encrypted: EncryptedToken =
418 serde_json::from_str(&json).context("Failed to parse token file")?;
419
420 let token = decrypt_token(&encrypted)?;
421
422 if token.is_expired() {
424 tracing::warn!("OAuth token has expired, removing...");
425 clear_oauth_token_file()?;
426 return Ok(None);
427 }
428
429 Ok(Some(token))
430}
431
432pub fn load_oauth_token() -> Result<Option<OpenRouterToken>> {
444 match load_oauth_token_keyring() {
445 Ok(Some(token)) => return Ok(Some(token)),
446 Ok(None) => {
447 tracing::debug!("No token in keyring, checking file storage");
449 }
450 Err(e) => {
451 let error_str = e.to_string().to_lowercase();
453 if error_str.contains("no entry") || error_str.contains("not found") {
454 tracing::debug!("Keyring entry not found, checking file storage");
455 } else {
456 return Err(e);
459 }
460 }
461 }
462
463 load_oauth_token_file()
465}
466
467fn clear_oauth_token_keyring() -> Result<()> {
469 let entry = match keyring::Entry::new("vtcode", "openrouter_oauth") {
470 Ok(e) => e,
471 Err(_) => return Ok(()),
472 };
473
474 match entry.delete_credential() {
475 Ok(_) => tracing::info!("OAuth token cleared from keyring"),
476 Err(keyring::Error::NoEntry) => {}
477 Err(e) => return Err(anyhow!("Failed to clear keyring entry: {}", e)),
478 }
479
480 Ok(())
481}
482
483fn clear_oauth_token_file() -> Result<()> {
485 let path = get_token_path()?;
486
487 if path.exists() {
488 fs::remove_file(&path).context("Failed to remove token file")?;
489 tracing::info!("OAuth token cleared from file");
490 }
491
492 Ok(())
493}
494
495pub fn clear_oauth_token() -> Result<()> {
497 let _ = clear_oauth_token_keyring();
499 let _ = clear_oauth_token_file();
500
501 tracing::info!("OAuth token cleared from all storage");
502 Ok(())
503}
504
505pub fn get_auth_status() -> Result<AuthStatus> {
507 match load_oauth_token()? {
508 Some(token) => {
509 let now = std::time::SystemTime::now()
510 .duration_since(std::time::UNIX_EPOCH)
511 .map(|d| d.as_secs())
512 .unwrap_or(0);
513
514 let age_seconds = now.saturating_sub(token.obtained_at);
515
516 Ok(AuthStatus::Authenticated {
517 label: token.label,
518 age_seconds,
519 expires_in: token.expires_at.map(|e| e.saturating_sub(now)),
520 })
521 }
522 None => Ok(AuthStatus::NotAuthenticated),
523 }
524}
525
526#[derive(Debug, Clone)]
528pub enum AuthStatus {
529 Authenticated {
531 label: Option<String>,
533 age_seconds: u64,
535 expires_in: Option<u64>,
537 },
538 NotAuthenticated,
540}
541
542impl AuthStatus {
543 pub fn is_authenticated(&self) -> bool {
545 matches!(self, AuthStatus::Authenticated { .. })
546 }
547
548 pub fn display_string(&self) -> String {
550 match self {
551 AuthStatus::Authenticated {
552 label,
553 age_seconds,
554 expires_in,
555 } => {
556 let label_str = label
557 .as_ref()
558 .map(|l| format!(" ({})", l))
559 .unwrap_or_default();
560 let age_str = humanize_duration(*age_seconds);
561 let expiry_str = expires_in
562 .map(|e| format!(", expires in {}", humanize_duration(e)))
563 .unwrap_or_default();
564 format!(
565 "Authenticated{}, obtained {}{}",
566 label_str, age_str, expiry_str
567 )
568 }
569 AuthStatus::NotAuthenticated => "Not authenticated".to_string(),
570 }
571 }
572}
573
574fn humanize_duration(seconds: u64) -> String {
576 if seconds < 60 {
577 format!("{}s ago", seconds)
578 } else if seconds < 3600 {
579 format!("{}m ago", seconds / 60)
580 } else if seconds < 86400 {
581 format!("{}h ago", seconds / 3600)
582 } else {
583 format!("{}d ago", seconds / 86400)
584 }
585}
586
587#[cfg(test)]
588mod tests {
589 use super::*;
590
591 #[test]
592 fn test_auth_url_generation() {
593 let challenge = PkceChallenge {
594 code_verifier: "test_verifier".to_string(),
595 code_challenge: "test_challenge".to_string(),
596 code_challenge_method: "S256".to_string(),
597 };
598
599 let url = get_auth_url(&challenge, 8484);
600
601 assert!(url.starts_with("https://openrouter.ai/auth"));
602 assert!(url.contains("callback_url="));
603 assert!(url.contains("code_challenge=test_challenge"));
604 assert!(url.contains("code_challenge_method=S256"));
605 }
606
607 #[test]
608 fn test_token_expiry_check() {
609 let now = std::time::SystemTime::now()
610 .duration_since(std::time::UNIX_EPOCH)
611 .unwrap()
612 .as_secs();
613
614 let token = OpenRouterToken {
616 api_key: "test".to_string(),
617 obtained_at: now,
618 expires_at: Some(now + 3600),
619 label: None,
620 };
621 assert!(!token.is_expired());
622
623 let expired_token = OpenRouterToken {
625 api_key: "test".to_string(),
626 obtained_at: now - 7200,
627 expires_at: Some(now - 3600),
628 label: None,
629 };
630 assert!(expired_token.is_expired());
631
632 let no_expiry_token = OpenRouterToken {
634 api_key: "test".to_string(),
635 obtained_at: now,
636 expires_at: None,
637 label: None,
638 };
639 assert!(!no_expiry_token.is_expired());
640 }
641
642 #[test]
643 fn test_encryption_roundtrip() {
644 let token = OpenRouterToken {
645 api_key: "sk-test-key-12345".to_string(),
646 obtained_at: 1234567890,
647 expires_at: Some(1234567890 + 86400),
648 label: Some("Test Token".to_string()),
649 };
650
651 let encrypted = encrypt_token(&token).unwrap();
652 let decrypted = decrypt_token(&encrypted).unwrap();
653
654 assert_eq!(decrypted.api_key, token.api_key);
655 assert_eq!(decrypted.obtained_at, token.obtained_at);
656 assert_eq!(decrypted.expires_at, token.expires_at);
657 assert_eq!(decrypted.label, token.label);
658 }
659
660 #[test]
661 fn test_auth_status_display() {
662 let status = AuthStatus::Authenticated {
663 label: Some("My App".to_string()),
664 age_seconds: 3700,
665 expires_in: Some(86000),
666 };
667
668 let display = status.display_string();
669 assert!(display.contains("Authenticated"));
670 assert!(display.contains("My App"));
671 }
672}