1use hyper::body::Incoming;
2use hyper::service::service_fn;
3use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder};
4use hyper_util::client::legacy::connect::HttpConnector;
5use hyper_util::client::legacy::Client;
6use hyper_util::rt::TokioExecutor;
7use hyper_util::rt::TokioIo;
8use std::collections::{HashMap, HashSet};
9use std::net::SocketAddr;
10use std::sync::Arc;
11use std::time::Duration;
12use tokio::net::TcpListener;
13use tokio::sync::{RwLock, Semaphore};
14use tracing::{error, info, warn};
15
16#[cfg(feature = "tls")]
17use crate::proxy::tls::{build_tls_acceptor, listen_http_redirect, listen_tls};
18
19use crate::config::{extract_hostname, resolve_listen_addr, tls_redirect_port, Config};
20use crate::proxy::handler::proxy;
21
22pub struct Proxy {
76 config: Arc<RwLock<Config>>,
77 client: Client<HttpsConnector<HttpConnector>, Incoming>,
78 max_concurrency: usize,
79 semaphore: Arc<Semaphore>,
80}
81
82impl Proxy {
83 pub fn new(config: Config) -> Self {
96 let mut http = HttpConnector::new();
97 http.set_keepalive(Some(Duration::from_secs(60)));
98 http.set_nodelay(true);
99 let https = HttpsConnectorBuilder::new()
100 .with_native_roots()
101 .expect("Failed to load native TLS root certificates")
102 .https_or_http()
103 .enable_http1()
104 .wrap_connector(http);
105
106 let client = Client::builder(TokioExecutor::new())
107 .pool_max_idle_per_host(100)
108 .pool_idle_timeout(Duration::from_secs(90))
109 .build::<_, Incoming>(https);
110
111 let max_concurrency = std::env::var("TINY_PROXY_MAX_CONCURRENCY")
112 .ok()
113 .and_then(|v| v.parse().ok())
114 .unwrap_or_else(|| num_cpus::get() * 256);
115
116 let semaphore = Arc::new(Semaphore::new(max_concurrency));
117
118 info!(
119 "Proxy initialized with max_concurrency={} (default: {})",
120 max_concurrency,
121 num_cpus::get() * 256
122 );
123
124 Self {
125 config: Arc::new(RwLock::new(config)),
126 client,
127 max_concurrency,
128 semaphore,
129 }
130 }
131
132 pub fn from_shared(config: Arc<RwLock<Config>>) -> Self {
141 let mut http = HttpConnector::new();
142 http.set_keepalive(Some(Duration::from_secs(60)));
143 http.set_nodelay(true);
144 let https = HttpsConnectorBuilder::new()
145 .with_native_roots()
146 .expect("Failed to load native TLS root certificates")
147 .https_or_http()
148 .enable_http1()
149 .wrap_connector(http);
150
151 let client = Client::builder(TokioExecutor::new())
152 .pool_max_idle_per_host(100)
153 .pool_idle_timeout(Duration::from_secs(90))
154 .build::<_, Incoming>(https);
155
156 let max_concurrency = std::env::var("TINY_PROXY_MAX_CONCURRENCY")
157 .ok()
158 .and_then(|v| v.parse().ok())
159 .unwrap_or_else(|| num_cpus::get() * 256);
160
161 let semaphore = Arc::new(Semaphore::new(max_concurrency));
162
163 info!(
164 "Proxy initialized with max_concurrency={} (default: {})",
165 max_concurrency,
166 num_cpus::get() * 256
167 );
168
169 Self {
170 config,
171 client,
172 max_concurrency,
173 semaphore,
174 }
175 }
176
177 pub async fn start(&self, addr: &str) -> anyhow::Result<()> {
203 let addr: SocketAddr = addr.parse()?;
204 self.start_with_addr(addr).await
205 }
206
207 pub async fn start_with_addr(&self, addr: SocketAddr) -> anyhow::Result<()> {
216 let config_snapshot = self.config.read().await.clone();
218 let tls_sites: Vec<(String, crate::config::TlsConfig)> = config_snapshot
219 .sites
220 .values()
221 .filter(|site| {
222 site_addr_matches(&site.address, &addr) && site.tls.is_some()
225 })
226 .filter_map(|site| {
227 let hostname = extract_hostname(&site.address);
229 site.tls.clone().map(|tls| (hostname.to_string(), tls))
230 })
231 .collect();
232
233 if !tls_sites.is_empty() {
234 #[cfg(feature = "tls")]
235 {
236 self.start_tls(addr, tls_sites).await
237 }
238 #[cfg(not(feature = "tls"))]
239 {
240 anyhow::bail!(
241 "TLS configuration found for {} but 'tls' feature is disabled. \
242 Refusing to start as plain HTTP (security risk). \
243 Rebuild with --features tls or remove 'tls' from config.",
244 addr
245 );
246 }
247 } else {
248 self.start_http(addr).await
249 }
250 }
251
252 pub async fn start_all(&self) -> anyhow::Result<()> {
280 let config_snapshot = self.config.read().await.clone();
281
282 let mut socket_groups: HashMap<SocketAddr, Vec<&crate::config::SiteConfig>> =
284 HashMap::new();
285 for site in config_snapshot.sites.values() {
286 let listen_addr = resolve_listen_addr(&site.address)?;
287 socket_groups.entry(listen_addr).or_default().push(site);
288 }
289
290 let mut http_handles = Vec::new();
291 let mut tls_redirects: HashSet<(SocketAddr, u16)> = HashSet::new(); for (listen_addr, sites) in socket_groups {
294 let tls_sites: Vec<_> = sites.iter().copied().filter(|s| s.tls.is_some()).collect();
295 let has_tls = !tls_sites.is_empty();
296 let has_plain = tls_sites.len() != sites.len();
297
298 if has_tls && has_plain {
299 anyhow::bail!(
300 "Mixed TLS and non-TLS sites on the same listen address {} is not supported",
301 listen_addr
302 );
303 }
304
305 if has_tls {
306 #[cfg(feature = "tls")]
307 {
308 let tls_entries: Vec<(String, crate::config::TlsConfig)> = tls_sites
309 .iter()
310 .filter_map(|s| {
311 let hostname = extract_hostname(&s.address);
312 s.tls.clone().map(|tls| (hostname.to_string(), tls))
313 })
314 .collect();
315
316 let tls_port = listen_addr.port();
317
318 let client = self.client.clone();
319 let config = self.config.clone();
320 let semaphore = self.semaphore.clone();
321
322 let acceptor = build_tls_acceptor(&tls_entries, None)?;
323 info!(
324 "Starting HTTPS listener on {} ({} domain(s))",
325 listen_addr,
326 tls_entries.len()
327 );
328
329 let handle = tokio::spawn(async move {
330 if let Err(e) =
331 listen_tls(listen_addr, acceptor, semaphore, move |req, remote_addr| {
332 let client = client.clone();
333 let config = config.clone();
334 async move {
335 let config_guard = config.read().await;
336 let config_snapshot = Arc::new(config_guard.clone());
337 drop(config_guard);
338 proxy(req, client, config_snapshot, remote_addr, true).await
339 }
340 })
341 .await
342 {
343 error!("TLS listener error: {}", e);
344 }
345 });
346 http_handles.push(handle);
347
348 tls_redirects.insert((
349 SocketAddr::new(listen_addr.ip(), tls_redirect_port(tls_port)),
350 tls_port,
351 ));
352 }
353
354 #[cfg(not(feature = "tls"))]
355 {
356 anyhow::bail!(
357 "TLS configuration found for {} but 'tls' feature is disabled. \
358 Refusing to start as plain HTTP (security risk). \
359 Rebuild with --features tls or remove 'tls' from config.",
360 listen_addr
361 );
362 }
363 } else {
364 let client = self.client.clone();
365 let config = self.config.clone();
366 let semaphore = self.semaphore.clone();
367 let max_concurrency = self.max_concurrency;
368
369 let handle = tokio::spawn(async move {
370 if let Err(e) =
371 Self::run_http_loop(listen_addr, client, config, semaphore, max_concurrency)
372 .await
373 {
374 error!("HTTP listener error: {}", e);
375 }
376 });
377 http_handles.push(handle);
378 }
379 }
380
381 #[cfg(feature = "tls")]
382 for (redirect_addr, tls_port) in tls_redirects {
383 info!(
384 "Starting HTTP→HTTPS redirect on http://{} → :{}",
385 redirect_addr, tls_port
386 );
387 let handle = tokio::spawn(async move {
388 match listen_http_redirect(redirect_addr, tls_port).await {
389 Ok(()) => {}
390 Err(e) => {
391 warn!(
392 "HTTP redirect on port {} failed (HTTPS on :{} still active): {}",
393 redirect_addr.port(),
394 tls_port,
395 e
396 );
397 }
398 }
399 });
400 http_handles.push(handle);
401 }
402
403 if http_handles.is_empty() {
404 warn!("No listeners configured — proxy has no sites");
405 return Ok(());
406 }
407
408 info!(
409 "Started {} listener(s), max concurrency: {} ({})",
410 http_handles.len(),
411 self.max_concurrency,
412 if self.max_concurrency == num_cpus::get() * 256 {
413 "default"
414 } else {
415 "custom"
416 }
417 );
418
419 for handle in http_handles {
422 if let Err(e) = handle.await {
423 error!("Listener task panicked: {}", e);
424 }
425 }
426
427 Ok(())
428 }
429
430 async fn start_http(&self, addr: SocketAddr) -> anyhow::Result<()> {
432 Self::run_http_loop(
433 addr,
434 self.client.clone(),
435 self.config.clone(),
436 self.semaphore.clone(),
437 self.max_concurrency,
438 )
439 .await
440 }
441
442 async fn run_http_loop(
444 addr: SocketAddr,
445 client: Client<HttpsConnector<HttpConnector>, Incoming>,
446 config: Arc<RwLock<Config>>,
447 semaphore: Arc<Semaphore>,
448 max_concurrency: usize,
449 ) -> anyhow::Result<()> {
450 let listener = TcpListener::bind(&addr).await?;
451 info!("Tiny Proxy listening on http://{}", addr);
452
453 loop {
454 let (stream, remote_addr) = listener.accept().await?;
455 let io = TokioIo::new(stream);
456 let client = client.clone();
457 let config = config.clone();
458 let semaphore = semaphore.clone();
459
460 match semaphore.try_acquire_owned() {
461 Ok(permit) => {
462 tokio::task::spawn(async move {
463 let _permit = permit;
464 let service = service_fn(move |req| {
465 let client = client.clone();
466 let config = config.clone();
467
468 let config_clone = config.clone();
469 async move {
470 let config_guard = config_clone.read().await;
471 let config_snapshot = Arc::new(config_guard.clone());
472 drop(config_guard);
473 proxy(req, client, config_snapshot, remote_addr, false).await
474 }
475 });
476
477 let mut builder = hyper::server::conn::http1::Builder::new();
478 builder.keep_alive(true).pipeline_flush(false);
479
480 builder.serve_connection(io, service).await
481 });
482 }
483 Err(_) => {
484 warn!(
485 "Concurrency limit exceeded ({}), rejecting connection",
486 max_concurrency
487 );
488 }
489 }
490 }
491 }
492
493 #[cfg(feature = "tls")]
495 async fn start_tls(
496 &self,
497 addr: SocketAddr,
498 tls_sites: Vec<(String, crate::config::TlsConfig)>,
499 ) -> anyhow::Result<()> {
500 let acceptor = build_tls_acceptor(&tls_sites, None)?;
501 info!(
502 "Starting HTTPS listener on https://{} ({} domain(s))",
503 addr,
504 tls_sites.len()
505 );
506
507 let client = self.client.clone();
508 let config = self.config.clone();
509 let semaphore = self.semaphore.clone();
510
511 listen_tls(addr, acceptor, semaphore, move |req, remote_addr| {
512 let client = client.clone();
513 let config = config.clone();
514 async move {
515 let config_guard = config.read().await;
516 let config_snapshot = Arc::new(config_guard.clone());
517 drop(config_guard);
518 proxy(req, client, config_snapshot, remote_addr, true).await
519 }
520 })
521 .await
522 }
523
524 pub fn shared_config(&self) -> Arc<RwLock<Config>> {
534 self.config.clone()
535 }
536
537 pub async fn config_snapshot(&self) -> Config {
546 self.config.read().await.clone()
547 }
548
549 pub fn max_concurrency(&self) -> usize {
555 self.max_concurrency
556 }
557
558 pub fn set_max_concurrency(&mut self, max: usize) {
569 self.max_concurrency = max;
570 self.semaphore = Arc::new(Semaphore::new(max));
571 info!("Max concurrency updated to {}", max);
572 }
573
574 pub async fn update_config(&self, config: Config) {
587 let mut guard = self.config.write().await;
588 info!("Configuration updated ({} sites)", config.sites.len());
589 *guard = config;
590 }
591}
592
593fn site_addr_matches(site_address: &str, listen_addr: &SocketAddr) -> bool {
600 let mut parts = site_address.rsplitn(2, ':');
601 let port_str = parts.next().unwrap_or("");
602 let host_str = parts.next().unwrap_or("");
603
604 let site_port: u16 = match port_str.parse() {
605 Ok(p) => p,
606 Err(_) => return false,
607 };
608
609 if site_port != listen_addr.port() {
610 return false;
611 }
612
613 if host_str.is_empty() || host_str == "0.0.0.0" || host_str == "::" {
615 return true; }
617
618 let site_ip = if host_str == "localhost" {
621 std::net::IpAddr::from(std::net::Ipv4Addr::new(127, 0, 0, 1))
622 } else if let Ok(ip) = host_str.parse::<std::net::IpAddr>() {
623 ip
624 } else {
625 return true;
627 };
628
629 site_ip == listen_addr.ip()
630}
631
632#[cfg(test)]
633mod tests {
634 use super::*;
635 use std::collections::HashMap;
636
637 #[test]
638 fn test_proxy_creation() {
639 let config = Config {
640 sites: HashMap::new(),
641 };
642 let proxy = Proxy::new(config);
643 let rt = tokio::runtime::Runtime::new().unwrap();
645 let snapshot = rt.block_on(proxy.config_snapshot());
646 assert_eq!(snapshot.sites.len(), 0);
647 }
648
649 #[tokio::test]
650 async fn test_config_access() {
651 let mut config = Config {
652 sites: HashMap::new(),
653 };
654 config.sites.insert(
655 "localhost:8080".to_string(),
656 crate::config::SiteConfig {
657 address: "localhost:8080".to_string(),
658 directives: vec![],
659 tls: None,
660 },
661 );
662
663 let proxy = Proxy::new(config);
664 let snapshot = proxy.config_snapshot().await;
665 assert_eq!(snapshot.sites.len(), 1);
666 assert!(snapshot.sites.contains_key("localhost:8080"));
667 }
668
669 #[tokio::test]
670 async fn test_config_update() {
671 let config1 = Config {
672 sites: HashMap::new(),
673 };
674 let proxy = Proxy::new(config1);
675 let snapshot = proxy.config_snapshot().await;
676 assert_eq!(snapshot.sites.len(), 0);
677
678 let mut config2 = Config {
679 sites: HashMap::new(),
680 };
681 config2.sites.insert(
682 "test.local".to_string(),
683 crate::config::SiteConfig {
684 address: "test.local".to_string(),
685 directives: vec![],
686 tls: None,
687 },
688 );
689
690 proxy.update_config(config2).await;
691 let snapshot = proxy.config_snapshot().await;
692 assert_eq!(snapshot.sites.len(), 1);
693 assert!(snapshot.sites.contains_key("test.local"));
694 }
695
696 #[tokio::test]
697 async fn test_shared_config_handle() {
698 let config = Config {
699 sites: HashMap::new(),
700 };
701 let proxy = Proxy::new(config);
702
703 let handle = proxy.shared_config();
704
705 {
707 let mut guard = handle.write().await;
708 guard.sites.insert(
709 "shared.local".to_string(),
710 crate::config::SiteConfig {
711 address: "shared.local".to_string(),
712 directives: vec![],
713 tls: None,
714 },
715 );
716 }
717
718 let snapshot = proxy.config_snapshot().await;
720 assert_eq!(snapshot.sites.len(), 1);
721 assert!(snapshot.sites.contains_key("shared.local"));
722 }
723
724 #[test]
725 fn test_from_shared() {
726 let config = Config {
727 sites: HashMap::new(),
728 };
729 let shared = Arc::new(RwLock::new(config));
730 let proxy = Proxy::from_shared(shared.clone());
731
732 let rt = tokio::runtime::Runtime::new().unwrap();
734 {
735 let mut guard = rt.block_on(shared.write());
736 guard.sites.insert(
737 "from-shared.local".to_string(),
738 crate::config::SiteConfig {
739 address: "from-shared.local".to_string(),
740 directives: vec![],
741 tls: None,
742 },
743 );
744 }
745 let snapshot = rt.block_on(proxy.config_snapshot());
746 assert_eq!(snapshot.sites.len(), 1);
747 assert!(snapshot.sites.contains_key("from-shared.local"));
748 }
749
750 #[test]
753 fn test_site_addr_matches_localhost() {
754 let addr: SocketAddr = "127.0.0.1:8080".parse().unwrap();
755 assert!(site_addr_matches("localhost:8080", &addr));
756 }
757
758 #[test]
759 fn test_site_addr_matches_ip() {
760 let addr: SocketAddr = "0.0.0.0:443".parse().unwrap();
761 assert!(site_addr_matches("0.0.0.0:443", &addr));
762 }
763
764 #[test]
765 fn test_site_addr_matches_hostname_by_port() {
766 let addr: SocketAddr = "0.0.0.0:443".parse().unwrap();
767 assert!(site_addr_matches("example.com:443", &addr));
769 }
770
771 #[test]
772 fn test_site_addr_matches_port_mismatch() {
773 let addr: SocketAddr = "0.0.0.0:443".parse().unwrap();
774 assert!(!site_addr_matches("example.com:8443", &addr));
775 }
776
777 #[test]
778 fn test_site_addr_matches_wildcard_host() {
779 let addr: SocketAddr = "0.0.0.0:9090".parse().unwrap();
780 assert!(site_addr_matches(":9090", &addr));
781 }
782}