twapi_v2/
oauth.rs

1use oauth2::{
2    basic::BasicClient, AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, RedirectUrl,
3    Scope, TokenResponse, TokenUrl,
4};
5use std::time::Duration;
6use thiserror::Error;
7
8use crate::api::TwapiOptions;
9
10pub enum TwitterScope {
11    TweetRead,
12    TweetWrite,
13    TweetModerateWrite,
14    UsersRead,
15    FollowsRead,
16    FollowsWrite,
17    OfflineAccess,
18    SpaceRead,
19    MuteRead,
20    MuteWrite,
21    LikeRead,
22    LikeWrite,
23    ListRead,
24    ListWrite,
25    BlockRead,
26    BlockWrite,
27    BookmarkRead,
28    BookmarkWrite,
29    DmRead,
30    DmWrite,
31    MediaWrite,
32}
33
34impl TwitterScope {
35    pub fn all() -> Vec<Self> {
36        vec![
37            Self::TweetRead,
38            Self::TweetWrite,
39            Self::TweetModerateWrite,
40            Self::UsersRead,
41            Self::FollowsRead,
42            Self::FollowsWrite,
43            Self::OfflineAccess,
44            Self::SpaceRead,
45            Self::MuteRead,
46            Self::MuteWrite,
47            Self::LikeRead,
48            Self::LikeWrite,
49            Self::ListRead,
50            Self::ListWrite,
51            Self::BlockRead,
52            Self::BlockWrite,
53            Self::BookmarkRead,
54            Self::BookmarkWrite,
55            Self::DmRead,
56            Self::DmWrite,
57            Self::MediaWrite,
58        ]
59    }
60}
61
62impl std::fmt::Display for TwitterScope {
63    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
64        match self {
65            Self::TweetRead => write!(f, "tweet.read"),
66            Self::TweetWrite => write!(f, "tweet.write"),
67            Self::TweetModerateWrite => write!(f, "tweet.moderate.write"),
68            Self::UsersRead => write!(f, "users.read"),
69            Self::FollowsRead => write!(f, "follows.read"),
70            Self::FollowsWrite => write!(f, "follows.write"),
71            Self::OfflineAccess => write!(f, "offline.access"),
72            Self::SpaceRead => write!(f, "space.read"),
73            Self::MuteRead => write!(f, "mute.read"),
74            Self::MuteWrite => write!(f, "mute.write"),
75            Self::LikeRead => write!(f, "like.read"),
76            Self::LikeWrite => write!(f, "like.write"),
77            Self::ListRead => write!(f, "list.read"),
78            Self::ListWrite => write!(f, "list.write"),
79            Self::BlockRead => write!(f, "block.read"),
80            Self::BlockWrite => write!(f, "block.write"),
81            Self::BookmarkRead => write!(f, "bookmark.read"),
82            Self::BookmarkWrite => write!(f, "bookmark.write"),
83            Self::DmRead => write!(f, "dm.read"),
84            Self::DmWrite => write!(f, "dm.write"),
85            Self::MediaWrite => write!(f, "media.write"),
86        }
87    }
88}
89
90const AUTH_URL: &str = "https://x.com/i/oauth2/authorize";
91const TOKEN_URL: &str = "https://api.x.com/2/oauth2/token";
92
93#[derive(Error, Debug)]
94pub enum OAuthError {
95    #[error("Url {0}")]
96    Url(#[from] oauth2::url::ParseError),
97
98    #[error("Reqwest {0}")]
99    Reqwest(#[from] reqwest::Error),
100
101    #[error("Token {0}")]
102    Token(String),
103}
104
105#[derive(Debug, Clone)]
106pub struct OAuthUrlResult {
107    pub oauth_url: String,
108    pub pkce_verifier: String,
109}
110
111#[derive(Debug, Clone)]
112pub struct TokenResult {
113    pub access_token: String,
114    pub refresh_token: Option<String>,
115    pub expires_in: Option<Duration>,
116}
117
118pub struct TwitterOauth {
119    client_id: ClientId,
120    client_secret: ClientSecret,
121    auth_url: AuthUrl,
122    token_url: TokenUrl,
123    redirect_url: RedirectUrl,
124    scopes: Vec<Scope>,
125}
126
127impl TwitterOauth {
128    pub fn new(
129        api_key_code: &str,
130        api_secret_code: &str,
131        callback_url: &str,
132        scopes: Vec<TwitterScope>,
133    ) -> Result<Self, OAuthError> {
134        let redirect_url = RedirectUrl::new(callback_url.to_string())?;
135        let scopes: Vec<Scope> = scopes
136            .into_iter()
137            .map(|it| Scope::new(it.to_string()))
138            .collect();
139        Ok(Self {
140            client_id: ClientId::new(api_key_code.to_owned()),
141            client_secret: ClientSecret::new(api_secret_code.to_owned()),
142            auth_url: AuthUrl::new(AUTH_URL.to_owned())?,
143            token_url: TokenUrl::new(TOKEN_URL.to_owned())?,
144            redirect_url,
145            scopes,
146        })
147    }
148
149    pub fn oauth_url(&self) -> OAuthUrlResult {
150        self.oauth_url_with_state(None)
151    }
152
153    pub fn oauth_url_with_state(&self, state: Option<String>) -> OAuthUrlResult {
154        let (pkce_challenge, pkce_verifier) = oauth2::PkceCodeChallenge::new_random_sha256();
155        let csrf_token = match state {
156            Some(ref state_value) => CsrfToken::new(state_value.clone()),
157            None => CsrfToken::new_random(),
158        };
159
160        let client = BasicClient::new(self.client_id.clone())
161            .set_client_secret(self.client_secret.clone())
162            .set_auth_uri(self.auth_url.clone())
163            .set_token_uri(self.token_url.clone());
164
165        let (auth_url, _csrf_token) = client
166            .clone()
167            .set_redirect_uri(self.redirect_url.clone())
168            .authorize_url(|| csrf_token)
169            .add_scopes(self.scopes.clone())
170            .set_pkce_challenge(pkce_challenge)
171            .url();
172
173        OAuthUrlResult {
174            oauth_url: auth_url.to_string(),
175            pkce_verifier: pkce_verifier.secret().to_string(),
176        }
177    }
178
179    pub async fn token(
180        &self,
181        pkce_verifier_str: &str,
182        code: &str,
183        twapi_options: Option<&TwapiOptions>,
184    ) -> Result<TokenResult, OAuthError> {
185        let pkce_verifier = oauth2::PkceCodeVerifier::new(pkce_verifier_str.to_owned());
186
187        let client = BasicClient::new(self.client_id.clone())
188            .set_client_secret(self.client_secret.clone())
189            .set_auth_uri(self.auth_url.clone())
190            .set_token_uri(self.token_url.clone());
191
192        let mut client_builder =
193            reqwest::ClientBuilder::new().redirect(reqwest::redirect::Policy::none());
194
195        if let Some(twapi_options) = twapi_options {
196            if let Some(timeout) = twapi_options.timeout {
197                client_builder = client_builder.timeout(timeout);
198            }
199        }
200
201        let http_client = client_builder.build()?;
202
203        let token = client
204            .clone()
205            .set_redirect_uri(self.redirect_url.clone())
206            .exchange_code(AuthorizationCode::new(code.to_owned()))
207            .set_pkce_verifier(pkce_verifier)
208            .request_async(&http_client)
209            .await
210            .map_err(|e| OAuthError::Token(format!("{:?}", e)))?;
211        Ok(TokenResult {
212            access_token: token.access_token().secret().to_string(),
213            refresh_token: token.refresh_token().map(|it| it.secret().to_string()),
214            expires_in: token.expires_in(),
215        })
216    }
217}