stynx_code_auth/infrastructure/oauth/
callback_server.rs1use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
2use stynx_code_errors::{AppError, AppResult};
3use tokio::io::{AsyncReadExt, AsyncWriteExt};
4use tokio::net::TcpListener;
5
6pub fn generate_state() -> String {
7 let mut bytes = [0u8; 32];
8 getrandom::getrandom(&mut bytes)
9 .expect("OS CSPRNG unavailable — cannot generate OAuth state safely");
10 URL_SAFE_NO_PAD.encode(bytes)
11}
12
13pub async fn run_callback_server(port: u16, expected_state: &str) -> AppResult<String> {
14 let addr = format!("127.0.0.1:{port}");
15 let listener = TcpListener::bind(&addr)
16 .await
17 .map_err(|e| AppError::Provider(format!("failed to bind callback server on {addr}: {e}")))?;
18
19 let (mut stream, _) = listener
20 .accept()
21 .await
22 .map_err(|e| AppError::Provider(format!("callback server accept error: {e}")))?;
23
24 let mut buf = vec![0u8; 4096];
25 let n = stream
26 .read(&mut buf)
27 .await
28 .map_err(|e| AppError::Provider(format!("callback server read error: {e}")))?;
29
30 let request = String::from_utf8_lossy(&buf[..n]);
31 let params = parse_query_params(&request);
32
33 let returned_state = params.iter()
34 .find_map(|(k, v)| if k == "state" { Some(v.as_str()) } else { None })
35 .ok_or_else(|| AppError::Provider(
36 "missing `state` parameter in OAuth callback (CSRF protection violated)".to_string()
37 ))?;
38
39 if !constant_time_eq(returned_state.as_bytes(), expected_state.as_bytes()) {
40 return Err(AppError::Provider(
41 "OAuth `state` mismatch — possible CSRF attack, refusing to continue".to_string(),
42 ));
43 }
44
45 let code = params.iter()
46 .find_map(|(k, v)| if k == "code" { Some(v.clone()) } else { None })
47 .ok_or_else(|| AppError::Provider("no `code` parameter in OAuth callback".to_string()))?;
48
49 let response = "HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\nContent-Length: 2\r\n\r\nOK";
50 stream
51 .write_all(response.as_bytes())
52 .await
53 .map_err(|e| AppError::Provider(format!("callback server write error: {e}")))?;
54
55 Ok(code)
56}
57
58fn parse_query_params(request: &str) -> Vec<(String, String)> {
59 let Some(first_line) = request.lines().next() else { return Vec::new(); };
60 let Some(path) = first_line.split_whitespace().nth(1) else { return Vec::new(); };
61 let Some(query) = path.split_once('?').map(|(_, q)| q) else { return Vec::new(); };
62 query
63 .split('&')
64 .filter_map(|pair| {
65 let (k, v) = pair.split_once('=')?;
66 Some((k.to_string(), url_decode(v)))
67 })
68 .collect()
69}
70
71fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
72 if a.len() != b.len() { return false; }
73 let mut acc = 0u8;
74 for (x, y) in a.iter().zip(b.iter()) {
75 acc |= x ^ y;
76 }
77 acc == 0
78}
79
80fn url_decode(s: &str) -> String {
81 let mut result = String::with_capacity(s.len());
82 let mut chars = s.chars().peekable();
83 while let Some(ch) = chars.next() {
84 if ch == '%' {
85 let hi = chars.next().unwrap_or('0');
86 let lo = chars.next().unwrap_or('0');
87 let hex = format!("{hi}{lo}");
88 if let Ok(byte) = u8::from_str_radix(&hex, 16) {
89 result.push(byte as char);
90 } else {
91 result.push('%');
92 result.push(hi);
93 result.push(lo);
94 }
95 } else if ch == '+' {
96 result.push(' ');
97 } else {
98 result.push(ch);
99 }
100 }
101 result
102}