1use std::{
2 fmt::Debug,
3 sync::{Arc, RwLock},
4};
5
6use oauth2::{
7 basic::{
8 BasicErrorResponse, BasicRevocationErrorResponse, BasicTokenIntrospectionResponse,
9 BasicTokenType,
10 },
11 reqwest::async_http_client,
12 AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, PkceCodeChallenge, RedirectUrl,
13 RefreshToken, StandardRevocableToken, TokenUrl,
14};
15use reqwest::{header::CONTENT_LENGTH, Method, Url};
16use serde::{
17 de::{value::BytesDeserializer, DeserializeOwned, IntoDeserializer},
18 Serialize,
19};
20use tracing::info;
21
22use crate::{
23 auth::{
24 AuthCodeFlow, AuthCodePkceFlow, AuthFlow, AuthenticationState, ClientCredsFlow, Scopes,
25 Token, Unauthenticated, UnknownFlow,
26 },
27 error::{Error, Result, SpotifyError},
28};
29
30const AUTHORISATION_URL: &str = "https://accounts.spotify.com/authorize";
31const TOKEN_URL: &str = "https://accounts.spotify.com/api/token";
32pub(crate) const API_URL: &str = "https://api.spotify.com/v1";
33
34pub(crate) type OAuthClient = oauth2::Client<
35 BasicErrorResponse,
36 Token,
37 BasicTokenType,
38 BasicTokenIntrospectionResponse,
39 StandardRevocableToken,
40 BasicRevocationErrorResponse,
41>;
42
43pub type AuthCodeClient<A> = Client<A, AuthCodeFlow>;
45
46pub type AuthCodePkceClient<A> = Client<A, AuthCodePkceFlow>;
48
49pub type ClientCredsClient<A> = Client<A, ClientCredsFlow>;
51
52#[doc(hidden)]
53#[derive(Debug)]
54pub(crate) enum Body<P: Serialize = ()> {
55 Json(P),
56 File(Vec<u8>),
57}
58
59#[derive(Clone, Debug)]
64pub struct Client<A: AuthenticationState, F: AuthFlow> {
65 pub auto_refresh: bool,
70 pub(crate) auth_state: Arc<RwLock<A>>,
73 pub(crate) auth_flow: F,
76 pub(crate) oauth: OAuthClient,
78 pub(crate) http: reqwest::Client,
80}
81
82impl Client<Token, UnknownFlow> {
83 pub async fn from_refresh_token(
87 client_id: impl Into<String>,
88 client_secret: Option<&str>,
89 scopes: Option<Scopes>,
90 auto_refresh: bool,
91 refresh_token: String,
92 ) -> Result<Self> {
93 let client_id = ClientId::new(client_id.into());
94 let client_secret = client_secret.map(|s| ClientSecret::new(s.to_owned()));
95
96 let oauth_client = OAuthClient::new(
97 client_id,
98 client_secret,
99 AuthUrl::new(AUTHORISATION_URL.to_owned()).unwrap(),
100 Some(TokenUrl::new(TOKEN_URL.to_owned()).unwrap()),
101 );
102
103 let refresh_token = RefreshToken::new(refresh_token);
104 let mut req = oauth_client.exchange_refresh_token(&refresh_token);
105
106 if let Some(scopes) = scopes {
107 req = req.add_scopes(scopes.0);
108 }
109
110 let mut token = req.request_async(async_http_client).await?.set_timestamps();
111 if token.refresh_token.is_none() {
112 token.refresh_token = Some(refresh_token);
115 }
116
117 Ok(Self {
118 auto_refresh,
119 auth_state: Arc::new(RwLock::new(token)),
120 auth_flow: UnknownFlow,
121 oauth: oauth_client,
122 http: reqwest::Client::new(),
123 })
124 }
125}
126
127impl<F: AuthFlow> Client<Token, F> {
128 pub fn token(&self) -> Arc<RwLock<Token>> {
133 self.auth_state.clone()
134 }
135
136 pub fn access_token(&self) -> Result<String> {
143 let token = self
144 .auth_state
145 .read()
146 .expect("The lock holding the token has been poisoned.");
147
148 Ok(token.access_token.secret().clone())
149 }
150
151 pub fn refresh_token(&self) -> Result<Option<String>> {
158 let token = self
159 .auth_state
160 .read()
161 .expect("The lock holding the token has been poisoned.");
162
163 let refresh_token = token.refresh_token.as_ref().map(|t| t.secret().clone());
164
165 Ok(refresh_token)
166 }
167
168 pub async fn exchange_refresh_token(&self) -> Result<()> {
171 let refresh_token = {
172 let lock = self.auth_state.read().unwrap_or_else(|e| e.into_inner());
173
174 let Some(refresh_token) = &lock.refresh_token else {
175 return Err(Error::RefreshUnavailable);
176 };
177
178 refresh_token.clone()
179 };
180
181 let token = self
182 .oauth
183 .exchange_refresh_token(&refresh_token)
184 .request_async(async_http_client)
185 .await?
186 .set_timestamps();
187
188 let mut lock = self
189 .auth_state
190 .write()
191 .expect("The lock holding the token has been poisoned.");
192 *lock = token;
193 Ok(())
194 }
195
196 pub(crate) async fn request<P: Serialize + Debug, T: DeserializeOwned>(
197 &self,
198 method: Method,
199 endpoint: String,
200 query: Option<P>,
201 body: Option<Body<P>>,
202 ) -> Result<T> {
203 let (token_expired, secret) = {
204 let lock = self
205 .auth_state
206 .read()
207 .expect("The lock holding the token has been poisoned.");
208
209 (lock.is_expired(), lock.access_token.secret().to_owned())
210 };
211
212 if token_expired {
213 if self.auto_refresh {
214 info!("The token has expired, attempting to refresh...");
215
216 self.exchange_refresh_token().await?;
217
218 let lock = self
219 .auth_state
220 .read()
221 .expect("The lock holding the token has been poisoned.");
222
223 info!("The token has been successfully refreshed. The new token will expire in {} seconds", lock.expires_in);
224 } else {
225 info!("The token has expired, automatic refresh is disabled.");
226 return Err(Error::ExpiredToken);
227 }
228 }
229
230 let mut req = {
231 self.http
232 .request(method, format!("{API_URL}{endpoint}"))
233 .bearer_auth(secret)
234 };
235
236 if let Some(q) = query {
237 req = req.query(&q);
238 }
239
240 if let Some(b) = body {
241 match b {
242 Body::Json(j) => req = req.json(&j),
243 Body::File(f) => req = req.body(f),
244 }
245 } else {
246 req = req.header(CONTENT_LENGTH, 0);
250 }
251
252 let req = req.build()?;
253 info!(headers = ?req.headers(), "{} request sent to {}", req.method(), req.url());
254
255 let res = self.http.execute(req).await?;
256
257 if res.status().is_success() {
258 let bytes = res.bytes().await?;
259
260 let deserialized = serde_json::from_slice::<T>(&bytes).or_else(|e| {
262 let de: BytesDeserializer<'_, serde::de::value::Error> =
265 bytes.as_ref().into_deserializer();
266
267 T::deserialize(de).map_err(|_| e)
270 });
271 match deserialized {
274 Ok(content) => Ok(content),
275 Err(err) => {
276 let body = std::str::from_utf8(&bytes).map_err(|_| Error::InvalidResponse)?;
277
278 tracing::error!(
279 %body,
280 "Failed to deserialize the response body into an object or Nil."
281 );
282
283 Err(Error::Deserialization {
284 source: err,
285 body: body.to_owned(),
286 })
287 }
288 }
289 } else {
290 Err(res.json::<SpotifyError>().await?.into())
291 }
292 }
293
294 pub(crate) async fn get<P: Serialize + Debug, T: DeserializeOwned>(
295 &self,
296 endpoint: String,
297 query: impl Into<Option<P>>,
298 ) -> Result<T> {
299 self.request(Method::GET, endpoint, query.into(), None)
300 .await
301 }
302
303 pub(crate) async fn post<P: Serialize + Debug, T: DeserializeOwned>(
304 &self,
305 endpoint: String,
306 body: impl Into<Option<Body<P>>>,
307 ) -> Result<T> {
308 self.request(Method::POST, endpoint, None, body.into())
309 .await
310 }
311
312 pub(crate) async fn put<P: Serialize + Debug, T: DeserializeOwned>(
313 &self,
314 endpoint: String,
315 body: impl Into<Option<Body<P>>>,
316 ) -> Result<T> {
317 self.request(Method::PUT, endpoint, None, body.into()).await
318 }
319
320 pub(crate) async fn delete<P: Serialize + Debug, T: DeserializeOwned>(
321 &self,
322 endpoint: String,
323 body: impl Into<Option<Body<P>>>,
324 ) -> Result<T> {
325 self.request(Method::DELETE, endpoint, None, body.into())
326 .await
327 }
328}
329
330impl AuthCodeClient<Unauthenticated> {
331 pub fn new<S>(
338 client_id: impl Into<String>,
339 client_secret: impl Into<String>,
340 scopes: S,
341 redirect_uri: RedirectUrl,
342 auto_refresh: bool,
343 ) -> (Self, Url)
344 where
345 S: Into<Scopes>,
346 {
347 let client_id = ClientId::new(client_id.into());
348 let client_secret = Some(ClientSecret::new(client_secret.into()));
349
350 let oauth = OAuthClient::new(
351 client_id,
352 client_secret,
353 AuthUrl::new(AUTHORISATION_URL.to_owned()).unwrap(),
354 Some(TokenUrl::new(TOKEN_URL.to_owned()).unwrap()),
355 )
356 .set_redirect_uri(redirect_uri);
357
358 let (auth_url, csrf_token) = oauth
359 .authorize_url(CsrfToken::new_random)
360 .add_scopes(scopes.into().0)
361 .url();
362
363 (
364 Client {
365 auto_refresh,
366 auth_state: Arc::new(RwLock::new(Unauthenticated)),
367 auth_flow: AuthCodeFlow { csrf_token },
368 oauth,
369 http: reqwest::Client::new(),
370 },
371 auth_url,
372 )
373 }
374
375 pub async fn authenticate(
380 self,
381 auth_code: impl Into<String>,
382 csrf_state: impl AsRef<str>,
383 ) -> Result<Client<Token, AuthCodeFlow>> {
384 let auth_code = auth_code.into().trim().to_owned();
385 let csrf_state = csrf_state.as_ref().trim();
386
387 if csrf_state != self.auth_flow.csrf_token.secret() {
388 return Err(Error::InvalidStateParameter);
389 }
390
391 let token = self
392 .oauth
393 .exchange_code(AuthorizationCode::new(auth_code))
394 .request_async(async_http_client)
395 .await?
396 .set_timestamps();
397
398 Ok(Client {
399 auto_refresh: self.auto_refresh,
400 auth_state: Arc::new(RwLock::new(token)),
401 auth_flow: self.auth_flow,
402 oauth: self.oauth,
403 http: self.http,
404 })
405 }
406}
407
408impl AuthCodePkceClient<Unauthenticated> {
409 pub fn new<T, S>(
416 client_id: T,
417 scopes: S,
418 redirect_uri: RedirectUrl,
419 auto_refresh: bool,
420 ) -> (Self, Url)
421 where
422 T: Into<String>,
423 S: Into<Scopes>,
424 {
425 let client_id = ClientId::new(client_id.into());
426
427 let oauth = OAuthClient::new(
428 client_id,
429 None,
430 AuthUrl::new(AUTHORISATION_URL.to_owned()).unwrap(),
431 Some(TokenUrl::new(TOKEN_URL.to_owned()).unwrap()),
432 )
433 .set_redirect_uri(redirect_uri);
434
435 let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
436
437 let (auth_url, csrf_token) = oauth
438 .authorize_url(CsrfToken::new_random)
439 .add_scopes(scopes.into().0)
440 .set_pkce_challenge(pkce_challenge)
441 .url();
442
443 (
444 Client {
445 auto_refresh,
446 auth_state: Arc::new(RwLock::new(Unauthenticated)),
447 auth_flow: AuthCodePkceFlow {
448 csrf_token,
449 pkce_verifier: Some(pkce_verifier),
450 },
451 oauth,
452 http: reqwest::Client::new(),
453 },
454 auth_url,
455 )
456 }
457
458 pub async fn authenticate(
463 mut self,
464 auth_code: impl Into<String>,
465 csrf_state: impl AsRef<str>,
466 ) -> Result<Client<Token, AuthCodePkceFlow>> {
467 let auth_code = auth_code.into().trim().to_owned();
468 let csrf_state = csrf_state.as_ref().trim();
469
470 if csrf_state != self.auth_flow.csrf_token.secret() {
471 return Err(Error::InvalidStateParameter);
472 }
473
474 let Some(pkce_verifier) = self.auth_flow.pkce_verifier.take() else {
475 tracing::error!(client = ?self, "No PKCE code verifier present when authenticating the client.");
478 return Err(Error::InvalidClientState);
479 };
480
481 let token = self
482 .oauth
483 .exchange_code(AuthorizationCode::new(auth_code))
484 .set_pkce_verifier(pkce_verifier)
485 .request_async(async_http_client)
486 .await?
487 .set_timestamps();
488
489 Ok(Client {
490 auto_refresh: self.auto_refresh,
491 auth_state: Arc::new(RwLock::new(token)),
492 auth_flow: self.auth_flow,
493 oauth: self.oauth,
494 http: self.http,
495 })
496 }
497}
498
499impl ClientCredsClient<Unauthenticated> {
500 pub async fn authenticate(
506 client_id: impl Into<String>,
507 client_secret: impl Into<String>,
508 ) -> Result<ClientCredsClient<Token>> {
509 let client_id = ClientId::new(client_id.into());
510 let client_secret = Some(ClientSecret::new(client_secret.into()));
511
512 let oauth = OAuthClient::new(
513 client_id,
514 client_secret,
515 AuthUrl::new(AUTHORISATION_URL.to_owned()).unwrap(),
516 Some(TokenUrl::new(TOKEN_URL.to_owned()).unwrap()),
517 );
518
519 let token = oauth
520 .exchange_client_credentials()
521 .request_async(async_http_client)
522 .await?
523 .set_timestamps();
524
525 Ok(Client {
526 auto_refresh: false,
527 auth_state: Arc::new(RwLock::new(token)),
528 auth_flow: ClientCredsFlow,
529 oauth,
530 http: reqwest::Client::new(),
531 })
532 }
533}
534
535impl AuthCodeClient<Token> {
536 pub async fn from_access_token(
542 client_id: impl Into<String>,
543 client_secret: impl Into<String>,
544 auto_refresh: bool,
545 token: Token,
546 ) -> Result<Self> {
547 let client_id = ClientId::new(client_id.into());
548 let client_secret = Some(ClientSecret::new(client_secret.into()));
550
551 let oauth_client = OAuthClient::new(
552 client_id,
553 client_secret,
554 AuthUrl::new(AUTHORISATION_URL.to_owned()).unwrap(),
555 Some(TokenUrl::new(TOKEN_URL.to_owned()).unwrap()),
556 );
557
558 let http = reqwest::Client::new();
559
560 let res = http
562 .get(format!("{API_URL}/markets"))
563 .bearer_auth(token.secret())
564 .header(CONTENT_LENGTH, 0)
565 .send()
566 .await?;
567
568 if !res.status().is_success() {
569 return Err(res.json::<SpotifyError>().await?.into());
570 }
571
572 let auth_flow = AuthCodeFlow {
573 csrf_token: CsrfToken::new("not needed".to_owned()),
574 };
575
576 let auto_refresh = auto_refresh && token.refresh_token.is_some();
577
578 Ok(Self {
579 auto_refresh,
580 auth_state: Arc::new(RwLock::new(token)),
581 auth_flow,
582 oauth: oauth_client,
583 http,
584 })
585 }
586}
587
588impl AuthCodePkceClient<Token> {
589 pub async fn from_access_token(
595 client_id: impl Into<String>,
596 auto_refresh: bool,
597 token: Token,
598 ) -> Result<Self> {
599 let client_id = ClientId::new(client_id.into());
600
601 let oauth_client = OAuthClient::new(
602 client_id,
603 None,
604 AuthUrl::new(AUTHORISATION_URL.to_owned()).unwrap(),
605 Some(TokenUrl::new(TOKEN_URL.to_owned()).unwrap()),
606 );
607
608 let http = reqwest::Client::new();
609
610 let res = http
612 .get(format!("{API_URL}/recommendations/available-genre-seeds"))
613 .bearer_auth(token.secret())
614 .header(CONTENT_LENGTH, 0)
615 .send()
616 .await?;
617
618 if !res.status().is_success() {
619 return Err(res.json::<SpotifyError>().await?.into());
620 }
621
622 let auth_flow = AuthCodePkceFlow {
623 csrf_token: CsrfToken::new("not needed".to_owned()),
624 pkce_verifier: None,
625 };
626
627 let auto_refresh = auto_refresh && token.refresh_token.is_some();
628
629 Ok(Self {
630 auto_refresh,
631 auth_state: Arc::new(RwLock::new(token)),
632 auth_flow,
633 oauth: oauth_client,
634 http,
635 })
636 }
637}
638
639impl ClientCredsClient<Token> {
640 pub async fn from_access_token(
646 client_id: impl Into<String>,
647 client_secret: impl Into<String>,
648 token: Token,
649 ) -> Result<Self> {
650 let client_id = ClientId::new(client_id.into());
651 let client_secret = Some(ClientSecret::new(client_secret.into()));
652
653 let oauth_client = OAuthClient::new(
654 client_id,
655 client_secret,
656 AuthUrl::new(AUTHORISATION_URL.to_owned()).unwrap(),
657 Some(TokenUrl::new(TOKEN_URL.to_owned()).unwrap()),
658 );
659
660 let http = reqwest::Client::new();
661
662 let res = http
664 .get(format!("{API_URL}/recommendations/available-genre-seeds"))
665 .bearer_auth(token.secret())
666 .header(CONTENT_LENGTH, 0)
667 .send()
668 .await?;
669
670 if !res.status().is_success() {
671 return Err(res.json::<SpotifyError>().await?.into());
672 }
673
674 Ok(Self {
675 auto_refresh: false,
676 auth_state: Arc::new(RwLock::new(token)),
677 auth_flow: ClientCredsFlow,
678 oauth: oauth_client,
679 http,
680 })
681 }
682}