twitter_v2/authorization/
oauth2.rs

1use super::Authorization;
2use crate::error::{Error, Result};
3use async_trait::async_trait;
4use oauth2::basic::{BasicClient, BasicRequestTokenError, BasicTokenResponse};
5use oauth2::{
6    AccessToken, AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, PkceCodeChallenge,
7    PkceCodeVerifier, RedirectUrl, RefreshToken, RevocationUrl, StandardRevocableToken,
8    TokenResponse, TokenUrl,
9};
10use reqwest::header::HeaderValue;
11use reqwest::Request;
12use serde::{Deserialize, Serialize};
13use std::convert::{TryFrom, TryInto};
14use std::future::Future;
15use std::sync::Arc;
16use strum::{Display, EnumString};
17use time::OffsetDateTime;
18use tokio::sync::{RwLock, RwLockReadGuard};
19use url::Url;
20
21#[derive(Copy, Clone, Debug, EnumString, Display, Serialize, Deserialize)]
22#[strum(serialize_all = "snake_case")]
23pub enum Scope {
24    #[strum(serialize = "tweet.read")]
25    #[serde(rename = "tweet.read")]
26    TweetRead,
27    #[strum(serialize = "tweet.write")]
28    #[serde(rename = "tweet.write")]
29    TweetWrite,
30    #[strum(serialize = "tweet.moderate.write")]
31    #[serde(rename = "tweet.moderate.write")]
32    TweetModerateWrite,
33    #[strum(serialize = "users.read")]
34    #[serde(rename = "users.read")]
35    UsersRead,
36    #[strum(serialize = "follows.read")]
37    #[serde(rename = "follows.read")]
38    FollowsRead,
39    #[strum(serialize = "follows.write")]
40    #[serde(rename = "follows.write")]
41    FollowsWrite,
42    #[strum(serialize = "offline.access")]
43    #[serde(rename = "offline.access")]
44    OfflineAccess,
45    #[strum(serialize = "space.read")]
46    #[serde(rename = "space.read")]
47    SpaceRead,
48    #[strum(serialize = "mute.read")]
49    #[serde(rename = "mute.read")]
50    MuteRead,
51    #[strum(serialize = "mute.write")]
52    #[serde(rename = "mute.write")]
53    MuteWrite,
54    #[strum(serialize = "like.read")]
55    #[serde(rename = "like.read")]
56    LikeRead,
57    #[strum(serialize = "like.write")]
58    #[serde(rename = "like.write")]
59    LikeWrite,
60    #[strum(serialize = "list.read")]
61    #[serde(rename = "list.read")]
62    ListRead,
63    #[strum(serialize = "list.write")]
64    #[serde(rename = "list.write")]
65    ListWrite,
66    #[strum(serialize = "block.read")]
67    #[serde(rename = "block.read")]
68    BlockRead,
69    #[strum(serialize = "block.write")]
70    #[serde(rename = "block.write")]
71    BlockWrite,
72    #[strum(serialize = "bookmark.read")]
73    #[serde(rename = "bookmark.read")]
74    BookmarkRead,
75    #[strum(serialize = "bookmark.write")]
76    #[serde(rename = "bookmark.write")]
77    BookmarkWrite,
78}
79
80impl From<Scope> for oauth2::Scope {
81    fn from(scope: Scope) -> Self {
82        oauth2::Scope::new(scope.to_string())
83    }
84}
85
86#[derive(Clone, Debug)]
87pub struct Oauth2Client(BasicClient);
88
89impl Oauth2Client {
90    pub fn new(client_id: impl ToString, client_secret: impl ToString, callback_url: Url) -> Self {
91        Self(
92            BasicClient::new(
93                ClientId::new(client_id.to_string()),
94                Some(ClientSecret::new(client_secret.to_string())),
95                AuthUrl::from_url("https://twitter.com/i/oauth2/authorize".parse().unwrap()),
96                Some(TokenUrl::from_url(
97                    "https://api.twitter.com/2/oauth2/token".parse().unwrap(),
98                )),
99            )
100            .set_revocation_uri(RevocationUrl::from_url(
101                "https://api.twitter.com/2/oauth2/revoke".parse().unwrap(),
102            ))
103            .set_redirect_uri(RedirectUrl::from_url(callback_url)),
104        )
105    }
106
107    pub fn auth_url(
108        &self,
109        challenge: PkceCodeChallenge,
110        scopes: impl IntoIterator<Item = Scope>,
111    ) -> (Url, CsrfToken) {
112        self.0
113            .authorize_url(CsrfToken::new_random)
114            .set_pkce_challenge(challenge)
115            .add_scopes(scopes.into_iter().map(|s| s.into()))
116            .url()
117    }
118
119    pub async fn request_token(
120        &self,
121        code: AuthorizationCode,
122        verifier: PkceCodeVerifier,
123    ) -> Result<Oauth2Token> {
124        let res = self
125            .0
126            .exchange_code(code)
127            .set_pkce_verifier(verifier)
128            .request_async(oauth2::reqwest::async_http_client)
129            .await?;
130        res.try_into()
131    }
132
133    pub async fn revoke_token(&self, token: StandardRevocableToken) -> Result<()> {
134        Ok(self
135            .0
136            .revoke_token(token)
137            .unwrap()
138            .request_async(oauth2::reqwest::async_http_client)
139            .await?)
140    }
141
142    pub async fn refresh_token(&self, token: &RefreshToken) -> Result<Oauth2Token> {
143        self.0
144            .exchange_refresh_token(token)
145            .request_async(oauth2::reqwest::async_http_client)
146            .await?
147            .try_into()
148    }
149
150    pub async fn refresh_token_if_expired(&self, token: &mut Oauth2Token) -> Result<bool> {
151        if token.is_expired() {
152            if let Some(refresh_token) = token.refresh_token() {
153                *token = self.refresh_token(refresh_token).await?;
154                Ok(true)
155            } else {
156                Err(Error::NoRefreshToken)
157            }
158        } else {
159            Ok(false)
160        }
161    }
162}
163
164#[derive(Clone, Debug, Serialize, Deserialize)]
165pub struct Oauth2Token {
166    access_token: AccessToken,
167    refresh_token: Option<RefreshToken>,
168    #[serde(with = "time::serde::rfc3339")]
169    expires: OffsetDateTime,
170    scopes: Vec<Scope>,
171}
172
173impl Oauth2Token {
174    pub fn access_token(&self) -> &AccessToken {
175        &self.access_token
176    }
177    pub fn refresh_token(&self) -> Option<&RefreshToken> {
178        self.refresh_token.as_ref()
179    }
180    pub fn expires(&self) -> OffsetDateTime {
181        self.expires
182    }
183    pub fn is_expired(&self) -> bool {
184        self.expires < OffsetDateTime::now_utc()
185    }
186    pub fn scopes(&self) -> &[Scope] {
187        &self.scopes
188    }
189    pub fn revokable_token(&self) -> StandardRevocableToken {
190        if let Some(refresh_token) = self.refresh_token.as_ref() {
191            StandardRevocableToken::RefreshToken(refresh_token.clone())
192        } else {
193            StandardRevocableToken::AccessToken(self.access_token.clone())
194        }
195    }
196}
197
198impl TryFrom<BasicTokenResponse> for Oauth2Token {
199    type Error = Error;
200    fn try_from(token: BasicTokenResponse) -> Result<Self, Self::Error> {
201        Ok(Self {
202            access_token: token.access_token().clone(),
203            refresh_token: token.refresh_token().cloned(),
204            expires: OffsetDateTime::now_utc()
205                + token.expires_in().ok_or_else(|| {
206                    Error::Oauth2TokenError(BasicRequestTokenError::Other(
207                        "Missing expiration".to_string(),
208                    ))
209                })?,
210            scopes: token
211                .scopes()
212                .ok_or_else(|| {
213                    Error::Oauth2TokenError(BasicRequestTokenError::Other(
214                        "Missing scopes".to_string(),
215                    ))
216                })?
217                .iter()
218                .map(|s| {
219                    s.parse().map_err(|err| {
220                        Error::Oauth2TokenError(BasicRequestTokenError::Other(format!(
221                            "Invalid scope: {}",
222                            err
223                        )))
224                    })
225                })
226                .collect::<Result<Vec<_>>>()?,
227        })
228    }
229}
230
231#[async_trait]
232impl Authorization for Oauth2Token {
233    async fn header(&self, _request: &Request) -> Result<HeaderValue> {
234        format!("Bearer {}", self.access_token().secret())
235            .parse()
236            .map_err(Error::InvalidAuthorizationHeader)
237    }
238}
239
240fn no_op(_: Oauth2Token) -> futures::future::Ready<Result<()>> {
241    futures::future::ok(())
242}
243pub type NoCallback = fn(Oauth2Token) -> futures::future::Ready<Result<()>>;
244
245#[derive(Clone, Debug)]
246pub struct RefreshableOauth2Token<C> {
247    oauth_client: Oauth2Client,
248    token: Arc<RwLock<Oauth2Token>>,
249    callback: C,
250}
251
252impl RefreshableOauth2Token<NoCallback> {
253    pub fn new(oauth_client: Oauth2Client, token: Oauth2Token) -> Self {
254        Self {
255            oauth_client,
256            token: Arc::new(RwLock::new(token)),
257            callback: no_op,
258        }
259    }
260}
261
262impl<C> RefreshableOauth2Token<C> {
263    pub fn with_callback<T>(&self, callback: T) -> RefreshableOauth2Token<T> {
264        RefreshableOauth2Token {
265            oauth_client: self.oauth_client.clone(),
266            token: self.token.clone(),
267            callback,
268        }
269    }
270    pub async fn token(&self) -> RwLockReadGuard<'_, Oauth2Token> {
271        self.token.read().await
272    }
273
274    pub async fn revoke(&self) -> Result<()> {
275        self.oauth_client
276            .revoke_token(self.token.read().await.revokable_token())
277            .await
278    }
279}
280
281impl<C, F> RefreshableOauth2Token<C>
282where
283    C: Fn(Oauth2Token) -> F + Send + Sync,
284    F: Future<Output = Result<()>>,
285{
286    pub async fn refresh(&self) -> Result<()> {
287        let mut token = self.token.write().await;
288        *token = self
289            .oauth_client
290            .refresh_token(token.refresh_token.as_ref().ok_or(Error::NoRefreshToken)?)
291            .await?;
292        (self.callback)(token.clone()).await?;
293        Ok(())
294    }
295}
296
297#[async_trait]
298impl<C, F> Authorization for RefreshableOauth2Token<C>
299where
300    C: Fn(Oauth2Token) -> F + Send + Sync,
301    F: Future<Output = Result<()>> + Send,
302{
303    async fn header(&self, request: &Request) -> Result<HeaderValue> {
304        let mut token = self.token.write().await;
305        if self
306            .oauth_client
307            .refresh_token_if_expired(&mut token)
308            .await?
309        {
310            (self.callback)(token.clone()).await?;
311        }
312        token.header(request).await
313    }
314}