1use anyhow::{Context, Result, anyhow};
2use futures_util::{SinkExt, StreamExt};
3use std::{
4 collections::HashMap,
5 hash::BuildHasher,
6 sync::{Arc, Mutex},
7};
8use tokio::{
9 io::{AsyncReadExt, AsyncWriteExt},
10 net::TcpStream,
11};
12use tokio_tungstenite::{
13 WebSocketStream, accept_hdr_async,
14 tungstenite::{
15 Error as TungsteniteError, Message,
16 error::ProtocolError,
17 handshake::server::{Request, Response},
18 },
19};
20use tracing::{debug, error, info, warn};
21
22use crate::config::TargetConfig;
23use crate::security::parse_original_client_ip;
24use crate::stream::StreamType;
25
26pub const BUFFER_SIZE: usize = 8192;
27
28#[tracing::instrument(skip(stream, targets), fields(client_addr = %stream.peer_addr().unwrap_or_else(|_| "unknown".parse().unwrap())))]
29pub async fn handle_connection<S: BuildHasher + Sync>(
30 stream: StreamType,
31 targets: &HashMap<String, TargetConfig, S>,
32) -> Result<()> {
33 let host_header = Arc::new(Mutex::new(None::<String>));
34 let host_header_clone = host_header.clone();
35 let client_addr = stream
37 .peer_addr()
38 .unwrap_or_else(|_| "unknown".parse().unwrap());
39
40 let client_ip = Arc::new(Mutex::new(None::<String>));
41 let client_ip_clone = client_ip.clone();
42
43 let callback = move |req: &Request, response: Response| {
44 if let Some(host) = req.headers().get("host") {
45 if let Ok(host_str) = host.to_str() {
46 if let Ok(mut guard) = host_header_clone.lock() {
47 *guard = Some(host_str.to_string());
48 }
49 }
50 }
51
52 if let Some(xff) = req.headers().get("x-forwarded-for") {
54 if let Ok(xff_str) = xff.to_str() {
55 if let Some(original_ip) = parse_original_client_ip(xff_str) {
56 if let Ok(mut guard) = client_ip_clone.lock() {
57 *guard = Some(original_ip);
58 }
59 }
60 }
61 }
62
63 Ok(response)
64 };
65
66 let ws_stream = accept_hdr_async(stream, callback)
67 .await
68 .context("Failed to perform WebSocket handshake")?;
69
70 let host = host_header
71 .lock()
72 .unwrap()
73 .as_ref()
74 .ok_or_else(|| anyhow!("No Host header found in request"))?
75 .clone();
76
77 let original_client_ip = client_ip.lock().unwrap().clone();
78
79 let target_config = targets
80 .get(&host)
81 .ok_or_else(|| anyhow!("No target configured for domain: {}", host))?;
82
83 match original_client_ip {
85 Some(ref ip) => {
86 info!(
87 host = %host,
88 target_host = %target_config.host,
89 target_port = target_config.port,
90 client_ip = %ip,
91 direct_addr = %client_addr,
92 "Routing request"
93 );
94 }
95 None => {
96 info!(
97 host = %host,
98 target_host = %target_config.host,
99 target_port = target_config.port,
100 client_ip = %client_addr,
101 "Routing request"
102 );
103 }
104 }
105
106 handle_socket(ws_stream, target_config, original_client_ip).await?;
107 Ok(())
108}
109
110#[tracing::instrument(skip(websocket, target_config, client_ip))]
111pub async fn handle_socket(
112 websocket: WebSocketStream<StreamType>,
113 target_config: &TargetConfig,
114 client_ip: Option<String>,
115) -> Result<()> {
116 let target_addr = format!("{}:{}", target_config.host, target_config.port);
117
118 if let Some(ref ip) = client_ip {
119 debug!(target_addr = %target_addr, client_ip = %ip, "Attempting to connect to target server");
120 } else {
121 debug!(target_addr = %target_addr, "Attempting to connect to target server");
122 }
123
124 let tcp_stream = TcpStream::connect(&target_addr)
125 .await
126 .with_context(|| format!("Failed to connect to target {target_addr}"))?;
127
128 if let Some(ref ip) = client_ip {
129 info!(target_addr = %target_addr, client_ip = %ip, "Connected to target server");
130 } else {
131 info!(target_addr = %target_addr, "Connected to target server");
132 }
133
134 let (mut ws_sender, mut ws_receiver) = websocket.split();
135 let (mut tcp_reader, mut tcp_writer) = tcp_stream.into_split();
136
137 let ws_to_tcp = async {
138 while let Some(msg) = ws_receiver.next().await {
139 match msg {
140 Ok(Message::Binary(data)) => {
141 debug!(bytes = data.len(), "Forwarding data from WebSocket to TCP");
142 if let Err(e) = tcp_writer.write_all(&data).await {
143 error!(error = %e, bytes = data.len(), "Failed to write to TCP");
144 return Err(e).context("Failed to write WebSocket data to TCP connection");
145 }
146 }
147 Ok(Message::Text(_)) => {
148 warn!("Dropping text message (binary only)");
149 }
150 Ok(Message::Close(_)) => {
151 info!("WebSocket connection closed");
152 break;
153 }
154 Err(e) => {
155 match e {
156 TungsteniteError::ConnectionClosed
157 | TungsteniteError::Protocol(ProtocolError::ResetWithoutClosingHandshake) =>
158 {
159 debug!("Client disconnected: {e}");
160 }
161 _ => {
162 error!("WebSocket error: {e}");
163 }
164 }
165 break;
166 }
167 _ => {}
168 }
169 }
170 Ok(())
171 };
172
173 let tcp_to_ws = async {
174 let mut buffer = [0u8; BUFFER_SIZE];
175
176 loop {
177 match tcp_reader.read(&mut buffer).await {
178 Ok(0) => {
179 info!("TCP connection closed");
180 break;
181 }
182 Ok(n) => {
183 let data = &buffer[..n];
184 debug!(bytes = n, "Forwarding data from TCP to WebSocket");
185 if let Err(e) = ws_sender.send(Message::Binary(data.to_vec().into())).await {
186 error!(error = %e, bytes = data.len(), "Failed to send WebSocket message");
187 return Err(e).context("Failed to send TCP data via WebSocket");
188 }
189 }
190 Err(e) => {
191 error!("Failed to read from TCP: {e}");
192 break;
193 }
194 }
195 }
196 Ok(())
197 };
198
199 tokio::select! {
200 result = ws_to_tcp => result?,
201 result = tcp_to_ws => result?,
202 }
203
204 if let Some(ref ip) = client_ip {
205 info!(client_ip = %ip, "Proxy connection closed");
206 } else {
207 info!("Proxy connection closed");
208 }
209 Ok(())
210}