1use async_trait::async_trait;
4use std::convert::Infallible;
5use std::fmt::{Debug, Display};
6
7#[cfg(any(
8 all(
9 feature = "refreshing-token-native-tls",
10 feature = "refreshing-token-rustls-native-roots"
11 ),
12 all(
13 feature = "refreshing-token-native-tls",
14 feature = "refreshing-token-rustls-webpki-roots"
15 ),
16 all(
17 feature = "refreshing-token-rustls-native-roots",
18 feature = "refreshing-token-rustls-webpki-roots"
19 ),
20))]
21compile_error!("`refreshing-token-native-tls`, `refreshing-token-rustls-native-roots` and `refreshing-token-rustls-webpki-roots` feature flags are mutually exclusive, enable at most one of them");
22
23#[cfg(feature = "__refreshing-token")]
24use {
25 chrono::DateTime,
26 chrono::Utc,
27 std::{sync::Arc, time::Duration},
28 thiserror::Error,
29 tokio::sync::Mutex,
30};
31
32#[cfg(feature = "with-serde")]
33use {serde::Deserialize, serde::Serialize};
34
35#[derive(Debug, Clone)]
37#[cfg_attr(feature = "with-serde", derive(Serialize, Deserialize))]
38pub struct CredentialsPair {
39 pub login: String,
41 pub token: Option<String>,
45}
46
47#[async_trait]
50pub trait LoginCredentials: Debug + Send + Sync + 'static {
51 type Error: Send + Sync + Debug + Display;
53
54 async fn get_credentials(&self) -> Result<CredentialsPair, Self::Error>;
56}
57
58#[derive(Debug, Clone)]
61#[cfg_attr(feature = "with-serde", derive(Serialize, Deserialize))]
62pub struct StaticLoginCredentials {
63 pub credentials: CredentialsPair,
65}
66
67impl StaticLoginCredentials {
68 pub fn new(login: String, token: Option<String>) -> StaticLoginCredentials {
71 StaticLoginCredentials {
72 credentials: CredentialsPair { login, token },
73 }
74 }
75
76 pub fn anonymous() -> StaticLoginCredentials {
78 StaticLoginCredentials::new("justinfan12345".to_owned(), None)
79 }
80}
81
82#[async_trait]
83impl LoginCredentials for StaticLoginCredentials {
84 type Error = Infallible;
85
86 async fn get_credentials(&self) -> Result<CredentialsPair, Infallible> {
87 Ok(self.credentials.clone())
88 }
89}
90
91#[cfg(feature = "__refreshing-token")]
94#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct UserAccessToken {
96 pub access_token: String,
98 pub refresh_token: String,
100 pub created_at: DateTime<Utc>,
102 pub expires_at: Option<DateTime<Utc>>,
104}
105
106#[cfg(feature = "__refreshing-token")]
122#[derive(Serialize, Deserialize)]
123pub struct GetAccessTokenResponse {
124 pub access_token: String,
135 pub refresh_token: String,
137 pub expires_in: Option<u64>,
140}
141
142#[cfg(feature = "__refreshing-token")]
143impl From<GetAccessTokenResponse> for UserAccessToken {
144 fn from(response: GetAccessTokenResponse) -> Self {
145 let now = Utc::now();
146 UserAccessToken {
147 access_token: response.access_token,
148 refresh_token: response.refresh_token,
149 created_at: now,
150 expires_at: response
151 .expires_in
152 .map(|d| now + chrono::Duration::from_std(Duration::from_secs(d)).unwrap()),
153 }
154 }
155}
156
157#[cfg(feature = "__refreshing-token")]
159#[async_trait]
160pub trait TokenStorage: Debug + Send + 'static {
161 type LoadError: Send + Sync + Debug + Display;
163 type UpdateError: Send + Sync + Debug + Display;
165
166 async fn load_token(&mut self) -> Result<UserAccessToken, Self::LoadError>;
168 async fn update_token(&mut self, token: &UserAccessToken) -> Result<(), Self::UpdateError>;
172}
173
174#[cfg(feature = "__refreshing-token")]
178#[derive(Debug, Clone)]
179pub struct RefreshingLoginCredentials<S: TokenStorage> {
180 http_client: reqwest::Client,
181 user_login: Arc<Mutex<Option<String>>>,
182 client_id: String,
183 client_secret: String,
184 token_storage: Arc<Mutex<S>>,
185}
186
187#[cfg(feature = "__refreshing-token")]
188impl<S: TokenStorage> RefreshingLoginCredentials<S> {
189 pub fn init(
192 client_id: String,
193 client_secret: String,
194 token_storage: S,
195 ) -> RefreshingLoginCredentials<S> {
196 RefreshingLoginCredentials::init_with_username(
197 None,
198 client_id,
199 client_secret,
200 token_storage,
201 )
202 }
203
204 pub fn init_with_username(
209 user_login: Option<String>,
210 client_id: String,
211 client_secret: String,
212 token_storage: S,
213 ) -> RefreshingLoginCredentials<S> {
214 RefreshingLoginCredentials {
215 http_client: reqwest::Client::new(),
216 user_login: Arc::new(Mutex::new(user_login)),
217 client_id,
218 client_secret,
219 token_storage: Arc::new(Mutex::new(token_storage)),
220 }
221 }
222}
223
224#[cfg(feature = "__refreshing-token")]
226#[derive(Error, Debug)]
227pub enum RefreshingLoginError<S: TokenStorage> {
228 #[error("Failed to retrieve token from storage: {0}")]
230 LoadError(S::LoadError),
231 #[error("Failed to refresh token: {0}")]
233 RefreshError(reqwest::Error),
234 #[error("Failed to update token in storage: {0}")]
236 UpdateError(S::UpdateError),
237}
238
239#[cfg(feature = "__refreshing-token")]
240const SHOULD_REFRESH_AFTER_FACTOR: f64 = 0.9;
241
242#[cfg(feature = "__refreshing-token")]
243#[async_trait]
244impl<S: TokenStorage> LoginCredentials for RefreshingLoginCredentials<S> {
245 type Error = RefreshingLoginError<S>;
246
247 async fn get_credentials(&self) -> Result<CredentialsPair, RefreshingLoginError<S>> {
248 let mut token_storage = self.token_storage.lock().await;
249
250 let mut current_token = token_storage
251 .load_token()
252 .await
253 .map_err(RefreshingLoginError::LoadError)?;
254
255 let token_expires_after = if let Some(expires_at) = current_token.expires_at {
256 (expires_at - current_token.created_at).to_std().unwrap()
258 } else {
259 Duration::from_secs(24 * 60 * 60)
261 };
262 let token_age = (Utc::now() - current_token.created_at).to_std().unwrap();
263 let max_token_age = token_expires_after.mul_f64(SHOULD_REFRESH_AFTER_FACTOR);
264 let is_token_expired = token_age >= max_token_age;
265
266 if is_token_expired {
267 let response = self
268 .http_client
269 .post("https://id.twitch.tv/oauth2/token")
270 .query(&[
271 ("grant_type", "refresh_token"),
272 ("refresh_token", ¤t_token.refresh_token),
273 ("client_id", &self.client_id),
274 ("client_secret", &self.client_secret),
275 ])
276 .send()
277 .await
278 .map_err(RefreshingLoginError::RefreshError)?
279 .json::<GetAccessTokenResponse>()
280 .await
281 .map_err(RefreshingLoginError::RefreshError)?;
282
283 current_token = UserAccessToken::from(response);
285
286 token_storage
287 .update_token(¤t_token)
288 .await
289 .map_err(RefreshingLoginError::UpdateError)?;
290 }
291
292 let mut current_login = self.user_login.lock().await;
293
294 let login = match &*current_login {
295 Some(login) => login.clone(),
296 None => {
297 let response = self
298 .http_client
299 .get("https://api.twitch.tv/helix/users")
300 .header("Client-Id", &self.client_id)
301 .bearer_auth(¤t_token.access_token)
302 .send()
303 .await
304 .map_err(RefreshingLoginError::RefreshError)?;
305
306 let users_response = response
307 .json::<UsersResponse>()
308 .await
309 .map_err(RefreshingLoginError::RefreshError)?;
310
311 let user = users_response.data.into_iter().next().unwrap();
313
314 tracing::info!(
318 "Fetched login name `{}` for provided auth token",
319 &user.login
320 );
321
322 *current_login = Some(user.login.clone());
323
324 user.login
325 }
326 };
327
328 Ok(CredentialsPair {
329 login,
330 token: Some(current_token.access_token.clone()),
331 })
332 }
333}
334
335#[cfg(feature = "__refreshing-token")]
338#[derive(Deserialize)]
339struct UsersResponse {
340 data: Vec<UserObject>,
341}
342
343#[cfg(feature = "__refreshing-token")]
346#[derive(Deserialize)]
347struct UserObject {
348 login: String,
349}