Skip to main content

tuitbot_core/startup/
services.rs

1//! PKCE generation, OAuth URL building, token exchange, credential verification,
2//! and startup banner formatting.
3
4use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _};
5use serde::Deserialize;
6use sha2::{Digest, Sha256};
7
8use super::config::{ApiTier, StartupError, StoredTokens, TierCapabilities};
9
10// ============================================================================
11// X API OAuth 2.0 endpoints (re-exported via mod.rs)
12// ============================================================================
13
14/// X API OAuth 2.0 authorization endpoint.
15pub const X_AUTH_URL: &str = "https://twitter.com/i/oauth2/authorize";
16
17/// X API OAuth 2.0 token endpoint.
18pub const X_TOKEN_URL: &str = "https://api.twitter.com/2/oauth2/token";
19
20/// X API users/me endpoint for credential verification.
21pub const X_USERS_ME_URL: &str = "https://api.twitter.com/2/users/me";
22
23// ============================================================================
24// PKCE Authentication
25// ============================================================================
26
27/// PKCE code verifier and challenge pair.
28#[derive(Debug, Clone)]
29pub struct PkceChallenge {
30    /// The code verifier (sent during token exchange).
31    pub verifier: String,
32    /// The code challenge (sent during authorization).
33    pub challenge: String,
34    /// CSRF state parameter.
35    pub state: String,
36}
37
38/// Generate a PKCE code verifier, challenge, and state parameter.
39pub fn generate_pkce() -> PkceChallenge {
40    use rand::Rng;
41    let random_bytes: [u8; 32] = rand::rng().random();
42    let verifier = URL_SAFE_NO_PAD.encode(random_bytes);
43    let challenge = URL_SAFE_NO_PAD.encode(Sha256::digest(verifier.as_bytes()));
44    let state_bytes: [u8; 16] = rand::rng().random();
45    let state = URL_SAFE_NO_PAD.encode(state_bytes);
46    PkceChallenge {
47        verifier,
48        challenge,
49        state,
50    }
51}
52
53/// Percent-encode a string for use in URL query parameters (RFC 3986).
54pub(super) fn url_encode(s: &str) -> String {
55    let mut encoded = String::with_capacity(s.len() * 3);
56    for byte in s.bytes() {
57        match byte {
58            b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
59                encoded.push(byte as char);
60            }
61            _ => {
62                use std::fmt::Write;
63                let _ = write!(encoded, "%{byte:02X}");
64            }
65        }
66    }
67    encoded
68}
69
70/// Build the X API OAuth 2.0 authorization URL.
71pub fn build_auth_url(
72    client_id: &str,
73    redirect_uri: &str,
74    state: &str,
75    code_challenge: &str,
76) -> String {
77    use crate::x_api::scopes::REQUIRED_SCOPES;
78    let oauth_scopes = REQUIRED_SCOPES.join(" ");
79    format!(
80        "{}?response_type=code&client_id={}&redirect_uri={}&scope={}&state={}&code_challenge={}&code_challenge_method=S256&prompt=consent",
81        X_AUTH_URL,
82        url_encode(client_id),
83        url_encode(redirect_uri),
84        url_encode(&oauth_scopes),
85        url_encode(state),
86        url_encode(code_challenge),
87    )
88}
89
90/// Build the redirect URI from config auth settings.
91pub fn build_redirect_uri(callback_host: &str, callback_port: u16) -> String {
92    format!("http://{callback_host}:{callback_port}/callback")
93}
94
95/// Exchange an authorization code for OAuth tokens.
96pub async fn exchange_auth_code(
97    client_id: &str,
98    code: &str,
99    redirect_uri: &str,
100    code_verifier: &str,
101) -> Result<StoredTokens, StartupError> {
102    let client = reqwest::Client::new();
103    let resp = client
104        .post(X_TOKEN_URL)
105        .form(&[
106            ("grant_type", "authorization_code"),
107            ("code", code),
108            ("redirect_uri", redirect_uri),
109            ("code_verifier", code_verifier),
110            ("client_id", client_id),
111        ])
112        .send()
113        .await
114        .map_err(|e| StartupError::XApiError(format!("token exchange request failed: {e}")))?;
115
116    if !resp.status().is_success() {
117        let status = resp.status();
118        let body = resp.text().await.unwrap_or_default();
119        return Err(StartupError::XApiError(format!(
120            "token exchange failed (HTTP {status}): {body}"
121        )));
122    }
123
124    #[derive(Deserialize)]
125    struct TokenResponse {
126        access_token: String,
127        #[serde(default)]
128        refresh_token: Option<String>,
129        #[serde(default)]
130        expires_in: Option<i64>,
131        #[serde(default)]
132        scope: Option<String>,
133    }
134
135    let token_resp: TokenResponse = resp
136        .json()
137        .await
138        .map_err(|e| StartupError::XApiError(format!("failed to parse token response: {e}")))?;
139
140    let expires_at = token_resp
141        .expires_in
142        .map(|secs| chrono::Utc::now() + chrono::TimeDelta::seconds(secs));
143    let scopes = token_resp
144        .scope
145        .map(|s| s.split_whitespace().map(String::from).collect())
146        .unwrap_or_default();
147
148    Ok(StoredTokens {
149        access_token: token_resp.access_token,
150        refresh_token: token_resp.refresh_token,
151        expires_at,
152        scopes,
153    })
154}
155
156/// Verify OAuth credentials by calling the X API /2/users/me endpoint.
157///
158/// Returns the authenticated user's username on success.
159pub async fn verify_credentials(access_token: &str) -> Result<String, StartupError> {
160    let client = reqwest::Client::new();
161    let resp = client
162        .get(X_USERS_ME_URL)
163        .bearer_auth(access_token)
164        .send()
165        .await
166        .map_err(|e| {
167            StartupError::XApiError(format!("credential verification request failed: {e}"))
168        })?;
169
170    if !resp.status().is_success() {
171        let status = resp.status();
172        let body = resp.text().await.unwrap_or_default();
173        return Err(StartupError::XApiError(format!(
174            "credential verification failed (HTTP {status}): {body}"
175        )));
176    }
177
178    #[derive(Deserialize)]
179    struct UserResponse {
180        data: UserData,
181    }
182
183    #[derive(Deserialize)]
184    struct UserData {
185        username: String,
186    }
187
188    let user: UserResponse = resp
189        .json()
190        .await
191        .map_err(|e| StartupError::XApiError(format!("failed to parse user response: {e}")))?;
192
193    Ok(user.data.username)
194}
195
196/// Extract the authorization code from a callback URL or raw code string.
197///
198/// Accepts either a full URL (e.g., `http://127.0.0.1:8080/callback?code=XXX&state=YYY`)
199/// or a bare authorization code.
200pub fn extract_auth_code(input: &str) -> String {
201    let trimmed = input.trim();
202    if trimmed.contains("code=") {
203        // Parse code from URL query parameters.
204        if let Some(query) = trimmed.split('?').nth(1) {
205            for pair in query.split('&') {
206                if let Some(value) = pair.strip_prefix("code=") {
207                    return value.to_string();
208                }
209            }
210        }
211    }
212    trimmed.to_string()
213}
214
215/// Extract the `state` parameter from a callback URL or query string.
216///
217/// Returns an empty string if no `state` parameter is found.
218pub fn extract_callback_state(input: &str) -> String {
219    let query = if let Some(q) = input.split('?').nth(1) {
220        // Strip HTTP version suffix if present (e.g. " HTTP/1.1").
221        q.split_whitespace().next().unwrap_or(q)
222    } else {
223        input.trim()
224    };
225    for pair in query.split('&') {
226        if let Some(value) = pair.strip_prefix("state=") {
227            return value.to_string();
228        }
229    }
230    String::new()
231}
232
233// ============================================================================
234// Startup Banner
235// ============================================================================
236
237/// Format the startup banner printed when the agent starts.
238pub fn format_startup_banner(
239    tier: ApiTier,
240    capabilities: &TierCapabilities,
241    status_interval: u64,
242) -> String {
243    let loops = capabilities.enabled_loop_names().join(", ");
244    let status = if status_interval > 0 {
245        format!("every {status_interval}s")
246    } else {
247        "disabled".to_string()
248    };
249    format!(
250        "Tuitbot v{version}\n\
251         Tier: {tier} | Loops: {loops}\n\
252         Status summary: {status}\n\
253         Press Ctrl+C to stop.",
254        version = env!("CARGO_PKG_VERSION"),
255    )
256}