Skip to main content

stynx_code_auth/infrastructure/oauth/
callback_server.rs

1use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
2use stynx_code_errors::{AppError, AppResult};
3use tokio::io::{AsyncReadExt, AsyncWriteExt};
4use tokio::net::TcpListener;
5
6pub fn generate_state() -> String {
7    let mut bytes = [0u8; 32];
8    getrandom::getrandom(&mut bytes)
9        .expect("OS CSPRNG unavailable — cannot generate OAuth state safely");
10    URL_SAFE_NO_PAD.encode(bytes)
11}
12
13pub async fn run_callback_server(port: u16, expected_state: &str) -> AppResult<String> {
14    let addr = format!("127.0.0.1:{port}");
15    let listener = TcpListener::bind(&addr)
16        .await
17        .map_err(|e| AppError::Provider(format!("failed to bind callback server on {addr}: {e}")))?;
18
19    let (mut stream, _) = listener
20        .accept()
21        .await
22        .map_err(|e| AppError::Provider(format!("callback server accept error: {e}")))?;
23
24    let mut buf = vec![0u8; 4096];
25    let n = stream
26        .read(&mut buf)
27        .await
28        .map_err(|e| AppError::Provider(format!("callback server read error: {e}")))?;
29
30    let request = String::from_utf8_lossy(&buf[..n]);
31    let params = parse_query_params(&request);
32
33    let returned_state = params.iter()
34        .find_map(|(k, v)| if k == "state" { Some(v.as_str()) } else { None })
35        .ok_or_else(|| AppError::Provider(
36            "missing `state` parameter in OAuth callback (CSRF protection violated)".to_string()
37        ))?;
38
39    if !constant_time_eq(returned_state.as_bytes(), expected_state.as_bytes()) {
40        return Err(AppError::Provider(
41            "OAuth `state` mismatch — possible CSRF attack, refusing to continue".to_string(),
42        ));
43    }
44
45    let code = params.iter()
46        .find_map(|(k, v)| if k == "code" { Some(v.clone()) } else { None })
47        .ok_or_else(|| AppError::Provider("no `code` parameter in OAuth callback".to_string()))?;
48
49    let response = "HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\nContent-Length: 2\r\n\r\nOK";
50    stream
51        .write_all(response.as_bytes())
52        .await
53        .map_err(|e| AppError::Provider(format!("callback server write error: {e}")))?;
54
55    Ok(code)
56}
57
58fn parse_query_params(request: &str) -> Vec<(String, String)> {
59    let Some(first_line) = request.lines().next() else { return Vec::new(); };
60    let Some(path) = first_line.split_whitespace().nth(1) else { return Vec::new(); };
61    let Some(query) = path.split_once('?').map(|(_, q)| q) else { return Vec::new(); };
62    query
63        .split('&')
64        .filter_map(|pair| {
65            let (k, v) = pair.split_once('=')?;
66            Some((k.to_string(), url_decode(v)))
67        })
68        .collect()
69}
70
71fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
72    if a.len() != b.len() { return false; }
73    let mut acc = 0u8;
74    for (x, y) in a.iter().zip(b.iter()) {
75        acc |= x ^ y;
76    }
77    acc == 0
78}
79
80fn url_decode(s: &str) -> String {
81    let mut result = String::with_capacity(s.len());
82    let mut chars = s.chars().peekable();
83    while let Some(ch) = chars.next() {
84        if ch == '%' {
85            let hi = chars.next().unwrap_or('0');
86            let lo = chars.next().unwrap_or('0');
87            let hex = format!("{hi}{lo}");
88            if let Ok(byte) = u8::from_str_radix(&hex, 16) {
89                result.push(byte as char);
90            } else {
91                result.push('%');
92                result.push(hi);
93                result.push(lo);
94            }
95        } else if ch == '+' {
96            result.push(' ');
97        } else {
98            result.push(ch);
99        }
100    }
101    result
102}