spotify_cli/oauth/
flow.rs1use thiserror::Error;
7use url::Url;
8
9use super::callback_server::{CallbackError, CallbackResult, CallbackServer, DEFAULT_PORT};
10use super::pkce::PkceChallenge;
11use super::token::{SpotifyTokenResponse, Token};
12use crate::http::auth::SpotifyAuth;
13
14const AUTHORIZE_ENDPOINT: &str = "/authorize";
15
16#[derive(Debug, Error)]
17pub enum OAuthError {
18 #[error("Callback error: {0}")]
19 Callback(#[from] CallbackError),
20
21 #[error("Auth error: {0}")]
22 Auth(#[from] crate::http::auth::AuthError),
23
24 #[error("Failed to open browser: {0}")]
25 Browser(String),
26
27 #[error("Failed to parse token response")]
28 TokenParse,
29}
30
31pub struct OAuthFlow {
35 client_id: String,
36 redirect_uri: String,
37 scopes: Vec<String>,
38 port: u16,
39}
40
41impl OAuthFlow {
42 pub fn new(client_id: String) -> Self {
46 let port = DEFAULT_PORT;
47 let redirect_uri = format!("http://127.0.0.1:{}/callback", port);
48
49 Self {
50 client_id,
51 redirect_uri,
52 scopes: default_scopes(),
53 port,
54 }
55 }
56
57 pub fn with_scopes(mut self, scopes: Vec<String>) -> Self {
59 self.scopes = scopes;
60 self
61 }
62
63 pub fn with_port(mut self, port: u16) -> Self {
65 self.port = port;
66 self.redirect_uri = format!("http://127.0.0.1:{}/callback", port);
67 self
68 }
69
70 pub async fn authenticate(&self) -> Result<Token, OAuthError> {
77 let pkce = PkceChallenge::generate();
78
79 let auth_url = self.build_auth_url(&pkce);
80
81 open_browser(&auth_url)?;
82
83 let callback_result = self.wait_for_callback()?;
84
85 let token = self
86 .exchange_code(&callback_result.code, &pkce.verifier)
87 .await?;
88
89 Ok(token)
90 }
91
92 pub async fn refresh(&self, refresh_token: &str) -> Result<Token, OAuthError> {
94 let auth = SpotifyAuth::new();
95
96 let response = auth.refresh_token(&self.client_id, refresh_token).await?;
97
98 let token_response: SpotifyTokenResponse =
99 serde_json::from_value(response).map_err(|_| OAuthError::TokenParse)?;
100
101 Ok(Token::from_response(token_response))
102 }
103
104 fn build_auth_url(&self, pkce: &PkceChallenge) -> String {
105 let mut url = Url::parse(&SpotifyAuth::url(AUTHORIZE_ENDPOINT))
106 .expect("AUTHORIZE_ENDPOINT is a valid URL");
107
108 url.query_pairs_mut()
109 .append_pair("client_id", &self.client_id)
110 .append_pair("response_type", "code")
111 .append_pair("redirect_uri", &self.redirect_uri)
112 .append_pair("scope", &self.scopes.join(" "))
113 .append_pair("code_challenge_method", "S256")
114 .append_pair("code_challenge", &pkce.challenge);
115
116 url.to_string()
117 }
118
119 fn wait_for_callback(&self) -> Result<CallbackResult, OAuthError> {
120 let server = CallbackServer::new(self.port);
121 let result = server.wait_for_callback()?;
122 Ok(result)
123 }
124
125 async fn exchange_code(&self, code: &str, verifier: &str) -> Result<Token, OAuthError> {
126 let auth = SpotifyAuth::new();
127
128 let response = auth
129 .exchange_code(&self.client_id, code, &self.redirect_uri, verifier)
130 .await?;
131
132 let token_response: SpotifyTokenResponse =
133 serde_json::from_value(response).map_err(|_| OAuthError::TokenParse)?;
134
135 Ok(Token::from_response(token_response))
136 }
137}
138
139fn default_scopes() -> Vec<String> {
140 vec![
141 "user-read-playback-state".to_string(),
142 "user-modify-playback-state".to_string(),
143 "user-read-currently-playing".to_string(),
144 "user-library-read".to_string(),
145 "user-library-modify".to_string(),
146 "playlist-read-private".to_string(),
147 "playlist-read-collaborative".to_string(),
148 "playlist-modify-private".to_string(),
149 "playlist-modify-public".to_string(),
150 "user-read-private".to_string(),
151 "user-read-email".to_string(),
152 "user-top-read".to_string(),
153 "user-read-recently-played".to_string(),
154 "user-follow-read".to_string(),
155 "user-follow-modify".to_string(),
156 ]
157}
158
159fn open_browser(url: &str) -> Result<(), OAuthError> {
160 #[cfg(target_os = "macos")]
161 {
162 std::process::Command::new("open")
163 .arg(url)
164 .spawn()
165 .map_err(|e| OAuthError::Browser(e.to_string()))?;
166 }
167
168 #[cfg(target_os = "linux")]
169 {
170 std::process::Command::new("xdg-open")
171 .arg(url)
172 .spawn()
173 .map_err(|e| OAuthError::Browser(e.to_string()))?;
174 }
175
176 #[cfg(target_os = "windows")]
177 {
178 std::process::Command::new("cmd")
179 .args(["/C", "start", "", url])
180 .spawn()
181 .map_err(|e| OAuthError::Browser(e.to_string()))?;
182 }
183
184 Ok(())
185}
186
187#[cfg(test)]
188mod tests {
189 use super::*;
190
191 #[test]
192 fn oauth_flow_new_creates_with_defaults() {
193 let flow = OAuthFlow::new("test_client_id".to_string());
194 assert_eq!(flow.client_id, "test_client_id");
195 assert_eq!(flow.port, DEFAULT_PORT);
196 assert!(flow.redirect_uri.contains("127.0.0.1"));
197 assert!(flow.redirect_uri.contains("/callback"));
198 }
199
200 #[test]
201 fn oauth_flow_with_scopes() {
202 let flow = OAuthFlow::new("client".to_string())
203 .with_scopes(vec!["scope1".to_string(), "scope2".to_string()]);
204 assert_eq!(flow.scopes.len(), 2);
205 assert!(flow.scopes.contains(&"scope1".to_string()));
206 assert!(flow.scopes.contains(&"scope2".to_string()));
207 }
208
209 #[test]
210 fn oauth_flow_with_port() {
211 let flow = OAuthFlow::new("client".to_string()).with_port(9999);
212 assert_eq!(flow.port, 9999);
213 assert!(flow.redirect_uri.contains("9999"));
214 }
215
216 #[test]
217 fn oauth_flow_port_updates_redirect_uri() {
218 let flow = OAuthFlow::new("client".to_string()).with_port(3000);
219 assert_eq!(flow.redirect_uri, "http://127.0.0.1:3000/callback");
220 }
221
222 #[test]
223 fn default_scopes_contains_required_scopes() {
224 let scopes = default_scopes();
225 assert!(scopes.contains(&"user-read-playback-state".to_string()));
226 assert!(scopes.contains(&"user-modify-playback-state".to_string()));
227 assert!(scopes.contains(&"user-library-read".to_string()));
228 assert!(scopes.contains(&"user-library-modify".to_string()));
229 assert!(scopes.contains(&"playlist-read-private".to_string()));
230 assert!(scopes.contains(&"user-read-private".to_string()));
231 }
232
233 #[test]
234 fn default_scopes_count() {
235 let scopes = default_scopes();
236 assert_eq!(scopes.len(), 15);
237 }
238
239 #[test]
240 fn oauth_error_display_callback() {
241 let err = OAuthError::Callback(CallbackError::Timeout);
242 let display = format!("{}", err);
243 assert!(display.contains("Callback"));
244 }
245
246 #[test]
247 fn oauth_error_display_browser() {
248 let err = OAuthError::Browser("failed to open".to_string());
249 let display = format!("{}", err);
250 assert!(display.contains("browser"));
251 assert!(display.contains("failed to open"));
252 }
253
254 #[test]
255 fn oauth_error_display_token_parse() {
256 let err = OAuthError::TokenParse;
257 let display = format!("{}", err);
258 assert!(display.contains("token"));
259 }
260
261 #[test]
262 fn oauth_error_from_callback_error() {
263 let callback_err = CallbackError::Timeout;
264 let oauth_err: OAuthError = callback_err.into();
265 match oauth_err {
266 OAuthError::Callback(_) => {}
267 _ => panic!("Expected Callback variant"),
268 }
269 }
270
271 #[test]
272 fn build_auth_url_contains_required_params() {
273 let flow = OAuthFlow::new("test_client".to_string());
274 let pkce = PkceChallenge::generate();
275 let url = flow.build_auth_url(&pkce);
276
277 assert!(url.contains("client_id=test_client"));
278 assert!(url.contains("response_type=code"));
279 assert!(url.contains("code_challenge_method=S256"));
280 assert!(url.contains("redirect_uri="));
281 assert!(url.contains("scope="));
282 }
283
284 #[test]
285 fn build_auth_url_includes_pkce_challenge() {
286 let flow = OAuthFlow::new("client".to_string());
287 let pkce = PkceChallenge::generate();
288 let url = flow.build_auth_url(&pkce);
289
290 assert!(url.contains(&pkce.challenge));
291 }
292
293 #[test]
294 fn oauth_flow_chaining_works() {
295 let flow = OAuthFlow::new("client".to_string())
296 .with_port(5000)
297 .with_scopes(vec!["scope1".to_string()]);
298
299 assert_eq!(flow.port, 5000);
300 assert_eq!(flow.scopes.len(), 1);
301 }
302}