1use oauth2::{
2 basic::BasicClient, AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, RedirectUrl,
3 Scope, TokenResponse, TokenUrl,
4};
5use std::time::Duration;
6use thiserror::Error;
7
8use crate::api::TwapiOptions;
9
10pub enum TwitterScope {
11 TweetRead,
12 TweetWrite,
13 TweetModerateWrite,
14 UsersRead,
15 FollowsRead,
16 FollowsWrite,
17 OfflineAccess,
18 SpaceRead,
19 MuteRead,
20 MuteWrite,
21 LikeRead,
22 LikeWrite,
23 ListRead,
24 ListWrite,
25 BlockRead,
26 BlockWrite,
27 BookmarkRead,
28 BookmarkWrite,
29 DmRead,
30 DmWrite,
31 MediaWrite,
32}
33
34impl TwitterScope {
35 pub fn all() -> Vec<Self> {
36 vec![
37 Self::TweetRead,
38 Self::TweetWrite,
39 Self::TweetModerateWrite,
40 Self::UsersRead,
41 Self::FollowsRead,
42 Self::FollowsWrite,
43 Self::OfflineAccess,
44 Self::SpaceRead,
45 Self::MuteRead,
46 Self::MuteWrite,
47 Self::LikeRead,
48 Self::LikeWrite,
49 Self::ListRead,
50 Self::ListWrite,
51 Self::BlockRead,
52 Self::BlockWrite,
53 Self::BookmarkRead,
54 Self::BookmarkWrite,
55 Self::DmRead,
56 Self::DmWrite,
57 Self::MediaWrite,
58 ]
59 }
60}
61
62impl std::fmt::Display for TwitterScope {
63 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
64 match self {
65 Self::TweetRead => write!(f, "tweet.read"),
66 Self::TweetWrite => write!(f, "tweet.write"),
67 Self::TweetModerateWrite => write!(f, "tweet.moderate.write"),
68 Self::UsersRead => write!(f, "users.read"),
69 Self::FollowsRead => write!(f, "follows.read"),
70 Self::FollowsWrite => write!(f, "follows.write"),
71 Self::OfflineAccess => write!(f, "offline.access"),
72 Self::SpaceRead => write!(f, "space.read"),
73 Self::MuteRead => write!(f, "mute.read"),
74 Self::MuteWrite => write!(f, "mute.write"),
75 Self::LikeRead => write!(f, "like.read"),
76 Self::LikeWrite => write!(f, "like.write"),
77 Self::ListRead => write!(f, "list.read"),
78 Self::ListWrite => write!(f, "list.write"),
79 Self::BlockRead => write!(f, "block.read"),
80 Self::BlockWrite => write!(f, "block.write"),
81 Self::BookmarkRead => write!(f, "bookmark.read"),
82 Self::BookmarkWrite => write!(f, "bookmark.write"),
83 Self::DmRead => write!(f, "dm.read"),
84 Self::DmWrite => write!(f, "dm.write"),
85 Self::MediaWrite => write!(f, "media.write"),
86 }
87 }
88}
89
90const AUTH_URL: &str = "https://x.com/i/oauth2/authorize";
91const TOKEN_URL: &str = "https://api.x.com/2/oauth2/token";
92
93#[derive(Error, Debug)]
94pub enum OAuthError {
95 #[error("Url {0}")]
96 Url(#[from] oauth2::url::ParseError),
97
98 #[error("Reqwest {0}")]
99 Reqwest(#[from] reqwest::Error),
100
101 #[error("Token {0}")]
102 Token(String),
103}
104
105#[derive(Debug, Clone)]
106pub struct OAuthUrlResult {
107 pub oauth_url: String,
108 pub pkce_verifier: String,
109}
110
111#[derive(Debug, Clone)]
112pub struct TokenResult {
113 pub access_token: String,
114 pub refresh_token: Option<String>,
115 pub expires_in: Option<Duration>,
116}
117
118pub struct TwitterOauth {
119 client_id: ClientId,
120 client_secret: ClientSecret,
121 auth_url: AuthUrl,
122 token_url: TokenUrl,
123 redirect_url: RedirectUrl,
124 scopes: Vec<Scope>,
125}
126
127impl TwitterOauth {
128 pub fn new(
129 api_key_code: &str,
130 api_secret_code: &str,
131 callback_url: &str,
132 scopes: Vec<TwitterScope>,
133 ) -> Result<Self, OAuthError> {
134 let redirect_url = RedirectUrl::new(callback_url.to_string())?;
135 let scopes: Vec<Scope> = scopes
136 .into_iter()
137 .map(|it| Scope::new(it.to_string()))
138 .collect();
139 Ok(Self {
140 client_id: ClientId::new(api_key_code.to_owned()),
141 client_secret: ClientSecret::new(api_secret_code.to_owned()),
142 auth_url: AuthUrl::new(AUTH_URL.to_owned())?,
143 token_url: TokenUrl::new(TOKEN_URL.to_owned())?,
144 redirect_url,
145 scopes,
146 })
147 }
148
149 pub fn oauth_url(&self) -> OAuthUrlResult {
150 self.oauth_url_with_state(None)
151 }
152
153 pub fn oauth_url_with_state(&self, state: Option<String>) -> OAuthUrlResult {
154 let (pkce_challenge, pkce_verifier) = oauth2::PkceCodeChallenge::new_random_sha256();
155 let csrf_token = match state {
156 Some(ref state_value) => CsrfToken::new(state_value.clone()),
157 None => CsrfToken::new_random(),
158 };
159
160 let client = BasicClient::new(self.client_id.clone())
161 .set_client_secret(self.client_secret.clone())
162 .set_auth_uri(self.auth_url.clone())
163 .set_token_uri(self.token_url.clone());
164
165 let (auth_url, _csrf_token) = client
166 .clone()
167 .set_redirect_uri(self.redirect_url.clone())
168 .authorize_url(|| csrf_token)
169 .add_scopes(self.scopes.clone())
170 .set_pkce_challenge(pkce_challenge)
171 .url();
172
173 OAuthUrlResult {
174 oauth_url: auth_url.to_string(),
175 pkce_verifier: pkce_verifier.secret().to_string(),
176 }
177 }
178
179 pub async fn token(
180 &self,
181 pkce_verifier_str: &str,
182 code: &str,
183 twapi_options: Option<&TwapiOptions>,
184 ) -> Result<TokenResult, OAuthError> {
185 let pkce_verifier = oauth2::PkceCodeVerifier::new(pkce_verifier_str.to_owned());
186
187 let client = BasicClient::new(self.client_id.clone())
188 .set_client_secret(self.client_secret.clone())
189 .set_auth_uri(self.auth_url.clone())
190 .set_token_uri(self.token_url.clone());
191
192 let mut client_builder =
193 reqwest::ClientBuilder::new().redirect(reqwest::redirect::Policy::none());
194
195 if let Some(twapi_options) = twapi_options {
196 if let Some(timeout) = twapi_options.timeout {
197 client_builder = client_builder.timeout(timeout);
198 }
199 }
200
201 let http_client = client_builder.build()?;
202
203 let token = client
204 .clone()
205 .set_redirect_uri(self.redirect_url.clone())
206 .exchange_code(AuthorizationCode::new(code.to_owned()))
207 .set_pkce_verifier(pkce_verifier)
208 .request_async(&http_client)
209 .await
210 .map_err(|e| OAuthError::Token(format!("{:?}", e)))?;
211 Ok(TokenResult {
212 access_token: token.access_token().secret().to_string(),
213 refresh_token: token.refresh_token().map(|it| it.secret().to_string()),
214 expires_in: token.expires_in(),
215 })
216 }
217}