1use crate::cert::CertificateAuthority;
2use crate::error::ProxyError;
3use crate::handler::{boxed_body, full_boxed_body, Buffered, BoxBody, Dropped, RequestHandler};
4use crate::tls;
5use bytes::Bytes;
6use http_body_util::{BodyExt, Empty, Full};
7use hyper::client::conn::http1 as client_http1;
8use hyper::server::conn::http1 as server_http1;
9use hyper::service::service_fn;
10use hyper::upgrade::Upgraded;
11use hyper::{Method, Request, Response};
12use hyper_util::rt::TokioIo;
13use std::net::SocketAddr;
14use std::sync::Arc;
15use tokio::net::TcpStream;
16use tracing::{debug, error, info, warn};
17
18const MAX_INTERCEPT_BODY: usize = 10 * 1024 * 1024;
20
21fn should_intercept_body(headers: &hyper::HeaderMap) -> bool {
26 if let Some(cl) = headers.get(hyper::header::CONTENT_LENGTH) {
27 if let Ok(s) = cl.to_str() {
28 if let Ok(len) = s.parse::<usize>() {
29 return len <= MAX_INTERCEPT_BODY;
30 }
31 }
32 }
33 false
34}
35
36async fn try_collect_body<B>(body: B) -> Option<Bytes>
38where
39 B: hyper::body::Body<Data = Bytes, Error = hyper::Error>,
40{
41 use http_body_util::Limited;
42 let limited = Limited::new(body, MAX_INTERCEPT_BODY);
43 BodyExt::collect(limited)
44 .await
45 .ok()
46 .map(|c| c.to_bytes())
47}
48
49pub struct ProxyState {
51 pub ca: Arc<CertificateAuthority>,
52 pub mitm: bool,
53 pub intercept: bool,
54 pub handler: Arc<dyn RequestHandler>,
55}
56
57pub async fn handle_connection(
59 stream: TcpStream,
60 addr: SocketAddr,
61 state: Arc<ProxyState>,
62) {
63 debug!("New connection from {addr}");
64
65 let io = TokioIo::new(stream);
66 let state = state.clone();
67
68 let service = service_fn(move |req: Request<hyper::body::Incoming>| {
69 let state = state.clone();
70 async move { handle_request(req, state).await }
71 });
72
73 if let Err(e) = server_http1::Builder::new()
74 .preserve_header_case(true)
75 .title_case_headers(true)
76 .serve_connection(io, service)
77 .with_upgrades()
78 .await
79 {
80 if !e.to_string().contains("early eof")
81 && !e.to_string().contains("connection closed")
82 {
83 error!("Connection error from {addr}: {e}");
84 }
85 }
86}
87
88async fn handle_request(
90 req: Request<hyper::body::Incoming>,
91 state: Arc<ProxyState>,
92) -> Result<Response<BoxBody>, hyper::Error> {
93 if req.method() == Method::CONNECT {
94 handle_connect(req, state).await
95 } else {
96 handle_forward(req, state).await
97 }
98}
99
100async fn handle_forward(
104 req: Request<hyper::body::Incoming>,
105 state: Arc<ProxyState>,
106) -> Result<Response<BoxBody>, hyper::Error> {
107 let uri = req.uri().clone();
108 let host = match uri.host() {
109 Some(h) => h.to_string(),
110 None => {
111 warn!("Request with no host: {uri}");
112 return Ok(bad_request("Missing host in URI"));
113 }
114 };
115 let port = uri.port_u16().unwrap_or(80);
116 let addr = format!("{host}:{port}");
117
118 let (mut parts, body) = req.into_parts();
120 let path = parts
121 .uri
122 .path_and_query()
123 .map(|pq| pq.as_str())
124 .unwrap_or("/");
125 parts.uri = match path.parse() {
126 Ok(uri) => uri,
127 Err(_) => {
128 warn!("Invalid path: {path}");
129 return Ok(bad_request("Invalid request URI"));
130 }
131 };
132
133 let do_intercept = state.intercept && should_intercept_body(&parts.headers);
136
137 strip_hop_by_hop_headers(&mut parts.headers);
138
139 let mut forwarded_req = if do_intercept {
140 match try_collect_body(body).await {
141 Some(bytes) => {
142 let mut req = Request::from_parts(parts, full_boxed_body(bytes));
143 req.extensions_mut().insert(Buffered);
144 req
145 }
146 None => {
147 error!("Request body collection failed despite acceptable Content-Length");
150 return Ok(bad_gateway("Request body read error"));
151 }
152 }
153 } else {
154 Request::from_parts(parts, boxed_body(body))
155 };
156
157 state.handler.handle_request(&mut forwarded_req);
158
159 if forwarded_req.extensions().get::<Dropped>().is_some() {
160 return Ok(bad_gateway("Request dropped by interceptor"));
161 }
162
163 let upstream = match TcpStream::connect(&addr).await {
165 Ok(s) => s,
166 Err(e) => {
167 error!("Failed to connect to {addr}: {e}");
168 return Ok(bad_gateway(&format!("Failed to connect to {addr}")));
169 }
170 };
171
172 let io = TokioIo::new(upstream);
173 let (mut sender, conn) = match client_http1::handshake(io).await {
174 Ok(r) => r,
175 Err(e) => {
176 error!("Handshake with {addr} failed: {e}");
177 return Ok(bad_gateway("Upstream handshake failed"));
178 }
179 };
180
181 tokio::spawn(async move {
182 if let Err(e) = conn.await {
183 error!("Upstream connection error: {e}");
184 }
185 });
186
187 match sender.send_request(forwarded_req).await {
188 Ok(res) => {
189 let (parts, body) = res.into_parts();
190 let mut response = if state.intercept && should_intercept_body(&parts.headers) {
191 match try_collect_body(body).await {
192 Some(bytes) => {
193 let mut res = Response::from_parts(parts, full_boxed_body(bytes));
194 res.extensions_mut().insert(Buffered);
195 res
196 }
197 None => {
198 error!("Response body collection failed");
199 return Ok(bad_gateway("Response body collection failed"));
200 }
201 }
202 } else {
203 Response::from_parts(parts, boxed_body(body))
204 };
205 state.handler.handle_response(&mut response);
206 if response.extensions().get::<Dropped>().is_some() {
207 return Ok(interceptor_dropped_response());
208 }
209 Ok(response)
210 }
211 Err(e) => {
212 error!("Upstream request failed: {e}");
213 Ok(bad_gateway("Upstream request failed"))
214 }
215 }
216}
217
218async fn handle_connect(
222 req: Request<hyper::body::Incoming>,
223 state: Arc<ProxyState>,
224) -> Result<Response<BoxBody>, hyper::Error> {
225 let target = match req.uri().authority() {
226 Some(auth) => auth.to_string(),
227 None => {
228 warn!("CONNECT without authority");
229 return Ok(bad_request("CONNECT target missing"));
230 }
231 };
232
233 let (host, port) = parse_host_port(&target);
234 let addr = format!("{host}:{port}");
235
236 info!("CONNECT {target}");
237
238 if state.mitm {
239 handle_mitm(req, host, addr, state).await
241 } else {
242 handle_tunnel(req, addr).await
244 }
245}
246
247async fn handle_tunnel(
249 req: Request<hyper::body::Incoming>,
250 addr: String,
251) -> Result<Response<BoxBody>, hyper::Error> {
252 tokio::spawn(async move {
253 match hyper::upgrade::on(req).await {
254 Ok(upgraded) => {
255 if let Err(e) = tunnel_bidirectional(upgraded, &addr).await {
256 error!("Tunnel error to {addr}: {e}");
257 }
258 }
259 Err(e) => {
260 error!("Upgrade failed: {e}");
261 }
262 }
263 });
264
265 Ok(Response::new(empty_body()))
267}
268
269async fn tunnel_bidirectional(
271 upgraded: Upgraded,
272 addr: &str,
273) -> crate::error::Result<()> {
274 let mut upstream = TcpStream::connect(addr).await?;
275
276 let mut client = TokioIo::new(upgraded);
277
278 let (client_to_server, server_to_client) =
279 tokio::io::copy_bidirectional(&mut client, &mut upstream).await?;
280
281 debug!(
282 "Tunnel closed: {addr} (client→server: {client_to_server}B, server→client: {server_to_client}B)"
283 );
284 Ok(())
285}
286
287async fn handle_mitm(
289 req: Request<hyper::body::Incoming>,
290 host: String,
291 addr: String,
292 state: Arc<ProxyState>,
293) -> Result<Response<BoxBody>, hyper::Error> {
294 let state = state.clone();
295
296 tokio::spawn(async move {
297 match hyper::upgrade::on(req).await {
298 Ok(upgraded) => {
299 if let Err(e) =
300 mitm_intercept(upgraded, &host, &addr, state).await
301 {
302 error!("MITM error for {host}: {e}");
303 }
304 }
305 Err(e) => {
306 error!("MITM upgrade failed: {e}");
307 }
308 }
309 });
310
311 Ok(Response::new(empty_body()))
312}
313
314async fn mitm_intercept(
316 upgraded: Upgraded,
317 host: &str,
318 addr: &str,
319 state: Arc<ProxyState>,
320) -> crate::error::Result<()> {
321 let acceptor = tls::make_tls_acceptor(&state.ca, host).await?;
323
324 let client_io = TokioIo::new(upgraded);
326 let client_tls = acceptor
327 .accept(client_io)
328 .await
329 .map_err(|e| ProxyError::Other(format!("Client TLS accept failed: {e}")))?;
330
331 let client_tls = TokioIo::new(client_tls);
332
333 let host = host.to_string();
335 let addr = addr.to_string();
336
337 let service = service_fn(move |req: Request<hyper::body::Incoming>| {
338 let host = host.clone();
339 let addr = addr.clone();
340 let state = state.clone();
341 async move {
342 mitm_forward_request(req, &host, &addr, state).await
343 }
344 });
345
346 if let Err(e) = server_http1::Builder::new()
347 .preserve_header_case(true)
348 .title_case_headers(true)
349 .serve_connection(client_tls, service)
350 .await
351 {
352 if !e.to_string().contains("early eof")
353 && !e.to_string().contains("connection closed")
354 {
355 debug!("MITM connection closed: {e}");
356 }
357 }
358
359 Ok(())
360}
361
362async fn mitm_forward_request(
364 req: Request<hyper::body::Incoming>,
365 host: &str,
366 addr: &str,
367 state: Arc<ProxyState>,
368) -> Result<Response<BoxBody>, hyper::Error> {
369 let (mut parts, body) = req.into_parts();
370
371 let do_intercept = state.intercept && should_intercept_body(&parts.headers);
372 strip_hop_by_hop_headers(&mut parts.headers);
373
374 let mut forwarded_req = if do_intercept {
375 match try_collect_body(body).await {
376 Some(bytes) => {
377 let mut req = Request::from_parts(parts, full_boxed_body(bytes));
378 req.extensions_mut().insert(Buffered);
379 req
380 }
381 None => {
382 error!("MITM request body collection failed");
383 return Ok(bad_gateway("Request body read error"));
384 }
385 }
386 } else {
387 Request::from_parts(parts, boxed_body(body))
388 };
389
390 state.handler.handle_request(&mut forwarded_req);
391
392 if forwarded_req.extensions().get::<Dropped>().is_some() {
393 return Ok(bad_gateway("Request dropped by interceptor"));
394 }
395
396 let upstream_tls = match tls::connect_tls_upstream(host, addr).await {
398 Ok(s) => s,
399 Err(e) => {
400 error!("Failed TLS connect to {addr}: {e}");
401 return Ok(bad_gateway(&format!(
402 "Failed to connect to upstream: {e}"
403 )));
404 }
405 };
406
407 let io = TokioIo::new(upstream_tls);
408 let (mut sender, conn) = match client_http1::handshake(io).await {
409 Ok(r) => r,
410 Err(e) => {
411 error!("Upstream TLS handshake failed: {e}");
412 return Ok(bad_gateway("Upstream TLS handshake failed"));
413 }
414 };
415
416 tokio::spawn(async move {
417 if let Err(e) = conn.await {
418 debug!("Upstream TLS connection closed: {e}");
419 }
420 });
421
422 match sender.send_request(forwarded_req).await {
423 Ok(res) => {
424 let (parts, body) = res.into_parts();
425 let mut response = if state.intercept && should_intercept_body(&parts.headers) {
426 match try_collect_body(body).await {
427 Some(bytes) => {
428 let mut res = Response::from_parts(parts, full_boxed_body(bytes));
429 res.extensions_mut().insert(Buffered);
430 res
431 }
432 None => {
433 error!("MITM response body collection failed");
434 return Ok(bad_gateway("Response body collection failed"));
435 }
436 }
437 } else {
438 Response::from_parts(parts, boxed_body(body))
439 };
440 state.handler.handle_response(&mut response);
441 if response.extensions().get::<Dropped>().is_some() {
442 return Ok(interceptor_dropped_response());
443 }
444 Ok(response)
445 }
446 Err(e) => {
447 error!("Upstream TLS request failed: {e}");
448 Ok(bad_gateway("Upstream request failed"))
449 }
450 }
451}
452
453const HOP_BY_HOP_HEADERS: &[&str] = &[
457 "connection",
458 "keep-alive",
459 "proxy-authenticate",
460 "proxy-authorization",
461 "te",
462 "trailers",
463 "transfer-encoding",
464 "upgrade",
465];
466
467pub fn parse_host_port(target: &str) -> (String, u16) {
470 if let Some(bracketed) = target.strip_prefix('[') {
471 if let Some((ip6, rest)) = bracketed.split_once(']') {
473 let port = rest
474 .strip_prefix(':')
475 .and_then(|p| p.parse().ok())
476 .unwrap_or(443);
477 return (ip6.to_string(), port);
478 }
479 }
480 if let Some((host, port_str)) = target.rsplit_once(':') {
482 if let Ok(port) = port_str.parse::<u16>() {
483 return (host.to_string(), port);
484 }
485 }
486 (target.to_string(), 443)
487}
488
489fn strip_hop_by_hop_headers(headers: &mut hyper::HeaderMap) {
490 if let Some(conn_val) = headers.get("connection").cloned() {
492 if let Ok(val) = conn_val.to_str() {
493 for name in val.split(',') {
494 let name = name.trim();
495 if !name.is_empty() {
496 headers.remove(name);
497 }
498 }
499 }
500 }
501
502 for name in HOP_BY_HOP_HEADERS {
503 headers.remove(*name);
504 }
505}
506
507fn empty_body() -> BoxBody {
508 Empty::<Bytes>::new()
509 .map_err(|never| match never {})
510 .boxed()
511}
512
513fn bad_request(msg: &str) -> Response<BoxBody> {
514 Response::builder()
515 .status(400)
516 .body(full_body(msg))
517 .unwrap()
518}
519
520fn bad_gateway(msg: &str) -> Response<BoxBody> {
521 Response::builder()
522 .status(502)
523 .body(full_body(msg))
524 .unwrap()
525}
526
527fn interceptor_dropped_response() -> Response<BoxBody> {
532 Response::builder()
533 .status(444)
534 .header("Connection", "close")
535 .header("X-RustGate-Interceptor", "response-dropped")
536 .body(full_body(
537 "Response dropped by interceptor. The upstream request was already executed. Do not retry.",
538 ))
539 .unwrap()
540}
541
542fn full_body(msg: &str) -> BoxBody {
543 Full::new(Bytes::from(msg.to_string()))
544 .map_err(|never| match never {})
545 .boxed()
546}