spotify_cli/oauth/
callback_server.rs1use std::time::Duration;
7use thiserror::Error;
8use tiny_http::{Response, Server};
9use url::Url;
10
11use crate::constants::{DEFAULT_OAUTH_PORT, OAUTH_CALLBACK_PATH, OAUTH_CALLBACK_TIMEOUT_SECS};
12
13pub const DEFAULT_PORT: u16 = DEFAULT_OAUTH_PORT;
15pub const CALLBACK_PATH: &str = OAUTH_CALLBACK_PATH;
17
18#[derive(Debug, Error)]
19pub enum CallbackError {
20 #[error("Failed to start server: {0}")]
21 ServerStart(String),
22
23 #[error("Timeout waiting for callback")]
24 Timeout,
25
26 #[error("Missing authorization code")]
27 MissingCode,
28
29 #[error("Authorization denied: {0}")]
30 Denied(String),
31
32 #[error("Invalid callback request")]
33 InvalidRequest,
34}
35
36pub struct CallbackServer {
38 port: u16,
39 timeout: Duration,
40}
41
42pub struct CallbackResult {
44 pub code: String,
46 pub state: Option<String>,
48}
49
50impl CallbackServer {
51 pub fn new(port: u16) -> Self {
55 Self {
56 port,
57 timeout: Duration::from_secs(OAUTH_CALLBACK_TIMEOUT_SECS),
58 }
59 }
60
61 pub fn with_timeout(mut self, timeout: Duration) -> Self {
63 self.timeout = timeout;
64 self
65 }
66
67 pub fn redirect_uri(&self) -> String {
69 format!("http://127.0.0.1:{}{}", self.port, CALLBACK_PATH)
70 }
71
72 pub fn wait_for_callback(self) -> Result<CallbackResult, CallbackError> {
76 let addr = format!("127.0.0.1:{}", self.port);
77 let server = Server::http(&addr).map_err(|e| CallbackError::ServerStart(e.to_string()))?;
78
79 loop {
80 let request = match server.recv_timeout(self.timeout) {
81 Ok(Some(req)) => req,
82 Ok(None) => return Err(CallbackError::Timeout),
83 Err(_) => return Err(CallbackError::Timeout),
84 };
85
86 let url_str = format!("http://127.0.0.1{}", request.url());
87 let url = Url::parse(&url_str).map_err(|_| CallbackError::InvalidRequest)?;
88
89 if url.path() != CALLBACK_PATH {
90 let response = Response::from_string("Not found").with_status_code(404);
91 let _ = request.respond(response);
92 continue;
93 }
94
95 let params: std::collections::HashMap<_, _> = url.query_pairs().collect();
96
97 if let Some(error) = params.get("error") {
98 let description = params
99 .get("error_description")
100 .map(|s| s.to_string())
101 .unwrap_or_else(|| error.to_string());
102
103 let response = Response::from_string(error_html(&description)).with_header(
104 "Content-Type: text/html; charset=utf-8"
105 .parse::<tiny_http::Header>()
106 .expect("static header string is valid"),
107 );
108 let _ = request.respond(response);
109
110 return Err(CallbackError::Denied(description));
111 }
112
113 let code = params
114 .get("code")
115 .map(|s| s.to_string())
116 .ok_or(CallbackError::MissingCode)?;
117
118 let state = params.get("state").map(|s| s.to_string());
119
120 let response = Response::from_string(success_html()).with_header(
121 "Content-Type: text/html; charset=utf-8"
122 .parse::<tiny_http::Header>()
123 .expect("static header string is valid"),
124 );
125 let _ = request.respond(response);
126
127 return Ok(CallbackResult { code, state });
128 }
129 }
130}
131
132fn success_html() -> String {
133 r#"<!DOCTYPE html>
134<html>
135<head>
136 <title>spotify-cli</title>
137 <style>
138 body {
139 font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
140 display: flex;
141 justify-content: center;
142 align-items: center;
143 height: 100vh;
144 margin: 0;
145 background: #191414;
146 color: #1DB954;
147 }
148 .container { text-align: center; }
149 h1 { font-size: 2rem; margin-bottom: 1rem; }
150 p { color: #b3b3b3; }
151 </style>
152</head>
153<body>
154 <div class="container">
155 <h1>Authenticated!</h1>
156 <p>You can close this window and return to your terminal.</p>
157 </div>
158</body>
159</html>"#
160 .to_string()
161}
162
163fn error_html(message: &str) -> String {
164 format!(
165 r#"<!DOCTYPE html>
166<html>
167<head>
168 <title>spotify-cli - Error</title>
169 <style>
170 body {{
171 font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
172 display: flex;
173 justify-content: center;
174 align-items: center;
175 height: 100vh;
176 margin: 0;
177 background: #191414;
178 color: #e22134;
179 }}
180 .container {{ text-align: center; }}
181 h1 {{ font-size: 2rem; margin-bottom: 1rem; }}
182 p {{ color: #b3b3b3; }}
183 </style>
184</head>
185<body>
186 <div class="container">
187 <h1>Authentication Failed</h1>
188 <p>{}</p>
189 </div>
190</body>
191</html>"#,
192 message
193 )
194}
195
196#[cfg(test)]
197mod tests {
198 use super::*;
199
200 #[test]
201 fn redirect_uri_uses_correct_format() {
202 let server = CallbackServer::new(8888);
203 assert_eq!(server.redirect_uri(), "http://127.0.0.1:8888/callback");
204 }
205
206 #[test]
207 fn can_customize_port() {
208 let server = CallbackServer::new(9999);
209 assert_eq!(server.redirect_uri(), "http://127.0.0.1:9999/callback");
210 }
211
212 #[test]
213 fn with_timeout_sets_custom_timeout() {
214 let server = CallbackServer::new(8888).with_timeout(Duration::from_secs(60));
215 assert_eq!(server.timeout, Duration::from_secs(60));
216 }
217
218 #[test]
219 fn default_timeout_is_five_minutes() {
220 let server = CallbackServer::new(8888);
221 assert_eq!(
222 server.timeout,
223 Duration::from_secs(OAUTH_CALLBACK_TIMEOUT_SECS)
224 );
225 }
226
227 #[test]
228 fn callback_error_display_server_start() {
229 let err = CallbackError::ServerStart("port in use".to_string());
230 let display = format!("{}", err);
231 assert!(display.contains("server"));
232 assert!(display.contains("port in use"));
233 }
234
235 #[test]
236 fn callback_error_display_timeout() {
237 let err = CallbackError::Timeout;
238 let display = format!("{}", err);
239 assert!(display.contains("Timeout"));
240 }
241
242 #[test]
243 fn callback_error_display_missing_code() {
244 let err = CallbackError::MissingCode;
245 let display = format!("{}", err);
246 assert!(display.contains("authorization code"));
247 }
248
249 #[test]
250 fn callback_error_display_denied() {
251 let err = CallbackError::Denied("access_denied".to_string());
252 let display = format!("{}", err);
253 assert!(display.contains("denied"));
254 assert!(display.contains("access_denied"));
255 }
256
257 #[test]
258 fn callback_error_display_invalid_request() {
259 let err = CallbackError::InvalidRequest;
260 let display = format!("{}", err);
261 assert!(display.contains("Invalid"));
262 }
263
264 #[test]
265 fn success_html_contains_authenticated() {
266 let html = success_html();
267 assert!(html.contains("Authenticated"));
268 assert!(html.contains("html"));
269 assert!(html.contains("spotify-cli"));
270 }
271
272 #[test]
273 fn error_html_contains_message() {
274 let html = error_html("Test error message");
275 assert!(html.contains("Test error message"));
276 assert!(html.contains("Authentication Failed"));
277 assert!(html.contains("html"));
278 }
279
280 #[test]
281 fn callback_result_stores_code_and_state() {
282 let result = CallbackResult {
283 code: "test_code".to_string(),
284 state: Some("test_state".to_string()),
285 };
286 assert_eq!(result.code, "test_code");
287 assert_eq!(result.state, Some("test_state".to_string()));
288 }
289
290 #[test]
291 fn callback_result_state_can_be_none() {
292 let result = CallbackResult {
293 code: "code".to_string(),
294 state: None,
295 };
296 assert!(result.state.is_none());
297 }
298
299 #[test]
300 fn default_port_constant() {
301 assert_eq!(DEFAULT_PORT, 8888);
302 }
303
304 #[test]
305 fn callback_path_constant() {
306 assert_eq!(CALLBACK_PATH, "/callback");
307 }
308
309 #[test]
310 fn chained_with_timeout() {
311 let server = CallbackServer::new(8080).with_timeout(Duration::from_secs(120));
312 assert_eq!(server.port, 8080);
313 assert_eq!(server.timeout, Duration::from_secs(120));
314 }
315}