scoopit_api/
access_token_store.rs

1use std::{
2    convert::TryInto,
3    sync::{Arc, RwLock},
4    time::{Duration, SystemTime, UNIX_EPOCH},
5};
6
7use anyhow::Context;
8use log::{debug, error};
9
10use crate::{
11    oauth::{AccessTokenRequest, AccessTokenResponse},
12    AccessToken, ScoopitAPI,
13};
14
15struct AccessTokenRenewer {
16    scoopit_api: ScoopitAPI,
17    client: reqwest::Client,
18    client_id: String,
19    client_secret: String,
20}
21
22impl AccessTokenRenewer {
23    async fn renew_token(&self, refresh_token: &str) -> anyhow::Result<AccessToken> {
24        let new_access_token = self
25            .client
26            .post(self.scoopit_api.access_token_endpoint.clone())
27            .form(&AccessTokenRequest {
28                client_id: &self.client_id,
29                client_secret: &self.client_secret,
30                grant_type: "refresh_token",
31                refresh_token: Some(refresh_token),
32            })
33            .send()
34            .await?
35            .error_for_status()?
36            .json::<AccessTokenResponse>()
37            .await?;
38
39        debug!("Got new token: {:?}", new_access_token);
40
41        Ok(new_access_token.try_into()?)
42    }
43}
44
45pub async fn authenticate_with_client_credentials(
46    client: &reqwest::Client,
47    scoopit_api: &ScoopitAPI,
48    client_id: &str,
49    client_secret: &str,
50) -> anyhow::Result<AccessToken> {
51    Ok(client
52        .post(scoopit_api.access_token_endpoint.clone())
53        .form(&AccessTokenRequest {
54            client_id: client_id,
55            client_secret: client_secret,
56            grant_type: "client_credentials",
57            refresh_token: None,
58        })
59        .send()
60        .await?
61        .error_for_status()?
62        .json::<AccessTokenResponse>()
63        .await?
64        .try_into()?)
65}
66
67pub struct AccessTokenStore {
68    renewer: Arc<AccessTokenRenewer>,
69    access_token: Arc<RwLock<AccessToken>>,
70}
71
72impl AccessTokenStore {
73    pub fn new(
74        token: AccessToken,
75        scoopit_api: ScoopitAPI,
76        client: reqwest::Client,
77        client_id: String,
78        client_secret: String,
79    ) -> Self {
80        let access_token = Arc::new(RwLock::new(token));
81        let renewer = Arc::new(AccessTokenRenewer {
82            scoopit_api,
83            client,
84            client_id,
85            client_secret,
86        });
87        AccessTokenStore::schedule_renewal(renewer.clone(), access_token.clone());
88        Self {
89            access_token,
90            renewer,
91        }
92    }
93
94    fn schedule_renewal(renewer: Arc<AccessTokenRenewer>, access_token: Arc<RwLock<AccessToken>>) {
95        let renew_date = {
96            let token = access_token.read().unwrap();
97            // schedule renew 5 minutes after token expiry so we will be sure the
98            // access token will get refreshed if it needs it, thus the refresh token will
99            // also be refreshed (refresh token also expires, which forces us to keep the token
100            // alive)
101            token
102                .renew
103                .as_ref()
104                .map(|renew| UNIX_EPOCH + Duration::from_secs(renew.expires_at + 300))
105        };
106        if let Some(renew_date) = renew_date {
107            let wait_time = renew_date.duration_since(SystemTime::now()).ok();
108            tokio::spawn(AccessTokenStore::renew_if_needed_log_error(
109                renewer,
110                access_token,
111                wait_time,
112            ));
113        }
114    }
115
116    async fn renew_if_needed_log_error(
117        renewer: Arc<AccessTokenRenewer>,
118        access_token: Arc<RwLock<AccessToken>>,
119        wait_time: Option<Duration>,
120    ) {
121        debug!("Access token renew scheduled!");
122        if let Some(wait_time) = wait_time {
123            tokio::time::sleep(wait_time).await;
124        }
125        if let Err(e) =
126            AccessTokenStore::renew_token_if_needed(renewer.clone(), access_token.clone()).await
127        {
128            error!("Unable to renew access token! {}", e);
129            tokio::time::sleep(Duration::from_secs(1)).await;
130            AccessTokenStore::schedule_renewal(renewer, access_token);
131        }
132    }
133
134    async fn renew_token_if_needed(
135        renewer: Arc<AccessTokenRenewer>,
136        access_token: Arc<RwLock<AccessToken>>,
137    ) -> anyhow::Result<()> {
138        let refresh_token = {
139            let token = access_token.read().unwrap();
140            match &token.renew {
141                Some(renew) => {
142                    let now_timestamp = SystemTime::now().duration_since(UNIX_EPOCH)?;
143                    if now_timestamp.as_secs() < renew.expires_at {
144                        debug!("Access token: {}, no renew needed!", token.access_token);
145                        // no renew needed
146                        return Ok(());
147                    } else {
148                        debug!("Access token: {}, renew needed!", token.access_token);
149                        renew.refresh_token.clone()
150                    }
151                }
152                // no renew needed
153                None => return Ok(()),
154            }
155        };
156        // renew needed: lock lately to avoid having the lock guard being leaked in the future making
157        // the client not Send
158
159        let new_access_token = renewer.renew_token(&refresh_token).await?;
160
161        {
162            let mut token = access_token.write().unwrap();
163
164            *token = new_access_token;
165        }
166        AccessTokenStore::schedule_renewal(renewer, access_token);
167
168        Ok(())
169    }
170
171    pub async fn get_access_token(&self) -> anyhow::Result<String> {
172        AccessTokenStore::renew_token_if_needed(self.renewer.clone(), self.access_token.clone())
173            .await
174            .context("Cannot renew access token!")?;
175        Ok(self.access_token.read().unwrap().access_token.clone())
176    }
177}