Skip to main content

tuitbot_core/x_api/auth/
token.rs

1//! Token manager: persistence, loading, and automatic refresh.
2
3use std::sync::Arc;
4
5use chrono::Utc;
6use tokio::sync::RwLock;
7
8use crate::error::XApiError;
9
10use super::{save_tokens, TokenRefreshResponse, Tokens, REFRESH_WINDOW_SECS, TOKEN_URL};
11
12/// Manages token persistence, loading, and automatic refresh.
13pub struct TokenManager {
14    tokens: Arc<RwLock<Tokens>>,
15    client_id: String,
16    http_client: reqwest::Client,
17    token_path: std::path::PathBuf,
18    /// Serializes refresh attempts so only one runs at a time.
19    /// X API refresh tokens are single-use, so concurrent refreshes
20    /// would invalidate the token used by the second caller.
21    refresh_lock: tokio::sync::Mutex<()>,
22}
23
24impl TokenManager {
25    /// Create a new token manager with the given tokens and client configuration.
26    pub fn new(tokens: Tokens, client_id: String, token_path: std::path::PathBuf) -> Self {
27        Self {
28            tokens: Arc::new(RwLock::new(tokens)),
29            client_id,
30            http_client: reqwest::Client::new(),
31            token_path,
32            refresh_lock: tokio::sync::Mutex::new(()),
33        }
34    }
35
36    /// Get the current access token, refreshing if needed.
37    pub async fn get_access_token(&self) -> Result<String, XApiError> {
38        self.refresh_if_needed().await?;
39        let tokens = self.tokens.read().await;
40        Ok(tokens.access_token.clone())
41    }
42
43    /// Get a shared reference to the tokens lock for direct access.
44    pub fn tokens_lock(&self) -> Arc<RwLock<Tokens>> {
45        self.tokens.clone()
46    }
47
48    /// Refresh the access token if it is within 5 minutes of expiring.
49    ///
50    /// Acquires `refresh_lock` to prevent concurrent refresh attempts.
51    /// X API refresh tokens are single-use, so a second concurrent refresh
52    /// with the old token would fail and revoke the session.
53    pub async fn refresh_if_needed(&self) -> Result<(), XApiError> {
54        // Fast path: no refresh needed.
55        {
56            let tokens = self.tokens.read().await;
57            let seconds_until_expiry = tokens
58                .expires_at
59                .signed_duration_since(Utc::now())
60                .num_seconds();
61            if seconds_until_expiry >= REFRESH_WINDOW_SECS {
62                return Ok(());
63            }
64        }
65
66        // Serialize refresh attempts.
67        let _guard = self.refresh_lock.lock().await;
68
69        // Re-check after acquiring the lock — another caller may have
70        // already refreshed while we were waiting.
71        {
72            let tokens = self.tokens.read().await;
73            let seconds_until_expiry = tokens
74                .expires_at
75                .signed_duration_since(Utc::now())
76                .num_seconds();
77            if seconds_until_expiry >= REFRESH_WINDOW_SECS {
78                return Ok(());
79            }
80        }
81
82        self.do_refresh().await
83    }
84
85    /// Perform the token refresh.
86    async fn do_refresh(&self) -> Result<(), XApiError> {
87        let refresh_token = {
88            let tokens = self.tokens.read().await;
89            tokens.refresh_token.clone()
90        };
91
92        tracing::info!("Refreshing X API access token");
93
94        let params = [
95            ("grant_type", "refresh_token"),
96            ("refresh_token", &refresh_token),
97            ("client_id", &self.client_id),
98        ];
99
100        let response = self
101            .http_client
102            .post(TOKEN_URL)
103            .form(&params)
104            .send()
105            .await
106            .map_err(|e| XApiError::Network { source: e })?;
107
108        if !response.status().is_success() {
109            let status = response.status().as_u16();
110            let body = response.text().await.unwrap_or_default();
111            tracing::error!(
112                status,
113                body_len = body.len(),
114                "Token refresh failed (response body redacted)"
115            );
116            return Err(XApiError::AuthExpired);
117        }
118
119        let body: TokenRefreshResponse = response
120            .json()
121            .await
122            .map_err(|e| XApiError::Network { source: e })?;
123
124        let new_tokens = Tokens {
125            access_token: body.access_token,
126            refresh_token: body.refresh_token,
127            expires_at: Utc::now() + chrono::Duration::seconds(body.expires_in),
128            scopes: body
129                .scope
130                .split_whitespace()
131                .map(|s| s.to_string())
132                .collect(),
133        };
134
135        tracing::info!(
136            expires_at = %new_tokens.expires_at,
137            "Token refreshed successfully"
138        );
139
140        // Update in memory
141        {
142            let mut tokens = self.tokens.write().await;
143            *tokens = new_tokens.clone();
144        }
145
146        // Persist to disk
147        save_tokens(&new_tokens, &self.token_path).map_err(|e| {
148            tracing::error!(error = %e, "Failed to save refreshed tokens");
149            XApiError::ApiError {
150                status: 0,
151                message: format!("Failed to save tokens: {e}"),
152            }
153        })?;
154
155        Ok(())
156    }
157}