1use bytes::Bytes;
7use tokio::io::{AsyncRead, AsyncWrite};
8use tokio_tungstenite::{
9 WebSocketStream, accept_hdr_async_with_config,
10 tungstenite::{
11 handshake::server::{Request, Response},
12 protocol::WebSocketConfig,
13 },
14};
15use tracing::{debug, warn};
16use trojan_config::WebSocketConfig as WsCfg;
17
18use crate::error::ServerError;
19use crate::util::PrefixedStream;
20
21pub use trojan_core::transport::WsIo;
23
24pub const INITIAL_BUFFER_SIZE: usize = 2048;
26
27const HTTP_HEADER_END: &[u8] = b"\r\n\r\n";
28
29pub enum WsInspect {
31 NeedMore,
33 NotHttp,
35 HttpFallback,
37 Upgrade,
39 Reject(&'static str),
41}
42
43pub fn inspect_mixed(buf: &[u8], cfg: &WsCfg) -> WsInspect {
45 if buf.len() >= 3 && !could_be_http_method(buf) {
49 return WsInspect::NotHttp;
50 }
51
52 let header_end = find_header_end(buf);
53 if header_end.is_none() {
54 if buf.len() >= 256 {
58 return WsInspect::NotHttp;
59 }
60 return WsInspect::NeedMore;
61 }
62 let header_end = header_end.unwrap();
63 let header_bytes = &buf[..header_end];
64 let header_str = match std::str::from_utf8(header_bytes) {
65 Ok(v) => v,
66 Err(_) => return WsInspect::NotHttp,
67 };
68 let mut lines = header_str.split("\r\n");
69 let request_line = match lines.next() {
70 Some(v) => v,
71 None => return WsInspect::NotHttp,
72 };
73 let mut parts = request_line.split_whitespace();
74 let method = parts.next().unwrap_or("");
75 let path = parts.next().unwrap_or("");
76 let version = parts.next().unwrap_or("");
77 if !version.starts_with("HTTP/") {
78 return WsInspect::NotHttp;
79 }
80 if method != "GET" {
81 return WsInspect::HttpFallback;
82 }
83
84 let mut upgrade = false;
85 let mut connection_upgrade = false;
86 let mut ws_key = false;
87 let mut host: Option<&str> = None;
88
89 for line in lines {
90 if let Some((name, value)) = line.split_once(':') {
91 let name = name.trim().to_ascii_lowercase();
92 let value_trim = value.trim();
93 let value_lower = value_trim.to_ascii_lowercase();
94 match name.as_str() {
95 "upgrade" => {
96 if value_lower.contains("websocket") {
97 upgrade = true;
98 }
99 }
100 "connection" => {
101 if value_lower.contains("upgrade") {
102 connection_upgrade = true;
103 }
104 }
105 "sec-websocket-key" => {
106 if !value_trim.is_empty() {
107 ws_key = true;
108 }
109 }
110 "host" => {
111 host = Some(value_trim);
112 }
113 _ => {}
114 }
115 }
116 }
117
118 if !upgrade || !connection_upgrade || !ws_key {
119 return WsInspect::HttpFallback;
120 }
121
122 if !path_matches(cfg, path) || !host_matches(cfg, host) {
123 return WsInspect::Reject("websocket path/host mismatch");
124 }
125
126 WsInspect::Upgrade
127}
128
129pub async fn accept_ws<S>(
131 stream: S,
132 initial: Bytes,
133 cfg: &WsCfg,
134) -> Result<WebSocketStream<PrefixedStream<S>>, ServerError>
135where
136 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
137{
138 let max_frame = if cfg.max_frame_bytes == 0 {
139 None
140 } else {
141 Some(cfg.max_frame_bytes)
142 };
143 let ws_cfg = WebSocketConfig {
144 max_frame_size: max_frame,
145 max_message_size: max_frame,
146 ..WebSocketConfig::default()
147 };
148 let prefixed = PrefixedStream::new(initial, stream);
149 let ws = accept_hdr_async_with_config(
150 prefixed,
151 |req: &Request, resp: Response| {
152 debug!(path = %req.uri().path(), "websocket upgrade");
153 Ok(resp)
154 },
155 Some(ws_cfg),
156 )
157 .await
158 .map_err(|e| {
159 ServerError::Io(std::io::Error::new(
160 std::io::ErrorKind::InvalidData,
161 format!("websocket handshake failed: {e}"),
162 ))
163 })?;
164 Ok(ws)
165}
166
167pub async fn send_reject<S>(mut stream: S, reason: &'static str) -> Result<(), ServerError>
169where
170 S: AsyncWrite + Unpin,
171{
172 warn!(reason, "websocket rejected");
173 let response = b"HTTP/1.1 400 Bad Request\r\nContent-Length: 0\r\n\r\n";
174 tokio::io::AsyncWriteExt::write_all(&mut stream, response).await?;
175 Ok(())
176}
177
178fn could_be_http_method(buf: &[u8]) -> bool {
181 buf.starts_with(b"GET")
182 || buf.starts_with(b"POS")
183 || buf.starts_with(b"PUT")
184 || buf.starts_with(b"DEL")
185 || buf.starts_with(b"HEA")
186 || buf.starts_with(b"OPT")
187 || buf.starts_with(b"PAT")
188 || buf.starts_with(b"CON")
189 || buf.starts_with(b"TRA")
190}
191
192fn find_header_end(buf: &[u8]) -> Option<usize> {
193 buf.windows(HTTP_HEADER_END.len())
194 .position(|w| w == HTTP_HEADER_END)
195 .map(|idx| idx + HTTP_HEADER_END.len())
196}
197
198fn path_matches(cfg: &WsCfg, path: &str) -> bool {
199 let path_only = path.split('?').next().unwrap_or("");
200 path_only == cfg.path
201}
202
203fn host_matches(cfg: &WsCfg, host: Option<&str>) -> bool {
204 let expected = match cfg.host.as_deref() {
205 Some(v) => v,
206 None => return true,
207 };
208 let host = match host {
209 Some(v) => v,
210 None => return false,
211 };
212 let host_only = host.split(':').next().unwrap_or("");
213 host_only.eq_ignore_ascii_case(expected)
214}