1use bytes::Bytes;
7use tokio::io::{AsyncRead, AsyncWrite};
8use tokio_tungstenite::{
9 WebSocketStream, accept_hdr_async_with_config,
10 tungstenite::{
11 handshake::server::{Request, Response},
12 protocol::WebSocketConfig,
13 },
14};
15use tracing::{debug, warn};
16use trojan_config::WebSocketConfig as WsCfg;
17
18use crate::error::ServerError;
19use crate::util::PrefixedStream;
20
21pub use trojan_core::transport::WsIo;
23
24pub const INITIAL_BUFFER_SIZE: usize = 2048;
26
27const HTTP_HEADER_END: &[u8] = b"\r\n\r\n";
28
29#[derive(Debug)]
31pub enum WsInspect {
32 NeedMore,
34 NotHttp,
36 HttpFallback,
38 Upgrade,
40 Reject(&'static str),
42}
43
44pub fn inspect_mixed(buf: &[u8], cfg: &WsCfg) -> WsInspect {
46 if buf.len() >= 3 && !could_be_http_method(buf) {
50 return WsInspect::NotHttp;
51 }
52
53 let header_end = find_header_end(buf);
54 if header_end.is_none() {
55 if buf.len() >= 256 {
59 return WsInspect::NotHttp;
60 }
61 return WsInspect::NeedMore;
62 }
63 let header_end = header_end.unwrap();
64 let header_bytes = &buf[..header_end];
65 let header_str = match std::str::from_utf8(header_bytes) {
66 Ok(v) => v,
67 Err(_) => return WsInspect::NotHttp,
68 };
69 let mut lines = header_str.split("\r\n");
70 let request_line = match lines.next() {
71 Some(v) => v,
72 None => return WsInspect::NotHttp,
73 };
74 let mut parts = request_line.split_whitespace();
75 let method = parts.next().unwrap_or("");
76 let path = parts.next().unwrap_or("");
77 let version = parts.next().unwrap_or("");
78 if !version.starts_with("HTTP/") {
79 return WsInspect::NotHttp;
80 }
81 if method != "GET" {
82 return WsInspect::HttpFallback;
83 }
84
85 let mut upgrade = false;
86 let mut connection_upgrade = false;
87 let mut ws_key = false;
88 let mut host: Option<&str> = None;
89
90 for line in lines {
91 if let Some((name, value)) = line.split_once(':') {
92 let name = name.trim().to_ascii_lowercase();
93 let value_trim = value.trim();
94 let value_lower = value_trim.to_ascii_lowercase();
95 match name.as_str() {
96 "upgrade" => {
97 if value_lower.contains("websocket") {
98 upgrade = true;
99 }
100 }
101 "connection" => {
102 if value_lower.contains("upgrade") {
103 connection_upgrade = true;
104 }
105 }
106 "sec-websocket-key" => {
107 if !value_trim.is_empty() {
108 ws_key = true;
109 }
110 }
111 "host" => {
112 host = Some(value_trim);
113 }
114 _ => {}
115 }
116 }
117 }
118
119 if !upgrade || !connection_upgrade || !ws_key {
120 return WsInspect::HttpFallback;
121 }
122
123 if !path_matches(cfg, path) || !host_matches(cfg, host) {
124 return WsInspect::Reject("websocket path/host mismatch");
125 }
126
127 WsInspect::Upgrade
128}
129
130pub async fn accept_ws<S>(
132 stream: S,
133 initial: Bytes,
134 cfg: &WsCfg,
135) -> Result<WebSocketStream<PrefixedStream<S>>, ServerError>
136where
137 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
138{
139 let max_frame = if cfg.max_frame_bytes == 0 {
140 None
141 } else {
142 Some(cfg.max_frame_bytes)
143 };
144 let mut ws_cfg = WebSocketConfig::default();
145 ws_cfg.max_frame_size = max_frame;
146 ws_cfg.max_message_size = max_frame;
147 let prefixed = PrefixedStream::new(initial, stream);
148 let ws = accept_hdr_async_with_config(
149 prefixed,
150 |req: &Request, resp: Response| {
151 debug!(path = %req.uri().path(), "websocket upgrade");
152 Ok(resp)
153 },
154 Some(ws_cfg),
155 )
156 .await
157 .map_err(|e| {
158 ServerError::Io(std::io::Error::new(
159 std::io::ErrorKind::InvalidData,
160 format!("websocket handshake failed: {e}"),
161 ))
162 })?;
163 Ok(ws)
164}
165
166pub async fn send_reject<S>(mut stream: S, reason: &'static str) -> Result<(), ServerError>
168where
169 S: AsyncWrite + Unpin,
170{
171 warn!(reason, "websocket rejected");
172 let response = b"HTTP/1.1 400 Bad Request\r\nContent-Length: 0\r\n\r\n";
173 tokio::io::AsyncWriteExt::write_all(&mut stream, response).await?;
174 Ok(())
175}
176
177fn could_be_http_method(buf: &[u8]) -> bool {
180 buf.starts_with(b"GET")
181 || buf.starts_with(b"POS")
182 || buf.starts_with(b"PUT")
183 || buf.starts_with(b"DEL")
184 || buf.starts_with(b"HEA")
185 || buf.starts_with(b"OPT")
186 || buf.starts_with(b"PAT")
187 || buf.starts_with(b"CON")
188 || buf.starts_with(b"TRA")
189}
190
191fn find_header_end(buf: &[u8]) -> Option<usize> {
192 buf.windows(HTTP_HEADER_END.len())
193 .position(|w| w == HTTP_HEADER_END)
194 .map(|idx| idx + HTTP_HEADER_END.len())
195}
196
197fn path_matches(cfg: &WsCfg, path: &str) -> bool {
198 let path_only = path.split('?').next().unwrap_or("");
199 path_only == cfg.path
200}
201
202fn host_matches(cfg: &WsCfg, host: Option<&str>) -> bool {
203 let expected = match cfg.host.as_deref() {
204 Some(v) => v,
205 None => return true,
206 };
207 let host = match host {
208 Some(v) => v,
209 None => return false,
210 };
211 let host_only = host.split(':').next().unwrap_or("");
212 host_only.eq_ignore_ascii_case(expected)
213}