1use crate::ca::CertificateManager;
4use crate::error::{Error, Result};
5use crate::interceptor::InterceptorHandler;
6use crate::proxy::MitmConfig;
7use bytes::Bytes;
8use http::Version;
9use slinger::{Client, ClientBuilder, Request, Response};
10use std::sync::Arc;
11use std::time::Duration;
12use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader};
13use tokio::net::{TcpListener, TcpStream};
14use tokio::sync::RwLock;
15use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer};
16use tokio_rustls::rustls::ServerConfig;
17use tokio_rustls::TlsAcceptor;
18
19pub struct ProxyServer {
21 config: MitmConfig,
22 cert_manager: Arc<CertificateManager>,
23 interceptor_handler: Arc<RwLock<InterceptorHandler>>,
24 client: Client,
25}
26
27#[derive(Default)]
31pub struct ProxyServerBuilder {
32 config: Option<MitmConfig>,
33 cert_manager: Option<Arc<CertificateManager>>,
34 interceptor_handler: Option<Arc<RwLock<InterceptorHandler>>>,
35 client: Option<Client>,
36 client_config: Option<Box<dyn Fn(ClientBuilder) -> ClientBuilder + Send + Sync>>,
38}
39
40impl ProxyServerBuilder {
41 pub fn from_server(server: &ProxyServer) -> Self {
43 Self {
44 config: Some(server.config.clone()),
45 cert_manager: Some(server.cert_manager.clone()),
46 interceptor_handler: Some(server.interceptor_handler.clone()),
47 client: Some(server.client.clone()),
48 client_config: None,
49 }
50 }
51
52 pub fn config(mut self, config: MitmConfig) -> Self {
54 self.config = Some(config);
55 self
56 }
57
58 pub fn cert_manager(mut self, cert_manager: Arc<CertificateManager>) -> Self {
60 self.cert_manager = Some(cert_manager);
61 self
62 }
63
64 pub fn interceptor_handler(mut self, handler: Arc<RwLock<InterceptorHandler>>) -> Self {
66 self.interceptor_handler = Some(handler);
67 self
68 }
69
70 pub fn client(mut self, client: Client) -> Self {
72 self.client = Some(client);
73 self
74 }
75
76 pub fn configure_client<F>(mut self, f: F) -> Self
79 where
80 F: Fn(ClientBuilder) -> ClientBuilder + Send + Sync + 'static,
81 {
82 self.client_config = Some(Box::new(f));
83 self
84 }
85
86 pub fn build(self) -> Result<ProxyServer> {
94 let config = self.config.unwrap_or_default();
96
97 let cert_manager = match self.cert_manager {
100 Some(c) => c,
101 None => {
102 return Err(Error::proxy_error(
103 "CertificateManager not provided; use ProxyServer::builder().build_async().await to create one automatically".to_string(),
104 ))
105 }
106 };
107
108 let interceptor_handler = self
110 .interceptor_handler
111 .unwrap_or_else(|| Arc::new(RwLock::new(InterceptorHandler::new())));
112
113 let client = if let Some(client) = self.client {
115 client
116 } else if let Some(cfg_fn) = self.client_config {
117 let builder = Client::builder();
118 let configured = cfg_fn(builder);
119 configured
120 .build()
121 .map_err(|e| Error::proxy_error(format!("Failed to build client: {}", e)))?
122 } else {
123 if let Some(proxy) = &config.upstream_proxy {
125 Client::builder()
126 .timeout(Some(Duration::from_secs(60)))
127 .keepalive(true)
128 .proxy(proxy.clone())
129 .build()
130 .map_err(|e| {
131 Error::proxy_error(format!(
132 "Failed to build client with proxy {}: {}",
133 proxy.uri(),
134 e
135 ))
136 })?
137 } else {
138 Client::builder()
139 .keepalive(true)
140 .build()
141 .map_err(|e| Error::proxy_error(format!("Failed to build default client: {}", e)))?
142 }
143 };
144
145 Ok(ProxyServer {
146 config,
147 cert_manager,
148 interceptor_handler,
149 client,
150 })
151 }
152}
153
154impl ProxyServer {
155 pub fn new(
157 config: MitmConfig,
158 cert_manager: Arc<CertificateManager>,
159 interceptor_handler: Arc<RwLock<InterceptorHandler>>,
160 ) -> Result<Self> {
161 let client = if let Some(proxy) = &config.upstream_proxy {
162 Client::builder()
164 .timeout(Some(Duration::from_secs(60)))
165 .keepalive(true)
166 .proxy(proxy.clone())
167 .build()
168 .map_err(|e| {
169 Error::proxy_error(format!(
170 "Failed to build client with proxy {}: {}",
171 proxy.uri(),
172 e
173 ))
174 })?
175 } else {
176 Client::builder()
178 .keepalive(true)
179 .build()
180 .map_err(|e| Error::proxy_error(format!("Failed to build default client: {}", e)))?
181 };
182 Ok(Self {
183 config,
184 cert_manager,
185 interceptor_handler,
186 client,
187 })
188 }
189
190 pub async fn run(&self, addr: &str) -> Result<()> {
192 let listener = TcpListener::bind(addr)
193 .await
194 .map_err(|e| Error::proxy_error(format!("Failed to bind to {}: {}", addr, e)))?;
195 loop {
196 match listener.accept().await {
197 Ok((stream, _peer_addr)) => {
198 let config = self.config.clone();
199 let cert_manager = self.cert_manager.clone();
200 let interceptor = self.interceptor_handler.clone();
201 let client = self.client.clone();
202
203 tokio::spawn(async move {
204 if let Err(e) =
205 Self::handle_connection(stream, config, cert_manager, interceptor, client).await
206 {
207 tracing::error!("[MITM] Error handling connection: {}", e);
208 }
209 });
210 }
211 Err(e) => {
212 tracing::error!("[MITM] Failed to accept connection: {}", e);
213 }
214 }
215 }
216 }
217
218 async fn handle_connection(
220 mut stream: TcpStream,
221 config: MitmConfig,
222 cert_manager: Arc<CertificateManager>,
223 interceptor: Arc<RwLock<InterceptorHandler>>,
224 client: Client,
225 ) -> Result<()> {
226 use crate::socks5::Socks5Server;
227
228 let mut first_byte = [0u8; 1];
230 stream.read_exact(&mut first_byte).await?;
231
232 if first_byte[0] == 0x05 {
234 match Socks5Server::handle_handshake_with_version(&mut stream).await {
237 Ok(target_addr) => {
238 let target_host_port = target_addr.to_host_port();
239 match target_addr {
240 crate::socks5::TargetAddr::Domain(_domain, _port) => {
241 if config.enable_https_interception {
242 Self::handle_https_connect_socks5(
243 stream,
244 &target_host_port,
245 cert_manager,
246 interceptor,
247 client,
248 )
249 .await
250 } else {
251 Self::handle_tcp_tunnel(stream, &target_host_port).await
252 }
253 }
254 _ => {
255 if config.enable_https_interception {
256 Self::handle_https_connect_socks5(
257 stream,
258 &target_host_port,
259 cert_manager,
260 interceptor,
261 client,
262 )
263 .await
264 } else {
265 Self::handle_tcp_tunnel(stream, &target_host_port).await
266 }
267 }
268 }
269 }
270 Err(e) => Err(e),
271 }
272 } else {
273 let mut request_line = vec![first_byte[0]];
274 let mut buffer = [0u8; 1];
275 loop {
276 stream.read_exact(&mut buffer).await?;
277 request_line.push(buffer[0]);
278 if buffer[0] == b'\n' {
279 break;
280 }
281 if request_line.len() > 8192 {
282 return Err(Error::invalid_request("Request line too long".to_string()));
283 }
284 }
285
286 let request_line_str = String::from_utf8_lossy(&request_line);
287 let parts: Vec<&str> = request_line_str.split_whitespace().collect();
288 if parts.len() < 3 {
289 return Err(Error::invalid_request("Invalid request line".to_string()));
290 }
291
292 let method = parts[0].to_string();
293 let uri = parts[1].to_string();
294 if method == "CONNECT" {
295 let mut reader = BufReader::new(stream);
296 const MAX_CONNECT_HEADERS: usize = 16 * 1024; let mut headers_acc = 0usize;
298 loop {
299 let mut line = String::new();
300 let n = reader.read_line(&mut line).await?;
301 if n == 0 {
303 break;
304 }
305 headers_acc += n;
306 if headers_acc > MAX_CONNECT_HEADERS {
307 return Err(Error::invalid_request(
308 "CONNECT headers size exceeds maximum allowed".to_string(),
309 ));
310 }
311 if line == "\r\n" || line == "\n" || line.is_empty() {
313 break;
314 }
315 }
316 let stream = reader.into_inner();
317 if config.enable_https_interception {
318 Self::handle_https_connect(stream, &uri, cert_manager, interceptor, client).await
319 } else {
320 Self::handle_https_tunnel(stream, &uri).await
321 }
322 } else {
323 let buf_reader = BufReader::new(stream);
324 Self::handle_http_request(&method, &uri, buf_reader, interceptor, client).await
325 }
326 }
327 }
328
329 async fn handle_https_connect(
331 client_stream: TcpStream,
332 uri: &str,
333 cert_manager: Arc<CertificateManager>,
334 interceptor: Arc<RwLock<InterceptorHandler>>,
335 slinger_client: Client,
336 ) -> Result<()> {
337 let (domain, port) = Self::parse_host_port(uri)?;
339 Self::accept_tls_and_handle(
341 client_stream,
342 &domain,
343 port,
344 true,
345 cert_manager,
346 interceptor,
347 slinger_client,
348 )
349 .await
350 }
351
352 async fn handle_https_tunnel(client_stream: TcpStream, uri: &str) -> Result<()> {
354 Self::tcp_tunnel(client_stream, uri, true).await
355 }
356
357 async fn handle_tcp_tunnel(client_stream: TcpStream, target_addr: &str) -> Result<()> {
360 Self::tcp_tunnel(client_stream, target_addr, false).await
361 }
362
363 async fn handle_https_connect_socks5(
366 client_stream: TcpStream,
367 uri: &str,
368 cert_manager: Arc<CertificateManager>,
369 interceptor: Arc<RwLock<InterceptorHandler>>,
370 slinger_client: Client,
371 ) -> Result<()> {
372 let (domain, port) = Self::parse_host_port(uri)?;
374
375 Self::accept_tls_and_handle(
377 client_stream,
378 &domain,
379 port,
380 false,
381 cert_manager,
382 interceptor,
383 slinger_client,
384 )
385 .await
386 }
387
388 async fn accept_tls_and_handle(
391 mut client_stream: TcpStream,
392 domain: &str,
393 port: u16,
394 send_response: bool,
395 cert_manager: Arc<CertificateManager>,
396 interceptor: Arc<RwLock<InterceptorHandler>>,
397 slinger_client: Client,
398 ) -> Result<()> {
399 if send_response {
400 client_stream
401 .write_all(b"HTTP/1.1 200 Connection Established\r\n\r\n")
402 .await?;
403 client_stream
404 .flush() .await
406 .map_err(Error::Io)?;
407 }
408
409 let (cert_chain, key) = cert_manager.get_server_cert(domain).await?;
411 let tls_config = Self::create_tls_server_config(cert_chain, key)?;
413 let acceptor = TlsAcceptor::from(Arc::new(tls_config));
414 let tls_stream = acceptor
416 .accept(client_stream)
417 .await
418 .map_err(|e| Error::tls_error(format!("TLS handshake failed: {}", e)))?;
419 let domain_with_port = format!("{}:{}", domain, port);
420 Self::handle_https_stream(tls_stream, domain_with_port, interceptor, slinger_client).await
421 }
422
423 async fn tcp_tunnel(mut client_stream: TcpStream, uri: &str, send_response: bool) -> Result<()> {
425 let (host, port) = Self::parse_host_port(uri)?;
426 let addr = format!("{}:{}", host, port);
427
428 let mut target_stream = TcpStream::connect(&addr)
430 .await
431 .map_err(|e| Error::connection_error(format!("Failed to connect to {}: {}", addr, e)))?;
432
433 if send_response {
434 client_stream
435 .write_all(b"HTTP/1.1 200 Connection Established\r\n\r\n")
436 .await?;
437 }
438
439 let (mut client_read, mut client_write) = client_stream.split();
440 let (mut target_read, mut target_write) = target_stream.split();
441
442 let client_to_target = tokio::io::copy(&mut client_read, &mut target_write);
443 let target_to_client = tokio::io::copy(&mut target_read, &mut client_write);
444
445 tokio::select! {
446 _ = client_to_target => {},
447 _ = target_to_client => {},
448 }
449
450 Ok(())
451 }
452
453 async fn forward_request_via_client(
457 interceptor: Arc<RwLock<InterceptorHandler>>,
458 client: &Client,
459 request: Request,
460 ) -> Result<Option<Vec<u8>>> {
461 let handler = interceptor.read().await;
462 if let Some(modified_req) = handler.process_request(request).await? {
463 let uri = modified_req.uri().clone();
464 let method = modified_req.method().clone();
465 let headers = modified_req.headers().clone();
466 let body_data = if let Some(body) = modified_req.body() {
467 body.to_vec()
468 } else {
469 Vec::new()
470 };
471 let mut req_builder = client.request(method, uri);
472 for (name, value) in headers.iter() {
473 req_builder = req_builder.header(name, value);
474 }
475 req_builder = req_builder.body(body_data);
476 match req_builder.send().await {
477 Ok(response) => {
478 if let Some(final_response) = handler.process_response(response).await? {
479 let response_bytes = Self::serialize_http_response(&final_response)?;
480 return Ok(Some(response_bytes));
481 }
482 }
483 Err(_e) => {
484 return Ok(Some(b"HTTP/1.1 502 Bad Gateway\r\n\r\n".to_vec()));
485 }
486 }
487 }
488 Ok(None)
489 }
490
491 async fn handle_https_stream<S>(
493 mut tls_stream: S,
494 domain: String,
495 interceptor: Arc<RwLock<InterceptorHandler>>,
496 client: Client,
497 ) -> Result<()>
498 where
499 S: AsyncReadExt + AsyncWriteExt + Unpin,
500 {
501 const MAX_REQUEST_SIZE: usize = 1024 * 1024; let mut buffer = Vec::new();
504 let mut temp_buf = [0u8; 8192];
505
506 loop {
507 match tls_stream.read(&mut temp_buf).await {
508 Ok(0) => break,
509 Ok(n) => {
510 buffer.extend_from_slice(&temp_buf[..n]);
511 if buffer.len() > MAX_REQUEST_SIZE {
512 return Err(Error::invalid_request(
513 "Request size exceeds maximum allowed".to_string(),
514 ));
515 }
516 if buffer.windows(4).any(|w| w == b"\r\n\r\n") {
517 break;
518 }
519 }
520 Err(e) => return Err(Error::Io(e)),
521 }
522 }
523
524 if let Ok(request) = Self::parse_http_request(&buffer, &domain) {
526 if let Some(response_bytes) =
527 Self::forward_request_via_client(interceptor, &client, request).await?
528 {
529 tls_stream.write_all(&response_bytes).await?;
530 }
531 }
532
533 Ok(())
534 }
535
536 async fn handle_http_request<R>(
538 method: &str,
539 uri: &str,
540 mut reader: BufReader<R>,
541 interceptor: Arc<RwLock<InterceptorHandler>>,
542 client: Client,
543 ) -> Result<()>
544 where
545 R: AsyncReadExt + AsyncWriteExt + Unpin,
546 {
547 const MAX_HEADERS_SIZE: usize = 64 * 1024; let mut headers_buf = Vec::new();
550 loop {
551 let mut line = String::new();
552 reader.read_line(&mut line).await?;
553 if line == "\r\n" || line == "\n" {
554 break;
555 }
556 headers_buf.extend_from_slice(line.as_bytes());
557
558 if headers_buf.len() > MAX_HEADERS_SIZE {
560 return Err(Error::invalid_request(
561 "Headers size exceeds maximum allowed".to_string(),
562 ));
563 }
564 }
565
566 let mut request_builder = http::Request::builder()
568 .method(method)
569 .uri(uri)
570 .version(Version::HTTP_11);
571
572 for line in String::from_utf8_lossy(&headers_buf).lines() {
574 if let Some(idx) = line.find(':') {
575 let (name, value) = line.split_at(idx);
576 let value = value[1..].trim();
577 request_builder = request_builder.header(name.trim(), value);
578 }
579 }
580
581 let http_request = request_builder.body(Bytes::new())?;
582 let request: Request = http_request.into();
583
584 if let Some(response_bytes) =
586 Self::forward_request_via_client(interceptor, &client, request).await?
587 {
588 let mut stream = reader.into_inner();
589 stream.write_all(&response_bytes).await?;
590 }
591
592 Ok(())
593 }
594
595 fn create_tls_server_config(
597 cert_chain: Vec<CertificateDer<'static>>,
598 key: PrivateKeyDer<'static>,
599 ) -> Result<ServerConfig> {
600 let config = ServerConfig::builder()
601 .with_no_client_auth()
602 .with_single_cert(cert_chain, key)
603 .map_err(|e| Error::tls_error(format!("Failed to create TLS config: {}", e)))?;
604
605 Ok(config)
606 }
607
608 fn parse_host_port(uri: &str) -> Result<(String, u16)> {
610 let parts: Vec<&str> = uri.split(':').collect();
611 if parts.len() != 2 {
612 return Err(Error::invalid_request(format!("Invalid URI: {}", uri)));
613 }
614
615 let host = parts[0].to_string();
616 let port = parts[1]
617 .parse::<u16>()
618 .map_err(|_| Error::invalid_request(format!("Invalid port: {}", parts[1])))?;
619
620 Ok((host, port))
621 }
622
623 fn parse_http_request(buffer: &[u8], domain: &str) -> Result<Request> {
625 let request_str = String::from_utf8_lossy(buffer);
626 let mut lines = request_str.lines();
627
628 let request_line = lines
629 .next()
630 .ok_or_else(|| Error::invalid_request("Empty request".to_string()))?;
631 let parts: Vec<&str> = request_line.split_whitespace().collect();
632 if parts.len() < 3 {
633 return Err(Error::invalid_request("Invalid request line".to_string()));
634 }
635
636 let method = parts[0];
637 let path = parts[1];
638 let uri = if path.starts_with("http://") || path.starts_with("https://") {
639 path.to_string()
640 } else {
641 format!("https://{}{}", domain, path)
642 };
643
644 let mut request_builder = http::Request::builder()
645 .method(method)
646 .uri(uri)
647 .version(Version::HTTP_11);
648
649 for line in lines {
650 if line.is_empty() {
651 break;
652 }
653 if let Some(idx) = line.find(':') {
654 let (name, value) = line.split_at(idx);
655 let value = value[1..].trim();
656 request_builder = request_builder.header(name.trim(), value);
657 }
658 }
659
660 let http_request = request_builder.body(Bytes::new())?;
661 Ok(http_request.into())
662 }
663
664 fn serialize_http_response(response: &Response) -> Result<Vec<u8>> {
666 let mut buf = Vec::new();
667
668 let status = response.status_code();
670 let status_line = format!(
671 "HTTP/1.1 {} {}\r\n",
672 status.as_u16(),
673 status.canonical_reason().unwrap_or("Unknown")
674 );
675 buf.extend_from_slice(status_line.as_bytes());
676
677 for (name, value) in response.headers() {
679 buf.extend_from_slice(name.as_str().as_bytes());
680 buf.extend_from_slice(b": ");
681 buf.extend_from_slice(value.as_bytes());
682 buf.extend_from_slice(b"\r\n");
683 }
684 buf.extend_from_slice(b"\r\n");
686 if let Some(body) = response.body() {
688 buf.extend_from_slice(body.as_ref());
689 }
690 Ok(buf)
691 }
692}