Skip to main content

tuitbot_core/
startup.rs

1//! Startup types and helpers for Tuitbot CLI commands.
2//!
3//! Provides API tier detection types, OAuth token management,
4//! PKCE authentication helpers, startup banner formatting, and
5//! diagnostic check types used by the `run`, `auth`, and `test`
6//! CLI commands.
7
8use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _};
9use serde::{Deserialize, Serialize};
10use sha2::{Digest, Sha256};
11use std::fmt;
12use std::path::PathBuf;
13
14use crate::x_api::scopes::{self, ScopeAnalysis, REQUIRED_SCOPES};
15
16// ============================================================================
17// X API OAuth 2.0 endpoints
18// ============================================================================
19
20/// X API OAuth 2.0 authorization endpoint.
21pub const X_AUTH_URL: &str = "https://twitter.com/i/oauth2/authorize";
22
23/// X API OAuth 2.0 token endpoint.
24pub const X_TOKEN_URL: &str = "https://api.twitter.com/2/oauth2/token";
25
26/// X API users/me endpoint for credential verification.
27pub const X_USERS_ME_URL: &str = "https://api.twitter.com/2/users/me";
28
29// ============================================================================
30// API Tier
31// ============================================================================
32
33/// Detected X API tier.
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35pub enum ApiTier {
36    /// Free tier -- posting only (no search, no mentions).
37    Free,
38    /// Basic tier -- adds search/discovery.
39    Basic,
40    /// Pro tier -- all features.
41    Pro,
42}
43
44impl fmt::Display for ApiTier {
45    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
46        match self {
47            ApiTier::Free => write!(f, "Free"),
48            ApiTier::Basic => write!(f, "Basic"),
49            ApiTier::Pro => write!(f, "Pro"),
50        }
51    }
52}
53
54/// Capabilities enabled by the current API tier.
55#[derive(Debug, Clone)]
56pub struct TierCapabilities {
57    /// Whether the mentions loop can run.
58    pub mentions: bool,
59    /// Whether the discovery/search loop can run.
60    pub discovery: bool,
61    /// Whether posting (tweets + threads) is available.
62    pub posting: bool,
63    /// Whether tweet search is available.
64    pub search: bool,
65}
66
67impl TierCapabilities {
68    /// Determine capabilities for a given tier.
69    pub fn for_tier(tier: ApiTier) -> Self {
70        match tier {
71            ApiTier::Free => Self {
72                mentions: false,
73                discovery: false,
74                posting: true,
75                search: false,
76            },
77            ApiTier::Basic | ApiTier::Pro => Self {
78                mentions: true,
79                discovery: true,
80                posting: true,
81                search: true,
82            },
83        }
84    }
85
86    /// List the names of enabled automation loops.
87    pub fn enabled_loop_names(&self) -> Vec<&'static str> {
88        let mut loops = Vec::new();
89        if self.mentions {
90            loops.push("mentions");
91        }
92        if self.discovery {
93            loops.push("discovery");
94        }
95        // Content and threads are always enabled (no special tier required).
96        loops.push("content");
97        loops.push("threads");
98        loops
99    }
100
101    /// Format the tier capabilities as a status line.
102    pub fn format_status(&self) -> String {
103        let status = |enabled: bool| if enabled { "enabled" } else { "DISABLED" };
104        format!(
105            "Mentions: {}, Discovery: {}, Content: enabled, Threads: enabled",
106            status(self.mentions),
107            status(self.discovery),
108        )
109    }
110}
111
112// ============================================================================
113// Stored Tokens
114// ============================================================================
115
116/// OAuth tokens persisted to disk at `~/.tuitbot/tokens.json`.
117#[derive(Debug, Clone, Serialize, Deserialize)]
118pub struct StoredTokens {
119    /// OAuth 2.0 access token.
120    pub access_token: String,
121
122    /// OAuth 2.0 refresh token (for offline.access scope).
123    #[serde(default)]
124    pub refresh_token: Option<String>,
125
126    /// Token expiration timestamp.
127    #[serde(default)]
128    pub expires_at: Option<chrono::DateTime<chrono::Utc>>,
129
130    /// Granted OAuth scopes returned by X during token exchange.
131    #[serde(default)]
132    pub scopes: Vec<String>,
133}
134
135impl StoredTokens {
136    /// Check if the token has expired.
137    pub fn is_expired(&self) -> bool {
138        match self.expires_at {
139            Some(expires) => chrono::Utc::now() >= expires,
140            None => false,
141        }
142    }
143
144    /// Time remaining until token expires.
145    pub fn time_until_expiry(&self) -> Option<chrono::TimeDelta> {
146        self.expires_at.map(|expires| expires - chrono::Utc::now())
147    }
148
149    /// Format time until expiry as a human-readable string.
150    pub fn format_expiry(&self) -> String {
151        match self.time_until_expiry() {
152            Some(duration) if duration.num_seconds() > 0 => {
153                let hours = duration.num_hours();
154                let minutes = duration.num_minutes() % 60;
155                if hours > 0 {
156                    format!("{hours}h {minutes}m")
157                } else {
158                    format!("{minutes}m")
159                }
160            }
161            Some(_) => "expired".to_string(),
162            None => "no expiry set".to_string(),
163        }
164    }
165
166    /// Whether this token file includes scope metadata.
167    pub fn has_scope_info(&self) -> bool {
168        !self.scopes.is_empty()
169    }
170
171    /// Check whether a specific scope is granted.
172    pub fn has_scope(&self, scope: &str) -> bool {
173        self.scopes.iter().any(|granted| granted == scope)
174    }
175
176    /// Analyze granted scopes versus required Tuitbot scopes.
177    pub fn analyze_scopes(&self) -> ScopeAnalysis {
178        scopes::analyze_scopes(&self.scopes)
179    }
180}
181
182// ============================================================================
183// Startup Error
184// ============================================================================
185
186/// Errors that can occur during startup operations.
187#[derive(Debug, thiserror::Error)]
188pub enum StartupError {
189    /// Configuration is invalid or missing.
190    #[error("configuration error: {0}")]
191    Config(String),
192
193    /// No tokens found -- user needs to authenticate first.
194    #[error("authentication required: run `tuitbot auth` first")]
195    AuthRequired,
196
197    /// Tokens are expired and need re-authentication.
198    #[error("authentication expired: run `tuitbot auth` to re-authenticate")]
199    AuthExpired,
200
201    /// Token refresh attempt failed.
202    #[error("token refresh failed: {0}")]
203    TokenRefreshFailed(String),
204
205    /// Database initialization or access error.
206    #[error("database error: {0}")]
207    Database(String),
208
209    /// LLM provider configuration or connectivity error.
210    #[error("LLM provider error: {0}")]
211    LlmError(String),
212
213    /// X API communication error.
214    #[error("X API error: {0}")]
215    XApiError(String),
216
217    /// File I/O error.
218    #[error("I/O error: {0}")]
219    Io(#[from] std::io::Error),
220
221    /// Any other error.
222    #[error("{0}")]
223    Other(String),
224}
225
226// ============================================================================
227// Token File I/O
228// ============================================================================
229
230/// Default directory for Tuitbot data files (`~/.tuitbot/`).
231pub fn data_dir() -> PathBuf {
232    dirs::home_dir()
233        .unwrap_or_else(|| PathBuf::from("."))
234        .join(".tuitbot")
235}
236
237/// Path to the token storage file (`~/.tuitbot/tokens.json`).
238pub fn token_file_path() -> PathBuf {
239    data_dir().join("tokens.json")
240}
241
242/// Load OAuth tokens from the default file path.
243pub fn load_tokens_from_file() -> Result<StoredTokens, StartupError> {
244    let path = token_file_path();
245    let contents = std::fs::read_to_string(&path).map_err(|e| {
246        if e.kind() == std::io::ErrorKind::NotFound {
247            StartupError::AuthRequired
248        } else {
249            StartupError::Io(e)
250        }
251    })?;
252    serde_json::from_str(&contents)
253        .map_err(|e| StartupError::Other(format!("failed to parse tokens file: {e}")))
254}
255
256/// Save OAuth tokens to the default file path with secure permissions.
257///
258/// Creates the `~/.tuitbot/` directory if it does not exist.
259/// On Unix, sets file permissions to 0600 (owner read/write only).
260pub fn save_tokens_to_file(tokens: &StoredTokens) -> Result<(), StartupError> {
261    let dir = data_dir();
262    std::fs::create_dir_all(&dir)?;
263
264    let path = token_file_path();
265    let json = serde_json::to_string_pretty(tokens)
266        .map_err(|e| StartupError::Other(format!("failed to serialize tokens: {e}")))?;
267    std::fs::write(&path, json)?;
268
269    // Set file permissions to 0600 on Unix (owner read/write only).
270    #[cfg(unix)]
271    {
272        use std::os::unix::fs::PermissionsExt;
273        let perms = std::fs::Permissions::from_mode(0o600);
274        std::fs::set_permissions(&path, perms)?;
275    }
276
277    Ok(())
278}
279
280// ============================================================================
281// PKCE Authentication
282// ============================================================================
283
284/// PKCE code verifier and challenge pair.
285#[derive(Debug, Clone)]
286pub struct PkceChallenge {
287    /// The code verifier (sent during token exchange).
288    pub verifier: String,
289    /// The code challenge (sent during authorization).
290    pub challenge: String,
291    /// CSRF state parameter.
292    pub state: String,
293}
294
295/// Generate a PKCE code verifier, challenge, and state parameter.
296pub fn generate_pkce() -> PkceChallenge {
297    use rand::Rng;
298    let random_bytes: [u8; 32] = rand::thread_rng().gen();
299    let verifier = URL_SAFE_NO_PAD.encode(random_bytes);
300    let challenge = URL_SAFE_NO_PAD.encode(Sha256::digest(verifier.as_bytes()));
301    let state_bytes: [u8; 16] = rand::thread_rng().gen();
302    let state = URL_SAFE_NO_PAD.encode(state_bytes);
303    PkceChallenge {
304        verifier,
305        challenge,
306        state,
307    }
308}
309
310/// Percent-encode a string for use in URL query parameters (RFC 3986).
311fn url_encode(s: &str) -> String {
312    let mut encoded = String::with_capacity(s.len() * 3);
313    for byte in s.bytes() {
314        match byte {
315            b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
316                encoded.push(byte as char);
317            }
318            _ => {
319                use std::fmt::Write;
320                let _ = write!(encoded, "%{byte:02X}");
321            }
322        }
323    }
324    encoded
325}
326
327/// Build the X API OAuth 2.0 authorization URL.
328pub fn build_auth_url(
329    client_id: &str,
330    redirect_uri: &str,
331    state: &str,
332    code_challenge: &str,
333) -> String {
334    let oauth_scopes = REQUIRED_SCOPES.join(" ");
335    format!(
336        "{}?response_type=code&client_id={}&redirect_uri={}&scope={}&state={}&code_challenge={}&code_challenge_method=S256&prompt=consent",
337        X_AUTH_URL,
338        url_encode(client_id),
339        url_encode(redirect_uri),
340        url_encode(&oauth_scopes),
341        url_encode(state),
342        url_encode(code_challenge),
343    )
344}
345
346/// Build the redirect URI from config auth settings.
347pub fn build_redirect_uri(callback_host: &str, callback_port: u16) -> String {
348    format!("http://{callback_host}:{callback_port}/callback")
349}
350
351/// Exchange an authorization code for OAuth tokens.
352pub async fn exchange_auth_code(
353    client_id: &str,
354    code: &str,
355    redirect_uri: &str,
356    code_verifier: &str,
357) -> Result<StoredTokens, StartupError> {
358    let client = reqwest::Client::new();
359    let resp = client
360        .post(X_TOKEN_URL)
361        .form(&[
362            ("grant_type", "authorization_code"),
363            ("code", code),
364            ("redirect_uri", redirect_uri),
365            ("code_verifier", code_verifier),
366            ("client_id", client_id),
367        ])
368        .send()
369        .await
370        .map_err(|e| StartupError::XApiError(format!("token exchange request failed: {e}")))?;
371
372    if !resp.status().is_success() {
373        let status = resp.status();
374        let body = resp.text().await.unwrap_or_default();
375        return Err(StartupError::XApiError(format!(
376            "token exchange failed (HTTP {status}): {body}"
377        )));
378    }
379
380    #[derive(Deserialize)]
381    struct TokenResponse {
382        access_token: String,
383        #[serde(default)]
384        refresh_token: Option<String>,
385        #[serde(default)]
386        expires_in: Option<i64>,
387        #[serde(default)]
388        scope: Option<String>,
389    }
390
391    let token_resp: TokenResponse = resp
392        .json()
393        .await
394        .map_err(|e| StartupError::XApiError(format!("failed to parse token response: {e}")))?;
395
396    let expires_at = token_resp
397        .expires_in
398        .map(|secs| chrono::Utc::now() + chrono::TimeDelta::seconds(secs));
399    let scopes = token_resp
400        .scope
401        .map(|s| s.split_whitespace().map(String::from).collect())
402        .unwrap_or_default();
403
404    Ok(StoredTokens {
405        access_token: token_resp.access_token,
406        refresh_token: token_resp.refresh_token,
407        expires_at,
408        scopes,
409    })
410}
411
412/// Verify OAuth credentials by calling the X API /2/users/me endpoint.
413///
414/// Returns the authenticated user's username on success.
415pub async fn verify_credentials(access_token: &str) -> Result<String, StartupError> {
416    let client = reqwest::Client::new();
417    let resp = client
418        .get(X_USERS_ME_URL)
419        .bearer_auth(access_token)
420        .send()
421        .await
422        .map_err(|e| {
423            StartupError::XApiError(format!("credential verification request failed: {e}"))
424        })?;
425
426    if !resp.status().is_success() {
427        let status = resp.status();
428        let body = resp.text().await.unwrap_or_default();
429        return Err(StartupError::XApiError(format!(
430            "credential verification failed (HTTP {status}): {body}"
431        )));
432    }
433
434    #[derive(Deserialize)]
435    struct UserResponse {
436        data: UserData,
437    }
438
439    #[derive(Deserialize)]
440    struct UserData {
441        username: String,
442    }
443
444    let user: UserResponse = resp
445        .json()
446        .await
447        .map_err(|e| StartupError::XApiError(format!("failed to parse user response: {e}")))?;
448
449    Ok(user.data.username)
450}
451
452/// Extract the authorization code from a callback URL or raw code string.
453///
454/// Accepts either a full URL (e.g., `http://127.0.0.1:8080/callback?code=XXX&state=YYY`)
455/// or a bare authorization code.
456pub fn extract_auth_code(input: &str) -> String {
457    let trimmed = input.trim();
458    if trimmed.contains("code=") {
459        // Parse code from URL query parameters.
460        if let Some(query) = trimmed.split('?').nth(1) {
461            for pair in query.split('&') {
462                if let Some(value) = pair.strip_prefix("code=") {
463                    return value.to_string();
464                }
465            }
466        }
467    }
468    trimmed.to_string()
469}
470
471// ============================================================================
472// Startup Banner
473// ============================================================================
474
475/// Format the startup banner printed when the agent starts.
476pub fn format_startup_banner(
477    tier: ApiTier,
478    capabilities: &TierCapabilities,
479    status_interval: u64,
480) -> String {
481    let loops = capabilities.enabled_loop_names().join(", ");
482    let status = if status_interval > 0 {
483        format!("every {status_interval}s")
484    } else {
485        "disabled".to_string()
486    };
487    format!(
488        "Tuitbot v{version}\n\
489         Tier: {tier} | Loops: {loops}\n\
490         Status summary: {status}\n\
491         Press Ctrl+C to stop.",
492        version = env!("CARGO_PKG_VERSION"),
493    )
494}
495
496// ============================================================================
497// Path Helpers
498// ============================================================================
499
500/// Expand `~` at the start of a path to the user's home directory.
501pub fn expand_tilde(path: &str) -> PathBuf {
502    if let Some(rest) = path.strip_prefix("~/") {
503        if let Some(home) = dirs::home_dir() {
504            return home.join(rest);
505        }
506    } else if path == "~" {
507        if let Some(home) = dirs::home_dir() {
508            return home;
509        }
510    }
511    PathBuf::from(path)
512}
513
514// ============================================================================
515// Tests
516// ============================================================================
517
518#[cfg(test)]
519mod tests {
520    use super::*;
521
522    // --- ApiTier ---
523
524    #[test]
525    fn api_tier_display() {
526        assert_eq!(ApiTier::Free.to_string(), "Free");
527        assert_eq!(ApiTier::Basic.to_string(), "Basic");
528        assert_eq!(ApiTier::Pro.to_string(), "Pro");
529    }
530
531    // --- TierCapabilities ---
532
533    #[test]
534    fn free_tier_capabilities() {
535        let caps = TierCapabilities::for_tier(ApiTier::Free);
536        assert!(!caps.mentions);
537        assert!(!caps.discovery);
538        assert!(caps.posting);
539        assert!(!caps.search);
540    }
541
542    #[test]
543    fn basic_tier_capabilities() {
544        let caps = TierCapabilities::for_tier(ApiTier::Basic);
545        assert!(caps.mentions);
546        assert!(caps.discovery);
547        assert!(caps.posting);
548        assert!(caps.search);
549    }
550
551    #[test]
552    fn pro_tier_capabilities() {
553        let caps = TierCapabilities::for_tier(ApiTier::Pro);
554        assert!(caps.mentions);
555        assert!(caps.discovery);
556        assert!(caps.posting);
557        assert!(caps.search);
558    }
559
560    #[test]
561    fn free_tier_enabled_loops() {
562        let caps = TierCapabilities::for_tier(ApiTier::Free);
563        let loops = caps.enabled_loop_names();
564        assert_eq!(loops, vec!["content", "threads"]);
565    }
566
567    #[test]
568    fn basic_tier_enabled_loops() {
569        let caps = TierCapabilities::for_tier(ApiTier::Basic);
570        let loops = caps.enabled_loop_names();
571        assert_eq!(loops, vec!["mentions", "discovery", "content", "threads"]);
572    }
573
574    #[test]
575    fn tier_capabilities_format_status() {
576        let caps = TierCapabilities::for_tier(ApiTier::Free);
577        let status = caps.format_status();
578        assert!(status.contains("Mentions: DISABLED"));
579        assert!(status.contains("Discovery: DISABLED"));
580
581        let caps = TierCapabilities::for_tier(ApiTier::Basic);
582        let status = caps.format_status();
583        assert!(status.contains("Mentions: enabled"));
584        assert!(status.contains("Discovery: enabled"));
585    }
586
587    // --- StoredTokens ---
588
589    #[test]
590    fn stored_tokens_not_expired() {
591        let tokens = StoredTokens {
592            access_token: "test".to_string(),
593            refresh_token: None,
594            expires_at: Some(chrono::Utc::now() + chrono::TimeDelta::hours(1)),
595            scopes: vec![],
596        };
597        assert!(!tokens.is_expired());
598    }
599
600    #[test]
601    fn stored_tokens_expired() {
602        let tokens = StoredTokens {
603            access_token: "test".to_string(),
604            refresh_token: None,
605            expires_at: Some(chrono::Utc::now() - chrono::TimeDelta::hours(1)),
606            scopes: vec![],
607        };
608        assert!(tokens.is_expired());
609    }
610
611    #[test]
612    fn stored_tokens_no_expiry_is_not_expired() {
613        let tokens = StoredTokens {
614            access_token: "test".to_string(),
615            refresh_token: None,
616            expires_at: None,
617            scopes: vec![],
618        };
619        assert!(!tokens.is_expired());
620    }
621
622    #[test]
623    fn stored_tokens_format_expiry_hours() {
624        let tokens = StoredTokens {
625            access_token: "test".to_string(),
626            refresh_token: None,
627            expires_at: Some(chrono::Utc::now() + chrono::TimeDelta::minutes(102)),
628            scopes: vec![],
629        };
630        let formatted = tokens.format_expiry();
631        assert!(formatted.contains("h"));
632        assert!(formatted.contains("m"));
633    }
634
635    #[test]
636    fn stored_tokens_format_expiry_minutes_only() {
637        let tokens = StoredTokens {
638            access_token: "test".to_string(),
639            refresh_token: None,
640            expires_at: Some(chrono::Utc::now() + chrono::TimeDelta::minutes(30)),
641            scopes: vec![],
642        };
643        let formatted = tokens.format_expiry();
644        assert!(formatted.contains("m"));
645        assert!(!formatted.contains("h"));
646    }
647
648    #[test]
649    fn stored_tokens_format_expiry_expired() {
650        let tokens = StoredTokens {
651            access_token: "test".to_string(),
652            refresh_token: None,
653            expires_at: Some(chrono::Utc::now() - chrono::TimeDelta::hours(1)),
654            scopes: vec![],
655        };
656        assert_eq!(tokens.format_expiry(), "expired");
657    }
658
659    #[test]
660    fn stored_tokens_format_expiry_no_expiry() {
661        let tokens = StoredTokens {
662            access_token: "test".to_string(),
663            refresh_token: None,
664            expires_at: None,
665            scopes: vec![],
666        };
667        assert_eq!(tokens.format_expiry(), "no expiry set");
668    }
669
670    #[test]
671    fn stored_tokens_serialization_roundtrip() {
672        let tokens = StoredTokens {
673            access_token: "access123".to_string(),
674            refresh_token: Some("refresh456".to_string()),
675            expires_at: Some(
676                chrono::DateTime::parse_from_rfc3339("2026-06-01T12:00:00Z")
677                    .expect("valid datetime")
678                    .with_timezone(&chrono::Utc),
679            ),
680            scopes: vec!["tweet.read".to_string(), "tweet.write".to_string()],
681        };
682        let json = serde_json::to_string(&tokens).expect("serialize");
683        let deserialized: StoredTokens = serde_json::from_str(&json).expect("deserialize");
684        assert_eq!(deserialized.access_token, "access123");
685        assert_eq!(deserialized.refresh_token.as_deref(), Some("refresh456"));
686        assert!(deserialized.expires_at.is_some());
687        assert_eq!(
688            deserialized.scopes,
689            vec!["tweet.read".to_string(), "tweet.write".to_string()]
690        );
691    }
692
693    #[test]
694    fn stored_tokens_deserialize_without_scopes_defaults_empty() {
695        let json = r#"{
696            "access_token": "access123",
697            "refresh_token": "refresh456",
698            "expires_at": "2026-06-01T12:00:00Z"
699        }"#;
700
701        let tokens: StoredTokens = serde_json::from_str(json).expect("deserialize");
702        assert!(tokens.scopes.is_empty());
703        assert!(!tokens.has_scope_info());
704    }
705
706    #[test]
707    fn stored_tokens_scope_helpers_work() {
708        let tokens = StoredTokens {
709            access_token: "access123".to_string(),
710            refresh_token: Some("refresh456".to_string()),
711            expires_at: None,
712            scopes: vec!["tweet.read".to_string(), "users.read".to_string()],
713        };
714
715        assert!(tokens.has_scope_info());
716        assert!(tokens.has_scope("tweet.read"));
717        assert!(!tokens.has_scope("tweet.write"));
718    }
719
720    // --- Token File I/O ---
721
722    #[test]
723    fn save_and_load_tokens() {
724        let dir = tempfile::tempdir().expect("tempdir");
725        let path = dir.path().join("tokens.json");
726
727        let tokens = StoredTokens {
728            access_token: "test_access".to_string(),
729            refresh_token: Some("test_refresh".to_string()),
730            expires_at: None,
731            scopes: vec!["tweet.read".to_string()],
732        };
733
734        let json = serde_json::to_string_pretty(&tokens).expect("serialize");
735        std::fs::write(&path, &json).expect("write");
736
737        let contents = std::fs::read_to_string(&path).expect("read");
738        let loaded: StoredTokens = serde_json::from_str(&contents).expect("deserialize");
739        assert_eq!(loaded.access_token, "test_access");
740        assert_eq!(loaded.refresh_token.as_deref(), Some("test_refresh"));
741        assert_eq!(loaded.scopes, vec!["tweet.read".to_string()]);
742    }
743
744    #[cfg(unix)]
745    #[test]
746    fn save_tokens_sets_permissions() {
747        use std::os::unix::fs::PermissionsExt;
748
749        let dir = tempfile::tempdir().expect("tempdir");
750        // Override data dir for this test by saving directly.
751        let path = dir.path().join("tokens.json");
752        let tokens = StoredTokens {
753            access_token: "test".to_string(),
754            refresh_token: None,
755            expires_at: None,
756            scopes: vec![],
757        };
758        let json = serde_json::to_string_pretty(&tokens).expect("serialize");
759        std::fs::write(&path, &json).expect("write");
760        let perms = std::fs::Permissions::from_mode(0o600);
761        std::fs::set_permissions(&path, perms).expect("set perms");
762
763        let meta = std::fs::metadata(&path).expect("metadata");
764        assert_eq!(meta.permissions().mode() & 0o777, 0o600);
765    }
766
767    // --- Startup Error ---
768
769    #[test]
770    fn startup_error_display() {
771        let err = StartupError::AuthRequired;
772        assert_eq!(
773            err.to_string(),
774            "authentication required: run `tuitbot auth` first"
775        );
776
777        let err = StartupError::AuthExpired;
778        assert!(err.to_string().contains("expired"));
779
780        let err = StartupError::Config("bad field".to_string());
781        assert_eq!(err.to_string(), "configuration error: bad field");
782
783        let err = StartupError::XApiError("timeout".to_string());
784        assert_eq!(err.to_string(), "X API error: timeout");
785    }
786
787    // --- PKCE ---
788
789    #[test]
790    fn generate_pkce_produces_valid_challenge() {
791        let pkce = generate_pkce();
792        // Verifier should be 43 characters (32 bytes base64url encoded).
793        assert_eq!(pkce.verifier.len(), 43);
794        // Challenge should be 43 characters (32 bytes SHA-256 hash, base64url).
795        assert_eq!(pkce.challenge.len(), 43);
796        // State should be 22 characters (16 bytes base64url encoded).
797        assert_eq!(pkce.state.len(), 22);
798        // Verify the challenge matches the verifier.
799        let expected = URL_SAFE_NO_PAD.encode(Sha256::digest(pkce.verifier.as_bytes()));
800        assert_eq!(pkce.challenge, expected);
801    }
802
803    #[test]
804    fn generate_pkce_unique_each_time() {
805        let a = generate_pkce();
806        let b = generate_pkce();
807        assert_ne!(a.verifier, b.verifier);
808        assert_ne!(a.challenge, b.challenge);
809        assert_ne!(a.state, b.state);
810    }
811
812    // --- URL Building ---
813
814    #[test]
815    fn build_auth_url_contains_required_params() {
816        let url = build_auth_url(
817            "client123",
818            "http://localhost:8080/callback",
819            "state456",
820            "challenge789",
821        );
822        assert!(url.starts_with(X_AUTH_URL));
823        assert!(url.contains("response_type=code"));
824        assert!(url.contains("client_id=client123"));
825        assert!(url.contains("code_challenge=challenge789"));
826        assert!(url.contains("code_challenge_method=S256"));
827        assert!(url.contains("state=state456"));
828        // redirect_uri should be encoded.
829        assert!(url.contains("redirect_uri=http%3A%2F%2Flocalhost%3A8080%2Fcallback"));
830    }
831
832    #[test]
833    fn build_redirect_uri_format() {
834        let uri = build_redirect_uri("127.0.0.1", 8080);
835        assert_eq!(uri, "http://127.0.0.1:8080/callback");
836    }
837
838    // --- extract_auth_code ---
839
840    #[test]
841    fn extract_code_from_full_url() {
842        let code = extract_auth_code("http://127.0.0.1:8080/callback?code=abc123&state=xyz");
843        assert_eq!(code, "abc123");
844    }
845
846    #[test]
847    fn extract_code_from_bare_code() {
848        let code = extract_auth_code("  abc123  ");
849        assert_eq!(code, "abc123");
850    }
851
852    #[test]
853    fn extract_code_from_url_without_state() {
854        let code = extract_auth_code("http://127.0.0.1:8080/callback?code=mycode");
855        assert_eq!(code, "mycode");
856    }
857
858    // --- URL Encoding ---
859
860    #[test]
861    fn url_encode_basic() {
862        assert_eq!(url_encode("hello"), "hello");
863        assert_eq!(url_encode("hello world"), "hello%20world");
864        assert_eq!(
865            url_encode("http://localhost:8080/callback"),
866            "http%3A%2F%2Flocalhost%3A8080%2Fcallback"
867        );
868    }
869
870    // --- Startup Banner ---
871
872    #[test]
873    fn startup_banner_free_tier() {
874        let caps = TierCapabilities::for_tier(ApiTier::Free);
875        let banner = format_startup_banner(ApiTier::Free, &caps, 300);
876        assert!(banner.contains("Tuitbot v"));
877        assert!(banner.contains("Tier: Free"));
878        assert!(!banner.contains("mentions"));
879        assert!(banner.contains("content"));
880        assert!(banner.contains("threads"));
881        assert!(!banner.contains("discovery"));
882        assert!(banner.contains("every 300s"));
883    }
884
885    #[test]
886    fn startup_banner_basic_tier() {
887        let caps = TierCapabilities::for_tier(ApiTier::Basic);
888        let banner = format_startup_banner(ApiTier::Basic, &caps, 0);
889        assert!(banner.contains("Tier: Basic"));
890        assert!(banner.contains("discovery"));
891        assert!(banner.contains("disabled"));
892    }
893
894    #[test]
895    fn startup_banner_contains_ctrl_c_hint() {
896        let caps = TierCapabilities::for_tier(ApiTier::Free);
897        let banner = format_startup_banner(ApiTier::Free, &caps, 0);
898        assert!(banner.contains("Ctrl+C"));
899    }
900
901    // --- Path Helpers ---
902
903    #[test]
904    fn expand_tilde_works() {
905        let expanded = expand_tilde("~/.tuitbot/config.toml");
906        assert!(!expanded.to_string_lossy().starts_with('~'));
907    }
908
909    #[test]
910    fn expand_tilde_no_tilde() {
911        let expanded = expand_tilde("/absolute/path");
912        assert_eq!(expanded, PathBuf::from("/absolute/path"));
913    }
914
915    #[test]
916    fn data_dir_under_home() {
917        let dir = data_dir();
918        assert!(dir.to_string_lossy().contains(".tuitbot"));
919    }
920
921    #[test]
922    fn token_file_path_under_data_dir() {
923        let path = token_file_path();
924        assert!(path.to_string_lossy().contains("tokens.json"));
925        assert!(path.to_string_lossy().contains(".tuitbot"));
926    }
927}