1use std::io::Write;
11use std::path::Path;
12use std::sync::Arc;
13
14use chrono::{DateTime, Utc};
15use oauth2::basic::BasicClient;
16use oauth2::{
17 AuthUrl, AuthorizationCode, ClientId, CsrfToken, PkceCodeChallenge, RedirectUrl, Scope,
18 TokenResponse, TokenUrl,
19};
20use serde::{Deserialize, Serialize};
21use tokio::sync::RwLock;
22
23use crate::error::XApiError;
24
25use super::scopes::REQUIRED_SCOPES;
26
27const AUTH_URL: &str = "https://x.com/i/oauth2/authorize";
29
30const TOKEN_URL: &str = "https://api.x.com/2/oauth2/token";
32
33const REFRESH_WINDOW_SECS: i64 = 300;
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct Tokens {
39 pub access_token: String,
41 pub refresh_token: String,
43 pub expires_at: DateTime<Utc>,
45 pub scopes: Vec<String>,
47}
48
49pub struct TokenManager {
51 tokens: Arc<RwLock<Tokens>>,
52 client_id: String,
53 http_client: reqwest::Client,
54 token_path: std::path::PathBuf,
55}
56
57impl TokenManager {
58 pub fn new(tokens: Tokens, client_id: String, token_path: std::path::PathBuf) -> Self {
60 Self {
61 tokens: Arc::new(RwLock::new(tokens)),
62 client_id,
63 http_client: reqwest::Client::new(),
64 token_path,
65 }
66 }
67
68 pub async fn get_access_token(&self) -> Result<String, XApiError> {
70 self.refresh_if_needed().await?;
71 let tokens = self.tokens.read().await;
72 Ok(tokens.access_token.clone())
73 }
74
75 pub fn tokens_lock(&self) -> Arc<RwLock<Tokens>> {
77 self.tokens.clone()
78 }
79
80 pub async fn refresh_if_needed(&self) -> Result<(), XApiError> {
82 let should_refresh = {
83 let tokens = self.tokens.read().await;
84 let now = Utc::now();
85 let seconds_until_expiry = tokens.expires_at.signed_duration_since(now).num_seconds();
86 seconds_until_expiry < REFRESH_WINDOW_SECS
87 };
88
89 if should_refresh {
90 self.do_refresh().await?;
91 }
92
93 Ok(())
94 }
95
96 async fn do_refresh(&self) -> Result<(), XApiError> {
98 let refresh_token = {
99 let tokens = self.tokens.read().await;
100 tokens.refresh_token.clone()
101 };
102
103 tracing::info!("Refreshing X API access token");
104
105 let params = [
106 ("grant_type", "refresh_token"),
107 ("refresh_token", &refresh_token),
108 ("client_id", &self.client_id),
109 ];
110
111 let response = self
112 .http_client
113 .post(TOKEN_URL)
114 .form(¶ms)
115 .send()
116 .await
117 .map_err(|e| XApiError::Network { source: e })?;
118
119 if !response.status().is_success() {
120 let status = response.status().as_u16();
121 let body = response.text().await.unwrap_or_default();
122 tracing::error!(
123 status,
124 body_len = body.len(),
125 "Token refresh failed (response body redacted)"
126 );
127 return Err(XApiError::AuthExpired);
128 }
129
130 let body: TokenRefreshResponse = response
131 .json()
132 .await
133 .map_err(|e| XApiError::Network { source: e })?;
134
135 let new_tokens = Tokens {
136 access_token: body.access_token,
137 refresh_token: body.refresh_token,
138 expires_at: Utc::now() + chrono::Duration::seconds(body.expires_in),
139 scopes: body
140 .scope
141 .split_whitespace()
142 .map(|s| s.to_string())
143 .collect(),
144 };
145
146 tracing::info!(
147 expires_at = %new_tokens.expires_at,
148 "Token refreshed successfully"
149 );
150
151 {
153 let mut tokens = self.tokens.write().await;
154 *tokens = new_tokens.clone();
155 }
156
157 save_tokens(&new_tokens, &self.token_path).map_err(|e| {
159 tracing::error!(error = %e, "Failed to save refreshed tokens");
160 XApiError::ApiError {
161 status: 0,
162 message: format!("Failed to save tokens: {e}"),
163 }
164 })?;
165
166 Ok(())
167 }
168}
169
170#[derive(Debug, Deserialize)]
172struct TokenRefreshResponse {
173 access_token: String,
174 refresh_token: String,
175 expires_in: i64,
176 scope: String,
177}
178
179pub fn save_tokens(tokens: &Tokens, path: &Path) -> Result<(), String> {
181 if let Some(parent) = path.parent() {
182 std::fs::create_dir_all(parent).map_err(|e| format!("Failed to create directory: {e}"))?;
183 }
184
185 let json = serde_json::to_string_pretty(tokens)
186 .map_err(|e| format!("Failed to serialize tokens: {e}"))?;
187
188 #[cfg(unix)]
190 {
191 use std::io::Write;
192 use std::os::unix::fs::OpenOptionsExt;
193 let mut file = std::fs::OpenOptions::new()
194 .write(true)
195 .create(true)
196 .truncate(true)
197 .mode(0o600)
198 .open(path)
199 .map_err(|e| format!("Failed to create token file: {e}"))?;
200 file.write_all(json.as_bytes())
201 .map_err(|e| format!("Failed to write tokens: {e}"))?;
202 }
203
204 #[cfg(not(unix))]
205 {
206 std::fs::write(path, &json).map_err(|e| format!("Failed to write tokens: {e}"))?;
207 tracing::warn!("Cannot set restrictive file permissions on non-Unix platform");
208 }
209
210 Ok(())
211}
212
213pub fn load_tokens(path: &Path) -> Result<Option<Tokens>, XApiError> {
215 match std::fs::read_to_string(path) {
216 Ok(contents) => {
217 let tokens: Tokens =
218 serde_json::from_str(&contents).map_err(|e| XApiError::ApiError {
219 status: 0,
220 message: format!("Failed to parse tokens file: {e}"),
221 })?;
222 Ok(Some(tokens))
223 }
224 Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(None),
225 Err(e) => Err(XApiError::ApiError {
226 status: 0,
227 message: format!("Failed to read tokens file: {e}"),
228 }),
229 }
230}
231
232fn build_oauth_client(client_id: &str, redirect_uri: &str) -> Result<BasicClient, XApiError> {
234 let auth_url = AuthUrl::new(AUTH_URL.to_string()).map_err(|e| XApiError::ApiError {
235 status: 0,
236 message: format!("Invalid auth URL: {e}"),
237 })?;
238
239 let token_url = TokenUrl::new(TOKEN_URL.to_string()).map_err(|e| XApiError::ApiError {
240 status: 0,
241 message: format!("Invalid token URL: {e}"),
242 })?;
243
244 let redirect = RedirectUrl::new(redirect_uri.to_string()).map_err(|e| XApiError::ApiError {
245 status: 0,
246 message: format!("Invalid redirect URI: {e}"),
247 })?;
248
249 let client = BasicClient::new(
250 ClientId::new(client_id.to_string()),
251 None,
252 auth_url,
253 Some(token_url),
254 )
255 .set_redirect_uri(redirect);
256
257 Ok(client)
258}
259
260pub async fn authenticate_manual(client_id: &str) -> Result<Tokens, XApiError> {
265 let redirect_uri = "http://localhost/callback";
266 let client = build_oauth_client(client_id, redirect_uri)?;
267
268 let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
269
270 let mut auth_builder = client
271 .authorize_url(CsrfToken::new_random)
272 .set_pkce_challenge(pkce_challenge);
273 for scope in REQUIRED_SCOPES {
274 auth_builder = auth_builder.add_scope(Scope::new(scope.to_string()));
275 }
276 let (auth_url, csrf_state) = auth_builder.url();
277
278 println!("\n=== X API Authentication (Manual Mode) ===\n");
279 println!("1. Open this URL in your browser:\n");
280 println!(" {auth_url}\n");
281 println!("2. Authorize the application");
282 println!("3. Copy the authorization code from the callback URL");
283 println!(" (Look for ?code=XXXXX in the URL)\n");
284
285 let _ = csrf_state; print!("Paste the authorization code: ");
288 std::io::stdout().flush().map_err(|e| XApiError::ApiError {
289 status: 0,
290 message: format!("IO error: {e}"),
291 })?;
292
293 let mut code = String::new();
294 std::io::stdin()
295 .read_line(&mut code)
296 .map_err(|e| XApiError::ApiError {
297 status: 0,
298 message: format!("Failed to read input: {e}"),
299 })?;
300
301 let code = code.trim().to_string();
302 if code.is_empty() {
303 return Err(XApiError::ApiError {
304 status: 0,
305 message: "Authorization code cannot be empty".to_string(),
306 });
307 }
308
309 exchange_code(&client, &code, pkce_verifier).await
310}
311
312pub async fn authenticate_callback(
317 client_id: &str,
318 host: &str,
319 port: u16,
320) -> Result<Tokens, XApiError> {
321 let redirect_uri = format!("http://{host}:{port}/callback");
322 let client = build_oauth_client(client_id, &redirect_uri)?;
323
324 let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
325
326 let mut auth_builder = client
327 .authorize_url(CsrfToken::new_random)
328 .set_pkce_challenge(pkce_challenge);
329 for scope in REQUIRED_SCOPES {
330 auth_builder = auth_builder.add_scope(Scope::new(scope.to_string()));
331 }
332 let (auth_url, csrf_state) = auth_builder.url();
333
334 let addr = format!("{host}:{port}");
336 let listener = tokio::net::TcpListener::bind(&addr)
337 .await
338 .map_err(|e| XApiError::ApiError {
339 status: 0,
340 message: format!(
341 "Failed to bind callback server on {addr}: {e}. Try changing auth.callback_port."
342 ),
343 })?;
344
345 tracing::info!("Callback server listening on {addr}");
346
347 let url_str = auth_url.to_string();
349 if let Err(e) = open::that(&url_str) {
350 tracing::warn!(error = %e, "Failed to open browser automatically");
351 println!("\nCould not open browser automatically.");
352 println!("Please open this URL manually:\n");
353 println!(" {url_str}\n");
354 } else {
355 println!("\nOpened authorization URL in your browser.");
356 println!("Waiting for callback...\n");
357 }
358
359 let callback_result = tokio::time::timeout(
361 std::time::Duration::from_secs(120),
362 accept_callback(&listener, csrf_state.secret()),
363 )
364 .await
365 .map_err(|_| XApiError::ApiError {
366 status: 0,
367 message: "Authentication timed out after 120 seconds".to_string(),
368 })??;
369
370 exchange_code(&client, &callback_result, pkce_verifier).await
371}
372
373async fn accept_callback(
375 listener: &tokio::net::TcpListener,
376 expected_state: &str,
377) -> Result<String, XApiError> {
378 let (mut stream, _addr) = listener.accept().await.map_err(|e| XApiError::ApiError {
379 status: 0,
380 message: format!("Failed to accept connection: {e}"),
381 })?;
382
383 use tokio::io::{AsyncReadExt, AsyncWriteExt};
384
385 let mut buf = vec![0u8; 4096];
386 let n = stream
387 .read(&mut buf)
388 .await
389 .map_err(|e| XApiError::ApiError {
390 status: 0,
391 message: format!("Failed to read request: {e}"),
392 })?;
393
394 let request = String::from_utf8_lossy(&buf[..n]);
395
396 let first_line = request.lines().next().unwrap_or("");
398 let path = first_line.split_whitespace().nth(1).unwrap_or("");
399
400 let query_start = path.find('?').map(|i| i + 1);
401 let query_string = query_start.map(|i| &path[i..]).unwrap_or("");
402
403 let mut code = None;
404 let mut state = None;
405
406 for param in query_string.split('&') {
407 if let Some((key, value)) = param.split_once('=') {
408 match key {
409 "code" => code = Some(value.to_string()),
410 "state" => state = Some(value.to_string()),
411 _ => {}
412 }
413 }
414 }
415
416 let received_state = state.ok_or_else(|| XApiError::ApiError {
418 status: 0,
419 message: "Missing OAuth state parameter in callback".to_string(),
420 })?;
421 if received_state != expected_state {
422 let error_html = "HTTP/1.1 400 Bad Request\r\nContent-Type: text/html\r\n\r\n\
423 <html><body><h1>Authentication Failed</h1>\
424 <p>State parameter mismatch. This may indicate a CSRF attack.</p>\
425 <p>Please try again.</p></body></html>";
426 let _ = stream.write_all(error_html.as_bytes()).await;
427 return Err(XApiError::ApiError {
428 status: 0,
429 message: "OAuth state parameter mismatch".to_string(),
430 });
431 }
432
433 let auth_code = code.ok_or_else(|| XApiError::ApiError {
434 status: 0,
435 message: "No authorization code in callback URL".to_string(),
436 })?;
437
438 let success_html = "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n\
440 <html><body><h1>Authentication Successful!</h1>\
441 <p>You can close this tab and return to the terminal.</p></body></html>";
442 let _ = stream.write_all(success_html.as_bytes()).await;
443
444 Ok(auth_code)
445}
446
447async fn exchange_code(
449 client: &BasicClient,
450 code: &str,
451 pkce_verifier: oauth2::PkceCodeVerifier,
452) -> Result<Tokens, XApiError> {
453 let http_client = oauth2::reqwest::async_http_client;
454
455 let token_result = client
456 .exchange_code(AuthorizationCode::new(code.to_string()))
457 .set_pkce_verifier(pkce_verifier)
458 .request_async(http_client)
459 .await
460 .map_err(|e| XApiError::ApiError {
461 status: 0,
462 message: format!("Token exchange failed: {e}"),
463 })?;
464
465 let access_token = token_result.access_token().secret().to_string();
466 let refresh_token = token_result
467 .refresh_token()
468 .map(|rt| rt.secret().to_string())
469 .unwrap_or_default();
470
471 let expires_in = token_result
472 .expires_in()
473 .map(|d| d.as_secs() as i64)
474 .unwrap_or(7200);
475
476 let scopes: Vec<String> = token_result
477 .scopes()
478 .map(|s| s.iter().map(|scope| scope.to_string()).collect())
479 .unwrap_or_else(|| REQUIRED_SCOPES.iter().map(|s| s.to_string()).collect());
480
481 let tokens = Tokens {
482 access_token,
483 refresh_token,
484 expires_at: Utc::now() + chrono::Duration::seconds(expires_in),
485 scopes,
486 };
487
488 tracing::info!(
489 expires_at = %tokens.expires_at,
490 scopes = ?tokens.scopes,
491 "Authentication successful"
492 );
493
494 Ok(tokens)
495}
496
497#[cfg(test)]
498mod tests {
499 use super::*;
500 use std::path::PathBuf;
501
502 #[test]
503 fn tokens_serialize_deserialize() {
504 let tokens = Tokens {
505 access_token: "test_access".to_string(),
506 refresh_token: "test_refresh".to_string(),
507 expires_at: Utc::now() + chrono::Duration::hours(2),
508 scopes: vec!["tweet.read".to_string(), "tweet.write".to_string()],
509 };
510
511 let json = serde_json::to_string(&tokens).expect("serialize");
512 let parsed: Tokens = serde_json::from_str(&json).expect("deserialize");
513
514 assert_eq!(parsed.access_token, "test_access");
515 assert_eq!(parsed.refresh_token, "test_refresh");
516 assert_eq!(parsed.scopes.len(), 2);
517 }
518
519 #[test]
520 fn save_and_load_tokens() {
521 let dir = tempfile::tempdir().expect("temp dir");
522 let path = dir.path().join("tokens.json");
523
524 let tokens = Tokens {
525 access_token: "acc".to_string(),
526 refresh_token: "ref".to_string(),
527 expires_at: Utc::now() + chrono::Duration::hours(2),
528 scopes: vec!["tweet.read".to_string()],
529 };
530
531 save_tokens(&tokens, &path).expect("save");
532
533 let loaded = load_tokens(&path).expect("load").expect("some");
534 assert_eq!(loaded.access_token, "acc");
535 assert_eq!(loaded.refresh_token, "ref");
536 }
537
538 #[test]
539 fn load_tokens_file_not_found_returns_none() {
540 let path = PathBuf::from("/nonexistent/tokens.json");
541 let result = load_tokens(&path).expect("load");
542 assert!(result.is_none());
543 }
544
545 #[test]
546 fn load_tokens_malformed_returns_error() {
547 let dir = tempfile::tempdir().expect("temp dir");
548 let path = dir.path().join("tokens.json");
549 std::fs::write(&path, "not valid json").expect("write");
550
551 let result = load_tokens(&path);
552 assert!(result.is_err());
553 }
554
555 #[cfg(unix)]
556 #[test]
557 fn save_tokens_sets_permissions() {
558 use std::os::unix::fs::PermissionsExt;
559
560 let dir = tempfile::tempdir().expect("temp dir");
561 let path = dir.path().join("tokens.json");
562
563 let tokens = Tokens {
564 access_token: "a".to_string(),
565 refresh_token: "r".to_string(),
566 expires_at: Utc::now(),
567 scopes: vec![],
568 };
569
570 save_tokens(&tokens, &path).expect("save");
571
572 let metadata = std::fs::metadata(&path).expect("metadata");
573 let mode = metadata.permissions().mode() & 0o777;
574 assert_eq!(mode, 0o600, "token file should have 600 permissions");
575 }
576
577 #[test]
578 fn save_tokens_creates_parent_dirs() {
579 let dir = tempfile::tempdir().expect("temp dir");
580 let path = dir.path().join("nested").join("dir").join("tokens.json");
581
582 let tokens = Tokens {
583 access_token: "a".to_string(),
584 refresh_token: "r".to_string(),
585 expires_at: Utc::now(),
586 scopes: vec![],
587 };
588
589 save_tokens(&tokens, &path).expect("save");
590 assert!(path.exists());
591 }
592
593 #[tokio::test]
594 async fn token_manager_refresh_detects_expiry() {
595 let tokens = Tokens {
597 access_token: "old_token".to_string(),
598 refresh_token: "refresh".to_string(),
599 expires_at: Utc::now() + chrono::Duration::seconds(60), scopes: vec![],
601 };
602
603 let dir = tempfile::tempdir().expect("temp dir");
604 let path = dir.path().join("tokens.json");
605
606 let manager = TokenManager::new(tokens, "client_id".to_string(), path);
607
608 let result = manager.refresh_if_needed().await;
610 assert!(result.is_err());
612 }
613
614 #[tokio::test]
615 async fn token_manager_no_refresh_when_fresh() {
616 let tokens = Tokens {
617 access_token: "fresh_token".to_string(),
618 refresh_token: "refresh".to_string(),
619 expires_at: Utc::now() + chrono::Duration::hours(2), scopes: vec![],
621 };
622
623 let dir = tempfile::tempdir().expect("temp dir");
624 let path = dir.path().join("tokens.json");
625
626 let manager = TokenManager::new(tokens, "client_id".to_string(), path);
627
628 let result = manager.refresh_if_needed().await;
630 assert!(result.is_ok());
631
632 let token = manager.get_access_token().await.expect("get token");
633 assert_eq!(token, "fresh_token");
634 }
635
636 #[tokio::test]
637 async fn token_manager_refresh_with_mock() {
638 use wiremock::matchers::{body_string_contains, method};
639 use wiremock::{Mock, MockServer, ResponseTemplate};
640
641 let server = MockServer::start().await;
642
643 Mock::given(method("POST"))
644 .and(body_string_contains("grant_type=refresh_token"))
645 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
646 "access_token": "new_access",
647 "refresh_token": "new_refresh",
648 "expires_in": 7200,
649 "scope": "tweet.read tweet.write"
650 })))
651 .mount(&server)
652 .await;
653
654 let tokens = Tokens {
656 access_token: "old_token".to_string(),
657 refresh_token: "old_refresh".to_string(),
658 expires_at: Utc::now() + chrono::Duration::seconds(60),
659 scopes: vec![],
660 };
661
662 let dir = tempfile::tempdir().expect("temp dir");
663 let path = dir.path().join("tokens.json");
664
665 let manager = TokenManager::new(tokens, "client_id".to_string(), path);
668 let token = manager.tokens.read().await;
669 assert_eq!(token.access_token, "old_token");
670 }
671}