1use crate::error::{ProxyError, Result};
7use http::{header, Request, Response, StatusCode};
8use hyper::upgrade::OnUpgrade;
9use hyper_util::rt::TokioIo;
10use tokio::io::{AsyncRead, AsyncWrite};
11use tracing::{debug, error, info, warn};
12
13pub fn is_upgrade_request<B>(req: &Request<B>) -> bool {
17 req.headers()
18 .get(header::CONNECTION)
19 .and_then(|h| h.to_str().ok())
20 .is_some_and(|v| {
21 v.split(',')
22 .any(|t| t.trim().eq_ignore_ascii_case("upgrade"))
23 })
24}
25
26pub fn is_websocket_upgrade<B>(req: &Request<B>) -> bool {
30 if !is_upgrade_request(req) {
31 return false;
32 }
33
34 req.headers()
35 .get(header::UPGRADE)
36 .and_then(|h| h.to_str().ok())
37 .is_some_and(|v| v.eq_ignore_ascii_case("websocket"))
38}
39
40pub fn get_upgrade_protocol<B>(req: &Request<B>) -> Option<&str> {
44 req.headers()
45 .get(header::UPGRADE)
46 .and_then(|h| h.to_str().ok())
47}
48
49pub fn is_upgrade_response<B>(res: &Response<B>) -> bool {
51 res.status() == StatusCode::SWITCHING_PROTOCOLS
52}
53
54pub async fn proxy_tunnel<C, S>(mut client: C, mut server: S) -> Result<()>
64where
65 C: AsyncRead + AsyncWrite + Unpin + Send,
66 S: AsyncRead + AsyncWrite + Unpin + Send,
67{
68 match tokio::io::copy_bidirectional(&mut client, &mut server).await {
69 Ok((client_to_server, server_to_client)) => {
70 debug!(
71 client_to_server = client_to_server,
72 server_to_client = server_to_client,
73 "Tunnel closed"
74 );
75 Ok(())
76 }
77 Err(e) => {
78 if e.kind() == std::io::ErrorKind::ConnectionReset {
80 debug!("Tunnel connection reset");
81 Ok(())
82 } else {
83 warn!(error = %e, "Tunnel error");
84 Err(ProxyError::Io(e))
85 }
86 }
87 }
88}
89
90pub async fn proxy_upgrade(client_upgrade: OnUpgrade, server_upgrade: OnUpgrade) -> Result<()> {
100 let (client_result, server_result) = tokio::join!(client_upgrade, server_upgrade);
102
103 let client_io = client_result.map_err(|e| {
104 error!(error = %e, "Client upgrade failed");
105 ProxyError::Internal(format!("Client upgrade failed: {e}"))
106 })?;
107
108 let server_io = server_result.map_err(|e| {
109 error!(error = %e, "Server upgrade failed");
110 ProxyError::Internal(format!("Server upgrade failed: {e}"))
111 })?;
112
113 info!("Upgrade successful, starting bidirectional tunnel");
114
115 let client = TokioIo::new(client_io);
117 let server = TokioIo::new(server_io);
118
119 proxy_tunnel(client, server).await
120}
121
122pub fn get_websocket_key<B>(req: &Request<B>) -> Option<&str> {
124 req.headers()
125 .get("sec-websocket-key")
126 .and_then(|h| h.to_str().ok())
127}
128
129pub fn get_websocket_version<B>(req: &Request<B>) -> Option<&str> {
131 req.headers()
132 .get("sec-websocket-version")
133 .and_then(|h| h.to_str().ok())
134}
135
136const WEBSOCKET_HEADERS: &[&str] = &[
138 "sec-websocket-key",
139 "sec-websocket-version",
140 "sec-websocket-protocol",
141 "sec-websocket-extensions",
142];
143
144#[must_use]
146pub fn is_websocket_header(name: &str) -> bool {
147 WEBSOCKET_HEADERS
148 .iter()
149 .any(|h| h.eq_ignore_ascii_case(name))
150}
151
152pub fn copy_upgrade_headers(src: &http::request::Parts, dst: &mut http::request::Parts) {
154 if let Some(conn) = src.headers.get(header::CONNECTION) {
156 dst.headers.insert(header::CONNECTION, conn.clone());
157 }
158 if let Some(upgrade) = src.headers.get(header::UPGRADE) {
159 dst.headers.insert(header::UPGRADE, upgrade.clone());
160 }
161
162 for header_name in WEBSOCKET_HEADERS {
164 if let Some(value) = src.headers.get(*header_name) {
165 if let Ok(name) = header::HeaderName::from_bytes(header_name.as_bytes()) {
166 dst.headers.insert(name, value.clone());
167 }
168 }
169 }
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175 use http::Request;
176 use tokio::io::{AsyncRead, AsyncWrite};
177
178 #[test]
179 fn test_is_upgrade_request() {
180 let req = Request::builder()
182 .header("Connection", "upgrade")
183 .body(())
184 .unwrap();
185 assert!(is_upgrade_request(&req));
186
187 let req = Request::builder()
189 .header("connection", "Upgrade")
190 .body(())
191 .unwrap();
192 assert!(is_upgrade_request(&req));
193
194 let req = Request::builder()
196 .header("Connection", "keep-alive, upgrade")
197 .body(())
198 .unwrap();
199 assert!(is_upgrade_request(&req));
200
201 let req = Request::builder()
203 .header("Connection", "keep-alive")
204 .body(())
205 .unwrap();
206 assert!(!is_upgrade_request(&req));
207
208 let req = Request::builder().body(()).unwrap();
210 assert!(!is_upgrade_request(&req));
211 }
212
213 #[test]
214 fn test_is_websocket_upgrade() {
215 let req = Request::builder()
217 .header("Connection", "upgrade")
218 .header("Upgrade", "websocket")
219 .body(())
220 .unwrap();
221 assert!(is_websocket_upgrade(&req));
222
223 let req = Request::builder()
225 .header("Connection", "Upgrade")
226 .header("Upgrade", "WebSocket")
227 .body(())
228 .unwrap();
229 assert!(is_websocket_upgrade(&req));
230
231 let req = Request::builder()
233 .header("Connection", "upgrade")
234 .body(())
235 .unwrap();
236 assert!(!is_websocket_upgrade(&req));
237
238 let req = Request::builder()
240 .header("Connection", "upgrade")
241 .header("Upgrade", "h2c")
242 .body(())
243 .unwrap();
244 assert!(!is_websocket_upgrade(&req));
245 }
246
247 #[test]
248 fn test_get_upgrade_protocol() {
249 let req = Request::builder()
250 .header("Upgrade", "websocket")
251 .body(())
252 .unwrap();
253 assert_eq!(get_upgrade_protocol(&req), Some("websocket"));
254
255 let req = Request::builder()
256 .header("Upgrade", "h2c")
257 .body(())
258 .unwrap();
259 assert_eq!(get_upgrade_protocol(&req), Some("h2c"));
260
261 let req = Request::builder().body(()).unwrap();
262 assert_eq!(get_upgrade_protocol(&req), None);
263 }
264
265 #[test]
266 fn test_is_upgrade_response() {
267 let res = Response::builder()
268 .status(StatusCode::SWITCHING_PROTOCOLS)
269 .body(())
270 .unwrap();
271 assert!(is_upgrade_response(&res));
272
273 let res = Response::builder().status(StatusCode::OK).body(()).unwrap();
274 assert!(!is_upgrade_response(&res));
275 }
276
277 #[test]
278 fn test_get_websocket_key() {
279 let req = Request::builder()
280 .header("sec-websocket-key", "dGhlIHNhbXBsZSBub25jZQ==")
281 .body(())
282 .unwrap();
283 assert_eq!(get_websocket_key(&req), Some("dGhlIHNhbXBsZSBub25jZQ=="));
284
285 let req = Request::builder().body(()).unwrap();
286 assert_eq!(get_websocket_key(&req), None);
287 }
288
289 #[test]
290 fn test_get_websocket_version() {
291 let req = Request::builder()
292 .header("sec-websocket-version", "13")
293 .body(())
294 .unwrap();
295 assert_eq!(get_websocket_version(&req), Some("13"));
296 }
297
298 #[test]
299 fn test_is_websocket_header() {
300 assert!(is_websocket_header("sec-websocket-key"));
301 assert!(is_websocket_header("Sec-WebSocket-Key"));
302 assert!(is_websocket_header("sec-websocket-version"));
303 assert!(is_websocket_header("sec-websocket-protocol"));
304 assert!(is_websocket_header("sec-websocket-extensions"));
305 assert!(!is_websocket_header("content-type"));
306 assert!(!is_websocket_header("host"));
307 }
308
309 #[test]
310 fn test_copy_upgrade_headers() {
311 let src = Request::builder()
312 .header("Connection", "upgrade")
313 .header("Upgrade", "websocket")
314 .header("sec-websocket-key", "test-key")
315 .header("sec-websocket-version", "13")
316 .header("content-type", "text/plain")
317 .body(())
318 .unwrap()
319 .into_parts()
320 .0;
321
322 let mut dst = Request::builder().body(()).unwrap().into_parts().0;
323
324 copy_upgrade_headers(&src, &mut dst);
325
326 assert!(dst.headers.get(header::CONNECTION).is_some());
327 assert!(dst.headers.get(header::UPGRADE).is_some());
328 assert!(dst.headers.get("sec-websocket-key").is_some());
329 assert!(dst.headers.get("sec-websocket-version").is_some());
330 assert!(dst.headers.get("content-type").is_none());
332 }
333
334 #[tokio::test]
335 async fn test_proxy_tunnel_connection_reset() {
336 use std::io::{Error, ErrorKind};
337
338 struct MockStream;
343
344 impl AsyncRead for MockStream {
345 fn poll_read(
346 self: std::pin::Pin<&mut Self>,
347 _cx: &mut std::task::Context<'_>,
348 _buf: &mut tokio::io::ReadBuf<'_>,
349 ) -> std::task::Poll<std::io::Result<()>> {
350 std::task::Poll::Ready(Err(Error::new(
351 ErrorKind::ConnectionReset,
352 "connection reset",
353 )))
354 }
355 }
356
357 impl AsyncWrite for MockStream {
358 fn poll_write(
359 self: std::pin::Pin<&mut Self>,
360 _cx: &mut std::task::Context<'_>,
361 buf: &[u8],
362 ) -> std::task::Poll<std::io::Result<usize>> {
363 std::task::Poll::Ready(Ok(buf.len()))
364 }
365
366 fn poll_flush(
367 self: std::pin::Pin<&mut Self>,
368 _cx: &mut std::task::Context<'_>,
369 ) -> std::task::Poll<std::io::Result<()>> {
370 std::task::Poll::Ready(Ok(()))
371 }
372
373 fn poll_shutdown(
374 self: std::pin::Pin<&mut Self>,
375 _cx: &mut std::task::Context<'_>,
376 ) -> std::task::Poll<std::io::Result<()>> {
377 std::task::Poll::Ready(Ok(()))
378 }
379 }
380
381 let client = MockStream;
382 let server = MockStream;
383
384 let result = proxy_tunnel(client, server).await;
386 assert!(result.is_ok());
387 }
388}