tiny_proxy/proxy/proxy.rs
1use hyper::body::Incoming;
2use hyper::service::service_fn;
3use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder};
4use hyper_util::client::legacy::connect::HttpConnector;
5use hyper_util::client::legacy::Client;
6use hyper_util::rt::TokioExecutor;
7use hyper_util::rt::TokioIo;
8use std::net::SocketAddr;
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::net::TcpListener;
12use tokio::sync::{RwLock, Semaphore};
13use tracing::{info, warn};
14
15use crate::config::Config;
16use crate::proxy::handler::proxy;
17
18/// HTTP Proxy server that can be embedded into other applications
19///
20/// This struct encapsulates the proxy state and allows programmatic control
21/// over the proxy lifecycle. Configuration is stored in an `Arc<RwLock<Config>>`
22/// so it can be hot-reloaded at runtime (e.g. via the API server).
23///
24/// # Example
25///
26/// ```no_run
27/// use tiny_proxy::{Config, Proxy};
28///
29/// #[tokio::main]
30/// async fn main() -> anyhow::Result<()> {
31/// let config = Config::from_file("file.caddy")?;
32/// let proxy = Proxy::new(config);
33/// proxy.start("127.0.0.1:8080").await?;
34/// Ok(())
35/// }
36/// ```
37///
38/// # Hot-reload Example
39///
40/// ```no_run
41/// use tiny_proxy::{Config, Proxy};
42/// use std::sync::Arc;
43/// use tokio::sync::RwLock;
44///
45/// #[tokio::main]
46/// async fn main() -> anyhow::Result<()> {
47/// let config = Config::from_file("config.caddy")?;
48/// let proxy = Proxy::new(config);
49///
50/// // Get a handle to the shared config for hot-reload
51/// let config_handle = proxy.shared_config();
52///
53/// // Spawn proxy in background
54/// let handle = tokio::spawn(async move {
55/// if let Err(e) = proxy.start("127.0.0.1:8080").await {
56/// eprintln!("Proxy error: {}", e);
57/// }
58/// });
59///
60/// // Later, update config at runtime
61/// let new_config = Config::from_file("updated-config.caddy")?;
62/// {
63/// let mut guard = config_handle.write().await;
64/// *guard = new_config;
65/// }
66///
67/// handle.await?;
68/// Ok(())
69/// }
70/// ```
71pub struct Proxy {
72 config: Arc<RwLock<Config>>,
73 client: Client<HttpsConnector<HttpConnector>, Incoming>,
74 max_concurrency: usize,
75 semaphore: Arc<Semaphore>,
76}
77
78impl Proxy {
79 /// Create a new proxy instance with the given configuration
80 ///
81 /// The configuration is internally wrapped in `Arc<RwLock<Config>>`
82 /// so it can be shared with an API server for hot-reload.
83 ///
84 /// # Arguments
85 ///
86 /// * `config` - Configuration loaded from file or constructed programmatically
87 ///
88 /// # Returns
89 ///
90 /// A new `Proxy` instance ready to be started
91 pub fn new(config: Config) -> Self {
92 let mut http = HttpConnector::new();
93 http.set_keepalive(Some(Duration::from_secs(60)));
94 http.set_nodelay(true);
95 let https = HttpsConnectorBuilder::new()
96 .with_native_roots()
97 .unwrap()
98 .https_or_http()
99 .enable_http1()
100 .wrap_connector(http);
101
102 let client = Client::builder(TokioExecutor::new())
103 .pool_max_idle_per_host(100)
104 .pool_idle_timeout(Duration::from_secs(90))
105 .build::<_, Incoming>(https);
106
107 let max_concurrency = std::env::var("TINY_PROXY_MAX_CONCURRENCY")
108 .ok()
109 .and_then(|v| v.parse().ok())
110 .unwrap_or_else(|| num_cpus::get() * 256);
111
112 let semaphore = Arc::new(Semaphore::new(max_concurrency));
113
114 info!(
115 "Proxy initialized with max_concurrency={} (default: {})",
116 max_concurrency,
117 num_cpus::get() * 256
118 );
119
120 Self {
121 config: Arc::new(RwLock::new(config)),
122 client,
123 max_concurrency,
124 semaphore,
125 }
126 }
127
128 /// Create a new proxy instance from an already shared configuration
129 ///
130 /// Use this when you already have an `Arc<RwLock<Config>>` that is
131 /// shared with an API server or other component.
132 ///
133 /// # Arguments
134 ///
135 /// * `config` - Shared configuration wrapped in `Arc<RwLock<Config>>`
136 pub fn from_shared(config: Arc<RwLock<Config>>) -> Self {
137 let mut http = HttpConnector::new();
138 http.set_keepalive(Some(Duration::from_secs(60)));
139 http.set_nodelay(true);
140 let https = HttpsConnectorBuilder::new()
141 .with_native_roots()
142 .unwrap()
143 .https_or_http()
144 .enable_http1()
145 .wrap_connector(http);
146
147 let client = Client::builder(TokioExecutor::new())
148 .pool_max_idle_per_host(100)
149 .pool_idle_timeout(Duration::from_secs(90))
150 .build::<_, Incoming>(https);
151
152 let max_concurrency = std::env::var("TINY_PROXY_MAX_CONCURRENCY")
153 .ok()
154 .and_then(|v| v.parse().ok())
155 .unwrap_or_else(|| num_cpus::get() * 256);
156
157 let semaphore = Arc::new(Semaphore::new(max_concurrency));
158
159 info!(
160 "Proxy initialized with max_concurrency={} (default: {})",
161 max_concurrency,
162 num_cpus::get() * 256
163 );
164
165 Self {
166 config,
167 client,
168 max_concurrency,
169 semaphore,
170 }
171 }
172
173 /// Start the proxy server on the specified address
174 ///
175 /// This method blocks indefinitely, handling incoming connections.
176 /// To run the proxy in the background, spawn it in a tokio task.
177 ///
178 /// # Arguments
179 ///
180 /// * `addr` - Address to listen on (e.g., "127.0.0.1:8080" or "0.0.0.0:8080")
181 ///
182 /// # Example
183 ///
184 /// ```no_run
185 /// # use tiny_proxy::{Config, Proxy};
186 /// # #[tokio::main]
187 /// # async fn main() -> anyhow::Result<()> {
188 /// # let config = Config::from_file("config.caddy")?;
189 /// # let proxy = Proxy::new(config);
190 /// proxy.start("127.0.0.1:8080").await?;
191 /// # Ok(())
192 /// # }
193 /// ```
194 ///
195 /// To run in background:
196 /// ```no_run
197 /// # use tiny_proxy::{Config, Proxy};
198 /// # #[tokio::main]
199 /// # async fn main() -> anyhow::Result<()> {
200 /// # let config = Config::from_file("config.caddy")?;
201 /// # let proxy = std::sync::Arc::new(Proxy::new(config));
202 /// let handle = tokio::spawn(async move {
203 /// if let Err(e) = proxy.start("127.0.0.1:8080").await {
204 /// eprintln!("Proxy error: {}", e);
205 /// }
206 /// });
207 /// # handle.await?;
208 /// # Ok(())
209 /// # }
210 /// ```
211 pub async fn start(&self, addr: &str) -> anyhow::Result<()> {
212 let addr: SocketAddr = addr.parse()?;
213 self.start_with_addr(addr).await
214 }
215
216 /// Start the proxy server with a parsed SocketAddr
217 ///
218 /// This is a convenience method if you already have a parsed SocketAddr.
219 ///
220 /// # Arguments
221 ///
222 /// * `addr` - Parsed SocketAddr to listen on
223 pub async fn start_with_addr(&self, addr: SocketAddr) -> anyhow::Result<()> {
224 let listener = TcpListener::bind(&addr).await?;
225
226 info!("Tiny Proxy listening on http://{}", addr);
227 info!(
228 "Max concurrency: {} ({})",
229 self.max_concurrency,
230 if self.max_concurrency == num_cpus::get() * 256 {
231 "default"
232 } else {
233 "custom"
234 }
235 );
236
237 loop {
238 let (stream, remote_addr) = listener.accept().await?;
239 let io = TokioIo::new(stream);
240 let client = self.client.clone();
241 let config = self.config.clone();
242 let semaphore = self.semaphore.clone();
243
244 match semaphore.try_acquire_owned() {
245 Ok(permit) => {
246 tokio::task::spawn(async move {
247 let _permit = permit;
248 let service = service_fn(move |req| {
249 let client = client.clone();
250 let config = config.clone();
251
252 let config_clone = config.clone();
253 async move {
254 let config_guard = config_clone.read().await;
255 let config_snapshot = Arc::new(config_guard.clone());
256 drop(config_guard);
257 proxy(req, client, config_snapshot, remote_addr).await
258 }
259 });
260
261 let mut builder = hyper::server::conn::http1::Builder::new();
262 builder.keep_alive(true).pipeline_flush(false);
263
264 builder.serve_connection(io, service).await
265 });
266 }
267 Err(_) => {
268 warn!(
269 "Concurrency limit exceeded ({}), rejecting connection",
270 self.max_concurrency
271 );
272 }
273 }
274 }
275 }
276
277 /// Get a reference to the shared configuration handle
278 ///
279 /// This returns a clone of the `Arc<RwLock<Config>>`, allowing
280 /// external code (e.g. an API server) to read and update the
281 /// configuration at runtime.
282 ///
283 /// # Returns
284 ///
285 /// A cloned `Arc<RwLock<Config>>`
286 pub fn shared_config(&self) -> Arc<RwLock<Config>> {
287 self.config.clone()
288 }
289
290 /// Get a snapshot of the current configuration
291 ///
292 /// Reads the current configuration and returns an owned clone.
293 /// This is useful for inspecting config without holding a lock.
294 ///
295 /// # Returns
296 ///
297 /// A cloned `Config`
298 pub async fn config_snapshot(&self) -> Config {
299 self.config.read().await.clone()
300 }
301
302 /// Get current concurrency limit
303 ///
304 /// # Returns
305 ///
306 /// Current maximum number of concurrent connections
307 pub fn max_concurrency(&self) -> usize {
308 self.max_concurrency
309 }
310
311 /// Update concurrency limit at runtime
312 ///
313 /// # Arguments
314 ///
315 /// * `max` - New maximum number of concurrent connections
316 ///
317 /// # Note
318 ///
319 /// This updates the semaphore immediately. New connections will use
320 /// the new limit, but existing connections are not affected.
321 pub fn set_max_concurrency(&mut self, max: usize) {
322 self.max_concurrency = max;
323 self.semaphore = Arc::new(Semaphore::new(max));
324 info!("Max concurrency updated to {}", max);
325 }
326
327 /// Update the configuration at runtime (hot-reload)
328 ///
329 /// This atomically replaces the configuration. New connections will
330 /// use the updated configuration immediately. Existing connections
331 /// will continue to use their original configuration snapshot.
332 ///
333 /// # Arguments
334 ///
335 /// * `config` - New configuration to use
336 pub async fn update_config(&self, config: Config) {
337 let mut guard = self.config.write().await;
338 info!("Configuration updated ({} sites)", config.sites.len());
339 *guard = config;
340 }
341}
342
343#[cfg(test)]
344mod tests {
345 use super::*;
346 use std::collections::HashMap;
347
348 #[test]
349 fn test_proxy_creation() {
350 let config = Config {
351 sites: HashMap::new(),
352 };
353 let proxy = Proxy::new(config);
354 // Can't check sites len synchronously anymore, use snapshot
355 let rt = tokio::runtime::Runtime::new().unwrap();
356 let snapshot = rt.block_on(proxy.config_snapshot());
357 assert_eq!(snapshot.sites.len(), 0);
358 }
359
360 #[tokio::test]
361 async fn test_config_access() {
362 let mut config = Config {
363 sites: HashMap::new(),
364 };
365 config.sites.insert(
366 "localhost:8080".to_string(),
367 crate::config::SiteConfig {
368 address: "localhost:8080".to_string(),
369 directives: vec![],
370 },
371 );
372
373 let proxy = Proxy::new(config);
374 let snapshot = proxy.config_snapshot().await;
375 assert_eq!(snapshot.sites.len(), 1);
376 assert!(snapshot.sites.contains_key("localhost:8080"));
377 }
378
379 #[tokio::test]
380 async fn test_config_update() {
381 let config1 = Config {
382 sites: HashMap::new(),
383 };
384 let proxy = Proxy::new(config1);
385 let snapshot = proxy.config_snapshot().await;
386 assert_eq!(snapshot.sites.len(), 0);
387
388 let mut config2 = Config {
389 sites: HashMap::new(),
390 };
391 config2.sites.insert(
392 "test.local".to_string(),
393 crate::config::SiteConfig {
394 address: "test.local".to_string(),
395 directives: vec![],
396 },
397 );
398
399 proxy.update_config(config2).await;
400 let snapshot = proxy.config_snapshot().await;
401 assert_eq!(snapshot.sites.len(), 1);
402 assert!(snapshot.sites.contains_key("test.local"));
403 }
404
405 #[tokio::test]
406 async fn test_shared_config_handle() {
407 let config = Config {
408 sites: HashMap::new(),
409 };
410 let proxy = Proxy::new(config);
411
412 let handle = proxy.shared_config();
413
414 // Update via the shared handle
415 {
416 let mut guard = handle.write().await;
417 guard.sites.insert(
418 "shared.local".to_string(),
419 crate::config::SiteConfig {
420 address: "shared.local".to_string(),
421 directives: vec![],
422 },
423 );
424 }
425
426 // Verify the proxy sees the update
427 let snapshot = proxy.config_snapshot().await;
428 assert_eq!(snapshot.sites.len(), 1);
429 assert!(snapshot.sites.contains_key("shared.local"));
430 }
431
432 #[test]
433 fn test_from_shared() {
434 let config = Config {
435 sites: HashMap::new(),
436 };
437 let shared = Arc::new(RwLock::new(config));
438 let proxy = Proxy::from_shared(shared.clone());
439
440 // Verify both point to the same config
441 let rt = tokio::runtime::Runtime::new().unwrap();
442 {
443 let mut guard = rt.block_on(shared.write());
444 guard.sites.insert(
445 "from-shared.local".to_string(),
446 crate::config::SiteConfig {
447 address: "from-shared.local".to_string(),
448 directives: vec![],
449 },
450 );
451 }
452 let snapshot = rt.block_on(proxy.config_snapshot());
453 assert_eq!(snapshot.sites.len(), 1);
454 assert!(snapshot.sites.contains_key("from-shared.local"));
455 }
456}