spotify_cli/oauth/
flow.rs

1//! OAuth 2.0 Authorization Code flow with PKCE.
2//!
3//! Orchestrates the full authentication flow: browser authorization, callback handling,
4//! and token exchange.
5
6use thiserror::Error;
7use url::Url;
8
9use super::callback_server::{CallbackError, CallbackResult, CallbackServer, DEFAULT_PORT};
10use super::pkce::PkceChallenge;
11use super::token::{SpotifyTokenResponse, Token};
12use crate::http::auth::SpotifyAuth;
13
14const AUTHORIZE_ENDPOINT: &str = "/authorize";
15
16#[derive(Debug, Error)]
17pub enum OAuthError {
18    #[error("Callback error: {0}")]
19    Callback(#[from] CallbackError),
20
21    #[error("Auth error: {0}")]
22    Auth(#[from] crate::http::auth::AuthError),
23
24    #[error("Failed to open browser: {0}")]
25    Browser(String),
26
27    #[error("Failed to parse token response")]
28    TokenParse,
29}
30
31/// OAuth flow configuration and execution.
32///
33/// Handles the complete OAuth 2.0 Authorization Code flow with PKCE.
34pub struct OAuthFlow {
35    client_id: String,
36    redirect_uri: String,
37    scopes: Vec<String>,
38    port: u16,
39}
40
41impl OAuthFlow {
42    /// Create a new OAuth flow with the given Spotify client ID.
43    ///
44    /// Uses default scopes and port 8888 for the callback server.
45    pub fn new(client_id: String) -> Self {
46        let port = DEFAULT_PORT;
47        let redirect_uri = format!("http://127.0.0.1:{}/callback", port);
48
49        Self {
50            client_id,
51            redirect_uri,
52            scopes: default_scopes(),
53            port,
54        }
55    }
56
57    /// Override the default scopes.
58    pub fn with_scopes(mut self, scopes: Vec<String>) -> Self {
59        self.scopes = scopes;
60        self
61    }
62
63    /// Override the default callback port.
64    pub fn with_port(mut self, port: u16) -> Self {
65        self.port = port;
66        self.redirect_uri = format!("http://127.0.0.1:{}/callback", port);
67        self
68    }
69
70    /// Execute the full OAuth flow.
71    ///
72    /// 1. Generates PKCE challenge
73    /// 2. Opens browser to Spotify authorization page
74    /// 3. Waits for callback with authorization code
75    /// 4. Exchanges code for tokens
76    pub async fn authenticate(&self) -> Result<Token, OAuthError> {
77        let pkce = PkceChallenge::generate();
78
79        let auth_url = self.build_auth_url(&pkce);
80
81        open_browser(&auth_url)?;
82
83        let callback_result = self.wait_for_callback()?;
84
85        let token = self
86            .exchange_code(&callback_result.code, &pkce.verifier)
87            .await?;
88
89        Ok(token)
90    }
91
92    /// Refresh an expired access token using a refresh token.
93    pub async fn refresh(&self, refresh_token: &str) -> Result<Token, OAuthError> {
94        let auth = SpotifyAuth::new();
95
96        let response = auth.refresh_token(&self.client_id, refresh_token).await?;
97
98        let token_response: SpotifyTokenResponse =
99            serde_json::from_value(response).map_err(|_| OAuthError::TokenParse)?;
100
101        Ok(Token::from_response(token_response))
102    }
103
104    fn build_auth_url(&self, pkce: &PkceChallenge) -> String {
105        let mut url = Url::parse(&SpotifyAuth::url(AUTHORIZE_ENDPOINT))
106            .expect("AUTHORIZE_ENDPOINT is a valid URL");
107
108        url.query_pairs_mut()
109            .append_pair("client_id", &self.client_id)
110            .append_pair("response_type", "code")
111            .append_pair("redirect_uri", &self.redirect_uri)
112            .append_pair("scope", &self.scopes.join(" "))
113            .append_pair("code_challenge_method", "S256")
114            .append_pair("code_challenge", &pkce.challenge);
115
116        url.to_string()
117    }
118
119    fn wait_for_callback(&self) -> Result<CallbackResult, OAuthError> {
120        let server = CallbackServer::new(self.port);
121        let result = server.wait_for_callback()?;
122        Ok(result)
123    }
124
125    async fn exchange_code(&self, code: &str, verifier: &str) -> Result<Token, OAuthError> {
126        let auth = SpotifyAuth::new();
127
128        let response = auth
129            .exchange_code(&self.client_id, code, &self.redirect_uri, verifier)
130            .await?;
131
132        let token_response: SpotifyTokenResponse =
133            serde_json::from_value(response).map_err(|_| OAuthError::TokenParse)?;
134
135        Ok(Token::from_response(token_response))
136    }
137}
138
139fn default_scopes() -> Vec<String> {
140    vec![
141        "user-read-playback-state".to_string(),
142        "user-modify-playback-state".to_string(),
143        "user-read-currently-playing".to_string(),
144        "user-library-read".to_string(),
145        "user-library-modify".to_string(),
146        "playlist-read-private".to_string(),
147        "playlist-read-collaborative".to_string(),
148        "playlist-modify-private".to_string(),
149        "playlist-modify-public".to_string(),
150        "user-read-private".to_string(),
151        "user-read-email".to_string(),
152        "user-top-read".to_string(),
153        "user-read-recently-played".to_string(),
154        "user-follow-read".to_string(),
155        "user-follow-modify".to_string(),
156    ]
157}
158
159fn open_browser(url: &str) -> Result<(), OAuthError> {
160    #[cfg(target_os = "macos")]
161    {
162        std::process::Command::new("open")
163            .arg(url)
164            .spawn()
165            .map_err(|e| OAuthError::Browser(e.to_string()))?;
166    }
167
168    #[cfg(target_os = "linux")]
169    {
170        std::process::Command::new("xdg-open")
171            .arg(url)
172            .spawn()
173            .map_err(|e| OAuthError::Browser(e.to_string()))?;
174    }
175
176    #[cfg(target_os = "windows")]
177    {
178        std::process::Command::new("cmd")
179            .args(["/C", "start", "", url])
180            .spawn()
181            .map_err(|e| OAuthError::Browser(e.to_string()))?;
182    }
183
184    Ok(())
185}
186
187#[cfg(test)]
188mod tests {
189    use super::*;
190
191    #[test]
192    fn oauth_flow_new_creates_with_defaults() {
193        let flow = OAuthFlow::new("test_client_id".to_string());
194        assert_eq!(flow.client_id, "test_client_id");
195        assert_eq!(flow.port, DEFAULT_PORT);
196        assert!(flow.redirect_uri.contains("127.0.0.1"));
197        assert!(flow.redirect_uri.contains("/callback"));
198    }
199
200    #[test]
201    fn oauth_flow_with_scopes() {
202        let flow = OAuthFlow::new("client".to_string())
203            .with_scopes(vec!["scope1".to_string(), "scope2".to_string()]);
204        assert_eq!(flow.scopes.len(), 2);
205        assert!(flow.scopes.contains(&"scope1".to_string()));
206        assert!(flow.scopes.contains(&"scope2".to_string()));
207    }
208
209    #[test]
210    fn oauth_flow_with_port() {
211        let flow = OAuthFlow::new("client".to_string()).with_port(9999);
212        assert_eq!(flow.port, 9999);
213        assert!(flow.redirect_uri.contains("9999"));
214    }
215
216    #[test]
217    fn oauth_flow_port_updates_redirect_uri() {
218        let flow = OAuthFlow::new("client".to_string()).with_port(3000);
219        assert_eq!(flow.redirect_uri, "http://127.0.0.1:3000/callback");
220    }
221
222    #[test]
223    fn default_scopes_contains_required_scopes() {
224        let scopes = default_scopes();
225        assert!(scopes.contains(&"user-read-playback-state".to_string()));
226        assert!(scopes.contains(&"user-modify-playback-state".to_string()));
227        assert!(scopes.contains(&"user-library-read".to_string()));
228        assert!(scopes.contains(&"user-library-modify".to_string()));
229        assert!(scopes.contains(&"playlist-read-private".to_string()));
230        assert!(scopes.contains(&"user-read-private".to_string()));
231    }
232
233    #[test]
234    fn default_scopes_count() {
235        let scopes = default_scopes();
236        assert_eq!(scopes.len(), 15);
237    }
238
239    #[test]
240    fn oauth_error_display_callback() {
241        let err = OAuthError::Callback(CallbackError::Timeout);
242        let display = format!("{}", err);
243        assert!(display.contains("Callback"));
244    }
245
246    #[test]
247    fn oauth_error_display_browser() {
248        let err = OAuthError::Browser("failed to open".to_string());
249        let display = format!("{}", err);
250        assert!(display.contains("browser"));
251        assert!(display.contains("failed to open"));
252    }
253
254    #[test]
255    fn oauth_error_display_token_parse() {
256        let err = OAuthError::TokenParse;
257        let display = format!("{}", err);
258        assert!(display.contains("token"));
259    }
260
261    #[test]
262    fn oauth_error_from_callback_error() {
263        let callback_err = CallbackError::Timeout;
264        let oauth_err: OAuthError = callback_err.into();
265        match oauth_err {
266            OAuthError::Callback(_) => {}
267            _ => panic!("Expected Callback variant"),
268        }
269    }
270
271    #[test]
272    fn build_auth_url_contains_required_params() {
273        let flow = OAuthFlow::new("test_client".to_string());
274        let pkce = PkceChallenge::generate();
275        let url = flow.build_auth_url(&pkce);
276
277        assert!(url.contains("client_id=test_client"));
278        assert!(url.contains("response_type=code"));
279        assert!(url.contains("code_challenge_method=S256"));
280        assert!(url.contains("redirect_uri="));
281        assert!(url.contains("scope="));
282    }
283
284    #[test]
285    fn build_auth_url_includes_pkce_challenge() {
286        let flow = OAuthFlow::new("client".to_string());
287        let pkce = PkceChallenge::generate();
288        let url = flow.build_auth_url(&pkce);
289
290        assert!(url.contains(&pkce.challenge));
291    }
292
293    #[test]
294    fn oauth_flow_chaining_works() {
295        let flow = OAuthFlow::new("client".to_string())
296            .with_port(5000)
297            .with_scopes(vec!["scope1".to_string()]);
298
299        assert_eq!(flow.port, 5000);
300        assert_eq!(flow.scopes.len(), 1);
301    }
302}