websocket_relay/
proxy.rs

1use anyhow::{Context, Result, anyhow};
2use futures_util::{SinkExt, StreamExt};
3use std::{
4    collections::HashMap,
5    hash::BuildHasher,
6    sync::{Arc, Mutex},
7};
8use tokio::{
9    io::{AsyncReadExt, AsyncWriteExt},
10    net::TcpStream,
11};
12use tokio_tungstenite::{
13    WebSocketStream, accept_hdr_async,
14    tungstenite::{
15        Error as TungsteniteError, Message,
16        error::ProtocolError,
17        handshake::server::{Request, Response},
18    },
19};
20use tracing::{debug, error, info, warn};
21
22use crate::config::TargetConfig;
23use crate::security::parse_original_client_ip;
24use crate::stream::StreamType;
25
26pub const BUFFER_SIZE: usize = 8192;
27
28#[tracing::instrument(skip(stream, targets), fields(client_addr = %stream.peer_addr().unwrap_or_else(|_| "unknown".parse().unwrap())))]
29pub async fn handle_connection<S: BuildHasher + Sync>(
30    stream: StreamType,
31    targets: &HashMap<String, TargetConfig, S>,
32) -> Result<()> {
33    let host_header = Arc::new(Mutex::new(None::<String>));
34    let host_header_clone = host_header.clone();
35    // Get client address before moving the stream
36    let client_addr = stream
37        .peer_addr()
38        .unwrap_or_else(|_| "unknown".parse().unwrap());
39
40    let client_ip = Arc::new(Mutex::new(None::<String>));
41    let client_ip_clone = client_ip.clone();
42
43    let callback = move |req: &Request, response: Response| {
44        if let Some(host) = req.headers().get("host") {
45            if let Ok(host_str) = host.to_str() {
46                if let Ok(mut guard) = host_header_clone.lock() {
47                    *guard = Some(host_str.to_string());
48                }
49            }
50        }
51
52        // Extract original client IP from X-Forwarded-For header
53        if let Some(xff) = req.headers().get("x-forwarded-for") {
54            if let Ok(xff_str) = xff.to_str() {
55                if let Some(original_ip) = parse_original_client_ip(xff_str) {
56                    if let Ok(mut guard) = client_ip_clone.lock() {
57                        *guard = Some(original_ip);
58                    }
59                }
60            }
61        }
62
63        Ok(response)
64    };
65
66    let ws_stream = accept_hdr_async(stream, callback)
67        .await
68        .context("Failed to perform WebSocket handshake")?;
69
70    let host = host_header
71        .lock()
72        .unwrap()
73        .as_ref()
74        .ok_or_else(|| anyhow!("No Host header found in request"))?
75        .clone();
76
77    let original_client_ip = client_ip.lock().unwrap().clone();
78
79    let target_config = targets
80        .get(&host)
81        .ok_or_else(|| anyhow!("No target configured for domain: {}", host))?;
82
83    // Log with original client IP if available, otherwise use direct connection IP
84    match original_client_ip {
85        Some(ref ip) => {
86            info!(
87                host = %host,
88                target_host = %target_config.host,
89                target_port = target_config.port,
90                client_ip = %ip,
91                direct_addr = %client_addr,
92                "Routing request"
93            );
94        }
95        None => {
96            info!(
97                host = %host,
98                target_host = %target_config.host,
99                target_port = target_config.port,
100                client_ip = %client_addr,
101                "Routing request"
102            );
103        }
104    }
105
106    handle_socket(ws_stream, target_config, original_client_ip).await?;
107    Ok(())
108}
109
110#[tracing::instrument(skip(websocket, target_config, client_ip))]
111pub async fn handle_socket(
112    websocket: WebSocketStream<StreamType>,
113    target_config: &TargetConfig,
114    client_ip: Option<String>,
115) -> Result<()> {
116    let target_addr = format!("{}:{}", target_config.host, target_config.port);
117
118    if let Some(ref ip) = client_ip {
119        debug!(target_addr = %target_addr, client_ip = %ip, "Attempting to connect to target server");
120    } else {
121        debug!(target_addr = %target_addr, "Attempting to connect to target server");
122    }
123
124    let tcp_stream = TcpStream::connect(&target_addr)
125        .await
126        .with_context(|| format!("Failed to connect to target {target_addr}"))?;
127
128    if let Some(ref ip) = client_ip {
129        info!(target_addr = %target_addr, client_ip = %ip, "Connected to target server");
130    } else {
131        info!(target_addr = %target_addr, "Connected to target server");
132    }
133
134    let (mut ws_sender, mut ws_receiver) = websocket.split();
135    let (mut tcp_reader, mut tcp_writer) = tcp_stream.into_split();
136
137    let ws_to_tcp = async {
138        while let Some(msg) = ws_receiver.next().await {
139            match msg {
140                Ok(Message::Binary(data)) => {
141                    debug!(bytes = data.len(), "Forwarding data from WebSocket to TCP");
142                    if let Err(e) = tcp_writer.write_all(&data).await {
143                        error!(error = %e, bytes = data.len(), "Failed to write to TCP");
144                        return Err(e).context("Failed to write WebSocket data to TCP connection");
145                    }
146                }
147                Ok(Message::Text(_)) => {
148                    warn!("Dropping text message (binary only)");
149                }
150                Ok(Message::Close(_)) => {
151                    info!("WebSocket connection closed");
152                    break;
153                }
154                Err(e) => {
155                    match e {
156                        TungsteniteError::ConnectionClosed
157                        | TungsteniteError::Protocol(ProtocolError::ResetWithoutClosingHandshake) =>
158                        {
159                            debug!("Client disconnected: {e}");
160                        }
161                        _ => {
162                            error!("WebSocket error: {e}");
163                        }
164                    }
165                    break;
166                }
167                _ => {}
168            }
169        }
170        Ok(())
171    };
172
173    let tcp_to_ws = async {
174        let mut buffer = [0u8; BUFFER_SIZE];
175
176        loop {
177            match tcp_reader.read(&mut buffer).await {
178                Ok(0) => {
179                    info!("TCP connection closed");
180                    break;
181                }
182                Ok(n) => {
183                    let data = &buffer[..n];
184                    debug!(bytes = n, "Forwarding data from TCP to WebSocket");
185                    if let Err(e) = ws_sender.send(Message::Binary(data.to_vec().into())).await {
186                        error!(error = %e, bytes = data.len(), "Failed to send WebSocket message");
187                        return Err(e).context("Failed to send TCP data via WebSocket");
188                    }
189                }
190                Err(e) => {
191                    error!("Failed to read from TCP: {e}");
192                    break;
193                }
194            }
195        }
196        Ok(())
197    };
198
199    tokio::select! {
200        result = ws_to_tcp => result?,
201        result = tcp_to_ws => result?,
202    }
203
204    if let Some(ref ip) = client_ip {
205        info!(client_ip = %ip, "Proxy connection closed");
206    } else {
207        info!("Proxy connection closed");
208    }
209    Ok(())
210}