Skip to main content

twapi_oauth2/
oauth2.rs

1use base64::prelude::*;
2use std::time::Duration;
3
4use query_string_builder::QueryString;
5use reqwest::{StatusCode, header::HeaderMap};
6use serde::{Deserialize, Serialize};
7use sha2::Digest;
8
9use crate::{error::Error, execute_retry};
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct TokenResult {
13    pub access_token: String,
14    pub refresh_token: String,
15    pub expires_in: u64,
16    pub scope: String,
17    pub token_type: String,
18}
19
20enum ResponseType {
21    Code,
22    #[allow(unused)]
23    Token,
24}
25
26impl std::fmt::Display for ResponseType {
27    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
28        match self {
29            Self::Code => write!(f, "code"),
30            Self::Token => write!(f, "token"),
31        }
32    }
33}
34
35enum CodeChallengeMethod {
36    S256,
37    #[allow(unused)]
38    Plain,
39}
40
41impl std::fmt::Display for CodeChallengeMethod {
42    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
43        match self {
44            Self::S256 => write!(f, "S256"),
45            Self::Plain => write!(f, "plain"),
46        }
47    }
48}
49
50pub(crate) struct PkceS256 {
51    pub code_challenge: String,
52    pub code_verifier: String,
53}
54
55impl PkceS256 {
56    pub fn new() -> Self {
57        let size = 32;
58        let random_bytes: Vec<u8> = (0..size).map(|_| rand::random::<u8>()).collect();
59        let code_verifier = BASE64_URL_SAFE_NO_PAD.encode(&random_bytes);
60        let code_challenge = {
61            let hash = sha2::Sha256::digest(code_verifier.as_bytes());
62            BASE64_URL_SAFE_NO_PAD.encode(hash)
63        };
64        Self {
65            code_challenge,
66            code_verifier,
67        }
68    }
69}
70
71impl Default for PkceS256 {
72    fn default() -> Self {
73        Self::new()
74    }
75}
76
77#[allow(clippy::too_many_arguments)]
78fn authorize_url(
79    url: &str,
80    response_type: ResponseType,
81    client_id: &str,
82    redirect_uri: &str,
83    scopes: &str,
84    state: &str,
85    code_challenge: &str,
86    code_challenge_method: CodeChallengeMethod,
87) -> String {
88    let qs = QueryString::dynamic()
89        .with_value("response_type", response_type.to_string())
90        .with_value("client_id", client_id)
91        .with_value("redirect_uri", redirect_uri)
92        .with_value("scope", scopes)
93        .with_value("state", state)
94        .with_value("code_challenge", code_challenge)
95        .with_value("code_challenge_method", code_challenge_method.to_string());
96    format!("{}{}", url, qs)
97}
98
99#[allow(clippy::too_many_arguments)]
100pub(crate) async fn token(
101    url: &str,
102    client_id: &str,
103    client_secret: &str,
104    redirect_uri: &str,
105    code: &str,
106    code_verifier: &str,
107    grant_type: &str,
108    timeout: Duration,
109    try_count: usize,
110    retry_millis: u64,
111) -> Result<(TokenResult, StatusCode, HeaderMap), Error> {
112    let params = [
113        ("grant_type", grant_type),
114        ("code", code),
115        ("redirect_uri", redirect_uri),
116        ("client_id", client_id),
117        ("code_verifier", code_verifier),
118    ];
119
120    let client = reqwest::Client::new();
121
122    execute_retry(
123        || {
124            client
125                .post(url)
126                .form(&params)
127                .basic_auth(client_id, Some(client_secret))
128                .timeout(timeout)
129        },
130        try_count,
131        retry_millis,
132    )
133    .await
134}
135
136pub enum XScope {
137    TweetRead,
138    TweetWrite,
139    TweetModerateWrite,
140    UsersEmail,
141    UsersRead,
142    FollowsRead,
143    FollowsWrite,
144    OfflineAccess,
145    SpaceRead,
146    MuteRead,
147    MuteWrite,
148    LikeRead,
149    LikeWrite,
150    ListRead,
151    ListWrite,
152    BlockRead,
153    BlockWrite,
154    BookmarkRead,
155    BookmarkWrite,
156    DmRead,
157    DmWrite,
158    MediaWrite,
159}
160
161impl XScope {
162    pub fn all() -> Vec<Self> {
163        vec![
164            Self::TweetRead,
165            Self::TweetWrite,
166            Self::TweetModerateWrite,
167            Self::UsersEmail,
168            Self::UsersRead,
169            Self::FollowsRead,
170            Self::FollowsWrite,
171            Self::OfflineAccess,
172            Self::SpaceRead,
173            Self::MuteRead,
174            Self::MuteWrite,
175            Self::LikeRead,
176            Self::LikeWrite,
177            Self::ListRead,
178            Self::ListWrite,
179            Self::BlockRead,
180            Self::BlockWrite,
181            Self::BookmarkRead,
182            Self::BookmarkWrite,
183            Self::DmRead,
184            Self::DmWrite,
185            Self::MediaWrite,
186        ]
187    }
188
189    pub fn scopes_to_string(scopes: &[XScope]) -> String {
190        scopes
191            .iter()
192            .map(|s| s.to_string())
193            .collect::<Vec<String>>()
194            .join(" ")
195    }
196}
197
198impl std::fmt::Display for XScope {
199    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
200        match self {
201            Self::TweetRead => write!(f, "tweet.read"),
202            Self::TweetWrite => write!(f, "tweet.write"),
203            Self::TweetModerateWrite => write!(f, "tweet.moderate.write"),
204            Self::UsersEmail => write!(f, "users.email"),
205            Self::UsersRead => write!(f, "users.read"),
206            Self::FollowsRead => write!(f, "follows.read"),
207            Self::FollowsWrite => write!(f, "follows.write"),
208            Self::OfflineAccess => write!(f, "offline.access"),
209            Self::SpaceRead => write!(f, "space.read"),
210            Self::MuteRead => write!(f, "mute.read"),
211            Self::MuteWrite => write!(f, "mute.write"),
212            Self::LikeRead => write!(f, "like.read"),
213            Self::LikeWrite => write!(f, "like.write"),
214            Self::ListRead => write!(f, "list.read"),
215            Self::ListWrite => write!(f, "list.write"),
216            Self::BlockRead => write!(f, "block.read"),
217            Self::BlockWrite => write!(f, "block.write"),
218            Self::BookmarkRead => write!(f, "bookmark.read"),
219            Self::BookmarkWrite => write!(f, "bookmark.write"),
220            Self::DmRead => write!(f, "dm.read"),
221            Self::DmWrite => write!(f, "dm.write"),
222            Self::MediaWrite => write!(f, "media.write"),
223        }
224    }
225}
226
227pub const X_AUTHORIZE_URL: &str = "https://x.com/i/oauth2/authorize";
228pub const X_TOKEN_URL: &str = "https://api.x.com/2/oauth2/token";
229
230pub struct XClient {
231    client_id: String,
232    client_secret: String,
233    redirect_uri: String,
234    scopes: Vec<XScope>,
235    try_count: usize,
236    retry_millis: u64,
237    timeout: Duration,
238}
239
240impl XClient {
241    pub fn new(
242        client_id: &str,
243        client_secret: &str,
244        redirect_uri: &str,
245        scopes: Vec<XScope>,
246    ) -> Self {
247        Self::new_with_token_options(
248            client_id,
249            client_secret,
250            redirect_uri,
251            scopes,
252            3,
253            500,
254            Duration::from_secs(10),
255        )
256    }
257
258    pub fn new_with_token_options(
259        client_id: &str,
260        client_secret: &str,
261        redirect_uri: &str,
262        scopes: Vec<XScope>,
263        try_count: usize,
264        retry_millis: u64,
265        timeout: Duration,
266    ) -> Self {
267        Self {
268            client_id: client_id.to_string(),
269            client_secret: client_secret.to_string(),
270            redirect_uri: redirect_uri.to_string(),
271            scopes,
272            try_count,
273            retry_millis,
274            timeout,
275        }
276    }
277
278    pub fn authorize_url(&self, state: &str) -> (String, String) {
279        let pkce = PkceS256::new();
280
281        let scopes_str = XScope::scopes_to_string(&self.scopes);
282        (
283            authorize_url(
284                X_AUTHORIZE_URL,
285                ResponseType::Code,
286                &self.client_id,
287                &self.redirect_uri,
288                &scopes_str,
289                state,
290                &pkce.code_challenge,
291                CodeChallengeMethod::S256,
292            ),
293            pkce.code_verifier,
294        )
295    }
296
297    pub async fn token(
298        &self,
299        code: &str,
300        code_verifier: &str,
301    ) -> Result<(TokenResult, StatusCode, HeaderMap), Error> {
302        let (token_json, status_code, headers) = token(
303            X_TOKEN_URL,
304            &self.client_id,
305            &self.client_secret,
306            &self.redirect_uri,
307            code,
308            code_verifier,
309            "authorization_code",
310            self.timeout,
311            self.try_count,
312            self.retry_millis,
313        )
314        .await?;
315        Ok((token_json, status_code, headers))
316    }
317}
318
319#[cfg(test)]
320mod tests {
321    use super::*;
322
323    // CLIENT_ID=xxx CLIENT_SECRET=xxx REDIRECT_URL=http://localhost:8000/callback cargo test test_x_authorize -- --nocapture
324    #[tokio::test]
325    async fn test_x_authorize() {
326        let client_id = std::env::var("CLIENT_ID").unwrap();
327        let client_secret = std::env::var("CLIENT_SECRET").unwrap();
328        let redirect_url = std::env::var("REDIRECT_URL").unwrap();
329        let state = "test_state";
330        let x_client = XClient::new(&client_id, &client_secret, &redirect_url, XScope::all());
331        let (auth_url, code_verifier) = x_client.authorize_url(state);
332        println!("Authorize URL: {}", auth_url);
333        println!("Code Verifier: {}", code_verifier);
334    }
335
336    // CLIENT_ID=xxx cargo test -- --nocapture
337    #[tokio::test]
338    async fn test_authorize() {
339        let client_id = std::env::var("CLIENT_ID").unwrap();
340        let redirect_url = std::env::var("REDIRECT_URL").unwrap();
341        let state = "test_state";
342        let scopes = XScope::scopes_to_string(&XScope::all());
343        let code_challenge = "test_code_challenge";
344        let res = authorize_url(
345            X_AUTHORIZE_URL,
346            ResponseType::Code,
347            &client_id,
348            &redirect_url,
349            &scopes,
350            &state,
351            &code_challenge,
352            CodeChallengeMethod::Plain,
353        );
354        println!("res: {}", res);
355    }
356}