1use async_trait::async_trait;
4use std::convert::Infallible;
5use std::fmt::{Debug, Display};
6
7#[cfg(any(
8 all(
9 feature = "refreshing-token-native-tls",
10 feature = "refreshing-token-rustls-native-roots"
11 ),
12 all(
13 feature = "refreshing-token-native-tls",
14 feature = "refreshing-token-rustls-webpki-roots"
15 ),
16 all(
17 feature = "refreshing-token-rustls-native-roots",
18 feature = "refreshing-token-rustls-webpki-roots"
19 ),
20))]
21compile_error!(
22 "`refreshing-token-native-tls`, `refreshing-token-rustls-native-roots` and `refreshing-token-rustls-webpki-roots` feature flags are mutually exclusive, enable at most one of them"
23);
24
25#[cfg(feature = "__refreshing-token")]
26use {
27 chrono::DateTime,
28 chrono::Utc,
29 reqwest::ClientBuilder,
30 std::{sync::Arc, time::Duration},
31 thiserror::Error,
32 tokio::sync::Mutex,
33};
34
35#[cfg(feature = "with-serde")]
36use {serde::Deserialize, serde::Serialize};
37
38#[derive(Clone)]
40#[cfg_attr(feature = "with-serde", derive(Serialize, Deserialize))]
41pub struct CredentialsPair {
42 pub login: String,
44 pub token: Option<String>,
48}
49
50impl Debug for CredentialsPair {
52 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53 f.debug_struct("CredentialsPair")
54 .field("login", &self.login)
55 .field("token", &self.token.as_ref().map(|_| "[redacted]"))
56 .finish()
57 }
58}
59
60#[async_trait]
63pub trait LoginCredentials: Debug + Send + Sync + 'static {
64 type Error: Send + Sync + Debug + Display;
66
67 async fn get_credentials(&self) -> Result<CredentialsPair, Self::Error>;
69}
70
71#[derive(Debug, Clone)]
74#[cfg_attr(feature = "with-serde", derive(Serialize, Deserialize))]
75pub struct StaticLoginCredentials {
76 pub credentials: CredentialsPair,
78}
79
80impl StaticLoginCredentials {
81 #[must_use]
84 pub fn new(login: String, token: Option<String>) -> StaticLoginCredentials {
85 StaticLoginCredentials {
86 credentials: CredentialsPair { login, token },
87 }
88 }
89
90 #[must_use]
92 pub fn anonymous() -> StaticLoginCredentials {
93 StaticLoginCredentials::new("justinfan12345".to_owned(), None)
94 }
95}
96
97#[async_trait]
98impl LoginCredentials for StaticLoginCredentials {
99 type Error = Infallible;
100
101 async fn get_credentials(&self) -> Result<CredentialsPair, Infallible> {
102 Ok(self.credentials.clone())
103 }
104}
105
106#[cfg(feature = "__refreshing-token")]
109#[derive(Clone, Serialize, Deserialize)]
110pub struct UserAccessToken {
111 pub access_token: String,
113 pub refresh_token: String,
115 pub created_at: DateTime<Utc>,
117 pub expires_at: Option<DateTime<Utc>>,
119}
120
121#[cfg(feature = "__refreshing-token")]
123impl Debug for UserAccessToken {
124 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
125 f.debug_struct("UserAccessToken")
126 .field("access_token", &"[redacted]")
127 .field("refresh_token", &"[redacted]")
128 .field("created_at", &self.created_at)
129 .field("expires_at", &self.expires_at)
130 .finish()
131 }
132}
133
134#[cfg(feature = "__refreshing-token")]
150#[derive(Serialize, Deserialize)]
151pub struct GetAccessTokenResponse {
152 pub access_token: String,
163 pub refresh_token: String,
165 pub expires_in: Option<u64>,
168}
169
170#[cfg(feature = "__refreshing-token")]
171impl From<GetAccessTokenResponse> for UserAccessToken {
172 fn from(response: GetAccessTokenResponse) -> Self {
173 let now = Utc::now();
174 UserAccessToken {
175 access_token: response.access_token,
176 refresh_token: response.refresh_token,
177 created_at: now,
178 expires_at: response
179 .expires_in
180 .map(|d| now + chrono::Duration::from_std(Duration::from_secs(d)).unwrap()),
181 }
182 }
183}
184
185#[cfg(feature = "__refreshing-token")]
187#[async_trait]
188pub trait TokenStorage: Debug + Send + 'static {
189 type LoadError: Send + Sync + Debug + Display;
191 type UpdateError: Send + Sync + Debug + Display;
193
194 async fn load_token(&mut self) -> Result<UserAccessToken, Self::LoadError>;
196 async fn update_token(&mut self, token: &UserAccessToken) -> Result<(), Self::UpdateError>;
200}
201
202#[cfg(feature = "__refreshing-token")]
206#[derive(Clone)]
207pub struct RefreshingLoginCredentials<S: TokenStorage> {
208 http_client: reqwest::Client,
209 user_login: Arc<Mutex<Option<String>>>,
210 client_id: String,
211 client_secret: String,
212 token_storage: Arc<Mutex<S>>,
213}
214
215#[cfg(feature = "__refreshing-token")]
217impl<S: TokenStorage> Debug for RefreshingLoginCredentials<S> {
218 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
219 f.debug_struct("RefreshingLoginCredentials")
220 .field("http_client", &self.http_client)
221 .field("user_login", &self.user_login)
222 .field("client_id", &self.client_id)
223 .field("client_secret", &"[redacted]")
224 .field("token_storage", &self.token_storage)
225 .finish()
226 }
227}
228
229#[cfg(feature = "__refreshing-token")]
230impl<S: TokenStorage> RefreshingLoginCredentials<S> {
231 pub fn init(
234 client_id: String,
235 client_secret: String,
236 token_storage: S,
237 ) -> RefreshingLoginCredentials<S> {
238 RefreshingLoginCredentials::init_with_username(
239 None,
240 client_id,
241 client_secret,
242 token_storage,
243 )
244 }
245
246 pub fn init_with_username(
251 user_login: Option<String>,
252 client_id: String,
253 client_secret: String,
254 token_storage: S,
255 ) -> RefreshingLoginCredentials<S> {
256 let http_client = {
257 #[cfg_attr(
258 not(feature = "refreshing-token-rustls-webpki-roots"),
259 allow(unused_mut)
260 )]
261 let mut builder = ClientBuilder::new();
262
263 #[cfg(feature = "refreshing-token-rustls-webpki-roots")]
264 {
265 builder = builder.tls_certs_only(
266 webpki_root_certs::TLS_SERVER_ROOT_CERTS
267 .iter()
268 .map(|cert| reqwest::tls::Certificate::from_der(cert).unwrap()),
269 );
270 }
271
272 builder.build().unwrap()
273 };
274
275 RefreshingLoginCredentials {
276 http_client,
277 user_login: Arc::new(Mutex::new(user_login)),
278 client_id,
279 client_secret,
280 token_storage: Arc::new(Mutex::new(token_storage)),
281 }
282 }
283}
284
285#[cfg(feature = "__refreshing-token")]
287#[derive(Error, Debug)]
288pub enum RefreshingLoginError<S: TokenStorage> {
289 #[error("Failed to retrieve token from storage: {0}")]
291 LoadError(S::LoadError),
292 #[error("Failed to refresh token: {0}")]
294 RefreshError(reqwest::Error),
295 #[error("Failed to update token in storage: {0}")]
297 UpdateError(S::UpdateError),
298}
299
300#[cfg(feature = "__refreshing-token")]
301const SHOULD_REFRESH_AFTER_FACTOR: f64 = 0.9;
302
303#[cfg(feature = "__refreshing-token")]
304#[async_trait]
305impl<S: TokenStorage> LoginCredentials for RefreshingLoginCredentials<S> {
306 type Error = RefreshingLoginError<S>;
307
308 async fn get_credentials(&self) -> Result<CredentialsPair, RefreshingLoginError<S>> {
309 let mut token_storage = self.token_storage.lock().await;
310
311 let mut current_token = token_storage
312 .load_token()
313 .await
314 .map_err(RefreshingLoginError::LoadError)?;
315
316 let token_expires_after = if let Some(expires_at) = current_token.expires_at {
317 (expires_at - current_token.created_at).to_std().unwrap()
319 } else {
320 Duration::from_secs(24 * 60 * 60)
322 };
323 let token_age = (Utc::now() - current_token.created_at).to_std().unwrap();
324 let max_token_age = token_expires_after.mul_f64(SHOULD_REFRESH_AFTER_FACTOR);
325 let is_token_expired = token_age >= max_token_age;
326
327 if is_token_expired {
328 let response = self
329 .http_client
330 .post("https://id.twitch.tv/oauth2/token")
331 .query(&[
332 ("grant_type", "refresh_token"),
333 ("refresh_token", ¤t_token.refresh_token),
334 ("client_id", &self.client_id),
335 ("client_secret", &self.client_secret),
336 ])
337 .send()
338 .await
339 .map_err(RefreshingLoginError::RefreshError)?
340 .json::<GetAccessTokenResponse>()
341 .await
342 .map_err(RefreshingLoginError::RefreshError)?;
343
344 current_token = UserAccessToken::from(response);
346
347 token_storage
348 .update_token(¤t_token)
349 .await
350 .map_err(RefreshingLoginError::UpdateError)?;
351 }
352
353 let mut current_login = self.user_login.lock().await;
354
355 let login = if let Some(login) = &*current_login {
356 login.clone()
357 } else {
358 let response = self
359 .http_client
360 .get("https://api.twitch.tv/helix/users")
361 .header("Client-Id", &self.client_id)
362 .bearer_auth(¤t_token.access_token)
363 .send()
364 .await
365 .map_err(RefreshingLoginError::RefreshError)?;
366
367 let users_response = response
368 .json::<UsersResponse>()
369 .await
370 .map_err(RefreshingLoginError::RefreshError)?;
371
372 let user = users_response.data.into_iter().next().unwrap();
374
375 tracing::info!(
379 "Fetched login name `{}` for provided auth token",
380 &user.login
381 );
382
383 *current_login = Some(user.login.clone());
384
385 user.login
386 };
387
388 Ok(CredentialsPair {
389 login,
390 token: Some(current_token.access_token.clone()),
391 })
392 }
393}
394
395#[cfg(feature = "__refreshing-token")]
398#[derive(Deserialize)]
399struct UsersResponse {
400 data: Vec<UserObject>,
401}
402
403#[cfg(feature = "__refreshing-token")]
406#[derive(Deserialize)]
407struct UserObject {
408 login: String,
409}