spotify_cli/oauth/
callback_server.rs

1//! Local HTTP server for OAuth callback.
2//!
3//! Starts a temporary HTTP server on localhost to receive the OAuth callback
4//! containing the authorization code after user approval.
5
6use std::time::Duration;
7use thiserror::Error;
8use tiny_http::{Response, Server};
9use url::Url;
10
11use crate::constants::{DEFAULT_OAUTH_PORT, OAUTH_CALLBACK_PATH, OAUTH_CALLBACK_TIMEOUT_SECS};
12
13/// Re-export for backward compatibility.
14pub const DEFAULT_PORT: u16 = DEFAULT_OAUTH_PORT;
15/// Re-export for backward compatibility.
16pub const CALLBACK_PATH: &str = OAUTH_CALLBACK_PATH;
17
18#[derive(Debug, Error)]
19pub enum CallbackError {
20    #[error("Failed to start server: {0}")]
21    ServerStart(String),
22
23    #[error("Timeout waiting for callback")]
24    Timeout,
25
26    #[error("Missing authorization code")]
27    MissingCode,
28
29    #[error("Authorization denied: {0}")]
30    Denied(String),
31
32    #[error("Invalid callback request")]
33    InvalidRequest,
34}
35
36/// HTTP server that listens for the OAuth callback.
37pub struct CallbackServer {
38    port: u16,
39    timeout: Duration,
40}
41
42/// Result from a successful OAuth callback.
43pub struct CallbackResult {
44    /// The authorization code to exchange for tokens.
45    pub code: String,
46    /// Optional state parameter for CSRF protection.
47    pub state: Option<String>,
48}
49
50impl CallbackServer {
51    /// Create a new callback server on the given port.
52    ///
53    /// Default timeout is 5 minutes.
54    pub fn new(port: u16) -> Self {
55        Self {
56            port,
57            timeout: Duration::from_secs(OAUTH_CALLBACK_TIMEOUT_SECS),
58        }
59    }
60
61    /// Set a custom timeout for waiting for the callback.
62    pub fn with_timeout(mut self, timeout: Duration) -> Self {
63        self.timeout = timeout;
64        self
65    }
66
67    /// Get the redirect URI for this server.
68    pub fn redirect_uri(&self) -> String {
69        format!("http://127.0.0.1:{}{}", self.port, CALLBACK_PATH)
70    }
71
72    /// Start the server and wait for the OAuth callback.
73    ///
74    /// Blocks until callback is received or timeout expires.
75    pub fn wait_for_callback(self) -> Result<CallbackResult, CallbackError> {
76        let addr = format!("127.0.0.1:{}", self.port);
77        let server = Server::http(&addr).map_err(|e| CallbackError::ServerStart(e.to_string()))?;
78
79        loop {
80            let request = match server.recv_timeout(self.timeout) {
81                Ok(Some(req)) => req,
82                Ok(None) => return Err(CallbackError::Timeout),
83                Err(_) => return Err(CallbackError::Timeout),
84            };
85
86            let url_str = format!("http://127.0.0.1{}", request.url());
87            let url = Url::parse(&url_str).map_err(|_| CallbackError::InvalidRequest)?;
88
89            if url.path() != CALLBACK_PATH {
90                let response = Response::from_string("Not found").with_status_code(404);
91                let _ = request.respond(response);
92                continue;
93            }
94
95            let params: std::collections::HashMap<_, _> = url.query_pairs().collect();
96
97            if let Some(error) = params.get("error") {
98                let description = params
99                    .get("error_description")
100                    .map(|s| s.to_string())
101                    .unwrap_or_else(|| error.to_string());
102
103                let response = Response::from_string(error_html(&description)).with_header(
104                    "Content-Type: text/html; charset=utf-8"
105                        .parse::<tiny_http::Header>()
106                        .expect("static header string is valid"),
107                );
108                let _ = request.respond(response);
109
110                return Err(CallbackError::Denied(description));
111            }
112
113            let code = params
114                .get("code")
115                .map(|s| s.to_string())
116                .ok_or(CallbackError::MissingCode)?;
117
118            let state = params.get("state").map(|s| s.to_string());
119
120            let response = Response::from_string(success_html()).with_header(
121                "Content-Type: text/html; charset=utf-8"
122                    .parse::<tiny_http::Header>()
123                    .expect("static header string is valid"),
124            );
125            let _ = request.respond(response);
126
127            return Ok(CallbackResult { code, state });
128        }
129    }
130}
131
132fn success_html() -> String {
133    r#"<!DOCTYPE html>
134<html>
135<head>
136    <title>spotify-cli</title>
137    <style>
138        body {
139            font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
140            display: flex;
141            justify-content: center;
142            align-items: center;
143            height: 100vh;
144            margin: 0;
145            background: #191414;
146            color: #1DB954;
147        }
148        .container { text-align: center; }
149        h1 { font-size: 2rem; margin-bottom: 1rem; }
150        p { color: #b3b3b3; }
151    </style>
152</head>
153<body>
154    <div class="container">
155        <h1>Authenticated!</h1>
156        <p>You can close this window and return to your terminal.</p>
157    </div>
158</body>
159</html>"#
160        .to_string()
161}
162
163fn error_html(message: &str) -> String {
164    format!(
165        r#"<!DOCTYPE html>
166<html>
167<head>
168    <title>spotify-cli - Error</title>
169    <style>
170        body {{
171            font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
172            display: flex;
173            justify-content: center;
174            align-items: center;
175            height: 100vh;
176            margin: 0;
177            background: #191414;
178            color: #e22134;
179        }}
180        .container {{ text-align: center; }}
181        h1 {{ font-size: 2rem; margin-bottom: 1rem; }}
182        p {{ color: #b3b3b3; }}
183    </style>
184</head>
185<body>
186    <div class="container">
187        <h1>Authentication Failed</h1>
188        <p>{}</p>
189    </div>
190</body>
191</html>"#,
192        message
193    )
194}
195
196#[cfg(test)]
197mod tests {
198    use super::*;
199
200    #[test]
201    fn redirect_uri_uses_correct_format() {
202        let server = CallbackServer::new(8888);
203        assert_eq!(server.redirect_uri(), "http://127.0.0.1:8888/callback");
204    }
205
206    #[test]
207    fn can_customize_port() {
208        let server = CallbackServer::new(9999);
209        assert_eq!(server.redirect_uri(), "http://127.0.0.1:9999/callback");
210    }
211
212    #[test]
213    fn with_timeout_sets_custom_timeout() {
214        let server = CallbackServer::new(8888).with_timeout(Duration::from_secs(60));
215        assert_eq!(server.timeout, Duration::from_secs(60));
216    }
217
218    #[test]
219    fn default_timeout_is_five_minutes() {
220        let server = CallbackServer::new(8888);
221        assert_eq!(
222            server.timeout,
223            Duration::from_secs(OAUTH_CALLBACK_TIMEOUT_SECS)
224        );
225    }
226
227    #[test]
228    fn callback_error_display_server_start() {
229        let err = CallbackError::ServerStart("port in use".to_string());
230        let display = format!("{}", err);
231        assert!(display.contains("server"));
232        assert!(display.contains("port in use"));
233    }
234
235    #[test]
236    fn callback_error_display_timeout() {
237        let err = CallbackError::Timeout;
238        let display = format!("{}", err);
239        assert!(display.contains("Timeout"));
240    }
241
242    #[test]
243    fn callback_error_display_missing_code() {
244        let err = CallbackError::MissingCode;
245        let display = format!("{}", err);
246        assert!(display.contains("authorization code"));
247    }
248
249    #[test]
250    fn callback_error_display_denied() {
251        let err = CallbackError::Denied("access_denied".to_string());
252        let display = format!("{}", err);
253        assert!(display.contains("denied"));
254        assert!(display.contains("access_denied"));
255    }
256
257    #[test]
258    fn callback_error_display_invalid_request() {
259        let err = CallbackError::InvalidRequest;
260        let display = format!("{}", err);
261        assert!(display.contains("Invalid"));
262    }
263
264    #[test]
265    fn success_html_contains_authenticated() {
266        let html = success_html();
267        assert!(html.contains("Authenticated"));
268        assert!(html.contains("html"));
269        assert!(html.contains("spotify-cli"));
270    }
271
272    #[test]
273    fn error_html_contains_message() {
274        let html = error_html("Test error message");
275        assert!(html.contains("Test error message"));
276        assert!(html.contains("Authentication Failed"));
277        assert!(html.contains("html"));
278    }
279
280    #[test]
281    fn callback_result_stores_code_and_state() {
282        let result = CallbackResult {
283            code: "test_code".to_string(),
284            state: Some("test_state".to_string()),
285        };
286        assert_eq!(result.code, "test_code");
287        assert_eq!(result.state, Some("test_state".to_string()));
288    }
289
290    #[test]
291    fn callback_result_state_can_be_none() {
292        let result = CallbackResult {
293            code: "code".to_string(),
294            state: None,
295        };
296        assert!(result.state.is_none());
297    }
298
299    #[test]
300    fn default_port_constant() {
301        assert_eq!(DEFAULT_PORT, 8888);
302    }
303
304    #[test]
305    fn callback_path_constant() {
306        assert_eq!(CALLBACK_PATH, "/callback");
307    }
308
309    #[test]
310    fn chained_with_timeout() {
311        let server = CallbackServer::new(8080).with_timeout(Duration::from_secs(120));
312        assert_eq!(server.port, 8080);
313        assert_eq!(server.timeout, Duration::from_secs(120));
314    }
315}