1use rand::RngCore;
26use std::io::{BufRead, BufReader, Read, Write};
27use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpListener, TcpStream};
28use std::time::{Duration, Instant};
29
30use crate::error::{Error, Result};
31use crate::token::Token;
32
33#[derive(Debug, Clone)]
35pub struct HandshakeOptions {
36 pub client: Option<String>,
42 pub source: Option<String>,
48 pub omit_source: bool,
52 pub timeout: Duration,
54}
55
56impl Default for HandshakeOptions {
57 fn default() -> Self {
58 Self {
59 client: None,
60 source: None,
61 omit_source: false,
62 timeout: Duration::from_secs(5 * 60),
63 }
64 }
65}
66
67#[derive(Debug)]
69pub struct HandshakeOutcome {
70 pub token: Token,
72 pub login: Option<String>,
77}
78
79pub struct PendingHandshake {
85 listener: TcpListener,
86 url: String,
87 state: String,
88 timeout: Duration,
89}
90
91impl PendingHandshake {
92 pub fn url(&self) -> &str {
96 &self.url
97 }
98
99 pub fn state(&self) -> &str {
102 &self.state
103 }
104
105 pub async fn wait(self) -> Result<HandshakeOutcome> {
108 let PendingHandshake {
109 listener,
110 state,
111 timeout,
112 ..
113 } = self;
114
115 let join = tokio::task::spawn_blocking(move || accept_loop(listener, &state, timeout));
120 match join.await {
121 Ok(result) => result,
122 Err(e) => Err(Error::BadRequest(format!(
123 "loopback worker panicked or was cancelled: {e}"
124 ))),
125 }
126 }
127}
128
129pub fn loopback_handshake(base_url: &str, options: HandshakeOptions) -> Result<PendingHandshake> {
135 let listener = TcpListener::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))?;
139 let port = listener.local_addr()?.port();
140
141 let state = random_state();
142 let client = options
143 .client
144 .unwrap_or_else(|| "wavekat-platform-client".to_string());
145 let callback = format!("http://127.0.0.1:{port}/callback");
146
147 let base = base_url.trim_end_matches('/');
148 let mut url = format!(
149 "{base}/cli-login?callback={cb}&state={st}&client={cl}",
150 cb = url_encode(&callback),
151 st = url_encode(&state),
152 cl = url_encode(&client),
153 );
154 if !options.omit_source {
155 let source = options.source.unwrap_or_else(default_source);
156 if !source.is_empty() {
157 url.push_str("&source=");
158 url.push_str(&url_encode(&source));
159 }
160 }
161
162 Ok(PendingHandshake {
163 listener,
164 url,
165 state,
166 timeout: options.timeout,
167 })
168}
169
170fn accept_loop(
171 listener: TcpListener,
172 expected_state: &str,
173 timeout: Duration,
174) -> Result<HandshakeOutcome> {
175 listener.set_nonblocking(true)?;
176 let deadline = Instant::now() + timeout;
177
178 loop {
179 if Instant::now() >= deadline {
180 return Err(Error::Timeout(timeout));
181 }
182 match listener.accept() {
183 Ok((stream, _addr)) => {
184 stream.set_nonblocking(false)?;
187 match handle_callback(stream, expected_state) {
188 Ok(HandlerResult::Got(outcome)) => return Ok(outcome),
189 Ok(HandlerResult::KeepListening) => continue,
190 Err(e @ Error::StateMismatch { .. }) => return Err(e),
191 Err(e @ Error::Cancelled(_)) => return Err(e),
192 Err(_) => {
193 continue;
196 }
197 }
198 }
199 Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
200 std::thread::sleep(Duration::from_millis(50));
201 continue;
202 }
203 Err(e) => return Err(e.into()),
204 }
205 }
206}
207
208enum HandlerResult {
209 Got(HandshakeOutcome),
210 KeepListening,
211}
212
213fn handle_callback(mut stream: TcpStream, expected_state: &str) -> Result<HandlerResult> {
214 stream.set_read_timeout(Some(Duration::from_secs(5))).ok();
215 stream.set_write_timeout(Some(Duration::from_secs(5))).ok();
216
217 let mut reader = BufReader::new(stream.try_clone()?);
220 let mut request_line = String::new();
221 reader.read_line(&mut request_line)?;
222
223 let mut header_bytes = 0usize;
224 let mut line = String::new();
225 loop {
226 line.clear();
227 let n = reader.read_line(&mut line)?;
228 if n == 0 || line == "\r\n" || line == "\n" {
229 break;
230 }
231 header_bytes += n;
232 if header_bytes > 8192 {
233 return Err(Error::BadRequest("request headers too large".into()));
234 }
235 }
236
237 let mut parts = request_line.split_whitespace();
238 let method = parts.next().unwrap_or("");
239 let target = parts.next().unwrap_or("");
240 if method != "GET" {
241 respond(&mut stream, 405, "method not allowed", "method not allowed")?;
242 return Ok(HandlerResult::KeepListening);
243 }
244 if !target.starts_with("/callback") {
245 respond(&mut stream, 404, "not found", "not found")?;
248 return Ok(HandlerResult::KeepListening);
249 }
250
251 let query = target.split_once('?').map(|(_, q)| q).unwrap_or("");
252 let mut token: Option<String> = None;
253 let mut state: Option<String> = None;
254 let mut error: Option<String> = None;
255 let mut login: Option<String> = None;
256 for (k, v) in parse_query(query) {
257 match k.as_str() {
258 "token" => token = Some(v),
259 "state" => state = Some(v),
260 "error" => error = Some(v),
261 "login" => login = Some(v),
262 _ => {}
263 }
264 }
265
266 if state.as_deref() != Some(expected_state) {
267 respond(
268 &mut stream,
269 400,
270 "bad state",
271 "<h1>State mismatch</h1><p>Re-run the sign-in to start over.</p>",
272 )?;
273 return Err(Error::StateMismatch {
274 actual: state,
275 expected: expected_state.to_string(),
276 });
277 }
278
279 if let Some(err) = error {
280 respond(
281 &mut stream,
282 200,
283 "OK",
284 &format!(
285 "<h1>Login cancelled</h1><p>You can close this tab and try again.</p><p style='color:#888'>reason: {}</p>",
286 html_escape(&err),
287 ),
288 )?;
289 return Err(Error::Cancelled(err));
290 }
291
292 let Some(tok) = token else {
293 respond(&mut stream, 400, "missing token", "missing token")?;
294 return Err(Error::BadRequest("callback missing token".into()));
295 };
296
297 respond(
298 &mut stream,
299 200,
300 "OK",
301 "<!doctype html><html><head><meta charset=utf-8><title>WaveKat sign-in complete</title><style>body{font-family:system-ui,sans-serif;max-width:32rem;margin:4rem auto;padding:0 1rem;color:#1a1a1a}</style></head><body><h1>You're signed in.</h1><p>You can close this tab and return to the app.</p></body></html>",
302 )?;
303 Ok(HandlerResult::Got(HandshakeOutcome {
304 token: Token::new(tok),
305 login,
306 }))
307}
308
309fn respond(stream: &mut TcpStream, status: u16, reason: &str, body: &str) -> Result<()> {
310 let body_bytes = body.as_bytes();
311 let resp = format!(
312 "HTTP/1.1 {status} {reason}\r\nContent-Type: text/html; charset=utf-8\r\nContent-Length: {len}\r\nConnection: close\r\n\r\n",
313 len = body_bytes.len(),
314 );
315 stream.write_all(resp.as_bytes())?;
316 stream.write_all(body_bytes)?;
317 let _ = stream.flush();
318 let mut sink = [0u8; 64];
321 let _ = stream.set_read_timeout(Some(Duration::from_millis(50)));
322 let _ = stream.read(&mut sink);
323 Ok(())
324}
325
326fn random_state() -> String {
327 let mut bytes = [0u8; 24];
328 rand::thread_rng().fill_bytes(&mut bytes);
329 base64url(&bytes)
330}
331
332fn base64url(bytes: &[u8]) -> String {
335 const ALPHA: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
336 let mut out = String::with_capacity((bytes.len() * 4).div_ceil(3));
337 let mut i = 0;
338 while i + 3 <= bytes.len() {
339 let n = ((bytes[i] as u32) << 16) | ((bytes[i + 1] as u32) << 8) | (bytes[i + 2] as u32);
340 out.push(ALPHA[((n >> 18) & 0x3f) as usize] as char);
341 out.push(ALPHA[((n >> 12) & 0x3f) as usize] as char);
342 out.push(ALPHA[((n >> 6) & 0x3f) as usize] as char);
343 out.push(ALPHA[(n & 0x3f) as usize] as char);
344 i += 3;
345 }
346 let rem = bytes.len() - i;
347 if rem == 1 {
348 let n = (bytes[i] as u32) << 16;
349 out.push(ALPHA[((n >> 18) & 0x3f) as usize] as char);
350 out.push(ALPHA[((n >> 12) & 0x3f) as usize] as char);
351 } else if rem == 2 {
352 let n = ((bytes[i] as u32) << 16) | ((bytes[i + 1] as u32) << 8);
353 out.push(ALPHA[((n >> 18) & 0x3f) as usize] as char);
354 out.push(ALPHA[((n >> 12) & 0x3f) as usize] as char);
355 out.push(ALPHA[((n >> 6) & 0x3f) as usize] as char);
356 }
357 out
358}
359
360fn default_source() -> String {
361 std::env::var("HOSTNAME")
362 .ok()
363 .or_else(|| hostname().ok())
364 .unwrap_or_else(|| "unknown-host".to_string())
365}
366
367#[cfg(unix)]
368fn hostname() -> Result<String> {
369 let out = std::process::Command::new("hostname").output()?;
370 if !out.status.success() {
371 return Err(Error::BadRequest("hostname exited non-zero".into()));
372 }
373 Ok(String::from_utf8_lossy(&out.stdout).trim().to_string())
374}
375
376#[cfg(not(unix))]
377fn hostname() -> Result<String> {
378 std::env::var("COMPUTERNAME")
379 .map_err(|e| Error::BadRequest(format!("COMPUTERNAME not set: {e}")))
380}
381
382fn url_encode(s: &str) -> String {
383 url::form_urlencoded::byte_serialize(s.as_bytes()).collect()
384}
385
386fn parse_query(q: &str) -> Vec<(String, String)> {
387 url::form_urlencoded::parse(q.as_bytes())
388 .map(|(k, v)| (k.into_owned(), v.into_owned()))
389 .collect()
390}
391
392fn html_escape(s: &str) -> String {
393 let mut out = String::with_capacity(s.len());
394 for c in s.chars() {
395 match c {
396 '<' => out.push_str("<"),
397 '>' => out.push_str(">"),
398 '&' => out.push_str("&"),
399 '"' => out.push_str("""),
400 '\'' => out.push_str("'"),
401 _ => out.push(c),
402 }
403 }
404 out
405}
406
407#[cfg(test)]
408mod tests {
409 use super::*;
410
411 #[test]
414 fn base64url_rfc_vectors() {
415 assert_eq!(base64url(b""), "");
416 assert_eq!(base64url(b"f"), "Zg");
417 assert_eq!(base64url(b"fo"), "Zm8");
418 assert_eq!(base64url(b"foo"), "Zm9v");
419 assert_eq!(base64url(b"foob"), "Zm9vYg");
420 assert_eq!(base64url(b"fooba"), "Zm9vYmE");
421 assert_eq!(base64url(b"foobar"), "Zm9vYmFy");
422 }
423
424 #[test]
425 fn base64url_uses_url_safe_alphabet() {
426 assert_eq!(base64url(&[0xfb, 0xff, 0xff]), "-___");
427 let big: Vec<u8> = (0u8..=255).collect();
428 let out = base64url(&big);
429 assert!(!out.contains('+'));
430 assert!(!out.contains('/'));
431 assert!(!out.contains('='));
432 }
433
434 #[test]
435 fn random_state_shape() {
436 let s = random_state();
437 assert_eq!(s.len(), 32);
438 let alpha: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
439 for b in s.as_bytes() {
440 assert!(alpha.contains(b), "unexpected byte {b:#x} in state");
441 }
442 }
443
444 #[test]
445 fn random_state_is_not_constant() {
446 assert_ne!(random_state(), random_state());
447 }
448
449 #[test]
450 fn html_escape_handles_metacharacters() {
451 assert_eq!(
452 html_escape("<a href=\"x\">it's & ok</a>"),
453 "<a href="x">it's & ok</a>",
454 );
455 assert_eq!(html_escape("plain text"), "plain text");
456 assert_eq!(html_escape(""), "");
457 }
458
459 #[test]
462 fn handshake_options_default_has_sensible_timeout() {
463 let opts = HandshakeOptions::default();
464 assert_eq!(opts.timeout, Duration::from_secs(300));
465 assert!(opts.client.is_none());
466 assert!(opts.source.is_none());
467 assert!(!opts.omit_source);
468 }
469
470 #[test]
471 fn default_source_falls_back_to_a_hostname() {
472 let s = default_source();
473 assert!(!s.is_empty(), "should never produce an empty source");
474 }
475
476 #[test]
477 fn loopback_handshake_returns_url_with_loopback_callback() {
478 let pending =
479 loopback_handshake("https://platform.wavekat.com", HandshakeOptions::default())
480 .expect("bind loopback");
481 let url = pending.url();
482 assert!(url.starts_with("https://platform.wavekat.com/cli-login?"));
483 assert!(url.contains("127.0.0.1"), "{url}");
484 assert!(url.contains(&format!(
485 "state={}",
486 url::form_urlencoded::byte_serialize(pending.state().as_bytes()).collect::<String>()
487 )));
488 assert!(
490 url.contains("client=wavekat-platform-client"),
491 "expected client=wavekat-platform-client in {url}",
492 );
493 assert!(url.contains("&source="), "expected &source=... in {url}");
495 }
496
497 #[test]
498 fn loopback_handshake_uses_explicit_client_and_source() {
499 let pending = loopback_handshake(
500 "https://platform.wavekat.com",
501 HandshakeOptions {
502 client: Some("wavekat-voice".into()),
503 source: Some("studio-mac".into()),
504 ..Default::default()
505 },
506 )
507 .expect("bind loopback");
508 let url = pending.url();
509 assert!(url.contains("client=wavekat-voice"), "{url}");
510 assert!(url.contains("source=studio-mac"), "{url}");
511 }
512
513 #[test]
514 fn loopback_handshake_omits_source_when_requested() {
515 let pending = loopback_handshake(
516 "https://platform.wavekat.com",
517 HandshakeOptions {
518 client: Some("wavekat-voice".into()),
519 omit_source: true,
520 ..Default::default()
521 },
522 )
523 .expect("bind loopback");
524 let url = pending.url();
525 assert!(url.contains("client=wavekat-voice"), "{url}");
526 assert!(!url.contains("source="), "should not include source: {url}");
527 }
528
529 #[test]
530 fn loopback_handshake_strips_trailing_slash() {
531 let pending = loopback_handshake("https://platform.wavekat.com/", Default::default())
532 .expect("bind loopback");
533 assert!(pending
534 .url()
535 .starts_with("https://platform.wavekat.com/cli-login?"));
536 }
537}