1use arc_swap::ArcSwap;
2use hyper::body::Incoming;
3use hyper::service::service_fn;
4use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder};
5use hyper_util::client::legacy::connect::HttpConnector;
6use hyper_util::client::legacy::Client;
7use hyper_util::rt::TokioExecutor;
8use hyper_util::rt::TokioIo;
9use std::collections::HashMap;
10#[cfg(feature = "tls")]
11use std::collections::HashSet;
12use std::net::SocketAddr;
13use std::sync::Arc;
14use std::time::Duration;
15use tokio::net::TcpListener;
16use tokio::sync::Semaphore;
17use tracing::{error, info, warn};
18
19#[cfg(feature = "tls")]
20use crate::proxy::tls::{build_tls_acceptor, listen_http_redirect, listen_tls};
21
22#[cfg(feature = "tls")]
23use crate::config::tls_redirect_port;
24use crate::config::{extract_hostname, resolve_listen_addr, Config};
25use crate::proxy::handler::proxy;
26
27pub struct Proxy {
77 config: Arc<ArcSwap<Config>>,
78 client: Client<HttpsConnector<HttpConnector>, Incoming>,
79 max_concurrency: usize,
80 semaphore: Arc<Semaphore>,
81}
82
83impl Proxy {
84 pub fn new(config: Config) -> Self {
97 let mut http = HttpConnector::new();
98 http.set_keepalive(Some(Duration::from_secs(60)));
99 http.set_nodelay(true);
100 let https = HttpsConnectorBuilder::new()
101 .with_native_roots()
102 .expect("Failed to load native TLS root certificates")
103 .https_or_http()
104 .enable_http1()
105 .wrap_connector(http);
106
107 let client = Client::builder(TokioExecutor::new())
108 .pool_max_idle_per_host(100)
109 .pool_idle_timeout(Duration::from_secs(90))
110 .build::<_, Incoming>(https);
111
112 let max_concurrency = std::env::var("TINY_PROXY_MAX_CONCURRENCY")
113 .ok()
114 .and_then(|v| v.parse().ok())
115 .unwrap_or_else(|| num_cpus::get() * 256);
116
117 let semaphore = Arc::new(Semaphore::new(max_concurrency));
118
119 info!(
120 "Proxy initialized with max_concurrency={} (default: {})",
121 max_concurrency,
122 num_cpus::get() * 256
123 );
124
125 Self {
126 config: Arc::new(ArcSwap::from_pointee(config)),
127 client,
128 max_concurrency,
129 semaphore,
130 }
131 }
132
133 pub fn from_shared(config: Arc<ArcSwap<Config>>) -> Self {
142 let mut http = HttpConnector::new();
143 http.set_keepalive(Some(Duration::from_secs(60)));
144 http.set_nodelay(true);
145 let https = HttpsConnectorBuilder::new()
146 .with_native_roots()
147 .expect("Failed to load native TLS root certificates")
148 .https_or_http()
149 .enable_http1()
150 .wrap_connector(http);
151
152 let client = Client::builder(TokioExecutor::new())
153 .pool_max_idle_per_host(100)
154 .pool_idle_timeout(Duration::from_secs(90))
155 .build::<_, Incoming>(https);
156
157 let max_concurrency = std::env::var("TINY_PROXY_MAX_CONCURRENCY")
158 .ok()
159 .and_then(|v| v.parse().ok())
160 .unwrap_or_else(|| num_cpus::get() * 256);
161
162 let semaphore = Arc::new(Semaphore::new(max_concurrency));
163
164 info!(
165 "Proxy initialized with max_concurrency={} (default: {})",
166 max_concurrency,
167 num_cpus::get() * 256
168 );
169
170 Self {
171 config,
172 client,
173 max_concurrency,
174 semaphore,
175 }
176 }
177
178 pub async fn start(&self, addr: &str) -> anyhow::Result<()> {
204 let addr: SocketAddr = addr.parse()?;
205 self.start_with_addr(addr).await
206 }
207
208 pub async fn start_with_addr(&self, addr: SocketAddr) -> anyhow::Result<()> {
217 let config_snapshot = self.config.load_full();
219 let tls_sites: Vec<(String, crate::config::TlsConfig)> = config_snapshot
220 .sites
221 .values()
222 .filter(|site| {
223 site_addr_matches(&site.address, &addr) && site.tls.is_some()
226 })
227 .filter_map(|site| {
228 let hostname = extract_hostname(&site.address);
230 site.tls.clone().map(|tls| (hostname.to_string(), tls))
231 })
232 .collect();
233
234 if !tls_sites.is_empty() {
235 #[cfg(feature = "tls")]
236 {
237 self.start_tls(addr, tls_sites).await
238 }
239 #[cfg(not(feature = "tls"))]
240 {
241 anyhow::bail!(
242 "TLS configuration found for {} but 'tls' feature is disabled. \
243 Refusing to start as plain HTTP (security risk). \
244 Rebuild with --features tls or remove 'tls' from config.",
245 addr
246 );
247 }
248 } else {
249 self.start_http(addr).await
250 }
251 }
252
253 pub async fn start_all(&self) -> anyhow::Result<()> {
281 let config_snapshot = self.config.load_full();
282
283 let mut socket_groups: HashMap<SocketAddr, Vec<&crate::config::SiteConfig>> =
285 HashMap::new();
286 for site in config_snapshot.sites.values() {
287 let listen_addr = resolve_listen_addr(&site.address)?;
288 socket_groups.entry(listen_addr).or_default().push(site);
289 }
290
291 let mut http_handles = Vec::new();
292
293 #[cfg(feature = "tls")]
294 let mut tls_redirects: HashSet<(SocketAddr, u16)> = HashSet::new(); for (listen_addr, sites) in socket_groups {
297 let tls_sites: Vec<_> = sites.iter().copied().filter(|s| s.tls.is_some()).collect();
298 let has_tls = !tls_sites.is_empty();
299 let has_plain = tls_sites.len() != sites.len();
300
301 if has_tls && has_plain {
302 anyhow::bail!(
303 "Mixed TLS and non-TLS sites on the same listen address {} is not supported",
304 listen_addr
305 );
306 }
307
308 if has_tls {
309 #[cfg(feature = "tls")]
310 {
311 let tls_entries: Vec<(String, crate::config::TlsConfig)> = tls_sites
312 .iter()
313 .filter_map(|s| {
314 let hostname = extract_hostname(&s.address);
315 s.tls.clone().map(|tls| (hostname.to_string(), tls))
316 })
317 .collect();
318
319 let tls_port = listen_addr.port();
320
321 let client = self.client.clone();
322 let config = self.config.clone();
323 let semaphore = self.semaphore.clone();
324
325 let acceptor = build_tls_acceptor(&tls_entries, None)?;
326 info!(
327 "Starting HTTPS listener on {} ({} domain(s))",
328 listen_addr,
329 tls_entries.len()
330 );
331
332 let handle = tokio::spawn(async move {
333 if let Err(e) =
334 listen_tls(listen_addr, acceptor, semaphore, move |req, remote_addr| {
335 let client = client.clone();
336 let config = config.clone();
337 async move {
338 let config_snapshot = config.load_full();
339 proxy(req, client, config_snapshot, remote_addr, true).await
340 }
341 })
342 .await
343 {
344 error!("TLS listener error: {}", e);
345 }
346 });
347 http_handles.push(handle);
348
349 tls_redirects.insert((
350 SocketAddr::new(listen_addr.ip(), tls_redirect_port(tls_port)),
351 tls_port,
352 ));
353 }
354
355 #[cfg(not(feature = "tls"))]
356 {
357 anyhow::bail!(
358 "TLS configuration found for {} but 'tls' feature is disabled. \
359 Refusing to start as plain HTTP (security risk). \
360 Rebuild with --features tls or remove 'tls' from config.",
361 listen_addr
362 );
363 }
364 } else {
365 let client = self.client.clone();
366 let config = self.config.clone();
367 let semaphore = self.semaphore.clone();
368 let max_concurrency = self.max_concurrency;
369
370 let handle = tokio::spawn(async move {
371 if let Err(e) =
372 Self::run_http_loop(listen_addr, client, config, semaphore, max_concurrency)
373 .await
374 {
375 error!("HTTP listener error: {}", e);
376 }
377 });
378 http_handles.push(handle);
379 }
380 }
381
382 #[cfg(feature = "tls")]
383 for (redirect_addr, tls_port) in tls_redirects {
384 info!(
385 "Starting HTTP→HTTPS redirect on http://{} → :{}",
386 redirect_addr, tls_port
387 );
388 let handle = tokio::spawn(async move {
389 match listen_http_redirect(redirect_addr, tls_port).await {
390 Ok(()) => {}
391 Err(e) => {
392 warn!(
393 "HTTP redirect on port {} failed (HTTPS on :{} still active): {}",
394 redirect_addr.port(),
395 tls_port,
396 e
397 );
398 }
399 }
400 });
401 http_handles.push(handle);
402 }
403
404 if http_handles.is_empty() {
405 warn!("No listeners configured — proxy has no sites");
406 return Ok(());
407 }
408
409 info!(
410 "Started {} listener(s), max concurrency: {} ({})",
411 http_handles.len(),
412 self.max_concurrency,
413 if self.max_concurrency == num_cpus::get() * 256 {
414 "default"
415 } else {
416 "custom"
417 }
418 );
419
420 for handle in http_handles {
423 if let Err(e) = handle.await {
424 error!("Listener task panicked: {}", e);
425 }
426 }
427
428 Ok(())
429 }
430
431 async fn start_http(&self, addr: SocketAddr) -> anyhow::Result<()> {
433 Self::run_http_loop(
434 addr,
435 self.client.clone(),
436 self.config.clone(),
437 self.semaphore.clone(),
438 self.max_concurrency,
439 )
440 .await
441 }
442
443 async fn run_http_loop(
445 addr: SocketAddr,
446 client: Client<HttpsConnector<HttpConnector>, Incoming>,
447 config: Arc<ArcSwap<Config>>,
448 semaphore: Arc<Semaphore>,
449 max_concurrency: usize,
450 ) -> anyhow::Result<()> {
451 let listener = TcpListener::bind(&addr).await?;
452 info!("Tiny Proxy listening on http://{}", addr);
453
454 loop {
455 let (stream, remote_addr) = listener.accept().await?;
456 let io = TokioIo::new(stream);
457 let client = client.clone();
458 let config = config.clone();
459 let semaphore = semaphore.clone();
460
461 match semaphore.try_acquire_owned() {
462 Ok(permit) => {
463 tokio::task::spawn(async move {
464 let _permit = permit;
465
466 let service = service_fn(move |req| {
467 let client = client.clone();
468 let config = config.clone();
469
470 async move {
471 let config_snapshot = config.load_full();
472 proxy(req, client, config_snapshot, remote_addr, false).await
473 }
474 });
475
476 let mut builder = hyper::server::conn::http1::Builder::new();
477 builder.keep_alive(true).pipeline_flush(false);
478
479 builder.serve_connection(io, service).await
480 });
481 }
482 Err(_) => {
483 warn!(
484 "Concurrency limit exceeded ({}), rejecting connection",
485 max_concurrency
486 );
487 }
488 }
489 }
490 }
491
492 #[cfg(feature = "tls")]
494 async fn start_tls(
495 &self,
496 addr: SocketAddr,
497 tls_sites: Vec<(String, crate::config::TlsConfig)>,
498 ) -> anyhow::Result<()> {
499 let acceptor = build_tls_acceptor(&tls_sites, None)?;
500 info!(
501 "Starting HTTPS listener on https://{} ({} domain(s))",
502 addr,
503 tls_sites.len()
504 );
505
506 let client = self.client.clone();
507 let config = self.config.clone();
508 let semaphore = self.semaphore.clone();
509
510 listen_tls(addr, acceptor, semaphore, move |req, remote_addr| {
511 let client = client.clone();
512 let config = config.clone();
513
514 async move {
515 let config_snapshot = config.load_full();
516 proxy(req, client, config_snapshot, remote_addr, true).await
517 }
518 })
519 .await
520 }
521
522 pub fn shared_config(&self) -> Arc<ArcSwap<Config>> {
532 self.config.clone()
533 }
534
535 pub fn config_snapshot(&self) -> Arc<Config> {
544 self.config.load_full()
545 }
546
547 pub fn max_concurrency(&self) -> usize {
553 self.max_concurrency
554 }
555
556 pub fn set_max_concurrency(&mut self, max: usize) {
567 self.max_concurrency = max;
568 self.semaphore = Arc::new(Semaphore::new(max));
569 info!("Max concurrency updated to {}", max);
570 }
571
572 pub fn update_config(&self, config: Config) {
585 info!("Configuration updated ({} sites)", config.sites.len());
586 self.config.store(Arc::new(config));
587 }
588}
589
590fn site_addr_matches(site_address: &str, listen_addr: &SocketAddr) -> bool {
597 let mut parts = site_address.rsplitn(2, ':');
598 let port_str = parts.next().unwrap_or("");
599 let host_str = parts.next().unwrap_or("");
600
601 let site_port: u16 = match port_str.parse() {
602 Ok(p) => p,
603 Err(_) => return false,
604 };
605
606 if site_port != listen_addr.port() {
607 return false;
608 }
609
610 if host_str.is_empty() || host_str == "0.0.0.0" || host_str == "::" {
612 return true; }
614
615 let site_ip = if host_str == "localhost" {
618 std::net::IpAddr::from(std::net::Ipv4Addr::new(127, 0, 0, 1))
619 } else if let Ok(ip) = host_str.parse::<std::net::IpAddr>() {
620 ip
621 } else {
622 return true;
624 };
625
626 site_ip == listen_addr.ip()
627}
628
629#[cfg(test)]
630mod tests {
631 use super::*;
632 use std::collections::HashMap;
633
634 #[test]
635 fn test_proxy_creation() {
636 let config = Config {
637 sites: HashMap::new(),
638 };
639 let proxy = Proxy::new(config);
640 let snapshot = proxy.config_snapshot();
641 assert_eq!(snapshot.sites.len(), 0);
642 }
643
644 #[tokio::test]
645 async fn test_config_access() {
646 let mut config = Config {
647 sites: HashMap::new(),
648 };
649 config.sites.insert(
650 "localhost:8080".to_string(),
651 crate::config::SiteConfig {
652 address: "localhost:8080".to_string(),
653 directives: vec![],
654 tls: None,
655 },
656 );
657
658 let proxy = Proxy::new(config);
659 let snapshot = proxy.config_snapshot();
660 assert_eq!(snapshot.sites.len(), 1);
661 assert!(snapshot.sites.contains_key("localhost:8080"));
662 }
663
664 #[tokio::test]
665 async fn test_config_update() {
666 let config1 = Config {
667 sites: HashMap::new(),
668 };
669 let proxy = Proxy::new(config1);
670 let snapshot = proxy.config_snapshot();
671 assert_eq!(snapshot.sites.len(), 0);
672
673 let mut config2 = Config {
674 sites: HashMap::new(),
675 };
676 config2.sites.insert(
677 "test.local".to_string(),
678 crate::config::SiteConfig {
679 address: "test.local".to_string(),
680 directives: vec![],
681 tls: None,
682 },
683 );
684
685 proxy.update_config(config2);
686 let snapshot = proxy.config_snapshot();
687 assert_eq!(snapshot.sites.len(), 1);
688 assert!(snapshot.sites.contains_key("test.local"));
689 }
690
691 #[tokio::test]
692 async fn test_shared_config_handle() {
693 let config = Config {
694 sites: HashMap::new(),
695 };
696 let proxy = Proxy::new(config);
697
698 let handle = proxy.shared_config();
699
700 {
702 let mut m = HashMap::new();
703 m.insert(
704 "shared.local".to_string(),
705 crate::config::SiteConfig {
706 address: "shared.local".to_string(),
707 directives: vec![],
708 tls: None,
709 },
710 );
711 handle.store(Arc::new(Config { sites: m }));
712 }
713
714 let snapshot = proxy.config_snapshot();
716 assert_eq!(snapshot.sites.len(), 1);
717 assert!(snapshot.sites.contains_key("shared.local"));
718 }
719
720 #[test]
721 fn test_from_shared() {
722 let config = Config {
723 sites: HashMap::new(),
724 };
725 let shared = Arc::new(ArcSwap::from_pointee(config));
726 let proxy = Proxy::from_shared(shared.clone());
727
728 let mut m = HashMap::new();
730 m.insert(
731 "from-shared.local".to_string(),
732 crate::config::SiteConfig {
733 address: "from-shared.local".to_string(),
734 directives: vec![],
735 tls: None,
736 },
737 );
738 shared.store(Arc::new(Config { sites: m }));
739
740 let snapshot = proxy.config_snapshot();
742 assert_eq!(snapshot.sites.len(), 1);
743 assert!(snapshot.sites.contains_key("from-shared.local"));
744 }
745
746 #[test]
749 fn test_site_addr_matches_localhost() {
750 let addr: SocketAddr = "127.0.0.1:8080".parse().unwrap();
751 assert!(site_addr_matches("localhost:8080", &addr));
752 }
753
754 #[test]
755 fn test_site_addr_matches_ip() {
756 let addr: SocketAddr = "0.0.0.0:443".parse().unwrap();
757 assert!(site_addr_matches("0.0.0.0:443", &addr));
758 }
759
760 #[test]
761 fn test_site_addr_matches_hostname_by_port() {
762 let addr: SocketAddr = "0.0.0.0:443".parse().unwrap();
763 assert!(site_addr_matches("example.com:443", &addr));
765 }
766
767 #[test]
768 fn test_site_addr_matches_port_mismatch() {
769 let addr: SocketAddr = "0.0.0.0:443".parse().unwrap();
770 assert!(!site_addr_matches("example.com:8443", &addr));
771 }
772
773 #[test]
774 fn test_site_addr_matches_wildcard_host() {
775 let addr: SocketAddr = "0.0.0.0:9090".parse().unwrap();
776 assert!(site_addr_matches(":9090", &addr));
777 }
778}