Skip to main content

zlayer_proxy/
tunnel.rs

1//! WebSocket and upgrade tunneling
2//!
3//! This module provides functionality for handling HTTP upgrade requests,
4//! including WebSocket connections, and bidirectional tunneling.
5
6use crate::error::{ProxyError, Result};
7use http::{header, Request, Response, StatusCode};
8use hyper::upgrade::OnUpgrade;
9use hyper_util::rt::TokioIo;
10use tokio::io::{AsyncRead, AsyncWrite};
11use tracing::{debug, error, info, warn};
12
13/// Check if a request is an upgrade request
14///
15/// Returns true if the request has a Connection: upgrade header
16pub fn is_upgrade_request<B>(req: &Request<B>) -> bool {
17    req.headers()
18        .get(header::CONNECTION)
19        .and_then(|h| h.to_str().ok())
20        .is_some_and(|v| {
21            v.split(',')
22                .any(|t| t.trim().eq_ignore_ascii_case("upgrade"))
23        })
24}
25
26/// Check if a request is a WebSocket upgrade request
27///
28/// Returns true if the request has both Connection: upgrade and Upgrade: websocket headers
29pub fn is_websocket_upgrade<B>(req: &Request<B>) -> bool {
30    if !is_upgrade_request(req) {
31        return false;
32    }
33
34    req.headers()
35        .get(header::UPGRADE)
36        .and_then(|h| h.to_str().ok())
37        .is_some_and(|v| v.eq_ignore_ascii_case("websocket"))
38}
39
40/// Get the upgrade protocol from a request
41///
42/// Returns the value of the Upgrade header if present
43pub fn get_upgrade_protocol<B>(req: &Request<B>) -> Option<&str> {
44    req.headers()
45        .get(header::UPGRADE)
46        .and_then(|h| h.to_str().ok())
47}
48
49/// Check if a response indicates a successful upgrade
50pub fn is_upgrade_response<B>(res: &Response<B>) -> bool {
51    res.status() == StatusCode::SWITCHING_PROTOCOLS
52}
53
54/// Proxy a tunnel connection between two upgraded connections
55///
56/// This function performs bidirectional copying between the client
57/// and server connections after an upgrade.
58///
59/// # Errors
60///
61/// Returns an error if the bidirectional copy fails with a non-reset
62/// IO error.
63pub async fn proxy_tunnel<C, S>(mut client: C, mut server: S) -> Result<()>
64where
65    C: AsyncRead + AsyncWrite + Unpin + Send,
66    S: AsyncRead + AsyncWrite + Unpin + Send,
67{
68    match tokio::io::copy_bidirectional(&mut client, &mut server).await {
69        Ok((client_to_server, server_to_client)) => {
70            debug!(
71                client_to_server = client_to_server,
72                server_to_client = server_to_client,
73                "Tunnel closed"
74            );
75            Ok(())
76        }
77        Err(e) => {
78            // Connection reset is common and not really an error
79            if e.kind() == std::io::ErrorKind::ConnectionReset {
80                debug!("Tunnel connection reset");
81                Ok(())
82            } else {
83                warn!(error = %e, "Tunnel error");
84                Err(ProxyError::Io(e))
85            }
86        }
87    }
88}
89
90/// Handle upgrade with explicit upgrade futures
91///
92/// This is a higher-level function that takes `OnUpgrade` futures from hyper
93/// and handles the bidirectional copying between them.
94///
95/// # Errors
96///
97/// Returns an error if either upgrade fails or if the bidirectional
98/// tunnel encounters a fatal IO error.
99pub async fn proxy_upgrade(client_upgrade: OnUpgrade, server_upgrade: OnUpgrade) -> Result<()> {
100    // Wait for both upgrades to complete
101    let (client_result, server_result) = tokio::join!(client_upgrade, server_upgrade);
102
103    let client_io = client_result.map_err(|e| {
104        error!(error = %e, "Client upgrade failed");
105        ProxyError::Internal(format!("Client upgrade failed: {e}"))
106    })?;
107
108    let server_io = server_result.map_err(|e| {
109        error!(error = %e, "Server upgrade failed");
110        ProxyError::Internal(format!("Server upgrade failed: {e}"))
111    })?;
112
113    info!("Upgrade successful, starting bidirectional tunnel");
114
115    // Use hyper_util's TokioIo wrapper for the upgraded connections
116    let client = TokioIo::new(client_io);
117    let server = TokioIo::new(server_io);
118
119    proxy_tunnel(client, server).await
120}
121
122/// Extract WebSocket key from a request for validation
123pub fn get_websocket_key<B>(req: &Request<B>) -> Option<&str> {
124    req.headers()
125        .get("sec-websocket-key")
126        .and_then(|h| h.to_str().ok())
127}
128
129/// Extract WebSocket version from a request
130pub fn get_websocket_version<B>(req: &Request<B>) -> Option<&str> {
131    req.headers()
132        .get("sec-websocket-version")
133        .and_then(|h| h.to_str().ok())
134}
135
136/// Headers that should be forwarded for WebSocket upgrades
137const WEBSOCKET_HEADERS: &[&str] = &[
138    "sec-websocket-key",
139    "sec-websocket-version",
140    "sec-websocket-protocol",
141    "sec-websocket-extensions",
142];
143
144/// Check if a header should be preserved for WebSocket upgrades
145#[must_use]
146pub fn is_websocket_header(name: &str) -> bool {
147    WEBSOCKET_HEADERS
148        .iter()
149        .any(|h| h.eq_ignore_ascii_case(name))
150}
151
152/// Copy upgrade-related headers from source to destination request parts
153pub fn copy_upgrade_headers(src: &http::request::Parts, dst: &mut http::request::Parts) {
154    // Copy Connection and Upgrade headers
155    if let Some(conn) = src.headers.get(header::CONNECTION) {
156        dst.headers.insert(header::CONNECTION, conn.clone());
157    }
158    if let Some(upgrade) = src.headers.get(header::UPGRADE) {
159        dst.headers.insert(header::UPGRADE, upgrade.clone());
160    }
161
162    // Copy WebSocket-specific headers
163    for header_name in WEBSOCKET_HEADERS {
164        if let Some(value) = src.headers.get(*header_name) {
165            if let Ok(name) = header::HeaderName::from_bytes(header_name.as_bytes()) {
166                dst.headers.insert(name, value.clone());
167            }
168        }
169    }
170}
171
172#[cfg(test)]
173mod tests {
174    use super::*;
175    use http::Request;
176    use tokio::io::{AsyncRead, AsyncWrite};
177
178    #[test]
179    fn test_is_upgrade_request() {
180        // With upgrade
181        let req = Request::builder()
182            .header("Connection", "upgrade")
183            .body(())
184            .unwrap();
185        assert!(is_upgrade_request(&req));
186
187        // With mixed case
188        let req = Request::builder()
189            .header("connection", "Upgrade")
190            .body(())
191            .unwrap();
192        assert!(is_upgrade_request(&req));
193
194        // With multiple values
195        let req = Request::builder()
196            .header("Connection", "keep-alive, upgrade")
197            .body(())
198            .unwrap();
199        assert!(is_upgrade_request(&req));
200
201        // Without upgrade
202        let req = Request::builder()
203            .header("Connection", "keep-alive")
204            .body(())
205            .unwrap();
206        assert!(!is_upgrade_request(&req));
207
208        // No connection header
209        let req = Request::builder().body(()).unwrap();
210        assert!(!is_upgrade_request(&req));
211    }
212
213    #[test]
214    fn test_is_websocket_upgrade() {
215        // Valid WebSocket upgrade
216        let req = Request::builder()
217            .header("Connection", "upgrade")
218            .header("Upgrade", "websocket")
219            .body(())
220            .unwrap();
221        assert!(is_websocket_upgrade(&req));
222
223        // Mixed case
224        let req = Request::builder()
225            .header("Connection", "Upgrade")
226            .header("Upgrade", "WebSocket")
227            .body(())
228            .unwrap();
229        assert!(is_websocket_upgrade(&req));
230
231        // Missing Upgrade header
232        let req = Request::builder()
233            .header("Connection", "upgrade")
234            .body(())
235            .unwrap();
236        assert!(!is_websocket_upgrade(&req));
237
238        // Wrong upgrade protocol
239        let req = Request::builder()
240            .header("Connection", "upgrade")
241            .header("Upgrade", "h2c")
242            .body(())
243            .unwrap();
244        assert!(!is_websocket_upgrade(&req));
245    }
246
247    #[test]
248    fn test_get_upgrade_protocol() {
249        let req = Request::builder()
250            .header("Upgrade", "websocket")
251            .body(())
252            .unwrap();
253        assert_eq!(get_upgrade_protocol(&req), Some("websocket"));
254
255        let req = Request::builder()
256            .header("Upgrade", "h2c")
257            .body(())
258            .unwrap();
259        assert_eq!(get_upgrade_protocol(&req), Some("h2c"));
260
261        let req = Request::builder().body(()).unwrap();
262        assert_eq!(get_upgrade_protocol(&req), None);
263    }
264
265    #[test]
266    fn test_is_upgrade_response() {
267        let res = Response::builder()
268            .status(StatusCode::SWITCHING_PROTOCOLS)
269            .body(())
270            .unwrap();
271        assert!(is_upgrade_response(&res));
272
273        let res = Response::builder().status(StatusCode::OK).body(()).unwrap();
274        assert!(!is_upgrade_response(&res));
275    }
276
277    #[test]
278    fn test_get_websocket_key() {
279        let req = Request::builder()
280            .header("sec-websocket-key", "dGhlIHNhbXBsZSBub25jZQ==")
281            .body(())
282            .unwrap();
283        assert_eq!(get_websocket_key(&req), Some("dGhlIHNhbXBsZSBub25jZQ=="));
284
285        let req = Request::builder().body(()).unwrap();
286        assert_eq!(get_websocket_key(&req), None);
287    }
288
289    #[test]
290    fn test_get_websocket_version() {
291        let req = Request::builder()
292            .header("sec-websocket-version", "13")
293            .body(())
294            .unwrap();
295        assert_eq!(get_websocket_version(&req), Some("13"));
296    }
297
298    #[test]
299    fn test_is_websocket_header() {
300        assert!(is_websocket_header("sec-websocket-key"));
301        assert!(is_websocket_header("Sec-WebSocket-Key"));
302        assert!(is_websocket_header("sec-websocket-version"));
303        assert!(is_websocket_header("sec-websocket-protocol"));
304        assert!(is_websocket_header("sec-websocket-extensions"));
305        assert!(!is_websocket_header("content-type"));
306        assert!(!is_websocket_header("host"));
307    }
308
309    #[test]
310    fn test_copy_upgrade_headers() {
311        let src = Request::builder()
312            .header("Connection", "upgrade")
313            .header("Upgrade", "websocket")
314            .header("sec-websocket-key", "test-key")
315            .header("sec-websocket-version", "13")
316            .header("content-type", "text/plain")
317            .body(())
318            .unwrap()
319            .into_parts()
320            .0;
321
322        let mut dst = Request::builder().body(()).unwrap().into_parts().0;
323
324        copy_upgrade_headers(&src, &mut dst);
325
326        assert!(dst.headers.get(header::CONNECTION).is_some());
327        assert!(dst.headers.get(header::UPGRADE).is_some());
328        assert!(dst.headers.get("sec-websocket-key").is_some());
329        assert!(dst.headers.get("sec-websocket-version").is_some());
330        // content-type should not be copied
331        assert!(dst.headers.get("content-type").is_none());
332    }
333
334    #[tokio::test]
335    async fn test_proxy_tunnel_connection_reset() {
336        use std::io::{Error, ErrorKind};
337
338        // Test that connection reset is handled gracefully
339        // This simulates what happens when a connection is closed unexpectedly
340
341        // Create a mock stream that immediately returns connection reset
342        struct MockStream;
343
344        impl AsyncRead for MockStream {
345            fn poll_read(
346                self: std::pin::Pin<&mut Self>,
347                _cx: &mut std::task::Context<'_>,
348                _buf: &mut tokio::io::ReadBuf<'_>,
349            ) -> std::task::Poll<std::io::Result<()>> {
350                std::task::Poll::Ready(Err(Error::new(
351                    ErrorKind::ConnectionReset,
352                    "connection reset",
353                )))
354            }
355        }
356
357        impl AsyncWrite for MockStream {
358            fn poll_write(
359                self: std::pin::Pin<&mut Self>,
360                _cx: &mut std::task::Context<'_>,
361                buf: &[u8],
362            ) -> std::task::Poll<std::io::Result<usize>> {
363                std::task::Poll::Ready(Ok(buf.len()))
364            }
365
366            fn poll_flush(
367                self: std::pin::Pin<&mut Self>,
368                _cx: &mut std::task::Context<'_>,
369            ) -> std::task::Poll<std::io::Result<()>> {
370                std::task::Poll::Ready(Ok(()))
371            }
372
373            fn poll_shutdown(
374                self: std::pin::Pin<&mut Self>,
375                _cx: &mut std::task::Context<'_>,
376            ) -> std::task::Poll<std::io::Result<()>> {
377                std::task::Poll::Ready(Ok(()))
378            }
379        }
380
381        let client = MockStream;
382        let server = MockStream;
383
384        // Connection reset should be handled gracefully (returns Ok)
385        let result = proxy_tunnel(client, server).await;
386        assert!(result.is_ok());
387    }
388}