Skip to main content

synaps_cli/core/auth/
mod.rs

1//! OAuth 2.0 Authorization Code + PKCE flow for Anthropic (Claude Pro/Max).
2//!
3//! Implements the same flow as Claude Code and Pi coding agent:
4//! 1. Generate PKCE verifier + challenge
5//! 2. Start localhost callback server
6//! 3. Open browser to claude.ai/oauth/authorize
7//! 4. Capture redirect with auth code
8//! 5. Exchange code for access + refresh tokens
9//! 6. Save to ~/.pi/agent/auth.json (shared with Pi)
10
11use serde::{Deserialize, Serialize};
12use tokio::sync::oneshot;
13
14mod pkce;
15mod callback;
16mod token;
17mod storage;
18mod browser;
19mod openai_codex;
20
21// ── Re-exports ──────────────────────────────────────────────────────────────────
22
23pub use pkce::{generate_code_verifier, generate_code_challenge, generate_state, build_auth_url};
24pub use callback::{CallbackServerHandle, start_callback_server};
25pub use token::{exchange_code_for_tokens, refresh_token, ensure_fresh_token, ensure_fresh_provider_token};
26pub use storage::{auth_file_path, load_auth, load_provider_auth, save_auth, save_provider_auth};
27pub use browser::open_browser;
28pub use openai_codex::{extract_account_id as extract_codex_account_id, login as login_openai_codex};
29
30// ── Constants (match Claude Code / Pi) ──────────────────────────────────────
31
32pub(super) const CLIENT_ID: &str = "9d1c250a-e61b-44d9-88ed-5944d1962f5e";
33pub(super) const AUTHORIZE_URL: &str = "https://claude.ai/oauth/authorize";
34pub(super) const TOKEN_URL: &str = "https://platform.claude.com/v1/oauth/token";
35pub(super) const CALLBACK_HOST: &str = "127.0.0.1";
36pub(super) const CALLBACK_PORT: u16 = 53692;
37pub(super) const SCOPES: &str = "org:create_api_key user:profile user:inference user:sessions:claude_code user:mcp_servers user:file_upload";
38
39// ── Types ───────────────────────────────────────────────────────────────────
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct OAuthCredentials {
43    #[serde(rename = "type")]
44    pub auth_type: String,
45    pub refresh: String,
46    pub access: String,
47    pub expires: u64,
48    #[serde(rename = "accountId", skip_serializing_if = "Option::is_none")]
49    pub account_id: Option<String>,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct AuthFile {
54    pub anthropic: OAuthCredentials,
55    #[serde(rename = "openai-codex", default, skip_serializing_if = "Option::is_none")]
56    pub openai_codex: Option<OAuthCredentials>,
57}
58
59#[derive(Debug, Deserialize)]
60pub(crate) struct TokenResponse {
61    pub(crate) access_token: String,
62    pub(crate) refresh_token: String,
63    pub(crate) expires_in: u64,
64}
65
66/// Result from the OAuth callback.
67#[derive(Debug, Clone)]
68pub struct CallbackResult {
69    pub code: String,
70    pub state: String,
71}
72
73/// Check if the current token is expired (or will expire within 5 minutes).
74pub fn is_token_expired(creds: &OAuthCredentials) -> bool {
75    now_millis() >= creds.expires
76}
77
78// ── Helpers ─────────────────────────────────────────────────────────────────
79
80pub(crate) fn now_millis() -> u64 {
81    crate::epoch_millis()
82}
83
84/// Try to parse manual input as a redirect URL or raw code.
85fn parse_manual_input(input: &str) -> (Option<String>, Option<String>) {
86    let trimmed = input.trim();
87
88    // Try as full URL
89    if let Ok(url) = url::Url::parse(trimmed) {
90        let code = url.query_pairs().find(|(k, _)| k == "code").map(|(_, v)| v.to_string());
91        let state = url.query_pairs().find(|(k, _)| k == "state").map(|(_, v)| v.to_string());
92        if code.is_some() {
93            return (code, state);
94        }
95    }
96
97    // Try as "code#state" format (Claude Code manual flow)
98    if trimmed.contains('#') {
99        let parts: Vec<&str> = trimmed.splitn(2, '#').collect();
100        if parts.len() == 2 && !parts[0].is_empty() && !parts[1].is_empty() {
101            return (Some(parts[0].to_string()), Some(parts[1].to_string()));
102        }
103    }
104
105    // Treat as raw code
106    if !trimmed.is_empty() {
107        return (Some(trimmed.to_string()), None);
108    }
109
110    (None, None)
111}
112
113// ── High-level login flow ───────────────────────────────────────────────────
114
115/// Run the full OAuth login flow. Returns saved credentials.
116pub async fn login() -> std::result::Result<OAuthCredentials, String> {
117    let port = CALLBACK_PORT;
118
119    // 1. Generate PKCE
120    let verifier = generate_code_verifier();
121    let challenge = generate_code_challenge(&verifier);
122    let state = generate_state();
123
124    // 2. Start callback server
125    let (rx, server_handle) = start_callback_server(state.clone(), port).await?;
126
127    // 3. Build URL and open browser
128    let auth_url = build_auth_url(&challenge, &state, port);
129
130    eprintln!("\n\x1b[1mOpening browser to sign in...\x1b[0m\n");
131
132    if let Err(e) = open_browser(&auth_url) {
133        eprintln!("Could not open browser automatically: {}", e);
134    }
135
136    eprintln!("\x1b[2mIf the browser didn't open, visit this URL:\x1b[0m");
137    eprintln!("\x1b[36m{}\x1b[0m\n", auth_url);
138
139    // Also provide manual paste option
140    let (manual_tx, manual_rx) = oneshot::channel::<CallbackResult>();
141    let manual_state = state.clone();
142    let stdin_task = tokio::spawn(async move {
143        eprintln!("\x1b[2mOr paste the authorization code here:\x1b[0m");
144
145        let mut line = String::new();
146        let result = tokio::task::spawn_blocking(move || {
147            std::io::stdin().read_line(&mut line).ok();
148            line.trim().to_string()
149        })
150        .await;
151
152        if let Ok(input) = result {
153            if !input.is_empty() {
154                let (code, parsed_state) = parse_manual_input(&input);
155                if let Some(code) = code {
156                    let _ = manual_tx.send(CallbackResult {
157                        code,
158                        state: parsed_state.unwrap_or(manual_state),
159                    });
160                }
161            }
162        }
163    });
164
165    // 4. Wait for either callback or manual input
166    let result = tokio::select! {
167        callback = rx => {
168            match callback {
169                Ok(result) => result,
170                Err(_) => return Err("Callback channel closed".to_string()),
171            }
172        }
173        manual = manual_rx => {
174            match manual {
175                Ok(result) => result,
176                Err(_) => return Err("Manual input channel closed".to_string()),
177            }
178        }
179    };
180
181    stdin_task.abort();
182
183    // 5. Verify state
184    if result.state != state {
185        server_handle.shutdown().await;
186        return Err("OAuth state mismatch — possible CSRF attack".to_string());
187    }
188
189    eprintln!("\n\x1b[1mExchanging code for tokens...\x1b[0m");
190
191    // 6. Exchange code for tokens
192    let creds = exchange_code_for_tokens(&result.code, &result.state, &verifier, port).await?;
193
194    // 7. Shut down callback server
195    server_handle.shutdown().await;
196
197    // 8. Save to auth.json
198    save_auth(&creds)?;
199
200    Ok(creds)
201}
202
203#[cfg(test)]
204mod tests {
205    use super::*;
206    use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
207
208    #[test]
209    fn test_generate_code_verifier() {
210        let verifier = generate_code_verifier();
211        assert!(!verifier.is_empty(), "Code verifier should not be empty");
212        assert!(verifier.len() > 20, "Code verifier should be longer than 20 characters");
213        let verifier2 = generate_code_verifier();
214        assert_ne!(verifier, verifier2, "Two calls should produce different verifiers");
215    }
216
217    #[test]
218    fn test_generate_code_challenge() {
219        let verifier = "test_verifier_123";
220        let challenge = generate_code_challenge(verifier);
221        assert!(!challenge.is_empty(), "Code challenge should not be empty");
222        let challenge2 = generate_code_challenge(verifier);
223        assert_eq!(challenge, challenge2, "Same verifier should produce same challenge");
224        let different_challenge = generate_code_challenge("different_verifier_456");
225        assert_ne!(challenge, different_challenge, "Different verifiers should produce different challenges");
226    }
227
228    #[test]
229    fn test_generate_state() {
230        let state = generate_state();
231        assert!(!state.is_empty(), "State should not be empty");
232        let state2 = generate_state();
233        assert_ne!(state, state2, "Two calls should produce different states");
234    }
235
236    #[test]
237    fn test_build_auth_url() {
238        let challenge = "test_challenge";
239        let state = "test_state";
240        let port = 8080;
241        let url = build_auth_url(challenge, state, port);
242        assert!(url.contains("claude.ai/oauth/authorize"));
243        assert!(url.contains("client_id=9d1c250a-e61b-44d9-88ed-5944d1962f5e"));
244        assert!(url.contains(&format!("code_challenge={}", challenge)));
245        assert!(url.contains(&format!("state={}", state)));
246        assert!(url.contains("localhost"));
247        assert!(url.contains(&port.to_string()));
248        assert!(url.contains("redirect_uri="));
249    }
250
251    #[test]
252    fn test_is_token_expired() {
253        let expired_creds = OAuthCredentials {
254            auth_type: "oauth".to_string(),
255            refresh: "test_refresh".to_string(),
256            access: "test_access".to_string(),
257            expires: 0,
258            account_id: None,
259        };
260        assert!(is_token_expired(&expired_creds));
261
262        let future_time = now_millis() + 3600000;
263        let fresh_creds = OAuthCredentials {
264            auth_type: "oauth".to_string(),
265            refresh: "test_refresh".to_string(),
266            access: "test_access".to_string(),
267            expires: future_time,
268            account_id: None,
269        };
270        assert!(!is_token_expired(&fresh_creds));
271        assert_eq!(fresh_creds.auth_type, "oauth");
272    }
273
274    #[test]
275    fn test_pkce_challenge_sha256() {
276        let verifier = "test_verifier_string";
277        let challenge = generate_code_challenge(verifier);
278        
279        use sha2::{Sha256, Digest};
280        let mut hasher = Sha256::new();
281        hasher.update(verifier.as_bytes());
282        let hash = hasher.finalize();
283        let expected = URL_SAFE_NO_PAD.encode(hash);
284        
285        assert_eq!(challenge, expected);
286    }
287
288    #[test]
289    fn test_code_verifier_length() {
290        let verifier = generate_code_verifier();
291        assert_eq!(verifier.len(), 43);
292    }
293
294    #[test]
295    fn test_state_length() {
296        let state = generate_state();
297        assert_eq!(state.len(), 43);
298    }
299
300    #[test]
301    fn test_build_auth_url_required_params() {
302        let url = build_auth_url("test_challenge", "test_state", 8080);
303        assert!(url.contains("response_type=code"));
304        assert!(url.contains("code_challenge_method=S256"));
305        assert!(url.contains("scope="));
306        assert!(url.contains("redirect_uri="));
307        assert!(url.contains("8080"));
308    }
309
310    #[test]
311    fn test_is_token_expired_edge_cases() {
312        let current_time = std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_millis() as u64;
313        
314        let exactly_now_creds = OAuthCredentials {
315            auth_type: "oauth".to_string(),
316            refresh: "test_refresh".to_string(),
317            access: "test_access".to_string(),
318            expires: current_time,
319            account_id: None,
320        };
321        assert!(is_token_expired(&exactly_now_creds));
322        
323        let one_ms_future_creds = OAuthCredentials {
324            auth_type: "oauth".to_string(),
325            refresh: "test_refresh".to_string(),
326            access: "test_access".to_string(),
327            expires: current_time + 1,
328            account_id: None,
329        };
330        assert!(!is_token_expired(&one_ms_future_creds));
331    }
332
333    #[test]
334    fn test_auth_file_path() {
335        let path = auth_file_path();
336        let path_str = path.to_string_lossy();
337        assert!(path_str.ends_with("auth.json"));
338    }
339
340    #[test]
341    fn test_oauth_credentials_serialization_roundtrip() {
342        let original_creds = OAuthCredentials {
343            auth_type: "oauth".to_string(),
344            refresh: "test_refresh_token".to_string(),
345            access: "test_access_token".to_string(),
346            expires: 1234567890,
347            account_id: None,
348        };
349        
350        let json = serde_json::to_string(&original_creds).expect("Should serialize");
351        let deserialized_creds: OAuthCredentials = serde_json::from_str(&json).expect("Should deserialize");
352        
353        assert_eq!(original_creds.auth_type, deserialized_creds.auth_type);
354        assert_eq!(original_creds.refresh, deserialized_creds.refresh);
355        assert_eq!(original_creds.access, deserialized_creds.access);
356        assert_eq!(original_creds.expires, deserialized_creds.expires);
357    }
358}