Skip to main content

soth_mitm/
proxy.rs

1use std::fs;
2#[cfg(unix)]
3use std::io::Write;
4use std::path::Path;
5use std::sync::Arc;
6use std::time::Duration;
7
8use crate::server::RuntimeGovernor;
9use tokio::sync::Mutex;
10use tokio::task::JoinHandle;
11
12use crate::ca::CertificateAuthority;
13use crate::config::MitmConfig;
14use crate::errors::MitmError;
15use crate::handler::InterceptHandler;
16use crate::metrics::{ProxyMetrics, ProxyMetricsStore};
17use crate::runtime::{build_runtime_server, RuntimeConfigHandle};
18
19/// The intercepting proxy instance, ready to run or start as a background task.
20///
21/// Constructed via [`MitmProxyBuilder`](crate::MitmProxyBuilder).
22pub struct MitmProxy<H: InterceptHandler> {
23    config: MitmConfig,
24    handler: Arc<H>,
25    ca: Option<CertificateAuthority>,
26    metrics_store: Arc<ProxyMetricsStore>,
27}
28
29impl<H: InterceptHandler> MitmProxy<H> {
30    pub(crate) fn new(
31        config: MitmConfig,
32        handler: Arc<H>,
33        ca: Option<CertificateAuthority>,
34        metrics_store: Arc<ProxyMetricsStore>,
35    ) -> Self {
36        Self {
37            config,
38            handler,
39            ca,
40            metrics_store,
41        }
42    }
43
44    /// Runs the proxy, blocking until it shuts down or encounters a fatal error.
45    pub async fn run(self) -> Result<(), MitmError> {
46        self.prepare_ca_material().await?;
47        let runtime_bundle = build_runtime_server(
48            &self.config,
49            Arc::clone(&self.handler),
50            Arc::clone(&self.metrics_store),
51        )?;
52        runtime_bundle.server.run().await.map_err(MitmError::from)
53    }
54
55    /// Starts the proxy as a background task and returns a handle for shutdown and metrics.
56    pub async fn start(self) -> Result<MitmProxyHandle, MitmError> {
57        self.prepare_ca_material().await?;
58        let runtime_bundle = build_runtime_server(
59            &self.config,
60            Arc::clone(&self.handler),
61            Arc::clone(&self.metrics_store),
62        )?;
63        let runtime_config = runtime_bundle.config_handle.clone();
64        let runtime_governor = runtime_bundle.server.runtime_observability_handle();
65        let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
66
67        let join_handle = tokio::spawn(async move {
68            runtime_bundle
69                .server
70                .run_until_shutdown(shutdown_rx)
71                .await
72                .map_err(MitmError::from)
73        });
74        Ok(MitmProxyHandle {
75            join_handle: Arc::new(Mutex::new(Some(join_handle))),
76            metrics_store: Arc::clone(&self.metrics_store),
77            runtime_config,
78            runtime_governor,
79            shutdown_tx,
80        })
81    }
82
83    /// Like [`start`] but uses a pre-bound TCP listener inherited from a
84    /// supervisor process, enabling zero-downtime restarts.
85    pub async fn start_with_listener(
86        self,
87        listener: tokio::net::TcpListener,
88    ) -> Result<MitmProxyHandle, MitmError> {
89        self.prepare_ca_material().await?;
90        let runtime_bundle = build_runtime_server(
91            &self.config,
92            Arc::clone(&self.handler),
93            Arc::clone(&self.metrics_store),
94        )?;
95        let runtime_config = runtime_bundle.config_handle.clone();
96        let runtime_governor = runtime_bundle.server.runtime_observability_handle();
97        let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
98
99        let join_handle = tokio::spawn(async move {
100            runtime_bundle
101                .server
102                .run_until_shutdown_with_listener(listener, shutdown_rx)
103                .await
104                .map_err(MitmError::from)
105        });
106        Ok(MitmProxyHandle {
107            join_handle: Arc::new(Mutex::new(Some(join_handle))),
108            metrics_store: Arc::clone(&self.metrics_store),
109            runtime_config,
110            runtime_governor,
111            shutdown_tx,
112        })
113    }
114
115    async fn prepare_ca_material(&self) -> Result<(), MitmError> {
116        let ca = self.ca.clone();
117        let cert_path = self.config.tls.ca_cert_path.clone();
118        let key_path = self.config.tls.ca_key_path.clone();
119
120        if ca.is_none() {
121            return Ok(());
122        }
123        let ca = ca.unwrap();
124
125        tokio::task::spawn_blocking(move || -> Result<(), MitmError> {
126            if let Some(parent) = cert_path.parent() {
127                if !parent.as_os_str().is_empty() {
128                    fs::create_dir_all(parent)?;
129                }
130            }
131            if let Some(parent) = key_path.parent() {
132                if !parent.as_os_str().is_empty() {
133                    fs::create_dir_all(parent)?;
134                }
135            }
136            fs::write(&cert_path, &ca.cert_pem)?;
137            write_private_key_file(&key_path, &ca.key_pem)?;
138            Ok(())
139        })
140        .await
141        .map_err(MitmError::Join)?
142    }
143}
144
145/// Handle to a running proxy, providing shutdown, config reload, and metrics access.
146pub struct MitmProxyHandle {
147    join_handle: Arc<Mutex<Option<JoinHandle<Result<(), MitmError>>>>>,
148    metrics_store: Arc<ProxyMetricsStore>,
149    runtime_config: RuntimeConfigHandle,
150    runtime_governor: Arc<RuntimeGovernor>,
151    shutdown_tx: tokio::sync::watch::Sender<bool>,
152}
153
154impl MitmProxyHandle {
155    /// Hot-reloads the proxy configuration. Currently only interception scope
156    /// changes are supported; other field changes will return an error.
157    pub async fn reload(&self, next_config: MitmConfig) -> Result<(), MitmError> {
158        self.runtime_config.apply_reload(&next_config)?;
159        Ok(())
160    }
161
162    /// Returns a snapshot of the currently active configuration.
163    pub async fn current_config(&self) -> MitmConfig {
164        self.runtime_config.current_config()
165    }
166
167    /// Gracefully shuts down the proxy, draining active flows before the deadline.
168    pub async fn shutdown(self, timeout: Duration) -> Result<(), MitmError> {
169        let mut guard = self.join_handle.lock().await;
170        let Some(handle) = guard.take() else {
171            return Ok(());
172        };
173        drop(guard);
174        let mut handle = handle;
175
176        let _ = self.shutdown_tx.send(true);
177        let deadline = tokio::time::Instant::now() + timeout;
178        let drained =
179            wait_for_active_flows_to_drain(Arc::clone(&self.runtime_governor), deadline).await;
180        if !drained {
181            handle.abort();
182            let _ = tokio::time::timeout(Duration::from_millis(100), &mut handle).await;
183            return Err(shutdown_timeout_error());
184        }
185
186        let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
187        match tokio::time::timeout(remaining, &mut handle).await {
188            Ok(join_result) => match join_result {
189                Ok(result) => result,
190                Err(error) if error.is_cancelled() => Ok(()),
191                Err(error) => Err(MitmError::Join(error)),
192            },
193            Err(_) => {
194                handle.abort();
195                let _ = tokio::time::timeout(Duration::from_millis(100), &mut handle).await;
196                Err(shutdown_timeout_error())
197            }
198        }
199    }
200
201    pub fn metrics(&self) -> ProxyMetrics {
202        self.metrics_store.snapshot()
203    }
204}
205
206fn shutdown_timeout_error() -> MitmError {
207    MitmError::Io(std::io::Error::new(
208        std::io::ErrorKind::TimedOut,
209        "timed out waiting for proxy shutdown",
210    ))
211}
212
213async fn wait_for_active_flows_to_drain(
214    runtime_governor: Arc<RuntimeGovernor>,
215    deadline: tokio::time::Instant,
216) -> bool {
217    loop {
218        if runtime_governor.snapshot().active_flows == 0 {
219            return true;
220        }
221
222        let now = tokio::time::Instant::now();
223        if now >= deadline {
224            return false;
225        }
226
227        let sleep_for = (deadline - now).min(Duration::from_millis(25));
228        tokio::time::sleep(sleep_for).await;
229    }
230}
231
232fn write_private_key_file(path: &Path, key_pem: &[u8]) -> std::io::Result<()> {
233    #[cfg(unix)]
234    {
235        use std::os::unix::fs::{OpenOptionsExt, PermissionsExt};
236
237        let mut file = fs::OpenOptions::new()
238            .create(true)
239            .truncate(true)
240            .write(true)
241            .mode(0o600)
242            .open(path)?;
243        file.write_all(key_pem)?;
244        file.flush()?;
245        fs::set_permissions(path, fs::Permissions::from_mode(0o600))?;
246        Ok(())
247    }
248
249    #[cfg(not(unix))]
250    {
251        fs::write(path, key_pem)?;
252        Ok(())
253    }
254}
255
256#[cfg(test)]
257mod tests {
258    use std::fs;
259    #[cfg(unix)]
260    use std::os::unix::fs::PermissionsExt;
261    use std::sync::Arc;
262    use std::time::Duration;
263
264    use crate::server::{RuntimeBudgetConfig, RuntimeGovernor};
265    use tokio::sync::Mutex;
266
267    use super::{write_private_key_file, MitmProxyHandle};
268    use crate::config::MitmConfig;
269    use crate::errors::MitmError;
270    use crate::metrics::ProxyMetricsStore;
271    use crate::runtime::RuntimeConfigHandle;
272
273    fn build_handle(
274        runtime_governor: Arc<RuntimeGovernor>,
275        shutdown_tx: tokio::sync::watch::Sender<bool>,
276        join_handle: Option<tokio::task::JoinHandle<Result<(), MitmError>>>,
277    ) -> MitmProxyHandle {
278        let mut config = MitmConfig::default();
279        config
280            .interception
281            .destinations
282            .push("api.example.com:443".to_string());
283        MitmProxyHandle {
284            join_handle: Arc::new(Mutex::new(join_handle)),
285            metrics_store: Arc::new(ProxyMetricsStore::default()),
286            runtime_config: RuntimeConfigHandle::from_config(&config)
287                .expect("runtime config handle must build"),
288            runtime_governor,
289            shutdown_tx,
290        }
291    }
292
293    #[tokio::test]
294    async fn shutdown_noop_when_handle_already_consumed() {
295        let runtime_governor = Arc::new(RuntimeGovernor::new(RuntimeBudgetConfig::default()));
296        let (shutdown_tx, _) = tokio::sync::watch::channel(false);
297        let handle = build_handle(runtime_governor, shutdown_tx, None);
298        handle
299            .shutdown(Duration::from_millis(10))
300            .await
301            .expect("shutdown should be a no-op when handle is empty");
302    }
303
304    #[tokio::test]
305    async fn shutdown_drains_active_flows_before_joining_runtime() {
306        let runtime_governor = Arc::new(RuntimeGovernor::new(RuntimeBudgetConfig::default()));
307        let permit = runtime_governor
308            .clone()
309            .try_acquire_flow_permit()
310            .expect("flow permit");
311        let flow_guard = runtime_governor.begin_flow(permit);
312        let (shutdown_tx, mut shutdown_rx) = tokio::sync::watch::channel(false);
313        let join_handle = tokio::spawn(async move {
314            let _ = shutdown_rx.changed().await;
315            Ok(())
316        });
317        let guard_task = tokio::spawn(async move {
318            tokio::time::sleep(Duration::from_millis(60)).await;
319            drop(flow_guard);
320        });
321
322        let handle = build_handle(
323            Arc::clone(&runtime_governor),
324            shutdown_tx,
325            Some(join_handle),
326        );
327        let started = std::time::Instant::now();
328        handle
329            .shutdown(Duration::from_millis(250))
330            .await
331            .expect("shutdown should wait for active flow to drain");
332        assert!(
333            started.elapsed() >= Duration::from_millis(55),
334            "shutdown must wait for in-flight flow drain window"
335        );
336        guard_task.await.expect("guard task should complete");
337    }
338
339    #[tokio::test]
340    async fn shutdown_returns_timeout_when_active_flows_do_not_drain() {
341        let runtime_governor = Arc::new(RuntimeGovernor::new(RuntimeBudgetConfig::default()));
342        let permit = runtime_governor
343            .clone()
344            .try_acquire_flow_permit()
345            .expect("flow permit");
346        let _flow_guard = runtime_governor.begin_flow(permit);
347        let (shutdown_tx, mut shutdown_rx) = tokio::sync::watch::channel(false);
348        let join_handle = tokio::spawn(async move {
349            let _ = shutdown_rx.changed().await;
350            Ok(())
351        });
352        let handle = build_handle(runtime_governor, shutdown_tx, Some(join_handle));
353        let error = handle
354            .shutdown(Duration::from_millis(5))
355            .await
356            .expect_err("active flow not draining must force timeout");
357        match error {
358            MitmError::Io(io_error) => {
359                assert_eq!(io_error.kind(), std::io::ErrorKind::TimedOut);
360                assert!(io_error
361                    .to_string()
362                    .contains("timed out waiting for proxy shutdown"));
363            }
364            other => panic!("expected timeout IO error, got {other}"),
365        }
366    }
367
368    #[cfg(unix)]
369    #[test]
370    fn private_key_permissions_are_owner_only_on_unix() {
371        let temp_dir =
372            std::env::temp_dir().join(format!("soth-mitm-key-perm-{}", uuid::Uuid::new_v4()));
373        fs::create_dir_all(&temp_dir).expect("temp dir");
374        let key_path = temp_dir.join("ca-key.pem");
375
376        write_private_key_file(&key_path, b"key-material").expect("write private key");
377        let mode = fs::metadata(&key_path)
378            .expect("metadata")
379            .permissions()
380            .mode()
381            & 0o777;
382        assert_eq!(mode, 0o600, "private key file must be owner-readable only");
383
384        let _ = fs::remove_file(&key_path);
385        let _ = fs::remove_dir(&temp_dir);
386    }
387}