Skip to main content

tiny_proxy/proxy/
proxy.rs

1use hyper::body::Incoming;
2use hyper::service::service_fn;
3use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder};
4use hyper_util::client::legacy::connect::HttpConnector;
5use hyper_util::client::legacy::Client;
6use hyper_util::rt::TokioExecutor;
7use hyper_util::rt::TokioIo;
8use std::net::SocketAddr;
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::net::TcpListener;
12use tokio::sync::{RwLock, Semaphore};
13use tracing::{info, warn};
14
15use crate::config::Config;
16use crate::proxy::handler::proxy;
17
18/// HTTP Proxy server that can be embedded into other applications
19///
20/// This struct encapsulates the proxy state and allows programmatic control
21/// over the proxy lifecycle. Configuration is stored in an `Arc<RwLock<Config>>`
22/// so it can be hot-reloaded at runtime (e.g. via the API server).
23///
24/// # Example
25///
26/// ```no_run
27/// use tiny_proxy::{Config, Proxy};
28///
29/// #[tokio::main]
30/// async fn main() -> anyhow::Result<()> {
31///     let config = Config::from_file("file.caddy")?;
32///     let proxy = Proxy::new(config);
33///     proxy.start("127.0.0.1:8080").await?;
34///     Ok(())
35/// }
36/// ```
37///
38/// # Hot-reload Example
39///
40/// ```no_run
41/// use tiny_proxy::{Config, Proxy};
42/// use std::sync::Arc;
43/// use tokio::sync::RwLock;
44///
45/// #[tokio::main]
46/// async fn main() -> anyhow::Result<()> {
47///     let config = Config::from_file("config.caddy")?;
48///     let proxy = Proxy::new(config);
49///
50///     // Get a handle to the shared config for hot-reload
51///     let config_handle = proxy.shared_config();
52///
53///     // Spawn proxy in background
54///     let handle = tokio::spawn(async move {
55///         if let Err(e) = proxy.start("127.0.0.1:8080").await {
56///             eprintln!("Proxy error: {}", e);
57///         }
58///     });
59///
60///     // Later, update config at runtime
61///     let new_config = Config::from_file("updated-config.caddy")?;
62///     {
63///         let mut guard = config_handle.write().await;
64///         *guard = new_config;
65///     }
66///
67///     handle.await?;
68///     Ok(())
69/// }
70/// ```
71pub struct Proxy {
72    config: Arc<RwLock<Config>>,
73    client: Client<HttpsConnector<HttpConnector>, Incoming>,
74    max_concurrency: usize,
75    semaphore: Arc<Semaphore>,
76}
77
78impl Proxy {
79    /// Create a new proxy instance with the given configuration
80    ///
81    /// The configuration is internally wrapped in `Arc<RwLock<Config>>`
82    /// so it can be shared with an API server for hot-reload.
83    ///
84    /// # Arguments
85    ///
86    /// * `config` - Configuration loaded from file or constructed programmatically
87    ///
88    /// # Returns
89    ///
90    /// A new `Proxy` instance ready to be started
91    pub fn new(config: Config) -> Self {
92        let mut http = HttpConnector::new();
93        http.set_keepalive(Some(Duration::from_secs(60)));
94        http.set_nodelay(true);
95        let https = HttpsConnectorBuilder::new()
96            .with_native_roots()
97            .unwrap()
98            .https_or_http()
99            .enable_http1()
100            .wrap_connector(http);
101
102        let client = Client::builder(TokioExecutor::new())
103            .pool_max_idle_per_host(100)
104            .pool_idle_timeout(Duration::from_secs(90))
105            .build::<_, Incoming>(https);
106
107        let max_concurrency = std::env::var("TINY_PROXY_MAX_CONCURRENCY")
108            .ok()
109            .and_then(|v| v.parse().ok())
110            .unwrap_or_else(|| num_cpus::get() * 256);
111
112        let semaphore = Arc::new(Semaphore::new(max_concurrency));
113
114        info!(
115            "Proxy initialized with max_concurrency={} (default: {})",
116            max_concurrency,
117            num_cpus::get() * 256
118        );
119
120        Self {
121            config: Arc::new(RwLock::new(config)),
122            client,
123            max_concurrency,
124            semaphore,
125        }
126    }
127
128    /// Create a new proxy instance from an already shared configuration
129    ///
130    /// Use this when you already have an `Arc<RwLock<Config>>` that is
131    /// shared with an API server or other component.
132    ///
133    /// # Arguments
134    ///
135    /// * `config` - Shared configuration wrapped in `Arc<RwLock<Config>>`
136    pub fn from_shared(config: Arc<RwLock<Config>>) -> Self {
137        let mut http = HttpConnector::new();
138        http.set_keepalive(Some(Duration::from_secs(60)));
139        http.set_nodelay(true);
140        let https = HttpsConnectorBuilder::new()
141            .with_native_roots()
142            .unwrap()
143            .https_or_http()
144            .enable_http1()
145            .wrap_connector(http);
146
147        let client = Client::builder(TokioExecutor::new())
148            .pool_max_idle_per_host(100)
149            .pool_idle_timeout(Duration::from_secs(90))
150            .build::<_, Incoming>(https);
151
152        let max_concurrency = std::env::var("TINY_PROXY_MAX_CONCURRENCY")
153            .ok()
154            .and_then(|v| v.parse().ok())
155            .unwrap_or_else(|| num_cpus::get() * 256);
156
157        let semaphore = Arc::new(Semaphore::new(max_concurrency));
158
159        info!(
160            "Proxy initialized with max_concurrency={} (default: {})",
161            max_concurrency,
162            num_cpus::get() * 256
163        );
164
165        Self {
166            config,
167            client,
168            max_concurrency,
169            semaphore,
170        }
171    }
172
173    /// Start the proxy server on the specified address
174    ///
175    /// This method blocks indefinitely, handling incoming connections.
176    /// To run the proxy in the background, spawn it in a tokio task.
177    ///
178    /// # Arguments
179    ///
180    /// * `addr` - Address to listen on (e.g., "127.0.0.1:8080" or "0.0.0.0:8080")
181    ///
182    /// # Example
183    ///
184    /// ```no_run
185    /// # use tiny_proxy::{Config, Proxy};
186    /// # #[tokio::main]
187    /// # async fn main() -> anyhow::Result<()> {
188    /// # let config = Config::from_file("config.caddy")?;
189    /// # let proxy = Proxy::new(config);
190    /// proxy.start("127.0.0.1:8080").await?;
191    /// # Ok(())
192    /// # }
193    /// ```
194    ///
195    /// To run in background:
196    /// ```no_run
197    /// # use tiny_proxy::{Config, Proxy};
198    /// # #[tokio::main]
199    /// # async fn main() -> anyhow::Result<()> {
200    /// # let config = Config::from_file("config.caddy")?;
201    /// # let proxy = std::sync::Arc::new(Proxy::new(config));
202    /// let handle = tokio::spawn(async move {
203    ///     if let Err(e) = proxy.start("127.0.0.1:8080").await {
204    ///         eprintln!("Proxy error: {}", e);
205    ///     }
206    /// });
207    /// # handle.await?;
208    /// # Ok(())
209    /// # }
210    /// ```
211    pub async fn start(&self, addr: &str) -> anyhow::Result<()> {
212        let addr: SocketAddr = addr.parse()?;
213        self.start_with_addr(addr).await
214    }
215
216    /// Start the proxy server with a parsed SocketAddr
217    ///
218    /// This is a convenience method if you already have a parsed SocketAddr.
219    ///
220    /// # Arguments
221    ///
222    /// * `addr` - Parsed SocketAddr to listen on
223    pub async fn start_with_addr(&self, addr: SocketAddr) -> anyhow::Result<()> {
224        let listener = TcpListener::bind(&addr).await?;
225
226        info!("Tiny Proxy listening on http://{}", addr);
227        info!(
228            "Max concurrency: {} ({})",
229            self.max_concurrency,
230            if self.max_concurrency == num_cpus::get() * 256 {
231                "default"
232            } else {
233                "custom"
234            }
235        );
236
237        loop {
238            let (stream, remote_addr) = listener.accept().await?;
239            let io = TokioIo::new(stream);
240            let client = self.client.clone();
241            let config = self.config.clone();
242            let semaphore = self.semaphore.clone();
243
244            match semaphore.try_acquire_owned() {
245                Ok(permit) => {
246                    tokio::task::spawn(async move {
247                        let _permit = permit;
248                        let service = service_fn(move |req| {
249                            let client = client.clone();
250                            let config = config.clone();
251
252                            let config_clone = config.clone();
253                            async move {
254                                let config_guard = config_clone.read().await;
255                                let config_snapshot = Arc::new(config_guard.clone());
256                                drop(config_guard);
257                                proxy(req, client, config_snapshot, remote_addr).await
258                            }
259                        });
260
261                        let mut builder = hyper::server::conn::http1::Builder::new();
262                        builder.keep_alive(true).pipeline_flush(false);
263
264                        builder.serve_connection(io, service).await
265                    });
266                }
267                Err(_) => {
268                    warn!(
269                        "Concurrency limit exceeded ({}), rejecting connection",
270                        self.max_concurrency
271                    );
272                }
273            }
274        }
275    }
276
277    /// Get a reference to the shared configuration handle
278    ///
279    /// This returns a clone of the `Arc<RwLock<Config>>`, allowing
280    /// external code (e.g. an API server) to read and update the
281    /// configuration at runtime.
282    ///
283    /// # Returns
284    ///
285    /// A cloned `Arc<RwLock<Config>>`
286    pub fn shared_config(&self) -> Arc<RwLock<Config>> {
287        self.config.clone()
288    }
289
290    /// Get a snapshot of the current configuration
291    ///
292    /// Reads the current configuration and returns an owned clone.
293    /// This is useful for inspecting config without holding a lock.
294    ///
295    /// # Returns
296    ///
297    /// A cloned `Config`
298    pub async fn config_snapshot(&self) -> Config {
299        self.config.read().await.clone()
300    }
301
302    /// Get current concurrency limit
303    ///
304    /// # Returns
305    ///
306    /// Current maximum number of concurrent connections
307    pub fn max_concurrency(&self) -> usize {
308        self.max_concurrency
309    }
310
311    /// Update concurrency limit at runtime
312    ///
313    /// # Arguments
314    ///
315    /// * `max` - New maximum number of concurrent connections
316    ///
317    /// # Note
318    ///
319    /// This updates the semaphore immediately. New connections will use
320    /// the new limit, but existing connections are not affected.
321    pub fn set_max_concurrency(&mut self, max: usize) {
322        self.max_concurrency = max;
323        self.semaphore = Arc::new(Semaphore::new(max));
324        info!("Max concurrency updated to {}", max);
325    }
326
327    /// Update the configuration at runtime (hot-reload)
328    ///
329    /// This atomically replaces the configuration. New connections will
330    /// use the updated configuration immediately. Existing connections
331    /// will continue to use their original configuration snapshot.
332    ///
333    /// # Arguments
334    ///
335    /// * `config` - New configuration to use
336    pub async fn update_config(&self, config: Config) {
337        let mut guard = self.config.write().await;
338        info!("Configuration updated ({} sites)", config.sites.len());
339        *guard = config;
340    }
341}
342
343#[cfg(test)]
344mod tests {
345    use super::*;
346    use std::collections::HashMap;
347
348    #[test]
349    fn test_proxy_creation() {
350        let config = Config {
351            sites: HashMap::new(),
352        };
353        let proxy = Proxy::new(config);
354        // Can't check sites len synchronously anymore, use snapshot
355        let rt = tokio::runtime::Runtime::new().unwrap();
356        let snapshot = rt.block_on(proxy.config_snapshot());
357        assert_eq!(snapshot.sites.len(), 0);
358    }
359
360    #[tokio::test]
361    async fn test_config_access() {
362        let mut config = Config {
363            sites: HashMap::new(),
364        };
365        config.sites.insert(
366            "localhost:8080".to_string(),
367            crate::config::SiteConfig {
368                address: "localhost:8080".to_string(),
369                directives: vec![],
370            },
371        );
372
373        let proxy = Proxy::new(config);
374        let snapshot = proxy.config_snapshot().await;
375        assert_eq!(snapshot.sites.len(), 1);
376        assert!(snapshot.sites.contains_key("localhost:8080"));
377    }
378
379    #[tokio::test]
380    async fn test_config_update() {
381        let config1 = Config {
382            sites: HashMap::new(),
383        };
384        let proxy = Proxy::new(config1);
385        let snapshot = proxy.config_snapshot().await;
386        assert_eq!(snapshot.sites.len(), 0);
387
388        let mut config2 = Config {
389            sites: HashMap::new(),
390        };
391        config2.sites.insert(
392            "test.local".to_string(),
393            crate::config::SiteConfig {
394                address: "test.local".to_string(),
395                directives: vec![],
396            },
397        );
398
399        proxy.update_config(config2).await;
400        let snapshot = proxy.config_snapshot().await;
401        assert_eq!(snapshot.sites.len(), 1);
402        assert!(snapshot.sites.contains_key("test.local"));
403    }
404
405    #[tokio::test]
406    async fn test_shared_config_handle() {
407        let config = Config {
408            sites: HashMap::new(),
409        };
410        let proxy = Proxy::new(config);
411
412        let handle = proxy.shared_config();
413
414        // Update via the shared handle
415        {
416            let mut guard = handle.write().await;
417            guard.sites.insert(
418                "shared.local".to_string(),
419                crate::config::SiteConfig {
420                    address: "shared.local".to_string(),
421                    directives: vec![],
422                },
423            );
424        }
425
426        // Verify the proxy sees the update
427        let snapshot = proxy.config_snapshot().await;
428        assert_eq!(snapshot.sites.len(), 1);
429        assert!(snapshot.sites.contains_key("shared.local"));
430    }
431
432    #[test]
433    fn test_from_shared() {
434        let config = Config {
435            sites: HashMap::new(),
436        };
437        let shared = Arc::new(RwLock::new(config));
438        let proxy = Proxy::from_shared(shared.clone());
439
440        // Verify both point to the same config
441        let rt = tokio::runtime::Runtime::new().unwrap();
442        {
443            let mut guard = rt.block_on(shared.write());
444            guard.sites.insert(
445                "from-shared.local".to_string(),
446                crate::config::SiteConfig {
447                    address: "from-shared.local".to_string(),
448                    directives: vec![],
449                },
450            );
451        }
452        let snapshot = rt.block_on(proxy.config_snapshot());
453        assert_eq!(snapshot.sites.len(), 1);
454        assert!(snapshot.sites.contains_key("from-shared.local"));
455    }
456}