Skip to main content

tuitbot_core/
startup.rs

1//! Startup types and helpers for Tuitbot CLI commands.
2//!
3//! Provides API tier detection types, OAuth token management,
4//! PKCE authentication helpers, startup banner formatting, and
5//! diagnostic check types used by the `run`, `auth`, and `test`
6//! CLI commands.
7
8use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _};
9use serde::{Deserialize, Serialize};
10use sha2::{Digest, Sha256};
11use std::fmt;
12use std::path::PathBuf;
13
14// ============================================================================
15// X API OAuth 2.0 endpoints
16// ============================================================================
17
18/// X API OAuth 2.0 authorization endpoint.
19pub const X_AUTH_URL: &str = "https://twitter.com/i/oauth2/authorize";
20
21/// X API OAuth 2.0 token endpoint.
22pub const X_TOKEN_URL: &str = "https://api.twitter.com/2/oauth2/token";
23
24/// X API users/me endpoint for credential verification.
25pub const X_USERS_ME_URL: &str = "https://api.twitter.com/2/users/me";
26
27/// OAuth scopes required by Tuitbot.
28pub const OAUTH_SCOPES: &str = "tweet.read tweet.write users.read offline.access";
29
30// ============================================================================
31// API Tier
32// ============================================================================
33
34/// Detected X API tier.
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36pub enum ApiTier {
37    /// Free tier -- posting only (no search, no mentions).
38    Free,
39    /// Basic tier -- adds search/discovery.
40    Basic,
41    /// Pro tier -- all features.
42    Pro,
43}
44
45impl fmt::Display for ApiTier {
46    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
47        match self {
48            ApiTier::Free => write!(f, "Free"),
49            ApiTier::Basic => write!(f, "Basic"),
50            ApiTier::Pro => write!(f, "Pro"),
51        }
52    }
53}
54
55/// Capabilities enabled by the current API tier.
56#[derive(Debug, Clone)]
57pub struct TierCapabilities {
58    /// Whether the mentions loop can run.
59    pub mentions: bool,
60    /// Whether the discovery/search loop can run.
61    pub discovery: bool,
62    /// Whether posting (tweets + threads) is available.
63    pub posting: bool,
64    /// Whether tweet search is available.
65    pub search: bool,
66}
67
68impl TierCapabilities {
69    /// Determine capabilities for a given tier.
70    pub fn for_tier(tier: ApiTier) -> Self {
71        match tier {
72            ApiTier::Free => Self {
73                mentions: false,
74                discovery: false,
75                posting: true,
76                search: false,
77            },
78            ApiTier::Basic | ApiTier::Pro => Self {
79                mentions: true,
80                discovery: true,
81                posting: true,
82                search: true,
83            },
84        }
85    }
86
87    /// List the names of enabled automation loops.
88    pub fn enabled_loop_names(&self) -> Vec<&'static str> {
89        let mut loops = Vec::new();
90        if self.mentions {
91            loops.push("mentions");
92        }
93        if self.discovery {
94            loops.push("discovery");
95        }
96        // Content and threads are always enabled (no special tier required).
97        loops.push("content");
98        loops.push("threads");
99        loops
100    }
101
102    /// Format the tier capabilities as a status line.
103    pub fn format_status(&self) -> String {
104        let status = |enabled: bool| if enabled { "enabled" } else { "DISABLED" };
105        format!(
106            "Mentions: {}, Discovery: {}, Content: enabled, Threads: enabled",
107            status(self.mentions),
108            status(self.discovery),
109        )
110    }
111}
112
113// ============================================================================
114// Stored Tokens
115// ============================================================================
116
117/// OAuth tokens persisted to disk at `~/.tuitbot/tokens.json`.
118#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct StoredTokens {
120    /// OAuth 2.0 access token.
121    pub access_token: String,
122
123    /// OAuth 2.0 refresh token (for offline.access scope).
124    #[serde(default)]
125    pub refresh_token: Option<String>,
126
127    /// Token expiration timestamp.
128    #[serde(default)]
129    pub expires_at: Option<chrono::DateTime<chrono::Utc>>,
130}
131
132impl StoredTokens {
133    /// Check if the token has expired.
134    pub fn is_expired(&self) -> bool {
135        match self.expires_at {
136            Some(expires) => chrono::Utc::now() >= expires,
137            None => false,
138        }
139    }
140
141    /// Time remaining until token expires.
142    pub fn time_until_expiry(&self) -> Option<chrono::TimeDelta> {
143        self.expires_at.map(|expires| expires - chrono::Utc::now())
144    }
145
146    /// Format time until expiry as a human-readable string.
147    pub fn format_expiry(&self) -> String {
148        match self.time_until_expiry() {
149            Some(duration) if duration.num_seconds() > 0 => {
150                let hours = duration.num_hours();
151                let minutes = duration.num_minutes() % 60;
152                if hours > 0 {
153                    format!("{hours}h {minutes}m")
154                } else {
155                    format!("{minutes}m")
156                }
157            }
158            Some(_) => "expired".to_string(),
159            None => "no expiry set".to_string(),
160        }
161    }
162}
163
164// ============================================================================
165// Startup Error
166// ============================================================================
167
168/// Errors that can occur during startup operations.
169#[derive(Debug, thiserror::Error)]
170pub enum StartupError {
171    /// Configuration is invalid or missing.
172    #[error("configuration error: {0}")]
173    Config(String),
174
175    /// No tokens found -- user needs to authenticate first.
176    #[error("authentication required: run `tuitbot auth` first")]
177    AuthRequired,
178
179    /// Tokens are expired and need re-authentication.
180    #[error("authentication expired: run `tuitbot auth` to re-authenticate")]
181    AuthExpired,
182
183    /// Token refresh attempt failed.
184    #[error("token refresh failed: {0}")]
185    TokenRefreshFailed(String),
186
187    /// Database initialization or access error.
188    #[error("database error: {0}")]
189    Database(String),
190
191    /// LLM provider configuration or connectivity error.
192    #[error("LLM provider error: {0}")]
193    LlmError(String),
194
195    /// X API communication error.
196    #[error("X API error: {0}")]
197    XApiError(String),
198
199    /// File I/O error.
200    #[error("I/O error: {0}")]
201    Io(#[from] std::io::Error),
202
203    /// Any other error.
204    #[error("{0}")]
205    Other(String),
206}
207
208// ============================================================================
209// Token File I/O
210// ============================================================================
211
212/// Default directory for Tuitbot data files (`~/.tuitbot/`).
213pub fn data_dir() -> PathBuf {
214    dirs::home_dir()
215        .unwrap_or_else(|| PathBuf::from("."))
216        .join(".tuitbot")
217}
218
219/// Path to the token storage file (`~/.tuitbot/tokens.json`).
220pub fn token_file_path() -> PathBuf {
221    data_dir().join("tokens.json")
222}
223
224/// Load OAuth tokens from the default file path.
225pub fn load_tokens_from_file() -> Result<StoredTokens, StartupError> {
226    let path = token_file_path();
227    let contents = std::fs::read_to_string(&path).map_err(|e| {
228        if e.kind() == std::io::ErrorKind::NotFound {
229            StartupError::AuthRequired
230        } else {
231            StartupError::Io(e)
232        }
233    })?;
234    serde_json::from_str(&contents)
235        .map_err(|e| StartupError::Other(format!("failed to parse tokens file: {e}")))
236}
237
238/// Save OAuth tokens to the default file path with secure permissions.
239///
240/// Creates the `~/.tuitbot/` directory if it does not exist.
241/// On Unix, sets file permissions to 0600 (owner read/write only).
242pub fn save_tokens_to_file(tokens: &StoredTokens) -> Result<(), StartupError> {
243    let dir = data_dir();
244    std::fs::create_dir_all(&dir)?;
245
246    let path = token_file_path();
247    let json = serde_json::to_string_pretty(tokens)
248        .map_err(|e| StartupError::Other(format!("failed to serialize tokens: {e}")))?;
249    std::fs::write(&path, json)?;
250
251    // Set file permissions to 0600 on Unix (owner read/write only).
252    #[cfg(unix)]
253    {
254        use std::os::unix::fs::PermissionsExt;
255        let perms = std::fs::Permissions::from_mode(0o600);
256        std::fs::set_permissions(&path, perms)?;
257    }
258
259    Ok(())
260}
261
262// ============================================================================
263// PKCE Authentication
264// ============================================================================
265
266/// PKCE code verifier and challenge pair.
267#[derive(Debug, Clone)]
268pub struct PkceChallenge {
269    /// The code verifier (sent during token exchange).
270    pub verifier: String,
271    /// The code challenge (sent during authorization).
272    pub challenge: String,
273    /// CSRF state parameter.
274    pub state: String,
275}
276
277/// Generate a PKCE code verifier, challenge, and state parameter.
278pub fn generate_pkce() -> PkceChallenge {
279    use rand::Rng;
280    let random_bytes: [u8; 32] = rand::thread_rng().gen();
281    let verifier = URL_SAFE_NO_PAD.encode(random_bytes);
282    let challenge = URL_SAFE_NO_PAD.encode(Sha256::digest(verifier.as_bytes()));
283    let state_bytes: [u8; 16] = rand::thread_rng().gen();
284    let state = URL_SAFE_NO_PAD.encode(state_bytes);
285    PkceChallenge {
286        verifier,
287        challenge,
288        state,
289    }
290}
291
292/// Percent-encode a string for use in URL query parameters (RFC 3986).
293fn url_encode(s: &str) -> String {
294    let mut encoded = String::with_capacity(s.len() * 3);
295    for byte in s.bytes() {
296        match byte {
297            b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
298                encoded.push(byte as char);
299            }
300            _ => {
301                use std::fmt::Write;
302                let _ = write!(encoded, "%{byte:02X}");
303            }
304        }
305    }
306    encoded
307}
308
309/// Build the X API OAuth 2.0 authorization URL.
310pub fn build_auth_url(
311    client_id: &str,
312    redirect_uri: &str,
313    state: &str,
314    code_challenge: &str,
315) -> String {
316    format!(
317        "{}?response_type=code&client_id={}&redirect_uri={}&scope={}&state={}&code_challenge={}&code_challenge_method=S256",
318        X_AUTH_URL,
319        url_encode(client_id),
320        url_encode(redirect_uri),
321        url_encode(OAUTH_SCOPES),
322        url_encode(state),
323        url_encode(code_challenge),
324    )
325}
326
327/// Build the redirect URI from config auth settings.
328pub fn build_redirect_uri(callback_host: &str, callback_port: u16) -> String {
329    format!("http://{callback_host}:{callback_port}/callback")
330}
331
332/// Exchange an authorization code for OAuth tokens.
333pub async fn exchange_auth_code(
334    client_id: &str,
335    code: &str,
336    redirect_uri: &str,
337    code_verifier: &str,
338) -> Result<StoredTokens, StartupError> {
339    let client = reqwest::Client::new();
340    let resp = client
341        .post(X_TOKEN_URL)
342        .form(&[
343            ("grant_type", "authorization_code"),
344            ("code", code),
345            ("redirect_uri", redirect_uri),
346            ("code_verifier", code_verifier),
347            ("client_id", client_id),
348        ])
349        .send()
350        .await
351        .map_err(|e| StartupError::XApiError(format!("token exchange request failed: {e}")))?;
352
353    if !resp.status().is_success() {
354        let status = resp.status();
355        let body = resp.text().await.unwrap_or_default();
356        return Err(StartupError::XApiError(format!(
357            "token exchange failed (HTTP {status}): {body}"
358        )));
359    }
360
361    #[derive(Deserialize)]
362    struct TokenResponse {
363        access_token: String,
364        refresh_token: Option<String>,
365        expires_in: Option<i64>,
366    }
367
368    let token_resp: TokenResponse = resp
369        .json()
370        .await
371        .map_err(|e| StartupError::XApiError(format!("failed to parse token response: {e}")))?;
372
373    let expires_at = token_resp
374        .expires_in
375        .map(|secs| chrono::Utc::now() + chrono::TimeDelta::seconds(secs));
376
377    Ok(StoredTokens {
378        access_token: token_resp.access_token,
379        refresh_token: token_resp.refresh_token,
380        expires_at,
381    })
382}
383
384/// Verify OAuth credentials by calling the X API /2/users/me endpoint.
385///
386/// Returns the authenticated user's username on success.
387pub async fn verify_credentials(access_token: &str) -> Result<String, StartupError> {
388    let client = reqwest::Client::new();
389    let resp = client
390        .get(X_USERS_ME_URL)
391        .bearer_auth(access_token)
392        .send()
393        .await
394        .map_err(|e| {
395            StartupError::XApiError(format!("credential verification request failed: {e}"))
396        })?;
397
398    if !resp.status().is_success() {
399        let status = resp.status();
400        let body = resp.text().await.unwrap_or_default();
401        return Err(StartupError::XApiError(format!(
402            "credential verification failed (HTTP {status}): {body}"
403        )));
404    }
405
406    #[derive(Deserialize)]
407    struct UserResponse {
408        data: UserData,
409    }
410
411    #[derive(Deserialize)]
412    struct UserData {
413        username: String,
414    }
415
416    let user: UserResponse = resp
417        .json()
418        .await
419        .map_err(|e| StartupError::XApiError(format!("failed to parse user response: {e}")))?;
420
421    Ok(user.data.username)
422}
423
424/// Extract the authorization code from a callback URL or raw code string.
425///
426/// Accepts either a full URL (e.g., `http://127.0.0.1:8080/callback?code=XXX&state=YYY`)
427/// or a bare authorization code.
428pub fn extract_auth_code(input: &str) -> String {
429    let trimmed = input.trim();
430    if trimmed.contains("code=") {
431        // Parse code from URL query parameters.
432        if let Some(query) = trimmed.split('?').nth(1) {
433            for pair in query.split('&') {
434                if let Some(value) = pair.strip_prefix("code=") {
435                    return value.to_string();
436                }
437            }
438        }
439    }
440    trimmed.to_string()
441}
442
443// ============================================================================
444// Startup Banner
445// ============================================================================
446
447/// Format the startup banner printed when the agent starts.
448pub fn format_startup_banner(
449    tier: ApiTier,
450    capabilities: &TierCapabilities,
451    status_interval: u64,
452) -> String {
453    let loops = capabilities.enabled_loop_names().join(", ");
454    let status = if status_interval > 0 {
455        format!("every {status_interval}s")
456    } else {
457        "disabled".to_string()
458    };
459    format!(
460        "Tuitbot v{version}\n\
461         Tier: {tier} | Loops: {loops}\n\
462         Status summary: {status}\n\
463         Press Ctrl+C to stop.",
464        version = env!("CARGO_PKG_VERSION"),
465    )
466}
467
468// ============================================================================
469// Path Helpers
470// ============================================================================
471
472/// Expand `~` at the start of a path to the user's home directory.
473pub fn expand_tilde(path: &str) -> PathBuf {
474    if let Some(rest) = path.strip_prefix("~/") {
475        if let Some(home) = dirs::home_dir() {
476            return home.join(rest);
477        }
478    } else if path == "~" {
479        if let Some(home) = dirs::home_dir() {
480            return home;
481        }
482    }
483    PathBuf::from(path)
484}
485
486// ============================================================================
487// Tests
488// ============================================================================
489
490#[cfg(test)]
491mod tests {
492    use super::*;
493
494    // --- ApiTier ---
495
496    #[test]
497    fn api_tier_display() {
498        assert_eq!(ApiTier::Free.to_string(), "Free");
499        assert_eq!(ApiTier::Basic.to_string(), "Basic");
500        assert_eq!(ApiTier::Pro.to_string(), "Pro");
501    }
502
503    // --- TierCapabilities ---
504
505    #[test]
506    fn free_tier_capabilities() {
507        let caps = TierCapabilities::for_tier(ApiTier::Free);
508        assert!(!caps.mentions);
509        assert!(!caps.discovery);
510        assert!(caps.posting);
511        assert!(!caps.search);
512    }
513
514    #[test]
515    fn basic_tier_capabilities() {
516        let caps = TierCapabilities::for_tier(ApiTier::Basic);
517        assert!(caps.mentions);
518        assert!(caps.discovery);
519        assert!(caps.posting);
520        assert!(caps.search);
521    }
522
523    #[test]
524    fn pro_tier_capabilities() {
525        let caps = TierCapabilities::for_tier(ApiTier::Pro);
526        assert!(caps.mentions);
527        assert!(caps.discovery);
528        assert!(caps.posting);
529        assert!(caps.search);
530    }
531
532    #[test]
533    fn free_tier_enabled_loops() {
534        let caps = TierCapabilities::for_tier(ApiTier::Free);
535        let loops = caps.enabled_loop_names();
536        assert_eq!(loops, vec!["content", "threads"]);
537    }
538
539    #[test]
540    fn basic_tier_enabled_loops() {
541        let caps = TierCapabilities::for_tier(ApiTier::Basic);
542        let loops = caps.enabled_loop_names();
543        assert_eq!(loops, vec!["mentions", "discovery", "content", "threads"]);
544    }
545
546    #[test]
547    fn tier_capabilities_format_status() {
548        let caps = TierCapabilities::for_tier(ApiTier::Free);
549        let status = caps.format_status();
550        assert!(status.contains("Mentions: DISABLED"));
551        assert!(status.contains("Discovery: DISABLED"));
552
553        let caps = TierCapabilities::for_tier(ApiTier::Basic);
554        let status = caps.format_status();
555        assert!(status.contains("Mentions: enabled"));
556        assert!(status.contains("Discovery: enabled"));
557    }
558
559    // --- StoredTokens ---
560
561    #[test]
562    fn stored_tokens_not_expired() {
563        let tokens = StoredTokens {
564            access_token: "test".to_string(),
565            refresh_token: None,
566            expires_at: Some(chrono::Utc::now() + chrono::TimeDelta::hours(1)),
567        };
568        assert!(!tokens.is_expired());
569    }
570
571    #[test]
572    fn stored_tokens_expired() {
573        let tokens = StoredTokens {
574            access_token: "test".to_string(),
575            refresh_token: None,
576            expires_at: Some(chrono::Utc::now() - chrono::TimeDelta::hours(1)),
577        };
578        assert!(tokens.is_expired());
579    }
580
581    #[test]
582    fn stored_tokens_no_expiry_is_not_expired() {
583        let tokens = StoredTokens {
584            access_token: "test".to_string(),
585            refresh_token: None,
586            expires_at: None,
587        };
588        assert!(!tokens.is_expired());
589    }
590
591    #[test]
592    fn stored_tokens_format_expiry_hours() {
593        let tokens = StoredTokens {
594            access_token: "test".to_string(),
595            refresh_token: None,
596            expires_at: Some(chrono::Utc::now() + chrono::TimeDelta::minutes(102)),
597        };
598        let formatted = tokens.format_expiry();
599        assert!(formatted.contains("h"));
600        assert!(formatted.contains("m"));
601    }
602
603    #[test]
604    fn stored_tokens_format_expiry_minutes_only() {
605        let tokens = StoredTokens {
606            access_token: "test".to_string(),
607            refresh_token: None,
608            expires_at: Some(chrono::Utc::now() + chrono::TimeDelta::minutes(30)),
609        };
610        let formatted = tokens.format_expiry();
611        assert!(formatted.contains("m"));
612        assert!(!formatted.contains("h"));
613    }
614
615    #[test]
616    fn stored_tokens_format_expiry_expired() {
617        let tokens = StoredTokens {
618            access_token: "test".to_string(),
619            refresh_token: None,
620            expires_at: Some(chrono::Utc::now() - chrono::TimeDelta::hours(1)),
621        };
622        assert_eq!(tokens.format_expiry(), "expired");
623    }
624
625    #[test]
626    fn stored_tokens_format_expiry_no_expiry() {
627        let tokens = StoredTokens {
628            access_token: "test".to_string(),
629            refresh_token: None,
630            expires_at: None,
631        };
632        assert_eq!(tokens.format_expiry(), "no expiry set");
633    }
634
635    #[test]
636    fn stored_tokens_serialization_roundtrip() {
637        let tokens = StoredTokens {
638            access_token: "access123".to_string(),
639            refresh_token: Some("refresh456".to_string()),
640            expires_at: Some(
641                chrono::DateTime::parse_from_rfc3339("2026-06-01T12:00:00Z")
642                    .expect("valid datetime")
643                    .with_timezone(&chrono::Utc),
644            ),
645        };
646        let json = serde_json::to_string(&tokens).expect("serialize");
647        let deserialized: StoredTokens = serde_json::from_str(&json).expect("deserialize");
648        assert_eq!(deserialized.access_token, "access123");
649        assert_eq!(deserialized.refresh_token.as_deref(), Some("refresh456"));
650        assert!(deserialized.expires_at.is_some());
651    }
652
653    // --- Token File I/O ---
654
655    #[test]
656    fn save_and_load_tokens() {
657        let dir = tempfile::tempdir().expect("tempdir");
658        let path = dir.path().join("tokens.json");
659
660        let tokens = StoredTokens {
661            access_token: "test_access".to_string(),
662            refresh_token: Some("test_refresh".to_string()),
663            expires_at: None,
664        };
665
666        let json = serde_json::to_string_pretty(&tokens).expect("serialize");
667        std::fs::write(&path, &json).expect("write");
668
669        let contents = std::fs::read_to_string(&path).expect("read");
670        let loaded: StoredTokens = serde_json::from_str(&contents).expect("deserialize");
671        assert_eq!(loaded.access_token, "test_access");
672        assert_eq!(loaded.refresh_token.as_deref(), Some("test_refresh"));
673    }
674
675    #[cfg(unix)]
676    #[test]
677    fn save_tokens_sets_permissions() {
678        use std::os::unix::fs::PermissionsExt;
679
680        let dir = tempfile::tempdir().expect("tempdir");
681        // Override data dir for this test by saving directly.
682        let path = dir.path().join("tokens.json");
683        let tokens = StoredTokens {
684            access_token: "test".to_string(),
685            refresh_token: None,
686            expires_at: None,
687        };
688        let json = serde_json::to_string_pretty(&tokens).expect("serialize");
689        std::fs::write(&path, &json).expect("write");
690        let perms = std::fs::Permissions::from_mode(0o600);
691        std::fs::set_permissions(&path, perms).expect("set perms");
692
693        let meta = std::fs::metadata(&path).expect("metadata");
694        assert_eq!(meta.permissions().mode() & 0o777, 0o600);
695    }
696
697    // --- Startup Error ---
698
699    #[test]
700    fn startup_error_display() {
701        let err = StartupError::AuthRequired;
702        assert_eq!(
703            err.to_string(),
704            "authentication required: run `tuitbot auth` first"
705        );
706
707        let err = StartupError::AuthExpired;
708        assert!(err.to_string().contains("expired"));
709
710        let err = StartupError::Config("bad field".to_string());
711        assert_eq!(err.to_string(), "configuration error: bad field");
712
713        let err = StartupError::XApiError("timeout".to_string());
714        assert_eq!(err.to_string(), "X API error: timeout");
715    }
716
717    // --- PKCE ---
718
719    #[test]
720    fn generate_pkce_produces_valid_challenge() {
721        let pkce = generate_pkce();
722        // Verifier should be 43 characters (32 bytes base64url encoded).
723        assert_eq!(pkce.verifier.len(), 43);
724        // Challenge should be 43 characters (32 bytes SHA-256 hash, base64url).
725        assert_eq!(pkce.challenge.len(), 43);
726        // State should be 22 characters (16 bytes base64url encoded).
727        assert_eq!(pkce.state.len(), 22);
728        // Verify the challenge matches the verifier.
729        let expected = URL_SAFE_NO_PAD.encode(Sha256::digest(pkce.verifier.as_bytes()));
730        assert_eq!(pkce.challenge, expected);
731    }
732
733    #[test]
734    fn generate_pkce_unique_each_time() {
735        let a = generate_pkce();
736        let b = generate_pkce();
737        assert_ne!(a.verifier, b.verifier);
738        assert_ne!(a.challenge, b.challenge);
739        assert_ne!(a.state, b.state);
740    }
741
742    // --- URL Building ---
743
744    #[test]
745    fn build_auth_url_contains_required_params() {
746        let url = build_auth_url(
747            "client123",
748            "http://localhost:8080/callback",
749            "state456",
750            "challenge789",
751        );
752        assert!(url.starts_with(X_AUTH_URL));
753        assert!(url.contains("response_type=code"));
754        assert!(url.contains("client_id=client123"));
755        assert!(url.contains("code_challenge=challenge789"));
756        assert!(url.contains("code_challenge_method=S256"));
757        assert!(url.contains("state=state456"));
758        // redirect_uri should be encoded.
759        assert!(url.contains("redirect_uri=http%3A%2F%2Flocalhost%3A8080%2Fcallback"));
760    }
761
762    #[test]
763    fn build_redirect_uri_format() {
764        let uri = build_redirect_uri("127.0.0.1", 8080);
765        assert_eq!(uri, "http://127.0.0.1:8080/callback");
766    }
767
768    // --- extract_auth_code ---
769
770    #[test]
771    fn extract_code_from_full_url() {
772        let code = extract_auth_code("http://127.0.0.1:8080/callback?code=abc123&state=xyz");
773        assert_eq!(code, "abc123");
774    }
775
776    #[test]
777    fn extract_code_from_bare_code() {
778        let code = extract_auth_code("  abc123  ");
779        assert_eq!(code, "abc123");
780    }
781
782    #[test]
783    fn extract_code_from_url_without_state() {
784        let code = extract_auth_code("http://127.0.0.1:8080/callback?code=mycode");
785        assert_eq!(code, "mycode");
786    }
787
788    // --- URL Encoding ---
789
790    #[test]
791    fn url_encode_basic() {
792        assert_eq!(url_encode("hello"), "hello");
793        assert_eq!(url_encode("hello world"), "hello%20world");
794        assert_eq!(
795            url_encode("http://localhost:8080/callback"),
796            "http%3A%2F%2Flocalhost%3A8080%2Fcallback"
797        );
798    }
799
800    // --- Startup Banner ---
801
802    #[test]
803    fn startup_banner_free_tier() {
804        let caps = TierCapabilities::for_tier(ApiTier::Free);
805        let banner = format_startup_banner(ApiTier::Free, &caps, 300);
806        assert!(banner.contains("Tuitbot v"));
807        assert!(banner.contains("Tier: Free"));
808        assert!(!banner.contains("mentions"));
809        assert!(banner.contains("content"));
810        assert!(banner.contains("threads"));
811        assert!(!banner.contains("discovery"));
812        assert!(banner.contains("every 300s"));
813    }
814
815    #[test]
816    fn startup_banner_basic_tier() {
817        let caps = TierCapabilities::for_tier(ApiTier::Basic);
818        let banner = format_startup_banner(ApiTier::Basic, &caps, 0);
819        assert!(banner.contains("Tier: Basic"));
820        assert!(banner.contains("discovery"));
821        assert!(banner.contains("disabled"));
822    }
823
824    #[test]
825    fn startup_banner_contains_ctrl_c_hint() {
826        let caps = TierCapabilities::for_tier(ApiTier::Free);
827        let banner = format_startup_banner(ApiTier::Free, &caps, 0);
828        assert!(banner.contains("Ctrl+C"));
829    }
830
831    // --- Path Helpers ---
832
833    #[test]
834    fn expand_tilde_works() {
835        let expanded = expand_tilde("~/.tuitbot/config.toml");
836        assert!(!expanded.to_string_lossy().starts_with('~'));
837    }
838
839    #[test]
840    fn expand_tilde_no_tilde() {
841        let expanded = expand_tilde("/absolute/path");
842        assert_eq!(expanded, PathBuf::from("/absolute/path"));
843    }
844
845    #[test]
846    fn data_dir_under_home() {
847        let dir = data_dir();
848        assert!(dir.to_string_lossy().contains(".tuitbot"));
849    }
850
851    #[test]
852    fn token_file_path_under_data_dir() {
853        let path = token_file_path();
854        assert!(path.to_string_lossy().contains("tokens.json"));
855        assert!(path.to_string_lossy().contains(".tuitbot"));
856    }
857}