spotify_rs/
client.rs

1use std::{
2    fmt::Debug,
3    sync::{Arc, RwLock},
4};
5
6use oauth2::{
7    basic::{
8        BasicErrorResponse, BasicRevocationErrorResponse, BasicTokenIntrospectionResponse,
9        BasicTokenType,
10    },
11    reqwest::async_http_client,
12    AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, PkceCodeChallenge, RedirectUrl,
13    RefreshToken, StandardRevocableToken, TokenUrl,
14};
15use reqwest::{header::CONTENT_LENGTH, Method, Url};
16use serde::{
17    de::{value::BytesDeserializer, DeserializeOwned, IntoDeserializer},
18    Serialize,
19};
20use tracing::info;
21
22use crate::{
23    auth::{
24        AuthCodeFlow, AuthCodePkceFlow, AuthFlow, AuthenticationState, ClientCredsFlow, Scopes,
25        Token, Unauthenticated, UnknownFlow,
26    },
27    error::{Error, Result, SpotifyError},
28};
29
30const AUTHORISATION_URL: &str = "https://accounts.spotify.com/authorize";
31const TOKEN_URL: &str = "https://accounts.spotify.com/api/token";
32pub(crate) const API_URL: &str = "https://api.spotify.com/v1";
33
34pub(crate) type OAuthClient = oauth2::Client<
35    BasicErrorResponse,
36    Token,
37    BasicTokenType,
38    BasicTokenIntrospectionResponse,
39    StandardRevocableToken,
40    BasicRevocationErrorResponse,
41>;
42
43/// A client created using the Authorisation Code Flow.
44pub type AuthCodeClient<A> = Client<A, AuthCodeFlow>;
45
46/// A client created using the Authorisation Code with PKCE Flow.
47pub type AuthCodePkceClient<A> = Client<A, AuthCodePkceFlow>;
48
49/// A client created using the Client Credentials Flow.
50pub type ClientCredsClient<A> = Client<A, ClientCredsFlow>;
51
52#[doc(hidden)]
53#[derive(Debug)]
54pub(crate) enum Body<P: Serialize = ()> {
55    Json(P),
56    File(Vec<u8>),
57}
58
59/// The client which handles the authentication and all the Spotify API requests.
60///
61/// It is recommended to use one of the following: [`AuthCodeClient`], [`AuthCodePkceClient`]
62/// or [`ClientCredsClient`], depending on the chosen authentication flow.
63#[derive(Clone, Debug)]
64pub struct Client<A: AuthenticationState, F: AuthFlow> {
65    /// Dictates whether or not the client will request a new token when the
66    /// current one is about the expire.
67    ///
68    /// It will check if the token has expired in every request.
69    pub auto_refresh: bool,
70    // This is used for the typestate pattern, to differentiate an authenticated
71    // client from an unauthenticated one, but it also holds the Token.
72    pub(crate) auth_state: Arc<RwLock<A>>,
73    // This is used for the typestate pattern to differentiate between different
74    // authorisation flows, as well as hold the CSRF/PKCE verifiers.
75    pub(crate) auth_flow: F,
76    // The OAuth2 client.
77    pub(crate) oauth: OAuthClient,
78    // The HTTP client.
79    pub(crate) http: reqwest::Client,
80}
81
82impl Client<Token, UnknownFlow> {
83    /// Create a new authenticated and authorised client from a refresh token.
84    ///
85    /// This method will fail if the refresh token is invalid or a new one cannot be obtained.
86    pub async fn from_refresh_token(
87        client_id: impl Into<String>,
88        client_secret: Option<&str>,
89        scopes: Option<Scopes>,
90        auto_refresh: bool,
91        refresh_token: String,
92    ) -> Result<Self> {
93        let client_id = ClientId::new(client_id.into());
94        let client_secret = client_secret.map(|s| ClientSecret::new(s.to_owned()));
95
96        let oauth_client = OAuthClient::new(
97            client_id,
98            client_secret,
99            AuthUrl::new(AUTHORISATION_URL.to_owned()).unwrap(),
100            Some(TokenUrl::new(TOKEN_URL.to_owned()).unwrap()),
101        );
102
103        let refresh_token = RefreshToken::new(refresh_token);
104        let mut req = oauth_client.exchange_refresh_token(&refresh_token);
105
106        if let Some(scopes) = scopes {
107            req = req.add_scopes(scopes.0);
108        }
109
110        let token = req.request_async(async_http_client).await?.set_timestamps();
111
112        Ok(Self {
113            auto_refresh,
114            auth_state: Arc::new(RwLock::new(token)),
115            auth_flow: UnknownFlow,
116            oauth: oauth_client,
117            http: reqwest::Client::new(),
118        })
119    }
120}
121
122impl<F: AuthFlow> Client<Token, F> {
123    /// Get a reference to the client's token.
124    ///
125    /// Please note that the [RwLock] used here is **not** async-aware, and thus
126    /// the read/write guard should not be held across await points.
127    pub fn token(&self) -> Arc<RwLock<Token>> {
128        self.auth_state.clone()
129    }
130
131    /// Get the access token secret as an owned (cloned) string.
132    /// If you only need a reference, you can use [`token`](Self::token)
133    /// yourself and get a reference from the returned [RwLock].
134    ///
135    /// This method will fail if the `RwLock` that holds the token has
136    /// been poisoned.
137    pub fn access_token(&self) -> Result<String> {
138        let token = self
139            .auth_state
140            .read()
141            .expect("The lock holding the token has been poisoned.");
142
143        Ok(token.access_token.secret().clone())
144    }
145
146    /// Get the refresh token secret as an owned (cloned) string.
147    /// If you only need a reference, you can use [`token`](Self::token)
148    /// yourself and get a reference from the returned [RwLock].
149    ///
150    /// This method will fail if the `RwLock` that holds the token has
151    /// been poisoned.
152    pub fn refresh_token(&self) -> Result<Option<String>> {
153        let token = self
154            .auth_state
155            .read()
156            .expect("The lock holding the token has been poisoned.");
157
158        let refresh_token = token.refresh_token.as_ref().map(|t| t.secret().clone());
159
160        Ok(refresh_token)
161    }
162
163    /// Exchange the refresh token for a new access token and updates it in the client.
164    /// Only some auth flows allow for token refreshing.
165    pub async fn exchange_refresh_token(&self) -> Result<()> {
166        let refresh_token = {
167            let lock = self.auth_state.read().unwrap_or_else(|e| e.into_inner());
168
169            let Some(refresh_token) = &lock.refresh_token else {
170                return Err(Error::RefreshUnavailable);
171            };
172
173            refresh_token.clone()
174        };
175
176        let token = self
177            .oauth
178            .exchange_refresh_token(&refresh_token)
179            .request_async(async_http_client)
180            .await?
181            .set_timestamps();
182
183        let mut lock = self
184            .auth_state
185            .write()
186            .expect("The lock holding the token has been poisoned.");
187        *lock = token;
188        Ok(())
189    }
190
191    pub(crate) async fn request<P: Serialize + Debug, T: DeserializeOwned>(
192        &self,
193        method: Method,
194        endpoint: String,
195        query: Option<P>,
196        body: Option<Body<P>>,
197    ) -> Result<T> {
198        let (token_expired, secret) = {
199            let lock = self
200                .auth_state
201                .read()
202                .expect("The lock holding the token has been poisoned.");
203
204            (lock.is_expired(), lock.access_token.secret().to_owned())
205        };
206
207        if token_expired {
208            if self.auto_refresh {
209                info!("The token has expired, attempting to refresh...");
210
211                self.exchange_refresh_token().await?;
212
213                let lock = self
214                    .auth_state
215                    .read()
216                    .expect("The lock holding the token has been poisoned.");
217
218                info!("The token has been successfully refreshed. The new token will expire in {} seconds", lock.expires_in);
219            } else {
220                info!("The token has expired, automatic refresh is disabled.");
221                return Err(Error::ExpiredToken);
222            }
223        }
224
225        let mut req = {
226            self.http
227                .request(method, format!("{API_URL}{endpoint}"))
228                .bearer_auth(secret)
229        };
230
231        if let Some(q) = query {
232            req = req.query(&q);
233        }
234
235        if let Some(b) = body {
236            match b {
237                Body::Json(j) => req = req.json(&j),
238                Body::File(f) => req = req.body(f),
239            }
240        } else {
241            // Used because Spotify wants a Content-Length header for the PUT /audiobooks/me endpoint even though there is no body
242            // If not supplied, it will return an error in the form of HTML (not JSON), which I believe to be an issue on their end.
243            // No other endpoints so far behave this way.
244            req = req.header(CONTENT_LENGTH, 0);
245        }
246
247        let req = req.build()?;
248        info!(headers = ?req.headers(), "{} request sent to {}", req.method(), req.url());
249
250        let res = self.http.execute(req).await?;
251
252        if res.status().is_success() {
253            let bytes = res.bytes().await?;
254
255            // Try to deserialize from bytes of JSON text;
256            let deserialized = serde_json::from_slice::<T>(&bytes).or_else(|e| {
257                // if the previous operation fails, try deserializing straight
258                // from the bytes, which works for Nil.
259                let de: BytesDeserializer<'_, serde::de::value::Error> =
260                    bytes.as_ref().into_deserializer();
261
262                // This line also converts the serde::de::value::Error to a serde_json::Error
263                // to make it clearer to the end user that deserialization failed.
264                T::deserialize(de).map_err(|_| e)
265            });
266            // .context(DeserializationSnafu { body });
267
268            match deserialized {
269                Ok(content) => Ok(content),
270                Err(err) => {
271                    let body = std::str::from_utf8(&bytes).map_err(|_| Error::InvalidResponse)?;
272
273                    tracing::error!(
274                        %body,
275                        "Failed to deserialize the response body into an object or Nil."
276                    );
277
278                    Err(Error::Deserialization {
279                        source: err,
280                        body: body.to_owned(),
281                    })
282                }
283            }
284        } else {
285            Err(res.json::<SpotifyError>().await?.into())
286        }
287    }
288
289    pub(crate) async fn get<P: Serialize + Debug, T: DeserializeOwned>(
290        &self,
291        endpoint: String,
292        query: impl Into<Option<P>>,
293    ) -> Result<T> {
294        self.request(Method::GET, endpoint, query.into(), None)
295            .await
296    }
297
298    pub(crate) async fn post<P: Serialize + Debug, T: DeserializeOwned>(
299        &self,
300        endpoint: String,
301        body: impl Into<Option<Body<P>>>,
302    ) -> Result<T> {
303        self.request(Method::POST, endpoint, None, body.into())
304            .await
305    }
306
307    pub(crate) async fn put<P: Serialize + Debug, T: DeserializeOwned>(
308        &self,
309        endpoint: String,
310        body: impl Into<Option<Body<P>>>,
311    ) -> Result<T> {
312        self.request(Method::PUT, endpoint, None, body.into()).await
313    }
314
315    pub(crate) async fn delete<P: Serialize + Debug, T: DeserializeOwned>(
316        &self,
317        endpoint: String,
318        body: impl Into<Option<Body<P>>>,
319    ) -> Result<T> {
320        self.request(Method::DELETE, endpoint, None, body.into())
321            .await
322    }
323}
324
325impl AuthCodeClient<Unauthenticated> {
326    /// Create a new client and generate an authorisation URL
327    ///
328    /// You must redirect the user to the returned URL, which in turn redirects them to
329    /// the `redirect_uri` you provided, along with a `code` and `state` parameter in the URl.
330    ///
331    /// They are required for the next step in the auth process.
332    pub fn new<S>(
333        client_id: impl Into<String>,
334        client_secret: impl Into<String>,
335        scopes: S,
336        redirect_uri: RedirectUrl,
337        auto_refresh: bool,
338    ) -> (Self, Url)
339    where
340        S: Into<Scopes>,
341    {
342        let client_id = ClientId::new(client_id.into());
343        let client_secret = Some(ClientSecret::new(client_secret.into()));
344
345        let oauth = OAuthClient::new(
346            client_id,
347            client_secret,
348            AuthUrl::new(AUTHORISATION_URL.to_owned()).unwrap(),
349            Some(TokenUrl::new(TOKEN_URL.to_owned()).unwrap()),
350        )
351        .set_redirect_uri(redirect_uri);
352
353        let (auth_url, csrf_token) = oauth
354            .authorize_url(CsrfToken::new_random)
355            .add_scopes(scopes.into().0)
356            .url();
357
358        (
359            Client {
360                auto_refresh,
361                auth_state: Arc::new(RwLock::new(Unauthenticated)),
362                auth_flow: AuthCodeFlow { csrf_token },
363                oauth,
364                http: reqwest::Client::new(),
365            },
366            auth_url,
367        )
368    }
369
370    /// This will exchange the `auth_code` for a token which will allow the client
371    /// to make requests.
372    ///
373    /// `csrf_state` is used for CSRF protection.
374    pub async fn authenticate(
375        self,
376        auth_code: impl Into<String>,
377        csrf_state: impl AsRef<str>,
378    ) -> Result<Client<Token, AuthCodeFlow>> {
379        let auth_code = auth_code.into().trim().to_owned();
380        let csrf_state = csrf_state.as_ref().trim();
381
382        if csrf_state != self.auth_flow.csrf_token.secret() {
383            return Err(Error::InvalidStateParameter);
384        }
385
386        let token = self
387            .oauth
388            .exchange_code(AuthorizationCode::new(auth_code))
389            .request_async(async_http_client)
390            .await?
391            .set_timestamps();
392
393        Ok(Client {
394            auto_refresh: self.auto_refresh,
395            auth_state: Arc::new(RwLock::new(token)),
396            auth_flow: self.auth_flow,
397            oauth: self.oauth,
398            http: self.http,
399        })
400    }
401}
402
403impl AuthCodePkceClient<Unauthenticated> {
404    /// Create a new client and generate an authorisation URL
405    ///
406    /// You must redirect the user to the received URL, which in turn redirects them to
407    /// the redirect URI you provided, along with a `code` and `state` parameter in the URl.
408    ///
409    /// They are required for the next step in the auth process.
410    pub fn new<T, S>(
411        client_id: T,
412        scopes: S,
413        redirect_uri: RedirectUrl,
414        auto_refresh: bool,
415    ) -> (Self, Url)
416    where
417        T: Into<String>,
418        S: Into<Scopes>,
419    {
420        let client_id = ClientId::new(client_id.into());
421
422        let oauth = OAuthClient::new(
423            client_id,
424            None,
425            AuthUrl::new(AUTHORISATION_URL.to_owned()).unwrap(),
426            Some(TokenUrl::new(TOKEN_URL.to_owned()).unwrap()),
427        )
428        .set_redirect_uri(redirect_uri);
429
430        let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
431
432        let (auth_url, csrf_token) = oauth
433            .authorize_url(CsrfToken::new_random)
434            .add_scopes(scopes.into().0)
435            .set_pkce_challenge(pkce_challenge)
436            .url();
437
438        (
439            Client {
440                auto_refresh,
441                auth_state: Arc::new(RwLock::new(Unauthenticated)),
442                auth_flow: AuthCodePkceFlow {
443                    csrf_token,
444                    pkce_verifier: Some(pkce_verifier),
445                },
446                oauth,
447                http: reqwest::Client::new(),
448            },
449            auth_url,
450        )
451    }
452
453    /// This will exchange the `auth_code` for a token which will allow the client
454    /// to make requests.
455    ///
456    /// `csrf_state` is used for CSRF protection.
457    pub async fn authenticate(
458        mut self,
459        auth_code: impl Into<String>,
460        csrf_state: impl AsRef<str>,
461    ) -> Result<Client<Token, AuthCodePkceFlow>> {
462        let auth_code = auth_code.into().trim().to_owned();
463        let csrf_state = csrf_state.as_ref().trim();
464
465        if csrf_state != self.auth_flow.csrf_token.secret() {
466            return Err(Error::InvalidStateParameter);
467        }
468
469        let Some(pkce_verifier) = self.auth_flow.pkce_verifier.take() else {
470            // This should never be reached realistically, but an error
471            // will be thrown and log issued just in case.
472            tracing::error!(client = ?self, "No PKCE code verifier present when authenticating the client.");
473            return Err(Error::InvalidClientState);
474        };
475
476        let token = self
477            .oauth
478            .exchange_code(AuthorizationCode::new(auth_code))
479            .set_pkce_verifier(pkce_verifier)
480            .request_async(async_http_client)
481            .await?
482            .set_timestamps();
483
484        Ok(Client {
485            auto_refresh: self.auto_refresh,
486            auth_state: Arc::new(RwLock::new(token)),
487            auth_flow: self.auth_flow,
488            oauth: self.oauth,
489            http: self.http,
490        })
491    }
492}
493
494impl ClientCredsClient<Unauthenticated> {
495    /// This will exchange the client credentials for an access token used
496    /// to make requests.
497    ///
498    /// This authentication method doesn't allow for token refreshing or to access
499    /// user resources.
500    pub async fn authenticate(
501        client_id: impl Into<String>,
502        client_secret: impl Into<String>,
503    ) -> Result<ClientCredsClient<Token>> {
504        let client_id = ClientId::new(client_id.into());
505        let client_secret = Some(ClientSecret::new(client_secret.into()));
506
507        let oauth = OAuthClient::new(
508            client_id,
509            client_secret,
510            AuthUrl::new(AUTHORISATION_URL.to_owned()).unwrap(),
511            Some(TokenUrl::new(TOKEN_URL.to_owned()).unwrap()),
512        );
513
514        let token = oauth
515            .exchange_client_credentials()
516            .request_async(async_http_client)
517            .await?
518            .set_timestamps();
519
520        Ok(Client {
521            auto_refresh: false,
522            auth_state: Arc::new(RwLock::new(token)),
523            auth_flow: ClientCredsFlow,
524            oauth,
525            http: reqwest::Client::new(),
526        })
527    }
528}
529
530impl AuthCodeClient<Token> {
531    /// Create a new authenticated client from an access token.
532    /// This client will be able to access user data.
533    ///
534    /// This method will fail if the access token is invalid (a request will
535    /// be sent to check the token).
536    pub async fn from_access_token(
537        client_id: impl Into<String>,
538        client_secret: impl Into<String>,
539        auto_refresh: bool,
540        token: Token,
541    ) -> Result<Self> {
542        let client_id = ClientId::new(client_id.into());
543        // client_secret.map(|s| ClientSecret::new(s.to_owned()));
544        let client_secret = Some(ClientSecret::new(client_secret.into()));
545
546        let oauth_client = OAuthClient::new(
547            client_id,
548            client_secret,
549            AuthUrl::new(AUTHORISATION_URL.to_owned()).unwrap(),
550            Some(TokenUrl::new(TOKEN_URL.to_owned()).unwrap()),
551        );
552
553        let http = reqwest::Client::new();
554
555        // This is just a bogus request to check if the token is valid.
556        let res = http
557            .get(format!("{API_URL}/recommendations/available-genre-seeds"))
558            .bearer_auth(token.secret())
559            .header(CONTENT_LENGTH, 0)
560            .send()
561            .await?;
562
563        if !res.status().is_success() {
564            return Err(res.json::<SpotifyError>().await?.into());
565        }
566
567        let auth_flow = AuthCodeFlow {
568            csrf_token: CsrfToken::new("not needed".to_owned()),
569        };
570
571        let auto_refresh = auto_refresh && token.refresh_token.is_some();
572
573        Ok(Self {
574            auto_refresh,
575            auth_state: Arc::new(RwLock::new(token)),
576            auth_flow,
577            oauth: oauth_client,
578            http,
579        })
580    }
581}
582
583impl AuthCodePkceClient<Token> {
584    /// Create a new authenticated client from an access token.
585    /// This client will be able to access user data.
586    ///
587    /// This method will fail if the access token is invalid (a request will
588    /// be sent to check the token).
589    pub async fn from_access_token(
590        client_id: impl Into<String>,
591        auto_refresh: bool,
592        token: Token,
593    ) -> Result<Self> {
594        let client_id = ClientId::new(client_id.into());
595
596        let oauth_client = OAuthClient::new(
597            client_id,
598            None,
599            AuthUrl::new(AUTHORISATION_URL.to_owned()).unwrap(),
600            Some(TokenUrl::new(TOKEN_URL.to_owned()).unwrap()),
601        );
602
603        let http = reqwest::Client::new();
604
605        // This is just a bogus request to check if the token is valid.
606        let res = http
607            .get(format!("{API_URL}/recommendations/available-genre-seeds"))
608            .bearer_auth(token.secret())
609            .header(CONTENT_LENGTH, 0)
610            .send()
611            .await?;
612
613        if !res.status().is_success() {
614            return Err(res.json::<SpotifyError>().await?.into());
615        }
616
617        let auth_flow = AuthCodePkceFlow {
618            csrf_token: CsrfToken::new("not needed".to_owned()),
619            pkce_verifier: None,
620        };
621
622        let auto_refresh = auto_refresh && token.refresh_token.is_some();
623
624        Ok(Self {
625            auto_refresh,
626            auth_state: Arc::new(RwLock::new(token)),
627            auth_flow,
628            oauth: oauth_client,
629            http,
630        })
631    }
632}
633
634impl ClientCredsClient<Token> {
635    /// Create a new authenticated client from an access token.
636    /// This client will not be able to access user data.
637    ///
638    /// This method will fail if the access token is invalid (a request will
639    /// be sent to check the token).
640    pub async fn from_access_token(
641        client_id: impl Into<String>,
642        client_secret: impl Into<String>,
643        token: Token,
644    ) -> Result<Self> {
645        let client_id = ClientId::new(client_id.into());
646        let client_secret = Some(ClientSecret::new(client_secret.into()));
647
648        let oauth_client = OAuthClient::new(
649            client_id,
650            client_secret,
651            AuthUrl::new(AUTHORISATION_URL.to_owned()).unwrap(),
652            Some(TokenUrl::new(TOKEN_URL.to_owned()).unwrap()),
653        );
654
655        let http = reqwest::Client::new();
656
657        // This is just a bogus request to check if the token is valid.
658        let res = http
659            .get(format!("{API_URL}/recommendations/available-genre-seeds"))
660            .bearer_auth(token.secret())
661            .header(CONTENT_LENGTH, 0)
662            .send()
663            .await?;
664
665        if !res.status().is_success() {
666            return Err(res.json::<SpotifyError>().await?.into());
667        }
668
669        Ok(Self {
670            auto_refresh: false,
671            auth_state: Arc::new(RwLock::new(token)),
672            auth_flow: ClientCredsFlow,
673            oauth: oauth_client,
674            http,
675        })
676    }
677}