twitter_v2/authorization/
oauth2.rs1use super::Authorization;
2use crate::error::{Error, Result};
3use async_trait::async_trait;
4use oauth2::basic::{BasicClient, BasicRequestTokenError, BasicTokenResponse};
5use oauth2::{
6 AccessToken, AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, PkceCodeChallenge,
7 PkceCodeVerifier, RedirectUrl, RefreshToken, RevocationUrl, StandardRevocableToken,
8 TokenResponse, TokenUrl,
9};
10use reqwest::header::HeaderValue;
11use reqwest::Request;
12use serde::{Deserialize, Serialize};
13use std::convert::{TryFrom, TryInto};
14use std::future::Future;
15use std::sync::Arc;
16use strum::{Display, EnumString};
17use time::OffsetDateTime;
18use tokio::sync::{RwLock, RwLockReadGuard};
19use url::Url;
20
21#[derive(Copy, Clone, Debug, EnumString, Display, Serialize, Deserialize)]
22#[strum(serialize_all = "snake_case")]
23pub enum Scope {
24 #[strum(serialize = "tweet.read")]
25 #[serde(rename = "tweet.read")]
26 TweetRead,
27 #[strum(serialize = "tweet.write")]
28 #[serde(rename = "tweet.write")]
29 TweetWrite,
30 #[strum(serialize = "tweet.moderate.write")]
31 #[serde(rename = "tweet.moderate.write")]
32 TweetModerateWrite,
33 #[strum(serialize = "users.read")]
34 #[serde(rename = "users.read")]
35 UsersRead,
36 #[strum(serialize = "follows.read")]
37 #[serde(rename = "follows.read")]
38 FollowsRead,
39 #[strum(serialize = "follows.write")]
40 #[serde(rename = "follows.write")]
41 FollowsWrite,
42 #[strum(serialize = "offline.access")]
43 #[serde(rename = "offline.access")]
44 OfflineAccess,
45 #[strum(serialize = "space.read")]
46 #[serde(rename = "space.read")]
47 SpaceRead,
48 #[strum(serialize = "mute.read")]
49 #[serde(rename = "mute.read")]
50 MuteRead,
51 #[strum(serialize = "mute.write")]
52 #[serde(rename = "mute.write")]
53 MuteWrite,
54 #[strum(serialize = "like.read")]
55 #[serde(rename = "like.read")]
56 LikeRead,
57 #[strum(serialize = "like.write")]
58 #[serde(rename = "like.write")]
59 LikeWrite,
60 #[strum(serialize = "list.read")]
61 #[serde(rename = "list.read")]
62 ListRead,
63 #[strum(serialize = "list.write")]
64 #[serde(rename = "list.write")]
65 ListWrite,
66 #[strum(serialize = "block.read")]
67 #[serde(rename = "block.read")]
68 BlockRead,
69 #[strum(serialize = "block.write")]
70 #[serde(rename = "block.write")]
71 BlockWrite,
72 #[strum(serialize = "bookmark.read")]
73 #[serde(rename = "bookmark.read")]
74 BookmarkRead,
75 #[strum(serialize = "bookmark.write")]
76 #[serde(rename = "bookmark.write")]
77 BookmarkWrite,
78}
79
80impl From<Scope> for oauth2::Scope {
81 fn from(scope: Scope) -> Self {
82 oauth2::Scope::new(scope.to_string())
83 }
84}
85
86#[derive(Clone, Debug)]
87pub struct Oauth2Client(BasicClient);
88
89impl Oauth2Client {
90 pub fn new(client_id: impl ToString, client_secret: impl ToString, callback_url: Url) -> Self {
91 Self(
92 BasicClient::new(
93 ClientId::new(client_id.to_string()),
94 Some(ClientSecret::new(client_secret.to_string())),
95 AuthUrl::from_url("https://twitter.com/i/oauth2/authorize".parse().unwrap()),
96 Some(TokenUrl::from_url(
97 "https://api.twitter.com/2/oauth2/token".parse().unwrap(),
98 )),
99 )
100 .set_revocation_uri(RevocationUrl::from_url(
101 "https://api.twitter.com/2/oauth2/revoke".parse().unwrap(),
102 ))
103 .set_redirect_uri(RedirectUrl::from_url(callback_url)),
104 )
105 }
106
107 pub fn auth_url(
108 &self,
109 challenge: PkceCodeChallenge,
110 scopes: impl IntoIterator<Item = Scope>,
111 ) -> (Url, CsrfToken) {
112 self.0
113 .authorize_url(CsrfToken::new_random)
114 .set_pkce_challenge(challenge)
115 .add_scopes(scopes.into_iter().map(|s| s.into()))
116 .url()
117 }
118
119 pub async fn request_token(
120 &self,
121 code: AuthorizationCode,
122 verifier: PkceCodeVerifier,
123 ) -> Result<Oauth2Token> {
124 let res = self
125 .0
126 .exchange_code(code)
127 .set_pkce_verifier(verifier)
128 .request_async(oauth2::reqwest::async_http_client)
129 .await?;
130 res.try_into()
131 }
132
133 pub async fn revoke_token(&self, token: StandardRevocableToken) -> Result<()> {
134 Ok(self
135 .0
136 .revoke_token(token)
137 .unwrap()
138 .request_async(oauth2::reqwest::async_http_client)
139 .await?)
140 }
141
142 pub async fn refresh_token(&self, token: &RefreshToken) -> Result<Oauth2Token> {
143 self.0
144 .exchange_refresh_token(token)
145 .request_async(oauth2::reqwest::async_http_client)
146 .await?
147 .try_into()
148 }
149
150 pub async fn refresh_token_if_expired(&self, token: &mut Oauth2Token) -> Result<bool> {
151 if token.is_expired() {
152 if let Some(refresh_token) = token.refresh_token() {
153 *token = self.refresh_token(refresh_token).await?;
154 Ok(true)
155 } else {
156 Err(Error::NoRefreshToken)
157 }
158 } else {
159 Ok(false)
160 }
161 }
162}
163
164#[derive(Clone, Debug, Serialize, Deserialize)]
165pub struct Oauth2Token {
166 access_token: AccessToken,
167 refresh_token: Option<RefreshToken>,
168 #[serde(with = "time::serde::rfc3339")]
169 expires: OffsetDateTime,
170 scopes: Vec<Scope>,
171}
172
173impl Oauth2Token {
174 pub fn access_token(&self) -> &AccessToken {
175 &self.access_token
176 }
177 pub fn refresh_token(&self) -> Option<&RefreshToken> {
178 self.refresh_token.as_ref()
179 }
180 pub fn expires(&self) -> OffsetDateTime {
181 self.expires
182 }
183 pub fn is_expired(&self) -> bool {
184 self.expires < OffsetDateTime::now_utc()
185 }
186 pub fn scopes(&self) -> &[Scope] {
187 &self.scopes
188 }
189 pub fn revokable_token(&self) -> StandardRevocableToken {
190 if let Some(refresh_token) = self.refresh_token.as_ref() {
191 StandardRevocableToken::RefreshToken(refresh_token.clone())
192 } else {
193 StandardRevocableToken::AccessToken(self.access_token.clone())
194 }
195 }
196}
197
198impl TryFrom<BasicTokenResponse> for Oauth2Token {
199 type Error = Error;
200 fn try_from(token: BasicTokenResponse) -> Result<Self, Self::Error> {
201 Ok(Self {
202 access_token: token.access_token().clone(),
203 refresh_token: token.refresh_token().cloned(),
204 expires: OffsetDateTime::now_utc()
205 + token.expires_in().ok_or_else(|| {
206 Error::Oauth2TokenError(BasicRequestTokenError::Other(
207 "Missing expiration".to_string(),
208 ))
209 })?,
210 scopes: token
211 .scopes()
212 .ok_or_else(|| {
213 Error::Oauth2TokenError(BasicRequestTokenError::Other(
214 "Missing scopes".to_string(),
215 ))
216 })?
217 .iter()
218 .map(|s| {
219 s.parse().map_err(|err| {
220 Error::Oauth2TokenError(BasicRequestTokenError::Other(format!(
221 "Invalid scope: {}",
222 err
223 )))
224 })
225 })
226 .collect::<Result<Vec<_>>>()?,
227 })
228 }
229}
230
231#[async_trait]
232impl Authorization for Oauth2Token {
233 async fn header(&self, _request: &Request) -> Result<HeaderValue> {
234 format!("Bearer {}", self.access_token().secret())
235 .parse()
236 .map_err(Error::InvalidAuthorizationHeader)
237 }
238}
239
240fn no_op(_: Oauth2Token) -> futures::future::Ready<Result<()>> {
241 futures::future::ok(())
242}
243pub type NoCallback = fn(Oauth2Token) -> futures::future::Ready<Result<()>>;
244
245#[derive(Clone, Debug)]
246pub struct RefreshableOauth2Token<C> {
247 oauth_client: Oauth2Client,
248 token: Arc<RwLock<Oauth2Token>>,
249 callback: C,
250}
251
252impl RefreshableOauth2Token<NoCallback> {
253 pub fn new(oauth_client: Oauth2Client, token: Oauth2Token) -> Self {
254 Self {
255 oauth_client,
256 token: Arc::new(RwLock::new(token)),
257 callback: no_op,
258 }
259 }
260}
261
262impl<C> RefreshableOauth2Token<C> {
263 pub fn with_callback<T>(&self, callback: T) -> RefreshableOauth2Token<T> {
264 RefreshableOauth2Token {
265 oauth_client: self.oauth_client.clone(),
266 token: self.token.clone(),
267 callback,
268 }
269 }
270 pub async fn token(&self) -> RwLockReadGuard<'_, Oauth2Token> {
271 self.token.read().await
272 }
273
274 pub async fn revoke(&self) -> Result<()> {
275 self.oauth_client
276 .revoke_token(self.token.read().await.revokable_token())
277 .await
278 }
279}
280
281impl<C, F> RefreshableOauth2Token<C>
282where
283 C: Fn(Oauth2Token) -> F + Send + Sync,
284 F: Future<Output = Result<()>>,
285{
286 pub async fn refresh(&self) -> Result<()> {
287 let mut token = self.token.write().await;
288 *token = self
289 .oauth_client
290 .refresh_token(token.refresh_token.as_ref().ok_or(Error::NoRefreshToken)?)
291 .await?;
292 (self.callback)(token.clone()).await?;
293 Ok(())
294 }
295}
296
297#[async_trait]
298impl<C, F> Authorization for RefreshableOauth2Token<C>
299where
300 C: Fn(Oauth2Token) -> F + Send + Sync,
301 F: Future<Output = Result<()>> + Send,
302{
303 async fn header(&self, request: &Request) -> Result<HeaderValue> {
304 let mut token = self.token.write().await;
305 if self
306 .oauth_client
307 .refresh_token_if_expired(&mut token)
308 .await?
309 {
310 (self.callback)(token.clone()).await?;
311 }
312 token.header(request).await
313 }
314}