1use anyhow::{Context, Result, anyhow};
32use ring::aead::{self, Aad, LessSafeKey, NONCE_LEN, Nonce, UnboundKey};
33use ring::rand::{SecureRandom, SystemRandom};
34use serde::{Deserialize, Serialize};
35use std::fs;
36use std::path::PathBuf;
37
38pub use super::credentials::AuthCredentialsStoreMode;
39use super::credentials::keyring_entry;
40use super::pkce::PkceChallenge;
41use crate::storage_paths::{auth_storage_dir, write_private_file};
42
43const OPENROUTER_AUTH_URL: &str = "https://openrouter.ai/auth";
45const OPENROUTER_KEYS_URL: &str = "https://openrouter.ai/api/v1/auth/keys";
46
47pub const DEFAULT_CALLBACK_PORT: u16 = 8484;
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
52#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
53#[serde(default)]
54pub struct OpenRouterOAuthConfig {
55 pub use_oauth: bool,
57 pub callback_port: u16,
59 pub auto_refresh: bool,
61 pub flow_timeout_secs: u64,
63}
64
65impl Default for OpenRouterOAuthConfig {
66 fn default() -> Self {
67 Self {
68 use_oauth: false,
69 callback_port: DEFAULT_CALLBACK_PORT,
70 auto_refresh: true,
71 flow_timeout_secs: 300,
72 }
73 }
74}
75
76#[derive(Debug, Clone, Serialize, Deserialize)]
78pub struct OpenRouterToken {
79 pub api_key: String,
81 pub obtained_at: u64,
83 pub expires_at: Option<u64>,
85 pub label: Option<String>,
87}
88
89impl OpenRouterToken {
90 pub fn is_expired(&self) -> bool {
92 if let Some(expires_at) = self.expires_at {
93 let now = std::time::SystemTime::now()
94 .duration_since(std::time::UNIX_EPOCH)
95 .map(|d| d.as_secs())
96 .unwrap_or(0);
97 now >= expires_at
98 } else {
99 false
100 }
101 }
102}
103
104#[derive(Debug, Serialize, Deserialize)]
106struct EncryptedToken {
107 nonce: String,
109 ciphertext: String,
111 version: u8,
113}
114
115pub fn get_auth_url(challenge: &PkceChallenge, callback_port: u16) -> String {
124 let callback_url = format!("http://localhost:{}/callback", callback_port);
125 format!(
126 "{}?callback_url={}&code_challenge={}&code_challenge_method={}",
127 OPENROUTER_AUTH_URL,
128 urlencoding::encode(&callback_url),
129 urlencoding::encode(&challenge.code_challenge),
130 challenge.code_challenge_method
131 )
132}
133
134pub async fn exchange_code_for_token(code: &str, challenge: &PkceChallenge) -> Result<String> {
146 let client = reqwest::Client::new();
147
148 let payload = serde_json::json!({
149 "code": code,
150 "code_verifier": challenge.code_verifier,
151 "code_challenge_method": challenge.code_challenge_method
152 });
153
154 let response = client
155 .post(OPENROUTER_KEYS_URL)
156 .header("Content-Type", "application/json")
157 .json(&payload)
158 .send()
159 .await
160 .context("Failed to send token exchange request")?;
161
162 let status = response.status();
163 let body = response
164 .text()
165 .await
166 .context("Failed to read response body")?;
167
168 if !status.is_success() {
169 if status.as_u16() == 400 {
171 return Err(anyhow!(
172 "Invalid code_challenge_method. Ensure you're using the same method (S256) in both steps."
173 ));
174 } else if status.as_u16() == 403 {
175 return Err(anyhow!(
176 "Invalid code or code_verifier. The authorization code may have expired."
177 ));
178 } else if status.as_u16() == 405 {
179 return Err(anyhow!(
180 "Method not allowed. Ensure you're using POST over HTTPS."
181 ));
182 }
183 return Err(anyhow!("Token exchange failed (HTTP {}): {}", status, body));
184 }
185
186 let response_json: serde_json::Value =
188 serde_json::from_str(&body).context("Failed to parse token response")?;
189
190 let api_key = response_json
191 .get("key")
192 .and_then(|v| v.as_str())
193 .ok_or_else(|| anyhow!("Response missing 'key' field"))?
194 .to_string();
195
196 Ok(api_key)
197}
198
199fn get_token_path() -> Result<PathBuf> {
201 Ok(auth_storage_dir()?.join("openrouter.json"))
202}
203
204fn derive_encryption_key() -> Result<LessSafeKey> {
206 use ring::digest::{SHA256, digest};
207
208 let mut key_material = Vec::new();
210
211 if let Ok(hostname) = hostname::get() {
213 key_material.extend_from_slice(hostname.as_encoded_bytes());
214 }
215
216 #[cfg(unix)]
218 {
219 key_material.extend_from_slice(&nix::unistd::getuid().as_raw().to_le_bytes());
220 }
221 #[cfg(not(unix))]
222 {
223 if let Ok(user) = std::env::var("USER").or_else(|_| std::env::var("USERNAME")) {
224 key_material.extend_from_slice(user.as_bytes());
225 }
226 }
227
228 key_material.extend_from_slice(b"vtcode-openrouter-oauth-v1");
230
231 let hash = digest(&SHA256, &key_material);
233 let key_bytes: &[u8; 32] = hash.as_ref()[..32].try_into().context("Hash too short")?;
234
235 let unbound_key = UnboundKey::new(&aead::AES_256_GCM, key_bytes)
236 .map_err(|_| anyhow!("Invalid key length"))?;
237
238 Ok(LessSafeKey::new(unbound_key))
239}
240
241fn encrypt_token(token: &OpenRouterToken) -> Result<EncryptedToken> {
243 let key = derive_encryption_key()?;
244 let rng = SystemRandom::new();
245
246 let mut nonce_bytes = [0u8; NONCE_LEN];
248 rng.fill(&mut nonce_bytes)
249 .map_err(|_| anyhow!("Failed to generate nonce"))?;
250
251 let plaintext = serde_json::to_vec(token).context("Failed to serialize token")?;
253
254 let mut ciphertext = plaintext;
256 let nonce = Nonce::assume_unique_for_key(nonce_bytes);
257 key.seal_in_place_append_tag(nonce, Aad::empty(), &mut ciphertext)
258 .map_err(|_| anyhow!("Encryption failed"))?;
259
260 use base64::{Engine, engine::general_purpose::STANDARD};
261
262 Ok(EncryptedToken {
263 nonce: STANDARD.encode(nonce_bytes),
264 ciphertext: STANDARD.encode(&ciphertext),
265 version: 1,
266 })
267}
268
269fn decrypt_token(encrypted: &EncryptedToken) -> Result<OpenRouterToken> {
271 if encrypted.version != 1 {
272 return Err(anyhow!(
273 "Unsupported token format version: {}",
274 encrypted.version
275 ));
276 }
277
278 use base64::{Engine, engine::general_purpose::STANDARD};
279
280 let key = derive_encryption_key()?;
281
282 let nonce_bytes: [u8; NONCE_LEN] = STANDARD
283 .decode(&encrypted.nonce)
284 .context("Invalid nonce encoding")?
285 .try_into()
286 .map_err(|_| anyhow!("Invalid nonce length"))?;
287
288 let mut ciphertext = STANDARD
289 .decode(&encrypted.ciphertext)
290 .context("Invalid ciphertext encoding")?;
291
292 let nonce = Nonce::assume_unique_for_key(nonce_bytes);
293 let plaintext = key
294 .open_in_place(nonce, Aad::empty(), &mut ciphertext)
295 .map_err(|_| {
296 anyhow!("Decryption failed - token may be corrupted or from different machine")
297 })?;
298
299 serde_json::from_slice(plaintext).context("Failed to deserialize token")
300}
301
302pub fn save_oauth_token_with_mode(
308 token: &OpenRouterToken,
309 mode: AuthCredentialsStoreMode,
310) -> Result<()> {
311 let effective_mode = mode.effective_mode();
312
313 match effective_mode {
314 AuthCredentialsStoreMode::Keyring => save_oauth_token_keyring(token),
315 AuthCredentialsStoreMode::File => save_oauth_token_file(token),
316 _ => unreachable!(),
317 }
318}
319
320fn save_oauth_token_keyring(token: &OpenRouterToken) -> Result<()> {
322 let entry =
323 keyring_entry("vtcode", "openrouter_oauth").context("Failed to access OS keyring")?;
324
325 let token_json =
327 serde_json::to_string(token).context("Failed to serialize token for keyring")?;
328
329 entry
330 .set_password(&token_json)
331 .context("Failed to store token in OS keyring")?;
332
333 tracing::info!("OAuth token saved to OS keyring");
334 Ok(())
335}
336
337fn save_oauth_token_file(token: &OpenRouterToken) -> Result<()> {
339 let path = get_token_path()?;
340 let encrypted = encrypt_token(token)?;
341 let json =
342 serde_json::to_string_pretty(&encrypted).context("Failed to serialize encrypted token")?;
343 write_private_file(&path, json.as_bytes()).context("Failed to write token file")?;
344
345 tracing::info!("OAuth token saved to {}", path.display());
346 Ok(())
347}
348
349pub fn save_oauth_token(token: &OpenRouterToken) -> Result<()> {
354 save_oauth_token_with_mode(token, AuthCredentialsStoreMode::default())
355}
356
357pub fn load_oauth_token_with_mode(
361 mode: AuthCredentialsStoreMode,
362) -> Result<Option<OpenRouterToken>> {
363 let effective_mode = mode.effective_mode();
364
365 match effective_mode {
366 AuthCredentialsStoreMode::Keyring => load_oauth_token_keyring(),
367 AuthCredentialsStoreMode::File => load_oauth_token_file(),
368 _ => unreachable!(),
369 }
370}
371
372fn load_oauth_token_keyring() -> Result<Option<OpenRouterToken>> {
374 let entry = match keyring_entry("vtcode", "openrouter_oauth") {
375 Ok(e) => e,
376 Err(_) => return Ok(None),
377 };
378
379 let token_json = match entry.get_password() {
380 Ok(json) => json,
381 Err(keyring_core::Error::NoEntry) => return Ok(None),
382 Err(e) => return Err(anyhow!("Failed to read from keyring: {}", e)),
383 };
384
385 let token: OpenRouterToken =
386 serde_json::from_str(&token_json).context("Failed to parse token from keyring")?;
387
388 if token.is_expired() {
390 tracing::warn!("OAuth token has expired, removing...");
391 clear_oauth_token_keyring()?;
392 return Ok(None);
393 }
394
395 Ok(Some(token))
396}
397
398fn load_oauth_token_file() -> Result<Option<OpenRouterToken>> {
400 let path = get_token_path()?;
401
402 if !path.exists() {
403 return Ok(None);
404 }
405
406 let json = fs::read_to_string(&path).context("Failed to read token file")?;
407 let encrypted: EncryptedToken =
408 serde_json::from_str(&json).context("Failed to parse token file")?;
409
410 let token = decrypt_token(&encrypted)?;
411
412 if token.is_expired() {
414 tracing::warn!("OAuth token has expired, removing...");
415 clear_oauth_token_file()?;
416 return Ok(None);
417 }
418
419 Ok(Some(token))
420}
421
422pub fn load_oauth_token() -> Result<Option<OpenRouterToken>> {
434 match load_oauth_token_keyring() {
435 Ok(Some(token)) => return Ok(Some(token)),
436 Ok(None) => {
437 tracing::debug!("No token in keyring, checking file storage");
439 }
440 Err(e) => {
441 let error_str = e.to_string().to_lowercase();
443 if error_str.contains("no entry") || error_str.contains("not found") {
444 tracing::debug!("Keyring entry not found, checking file storage");
445 } else {
446 return Err(e);
449 }
450 }
451 }
452
453 load_oauth_token_file()
455}
456
457fn clear_oauth_token_keyring() -> Result<()> {
459 let entry = match keyring_entry("vtcode", "openrouter_oauth") {
460 Ok(e) => e,
461 Err(_) => return Ok(()),
462 };
463
464 match entry.delete_credential() {
465 Ok(_) => tracing::info!("OAuth token cleared from keyring"),
466 Err(keyring_core::Error::NoEntry) => {}
467 Err(e) => return Err(anyhow!("Failed to clear keyring entry: {}", e)),
468 }
469
470 Ok(())
471}
472
473fn clear_oauth_token_file() -> Result<()> {
475 let path = get_token_path()?;
476
477 if path.exists() {
478 fs::remove_file(&path).context("Failed to remove token file")?;
479 tracing::info!("OAuth token cleared from file");
480 }
481
482 Ok(())
483}
484
485pub fn clear_oauth_token_with_mode(mode: AuthCredentialsStoreMode) -> Result<()> {
487 match mode.effective_mode() {
488 AuthCredentialsStoreMode::Keyring => clear_oauth_token_keyring(),
489 AuthCredentialsStoreMode::File => clear_oauth_token_file(),
490 AuthCredentialsStoreMode::Auto => {
491 let _ = clear_oauth_token_keyring();
492 let _ = clear_oauth_token_file();
493 Ok(())
494 }
495 }
496}
497
498pub fn clear_oauth_token() -> Result<()> {
499 let _ = clear_oauth_token_keyring();
501 let _ = clear_oauth_token_file();
502
503 tracing::info!("OAuth token cleared from all storage");
504 Ok(())
505}
506
507pub fn get_auth_status_with_mode(mode: AuthCredentialsStoreMode) -> Result<AuthStatus> {
509 match load_oauth_token_with_mode(mode)? {
510 Some(token) => {
511 let now = std::time::SystemTime::now()
512 .duration_since(std::time::UNIX_EPOCH)
513 .map(|d| d.as_secs())
514 .unwrap_or(0);
515
516 let age_seconds = now.saturating_sub(token.obtained_at);
517
518 Ok(AuthStatus::Authenticated {
519 label: token.label,
520 age_seconds,
521 expires_in: token.expires_at.map(|e| e.saturating_sub(now)),
522 })
523 }
524 None => Ok(AuthStatus::NotAuthenticated),
525 }
526}
527
528pub fn get_auth_status() -> Result<AuthStatus> {
529 match load_oauth_token()? {
530 Some(token) => {
531 let now = std::time::SystemTime::now()
532 .duration_since(std::time::UNIX_EPOCH)
533 .map(|d| d.as_secs())
534 .unwrap_or(0);
535
536 let age_seconds = now.saturating_sub(token.obtained_at);
537
538 Ok(AuthStatus::Authenticated {
539 label: token.label,
540 age_seconds,
541 expires_in: token.expires_at.map(|e| e.saturating_sub(now)),
542 })
543 }
544 None => Ok(AuthStatus::NotAuthenticated),
545 }
546}
547
548#[derive(Debug, Clone)]
550pub enum AuthStatus {
551 Authenticated {
553 label: Option<String>,
555 age_seconds: u64,
557 expires_in: Option<u64>,
559 },
560 NotAuthenticated,
562}
563
564impl AuthStatus {
565 pub fn is_authenticated(&self) -> bool {
567 matches!(self, AuthStatus::Authenticated { .. })
568 }
569
570 pub fn display_string(&self) -> String {
572 match self {
573 AuthStatus::Authenticated {
574 label,
575 age_seconds,
576 expires_in,
577 } => {
578 let label_str = label
579 .as_ref()
580 .map(|l| format!(" ({})", l))
581 .unwrap_or_default();
582 let age_str = humanize_duration(*age_seconds);
583 let expiry_str = expires_in
584 .map(|e| format!(", expires in {}", humanize_duration(e)))
585 .unwrap_or_default();
586 format!(
587 "Authenticated{}, obtained {}{}",
588 label_str, age_str, expiry_str
589 )
590 }
591 AuthStatus::NotAuthenticated => "Not authenticated".to_string(),
592 }
593 }
594}
595
596fn humanize_duration(seconds: u64) -> String {
598 if seconds < 60 {
599 format!("{}s ago", seconds)
600 } else if seconds < 3600 {
601 format!("{}m ago", seconds / 60)
602 } else if seconds < 86400 {
603 format!("{}h ago", seconds / 3600)
604 } else {
605 format!("{}d ago", seconds / 86400)
606 }
607}
608
609#[cfg(test)]
610mod tests {
611 use super::*;
612 use assert_fs::TempDir;
613 use serial_test::serial;
614
615 struct TestAuthDirGuard {
616 temp_dir: Option<TempDir>,
617 previous: Option<PathBuf>,
618 }
619
620 impl TestAuthDirGuard {
621 fn new() -> Self {
622 let temp_dir = TempDir::new().expect("create temp auth dir");
623 let previous = crate::storage_paths::auth_storage_dir_override_for_tests()
624 .expect("read auth dir override");
625 crate::storage_paths::set_auth_storage_dir_override_for_tests(Some(
626 temp_dir.path().to_path_buf(),
627 ))
628 .expect("set temp auth dir override");
629 Self {
630 temp_dir: Some(temp_dir),
631 previous,
632 }
633 }
634 }
635
636 impl Drop for TestAuthDirGuard {
637 fn drop(&mut self) {
638 crate::storage_paths::set_auth_storage_dir_override_for_tests(self.previous.clone())
639 .expect("restore auth dir override");
640 if let Some(temp_dir) = self.temp_dir.take() {
641 temp_dir.close().expect("remove temp auth dir");
642 }
643 }
644 }
645
646 #[test]
647 fn test_auth_url_generation() {
648 let challenge = PkceChallenge {
649 code_verifier: "test_verifier".to_string(),
650 code_challenge: "test_challenge".to_string(),
651 code_challenge_method: "S256".to_string(),
652 };
653
654 let url = get_auth_url(&challenge, 8484);
655
656 assert!(url.starts_with("https://openrouter.ai/auth"));
657 assert!(url.contains("callback_url="));
658 assert!(url.contains("code_challenge=test_challenge"));
659 assert!(url.contains("code_challenge_method=S256"));
660 }
661
662 #[test]
663 fn test_token_expiry_check() {
664 let now = std::time::SystemTime::now()
665 .duration_since(std::time::UNIX_EPOCH)
666 .unwrap()
667 .as_secs();
668
669 let token = OpenRouterToken {
671 api_key: "test".to_string(),
672 obtained_at: now,
673 expires_at: Some(now + 3600),
674 label: None,
675 };
676 assert!(!token.is_expired());
677
678 let expired_token = OpenRouterToken {
680 api_key: "test".to_string(),
681 obtained_at: now - 7200,
682 expires_at: Some(now - 3600),
683 label: None,
684 };
685 assert!(expired_token.is_expired());
686
687 let no_expiry_token = OpenRouterToken {
689 api_key: "test".to_string(),
690 obtained_at: now,
691 expires_at: None,
692 label: None,
693 };
694 assert!(!no_expiry_token.is_expired());
695 }
696
697 #[test]
698 fn test_encryption_roundtrip() {
699 let token = OpenRouterToken {
700 api_key: "sk-test-key-12345".to_string(),
701 obtained_at: 1234567890,
702 expires_at: Some(1234567890 + 86400),
703 label: Some("Test Token".to_string()),
704 };
705
706 let encrypted = encrypt_token(&token).unwrap();
707 let decrypted = decrypt_token(&encrypted).unwrap();
708
709 assert_eq!(decrypted.api_key, token.api_key);
710 assert_eq!(decrypted.obtained_at, token.obtained_at);
711 assert_eq!(decrypted.expires_at, token.expires_at);
712 assert_eq!(decrypted.label, token.label);
713 }
714
715 #[test]
716 fn test_auth_status_display() {
717 let status = AuthStatus::Authenticated {
718 label: Some("My App".to_string()),
719 age_seconds: 3700,
720 expires_in: Some(86000),
721 };
722
723 let display = status.display_string();
724 assert!(display.contains("Authenticated"));
725 assert!(display.contains("My App"));
726 }
727
728 #[test]
729 #[serial]
730 fn file_storage_round_trips_without_plaintext() {
731 let _guard = TestAuthDirGuard::new();
732 let now = std::time::SystemTime::now()
733 .duration_since(std::time::UNIX_EPOCH)
734 .unwrap()
735 .as_secs();
736 let token = OpenRouterToken {
737 api_key: "sk-test-key-12345".to_string(),
738 obtained_at: now,
739 expires_at: Some(now + 86400),
740 label: Some("Test Token".to_string()),
741 };
742
743 save_oauth_token_with_mode(&token, AuthCredentialsStoreMode::File).expect("save token");
744 let loaded =
745 load_oauth_token_with_mode(AuthCredentialsStoreMode::File).expect("load token");
746 assert_eq!(
747 loaded.as_ref().map(|value| &value.api_key),
748 Some(&token.api_key)
749 );
750
751 let stored =
752 fs::read_to_string(get_token_path().expect("token path")).expect("read token file");
753 assert!(!stored.contains(&token.api_key));
754 }
755
756 #[test]
757 #[serial]
758 #[cfg(unix)]
759 fn file_storage_uses_private_permissions() {
760 use std::os::unix::fs::PermissionsExt;
761
762 let _guard = TestAuthDirGuard::new();
763 let now = std::time::SystemTime::now()
764 .duration_since(std::time::UNIX_EPOCH)
765 .unwrap()
766 .as_secs();
767 let token = OpenRouterToken {
768 api_key: "sk-test-key-12345".to_string(),
769 obtained_at: now,
770 expires_at: Some(now + 86400),
771 label: Some("Test Token".to_string()),
772 };
773
774 save_oauth_token_with_mode(&token, AuthCredentialsStoreMode::File).expect("save token");
775
776 let metadata =
777 fs::metadata(get_token_path().expect("token path")).expect("read token metadata");
778 assert_eq!(metadata.permissions().mode() & 0o777, 0o600);
779 }
780}