Skip to main content

vtcode_config/auth/
openrouter_oauth.rs

1//! OpenRouter OAuth PKCE authentication flow.
2//!
3//! This module implements the OAuth PKCE flow for OpenRouter, allowing users
4//! to authenticate with their OpenRouter account securely.
5//!
6//! ## Security Model
7//!
8//! Tokens are encrypted at rest using AES-256-GCM with a machine-derived key.
9//! The key is derived from:
10//! - Machine hostname
11//! - User ID (where available)
12//! - A static salt
13//!
14//! This provides reasonable protection against casual access while remaining
15//! portable across the same user's sessions on the same machine.
16
17use anyhow::{Context, Result, anyhow};
18use ring::aead::{self, Aad, LessSafeKey, Nonce, UnboundKey, NONCE_LEN};
19use ring::rand::{SecureRandom, SystemRandom};
20use serde::{Deserialize, Serialize};
21use std::fs;
22use std::path::PathBuf;
23
24use super::pkce::PkceChallenge;
25
26/// OpenRouter API endpoints
27const OPENROUTER_AUTH_URL: &str = "https://openrouter.ai/auth";
28const OPENROUTER_KEYS_URL: &str = "https://openrouter.ai/api/v1/auth/keys";
29
30/// Default callback port for localhost OAuth server
31pub const DEFAULT_CALLBACK_PORT: u16 = 8484;
32
33/// Configuration for OpenRouter OAuth authentication.
34#[derive(Debug, Clone, Serialize, Deserialize)]
35#[serde(default)]
36pub struct OpenRouterOAuthConfig {
37    /// Whether to use OAuth instead of API key
38    pub use_oauth: bool,
39    /// Port for the local callback server
40    pub callback_port: u16,
41    /// Whether to automatically refresh tokens
42    pub auto_refresh: bool,
43}
44
45impl Default for OpenRouterOAuthConfig {
46    fn default() -> Self {
47        Self {
48            use_oauth: false,
49            callback_port: DEFAULT_CALLBACK_PORT,
50            auto_refresh: true,
51        }
52    }
53}
54
55/// Stored OAuth token with metadata.
56#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct OpenRouterToken {
58    /// The API key obtained via OAuth
59    pub api_key: String,
60    /// When the token was obtained (Unix timestamp)
61    pub obtained_at: u64,
62    /// Optional expiry time (Unix timestamp)
63    pub expires_at: Option<u64>,
64    /// User-friendly label for the token
65    pub label: Option<String>,
66}
67
68impl OpenRouterToken {
69    /// Check if the token has expired.
70    pub fn is_expired(&self) -> bool {
71        if let Some(expires_at) = self.expires_at {
72            let now = std::time::SystemTime::now()
73                .duration_since(std::time::UNIX_EPOCH)
74                .map(|d| d.as_secs())
75                .unwrap_or(0);
76            now >= expires_at
77        } else {
78            false
79        }
80    }
81}
82
83/// Encrypted token wrapper for storage.
84#[derive(Debug, Serialize, Deserialize)]
85struct EncryptedToken {
86    /// Base64-encoded nonce
87    nonce: String,
88    /// Base64-encoded ciphertext (includes auth tag)
89    ciphertext: String,
90    /// Version for future format changes
91    version: u8,
92}
93
94/// Generate the OAuth authorization URL.
95///
96/// # Arguments
97/// * `challenge` - PKCE challenge containing the code_challenge
98/// * `callback_port` - Port for the localhost callback server
99///
100/// # Returns
101/// The full authorization URL to redirect the user to.
102pub fn get_auth_url(challenge: &PkceChallenge, callback_port: u16) -> String {
103    let callback_url = format!("http://localhost:{}/callback", callback_port);
104    format!(
105        "{}?callback_url={}&code_challenge={}&code_challenge_method={}",
106        OPENROUTER_AUTH_URL,
107        urlencoding::encode(&callback_url),
108        urlencoding::encode(&challenge.code_challenge),
109        challenge.code_challenge_method
110    )
111}
112
113/// Exchange an authorization code for an API key.
114///
115/// This makes a POST request to OpenRouter's token endpoint with the
116/// authorization code and PKCE verifier.
117///
118/// # Arguments
119/// * `code` - The authorization code from the callback URL
120/// * `challenge` - The PKCE challenge used during authorization
121///
122/// # Returns
123/// The obtained API key on success.
124pub async fn exchange_code_for_token(code: &str, challenge: &PkceChallenge) -> Result<String> {
125    let client = reqwest::Client::new();
126
127    let payload = serde_json::json!({
128        "code": code,
129        "code_verifier": challenge.code_verifier,
130        "code_challenge_method": challenge.code_challenge_method
131    });
132
133    let response = client
134        .post(OPENROUTER_KEYS_URL)
135        .header("Content-Type", "application/json")
136        .json(&payload)
137        .send()
138        .await
139        .context("Failed to send token exchange request")?;
140
141    let status = response.status();
142    let body = response.text().await.context("Failed to read response body")?;
143
144    if !status.is_success() {
145        // Parse error response for better messages
146        if status.as_u16() == 400 {
147            return Err(anyhow!(
148                "Invalid code_challenge_method. Ensure you're using the same method (S256) in both steps."
149            ));
150        } else if status.as_u16() == 403 {
151            return Err(anyhow!(
152                "Invalid code or code_verifier. The authorization code may have expired."
153            ));
154        } else if status.as_u16() == 405 {
155            return Err(anyhow!(
156                "Method not allowed. Ensure you're using POST over HTTPS."
157            ));
158        }
159        return Err(anyhow!("Token exchange failed (HTTP {}): {}", status, body));
160    }
161
162    // Parse the response to extract the key
163    let response_json: serde_json::Value =
164        serde_json::from_str(&body).context("Failed to parse token response")?;
165
166    let api_key = response_json
167        .get("key")
168        .and_then(|v| v.as_str())
169        .ok_or_else(|| anyhow!("Response missing 'key' field"))?
170        .to_string();
171
172    Ok(api_key)
173}
174
175/// Get the path to the token storage file.
176fn get_token_path() -> Result<PathBuf> {
177    let vtcode_dir = dirs::home_dir()
178        .ok_or_else(|| anyhow!("Could not determine home directory"))?
179        .join(".vtcode")
180        .join("auth");
181
182    fs::create_dir_all(&vtcode_dir).context("Failed to create auth directory")?;
183
184    Ok(vtcode_dir.join("openrouter.json"))
185}
186
187/// Derive encryption key from machine-specific data.
188fn derive_encryption_key() -> Result<LessSafeKey> {
189    use ring::digest::{SHA256, digest};
190
191    // Collect machine-specific entropy
192    let mut key_material = Vec::new();
193
194    // Hostname
195    if let Ok(hostname) = hostname::get() {
196        key_material.extend_from_slice(hostname.as_encoded_bytes());
197    }
198
199    // User ID (Unix) or username (cross-platform fallback)
200    #[cfg(unix)]
201    {
202        key_material.extend_from_slice(&nix::unistd::getuid().as_raw().to_le_bytes());
203    }
204    #[cfg(not(unix))]
205    {
206        if let Ok(user) = std::env::var("USER").or_else(|_| std::env::var("USERNAME")) {
207            key_material.extend_from_slice(user.as_bytes());
208        }
209    }
210
211    // Static salt (not secret, just ensures consistent key derivation)
212    key_material.extend_from_slice(b"vtcode-openrouter-oauth-v1");
213
214    // Hash to get 32-byte key
215    let hash = digest(&SHA256, &key_material);
216    let key_bytes: &[u8; 32] = hash.as_ref()[..32]
217        .try_into()
218        .context("Hash too short")?;
219
220    let unbound_key =
221        UnboundKey::new(&aead::AES_256_GCM, key_bytes).map_err(|_| anyhow!("Invalid key length"))?;
222
223    Ok(LessSafeKey::new(unbound_key))
224}
225
226/// Encrypt token data for storage.
227fn encrypt_token(token: &OpenRouterToken) -> Result<EncryptedToken> {
228    let key = derive_encryption_key()?;
229    let rng = SystemRandom::new();
230
231    // Generate random nonce
232    let mut nonce_bytes = [0u8; NONCE_LEN];
233    rng.fill(&mut nonce_bytes)
234        .map_err(|_| anyhow!("Failed to generate nonce"))?;
235
236    // Serialize token to JSON
237    let plaintext = serde_json::to_vec(token).context("Failed to serialize token")?;
238
239    // Encrypt (includes authentication tag)
240    let mut ciphertext = plaintext;
241    let nonce = Nonce::assume_unique_for_key(nonce_bytes);
242    key.seal_in_place_append_tag(nonce, Aad::empty(), &mut ciphertext)
243        .map_err(|_| anyhow!("Encryption failed"))?;
244
245    use base64::{Engine, engine::general_purpose::STANDARD};
246
247    Ok(EncryptedToken {
248        nonce: STANDARD.encode(nonce_bytes),
249        ciphertext: STANDARD.encode(&ciphertext),
250        version: 1,
251    })
252}
253
254/// Decrypt stored token data.
255fn decrypt_token(encrypted: &EncryptedToken) -> Result<OpenRouterToken> {
256    if encrypted.version != 1 {
257        return Err(anyhow!(
258            "Unsupported token format version: {}",
259            encrypted.version
260        ));
261    }
262
263    use base64::{Engine, engine::general_purpose::STANDARD};
264
265    let key = derive_encryption_key()?;
266
267    let nonce_bytes: [u8; NONCE_LEN] = STANDARD
268        .decode(&encrypted.nonce)
269        .context("Invalid nonce encoding")?
270        .try_into()
271        .map_err(|_| anyhow!("Invalid nonce length"))?;
272
273    let mut ciphertext = STANDARD
274        .decode(&encrypted.ciphertext)
275        .context("Invalid ciphertext encoding")?;
276
277    let nonce = Nonce::assume_unique_for_key(nonce_bytes);
278    let plaintext = key
279        .open_in_place(nonce, Aad::empty(), &mut ciphertext)
280        .map_err(|_| anyhow!("Decryption failed - token may be corrupted or from different machine"))?;
281
282    serde_json::from_slice(plaintext).context("Failed to deserialize token")
283}
284
285/// Save an OAuth token to encrypted storage.
286pub fn save_oauth_token(token: &OpenRouterToken) -> Result<()> {
287    let path = get_token_path()?;
288    let encrypted = encrypt_token(token)?;
289    let json = serde_json::to_string_pretty(&encrypted).context("Failed to serialize encrypted token")?;
290
291    fs::write(&path, json).context("Failed to write token file")?;
292
293    // Set restrictive permissions on Unix
294    #[cfg(unix)]
295    {
296        use std::os::unix::fs::PermissionsExt;
297        let perms = std::fs::Permissions::from_mode(0o600);
298        fs::set_permissions(&path, perms).context("Failed to set token file permissions")?;
299    }
300
301    tracing::info!("OAuth token saved to {}", path.display());
302    Ok(())
303}
304
305/// Load an OAuth token from encrypted storage.
306///
307/// Returns `None` if no token exists or the token has expired.
308pub fn load_oauth_token() -> Result<Option<OpenRouterToken>> {
309    let path = get_token_path()?;
310
311    if !path.exists() {
312        return Ok(None);
313    }
314
315    let json = fs::read_to_string(&path).context("Failed to read token file")?;
316    let encrypted: EncryptedToken =
317        serde_json::from_str(&json).context("Failed to parse token file")?;
318
319    let token = decrypt_token(&encrypted)?;
320
321    // Check expiry
322    if token.is_expired() {
323        tracing::warn!("OAuth token has expired, removing...");
324        clear_oauth_token()?;
325        return Ok(None);
326    }
327
328    Ok(Some(token))
329}
330
331/// Clear the stored OAuth token.
332pub fn clear_oauth_token() -> Result<()> {
333    let path = get_token_path()?;
334
335    if path.exists() {
336        fs::remove_file(&path).context("Failed to remove token file")?;
337        tracing::info!("OAuth token cleared");
338    }
339
340    Ok(())
341}
342
343/// Get the current OAuth authentication status.
344pub fn get_auth_status() -> Result<AuthStatus> {
345    match load_oauth_token()? {
346        Some(token) => {
347            let now = std::time::SystemTime::now()
348                .duration_since(std::time::UNIX_EPOCH)
349                .map(|d| d.as_secs())
350                .unwrap_or(0);
351
352            let age_seconds = now.saturating_sub(token.obtained_at);
353
354            Ok(AuthStatus::Authenticated {
355                label: token.label,
356                age_seconds,
357                expires_in: token.expires_at.map(|e| e.saturating_sub(now)),
358            })
359        }
360        None => Ok(AuthStatus::NotAuthenticated),
361    }
362}
363
364/// OAuth authentication status.
365#[derive(Debug, Clone)]
366pub enum AuthStatus {
367    /// User is authenticated with OAuth
368    Authenticated {
369        /// Optional label for the token
370        label: Option<String>,
371        /// How long ago the token was obtained (seconds)
372        age_seconds: u64,
373        /// Time until expiry (seconds), if known
374        expires_in: Option<u64>,
375    },
376    /// User is not authenticated via OAuth
377    NotAuthenticated,
378}
379
380impl AuthStatus {
381    /// Check if the user is authenticated.
382    pub fn is_authenticated(&self) -> bool {
383        matches!(self, AuthStatus::Authenticated { .. })
384    }
385
386    /// Get a human-readable status string.
387    pub fn display_string(&self) -> String {
388        match self {
389            AuthStatus::Authenticated {
390                label,
391                age_seconds,
392                expires_in,
393            } => {
394                let label_str = label
395                    .as_ref()
396                    .map(|l| format!(" ({})", l))
397                    .unwrap_or_default();
398                let age_str = humanize_duration(*age_seconds);
399                let expiry_str = expires_in
400                    .map(|e| format!(", expires in {}", humanize_duration(e)))
401                    .unwrap_or_default();
402                format!("Authenticated{}, obtained {}{}", label_str, age_str, expiry_str)
403            }
404            AuthStatus::NotAuthenticated => "Not authenticated".to_string(),
405        }
406    }
407}
408
409/// Convert seconds to human-readable duration.
410fn humanize_duration(seconds: u64) -> String {
411    if seconds < 60 {
412        format!("{}s ago", seconds)
413    } else if seconds < 3600 {
414        format!("{}m ago", seconds / 60)
415    } else if seconds < 86400 {
416        format!("{}h ago", seconds / 3600)
417    } else {
418        format!("{}d ago", seconds / 86400)
419    }
420}
421
422#[cfg(test)]
423mod tests {
424    use super::*;
425
426    #[test]
427    fn test_auth_url_generation() {
428        let challenge = PkceChallenge {
429            code_verifier: "test_verifier".to_string(),
430            code_challenge: "test_challenge".to_string(),
431            code_challenge_method: "S256".to_string(),
432        };
433
434        let url = get_auth_url(&challenge, 8484);
435
436        assert!(url.starts_with("https://openrouter.ai/auth"));
437        assert!(url.contains("callback_url="));
438        assert!(url.contains("code_challenge=test_challenge"));
439        assert!(url.contains("code_challenge_method=S256"));
440    }
441
442    #[test]
443    fn test_token_expiry_check() {
444        let now = std::time::SystemTime::now()
445            .duration_since(std::time::UNIX_EPOCH)
446            .unwrap()
447            .as_secs();
448
449        // Non-expired token
450        let token = OpenRouterToken {
451            api_key: "test".to_string(),
452            obtained_at: now,
453            expires_at: Some(now + 3600),
454            label: None,
455        };
456        assert!(!token.is_expired());
457
458        // Expired token
459        let expired_token = OpenRouterToken {
460            api_key: "test".to_string(),
461            obtained_at: now - 7200,
462            expires_at: Some(now - 3600),
463            label: None,
464        };
465        assert!(expired_token.is_expired());
466
467        // No expiry
468        let no_expiry_token = OpenRouterToken {
469            api_key: "test".to_string(),
470            obtained_at: now,
471            expires_at: None,
472            label: None,
473        };
474        assert!(!no_expiry_token.is_expired());
475    }
476
477    #[test]
478    fn test_encryption_roundtrip() {
479        let token = OpenRouterToken {
480            api_key: "sk-test-key-12345".to_string(),
481            obtained_at: 1234567890,
482            expires_at: Some(1234567890 + 86400),
483            label: Some("Test Token".to_string()),
484        };
485
486        let encrypted = encrypt_token(&token).unwrap();
487        let decrypted = decrypt_token(&encrypted).unwrap();
488
489        assert_eq!(decrypted.api_key, token.api_key);
490        assert_eq!(decrypted.obtained_at, token.obtained_at);
491        assert_eq!(decrypted.expires_at, token.expires_at);
492        assert_eq!(decrypted.label, token.label);
493    }
494
495    #[test]
496    fn test_auth_status_display() {
497        let status = AuthStatus::Authenticated {
498            label: Some("My App".to_string()),
499            age_seconds: 3700,
500            expires_in: Some(86000),
501        };
502
503        let display = status.display_string();
504        assert!(display.contains("Authenticated"));
505        assert!(display.contains("My App"));
506    }
507}