1use anyhow::{Context, Result, anyhow};
32use ring::aead::{self, Aad, LessSafeKey, NONCE_LEN, Nonce, UnboundKey};
33use ring::rand::{SecureRandom, SystemRandom};
34use serde::{Deserialize, Serialize};
35use std::fs;
36use std::path::PathBuf;
37
38pub use super::credentials::AuthCredentialsStoreMode;
39use super::pkce::PkceChallenge;
40use crate::storage_paths::{auth_storage_dir, write_private_file};
41
42const OPENROUTER_AUTH_URL: &str = "https://openrouter.ai/auth";
44const OPENROUTER_KEYS_URL: &str = "https://openrouter.ai/api/v1/auth/keys";
45
46pub const DEFAULT_CALLBACK_PORT: u16 = 8484;
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
51#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
52#[serde(default)]
53pub struct OpenRouterOAuthConfig {
54 pub use_oauth: bool,
56 pub callback_port: u16,
58 pub auto_refresh: bool,
60 pub flow_timeout_secs: u64,
62}
63
64impl Default for OpenRouterOAuthConfig {
65 fn default() -> Self {
66 Self {
67 use_oauth: false,
68 callback_port: DEFAULT_CALLBACK_PORT,
69 auto_refresh: true,
70 flow_timeout_secs: 300,
71 }
72 }
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct OpenRouterToken {
78 pub api_key: String,
80 pub obtained_at: u64,
82 pub expires_at: Option<u64>,
84 pub label: Option<String>,
86}
87
88impl OpenRouterToken {
89 pub fn is_expired(&self) -> bool {
91 if let Some(expires_at) = self.expires_at {
92 let now = std::time::SystemTime::now()
93 .duration_since(std::time::UNIX_EPOCH)
94 .map(|d| d.as_secs())
95 .unwrap_or(0);
96 now >= expires_at
97 } else {
98 false
99 }
100 }
101}
102
103#[derive(Debug, Serialize, Deserialize)]
105struct EncryptedToken {
106 nonce: String,
108 ciphertext: String,
110 version: u8,
112}
113
114pub fn get_auth_url(challenge: &PkceChallenge, callback_port: u16) -> String {
123 let callback_url = format!("http://localhost:{}/callback", callback_port);
124 format!(
125 "{}?callback_url={}&code_challenge={}&code_challenge_method={}",
126 OPENROUTER_AUTH_URL,
127 urlencoding::encode(&callback_url),
128 urlencoding::encode(&challenge.code_challenge),
129 challenge.code_challenge_method
130 )
131}
132
133pub async fn exchange_code_for_token(code: &str, challenge: &PkceChallenge) -> Result<String> {
145 let client = reqwest::Client::new();
146
147 let payload = serde_json::json!({
148 "code": code,
149 "code_verifier": challenge.code_verifier,
150 "code_challenge_method": challenge.code_challenge_method
151 });
152
153 let response = client
154 .post(OPENROUTER_KEYS_URL)
155 .header("Content-Type", "application/json")
156 .json(&payload)
157 .send()
158 .await
159 .context("Failed to send token exchange request")?;
160
161 let status = response.status();
162 let body = response
163 .text()
164 .await
165 .context("Failed to read response body")?;
166
167 if !status.is_success() {
168 if status.as_u16() == 400 {
170 return Err(anyhow!(
171 "Invalid code_challenge_method. Ensure you're using the same method (S256) in both steps."
172 ));
173 } else if status.as_u16() == 403 {
174 return Err(anyhow!(
175 "Invalid code or code_verifier. The authorization code may have expired."
176 ));
177 } else if status.as_u16() == 405 {
178 return Err(anyhow!(
179 "Method not allowed. Ensure you're using POST over HTTPS."
180 ));
181 }
182 return Err(anyhow!("Token exchange failed (HTTP {}): {}", status, body));
183 }
184
185 let response_json: serde_json::Value =
187 serde_json::from_str(&body).context("Failed to parse token response")?;
188
189 let api_key = response_json
190 .get("key")
191 .and_then(|v| v.as_str())
192 .ok_or_else(|| anyhow!("Response missing 'key' field"))?
193 .to_string();
194
195 Ok(api_key)
196}
197
198fn get_token_path() -> Result<PathBuf> {
200 Ok(auth_storage_dir()?.join("openrouter.json"))
201}
202
203fn derive_encryption_key() -> Result<LessSafeKey> {
205 use ring::digest::{SHA256, digest};
206
207 let mut key_material = Vec::new();
209
210 if let Ok(hostname) = hostname::get() {
212 key_material.extend_from_slice(hostname.as_encoded_bytes());
213 }
214
215 #[cfg(unix)]
217 {
218 key_material.extend_from_slice(&nix::unistd::getuid().as_raw().to_le_bytes());
219 }
220 #[cfg(not(unix))]
221 {
222 if let Ok(user) = std::env::var("USER").or_else(|_| std::env::var("USERNAME")) {
223 key_material.extend_from_slice(user.as_bytes());
224 }
225 }
226
227 key_material.extend_from_slice(b"vtcode-openrouter-oauth-v1");
229
230 let hash = digest(&SHA256, &key_material);
232 let key_bytes: &[u8; 32] = hash.as_ref()[..32].try_into().context("Hash too short")?;
233
234 let unbound_key = UnboundKey::new(&aead::AES_256_GCM, key_bytes)
235 .map_err(|_| anyhow!("Invalid key length"))?;
236
237 Ok(LessSafeKey::new(unbound_key))
238}
239
240fn encrypt_token(token: &OpenRouterToken) -> Result<EncryptedToken> {
242 let key = derive_encryption_key()?;
243 let rng = SystemRandom::new();
244
245 let mut nonce_bytes = [0u8; NONCE_LEN];
247 rng.fill(&mut nonce_bytes)
248 .map_err(|_| anyhow!("Failed to generate nonce"))?;
249
250 let plaintext = serde_json::to_vec(token).context("Failed to serialize token")?;
252
253 let mut ciphertext = plaintext;
255 let nonce = Nonce::assume_unique_for_key(nonce_bytes);
256 key.seal_in_place_append_tag(nonce, Aad::empty(), &mut ciphertext)
257 .map_err(|_| anyhow!("Encryption failed"))?;
258
259 use base64::{Engine, engine::general_purpose::STANDARD};
260
261 Ok(EncryptedToken {
262 nonce: STANDARD.encode(nonce_bytes),
263 ciphertext: STANDARD.encode(&ciphertext),
264 version: 1,
265 })
266}
267
268fn decrypt_token(encrypted: &EncryptedToken) -> Result<OpenRouterToken> {
270 if encrypted.version != 1 {
271 return Err(anyhow!(
272 "Unsupported token format version: {}",
273 encrypted.version
274 ));
275 }
276
277 use base64::{Engine, engine::general_purpose::STANDARD};
278
279 let key = derive_encryption_key()?;
280
281 let nonce_bytes: [u8; NONCE_LEN] = STANDARD
282 .decode(&encrypted.nonce)
283 .context("Invalid nonce encoding")?
284 .try_into()
285 .map_err(|_| anyhow!("Invalid nonce length"))?;
286
287 let mut ciphertext = STANDARD
288 .decode(&encrypted.ciphertext)
289 .context("Invalid ciphertext encoding")?;
290
291 let nonce = Nonce::assume_unique_for_key(nonce_bytes);
292 let plaintext = key
293 .open_in_place(nonce, Aad::empty(), &mut ciphertext)
294 .map_err(|_| {
295 anyhow!("Decryption failed - token may be corrupted or from different machine")
296 })?;
297
298 serde_json::from_slice(plaintext).context("Failed to deserialize token")
299}
300
301pub fn save_oauth_token_with_mode(
307 token: &OpenRouterToken,
308 mode: AuthCredentialsStoreMode,
309) -> Result<()> {
310 let effective_mode = mode.effective_mode();
311
312 match effective_mode {
313 AuthCredentialsStoreMode::Keyring => save_oauth_token_keyring(token),
314 AuthCredentialsStoreMode::File => save_oauth_token_file(token),
315 _ => unreachable!(),
316 }
317}
318
319fn save_oauth_token_keyring(token: &OpenRouterToken) -> Result<()> {
321 let entry =
322 keyring::Entry::new("vtcode", "openrouter_oauth").context("Failed to access OS keyring")?;
323
324 let token_json =
326 serde_json::to_string(token).context("Failed to serialize token for keyring")?;
327
328 entry
329 .set_password(&token_json)
330 .context("Failed to store token in OS keyring")?;
331
332 tracing::info!("OAuth token saved to OS keyring");
333 Ok(())
334}
335
336fn save_oauth_token_file(token: &OpenRouterToken) -> Result<()> {
338 let path = get_token_path()?;
339 let encrypted = encrypt_token(token)?;
340 let json =
341 serde_json::to_string_pretty(&encrypted).context("Failed to serialize encrypted token")?;
342 write_private_file(&path, json.as_bytes()).context("Failed to write token file")?;
343
344 tracing::info!("OAuth token saved to {}", path.display());
345 Ok(())
346}
347
348pub fn save_oauth_token(token: &OpenRouterToken) -> Result<()> {
353 save_oauth_token_with_mode(token, AuthCredentialsStoreMode::default())
354}
355
356pub fn load_oauth_token_with_mode(
360 mode: AuthCredentialsStoreMode,
361) -> Result<Option<OpenRouterToken>> {
362 let effective_mode = mode.effective_mode();
363
364 match effective_mode {
365 AuthCredentialsStoreMode::Keyring => load_oauth_token_keyring(),
366 AuthCredentialsStoreMode::File => load_oauth_token_file(),
367 _ => unreachable!(),
368 }
369}
370
371fn load_oauth_token_keyring() -> Result<Option<OpenRouterToken>> {
373 let entry = match keyring::Entry::new("vtcode", "openrouter_oauth") {
374 Ok(e) => e,
375 Err(_) => return Ok(None),
376 };
377
378 let token_json = match entry.get_password() {
379 Ok(json) => json,
380 Err(keyring::Error::NoEntry) => return Ok(None),
381 Err(e) => return Err(anyhow!("Failed to read from keyring: {}", e)),
382 };
383
384 let token: OpenRouterToken =
385 serde_json::from_str(&token_json).context("Failed to parse token from keyring")?;
386
387 if token.is_expired() {
389 tracing::warn!("OAuth token has expired, removing...");
390 clear_oauth_token_keyring()?;
391 return Ok(None);
392 }
393
394 Ok(Some(token))
395}
396
397fn load_oauth_token_file() -> Result<Option<OpenRouterToken>> {
399 let path = get_token_path()?;
400
401 if !path.exists() {
402 return Ok(None);
403 }
404
405 let json = fs::read_to_string(&path).context("Failed to read token file")?;
406 let encrypted: EncryptedToken =
407 serde_json::from_str(&json).context("Failed to parse token file")?;
408
409 let token = decrypt_token(&encrypted)?;
410
411 if token.is_expired() {
413 tracing::warn!("OAuth token has expired, removing...");
414 clear_oauth_token_file()?;
415 return Ok(None);
416 }
417
418 Ok(Some(token))
419}
420
421pub fn load_oauth_token() -> Result<Option<OpenRouterToken>> {
433 match load_oauth_token_keyring() {
434 Ok(Some(token)) => return Ok(Some(token)),
435 Ok(None) => {
436 tracing::debug!("No token in keyring, checking file storage");
438 }
439 Err(e) => {
440 let error_str = e.to_string().to_lowercase();
442 if error_str.contains("no entry") || error_str.contains("not found") {
443 tracing::debug!("Keyring entry not found, checking file storage");
444 } else {
445 return Err(e);
448 }
449 }
450 }
451
452 load_oauth_token_file()
454}
455
456fn clear_oauth_token_keyring() -> Result<()> {
458 let entry = match keyring::Entry::new("vtcode", "openrouter_oauth") {
459 Ok(e) => e,
460 Err(_) => return Ok(()),
461 };
462
463 match entry.delete_credential() {
464 Ok(_) => tracing::info!("OAuth token cleared from keyring"),
465 Err(keyring::Error::NoEntry) => {}
466 Err(e) => return Err(anyhow!("Failed to clear keyring entry: {}", e)),
467 }
468
469 Ok(())
470}
471
472fn clear_oauth_token_file() -> Result<()> {
474 let path = get_token_path()?;
475
476 if path.exists() {
477 fs::remove_file(&path).context("Failed to remove token file")?;
478 tracing::info!("OAuth token cleared from file");
479 }
480
481 Ok(())
482}
483
484pub fn clear_oauth_token_with_mode(mode: AuthCredentialsStoreMode) -> Result<()> {
486 match mode.effective_mode() {
487 AuthCredentialsStoreMode::Keyring => clear_oauth_token_keyring(),
488 AuthCredentialsStoreMode::File => clear_oauth_token_file(),
489 AuthCredentialsStoreMode::Auto => {
490 let _ = clear_oauth_token_keyring();
491 let _ = clear_oauth_token_file();
492 Ok(())
493 }
494 }
495}
496
497pub fn clear_oauth_token() -> Result<()> {
498 let _ = clear_oauth_token_keyring();
500 let _ = clear_oauth_token_file();
501
502 tracing::info!("OAuth token cleared from all storage");
503 Ok(())
504}
505
506pub fn get_auth_status_with_mode(mode: AuthCredentialsStoreMode) -> Result<AuthStatus> {
508 match load_oauth_token_with_mode(mode)? {
509 Some(token) => {
510 let now = std::time::SystemTime::now()
511 .duration_since(std::time::UNIX_EPOCH)
512 .map(|d| d.as_secs())
513 .unwrap_or(0);
514
515 let age_seconds = now.saturating_sub(token.obtained_at);
516
517 Ok(AuthStatus::Authenticated {
518 label: token.label,
519 age_seconds,
520 expires_in: token.expires_at.map(|e| e.saturating_sub(now)),
521 })
522 }
523 None => Ok(AuthStatus::NotAuthenticated),
524 }
525}
526
527pub fn get_auth_status() -> Result<AuthStatus> {
528 match load_oauth_token()? {
529 Some(token) => {
530 let now = std::time::SystemTime::now()
531 .duration_since(std::time::UNIX_EPOCH)
532 .map(|d| d.as_secs())
533 .unwrap_or(0);
534
535 let age_seconds = now.saturating_sub(token.obtained_at);
536
537 Ok(AuthStatus::Authenticated {
538 label: token.label,
539 age_seconds,
540 expires_in: token.expires_at.map(|e| e.saturating_sub(now)),
541 })
542 }
543 None => Ok(AuthStatus::NotAuthenticated),
544 }
545}
546
547#[derive(Debug, Clone)]
549pub enum AuthStatus {
550 Authenticated {
552 label: Option<String>,
554 age_seconds: u64,
556 expires_in: Option<u64>,
558 },
559 NotAuthenticated,
561}
562
563impl AuthStatus {
564 pub fn is_authenticated(&self) -> bool {
566 matches!(self, AuthStatus::Authenticated { .. })
567 }
568
569 pub fn display_string(&self) -> String {
571 match self {
572 AuthStatus::Authenticated {
573 label,
574 age_seconds,
575 expires_in,
576 } => {
577 let label_str = label
578 .as_ref()
579 .map(|l| format!(" ({})", l))
580 .unwrap_or_default();
581 let age_str = humanize_duration(*age_seconds);
582 let expiry_str = expires_in
583 .map(|e| format!(", expires in {}", humanize_duration(e)))
584 .unwrap_or_default();
585 format!(
586 "Authenticated{}, obtained {}{}",
587 label_str, age_str, expiry_str
588 )
589 }
590 AuthStatus::NotAuthenticated => "Not authenticated".to_string(),
591 }
592 }
593}
594
595fn humanize_duration(seconds: u64) -> String {
597 if seconds < 60 {
598 format!("{}s ago", seconds)
599 } else if seconds < 3600 {
600 format!("{}m ago", seconds / 60)
601 } else if seconds < 86400 {
602 format!("{}h ago", seconds / 3600)
603 } else {
604 format!("{}d ago", seconds / 86400)
605 }
606}
607
608#[cfg(test)]
609mod tests {
610 use super::*;
611 use assert_fs::TempDir;
612 use serial_test::serial;
613
614 struct TestAuthDirGuard {
615 temp_dir: Option<TempDir>,
616 previous: Option<PathBuf>,
617 }
618
619 impl TestAuthDirGuard {
620 fn new() -> Self {
621 let temp_dir = TempDir::new().expect("create temp auth dir");
622 let previous = crate::storage_paths::auth_storage_dir_override_for_tests()
623 .expect("read auth dir override");
624 crate::storage_paths::set_auth_storage_dir_override_for_tests(Some(
625 temp_dir.path().to_path_buf(),
626 ))
627 .expect("set temp auth dir override");
628 Self {
629 temp_dir: Some(temp_dir),
630 previous,
631 }
632 }
633 }
634
635 impl Drop for TestAuthDirGuard {
636 fn drop(&mut self) {
637 crate::storage_paths::set_auth_storage_dir_override_for_tests(self.previous.clone())
638 .expect("restore auth dir override");
639 if let Some(temp_dir) = self.temp_dir.take() {
640 temp_dir.close().expect("remove temp auth dir");
641 }
642 }
643 }
644
645 #[test]
646 fn test_auth_url_generation() {
647 let challenge = PkceChallenge {
648 code_verifier: "test_verifier".to_string(),
649 code_challenge: "test_challenge".to_string(),
650 code_challenge_method: "S256".to_string(),
651 };
652
653 let url = get_auth_url(&challenge, 8484);
654
655 assert!(url.starts_with("https://openrouter.ai/auth"));
656 assert!(url.contains("callback_url="));
657 assert!(url.contains("code_challenge=test_challenge"));
658 assert!(url.contains("code_challenge_method=S256"));
659 }
660
661 #[test]
662 fn test_token_expiry_check() {
663 let now = std::time::SystemTime::now()
664 .duration_since(std::time::UNIX_EPOCH)
665 .unwrap()
666 .as_secs();
667
668 let token = OpenRouterToken {
670 api_key: "test".to_string(),
671 obtained_at: now,
672 expires_at: Some(now + 3600),
673 label: None,
674 };
675 assert!(!token.is_expired());
676
677 let expired_token = OpenRouterToken {
679 api_key: "test".to_string(),
680 obtained_at: now - 7200,
681 expires_at: Some(now - 3600),
682 label: None,
683 };
684 assert!(expired_token.is_expired());
685
686 let no_expiry_token = OpenRouterToken {
688 api_key: "test".to_string(),
689 obtained_at: now,
690 expires_at: None,
691 label: None,
692 };
693 assert!(!no_expiry_token.is_expired());
694 }
695
696 #[test]
697 fn test_encryption_roundtrip() {
698 let token = OpenRouterToken {
699 api_key: "sk-test-key-12345".to_string(),
700 obtained_at: 1234567890,
701 expires_at: Some(1234567890 + 86400),
702 label: Some("Test Token".to_string()),
703 };
704
705 let encrypted = encrypt_token(&token).unwrap();
706 let decrypted = decrypt_token(&encrypted).unwrap();
707
708 assert_eq!(decrypted.api_key, token.api_key);
709 assert_eq!(decrypted.obtained_at, token.obtained_at);
710 assert_eq!(decrypted.expires_at, token.expires_at);
711 assert_eq!(decrypted.label, token.label);
712 }
713
714 #[test]
715 fn test_auth_status_display() {
716 let status = AuthStatus::Authenticated {
717 label: Some("My App".to_string()),
718 age_seconds: 3700,
719 expires_in: Some(86000),
720 };
721
722 let display = status.display_string();
723 assert!(display.contains("Authenticated"));
724 assert!(display.contains("My App"));
725 }
726
727 #[test]
728 #[serial]
729 fn file_storage_round_trips_without_plaintext() {
730 let _guard = TestAuthDirGuard::new();
731 let now = std::time::SystemTime::now()
732 .duration_since(std::time::UNIX_EPOCH)
733 .unwrap()
734 .as_secs();
735 let token = OpenRouterToken {
736 api_key: "sk-test-key-12345".to_string(),
737 obtained_at: now,
738 expires_at: Some(now + 86400),
739 label: Some("Test Token".to_string()),
740 };
741
742 save_oauth_token_with_mode(&token, AuthCredentialsStoreMode::File).expect("save token");
743 let loaded =
744 load_oauth_token_with_mode(AuthCredentialsStoreMode::File).expect("load token");
745 assert_eq!(
746 loaded.as_ref().map(|value| &value.api_key),
747 Some(&token.api_key)
748 );
749
750 let stored =
751 fs::read_to_string(get_token_path().expect("token path")).expect("read token file");
752 assert!(!stored.contains(&token.api_key));
753 }
754
755 #[test]
756 #[serial]
757 #[cfg(unix)]
758 fn file_storage_uses_private_permissions() {
759 use std::os::unix::fs::PermissionsExt;
760
761 let _guard = TestAuthDirGuard::new();
762 let now = std::time::SystemTime::now()
763 .duration_since(std::time::UNIX_EPOCH)
764 .unwrap()
765 .as_secs();
766 let token = OpenRouterToken {
767 api_key: "sk-test-key-12345".to_string(),
768 obtained_at: now,
769 expires_at: Some(now + 86400),
770 label: Some("Test Token".to_string()),
771 };
772
773 save_oauth_token_with_mode(&token, AuthCredentialsStoreMode::File).expect("save token");
774
775 let metadata =
776 fs::metadata(get_token_path().expect("token path")).expect("read token metadata");
777 assert_eq!(metadata.permissions().mode() & 0o777, 0o600);
778 }
779}