1use anyhow::{Context, Result, anyhow};
18use ring::aead::{self, Aad, LessSafeKey, NONCE_LEN, Nonce, UnboundKey};
19use ring::rand::{SecureRandom, SystemRandom};
20use serde::{Deserialize, Serialize};
21use std::fs;
22use std::path::PathBuf;
23
24use super::pkce::PkceChallenge;
25
26const OPENROUTER_AUTH_URL: &str = "https://openrouter.ai/auth";
28const OPENROUTER_KEYS_URL: &str = "https://openrouter.ai/api/v1/auth/keys";
29
30pub const DEFAULT_CALLBACK_PORT: u16 = 8484;
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
35#[serde(default)]
36pub struct OpenRouterOAuthConfig {
37 pub use_oauth: bool,
39 pub callback_port: u16,
41 pub auto_refresh: bool,
43}
44
45impl Default for OpenRouterOAuthConfig {
46 fn default() -> Self {
47 Self {
48 use_oauth: false,
49 callback_port: DEFAULT_CALLBACK_PORT,
50 auto_refresh: true,
51 }
52 }
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct OpenRouterToken {
58 pub api_key: String,
60 pub obtained_at: u64,
62 pub expires_at: Option<u64>,
64 pub label: Option<String>,
66}
67
68impl OpenRouterToken {
69 pub fn is_expired(&self) -> bool {
71 if let Some(expires_at) = self.expires_at {
72 let now = std::time::SystemTime::now()
73 .duration_since(std::time::UNIX_EPOCH)
74 .map(|d| d.as_secs())
75 .unwrap_or(0);
76 now >= expires_at
77 } else {
78 false
79 }
80 }
81}
82
83#[derive(Debug, Serialize, Deserialize)]
85struct EncryptedToken {
86 nonce: String,
88 ciphertext: String,
90 version: u8,
92}
93
94pub fn get_auth_url(challenge: &PkceChallenge, callback_port: u16) -> String {
103 let callback_url = format!("http://localhost:{}/callback", callback_port);
104 format!(
105 "{}?callback_url={}&code_challenge={}&code_challenge_method={}",
106 OPENROUTER_AUTH_URL,
107 urlencoding::encode(&callback_url),
108 urlencoding::encode(&challenge.code_challenge),
109 challenge.code_challenge_method
110 )
111}
112
113pub async fn exchange_code_for_token(code: &str, challenge: &PkceChallenge) -> Result<String> {
125 let client = reqwest::Client::new();
126
127 let payload = serde_json::json!({
128 "code": code,
129 "code_verifier": challenge.code_verifier,
130 "code_challenge_method": challenge.code_challenge_method
131 });
132
133 let response = client
134 .post(OPENROUTER_KEYS_URL)
135 .header("Content-Type", "application/json")
136 .json(&payload)
137 .send()
138 .await
139 .context("Failed to send token exchange request")?;
140
141 let status = response.status();
142 let body = response
143 .text()
144 .await
145 .context("Failed to read response body")?;
146
147 if !status.is_success() {
148 if status.as_u16() == 400 {
150 return Err(anyhow!(
151 "Invalid code_challenge_method. Ensure you're using the same method (S256) in both steps."
152 ));
153 } else if status.as_u16() == 403 {
154 return Err(anyhow!(
155 "Invalid code or code_verifier. The authorization code may have expired."
156 ));
157 } else if status.as_u16() == 405 {
158 return Err(anyhow!(
159 "Method not allowed. Ensure you're using POST over HTTPS."
160 ));
161 }
162 return Err(anyhow!("Token exchange failed (HTTP {}): {}", status, body));
163 }
164
165 let response_json: serde_json::Value =
167 serde_json::from_str(&body).context("Failed to parse token response")?;
168
169 let api_key = response_json
170 .get("key")
171 .and_then(|v| v.as_str())
172 .ok_or_else(|| anyhow!("Response missing 'key' field"))?
173 .to_string();
174
175 Ok(api_key)
176}
177
178fn get_token_path() -> Result<PathBuf> {
180 let vtcode_dir = dirs::home_dir()
181 .ok_or_else(|| anyhow!("Could not determine home directory"))?
182 .join(".vtcode")
183 .join("auth");
184
185 fs::create_dir_all(&vtcode_dir).context("Failed to create auth directory")?;
186
187 Ok(vtcode_dir.join("openrouter.json"))
188}
189
190fn derive_encryption_key() -> Result<LessSafeKey> {
192 use ring::digest::{SHA256, digest};
193
194 let mut key_material = Vec::new();
196
197 if let Ok(hostname) = hostname::get() {
199 key_material.extend_from_slice(hostname.as_encoded_bytes());
200 }
201
202 #[cfg(unix)]
204 {
205 key_material.extend_from_slice(&nix::unistd::getuid().as_raw().to_le_bytes());
206 }
207 #[cfg(not(unix))]
208 {
209 if let Ok(user) = std::env::var("USER").or_else(|_| std::env::var("USERNAME")) {
210 key_material.extend_from_slice(user.as_bytes());
211 }
212 }
213
214 key_material.extend_from_slice(b"vtcode-openrouter-oauth-v1");
216
217 let hash = digest(&SHA256, &key_material);
219 let key_bytes: &[u8; 32] = hash.as_ref()[..32].try_into().context("Hash too short")?;
220
221 let unbound_key = UnboundKey::new(&aead::AES_256_GCM, key_bytes)
222 .map_err(|_| anyhow!("Invalid key length"))?;
223
224 Ok(LessSafeKey::new(unbound_key))
225}
226
227fn encrypt_token(token: &OpenRouterToken) -> Result<EncryptedToken> {
229 let key = derive_encryption_key()?;
230 let rng = SystemRandom::new();
231
232 let mut nonce_bytes = [0u8; NONCE_LEN];
234 rng.fill(&mut nonce_bytes)
235 .map_err(|_| anyhow!("Failed to generate nonce"))?;
236
237 let plaintext = serde_json::to_vec(token).context("Failed to serialize token")?;
239
240 let mut ciphertext = plaintext;
242 let nonce = Nonce::assume_unique_for_key(nonce_bytes);
243 key.seal_in_place_append_tag(nonce, Aad::empty(), &mut ciphertext)
244 .map_err(|_| anyhow!("Encryption failed"))?;
245
246 use base64::{Engine, engine::general_purpose::STANDARD};
247
248 Ok(EncryptedToken {
249 nonce: STANDARD.encode(nonce_bytes),
250 ciphertext: STANDARD.encode(&ciphertext),
251 version: 1,
252 })
253}
254
255fn decrypt_token(encrypted: &EncryptedToken) -> Result<OpenRouterToken> {
257 if encrypted.version != 1 {
258 return Err(anyhow!(
259 "Unsupported token format version: {}",
260 encrypted.version
261 ));
262 }
263
264 use base64::{Engine, engine::general_purpose::STANDARD};
265
266 let key = derive_encryption_key()?;
267
268 let nonce_bytes: [u8; NONCE_LEN] = STANDARD
269 .decode(&encrypted.nonce)
270 .context("Invalid nonce encoding")?
271 .try_into()
272 .map_err(|_| anyhow!("Invalid nonce length"))?;
273
274 let mut ciphertext = STANDARD
275 .decode(&encrypted.ciphertext)
276 .context("Invalid ciphertext encoding")?;
277
278 let nonce = Nonce::assume_unique_for_key(nonce_bytes);
279 let plaintext = key
280 .open_in_place(nonce, Aad::empty(), &mut ciphertext)
281 .map_err(|_| {
282 anyhow!("Decryption failed - token may be corrupted or from different machine")
283 })?;
284
285 serde_json::from_slice(plaintext).context("Failed to deserialize token")
286}
287
288pub fn save_oauth_token(token: &OpenRouterToken) -> Result<()> {
290 let path = get_token_path()?;
291 let encrypted = encrypt_token(token)?;
292 let json =
293 serde_json::to_string_pretty(&encrypted).context("Failed to serialize encrypted token")?;
294
295 fs::write(&path, json).context("Failed to write token file")?;
296
297 #[cfg(unix)]
299 {
300 use std::os::unix::fs::PermissionsExt;
301 let perms = std::fs::Permissions::from_mode(0o600);
302 fs::set_permissions(&path, perms).context("Failed to set token file permissions")?;
303 }
304
305 tracing::info!("OAuth token saved to {}", path.display());
306 Ok(())
307}
308
309pub fn load_oauth_token() -> Result<Option<OpenRouterToken>> {
313 let path = get_token_path()?;
314
315 if !path.exists() {
316 return Ok(None);
317 }
318
319 let json = fs::read_to_string(&path).context("Failed to read token file")?;
320 let encrypted: EncryptedToken =
321 serde_json::from_str(&json).context("Failed to parse token file")?;
322
323 let token = decrypt_token(&encrypted)?;
324
325 if token.is_expired() {
327 tracing::warn!("OAuth token has expired, removing...");
328 clear_oauth_token()?;
329 return Ok(None);
330 }
331
332 Ok(Some(token))
333}
334
335pub fn clear_oauth_token() -> Result<()> {
337 let path = get_token_path()?;
338
339 if path.exists() {
340 fs::remove_file(&path).context("Failed to remove token file")?;
341 tracing::info!("OAuth token cleared");
342 }
343
344 Ok(())
345}
346
347pub fn get_auth_status() -> Result<AuthStatus> {
349 match load_oauth_token()? {
350 Some(token) => {
351 let now = std::time::SystemTime::now()
352 .duration_since(std::time::UNIX_EPOCH)
353 .map(|d| d.as_secs())
354 .unwrap_or(0);
355
356 let age_seconds = now.saturating_sub(token.obtained_at);
357
358 Ok(AuthStatus::Authenticated {
359 label: token.label,
360 age_seconds,
361 expires_in: token.expires_at.map(|e| e.saturating_sub(now)),
362 })
363 }
364 None => Ok(AuthStatus::NotAuthenticated),
365 }
366}
367
368#[derive(Debug, Clone)]
370pub enum AuthStatus {
371 Authenticated {
373 label: Option<String>,
375 age_seconds: u64,
377 expires_in: Option<u64>,
379 },
380 NotAuthenticated,
382}
383
384impl AuthStatus {
385 pub fn is_authenticated(&self) -> bool {
387 matches!(self, AuthStatus::Authenticated { .. })
388 }
389
390 pub fn display_string(&self) -> String {
392 match self {
393 AuthStatus::Authenticated {
394 label,
395 age_seconds,
396 expires_in,
397 } => {
398 let label_str = label
399 .as_ref()
400 .map(|l| format!(" ({})", l))
401 .unwrap_or_default();
402 let age_str = humanize_duration(*age_seconds);
403 let expiry_str = expires_in
404 .map(|e| format!(", expires in {}", humanize_duration(e)))
405 .unwrap_or_default();
406 format!(
407 "Authenticated{}, obtained {}{}",
408 label_str, age_str, expiry_str
409 )
410 }
411 AuthStatus::NotAuthenticated => "Not authenticated".to_string(),
412 }
413 }
414}
415
416fn humanize_duration(seconds: u64) -> String {
418 if seconds < 60 {
419 format!("{}s ago", seconds)
420 } else if seconds < 3600 {
421 format!("{}m ago", seconds / 60)
422 } else if seconds < 86400 {
423 format!("{}h ago", seconds / 3600)
424 } else {
425 format!("{}d ago", seconds / 86400)
426 }
427}
428
429#[cfg(test)]
430mod tests {
431 use super::*;
432
433 #[test]
434 fn test_auth_url_generation() {
435 let challenge = PkceChallenge {
436 code_verifier: "test_verifier".to_string(),
437 code_challenge: "test_challenge".to_string(),
438 code_challenge_method: "S256".to_string(),
439 };
440
441 let url = get_auth_url(&challenge, 8484);
442
443 assert!(url.starts_with("https://openrouter.ai/auth"));
444 assert!(url.contains("callback_url="));
445 assert!(url.contains("code_challenge=test_challenge"));
446 assert!(url.contains("code_challenge_method=S256"));
447 }
448
449 #[test]
450 fn test_token_expiry_check() {
451 let now = std::time::SystemTime::now()
452 .duration_since(std::time::UNIX_EPOCH)
453 .unwrap()
454 .as_secs();
455
456 let token = OpenRouterToken {
458 api_key: "test".to_string(),
459 obtained_at: now,
460 expires_at: Some(now + 3600),
461 label: None,
462 };
463 assert!(!token.is_expired());
464
465 let expired_token = OpenRouterToken {
467 api_key: "test".to_string(),
468 obtained_at: now - 7200,
469 expires_at: Some(now - 3600),
470 label: None,
471 };
472 assert!(expired_token.is_expired());
473
474 let no_expiry_token = OpenRouterToken {
476 api_key: "test".to_string(),
477 obtained_at: now,
478 expires_at: None,
479 label: None,
480 };
481 assert!(!no_expiry_token.is_expired());
482 }
483
484 #[test]
485 fn test_encryption_roundtrip() {
486 let token = OpenRouterToken {
487 api_key: "sk-test-key-12345".to_string(),
488 obtained_at: 1234567890,
489 expires_at: Some(1234567890 + 86400),
490 label: Some("Test Token".to_string()),
491 };
492
493 let encrypted = encrypt_token(&token).unwrap();
494 let decrypted = decrypt_token(&encrypted).unwrap();
495
496 assert_eq!(decrypted.api_key, token.api_key);
497 assert_eq!(decrypted.obtained_at, token.obtained_at);
498 assert_eq!(decrypted.expires_at, token.expires_at);
499 assert_eq!(decrypted.label, token.label);
500 }
501
502 #[test]
503 fn test_auth_status_display() {
504 let status = AuthStatus::Authenticated {
505 label: Some("My App".to_string()),
506 age_seconds: 3700,
507 expires_in: Some(86000),
508 };
509
510 let display = status.display_string();
511 assert!(display.contains("Authenticated"));
512 assert!(display.contains("My App"));
513 }
514}