1use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _};
9use serde::{Deserialize, Serialize};
10use sha2::{Digest, Sha256};
11use std::fmt;
12use std::path::PathBuf;
13
14pub const X_AUTH_URL: &str = "https://twitter.com/i/oauth2/authorize";
20
21pub const X_TOKEN_URL: &str = "https://api.twitter.com/2/oauth2/token";
23
24pub const X_USERS_ME_URL: &str = "https://api.twitter.com/2/users/me";
26
27pub const OAUTH_SCOPES: &str = "tweet.read tweet.write users.read offline.access";
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36pub enum ApiTier {
37 Free,
39 Basic,
41 Pro,
43}
44
45impl fmt::Display for ApiTier {
46 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
47 match self {
48 ApiTier::Free => write!(f, "Free"),
49 ApiTier::Basic => write!(f, "Basic"),
50 ApiTier::Pro => write!(f, "Pro"),
51 }
52 }
53}
54
55#[derive(Debug, Clone)]
57pub struct TierCapabilities {
58 pub mentions: bool,
60 pub discovery: bool,
62 pub posting: bool,
64 pub search: bool,
66}
67
68impl TierCapabilities {
69 pub fn for_tier(tier: ApiTier) -> Self {
71 match tier {
72 ApiTier::Free => Self {
73 mentions: false,
74 discovery: false,
75 posting: true,
76 search: false,
77 },
78 ApiTier::Basic | ApiTier::Pro => Self {
79 mentions: true,
80 discovery: true,
81 posting: true,
82 search: true,
83 },
84 }
85 }
86
87 pub fn enabled_loop_names(&self) -> Vec<&'static str> {
89 let mut loops = Vec::new();
90 if self.mentions {
91 loops.push("mentions");
92 }
93 if self.discovery {
94 loops.push("discovery");
95 }
96 loops.push("content");
98 loops.push("threads");
99 loops
100 }
101
102 pub fn format_status(&self) -> String {
104 let status = |enabled: bool| if enabled { "enabled" } else { "DISABLED" };
105 format!(
106 "Mentions: {}, Discovery: {}, Content: enabled, Threads: enabled",
107 status(self.mentions),
108 status(self.discovery),
109 )
110 }
111}
112
113#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct StoredTokens {
120 pub access_token: String,
122
123 #[serde(default)]
125 pub refresh_token: Option<String>,
126
127 #[serde(default)]
129 pub expires_at: Option<chrono::DateTime<chrono::Utc>>,
130}
131
132impl StoredTokens {
133 pub fn is_expired(&self) -> bool {
135 match self.expires_at {
136 Some(expires) => chrono::Utc::now() >= expires,
137 None => false,
138 }
139 }
140
141 pub fn time_until_expiry(&self) -> Option<chrono::TimeDelta> {
143 self.expires_at.map(|expires| expires - chrono::Utc::now())
144 }
145
146 pub fn format_expiry(&self) -> String {
148 match self.time_until_expiry() {
149 Some(duration) if duration.num_seconds() > 0 => {
150 let hours = duration.num_hours();
151 let minutes = duration.num_minutes() % 60;
152 if hours > 0 {
153 format!("{hours}h {minutes}m")
154 } else {
155 format!("{minutes}m")
156 }
157 }
158 Some(_) => "expired".to_string(),
159 None => "no expiry set".to_string(),
160 }
161 }
162}
163
164#[derive(Debug, thiserror::Error)]
170pub enum StartupError {
171 #[error("configuration error: {0}")]
173 Config(String),
174
175 #[error("authentication required: run `tuitbot auth` first")]
177 AuthRequired,
178
179 #[error("authentication expired: run `tuitbot auth` to re-authenticate")]
181 AuthExpired,
182
183 #[error("token refresh failed: {0}")]
185 TokenRefreshFailed(String),
186
187 #[error("database error: {0}")]
189 Database(String),
190
191 #[error("LLM provider error: {0}")]
193 LlmError(String),
194
195 #[error("X API error: {0}")]
197 XApiError(String),
198
199 #[error("I/O error: {0}")]
201 Io(#[from] std::io::Error),
202
203 #[error("{0}")]
205 Other(String),
206}
207
208pub fn data_dir() -> PathBuf {
214 dirs::home_dir()
215 .unwrap_or_else(|| PathBuf::from("."))
216 .join(".tuitbot")
217}
218
219pub fn token_file_path() -> PathBuf {
221 data_dir().join("tokens.json")
222}
223
224pub fn load_tokens_from_file() -> Result<StoredTokens, StartupError> {
226 let path = token_file_path();
227 let contents = std::fs::read_to_string(&path).map_err(|e| {
228 if e.kind() == std::io::ErrorKind::NotFound {
229 StartupError::AuthRequired
230 } else {
231 StartupError::Io(e)
232 }
233 })?;
234 serde_json::from_str(&contents)
235 .map_err(|e| StartupError::Other(format!("failed to parse tokens file: {e}")))
236}
237
238pub fn save_tokens_to_file(tokens: &StoredTokens) -> Result<(), StartupError> {
243 let dir = data_dir();
244 std::fs::create_dir_all(&dir)?;
245
246 let path = token_file_path();
247 let json = serde_json::to_string_pretty(tokens)
248 .map_err(|e| StartupError::Other(format!("failed to serialize tokens: {e}")))?;
249 std::fs::write(&path, json)?;
250
251 #[cfg(unix)]
253 {
254 use std::os::unix::fs::PermissionsExt;
255 let perms = std::fs::Permissions::from_mode(0o600);
256 std::fs::set_permissions(&path, perms)?;
257 }
258
259 Ok(())
260}
261
262#[derive(Debug, Clone)]
268pub struct PkceChallenge {
269 pub verifier: String,
271 pub challenge: String,
273 pub state: String,
275}
276
277pub fn generate_pkce() -> PkceChallenge {
279 use rand::Rng;
280 let random_bytes: [u8; 32] = rand::thread_rng().gen();
281 let verifier = URL_SAFE_NO_PAD.encode(random_bytes);
282 let challenge = URL_SAFE_NO_PAD.encode(Sha256::digest(verifier.as_bytes()));
283 let state_bytes: [u8; 16] = rand::thread_rng().gen();
284 let state = URL_SAFE_NO_PAD.encode(state_bytes);
285 PkceChallenge {
286 verifier,
287 challenge,
288 state,
289 }
290}
291
292fn url_encode(s: &str) -> String {
294 let mut encoded = String::with_capacity(s.len() * 3);
295 for byte in s.bytes() {
296 match byte {
297 b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
298 encoded.push(byte as char);
299 }
300 _ => {
301 use std::fmt::Write;
302 let _ = write!(encoded, "%{byte:02X}");
303 }
304 }
305 }
306 encoded
307}
308
309pub fn build_auth_url(
311 client_id: &str,
312 redirect_uri: &str,
313 state: &str,
314 code_challenge: &str,
315) -> String {
316 format!(
317 "{}?response_type=code&client_id={}&redirect_uri={}&scope={}&state={}&code_challenge={}&code_challenge_method=S256",
318 X_AUTH_URL,
319 url_encode(client_id),
320 url_encode(redirect_uri),
321 url_encode(OAUTH_SCOPES),
322 url_encode(state),
323 url_encode(code_challenge),
324 )
325}
326
327pub fn build_redirect_uri(callback_host: &str, callback_port: u16) -> String {
329 format!("http://{callback_host}:{callback_port}/callback")
330}
331
332pub async fn exchange_auth_code(
334 client_id: &str,
335 code: &str,
336 redirect_uri: &str,
337 code_verifier: &str,
338) -> Result<StoredTokens, StartupError> {
339 let client = reqwest::Client::new();
340 let resp = client
341 .post(X_TOKEN_URL)
342 .form(&[
343 ("grant_type", "authorization_code"),
344 ("code", code),
345 ("redirect_uri", redirect_uri),
346 ("code_verifier", code_verifier),
347 ("client_id", client_id),
348 ])
349 .send()
350 .await
351 .map_err(|e| StartupError::XApiError(format!("token exchange request failed: {e}")))?;
352
353 if !resp.status().is_success() {
354 let status = resp.status();
355 let body = resp.text().await.unwrap_or_default();
356 return Err(StartupError::XApiError(format!(
357 "token exchange failed (HTTP {status}): {body}"
358 )));
359 }
360
361 #[derive(Deserialize)]
362 struct TokenResponse {
363 access_token: String,
364 refresh_token: Option<String>,
365 expires_in: Option<i64>,
366 }
367
368 let token_resp: TokenResponse = resp
369 .json()
370 .await
371 .map_err(|e| StartupError::XApiError(format!("failed to parse token response: {e}")))?;
372
373 let expires_at = token_resp
374 .expires_in
375 .map(|secs| chrono::Utc::now() + chrono::TimeDelta::seconds(secs));
376
377 Ok(StoredTokens {
378 access_token: token_resp.access_token,
379 refresh_token: token_resp.refresh_token,
380 expires_at,
381 })
382}
383
384pub async fn verify_credentials(access_token: &str) -> Result<String, StartupError> {
388 let client = reqwest::Client::new();
389 let resp = client
390 .get(X_USERS_ME_URL)
391 .bearer_auth(access_token)
392 .send()
393 .await
394 .map_err(|e| {
395 StartupError::XApiError(format!("credential verification request failed: {e}"))
396 })?;
397
398 if !resp.status().is_success() {
399 let status = resp.status();
400 let body = resp.text().await.unwrap_or_default();
401 return Err(StartupError::XApiError(format!(
402 "credential verification failed (HTTP {status}): {body}"
403 )));
404 }
405
406 #[derive(Deserialize)]
407 struct UserResponse {
408 data: UserData,
409 }
410
411 #[derive(Deserialize)]
412 struct UserData {
413 username: String,
414 }
415
416 let user: UserResponse = resp
417 .json()
418 .await
419 .map_err(|e| StartupError::XApiError(format!("failed to parse user response: {e}")))?;
420
421 Ok(user.data.username)
422}
423
424pub fn extract_auth_code(input: &str) -> String {
429 let trimmed = input.trim();
430 if trimmed.contains("code=") {
431 if let Some(query) = trimmed.split('?').nth(1) {
433 for pair in query.split('&') {
434 if let Some(value) = pair.strip_prefix("code=") {
435 return value.to_string();
436 }
437 }
438 }
439 }
440 trimmed.to_string()
441}
442
443pub fn format_startup_banner(
449 tier: ApiTier,
450 capabilities: &TierCapabilities,
451 status_interval: u64,
452) -> String {
453 let loops = capabilities.enabled_loop_names().join(", ");
454 let status = if status_interval > 0 {
455 format!("every {status_interval}s")
456 } else {
457 "disabled".to_string()
458 };
459 format!(
460 "Tuitbot v{version}\n\
461 Tier: {tier} | Loops: {loops}\n\
462 Status summary: {status}\n\
463 Press Ctrl+C to stop.",
464 version = env!("CARGO_PKG_VERSION"),
465 )
466}
467
468pub fn expand_tilde(path: &str) -> PathBuf {
474 if let Some(rest) = path.strip_prefix("~/") {
475 if let Some(home) = dirs::home_dir() {
476 return home.join(rest);
477 }
478 } else if path == "~" {
479 if let Some(home) = dirs::home_dir() {
480 return home;
481 }
482 }
483 PathBuf::from(path)
484}
485
486#[cfg(test)]
491mod tests {
492 use super::*;
493
494 #[test]
497 fn api_tier_display() {
498 assert_eq!(ApiTier::Free.to_string(), "Free");
499 assert_eq!(ApiTier::Basic.to_string(), "Basic");
500 assert_eq!(ApiTier::Pro.to_string(), "Pro");
501 }
502
503 #[test]
506 fn free_tier_capabilities() {
507 let caps = TierCapabilities::for_tier(ApiTier::Free);
508 assert!(!caps.mentions);
509 assert!(!caps.discovery);
510 assert!(caps.posting);
511 assert!(!caps.search);
512 }
513
514 #[test]
515 fn basic_tier_capabilities() {
516 let caps = TierCapabilities::for_tier(ApiTier::Basic);
517 assert!(caps.mentions);
518 assert!(caps.discovery);
519 assert!(caps.posting);
520 assert!(caps.search);
521 }
522
523 #[test]
524 fn pro_tier_capabilities() {
525 let caps = TierCapabilities::for_tier(ApiTier::Pro);
526 assert!(caps.mentions);
527 assert!(caps.discovery);
528 assert!(caps.posting);
529 assert!(caps.search);
530 }
531
532 #[test]
533 fn free_tier_enabled_loops() {
534 let caps = TierCapabilities::for_tier(ApiTier::Free);
535 let loops = caps.enabled_loop_names();
536 assert_eq!(loops, vec!["content", "threads"]);
537 }
538
539 #[test]
540 fn basic_tier_enabled_loops() {
541 let caps = TierCapabilities::for_tier(ApiTier::Basic);
542 let loops = caps.enabled_loop_names();
543 assert_eq!(loops, vec!["mentions", "discovery", "content", "threads"]);
544 }
545
546 #[test]
547 fn tier_capabilities_format_status() {
548 let caps = TierCapabilities::for_tier(ApiTier::Free);
549 let status = caps.format_status();
550 assert!(status.contains("Mentions: DISABLED"));
551 assert!(status.contains("Discovery: DISABLED"));
552
553 let caps = TierCapabilities::for_tier(ApiTier::Basic);
554 let status = caps.format_status();
555 assert!(status.contains("Mentions: enabled"));
556 assert!(status.contains("Discovery: enabled"));
557 }
558
559 #[test]
562 fn stored_tokens_not_expired() {
563 let tokens = StoredTokens {
564 access_token: "test".to_string(),
565 refresh_token: None,
566 expires_at: Some(chrono::Utc::now() + chrono::TimeDelta::hours(1)),
567 };
568 assert!(!tokens.is_expired());
569 }
570
571 #[test]
572 fn stored_tokens_expired() {
573 let tokens = StoredTokens {
574 access_token: "test".to_string(),
575 refresh_token: None,
576 expires_at: Some(chrono::Utc::now() - chrono::TimeDelta::hours(1)),
577 };
578 assert!(tokens.is_expired());
579 }
580
581 #[test]
582 fn stored_tokens_no_expiry_is_not_expired() {
583 let tokens = StoredTokens {
584 access_token: "test".to_string(),
585 refresh_token: None,
586 expires_at: None,
587 };
588 assert!(!tokens.is_expired());
589 }
590
591 #[test]
592 fn stored_tokens_format_expiry_hours() {
593 let tokens = StoredTokens {
594 access_token: "test".to_string(),
595 refresh_token: None,
596 expires_at: Some(chrono::Utc::now() + chrono::TimeDelta::minutes(102)),
597 };
598 let formatted = tokens.format_expiry();
599 assert!(formatted.contains("h"));
600 assert!(formatted.contains("m"));
601 }
602
603 #[test]
604 fn stored_tokens_format_expiry_minutes_only() {
605 let tokens = StoredTokens {
606 access_token: "test".to_string(),
607 refresh_token: None,
608 expires_at: Some(chrono::Utc::now() + chrono::TimeDelta::minutes(30)),
609 };
610 let formatted = tokens.format_expiry();
611 assert!(formatted.contains("m"));
612 assert!(!formatted.contains("h"));
613 }
614
615 #[test]
616 fn stored_tokens_format_expiry_expired() {
617 let tokens = StoredTokens {
618 access_token: "test".to_string(),
619 refresh_token: None,
620 expires_at: Some(chrono::Utc::now() - chrono::TimeDelta::hours(1)),
621 };
622 assert_eq!(tokens.format_expiry(), "expired");
623 }
624
625 #[test]
626 fn stored_tokens_format_expiry_no_expiry() {
627 let tokens = StoredTokens {
628 access_token: "test".to_string(),
629 refresh_token: None,
630 expires_at: None,
631 };
632 assert_eq!(tokens.format_expiry(), "no expiry set");
633 }
634
635 #[test]
636 fn stored_tokens_serialization_roundtrip() {
637 let tokens = StoredTokens {
638 access_token: "access123".to_string(),
639 refresh_token: Some("refresh456".to_string()),
640 expires_at: Some(
641 chrono::DateTime::parse_from_rfc3339("2026-06-01T12:00:00Z")
642 .expect("valid datetime")
643 .with_timezone(&chrono::Utc),
644 ),
645 };
646 let json = serde_json::to_string(&tokens).expect("serialize");
647 let deserialized: StoredTokens = serde_json::from_str(&json).expect("deserialize");
648 assert_eq!(deserialized.access_token, "access123");
649 assert_eq!(deserialized.refresh_token.as_deref(), Some("refresh456"));
650 assert!(deserialized.expires_at.is_some());
651 }
652
653 #[test]
656 fn save_and_load_tokens() {
657 let dir = tempfile::tempdir().expect("tempdir");
658 let path = dir.path().join("tokens.json");
659
660 let tokens = StoredTokens {
661 access_token: "test_access".to_string(),
662 refresh_token: Some("test_refresh".to_string()),
663 expires_at: None,
664 };
665
666 let json = serde_json::to_string_pretty(&tokens).expect("serialize");
667 std::fs::write(&path, &json).expect("write");
668
669 let contents = std::fs::read_to_string(&path).expect("read");
670 let loaded: StoredTokens = serde_json::from_str(&contents).expect("deserialize");
671 assert_eq!(loaded.access_token, "test_access");
672 assert_eq!(loaded.refresh_token.as_deref(), Some("test_refresh"));
673 }
674
675 #[cfg(unix)]
676 #[test]
677 fn save_tokens_sets_permissions() {
678 use std::os::unix::fs::PermissionsExt;
679
680 let dir = tempfile::tempdir().expect("tempdir");
681 let path = dir.path().join("tokens.json");
683 let tokens = StoredTokens {
684 access_token: "test".to_string(),
685 refresh_token: None,
686 expires_at: None,
687 };
688 let json = serde_json::to_string_pretty(&tokens).expect("serialize");
689 std::fs::write(&path, &json).expect("write");
690 let perms = std::fs::Permissions::from_mode(0o600);
691 std::fs::set_permissions(&path, perms).expect("set perms");
692
693 let meta = std::fs::metadata(&path).expect("metadata");
694 assert_eq!(meta.permissions().mode() & 0o777, 0o600);
695 }
696
697 #[test]
700 fn startup_error_display() {
701 let err = StartupError::AuthRequired;
702 assert_eq!(
703 err.to_string(),
704 "authentication required: run `tuitbot auth` first"
705 );
706
707 let err = StartupError::AuthExpired;
708 assert!(err.to_string().contains("expired"));
709
710 let err = StartupError::Config("bad field".to_string());
711 assert_eq!(err.to_string(), "configuration error: bad field");
712
713 let err = StartupError::XApiError("timeout".to_string());
714 assert_eq!(err.to_string(), "X API error: timeout");
715 }
716
717 #[test]
720 fn generate_pkce_produces_valid_challenge() {
721 let pkce = generate_pkce();
722 assert_eq!(pkce.verifier.len(), 43);
724 assert_eq!(pkce.challenge.len(), 43);
726 assert_eq!(pkce.state.len(), 22);
728 let expected = URL_SAFE_NO_PAD.encode(Sha256::digest(pkce.verifier.as_bytes()));
730 assert_eq!(pkce.challenge, expected);
731 }
732
733 #[test]
734 fn generate_pkce_unique_each_time() {
735 let a = generate_pkce();
736 let b = generate_pkce();
737 assert_ne!(a.verifier, b.verifier);
738 assert_ne!(a.challenge, b.challenge);
739 assert_ne!(a.state, b.state);
740 }
741
742 #[test]
745 fn build_auth_url_contains_required_params() {
746 let url = build_auth_url(
747 "client123",
748 "http://localhost:8080/callback",
749 "state456",
750 "challenge789",
751 );
752 assert!(url.starts_with(X_AUTH_URL));
753 assert!(url.contains("response_type=code"));
754 assert!(url.contains("client_id=client123"));
755 assert!(url.contains("code_challenge=challenge789"));
756 assert!(url.contains("code_challenge_method=S256"));
757 assert!(url.contains("state=state456"));
758 assert!(url.contains("redirect_uri=http%3A%2F%2Flocalhost%3A8080%2Fcallback"));
760 }
761
762 #[test]
763 fn build_redirect_uri_format() {
764 let uri = build_redirect_uri("127.0.0.1", 8080);
765 assert_eq!(uri, "http://127.0.0.1:8080/callback");
766 }
767
768 #[test]
771 fn extract_code_from_full_url() {
772 let code = extract_auth_code("http://127.0.0.1:8080/callback?code=abc123&state=xyz");
773 assert_eq!(code, "abc123");
774 }
775
776 #[test]
777 fn extract_code_from_bare_code() {
778 let code = extract_auth_code(" abc123 ");
779 assert_eq!(code, "abc123");
780 }
781
782 #[test]
783 fn extract_code_from_url_without_state() {
784 let code = extract_auth_code("http://127.0.0.1:8080/callback?code=mycode");
785 assert_eq!(code, "mycode");
786 }
787
788 #[test]
791 fn url_encode_basic() {
792 assert_eq!(url_encode("hello"), "hello");
793 assert_eq!(url_encode("hello world"), "hello%20world");
794 assert_eq!(
795 url_encode("http://localhost:8080/callback"),
796 "http%3A%2F%2Flocalhost%3A8080%2Fcallback"
797 );
798 }
799
800 #[test]
803 fn startup_banner_free_tier() {
804 let caps = TierCapabilities::for_tier(ApiTier::Free);
805 let banner = format_startup_banner(ApiTier::Free, &caps, 300);
806 assert!(banner.contains("Tuitbot v"));
807 assert!(banner.contains("Tier: Free"));
808 assert!(!banner.contains("mentions"));
809 assert!(banner.contains("content"));
810 assert!(banner.contains("threads"));
811 assert!(!banner.contains("discovery"));
812 assert!(banner.contains("every 300s"));
813 }
814
815 #[test]
816 fn startup_banner_basic_tier() {
817 let caps = TierCapabilities::for_tier(ApiTier::Basic);
818 let banner = format_startup_banner(ApiTier::Basic, &caps, 0);
819 assert!(banner.contains("Tier: Basic"));
820 assert!(banner.contains("discovery"));
821 assert!(banner.contains("disabled"));
822 }
823
824 #[test]
825 fn startup_banner_contains_ctrl_c_hint() {
826 let caps = TierCapabilities::for_tier(ApiTier::Free);
827 let banner = format_startup_banner(ApiTier::Free, &caps, 0);
828 assert!(banner.contains("Ctrl+C"));
829 }
830
831 #[test]
834 fn expand_tilde_works() {
835 let expanded = expand_tilde("~/.tuitbot/config.toml");
836 assert!(!expanded.to_string_lossy().starts_with('~'));
837 }
838
839 #[test]
840 fn expand_tilde_no_tilde() {
841 let expanded = expand_tilde("/absolute/path");
842 assert_eq!(expanded, PathBuf::from("/absolute/path"));
843 }
844
845 #[test]
846 fn data_dir_under_home() {
847 let dir = data_dir();
848 assert!(dir.to_string_lossy().contains(".tuitbot"));
849 }
850
851 #[test]
852 fn token_file_path_under_data_dir() {
853 let path = token_file_path();
854 assert!(path.to_string_lossy().contains("tokens.json"));
855 assert!(path.to_string_lossy().contains(".tuitbot"));
856 }
857}