Skip to main content

vtcode_config/auth/
openrouter_oauth.rs

1//! OpenRouter OAuth PKCE authentication flow.
2//!
3//! This module implements the OAuth PKCE flow for OpenRouter, allowing users
4//! to authenticate with their OpenRouter account securely.
5//!
6//! ## Security Model
7//!
8//! Tokens are encrypted at rest using AES-256-GCM with a machine-derived key.
9//! The key is derived from:
10//! - Machine hostname
11//! - User ID (where available)
12//! - A static salt
13//!
14//! This provides reasonable protection against casual access while remaining
15//! portable across the same user's sessions on the same machine.
16
17use anyhow::{Context, Result, anyhow};
18use ring::aead::{self, Aad, LessSafeKey, NONCE_LEN, Nonce, UnboundKey};
19use ring::rand::{SecureRandom, SystemRandom};
20use serde::{Deserialize, Serialize};
21use std::fs;
22use std::path::PathBuf;
23
24use super::pkce::PkceChallenge;
25
26/// OpenRouter API endpoints
27const OPENROUTER_AUTH_URL: &str = "https://openrouter.ai/auth";
28const OPENROUTER_KEYS_URL: &str = "https://openrouter.ai/api/v1/auth/keys";
29
30/// Default callback port for localhost OAuth server
31pub const DEFAULT_CALLBACK_PORT: u16 = 8484;
32
33/// Configuration for OpenRouter OAuth authentication.
34#[derive(Debug, Clone, Serialize, Deserialize)]
35#[serde(default)]
36pub struct OpenRouterOAuthConfig {
37    /// Whether to use OAuth instead of API key
38    pub use_oauth: bool,
39    /// Port for the local callback server
40    pub callback_port: u16,
41    /// Whether to automatically refresh tokens
42    pub auto_refresh: bool,
43}
44
45impl Default for OpenRouterOAuthConfig {
46    fn default() -> Self {
47        Self {
48            use_oauth: false,
49            callback_port: DEFAULT_CALLBACK_PORT,
50            auto_refresh: true,
51        }
52    }
53}
54
55/// Stored OAuth token with metadata.
56#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct OpenRouterToken {
58    /// The API key obtained via OAuth
59    pub api_key: String,
60    /// When the token was obtained (Unix timestamp)
61    pub obtained_at: u64,
62    /// Optional expiry time (Unix timestamp)
63    pub expires_at: Option<u64>,
64    /// User-friendly label for the token
65    pub label: Option<String>,
66}
67
68impl OpenRouterToken {
69    /// Check if the token has expired.
70    pub fn is_expired(&self) -> bool {
71        if let Some(expires_at) = self.expires_at {
72            let now = std::time::SystemTime::now()
73                .duration_since(std::time::UNIX_EPOCH)
74                .map(|d| d.as_secs())
75                .unwrap_or(0);
76            now >= expires_at
77        } else {
78            false
79        }
80    }
81}
82
83/// Encrypted token wrapper for storage.
84#[derive(Debug, Serialize, Deserialize)]
85struct EncryptedToken {
86    /// Base64-encoded nonce
87    nonce: String,
88    /// Base64-encoded ciphertext (includes auth tag)
89    ciphertext: String,
90    /// Version for future format changes
91    version: u8,
92}
93
94/// Generate the OAuth authorization URL.
95///
96/// # Arguments
97/// * `challenge` - PKCE challenge containing the code_challenge
98/// * `callback_port` - Port for the localhost callback server
99///
100/// # Returns
101/// The full authorization URL to redirect the user to.
102pub fn get_auth_url(challenge: &PkceChallenge, callback_port: u16) -> String {
103    let callback_url = format!("http://localhost:{}/callback", callback_port);
104    format!(
105        "{}?callback_url={}&code_challenge={}&code_challenge_method={}",
106        OPENROUTER_AUTH_URL,
107        urlencoding::encode(&callback_url),
108        urlencoding::encode(&challenge.code_challenge),
109        challenge.code_challenge_method
110    )
111}
112
113/// Exchange an authorization code for an API key.
114///
115/// This makes a POST request to OpenRouter's token endpoint with the
116/// authorization code and PKCE verifier.
117///
118/// # Arguments
119/// * `code` - The authorization code from the callback URL
120/// * `challenge` - The PKCE challenge used during authorization
121///
122/// # Returns
123/// The obtained API key on success.
124pub async fn exchange_code_for_token(code: &str, challenge: &PkceChallenge) -> Result<String> {
125    let client = reqwest::Client::new();
126
127    let payload = serde_json::json!({
128        "code": code,
129        "code_verifier": challenge.code_verifier,
130        "code_challenge_method": challenge.code_challenge_method
131    });
132
133    let response = client
134        .post(OPENROUTER_KEYS_URL)
135        .header("Content-Type", "application/json")
136        .json(&payload)
137        .send()
138        .await
139        .context("Failed to send token exchange request")?;
140
141    let status = response.status();
142    let body = response
143        .text()
144        .await
145        .context("Failed to read response body")?;
146
147    if !status.is_success() {
148        // Parse error response for better messages
149        if status.as_u16() == 400 {
150            return Err(anyhow!(
151                "Invalid code_challenge_method. Ensure you're using the same method (S256) in both steps."
152            ));
153        } else if status.as_u16() == 403 {
154            return Err(anyhow!(
155                "Invalid code or code_verifier. The authorization code may have expired."
156            ));
157        } else if status.as_u16() == 405 {
158            return Err(anyhow!(
159                "Method not allowed. Ensure you're using POST over HTTPS."
160            ));
161        }
162        return Err(anyhow!("Token exchange failed (HTTP {}): {}", status, body));
163    }
164
165    // Parse the response to extract the key
166    let response_json: serde_json::Value =
167        serde_json::from_str(&body).context("Failed to parse token response")?;
168
169    let api_key = response_json
170        .get("key")
171        .and_then(|v| v.as_str())
172        .ok_or_else(|| anyhow!("Response missing 'key' field"))?
173        .to_string();
174
175    Ok(api_key)
176}
177
178/// Get the path to the token storage file.
179fn get_token_path() -> Result<PathBuf> {
180    let vtcode_dir = dirs::home_dir()
181        .ok_or_else(|| anyhow!("Could not determine home directory"))?
182        .join(".vtcode")
183        .join("auth");
184
185    fs::create_dir_all(&vtcode_dir).context("Failed to create auth directory")?;
186
187    Ok(vtcode_dir.join("openrouter.json"))
188}
189
190/// Derive encryption key from machine-specific data.
191fn derive_encryption_key() -> Result<LessSafeKey> {
192    use ring::digest::{SHA256, digest};
193
194    // Collect machine-specific entropy
195    let mut key_material = Vec::new();
196
197    // Hostname
198    if let Ok(hostname) = hostname::get() {
199        key_material.extend_from_slice(hostname.as_encoded_bytes());
200    }
201
202    // User ID (Unix) or username (cross-platform fallback)
203    #[cfg(unix)]
204    {
205        key_material.extend_from_slice(&nix::unistd::getuid().as_raw().to_le_bytes());
206    }
207    #[cfg(not(unix))]
208    {
209        if let Ok(user) = std::env::var("USER").or_else(|_| std::env::var("USERNAME")) {
210            key_material.extend_from_slice(user.as_bytes());
211        }
212    }
213
214    // Static salt (not secret, just ensures consistent key derivation)
215    key_material.extend_from_slice(b"vtcode-openrouter-oauth-v1");
216
217    // Hash to get 32-byte key
218    let hash = digest(&SHA256, &key_material);
219    let key_bytes: &[u8; 32] = hash.as_ref()[..32].try_into().context("Hash too short")?;
220
221    let unbound_key = UnboundKey::new(&aead::AES_256_GCM, key_bytes)
222        .map_err(|_| anyhow!("Invalid key length"))?;
223
224    Ok(LessSafeKey::new(unbound_key))
225}
226
227/// Encrypt token data for storage.
228fn encrypt_token(token: &OpenRouterToken) -> Result<EncryptedToken> {
229    let key = derive_encryption_key()?;
230    let rng = SystemRandom::new();
231
232    // Generate random nonce
233    let mut nonce_bytes = [0u8; NONCE_LEN];
234    rng.fill(&mut nonce_bytes)
235        .map_err(|_| anyhow!("Failed to generate nonce"))?;
236
237    // Serialize token to JSON
238    let plaintext = serde_json::to_vec(token).context("Failed to serialize token")?;
239
240    // Encrypt (includes authentication tag)
241    let mut ciphertext = plaintext;
242    let nonce = Nonce::assume_unique_for_key(nonce_bytes);
243    key.seal_in_place_append_tag(nonce, Aad::empty(), &mut ciphertext)
244        .map_err(|_| anyhow!("Encryption failed"))?;
245
246    use base64::{Engine, engine::general_purpose::STANDARD};
247
248    Ok(EncryptedToken {
249        nonce: STANDARD.encode(nonce_bytes),
250        ciphertext: STANDARD.encode(&ciphertext),
251        version: 1,
252    })
253}
254
255/// Decrypt stored token data.
256fn decrypt_token(encrypted: &EncryptedToken) -> Result<OpenRouterToken> {
257    if encrypted.version != 1 {
258        return Err(anyhow!(
259            "Unsupported token format version: {}",
260            encrypted.version
261        ));
262    }
263
264    use base64::{Engine, engine::general_purpose::STANDARD};
265
266    let key = derive_encryption_key()?;
267
268    let nonce_bytes: [u8; NONCE_LEN] = STANDARD
269        .decode(&encrypted.nonce)
270        .context("Invalid nonce encoding")?
271        .try_into()
272        .map_err(|_| anyhow!("Invalid nonce length"))?;
273
274    let mut ciphertext = STANDARD
275        .decode(&encrypted.ciphertext)
276        .context("Invalid ciphertext encoding")?;
277
278    let nonce = Nonce::assume_unique_for_key(nonce_bytes);
279    let plaintext = key
280        .open_in_place(nonce, Aad::empty(), &mut ciphertext)
281        .map_err(|_| {
282            anyhow!("Decryption failed - token may be corrupted or from different machine")
283        })?;
284
285    serde_json::from_slice(plaintext).context("Failed to deserialize token")
286}
287
288/// Save an OAuth token to encrypted storage.
289pub fn save_oauth_token(token: &OpenRouterToken) -> Result<()> {
290    let path = get_token_path()?;
291    let encrypted = encrypt_token(token)?;
292    let json =
293        serde_json::to_string_pretty(&encrypted).context("Failed to serialize encrypted token")?;
294
295    fs::write(&path, json).context("Failed to write token file")?;
296
297    // Set restrictive permissions on Unix
298    #[cfg(unix)]
299    {
300        use std::os::unix::fs::PermissionsExt;
301        let perms = std::fs::Permissions::from_mode(0o600);
302        fs::set_permissions(&path, perms).context("Failed to set token file permissions")?;
303    }
304
305    tracing::info!("OAuth token saved to {}", path.display());
306    Ok(())
307}
308
309/// Load an OAuth token from encrypted storage.
310///
311/// Returns `None` if no token exists or the token has expired.
312pub fn load_oauth_token() -> Result<Option<OpenRouterToken>> {
313    let path = get_token_path()?;
314
315    if !path.exists() {
316        return Ok(None);
317    }
318
319    let json = fs::read_to_string(&path).context("Failed to read token file")?;
320    let encrypted: EncryptedToken =
321        serde_json::from_str(&json).context("Failed to parse token file")?;
322
323    let token = decrypt_token(&encrypted)?;
324
325    // Check expiry
326    if token.is_expired() {
327        tracing::warn!("OAuth token has expired, removing...");
328        clear_oauth_token()?;
329        return Ok(None);
330    }
331
332    Ok(Some(token))
333}
334
335/// Clear the stored OAuth token.
336pub fn clear_oauth_token() -> Result<()> {
337    let path = get_token_path()?;
338
339    if path.exists() {
340        fs::remove_file(&path).context("Failed to remove token file")?;
341        tracing::info!("OAuth token cleared");
342    }
343
344    Ok(())
345}
346
347/// Get the current OAuth authentication status.
348pub fn get_auth_status() -> Result<AuthStatus> {
349    match load_oauth_token()? {
350        Some(token) => {
351            let now = std::time::SystemTime::now()
352                .duration_since(std::time::UNIX_EPOCH)
353                .map(|d| d.as_secs())
354                .unwrap_or(0);
355
356            let age_seconds = now.saturating_sub(token.obtained_at);
357
358            Ok(AuthStatus::Authenticated {
359                label: token.label,
360                age_seconds,
361                expires_in: token.expires_at.map(|e| e.saturating_sub(now)),
362            })
363        }
364        None => Ok(AuthStatus::NotAuthenticated),
365    }
366}
367
368/// OAuth authentication status.
369#[derive(Debug, Clone)]
370pub enum AuthStatus {
371    /// User is authenticated with OAuth
372    Authenticated {
373        /// Optional label for the token
374        label: Option<String>,
375        /// How long ago the token was obtained (seconds)
376        age_seconds: u64,
377        /// Time until expiry (seconds), if known
378        expires_in: Option<u64>,
379    },
380    /// User is not authenticated via OAuth
381    NotAuthenticated,
382}
383
384impl AuthStatus {
385    /// Check if the user is authenticated.
386    pub fn is_authenticated(&self) -> bool {
387        matches!(self, AuthStatus::Authenticated { .. })
388    }
389
390    /// Get a human-readable status string.
391    pub fn display_string(&self) -> String {
392        match self {
393            AuthStatus::Authenticated {
394                label,
395                age_seconds,
396                expires_in,
397            } => {
398                let label_str = label
399                    .as_ref()
400                    .map(|l| format!(" ({})", l))
401                    .unwrap_or_default();
402                let age_str = humanize_duration(*age_seconds);
403                let expiry_str = expires_in
404                    .map(|e| format!(", expires in {}", humanize_duration(e)))
405                    .unwrap_or_default();
406                format!(
407                    "Authenticated{}, obtained {}{}",
408                    label_str, age_str, expiry_str
409                )
410            }
411            AuthStatus::NotAuthenticated => "Not authenticated".to_string(),
412        }
413    }
414}
415
416/// Convert seconds to human-readable duration.
417fn humanize_duration(seconds: u64) -> String {
418    if seconds < 60 {
419        format!("{}s ago", seconds)
420    } else if seconds < 3600 {
421        format!("{}m ago", seconds / 60)
422    } else if seconds < 86400 {
423        format!("{}h ago", seconds / 3600)
424    } else {
425        format!("{}d ago", seconds / 86400)
426    }
427}
428
429#[cfg(test)]
430mod tests {
431    use super::*;
432
433    #[test]
434    fn test_auth_url_generation() {
435        let challenge = PkceChallenge {
436            code_verifier: "test_verifier".to_string(),
437            code_challenge: "test_challenge".to_string(),
438            code_challenge_method: "S256".to_string(),
439        };
440
441        let url = get_auth_url(&challenge, 8484);
442
443        assert!(url.starts_with("https://openrouter.ai/auth"));
444        assert!(url.contains("callback_url="));
445        assert!(url.contains("code_challenge=test_challenge"));
446        assert!(url.contains("code_challenge_method=S256"));
447    }
448
449    #[test]
450    fn test_token_expiry_check() {
451        let now = std::time::SystemTime::now()
452            .duration_since(std::time::UNIX_EPOCH)
453            .unwrap()
454            .as_secs();
455
456        // Non-expired token
457        let token = OpenRouterToken {
458            api_key: "test".to_string(),
459            obtained_at: now,
460            expires_at: Some(now + 3600),
461            label: None,
462        };
463        assert!(!token.is_expired());
464
465        // Expired token
466        let expired_token = OpenRouterToken {
467            api_key: "test".to_string(),
468            obtained_at: now - 7200,
469            expires_at: Some(now - 3600),
470            label: None,
471        };
472        assert!(expired_token.is_expired());
473
474        // No expiry
475        let no_expiry_token = OpenRouterToken {
476            api_key: "test".to_string(),
477            obtained_at: now,
478            expires_at: None,
479            label: None,
480        };
481        assert!(!no_expiry_token.is_expired());
482    }
483
484    #[test]
485    fn test_encryption_roundtrip() {
486        let token = OpenRouterToken {
487            api_key: "sk-test-key-12345".to_string(),
488            obtained_at: 1234567890,
489            expires_at: Some(1234567890 + 86400),
490            label: Some("Test Token".to_string()),
491        };
492
493        let encrypted = encrypt_token(&token).unwrap();
494        let decrypted = decrypt_token(&encrypted).unwrap();
495
496        assert_eq!(decrypted.api_key, token.api_key);
497        assert_eq!(decrypted.obtained_at, token.obtained_at);
498        assert_eq!(decrypted.expires_at, token.expires_at);
499        assert_eq!(decrypted.label, token.label);
500    }
501
502    #[test]
503    fn test_auth_status_display() {
504        let status = AuthStatus::Authenticated {
505            label: Some("My App".to_string()),
506            age_seconds: 3700,
507            expires_in: Some(86000),
508        };
509
510        let display = status.display_string();
511        assert!(display.contains("Authenticated"));
512        assert!(display.contains("My App"));
513    }
514}