Skip to main content

wavekat_platform_client/
oauth.rs

1//! Loopback OAuth handshake.
2//!
3//! Mirrors the flow that `wavekat-cli`'s `wk login` runs today:
4//!
5//!   1. Bind a TCP listener on `127.0.0.1:<random ephemeral port>`.
6//!   2. Generate a one-shot CSRF state.
7//!   3. Compute the platform's `/cli-login` URL with
8//!      `?callback=…&state=…&client=…[&source=…]` and hand it back to
9//!      the caller via [`PendingHandshake::url`].
10//!   4. Caller opens the URL however they want — `webbrowser::open` for a
11//!      CLI, `shell.openExternal` for an Electron app, plain `println!`
12//!      for a remote host with no browser.
13//!   5. Caller awaits [`PendingHandshake::wait`], which blocks (on a
14//!      worker thread) until the platform redirects the browser back to
15//!      the loopback URL with `?token=…&state=…` (success) or
16//!      `?error=…&state=…` (cancel).
17//!
18//! The crate intentionally does **not** call `webbrowser::open` itself —
19//! that decision is the consumer's. See `docs/01-initial-port.md` for why.
20//!
21//! The loopback HTTP server is hand-rolled (`std::net`) for the same
22//! reason as the CLI: one request, one response, no need to drag in a
23//! framework. Reads are bounded so a stray local probe can't tie us up.
24
25use rand::RngCore;
26use std::io::{BufRead, BufReader, Read, Write};
27use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpListener, TcpStream};
28use std::time::{Duration, Instant};
29
30use crate::error::{Error, Result};
31use crate::token::Token;
32
33/// Tunables for [`loopback_handshake`].
34#[derive(Debug, Clone)]
35pub struct HandshakeOptions {
36    /// Short app identifier sent as the `client` query param. Shown as
37    /// the consent screen title and in the user's "Active sessions"
38    /// listing. Defaults to `"wavekat-platform-client"`; consumers
39    /// usually want to override with their own product name (e.g.
40    /// `"wavekat-voice"`).
41    pub client: Option<String>,
42    /// Origin label sent as the `source` query param. Shown beside the
43    /// title as "from <source>" and in the listing. Typically the
44    /// machine's hostname; defaults to that. Set to `Some("")` (or
45    /// override [`Self::omit_source`] to true) to suppress it for
46    /// privacy.
47    pub source: Option<String>,
48    /// If true, no `source` is sent at all — the platform will store
49    /// `null` and the consent UI won't render a "from …" line. Useful
50    /// for desktop apps that don't want to disclose the hostname.
51    pub omit_source: bool,
52    /// How long to wait for the browser callback. Default: 5 min.
53    pub timeout: Duration,
54}
55
56impl Default for HandshakeOptions {
57    fn default() -> Self {
58        Self {
59            client: None,
60            source: None,
61            omit_source: false,
62            timeout: Duration::from_secs(5 * 60),
63        }
64    }
65}
66
67/// Result of a successful handshake.
68#[derive(Debug)]
69pub struct HandshakeOutcome {
70    /// The signed-in token. Hand to [`crate::Client::new`].
71    pub token: Token,
72    /// Echoed back from the platform — typically the user's login.
73    /// Useful so callers can skip an extra `/api/me` round-trip right
74    /// after sign-in. `None` when the platform didn't include it (older
75    /// platforms, or the `error=` path).
76    pub login: Option<String>,
77}
78
79/// In-flight handshake. Returned by [`loopback_handshake`] before the
80/// caller decides how to surface the URL.
81///
82/// Drop this without calling [`PendingHandshake::wait`] to abandon the
83/// flow; the listener closes when the value goes out of scope.
84pub struct PendingHandshake {
85    listener: TcpListener,
86    url: String,
87    state: String,
88    timeout: Duration,
89}
90
91impl PendingHandshake {
92    /// The URL to open in the user's browser. Hand this to whatever
93    /// "open external URL" facility the consumer has — `webbrowser::open`,
94    /// `shell.openExternal`, or `println!`.
95    pub fn url(&self) -> &str {
96        &self.url
97    }
98
99    /// The CSRF state we generated. Mostly useful for tests / debugging;
100    /// callers don't normally need to look at it.
101    pub fn state(&self) -> &str {
102        &self.state
103    }
104
105    /// Block (on a worker thread) until the browser redirects back, or
106    /// the timeout fires.
107    pub async fn wait(self) -> Result<HandshakeOutcome> {
108        let PendingHandshake {
109            listener,
110            state,
111            timeout,
112            ..
113        } = self;
114
115        // The accept loop is sync-by-nature (`std::net::TcpListener`),
116        // so we run it on a blocking worker. The deadline lives inside
117        // the worker so we never strand a thread on a permanently-open
118        // listener — set_nonblocking + sleep poll until the deadline.
119        let join = tokio::task::spawn_blocking(move || accept_loop(listener, &state, timeout));
120        match join.await {
121            Ok(result) => result,
122            Err(e) => Err(Error::BadRequest(format!(
123                "loopback worker panicked or was cancelled: {e}"
124            ))),
125        }
126    }
127}
128
129/// Bind the loopback listener and compute the platform sign-in URL.
130///
131/// This is sync because binding a `std::net::TcpListener` is sync — there's
132/// no `.await` to be had. The async work (waiting for the browser
133/// redirect) lives on [`PendingHandshake::wait`].
134pub fn loopback_handshake(base_url: &str, options: HandshakeOptions) -> Result<PendingHandshake> {
135    // Loopback only — never bind 0.0.0.0. Anything bound to a non-loopback
136    // interface could be reached by another host on the LAN for the brief
137    // window we listen, and the token in the redirect URL is a credential.
138    let listener = TcpListener::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))?;
139    let port = listener.local_addr()?.port();
140
141    let state = random_state();
142    let client = options
143        .client
144        .unwrap_or_else(|| "wavekat-platform-client".to_string());
145    let callback = format!("http://127.0.0.1:{port}/callback");
146
147    let base = base_url.trim_end_matches('/');
148    let mut url = format!(
149        "{base}/cli-login?callback={cb}&state={st}&client={cl}",
150        cb = url_encode(&callback),
151        st = url_encode(&state),
152        cl = url_encode(&client),
153    );
154    if !options.omit_source {
155        let source = options.source.unwrap_or_else(default_source);
156        if !source.is_empty() {
157            url.push_str("&source=");
158            url.push_str(&url_encode(&source));
159        }
160    }
161
162    Ok(PendingHandshake {
163        listener,
164        url,
165        state,
166        timeout: options.timeout,
167    })
168}
169
170fn accept_loop(
171    listener: TcpListener,
172    expected_state: &str,
173    timeout: Duration,
174) -> Result<HandshakeOutcome> {
175    listener.set_nonblocking(true)?;
176    let deadline = Instant::now() + timeout;
177
178    loop {
179        if Instant::now() >= deadline {
180            return Err(Error::Timeout(timeout));
181        }
182        match listener.accept() {
183            Ok((stream, _addr)) => {
184                // Switch the accepted stream back to blocking for the
185                // tiny request/response we're about to do.
186                stream.set_nonblocking(false)?;
187                match handle_callback(stream, expected_state) {
188                    Ok(HandlerResult::Got(outcome)) => return Ok(outcome),
189                    Ok(HandlerResult::KeepListening) => continue,
190                    Err(e @ Error::StateMismatch { .. }) => return Err(e),
191                    Err(e @ Error::Cancelled(_)) => return Err(e),
192                    Err(_) => {
193                        // A stray probe (favicon, devtools HEAD, etc.)
194                        // shouldn't break the real one. Keep listening.
195                        continue;
196                    }
197                }
198            }
199            Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
200                std::thread::sleep(Duration::from_millis(50));
201                continue;
202            }
203            Err(e) => return Err(e.into()),
204        }
205    }
206}
207
208enum HandlerResult {
209    Got(HandshakeOutcome),
210    KeepListening,
211}
212
213fn handle_callback(mut stream: TcpStream, expected_state: &str) -> Result<HandlerResult> {
214    stream.set_read_timeout(Some(Duration::from_secs(5))).ok();
215    stream.set_write_timeout(Some(Duration::from_secs(5))).ok();
216
217    // Read just the request line + headers. Cap at 8 KiB — the URL we
218    // care about is well under that, and we never need the body.
219    let mut reader = BufReader::new(stream.try_clone()?);
220    let mut request_line = String::new();
221    reader.read_line(&mut request_line)?;
222
223    let mut header_bytes = 0usize;
224    let mut line = String::new();
225    loop {
226        line.clear();
227        let n = reader.read_line(&mut line)?;
228        if n == 0 || line == "\r\n" || line == "\n" {
229            break;
230        }
231        header_bytes += n;
232        if header_bytes > 8192 {
233            return Err(Error::BadRequest("request headers too large".into()));
234        }
235    }
236
237    let mut parts = request_line.split_whitespace();
238    let method = parts.next().unwrap_or("");
239    let target = parts.next().unwrap_or("");
240    if method != "GET" {
241        respond(&mut stream, 405, "method not allowed", "method not allowed")?;
242        return Ok(HandlerResult::KeepListening);
243    }
244    if !target.starts_with("/callback") {
245        // Browsers fetch /favicon.ico; some OSes probe /. Reply 404 and
246        // keep listening — only /callback matters.
247        respond(&mut stream, 404, "not found", "not found")?;
248        return Ok(HandlerResult::KeepListening);
249    }
250
251    let query = target.split_once('?').map(|(_, q)| q).unwrap_or("");
252    let mut token: Option<String> = None;
253    let mut state: Option<String> = None;
254    let mut error: Option<String> = None;
255    let mut login: Option<String> = None;
256    for (k, v) in parse_query(query) {
257        match k.as_str() {
258            "token" => token = Some(v),
259            "state" => state = Some(v),
260            "error" => error = Some(v),
261            "login" => login = Some(v),
262            _ => {}
263        }
264    }
265
266    if state.as_deref() != Some(expected_state) {
267        respond(
268            &mut stream,
269            400,
270            "bad state",
271            "<h1>State mismatch</h1><p>Re-run the sign-in to start over.</p>",
272        )?;
273        return Err(Error::StateMismatch {
274            actual: state,
275            expected: expected_state.to_string(),
276        });
277    }
278
279    if let Some(err) = error {
280        respond(
281            &mut stream,
282            200,
283            "OK",
284            &format!(
285                "<h1>Login cancelled</h1><p>You can close this tab and try again.</p><p style='color:#888'>reason: {}</p>",
286                html_escape(&err),
287            ),
288        )?;
289        return Err(Error::Cancelled(err));
290    }
291
292    let Some(tok) = token else {
293        respond(&mut stream, 400, "missing token", "missing token")?;
294        return Err(Error::BadRequest("callback missing token".into()));
295    };
296
297    respond(
298        &mut stream,
299        200,
300        "OK",
301        "<!doctype html><html><head><meta charset=utf-8><title>WaveKat sign-in complete</title><style>body{font-family:system-ui,sans-serif;max-width:32rem;margin:4rem auto;padding:0 1rem;color:#1a1a1a}</style></head><body><h1>You're signed in.</h1><p>You can close this tab and return to the app.</p></body></html>",
302    )?;
303    Ok(HandlerResult::Got(HandshakeOutcome {
304        token: Token::new(tok),
305        login,
306    }))
307}
308
309fn respond(stream: &mut TcpStream, status: u16, reason: &str, body: &str) -> Result<()> {
310    let body_bytes = body.as_bytes();
311    let resp = format!(
312        "HTTP/1.1 {status} {reason}\r\nContent-Type: text/html; charset=utf-8\r\nContent-Length: {len}\r\nConnection: close\r\n\r\n",
313        len = body_bytes.len(),
314    );
315    stream.write_all(resp.as_bytes())?;
316    stream.write_all(body_bytes)?;
317    let _ = stream.flush();
318    // Drain a tiny amount of any unread bytes so the kernel doesn't RST
319    // the connection before the browser reads our response.
320    let mut sink = [0u8; 64];
321    let _ = stream.set_read_timeout(Some(Duration::from_millis(50)));
322    let _ = stream.read(&mut sink);
323    Ok(())
324}
325
326fn random_state() -> String {
327    let mut bytes = [0u8; 24];
328    rand::thread_rng().fill_bytes(&mut bytes);
329    base64url(&bytes)
330}
331
332// URL-safe base64 (no padding). Hand-rolled to avoid pulling in another
333// crate just for 24 bytes of CSRF state.
334fn base64url(bytes: &[u8]) -> String {
335    const ALPHA: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
336    let mut out = String::with_capacity((bytes.len() * 4).div_ceil(3));
337    let mut i = 0;
338    while i + 3 <= bytes.len() {
339        let n = ((bytes[i] as u32) << 16) | ((bytes[i + 1] as u32) << 8) | (bytes[i + 2] as u32);
340        out.push(ALPHA[((n >> 18) & 0x3f) as usize] as char);
341        out.push(ALPHA[((n >> 12) & 0x3f) as usize] as char);
342        out.push(ALPHA[((n >> 6) & 0x3f) as usize] as char);
343        out.push(ALPHA[(n & 0x3f) as usize] as char);
344        i += 3;
345    }
346    let rem = bytes.len() - i;
347    if rem == 1 {
348        let n = (bytes[i] as u32) << 16;
349        out.push(ALPHA[((n >> 18) & 0x3f) as usize] as char);
350        out.push(ALPHA[((n >> 12) & 0x3f) as usize] as char);
351    } else if rem == 2 {
352        let n = ((bytes[i] as u32) << 16) | ((bytes[i + 1] as u32) << 8);
353        out.push(ALPHA[((n >> 18) & 0x3f) as usize] as char);
354        out.push(ALPHA[((n >> 12) & 0x3f) as usize] as char);
355        out.push(ALPHA[((n >> 6) & 0x3f) as usize] as char);
356    }
357    out
358}
359
360fn default_source() -> String {
361    std::env::var("HOSTNAME")
362        .ok()
363        .or_else(|| hostname().ok())
364        .unwrap_or_else(|| "unknown-host".to_string())
365}
366
367#[cfg(unix)]
368fn hostname() -> Result<String> {
369    let out = std::process::Command::new("hostname").output()?;
370    if !out.status.success() {
371        return Err(Error::BadRequest("hostname exited non-zero".into()));
372    }
373    Ok(String::from_utf8_lossy(&out.stdout).trim().to_string())
374}
375
376#[cfg(not(unix))]
377fn hostname() -> Result<String> {
378    std::env::var("COMPUTERNAME")
379        .map_err(|e| Error::BadRequest(format!("COMPUTERNAME not set: {e}")))
380}
381
382fn url_encode(s: &str) -> String {
383    url::form_urlencoded::byte_serialize(s.as_bytes()).collect()
384}
385
386fn parse_query(q: &str) -> Vec<(String, String)> {
387    url::form_urlencoded::parse(q.as_bytes())
388        .map(|(k, v)| (k.into_owned(), v.into_owned()))
389        .collect()
390}
391
392fn html_escape(s: &str) -> String {
393    let mut out = String::with_capacity(s.len());
394    for c in s.chars() {
395        match c {
396            '<' => out.push_str("&lt;"),
397            '>' => out.push_str("&gt;"),
398            '&' => out.push_str("&amp;"),
399            '"' => out.push_str("&quot;"),
400            '\'' => out.push_str("&#39;"),
401            _ => out.push(c),
402        }
403    }
404    out
405}
406
407#[cfg(test)]
408mod tests {
409    use super::*;
410
411    // Carried verbatim from `wavekat-cli/src/commands/login.rs`.
412
413    #[test]
414    fn base64url_rfc_vectors() {
415        assert_eq!(base64url(b""), "");
416        assert_eq!(base64url(b"f"), "Zg");
417        assert_eq!(base64url(b"fo"), "Zm8");
418        assert_eq!(base64url(b"foo"), "Zm9v");
419        assert_eq!(base64url(b"foob"), "Zm9vYg");
420        assert_eq!(base64url(b"fooba"), "Zm9vYmE");
421        assert_eq!(base64url(b"foobar"), "Zm9vYmFy");
422    }
423
424    #[test]
425    fn base64url_uses_url_safe_alphabet() {
426        assert_eq!(base64url(&[0xfb, 0xff, 0xff]), "-___");
427        let big: Vec<u8> = (0u8..=255).collect();
428        let out = base64url(&big);
429        assert!(!out.contains('+'));
430        assert!(!out.contains('/'));
431        assert!(!out.contains('='));
432    }
433
434    #[test]
435    fn random_state_shape() {
436        let s = random_state();
437        assert_eq!(s.len(), 32);
438        let alpha: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
439        for b in s.as_bytes() {
440            assert!(alpha.contains(b), "unexpected byte {b:#x} in state");
441        }
442    }
443
444    #[test]
445    fn random_state_is_not_constant() {
446        assert_ne!(random_state(), random_state());
447    }
448
449    #[test]
450    fn html_escape_handles_metacharacters() {
451        assert_eq!(
452            html_escape("<a href=\"x\">it's & ok</a>"),
453            "&lt;a href=&quot;x&quot;&gt;it&#39;s &amp; ok&lt;/a&gt;",
454        );
455        assert_eq!(html_escape("plain text"), "plain text");
456        assert_eq!(html_escape(""), "");
457    }
458
459    // New tests for the v0.0.1 surface.
460
461    #[test]
462    fn handshake_options_default_has_sensible_timeout() {
463        let opts = HandshakeOptions::default();
464        assert_eq!(opts.timeout, Duration::from_secs(300));
465        assert!(opts.client.is_none());
466        assert!(opts.source.is_none());
467        assert!(!opts.omit_source);
468    }
469
470    #[test]
471    fn default_source_falls_back_to_a_hostname() {
472        let s = default_source();
473        assert!(!s.is_empty(), "should never produce an empty source");
474    }
475
476    #[test]
477    fn loopback_handshake_returns_url_with_loopback_callback() {
478        let pending =
479            loopback_handshake("https://platform.wavekat.com", HandshakeOptions::default())
480                .expect("bind loopback");
481        let url = pending.url();
482        assert!(url.starts_with("https://platform.wavekat.com/cli-login?"));
483        assert!(url.contains("127.0.0.1"), "{url}");
484        assert!(url.contains(&format!(
485            "state={}",
486            url::form_urlencoded::byte_serialize(pending.state().as_bytes()).collect::<String>()
487        )));
488        // Default `client` is the crate name.
489        assert!(
490            url.contains("client=wavekat-platform-client"),
491            "expected client=wavekat-platform-client in {url}",
492        );
493        // Default options send a source (the hostname).
494        assert!(url.contains("&source="), "expected &source=... in {url}");
495    }
496
497    #[test]
498    fn loopback_handshake_uses_explicit_client_and_source() {
499        let pending = loopback_handshake(
500            "https://platform.wavekat.com",
501            HandshakeOptions {
502                client: Some("wavekat-voice".into()),
503                source: Some("studio-mac".into()),
504                ..Default::default()
505            },
506        )
507        .expect("bind loopback");
508        let url = pending.url();
509        assert!(url.contains("client=wavekat-voice"), "{url}");
510        assert!(url.contains("source=studio-mac"), "{url}");
511    }
512
513    #[test]
514    fn loopback_handshake_omits_source_when_requested() {
515        let pending = loopback_handshake(
516            "https://platform.wavekat.com",
517            HandshakeOptions {
518                client: Some("wavekat-voice".into()),
519                omit_source: true,
520                ..Default::default()
521            },
522        )
523        .expect("bind loopback");
524        let url = pending.url();
525        assert!(url.contains("client=wavekat-voice"), "{url}");
526        assert!(!url.contains("source="), "should not include source: {url}");
527    }
528
529    #[test]
530    fn loopback_handshake_strips_trailing_slash() {
531        let pending = loopback_handshake("https://platform.wavekat.com/", Default::default())
532            .expect("bind loopback");
533        assert!(pending
534            .url()
535            .starts_with("https://platform.wavekat.com/cli-login?"));
536    }
537}