1use anyhow::{Context, Result, anyhow};
18use ring::aead::{self, Aad, LessSafeKey, Nonce, UnboundKey, NONCE_LEN};
19use ring::rand::{SecureRandom, SystemRandom};
20use serde::{Deserialize, Serialize};
21use std::fs;
22use std::path::PathBuf;
23
24use super::pkce::PkceChallenge;
25
26const OPENROUTER_AUTH_URL: &str = "https://openrouter.ai/auth";
28const OPENROUTER_KEYS_URL: &str = "https://openrouter.ai/api/v1/auth/keys";
29
30pub const DEFAULT_CALLBACK_PORT: u16 = 8484;
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
35#[serde(default)]
36pub struct OpenRouterOAuthConfig {
37 pub use_oauth: bool,
39 pub callback_port: u16,
41 pub auto_refresh: bool,
43}
44
45impl Default for OpenRouterOAuthConfig {
46 fn default() -> Self {
47 Self {
48 use_oauth: false,
49 callback_port: DEFAULT_CALLBACK_PORT,
50 auto_refresh: true,
51 }
52 }
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct OpenRouterToken {
58 pub api_key: String,
60 pub obtained_at: u64,
62 pub expires_at: Option<u64>,
64 pub label: Option<String>,
66}
67
68impl OpenRouterToken {
69 pub fn is_expired(&self) -> bool {
71 if let Some(expires_at) = self.expires_at {
72 let now = std::time::SystemTime::now()
73 .duration_since(std::time::UNIX_EPOCH)
74 .map(|d| d.as_secs())
75 .unwrap_or(0);
76 now >= expires_at
77 } else {
78 false
79 }
80 }
81}
82
83#[derive(Debug, Serialize, Deserialize)]
85struct EncryptedToken {
86 nonce: String,
88 ciphertext: String,
90 version: u8,
92}
93
94pub fn get_auth_url(challenge: &PkceChallenge, callback_port: u16) -> String {
103 let callback_url = format!("http://localhost:{}/callback", callback_port);
104 format!(
105 "{}?callback_url={}&code_challenge={}&code_challenge_method={}",
106 OPENROUTER_AUTH_URL,
107 urlencoding::encode(&callback_url),
108 urlencoding::encode(&challenge.code_challenge),
109 challenge.code_challenge_method
110 )
111}
112
113pub async fn exchange_code_for_token(code: &str, challenge: &PkceChallenge) -> Result<String> {
125 let client = reqwest::Client::new();
126
127 let payload = serde_json::json!({
128 "code": code,
129 "code_verifier": challenge.code_verifier,
130 "code_challenge_method": challenge.code_challenge_method
131 });
132
133 let response = client
134 .post(OPENROUTER_KEYS_URL)
135 .header("Content-Type", "application/json")
136 .json(&payload)
137 .send()
138 .await
139 .context("Failed to send token exchange request")?;
140
141 let status = response.status();
142 let body = response.text().await.context("Failed to read response body")?;
143
144 if !status.is_success() {
145 if status.as_u16() == 400 {
147 return Err(anyhow!(
148 "Invalid code_challenge_method. Ensure you're using the same method (S256) in both steps."
149 ));
150 } else if status.as_u16() == 403 {
151 return Err(anyhow!(
152 "Invalid code or code_verifier. The authorization code may have expired."
153 ));
154 } else if status.as_u16() == 405 {
155 return Err(anyhow!(
156 "Method not allowed. Ensure you're using POST over HTTPS."
157 ));
158 }
159 return Err(anyhow!("Token exchange failed (HTTP {}): {}", status, body));
160 }
161
162 let response_json: serde_json::Value =
164 serde_json::from_str(&body).context("Failed to parse token response")?;
165
166 let api_key = response_json
167 .get("key")
168 .and_then(|v| v.as_str())
169 .ok_or_else(|| anyhow!("Response missing 'key' field"))?
170 .to_string();
171
172 Ok(api_key)
173}
174
175fn get_token_path() -> Result<PathBuf> {
177 let vtcode_dir = dirs::home_dir()
178 .ok_or_else(|| anyhow!("Could not determine home directory"))?
179 .join(".vtcode")
180 .join("auth");
181
182 fs::create_dir_all(&vtcode_dir).context("Failed to create auth directory")?;
183
184 Ok(vtcode_dir.join("openrouter.json"))
185}
186
187fn derive_encryption_key() -> Result<LessSafeKey> {
189 use ring::digest::{SHA256, digest};
190
191 let mut key_material = Vec::new();
193
194 if let Ok(hostname) = hostname::get() {
196 key_material.extend_from_slice(hostname.as_encoded_bytes());
197 }
198
199 #[cfg(unix)]
201 {
202 key_material.extend_from_slice(&nix::unistd::getuid().as_raw().to_le_bytes());
203 }
204 #[cfg(not(unix))]
205 {
206 if let Ok(user) = std::env::var("USER").or_else(|_| std::env::var("USERNAME")) {
207 key_material.extend_from_slice(user.as_bytes());
208 }
209 }
210
211 key_material.extend_from_slice(b"vtcode-openrouter-oauth-v1");
213
214 let hash = digest(&SHA256, &key_material);
216 let key_bytes: &[u8; 32] = hash.as_ref()[..32]
217 .try_into()
218 .context("Hash too short")?;
219
220 let unbound_key =
221 UnboundKey::new(&aead::AES_256_GCM, key_bytes).map_err(|_| anyhow!("Invalid key length"))?;
222
223 Ok(LessSafeKey::new(unbound_key))
224}
225
226fn encrypt_token(token: &OpenRouterToken) -> Result<EncryptedToken> {
228 let key = derive_encryption_key()?;
229 let rng = SystemRandom::new();
230
231 let mut nonce_bytes = [0u8; NONCE_LEN];
233 rng.fill(&mut nonce_bytes)
234 .map_err(|_| anyhow!("Failed to generate nonce"))?;
235
236 let plaintext = serde_json::to_vec(token).context("Failed to serialize token")?;
238
239 let mut ciphertext = plaintext;
241 let nonce = Nonce::assume_unique_for_key(nonce_bytes);
242 key.seal_in_place_append_tag(nonce, Aad::empty(), &mut ciphertext)
243 .map_err(|_| anyhow!("Encryption failed"))?;
244
245 use base64::{Engine, engine::general_purpose::STANDARD};
246
247 Ok(EncryptedToken {
248 nonce: STANDARD.encode(nonce_bytes),
249 ciphertext: STANDARD.encode(&ciphertext),
250 version: 1,
251 })
252}
253
254fn decrypt_token(encrypted: &EncryptedToken) -> Result<OpenRouterToken> {
256 if encrypted.version != 1 {
257 return Err(anyhow!(
258 "Unsupported token format version: {}",
259 encrypted.version
260 ));
261 }
262
263 use base64::{Engine, engine::general_purpose::STANDARD};
264
265 let key = derive_encryption_key()?;
266
267 let nonce_bytes: [u8; NONCE_LEN] = STANDARD
268 .decode(&encrypted.nonce)
269 .context("Invalid nonce encoding")?
270 .try_into()
271 .map_err(|_| anyhow!("Invalid nonce length"))?;
272
273 let mut ciphertext = STANDARD
274 .decode(&encrypted.ciphertext)
275 .context("Invalid ciphertext encoding")?;
276
277 let nonce = Nonce::assume_unique_for_key(nonce_bytes);
278 let plaintext = key
279 .open_in_place(nonce, Aad::empty(), &mut ciphertext)
280 .map_err(|_| anyhow!("Decryption failed - token may be corrupted or from different machine"))?;
281
282 serde_json::from_slice(plaintext).context("Failed to deserialize token")
283}
284
285pub fn save_oauth_token(token: &OpenRouterToken) -> Result<()> {
287 let path = get_token_path()?;
288 let encrypted = encrypt_token(token)?;
289 let json = serde_json::to_string_pretty(&encrypted).context("Failed to serialize encrypted token")?;
290
291 fs::write(&path, json).context("Failed to write token file")?;
292
293 #[cfg(unix)]
295 {
296 use std::os::unix::fs::PermissionsExt;
297 let perms = std::fs::Permissions::from_mode(0o600);
298 fs::set_permissions(&path, perms).context("Failed to set token file permissions")?;
299 }
300
301 tracing::info!("OAuth token saved to {}", path.display());
302 Ok(())
303}
304
305pub fn load_oauth_token() -> Result<Option<OpenRouterToken>> {
309 let path = get_token_path()?;
310
311 if !path.exists() {
312 return Ok(None);
313 }
314
315 let json = fs::read_to_string(&path).context("Failed to read token file")?;
316 let encrypted: EncryptedToken =
317 serde_json::from_str(&json).context("Failed to parse token file")?;
318
319 let token = decrypt_token(&encrypted)?;
320
321 if token.is_expired() {
323 tracing::warn!("OAuth token has expired, removing...");
324 clear_oauth_token()?;
325 return Ok(None);
326 }
327
328 Ok(Some(token))
329}
330
331pub fn clear_oauth_token() -> Result<()> {
333 let path = get_token_path()?;
334
335 if path.exists() {
336 fs::remove_file(&path).context("Failed to remove token file")?;
337 tracing::info!("OAuth token cleared");
338 }
339
340 Ok(())
341}
342
343pub fn get_auth_status() -> Result<AuthStatus> {
345 match load_oauth_token()? {
346 Some(token) => {
347 let now = std::time::SystemTime::now()
348 .duration_since(std::time::UNIX_EPOCH)
349 .map(|d| d.as_secs())
350 .unwrap_or(0);
351
352 let age_seconds = now.saturating_sub(token.obtained_at);
353
354 Ok(AuthStatus::Authenticated {
355 label: token.label,
356 age_seconds,
357 expires_in: token.expires_at.map(|e| e.saturating_sub(now)),
358 })
359 }
360 None => Ok(AuthStatus::NotAuthenticated),
361 }
362}
363
364#[derive(Debug, Clone)]
366pub enum AuthStatus {
367 Authenticated {
369 label: Option<String>,
371 age_seconds: u64,
373 expires_in: Option<u64>,
375 },
376 NotAuthenticated,
378}
379
380impl AuthStatus {
381 pub fn is_authenticated(&self) -> bool {
383 matches!(self, AuthStatus::Authenticated { .. })
384 }
385
386 pub fn display_string(&self) -> String {
388 match self {
389 AuthStatus::Authenticated {
390 label,
391 age_seconds,
392 expires_in,
393 } => {
394 let label_str = label
395 .as_ref()
396 .map(|l| format!(" ({})", l))
397 .unwrap_or_default();
398 let age_str = humanize_duration(*age_seconds);
399 let expiry_str = expires_in
400 .map(|e| format!(", expires in {}", humanize_duration(e)))
401 .unwrap_or_default();
402 format!("Authenticated{}, obtained {}{}", label_str, age_str, expiry_str)
403 }
404 AuthStatus::NotAuthenticated => "Not authenticated".to_string(),
405 }
406 }
407}
408
409fn humanize_duration(seconds: u64) -> String {
411 if seconds < 60 {
412 format!("{}s ago", seconds)
413 } else if seconds < 3600 {
414 format!("{}m ago", seconds / 60)
415 } else if seconds < 86400 {
416 format!("{}h ago", seconds / 3600)
417 } else {
418 format!("{}d ago", seconds / 86400)
419 }
420}
421
422#[cfg(test)]
423mod tests {
424 use super::*;
425
426 #[test]
427 fn test_auth_url_generation() {
428 let challenge = PkceChallenge {
429 code_verifier: "test_verifier".to_string(),
430 code_challenge: "test_challenge".to_string(),
431 code_challenge_method: "S256".to_string(),
432 };
433
434 let url = get_auth_url(&challenge, 8484);
435
436 assert!(url.starts_with("https://openrouter.ai/auth"));
437 assert!(url.contains("callback_url="));
438 assert!(url.contains("code_challenge=test_challenge"));
439 assert!(url.contains("code_challenge_method=S256"));
440 }
441
442 #[test]
443 fn test_token_expiry_check() {
444 let now = std::time::SystemTime::now()
445 .duration_since(std::time::UNIX_EPOCH)
446 .unwrap()
447 .as_secs();
448
449 let token = OpenRouterToken {
451 api_key: "test".to_string(),
452 obtained_at: now,
453 expires_at: Some(now + 3600),
454 label: None,
455 };
456 assert!(!token.is_expired());
457
458 let expired_token = OpenRouterToken {
460 api_key: "test".to_string(),
461 obtained_at: now - 7200,
462 expires_at: Some(now - 3600),
463 label: None,
464 };
465 assert!(expired_token.is_expired());
466
467 let no_expiry_token = OpenRouterToken {
469 api_key: "test".to_string(),
470 obtained_at: now,
471 expires_at: None,
472 label: None,
473 };
474 assert!(!no_expiry_token.is_expired());
475 }
476
477 #[test]
478 fn test_encryption_roundtrip() {
479 let token = OpenRouterToken {
480 api_key: "sk-test-key-12345".to_string(),
481 obtained_at: 1234567890,
482 expires_at: Some(1234567890 + 86400),
483 label: Some("Test Token".to_string()),
484 };
485
486 let encrypted = encrypt_token(&token).unwrap();
487 let decrypted = decrypt_token(&encrypted).unwrap();
488
489 assert_eq!(decrypted.api_key, token.api_key);
490 assert_eq!(decrypted.obtained_at, token.obtained_at);
491 assert_eq!(decrypted.expires_at, token.expires_at);
492 assert_eq!(decrypted.label, token.label);
493 }
494
495 #[test]
496 fn test_auth_status_display() {
497 let status = AuthStatus::Authenticated {
498 label: Some("My App".to_string()),
499 age_seconds: 3700,
500 expires_in: Some(86000),
501 };
502
503 let display = status.display_string();
504 assert!(display.contains("Authenticated"));
505 assert!(display.contains("My App"));
506 }
507}