1use serde::{Deserialize, Serialize};
12use tokio::sync::oneshot;
13
14mod pkce;
15mod callback;
16mod token;
17mod storage;
18mod browser;
19mod openai_codex;
20
21pub use pkce::{generate_code_verifier, generate_code_challenge, generate_state, build_auth_url};
24pub use callback::{CallbackServerHandle, start_callback_server};
25pub use token::{exchange_code_for_tokens, refresh_token, ensure_fresh_token, ensure_fresh_provider_token};
26pub use storage::{auth_file_path, load_auth, load_provider_auth, save_auth, save_provider_auth};
27pub use browser::open_browser;
28pub use openai_codex::{extract_account_id as extract_codex_account_id, login as login_openai_codex};
29
30pub(super) const CLIENT_ID: &str = "9d1c250a-e61b-44d9-88ed-5944d1962f5e";
33pub(super) const AUTHORIZE_URL: &str = "https://claude.ai/oauth/authorize";
34pub(super) const TOKEN_URL: &str = "https://platform.claude.com/v1/oauth/token";
35pub(super) const CALLBACK_HOST: &str = "127.0.0.1";
36pub(super) const CALLBACK_PORT: u16 = 53692;
37pub(super) const SCOPES: &str = "org:create_api_key user:profile user:inference user:sessions:claude_code user:mcp_servers user:file_upload";
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct OAuthCredentials {
43 #[serde(rename = "type")]
44 pub auth_type: String,
45 pub refresh: String,
46 pub access: String,
47 pub expires: u64,
48 #[serde(rename = "accountId", skip_serializing_if = "Option::is_none")]
49 pub account_id: Option<String>,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct AuthFile {
54 pub anthropic: OAuthCredentials,
55 #[serde(rename = "openai-codex", default, skip_serializing_if = "Option::is_none")]
56 pub openai_codex: Option<OAuthCredentials>,
57}
58
59#[derive(Debug, Deserialize)]
60pub(crate) struct TokenResponse {
61 pub(crate) access_token: String,
62 pub(crate) refresh_token: String,
63 pub(crate) expires_in: u64,
64}
65
66#[derive(Debug, Clone)]
68pub struct CallbackResult {
69 pub code: String,
70 pub state: String,
71}
72
73pub fn is_token_expired(creds: &OAuthCredentials) -> bool {
75 now_millis() >= creds.expires
76}
77
78pub(crate) fn now_millis() -> u64 {
81 crate::epoch_millis()
82}
83
84fn parse_manual_input(input: &str) -> (Option<String>, Option<String>) {
86 let trimmed = input.trim();
87
88 if let Ok(url) = url::Url::parse(trimmed) {
90 let code = url.query_pairs().find(|(k, _)| k == "code").map(|(_, v)| v.to_string());
91 let state = url.query_pairs().find(|(k, _)| k == "state").map(|(_, v)| v.to_string());
92 if code.is_some() {
93 return (code, state);
94 }
95 }
96
97 if trimmed.contains('#') {
99 let parts: Vec<&str> = trimmed.splitn(2, '#').collect();
100 if parts.len() == 2 && !parts[0].is_empty() && !parts[1].is_empty() {
101 return (Some(parts[0].to_string()), Some(parts[1].to_string()));
102 }
103 }
104
105 if !trimmed.is_empty() {
107 return (Some(trimmed.to_string()), None);
108 }
109
110 (None, None)
111}
112
113pub async fn login() -> std::result::Result<OAuthCredentials, String> {
117 let port = CALLBACK_PORT;
118
119 let verifier = generate_code_verifier();
121 let challenge = generate_code_challenge(&verifier);
122 let state = generate_state();
123
124 let (rx, server_handle) = start_callback_server(state.clone(), port).await?;
126
127 let auth_url = build_auth_url(&challenge, &state, port);
129
130 eprintln!("\n\x1b[1mOpening browser to sign in...\x1b[0m\n");
131
132 if let Err(e) = open_browser(&auth_url) {
133 eprintln!("Could not open browser automatically: {}", e);
134 }
135
136 eprintln!("\x1b[2mIf the browser didn't open, visit this URL:\x1b[0m");
137 eprintln!("\x1b[36m{}\x1b[0m\n", auth_url);
138
139 let (manual_tx, manual_rx) = oneshot::channel::<CallbackResult>();
141 let manual_state = state.clone();
142 let stdin_task = tokio::spawn(async move {
143 eprintln!("\x1b[2mOr paste the authorization code here:\x1b[0m");
144
145 let mut line = String::new();
146 let result = tokio::task::spawn_blocking(move || {
147 std::io::stdin().read_line(&mut line).ok();
148 line.trim().to_string()
149 })
150 .await;
151
152 if let Ok(input) = result {
153 if !input.is_empty() {
154 let (code, parsed_state) = parse_manual_input(&input);
155 if let Some(code) = code {
156 let _ = manual_tx.send(CallbackResult {
157 code,
158 state: parsed_state.unwrap_or(manual_state),
159 });
160 }
161 }
162 }
163 });
164
165 let result = tokio::select! {
167 callback = rx => {
168 match callback {
169 Ok(result) => result,
170 Err(_) => return Err("Callback channel closed".to_string()),
171 }
172 }
173 manual = manual_rx => {
174 match manual {
175 Ok(result) => result,
176 Err(_) => return Err("Manual input channel closed".to_string()),
177 }
178 }
179 };
180
181 stdin_task.abort();
182
183 if result.state != state {
185 server_handle.shutdown().await;
186 return Err("OAuth state mismatch — possible CSRF attack".to_string());
187 }
188
189 eprintln!("\n\x1b[1mExchanging code for tokens...\x1b[0m");
190
191 let creds = exchange_code_for_tokens(&result.code, &result.state, &verifier, port).await?;
193
194 server_handle.shutdown().await;
196
197 save_auth(&creds)?;
199
200 Ok(creds)
201}
202
203#[cfg(test)]
204mod tests {
205 use super::*;
206 use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
207
208 #[test]
209 fn test_generate_code_verifier() {
210 let verifier = generate_code_verifier();
211 assert!(!verifier.is_empty(), "Code verifier should not be empty");
212 assert!(verifier.len() > 20, "Code verifier should be longer than 20 characters");
213 let verifier2 = generate_code_verifier();
214 assert_ne!(verifier, verifier2, "Two calls should produce different verifiers");
215 }
216
217 #[test]
218 fn test_generate_code_challenge() {
219 let verifier = "test_verifier_123";
220 let challenge = generate_code_challenge(verifier);
221 assert!(!challenge.is_empty(), "Code challenge should not be empty");
222 let challenge2 = generate_code_challenge(verifier);
223 assert_eq!(challenge, challenge2, "Same verifier should produce same challenge");
224 let different_challenge = generate_code_challenge("different_verifier_456");
225 assert_ne!(challenge, different_challenge, "Different verifiers should produce different challenges");
226 }
227
228 #[test]
229 fn test_generate_state() {
230 let state = generate_state();
231 assert!(!state.is_empty(), "State should not be empty");
232 let state2 = generate_state();
233 assert_ne!(state, state2, "Two calls should produce different states");
234 }
235
236 #[test]
237 fn test_build_auth_url() {
238 let challenge = "test_challenge";
239 let state = "test_state";
240 let port = 8080;
241 let url = build_auth_url(challenge, state, port);
242 assert!(url.contains("claude.ai/oauth/authorize"));
243 assert!(url.contains("client_id=9d1c250a-e61b-44d9-88ed-5944d1962f5e"));
244 assert!(url.contains(&format!("code_challenge={}", challenge)));
245 assert!(url.contains(&format!("state={}", state)));
246 assert!(url.contains("localhost"));
247 assert!(url.contains(&port.to_string()));
248 assert!(url.contains("redirect_uri="));
249 }
250
251 #[test]
252 fn test_is_token_expired() {
253 let expired_creds = OAuthCredentials {
254 auth_type: "oauth".to_string(),
255 refresh: "test_refresh".to_string(),
256 access: "test_access".to_string(),
257 expires: 0,
258 account_id: None,
259 };
260 assert!(is_token_expired(&expired_creds));
261
262 let future_time = now_millis() + 3600000;
263 let fresh_creds = OAuthCredentials {
264 auth_type: "oauth".to_string(),
265 refresh: "test_refresh".to_string(),
266 access: "test_access".to_string(),
267 expires: future_time,
268 account_id: None,
269 };
270 assert!(!is_token_expired(&fresh_creds));
271 assert_eq!(fresh_creds.auth_type, "oauth");
272 }
273
274 #[test]
275 fn test_pkce_challenge_sha256() {
276 let verifier = "test_verifier_string";
277 let challenge = generate_code_challenge(verifier);
278
279 use sha2::{Sha256, Digest};
280 let mut hasher = Sha256::new();
281 hasher.update(verifier.as_bytes());
282 let hash = hasher.finalize();
283 let expected = URL_SAFE_NO_PAD.encode(hash);
284
285 assert_eq!(challenge, expected);
286 }
287
288 #[test]
289 fn test_code_verifier_length() {
290 let verifier = generate_code_verifier();
291 assert_eq!(verifier.len(), 43);
292 }
293
294 #[test]
295 fn test_state_length() {
296 let state = generate_state();
297 assert_eq!(state.len(), 43);
298 }
299
300 #[test]
301 fn test_build_auth_url_required_params() {
302 let url = build_auth_url("test_challenge", "test_state", 8080);
303 assert!(url.contains("response_type=code"));
304 assert!(url.contains("code_challenge_method=S256"));
305 assert!(url.contains("scope="));
306 assert!(url.contains("redirect_uri="));
307 assert!(url.contains("8080"));
308 }
309
310 #[test]
311 fn test_is_token_expired_edge_cases() {
312 let current_time = std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_millis() as u64;
313
314 let exactly_now_creds = OAuthCredentials {
315 auth_type: "oauth".to_string(),
316 refresh: "test_refresh".to_string(),
317 access: "test_access".to_string(),
318 expires: current_time,
319 account_id: None,
320 };
321 assert!(is_token_expired(&exactly_now_creds));
322
323 let one_ms_future_creds = OAuthCredentials {
324 auth_type: "oauth".to_string(),
325 refresh: "test_refresh".to_string(),
326 access: "test_access".to_string(),
327 expires: current_time + 1,
328 account_id: None,
329 };
330 assert!(!is_token_expired(&one_ms_future_creds));
331 }
332
333 #[test]
334 fn test_auth_file_path() {
335 let path = auth_file_path();
336 let path_str = path.to_string_lossy();
337 assert!(path_str.ends_with("auth.json"));
338 }
339
340 #[test]
341 fn test_oauth_credentials_serialization_roundtrip() {
342 let original_creds = OAuthCredentials {
343 auth_type: "oauth".to_string(),
344 refresh: "test_refresh_token".to_string(),
345 access: "test_access_token".to_string(),
346 expires: 1234567890,
347 account_id: None,
348 };
349
350 let json = serde_json::to_string(&original_creds).expect("Should serialize");
351 let deserialized_creds: OAuthCredentials = serde_json::from_str(&json).expect("Should deserialize");
352
353 assert_eq!(original_creds.auth_type, deserialized_creds.auth_type);
354 assert_eq!(original_creds.refresh, deserialized_creds.refresh);
355 assert_eq!(original_creds.access, deserialized_creds.access);
356 assert_eq!(original_creds.expires, deserialized_creds.expires);
357 }
358}