Skip to main content

stynx_code_bridge/infrastructure/
websocket_transport.rs

1use base64::{Engine as _, engine::general_purpose::STANDARD};
2use stynx_code_errors::{AppError, AppResult};
3use tokio::net::TcpStream;
4use tokio::io::{AsyncReadExt, AsyncWriteExt};
5use crate::domain::bridge_types::BridgeMessage;
6
7pub struct WebSocketTransport;
8
9impl WebSocketTransport {
10    pub fn new() -> Self {
11        Self
12    }
13
14    pub async fn accept_connection(&self, mut stream: TcpStream) -> AppResult<BridgeConnection> {
15
16        let mut buf = vec![0u8; 4096];
17        let n = stream.read(&mut buf).await
18            .map_err(|e| AppError::Internal(anyhow::anyhow!(e)))?;
19        let request = String::from_utf8_lossy(&buf[..n]);
20
21        let key = request.lines()
22            .find(|l| l.to_lowercase().starts_with("sec-websocket-key:"))
23            .and_then(|l| l.splitn(2, ':').nth(1))
24            .map(|s| s.trim())
25            .ok_or_else(|| AppError::BadRequest("missing Sec-WebSocket-Key".into()))?;
26
27        let accept = compute_accept_key(key);
28
29        let response = format!(
30            "HTTP/1.1 101 Switching Protocols\r\n\
31             Upgrade: websocket\r\n\
32             Connection: Upgrade\r\n\
33             Sec-WebSocket-Accept: {}\r\n\
34             \r\n",
35            accept
36        );
37
38        stream.write_all(response.as_bytes()).await
39            .map_err(|e| AppError::Internal(anyhow::anyhow!(e)))?;
40
41        Ok(BridgeConnection { stream })
42    }
43}
44
45impl Default for WebSocketTransport {
46    fn default() -> Self {
47        Self::new()
48    }
49}
50
51fn compute_accept_key(key: &str) -> String {
52    const MAGIC: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
53    let input = format!("{}{}", key, MAGIC);
54    let hash = sha1_bytes(input.as_bytes());
55    STANDARD.encode(hash)
56}
57
58fn sha1_bytes(data: &[u8]) -> [u8; 20] {
59
60    let mut h: [u32; 5] = [0x67452301, 0xEFCDAB89, 0x98BADCFE, 0x10325476, 0xC3D2E1F0];
61    let bit_len = (data.len() as u64) * 8;
62
63    let mut msg = data.to_vec();
64    msg.push(0x80);
65    while msg.len() % 64 != 56 {
66        msg.push(0x00);
67    }
68    for i in (0..8).rev() {
69        msg.push((bit_len >> (i * 8)) as u8);
70    }
71
72    for chunk in msg.chunks(64) {
73        let mut w = [0u32; 80];
74        for i in 0..16 {
75            w[i] = u32::from_be_bytes([chunk[i*4], chunk[i*4+1], chunk[i*4+2], chunk[i*4+3]]);
76        }
77        for i in 16..80 {
78            w[i] = (w[i-3] ^ w[i-8] ^ w[i-14] ^ w[i-16]).rotate_left(1);
79        }
80
81        let (mut a, mut b, mut c, mut d, mut e) = (h[0], h[1], h[2], h[3], h[4]);
82
83        for i in 0..80 {
84            let (f, k) = match i {
85                0..=19  => ((b & c) | ((!b) & d),          0x5A827999u32),
86                20..=39 => (b ^ c ^ d,                      0x6ED9EBA1u32),
87                40..=59 => ((b & c) | (b & d) | (c & d),   0x8F1BBCDCu32),
88                _       => (b ^ c ^ d,                      0xCA62C1D6u32),
89            };
90            let temp = a.rotate_left(5)
91                .wrapping_add(f)
92                .wrapping_add(e)
93                .wrapping_add(k)
94                .wrapping_add(w[i]);
95            e = d; d = c; c = b.rotate_left(30); b = a; a = temp;
96        }
97
98        h[0] = h[0].wrapping_add(a);
99        h[1] = h[1].wrapping_add(b);
100        h[2] = h[2].wrapping_add(c);
101        h[3] = h[3].wrapping_add(d);
102        h[4] = h[4].wrapping_add(e);
103    }
104
105    let mut out = [0u8; 20];
106    for (i, &val) in h.iter().enumerate() {
107        out[i*4..i*4+4].copy_from_slice(&val.to_be_bytes());
108    }
109    out
110}
111
112pub struct BridgeConnection {
113    stream: TcpStream,
114}
115
116impl BridgeConnection {
117    pub async fn send(&mut self, msg: BridgeMessage) -> AppResult<()> {
118        let payload = serde_json::to_vec(&msg)
119            .map_err(|e| AppError::Internal(anyhow::anyhow!(e)))?;
120        let frame = encode_websocket_frame(&payload);
121        self.stream.write_all(&frame).await
122            .map_err(|e| AppError::Internal(anyhow::anyhow!(e)))?;
123        Ok(())
124    }
125
126    pub async fn recv(&mut self) -> Option<BridgeMessage> {
127        let payload = decode_websocket_frame(&mut self.stream).await?;
128        serde_json::from_slice(&payload).ok()
129    }
130}
131
132fn encode_websocket_frame(payload: &[u8]) -> Vec<u8> {
133    let mut frame = Vec::new();
134
135    frame.push(0x81);
136    let len = payload.len();
137    if len <= 125 {
138        frame.push(len as u8);
139    } else if len <= 65535 {
140        frame.push(126);
141        frame.push((len >> 8) as u8);
142        frame.push(len as u8);
143    } else {
144        frame.push(127);
145        for i in (0..8).rev() {
146            frame.push((len >> (i * 8)) as u8);
147        }
148    }
149    frame.extend_from_slice(payload);
150    frame
151}
152
153async fn decode_websocket_frame(stream: &mut TcpStream) -> Option<Vec<u8>> {
154    let mut header = [0u8; 2];
155    stream.read_exact(&mut header).await.ok()?;
156
157    let _fin = (header[0] & 0x80) != 0;
158    let opcode = header[0] & 0x0F;
159
160    if opcode == 8 {
161        return None;
162    }
163
164    let masked = (header[1] & 0x80) != 0;
165    let mut payload_len = (header[1] & 0x7F) as usize;
166
167    if payload_len == 126 {
168        let mut ext = [0u8; 2];
169        stream.read_exact(&mut ext).await.ok()?;
170        payload_len = u16::from_be_bytes(ext) as usize;
171    } else if payload_len == 127 {
172        let mut ext = [0u8; 8];
173        stream.read_exact(&mut ext).await.ok()?;
174        payload_len = u64::from_be_bytes(ext) as usize;
175    }
176
177    let mask = if masked {
178        let mut m = [0u8; 4];
179        stream.read_exact(&mut m).await.ok()?;
180        Some(m)
181    } else {
182        None
183    };
184
185    let mut payload = vec![0u8; payload_len];
186    stream.read_exact(&mut payload).await.ok()?;
187
188    if let Some(mask) = mask {
189        for (i, byte) in payload.iter_mut().enumerate() {
190            *byte ^= mask[i % 4];
191        }
192    }
193
194    Some(payload)
195}