1use super::types::OAuthError;
6use std::collections::HashMap;
7use std::sync::{Arc, Mutex};
8use tokio::io::{AsyncReadExt, AsyncWriteExt};
9use tokio::net::TcpListener;
10use tokio::time::{timeout, Duration};
11
12#[derive(Debug, Clone)]
13pub struct CallbackResult {
14 pub code: String,
15 #[allow(dead_code)]
16 pub state: String,
17}
18
19pub async fn run_callback_server(
28 port: u16,
29 expected_state: String,
30 timeout_secs: u64,
31) -> Result<CallbackResult, OAuthError> {
32 let bind_addr = format!("127.0.0.1:{}", port);
33 let listener = TcpListener::bind(&bind_addr)
34 .await
35 .map_err(|e| OAuthError::ServerError(format!("Failed to bind to port {}: {}", port, e)))?;
36
37 let actual_port = listener.local_addr().map(|a| a.port()).unwrap_or(port);
38 println!(
39 "Listening for OAuth callback on http://127.0.0.1:{}",
40 actual_port
41 );
42
43 let result: Arc<Mutex<Option<Result<CallbackResult, OAuthError>>>> = Arc::new(Mutex::new(None));
44
45 let server_result = result.clone();
46 let server_task = async move {
47 loop {
48 let (mut socket, _) = match listener.accept().await {
49 Ok(conn) => conn,
50 Err(e) => {
51 let mut res = server_result.lock().unwrap();
52 *res = Some(Err(OAuthError::ServerError(format!(
53 "Failed to accept connection: {}",
54 e
55 ))));
56 break;
57 }
58 };
59
60 let mut buffer = vec![0; 4096];
61 let n = match socket.read(&mut buffer).await {
62 Ok(n) if n > 0 => n,
63 _ => continue,
64 };
65
66 let request = String::from_utf8_lossy(&buffer[..n]);
67
68 if let Some(first_line) = request.lines().next() {
70 if let Some(path_part) = first_line.split_whitespace().nth(1) {
71 if let Some(query_start) = path_part.find('?') {
72 let query = &path_part[query_start + 1..];
73 let params = parse_query_string(query);
74
75 let response = if let (Some(code), Some(state)) =
76 (params.get("code"), params.get("state"))
77 {
78 if state != &expected_state {
80 let mut res = server_result.lock().unwrap();
81 *res = Some(Err(OAuthError::StateMismatch {
82 expected: expected_state.clone(),
83 actual: state.clone(),
84 }));
85 create_error_response("State mismatch - possible CSRF attack")
86 } else {
87 let mut res = server_result.lock().unwrap();
88 *res = Some(Ok(CallbackResult {
89 code: code.clone(),
90 state: state.clone(),
91 }));
92 create_success_response()
93 }
94 } else if let Some(error) = params.get("error") {
95 let mut res = server_result.lock().unwrap();
96 *res = Some(Err(OAuthError::SlackError(error.clone())));
97 create_error_response(&format!("OAuth error: {}", error))
98 } else {
99 create_error_response("Missing required parameters")
100 };
101
102 let _ = socket.write_all(response.as_bytes()).await;
103 let _ = socket.flush().await;
104 break;
105 }
106 }
107 }
108 }
109 };
110
111 match timeout(Duration::from_secs(timeout_secs), server_task).await {
113 Ok(_) => {
114 let res = result.lock().unwrap();
115 match res.as_ref() {
116 Some(Ok(callback_result)) => Ok(callback_result.clone()),
117 Some(Err(e)) => Err(format_oauth_error(e)),
118 None => Err(OAuthError::ServerError("No result received".to_string())),
119 }
120 }
121 Err(_) => Err(OAuthError::ServerError(format!(
122 "Timeout after {} seconds waiting for callback",
123 timeout_secs
124 ))),
125 }
126}
127
128fn format_oauth_error(err: &OAuthError) -> OAuthError {
130 match err {
131 OAuthError::ConfigError(msg) => OAuthError::ConfigError(msg.clone()),
132 OAuthError::NetworkError(msg) => OAuthError::NetworkError(msg.clone()),
133 OAuthError::HttpError(code, msg) => OAuthError::HttpError(*code, msg.clone()),
134 OAuthError::ParseError(msg) => OAuthError::ParseError(msg.clone()),
135 OAuthError::SlackError(msg) => OAuthError::SlackError(msg.clone()),
136 OAuthError::StateMismatch { expected, actual } => OAuthError::StateMismatch {
137 expected: expected.clone(),
138 actual: actual.clone(),
139 },
140 OAuthError::ServerError(msg) => OAuthError::ServerError(msg.clone()),
141 OAuthError::BrowserError(msg) => OAuthError::BrowserError(msg.clone()),
142 }
143}
144
145fn parse_query_string(query: &str) -> HashMap<String, String> {
147 query
148 .split('&')
149 .filter_map(|pair| {
150 let mut parts = pair.split('=');
151 match (parts.next(), parts.next()) {
152 (Some(key), Some(value)) => Some((
153 key.to_string(),
154 urlencoding::decode(value).ok()?.to_string(),
155 )),
156 _ => None,
157 }
158 })
159 .collect()
160}
161
162fn create_success_response() -> String {
163 "HTTP/1.1 200 OK\r\n\
164 Content-Type: text/html; charset=utf-8\r\n\
165 Connection: close\r\n\
166 \r\n\
167 <html>\
168 <head><title>Authentication Successful</title></head>\
169 <body>\
170 <h1>✓ Authentication Successful</h1>\
171 <p>You can close this window and return to the CLI.</p>\
172 </body>\
173 </html>"
174 .to_string()
175}
176
177fn create_error_response(message: &str) -> String {
178 format!(
179 "HTTP/1.1 400 Bad Request\r\n\
180 Content-Type: text/html; charset=utf-8\r\n\
181 Connection: close\r\n\
182 \r\n\
183 <html>\
184 <head><title>Authentication Failed</title></head>\
185 <body>\
186 <h1>✗ Authentication Failed</h1>\
187 <p>{}</p>\
188 </body>\
189 </html>",
190 message
191 )
192}
193
194mod urlencoding {
197 pub fn decode(s: &str) -> Result<String, ()> {
198 let mut result = String::new();
200 let mut chars = s.chars();
201 while let Some(c) = chars.next() {
202 match c {
203 '%' => {
204 let hex: String = chars.by_ref().take(2).collect();
205 if hex.len() == 2 {
206 if let Ok(byte) = u8::from_str_radix(&hex, 16) {
207 result.push(byte as char);
208 } else {
209 return Err(());
210 }
211 } else {
212 return Err(());
213 }
214 }
215 '+' => result.push(' '),
216 c => result.push(c),
217 }
218 }
219 Ok(result)
220 }
221}
222
223#[cfg(test)]
224mod tests {
225 use super::*;
226
227 #[test]
228 fn test_parse_query_string() {
229 let query = "code=test_code&state=test_state&foo=bar";
230 let params = parse_query_string(query);
231
232 assert_eq!(params.get("code"), Some(&"test_code".to_string()));
233 assert_eq!(params.get("state"), Some(&"test_state".to_string()));
234 assert_eq!(params.get("foo"), Some(&"bar".to_string()));
235 }
236
237 #[test]
238 fn test_parse_query_string_with_encoding() {
239 let query = "message=hello+world&name=test%20user";
240 let params = parse_query_string(query);
241
242 assert_eq!(params.get("message"), Some(&"hello world".to_string()));
243 assert_eq!(params.get("name"), Some(&"test user".to_string()));
244 }
245
246 #[tokio::test]
247 async fn test_callback_server_timeout() {
248 let state = "test_state".to_string();
250 let result = run_callback_server(13579, state, 1).await;
251
252 assert!(result.is_err());
253 match result {
254 Err(OAuthError::ServerError(msg)) => {
255 assert!(msg.contains("Timeout"));
256 }
257 _ => panic!("Expected ServerError with timeout"),
258 }
259 }
260}