Skip to main content

twapi_oauth2/
oauth2.rs

1use base64::prelude::*;
2use std::time::Duration;
3
4use query_string_builder::QueryString;
5use reqwest::{StatusCode, header::HeaderMap};
6use serde::{Deserialize, Serialize};
7use sha2::Digest;
8
9use crate::{error::Error, execute_retry, make_url};
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct TokenResult {
13    pub access_token: String,
14    pub refresh_token: String,
15    pub expires_in: u64,
16    pub scope: String,
17    pub token_type: String,
18}
19
20enum ResponseType {
21    Code,
22    #[allow(unused)]
23    Token,
24}
25
26impl std::fmt::Display for ResponseType {
27    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
28        match self {
29            Self::Code => write!(f, "code"),
30            Self::Token => write!(f, "token"),
31        }
32    }
33}
34
35enum CodeChallengeMethod {
36    S256,
37    #[allow(unused)]
38    Plain,
39}
40
41impl std::fmt::Display for CodeChallengeMethod {
42    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
43        match self {
44            Self::S256 => write!(f, "S256"),
45            Self::Plain => write!(f, "plain"),
46        }
47    }
48}
49
50pub(crate) struct PkceS256 {
51    pub code_challenge: String,
52    pub code_verifier: String,
53}
54
55impl PkceS256 {
56    pub fn new() -> Self {
57        let size = 32;
58        let random_bytes: Vec<u8> = (0..size).map(|_| rand::random::<u8>()).collect();
59        let code_verifier = BASE64_URL_SAFE_NO_PAD.encode(&random_bytes);
60        let code_challenge = {
61            let hash = sha2::Sha256::digest(code_verifier.as_bytes());
62            BASE64_URL_SAFE_NO_PAD.encode(hash)
63        };
64        Self {
65            code_challenge,
66            code_verifier,
67        }
68    }
69}
70
71impl Default for PkceS256 {
72    fn default() -> Self {
73        Self::new()
74    }
75}
76
77#[allow(clippy::too_many_arguments)]
78fn authorize_url(
79    url: &str,
80    response_type: ResponseType,
81    client_id: &str,
82    redirect_uri: &str,
83    scopes: &str,
84    state: &str,
85    code_challenge: &str,
86    code_challenge_method: CodeChallengeMethod,
87) -> String {
88    let qs = QueryString::dynamic()
89        .with_value("response_type", response_type.to_string())
90        .with_value("client_id", client_id)
91        .with_value("redirect_uri", redirect_uri)
92        .with_value("scope", scopes)
93        .with_value("state", state)
94        .with_value("code_challenge", code_challenge)
95        .with_value("code_challenge_method", code_challenge_method.to_string());
96    format!("{}{}", url, qs)
97}
98
99#[allow(clippy::too_many_arguments)]
100pub(crate) async fn token(
101    url: &str,
102    client_id: &str,
103    client_secret: &str,
104    redirect_uri: &str,
105    code: &str,
106    code_verifier: &str,
107    grant_type: &str,
108    timeout: Duration,
109    try_count: usize,
110    retry_duration: Duration,
111) -> Result<(TokenResult, StatusCode, HeaderMap), Error> {
112    let params = [
113        ("grant_type", grant_type),
114        ("code", code),
115        ("redirect_uri", redirect_uri),
116        ("client_id", client_id),
117        ("code_verifier", code_verifier),
118    ];
119
120    let client = reqwest::Client::new();
121
122    execute_retry(
123        || {
124            client
125                .post(url)
126                .form(&params)
127                .basic_auth(client_id, Some(client_secret))
128                .timeout(timeout)
129        },
130        try_count,
131        retry_duration,
132    )
133    .await
134}
135
136pub enum XScope {
137    TweetRead,
138    TweetWrite,
139    TweetModerateWrite,
140    UsersEmail,
141    UsersRead,
142    FollowsRead,
143    FollowsWrite,
144    OfflineAccess,
145    SpaceRead,
146    MuteRead,
147    MuteWrite,
148    LikeRead,
149    LikeWrite,
150    ListRead,
151    ListWrite,
152    BlockRead,
153    BlockWrite,
154    BookmarkRead,
155    BookmarkWrite,
156    DmRead,
157    DmWrite,
158    MediaWrite,
159}
160
161impl XScope {
162    pub fn all() -> Vec<Self> {
163        vec![
164            Self::TweetRead,
165            Self::TweetWrite,
166            Self::TweetModerateWrite,
167            Self::UsersEmail,
168            Self::UsersRead,
169            Self::FollowsRead,
170            Self::FollowsWrite,
171            Self::OfflineAccess,
172            Self::SpaceRead,
173            Self::MuteRead,
174            Self::MuteWrite,
175            Self::LikeRead,
176            Self::LikeWrite,
177            Self::ListRead,
178            Self::ListWrite,
179            Self::BlockRead,
180            Self::BlockWrite,
181            Self::BookmarkRead,
182            Self::BookmarkWrite,
183            Self::DmRead,
184            Self::DmWrite,
185            Self::MediaWrite,
186        ]
187    }
188
189    pub fn scopes_to_string(scopes: &[XScope]) -> String {
190        scopes
191            .iter()
192            .map(|s| s.to_string())
193            .collect::<Vec<String>>()
194            .join(" ")
195    }
196}
197
198impl std::fmt::Display for XScope {
199    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
200        match self {
201            Self::TweetRead => write!(f, "tweet.read"),
202            Self::TweetWrite => write!(f, "tweet.write"),
203            Self::TweetModerateWrite => write!(f, "tweet.moderate.write"),
204            Self::UsersEmail => write!(f, "users.email"),
205            Self::UsersRead => write!(f, "users.read"),
206            Self::FollowsRead => write!(f, "follows.read"),
207            Self::FollowsWrite => write!(f, "follows.write"),
208            Self::OfflineAccess => write!(f, "offline.access"),
209            Self::SpaceRead => write!(f, "space.read"),
210            Self::MuteRead => write!(f, "mute.read"),
211            Self::MuteWrite => write!(f, "mute.write"),
212            Self::LikeRead => write!(f, "like.read"),
213            Self::LikeWrite => write!(f, "like.write"),
214            Self::ListRead => write!(f, "list.read"),
215            Self::ListWrite => write!(f, "list.write"),
216            Self::BlockRead => write!(f, "block.read"),
217            Self::BlockWrite => write!(f, "block.write"),
218            Self::BookmarkRead => write!(f, "bookmark.read"),
219            Self::BookmarkWrite => write!(f, "bookmark.write"),
220            Self::DmRead => write!(f, "dm.read"),
221            Self::DmWrite => write!(f, "dm.write"),
222            Self::MediaWrite => write!(f, "media.write"),
223        }
224    }
225}
226
227pub const X_AUTHORIZE_URL: &str = "https://x.com/i/oauth2/authorize";
228
229const URL_POSTFIX: &str = "https://api.x.com";
230pub const X_TOKEN_URL_PREFIX: &str = "/2/oauth2/token";
231
232pub struct XClient {
233    client_id: String,
234    client_secret: String,
235    redirect_uri: String,
236    scopes: Vec<XScope>,
237    try_count: usize,
238    retry_duration: Duration,
239    timeout: Duration,
240    prefix_url: Option<String>,
241}
242
243impl XClient {
244    pub fn new(
245        client_id: &str,
246        client_secret: &str,
247        redirect_uri: &str,
248        scopes: Vec<XScope>,
249    ) -> Self {
250        Self::new_with_token_options(
251            client_id,
252            client_secret,
253            redirect_uri,
254            scopes,
255            3,
256            Duration::from_millis(100),
257            Duration::from_secs(10),
258            None,
259        )
260    }
261
262    #[allow(clippy::too_many_arguments)]
263    pub fn new_with_token_options(
264        client_id: &str,
265        client_secret: &str,
266        redirect_uri: &str,
267        scopes: Vec<XScope>,
268        try_count: usize,
269        retry_duration: Duration,
270        timeout: Duration,
271        prefix_url: Option<String>,
272    ) -> Self {
273        Self {
274            client_id: client_id.to_string(),
275            client_secret: client_secret.to_string(),
276            redirect_uri: redirect_uri.to_string(),
277            scopes,
278            try_count,
279            retry_duration,
280            timeout,
281            prefix_url,
282        }
283    }
284
285    pub fn authorize_url(&self, state: &str) -> (String, String) {
286        let pkce = PkceS256::new();
287
288        let scopes_str = XScope::scopes_to_string(&self.scopes);
289        (
290            authorize_url(
291                X_AUTHORIZE_URL,
292                ResponseType::Code,
293                &self.client_id,
294                &self.redirect_uri,
295                &scopes_str,
296                state,
297                &pkce.code_challenge,
298                CodeChallengeMethod::S256,
299            ),
300            pkce.code_verifier,
301        )
302    }
303
304    pub async fn token(
305        &self,
306        code: &str,
307        code_verifier: &str,
308    ) -> Result<(TokenResult, StatusCode, HeaderMap), Error> {
309        let (token_json, status_code, headers) = token(
310            &make_url(URL_POSTFIX, X_TOKEN_URL_PREFIX, &self.prefix_url),
311            &self.client_id,
312            &self.client_secret,
313            &self.redirect_uri,
314            code,
315            code_verifier,
316            "authorization_code",
317            self.timeout,
318            self.try_count,
319            self.retry_duration,
320        )
321        .await?;
322        Ok((token_json, status_code, headers))
323    }
324}
325
326#[cfg(test)]
327mod tests {
328    use super::*;
329
330    // CLIENT_ID=xxx CLIENT_SECRET=xxx REDIRECT_URL=http://localhost:8000/callback cargo test test_x_authorize -- --nocapture
331    #[tokio::test]
332    async fn test_x_authorize() {
333        let client_id = std::env::var("CLIENT_ID").unwrap();
334        let client_secret = std::env::var("CLIENT_SECRET").unwrap();
335        let redirect_url = std::env::var("REDIRECT_URL").unwrap();
336        let state = "test_state";
337        let x_client = XClient::new(&client_id, &client_secret, &redirect_url, XScope::all());
338        let (auth_url, code_verifier) = x_client.authorize_url(state);
339        println!("Authorize URL: {}", auth_url);
340        println!("Code Verifier: {}", code_verifier);
341    }
342
343    // CLIENT_ID=xxx cargo test -- --nocapture
344    #[tokio::test]
345    async fn test_authorize() {
346        let client_id = std::env::var("CLIENT_ID").unwrap();
347        let redirect_url = std::env::var("REDIRECT_URL").unwrap();
348        let state = "test_state";
349        let scopes = XScope::scopes_to_string(&XScope::all());
350        let code_challenge = "test_code_challenge";
351        let res = authorize_url(
352            X_AUTHORIZE_URL,
353            ResponseType::Code,
354            &client_id,
355            &redirect_url,
356            &scopes,
357            &state,
358            &code_challenge,
359            CodeChallengeMethod::Plain,
360        );
361        println!("res: {}", res);
362    }
363}