1use std::io::Write;
11use std::path::Path;
12use std::sync::Arc;
13
14use chrono::{DateTime, Utc};
15use oauth2::basic::BasicClient;
16use oauth2::{
17 AuthUrl, AuthorizationCode, ClientId, CsrfToken, PkceCodeChallenge, RedirectUrl, Scope,
18 TokenResponse, TokenUrl,
19};
20use serde::{Deserialize, Serialize};
21use tokio::sync::RwLock;
22
23use crate::error::XApiError;
24
25const AUTH_URL: &str = "https://x.com/i/oauth2/authorize";
27
28const TOKEN_URL: &str = "https://api.x.com/2/oauth2/token";
30
31const SCOPES: &[&str] = &[
33 "tweet.read",
34 "tweet.write",
35 "users.read",
36 "follows.read",
37 "follows.write",
38 "offline.access",
39];
40
41const REFRESH_WINDOW_SECS: i64 = 300;
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct Tokens {
47 pub access_token: String,
49 pub refresh_token: String,
51 pub expires_at: DateTime<Utc>,
53 pub scopes: Vec<String>,
55}
56
57pub struct TokenManager {
59 tokens: Arc<RwLock<Tokens>>,
60 client_id: String,
61 http_client: reqwest::Client,
62 token_path: std::path::PathBuf,
63}
64
65impl TokenManager {
66 pub fn new(tokens: Tokens, client_id: String, token_path: std::path::PathBuf) -> Self {
68 Self {
69 tokens: Arc::new(RwLock::new(tokens)),
70 client_id,
71 http_client: reqwest::Client::new(),
72 token_path,
73 }
74 }
75
76 pub async fn get_access_token(&self) -> Result<String, XApiError> {
78 self.refresh_if_needed().await?;
79 let tokens = self.tokens.read().await;
80 Ok(tokens.access_token.clone())
81 }
82
83 pub fn tokens_lock(&self) -> Arc<RwLock<Tokens>> {
85 self.tokens.clone()
86 }
87
88 pub async fn refresh_if_needed(&self) -> Result<(), XApiError> {
90 let should_refresh = {
91 let tokens = self.tokens.read().await;
92 let now = Utc::now();
93 let seconds_until_expiry = tokens.expires_at.signed_duration_since(now).num_seconds();
94 seconds_until_expiry < REFRESH_WINDOW_SECS
95 };
96
97 if should_refresh {
98 self.do_refresh().await?;
99 }
100
101 Ok(())
102 }
103
104 async fn do_refresh(&self) -> Result<(), XApiError> {
106 let refresh_token = {
107 let tokens = self.tokens.read().await;
108 tokens.refresh_token.clone()
109 };
110
111 tracing::info!("Refreshing X API access token");
112
113 let params = [
114 ("grant_type", "refresh_token"),
115 ("refresh_token", &refresh_token),
116 ("client_id", &self.client_id),
117 ];
118
119 let response = self
120 .http_client
121 .post(TOKEN_URL)
122 .form(¶ms)
123 .send()
124 .await
125 .map_err(|e| XApiError::Network { source: e })?;
126
127 if !response.status().is_success() {
128 let status = response.status().as_u16();
129 let body = response.text().await.unwrap_or_default();
130 tracing::error!(
131 status,
132 body_len = body.len(),
133 "Token refresh failed (response body redacted)"
134 );
135 return Err(XApiError::AuthExpired);
136 }
137
138 let body: TokenRefreshResponse = response
139 .json()
140 .await
141 .map_err(|e| XApiError::Network { source: e })?;
142
143 let new_tokens = Tokens {
144 access_token: body.access_token,
145 refresh_token: body.refresh_token,
146 expires_at: Utc::now() + chrono::Duration::seconds(body.expires_in),
147 scopes: body
148 .scope
149 .split_whitespace()
150 .map(|s| s.to_string())
151 .collect(),
152 };
153
154 tracing::info!(
155 expires_at = %new_tokens.expires_at,
156 "Token refreshed successfully"
157 );
158
159 {
161 let mut tokens = self.tokens.write().await;
162 *tokens = new_tokens.clone();
163 }
164
165 save_tokens(&new_tokens, &self.token_path).map_err(|e| {
167 tracing::error!(error = %e, "Failed to save refreshed tokens");
168 XApiError::ApiError {
169 status: 0,
170 message: format!("Failed to save tokens: {e}"),
171 }
172 })?;
173
174 Ok(())
175 }
176}
177
178#[derive(Debug, Deserialize)]
180struct TokenRefreshResponse {
181 access_token: String,
182 refresh_token: String,
183 expires_in: i64,
184 scope: String,
185}
186
187pub fn save_tokens(tokens: &Tokens, path: &Path) -> Result<(), String> {
189 if let Some(parent) = path.parent() {
190 std::fs::create_dir_all(parent).map_err(|e| format!("Failed to create directory: {e}"))?;
191 }
192
193 let json = serde_json::to_string_pretty(tokens)
194 .map_err(|e| format!("Failed to serialize tokens: {e}"))?;
195
196 #[cfg(unix)]
198 {
199 use std::io::Write;
200 use std::os::unix::fs::OpenOptionsExt;
201 let mut file = std::fs::OpenOptions::new()
202 .write(true)
203 .create(true)
204 .truncate(true)
205 .mode(0o600)
206 .open(path)
207 .map_err(|e| format!("Failed to create token file: {e}"))?;
208 file.write_all(json.as_bytes())
209 .map_err(|e| format!("Failed to write tokens: {e}"))?;
210 }
211
212 #[cfg(not(unix))]
213 {
214 std::fs::write(path, &json).map_err(|e| format!("Failed to write tokens: {e}"))?;
215 tracing::warn!("Cannot set restrictive file permissions on non-Unix platform");
216 }
217
218 Ok(())
219}
220
221pub fn load_tokens(path: &Path) -> Result<Option<Tokens>, XApiError> {
223 match std::fs::read_to_string(path) {
224 Ok(contents) => {
225 let tokens: Tokens =
226 serde_json::from_str(&contents).map_err(|e| XApiError::ApiError {
227 status: 0,
228 message: format!("Failed to parse tokens file: {e}"),
229 })?;
230 Ok(Some(tokens))
231 }
232 Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(None),
233 Err(e) => Err(XApiError::ApiError {
234 status: 0,
235 message: format!("Failed to read tokens file: {e}"),
236 }),
237 }
238}
239
240fn build_oauth_client(client_id: &str, redirect_uri: &str) -> Result<BasicClient, XApiError> {
242 let auth_url = AuthUrl::new(AUTH_URL.to_string()).map_err(|e| XApiError::ApiError {
243 status: 0,
244 message: format!("Invalid auth URL: {e}"),
245 })?;
246
247 let token_url = TokenUrl::new(TOKEN_URL.to_string()).map_err(|e| XApiError::ApiError {
248 status: 0,
249 message: format!("Invalid token URL: {e}"),
250 })?;
251
252 let redirect = RedirectUrl::new(redirect_uri.to_string()).map_err(|e| XApiError::ApiError {
253 status: 0,
254 message: format!("Invalid redirect URI: {e}"),
255 })?;
256
257 let client = BasicClient::new(
258 ClientId::new(client_id.to_string()),
259 None,
260 auth_url,
261 Some(token_url),
262 )
263 .set_redirect_uri(redirect);
264
265 Ok(client)
266}
267
268pub async fn authenticate_manual(client_id: &str) -> Result<Tokens, XApiError> {
273 let redirect_uri = "http://localhost/callback";
274 let client = build_oauth_client(client_id, redirect_uri)?;
275
276 let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
277
278 let mut auth_builder = client
279 .authorize_url(CsrfToken::new_random)
280 .set_pkce_challenge(pkce_challenge);
281 for scope in SCOPES {
282 auth_builder = auth_builder.add_scope(Scope::new(scope.to_string()));
283 }
284 let (auth_url, csrf_state) = auth_builder.url();
285
286 println!("\n=== X API Authentication (Manual Mode) ===\n");
287 println!("1. Open this URL in your browser:\n");
288 println!(" {auth_url}\n");
289 println!("2. Authorize the application");
290 println!("3. Copy the authorization code from the callback URL");
291 println!(" (Look for ?code=XXXXX in the URL)\n");
292
293 let _ = csrf_state; print!("Paste the authorization code: ");
296 std::io::stdout().flush().map_err(|e| XApiError::ApiError {
297 status: 0,
298 message: format!("IO error: {e}"),
299 })?;
300
301 let mut code = String::new();
302 std::io::stdin()
303 .read_line(&mut code)
304 .map_err(|e| XApiError::ApiError {
305 status: 0,
306 message: format!("Failed to read input: {e}"),
307 })?;
308
309 let code = code.trim().to_string();
310 if code.is_empty() {
311 return Err(XApiError::ApiError {
312 status: 0,
313 message: "Authorization code cannot be empty".to_string(),
314 });
315 }
316
317 exchange_code(&client, &code, pkce_verifier).await
318}
319
320pub async fn authenticate_callback(
325 client_id: &str,
326 host: &str,
327 port: u16,
328) -> Result<Tokens, XApiError> {
329 let redirect_uri = format!("http://{host}:{port}/callback");
330 let client = build_oauth_client(client_id, &redirect_uri)?;
331
332 let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
333
334 let mut auth_builder = client
335 .authorize_url(CsrfToken::new_random)
336 .set_pkce_challenge(pkce_challenge);
337 for scope in SCOPES {
338 auth_builder = auth_builder.add_scope(Scope::new(scope.to_string()));
339 }
340 let (auth_url, csrf_state) = auth_builder.url();
341
342 let addr = format!("{host}:{port}");
344 let listener = tokio::net::TcpListener::bind(&addr)
345 .await
346 .map_err(|e| XApiError::ApiError {
347 status: 0,
348 message: format!(
349 "Failed to bind callback server on {addr}: {e}. Try changing auth.callback_port."
350 ),
351 })?;
352
353 tracing::info!("Callback server listening on {addr}");
354
355 let url_str = auth_url.to_string();
357 if let Err(e) = open::that(&url_str) {
358 tracing::warn!(error = %e, "Failed to open browser automatically");
359 println!("\nCould not open browser automatically.");
360 println!("Please open this URL manually:\n");
361 println!(" {url_str}\n");
362 } else {
363 println!("\nOpened authorization URL in your browser.");
364 println!("Waiting for callback...\n");
365 }
366
367 let callback_result = tokio::time::timeout(
369 std::time::Duration::from_secs(120),
370 accept_callback(&listener, csrf_state.secret()),
371 )
372 .await
373 .map_err(|_| XApiError::ApiError {
374 status: 0,
375 message: "Authentication timed out after 120 seconds".to_string(),
376 })??;
377
378 exchange_code(&client, &callback_result, pkce_verifier).await
379}
380
381async fn accept_callback(
383 listener: &tokio::net::TcpListener,
384 expected_state: &str,
385) -> Result<String, XApiError> {
386 let (mut stream, _addr) = listener.accept().await.map_err(|e| XApiError::ApiError {
387 status: 0,
388 message: format!("Failed to accept connection: {e}"),
389 })?;
390
391 use tokio::io::{AsyncReadExt, AsyncWriteExt};
392
393 let mut buf = vec![0u8; 4096];
394 let n = stream
395 .read(&mut buf)
396 .await
397 .map_err(|e| XApiError::ApiError {
398 status: 0,
399 message: format!("Failed to read request: {e}"),
400 })?;
401
402 let request = String::from_utf8_lossy(&buf[..n]);
403
404 let first_line = request.lines().next().unwrap_or("");
406 let path = first_line.split_whitespace().nth(1).unwrap_or("");
407
408 let query_start = path.find('?').map(|i| i + 1);
409 let query_string = query_start.map(|i| &path[i..]).unwrap_or("");
410
411 let mut code = None;
412 let mut state = None;
413
414 for param in query_string.split('&') {
415 if let Some((key, value)) = param.split_once('=') {
416 match key {
417 "code" => code = Some(value.to_string()),
418 "state" => state = Some(value.to_string()),
419 _ => {}
420 }
421 }
422 }
423
424 let received_state = state.ok_or_else(|| XApiError::ApiError {
426 status: 0,
427 message: "Missing OAuth state parameter in callback".to_string(),
428 })?;
429 if received_state != expected_state {
430 let error_html = "HTTP/1.1 400 Bad Request\r\nContent-Type: text/html\r\n\r\n\
431 <html><body><h1>Authentication Failed</h1>\
432 <p>State parameter mismatch. This may indicate a CSRF attack.</p>\
433 <p>Please try again.</p></body></html>";
434 let _ = stream.write_all(error_html.as_bytes()).await;
435 return Err(XApiError::ApiError {
436 status: 0,
437 message: "OAuth state parameter mismatch".to_string(),
438 });
439 }
440
441 let auth_code = code.ok_or_else(|| XApiError::ApiError {
442 status: 0,
443 message: "No authorization code in callback URL".to_string(),
444 })?;
445
446 let success_html = "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n\
448 <html><body><h1>Authentication Successful!</h1>\
449 <p>You can close this tab and return to the terminal.</p></body></html>";
450 let _ = stream.write_all(success_html.as_bytes()).await;
451
452 Ok(auth_code)
453}
454
455async fn exchange_code(
457 client: &BasicClient,
458 code: &str,
459 pkce_verifier: oauth2::PkceCodeVerifier,
460) -> Result<Tokens, XApiError> {
461 let http_client = oauth2::reqwest::async_http_client;
462
463 let token_result = client
464 .exchange_code(AuthorizationCode::new(code.to_string()))
465 .set_pkce_verifier(pkce_verifier)
466 .request_async(http_client)
467 .await
468 .map_err(|e| XApiError::ApiError {
469 status: 0,
470 message: format!("Token exchange failed: {e}"),
471 })?;
472
473 let access_token = token_result.access_token().secret().to_string();
474 let refresh_token = token_result
475 .refresh_token()
476 .map(|rt| rt.secret().to_string())
477 .unwrap_or_default();
478
479 let expires_in = token_result
480 .expires_in()
481 .map(|d| d.as_secs() as i64)
482 .unwrap_or(7200);
483
484 let scopes: Vec<String> = token_result
485 .scopes()
486 .map(|s| s.iter().map(|scope| scope.to_string()).collect())
487 .unwrap_or_else(|| SCOPES.iter().map(|s| s.to_string()).collect());
488
489 let tokens = Tokens {
490 access_token,
491 refresh_token,
492 expires_at: Utc::now() + chrono::Duration::seconds(expires_in),
493 scopes,
494 };
495
496 tracing::info!(
497 expires_at = %tokens.expires_at,
498 scopes = ?tokens.scopes,
499 "Authentication successful"
500 );
501
502 Ok(tokens)
503}
504
505#[cfg(test)]
506mod tests {
507 use super::*;
508 use std::path::PathBuf;
509
510 #[test]
511 fn tokens_serialize_deserialize() {
512 let tokens = Tokens {
513 access_token: "test_access".to_string(),
514 refresh_token: "test_refresh".to_string(),
515 expires_at: Utc::now() + chrono::Duration::hours(2),
516 scopes: vec!["tweet.read".to_string(), "tweet.write".to_string()],
517 };
518
519 let json = serde_json::to_string(&tokens).expect("serialize");
520 let parsed: Tokens = serde_json::from_str(&json).expect("deserialize");
521
522 assert_eq!(parsed.access_token, "test_access");
523 assert_eq!(parsed.refresh_token, "test_refresh");
524 assert_eq!(parsed.scopes.len(), 2);
525 }
526
527 #[test]
528 fn save_and_load_tokens() {
529 let dir = tempfile::tempdir().expect("temp dir");
530 let path = dir.path().join("tokens.json");
531
532 let tokens = Tokens {
533 access_token: "acc".to_string(),
534 refresh_token: "ref".to_string(),
535 expires_at: Utc::now() + chrono::Duration::hours(2),
536 scopes: vec!["tweet.read".to_string()],
537 };
538
539 save_tokens(&tokens, &path).expect("save");
540
541 let loaded = load_tokens(&path).expect("load").expect("some");
542 assert_eq!(loaded.access_token, "acc");
543 assert_eq!(loaded.refresh_token, "ref");
544 }
545
546 #[test]
547 fn load_tokens_file_not_found_returns_none() {
548 let path = PathBuf::from("/nonexistent/tokens.json");
549 let result = load_tokens(&path).expect("load");
550 assert!(result.is_none());
551 }
552
553 #[test]
554 fn load_tokens_malformed_returns_error() {
555 let dir = tempfile::tempdir().expect("temp dir");
556 let path = dir.path().join("tokens.json");
557 std::fs::write(&path, "not valid json").expect("write");
558
559 let result = load_tokens(&path);
560 assert!(result.is_err());
561 }
562
563 #[cfg(unix)]
564 #[test]
565 fn save_tokens_sets_permissions() {
566 use std::os::unix::fs::PermissionsExt;
567
568 let dir = tempfile::tempdir().expect("temp dir");
569 let path = dir.path().join("tokens.json");
570
571 let tokens = Tokens {
572 access_token: "a".to_string(),
573 refresh_token: "r".to_string(),
574 expires_at: Utc::now(),
575 scopes: vec![],
576 };
577
578 save_tokens(&tokens, &path).expect("save");
579
580 let metadata = std::fs::metadata(&path).expect("metadata");
581 let mode = metadata.permissions().mode() & 0o777;
582 assert_eq!(mode, 0o600, "token file should have 600 permissions");
583 }
584
585 #[test]
586 fn save_tokens_creates_parent_dirs() {
587 let dir = tempfile::tempdir().expect("temp dir");
588 let path = dir.path().join("nested").join("dir").join("tokens.json");
589
590 let tokens = Tokens {
591 access_token: "a".to_string(),
592 refresh_token: "r".to_string(),
593 expires_at: Utc::now(),
594 scopes: vec![],
595 };
596
597 save_tokens(&tokens, &path).expect("save");
598 assert!(path.exists());
599 }
600
601 #[tokio::test]
602 async fn token_manager_refresh_detects_expiry() {
603 let tokens = Tokens {
605 access_token: "old_token".to_string(),
606 refresh_token: "refresh".to_string(),
607 expires_at: Utc::now() + chrono::Duration::seconds(60), scopes: vec![],
609 };
610
611 let dir = tempfile::tempdir().expect("temp dir");
612 let path = dir.path().join("tokens.json");
613
614 let manager = TokenManager::new(tokens, "client_id".to_string(), path);
615
616 let result = manager.refresh_if_needed().await;
618 assert!(result.is_err());
620 }
621
622 #[tokio::test]
623 async fn token_manager_no_refresh_when_fresh() {
624 let tokens = Tokens {
625 access_token: "fresh_token".to_string(),
626 refresh_token: "refresh".to_string(),
627 expires_at: Utc::now() + chrono::Duration::hours(2), scopes: vec![],
629 };
630
631 let dir = tempfile::tempdir().expect("temp dir");
632 let path = dir.path().join("tokens.json");
633
634 let manager = TokenManager::new(tokens, "client_id".to_string(), path);
635
636 let result = manager.refresh_if_needed().await;
638 assert!(result.is_ok());
639
640 let token = manager.get_access_token().await.expect("get token");
641 assert_eq!(token, "fresh_token");
642 }
643
644 #[tokio::test]
645 async fn token_manager_refresh_with_mock() {
646 use wiremock::matchers::{body_string_contains, method};
647 use wiremock::{Mock, MockServer, ResponseTemplate};
648
649 let server = MockServer::start().await;
650
651 Mock::given(method("POST"))
652 .and(body_string_contains("grant_type=refresh_token"))
653 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
654 "access_token": "new_access",
655 "refresh_token": "new_refresh",
656 "expires_in": 7200,
657 "scope": "tweet.read tweet.write"
658 })))
659 .mount(&server)
660 .await;
661
662 let tokens = Tokens {
664 access_token: "old_token".to_string(),
665 refresh_token: "old_refresh".to_string(),
666 expires_at: Utc::now() + chrono::Duration::seconds(60),
667 scopes: vec![],
668 };
669
670 let dir = tempfile::tempdir().expect("temp dir");
671 let path = dir.path().join("tokens.json");
672
673 let manager = TokenManager::new(tokens, "client_id".to_string(), path);
676 let token = manager.tokens.read().await;
677 assert_eq!(token.access_token, "old_token");
678 }
679}