Skip to main content

trojan_server/
ws.rs

1//! WebSocket transport support.
2//!
3//! This module provides WebSocket upgrade handling for the server.
4//! The `WsIo` adapter is provided by `trojan-core::transport`.
5
6use bytes::Bytes;
7use tokio::io::{AsyncRead, AsyncWrite};
8use tokio_tungstenite::{
9    WebSocketStream, accept_hdr_async_with_config,
10    tungstenite::{
11        handshake::server::{Request, Response},
12        protocol::WebSocketConfig,
13    },
14};
15use tracing::{debug, warn};
16use trojan_config::WebSocketConfig as WsCfg;
17
18use crate::error::ServerError;
19use crate::util::PrefixedStream;
20
21// Re-export WsIo from trojan-core for convenience
22pub use trojan_core::transport::WsIo;
23
24/// Initial buffer size for reading HTTP headers during WebSocket upgrade.
25pub const INITIAL_BUFFER_SIZE: usize = 2048;
26
27const HTTP_HEADER_END: &[u8] = b"\r\n\r\n";
28
29/// Result of inspecting buffered bytes for WebSocket upgrade.
30pub enum WsInspect {
31    /// Need more data to determine protocol.
32    NeedMore,
33    /// Not HTTP traffic, proceed as raw Trojan.
34    NotHttp,
35    /// HTTP but not WebSocket upgrade, fallback to HTTP backend.
36    HttpFallback,
37    /// Valid WebSocket upgrade request.
38    Upgrade,
39    /// Reject with reason (e.g., path/host mismatch).
40    Reject(&'static str),
41}
42
43/// Inspect buffered bytes for WebSocket upgrade in mixed mode.
44pub fn inspect_mixed(buf: &[u8], cfg: &WsCfg) -> WsInspect {
45    // Quick check: if the buffer doesn't start with a plausible HTTP method,
46    // it's definitely not HTTP. Trojan headers start with a hex hash which
47    // will never match these prefixes (except edge cases caught below).
48    if buf.len() >= 3 && !could_be_http_method(buf) {
49        return WsInspect::NotHttp;
50    }
51
52    let header_end = find_header_end(buf);
53    if header_end.is_none() {
54        // If we've read a significant amount of data without finding \r\n\r\n,
55        // and it doesn't look like HTTP is still being received, treat as not HTTP.
56        // HTTP request lines are typically under 8KB. Trojan headers are ~70 bytes.
57        if buf.len() >= 256 {
58            return WsInspect::NotHttp;
59        }
60        return WsInspect::NeedMore;
61    }
62    let header_end = header_end.unwrap();
63    let header_bytes = &buf[..header_end];
64    let header_str = match std::str::from_utf8(header_bytes) {
65        Ok(v) => v,
66        Err(_) => return WsInspect::NotHttp,
67    };
68    let mut lines = header_str.split("\r\n");
69    let request_line = match lines.next() {
70        Some(v) => v,
71        None => return WsInspect::NotHttp,
72    };
73    let mut parts = request_line.split_whitespace();
74    let method = parts.next().unwrap_or("");
75    let path = parts.next().unwrap_or("");
76    let version = parts.next().unwrap_or("");
77    if !version.starts_with("HTTP/") {
78        return WsInspect::NotHttp;
79    }
80    if method != "GET" {
81        return WsInspect::HttpFallback;
82    }
83
84    let mut upgrade = false;
85    let mut connection_upgrade = false;
86    let mut ws_key = false;
87    let mut host: Option<&str> = None;
88
89    for line in lines {
90        if let Some((name, value)) = line.split_once(':') {
91            let name = name.trim().to_ascii_lowercase();
92            let value_trim = value.trim();
93            let value_lower = value_trim.to_ascii_lowercase();
94            match name.as_str() {
95                "upgrade" => {
96                    if value_lower.contains("websocket") {
97                        upgrade = true;
98                    }
99                }
100                "connection" => {
101                    if value_lower.contains("upgrade") {
102                        connection_upgrade = true;
103                    }
104                }
105                "sec-websocket-key" => {
106                    if !value_trim.is_empty() {
107                        ws_key = true;
108                    }
109                }
110                "host" => {
111                    host = Some(value_trim);
112                }
113                _ => {}
114            }
115        }
116    }
117
118    if !upgrade || !connection_upgrade || !ws_key {
119        return WsInspect::HttpFallback;
120    }
121
122    if !path_matches(cfg, path) || !host_matches(cfg, host) {
123        return WsInspect::Reject("websocket path/host mismatch");
124    }
125
126    WsInspect::Upgrade
127}
128
129/// Accept a WebSocket upgrade on the given stream.
130pub async fn accept_ws<S>(
131    stream: S,
132    initial: Bytes,
133    cfg: &WsCfg,
134) -> Result<WebSocketStream<PrefixedStream<S>>, ServerError>
135where
136    S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
137{
138    let max_frame = if cfg.max_frame_bytes == 0 {
139        None
140    } else {
141        Some(cfg.max_frame_bytes)
142    };
143    let ws_cfg = WebSocketConfig {
144        max_frame_size: max_frame,
145        max_message_size: max_frame,
146        ..WebSocketConfig::default()
147    };
148    let prefixed = PrefixedStream::new(initial, stream);
149    let ws = accept_hdr_async_with_config(
150        prefixed,
151        |req: &Request, resp: Response| {
152            debug!(path = %req.uri().path(), "websocket upgrade");
153            Ok(resp)
154        },
155        Some(ws_cfg),
156    )
157    .await
158    .map_err(|e| {
159        ServerError::Io(std::io::Error::new(
160            std::io::ErrorKind::InvalidData,
161            format!("websocket handshake failed: {e}"),
162        ))
163    })?;
164    Ok(ws)
165}
166
167/// Send an HTTP 400 Bad Request response to reject the connection.
168pub async fn send_reject<S>(mut stream: S, reason: &'static str) -> Result<(), ServerError>
169where
170    S: AsyncWrite + Unpin,
171{
172    warn!(reason, "websocket rejected");
173    let response = b"HTTP/1.1 400 Bad Request\r\nContent-Length: 0\r\n\r\n";
174    tokio::io::AsyncWriteExt::write_all(&mut stream, response).await?;
175    Ok(())
176}
177
178/// Check if the buffer could plausibly start with an HTTP method.
179/// HTTP methods: GET, POST, PUT, DELETE, HEAD, OPTIONS, PATCH, CONNECT, TRACE.
180fn could_be_http_method(buf: &[u8]) -> bool {
181    buf.starts_with(b"GET")
182        || buf.starts_with(b"POS")
183        || buf.starts_with(b"PUT")
184        || buf.starts_with(b"DEL")
185        || buf.starts_with(b"HEA")
186        || buf.starts_with(b"OPT")
187        || buf.starts_with(b"PAT")
188        || buf.starts_with(b"CON")
189        || buf.starts_with(b"TRA")
190}
191
192fn find_header_end(buf: &[u8]) -> Option<usize> {
193    buf.windows(HTTP_HEADER_END.len())
194        .position(|w| w == HTTP_HEADER_END)
195        .map(|idx| idx + HTTP_HEADER_END.len())
196}
197
198fn path_matches(cfg: &WsCfg, path: &str) -> bool {
199    let path_only = path.split('?').next().unwrap_or("");
200    path_only == cfg.path
201}
202
203fn host_matches(cfg: &WsCfg, host: Option<&str>) -> bool {
204    let expected = match cfg.host.as_deref() {
205        Some(v) => v,
206        None => return true,
207    };
208    let host = match host {
209        Some(v) => v,
210        None => return false,
211    };
212    let host_only = host.split(':').next().unwrap_or("");
213    host_only.eq_ignore_ascii_case(expected)
214}