Skip to main content

soth_mitm/
proxy.rs

1use std::fs;
2#[cfg(unix)]
3use std::io::Write;
4use std::path::Path;
5use std::sync::Arc;
6use std::time::Duration;
7
8use crate::server::RuntimeGovernor;
9use tokio::sync::Mutex;
10use tokio::task::JoinHandle;
11
12use crate::ca::CertificateAuthority;
13use crate::config::MitmConfig;
14use crate::errors::MitmError;
15use crate::handler::InterceptHandler;
16use crate::metrics::{ProxyMetrics, ProxyMetricsStore};
17use crate::runtime::{build_runtime_server, RuntimeConfigHandle};
18
19/// The intercepting proxy instance, ready to run or start as a background task.
20///
21/// Constructed via [`MitmProxyBuilder`](crate::MitmProxyBuilder).
22pub struct MitmProxy<H: InterceptHandler> {
23    config: MitmConfig,
24    handler: Arc<H>,
25    ca: Option<CertificateAuthority>,
26    metrics_store: Arc<ProxyMetricsStore>,
27}
28
29impl<H: InterceptHandler> MitmProxy<H> {
30    pub(crate) fn new(
31        config: MitmConfig,
32        handler: Arc<H>,
33        ca: Option<CertificateAuthority>,
34        metrics_store: Arc<ProxyMetricsStore>,
35    ) -> Self {
36        Self {
37            config,
38            handler,
39            ca,
40            metrics_store,
41        }
42    }
43
44    /// Runs the proxy, blocking until it shuts down or encounters a fatal error.
45    pub async fn run(self) -> Result<(), MitmError> {
46        self.prepare_ca_material().await?;
47        let runtime_bundle = build_runtime_server(
48            &self.config,
49            Arc::clone(&self.handler),
50            Arc::clone(&self.metrics_store),
51        )?;
52        runtime_bundle.server.run().await.map_err(MitmError::from)
53    }
54
55    /// Starts the proxy as a background task and returns a handle for shutdown and metrics.
56    pub async fn start(self) -> Result<MitmProxyHandle, MitmError> {
57        self.prepare_ca_material().await?;
58        let runtime_bundle = build_runtime_server(
59            &self.config,
60            Arc::clone(&self.handler),
61            Arc::clone(&self.metrics_store),
62        )?;
63        let runtime_config = runtime_bundle.config_handle.clone();
64        let runtime_governor = runtime_bundle.server.runtime_observability_handle();
65        let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
66
67        let join_handle = tokio::spawn(async move {
68            runtime_bundle
69                .server
70                .run_until_shutdown(shutdown_rx)
71                .await
72                .map_err(MitmError::from)
73        });
74        Ok(MitmProxyHandle {
75            join_handle: Arc::new(Mutex::new(Some(join_handle))),
76            metrics_store: Arc::clone(&self.metrics_store),
77            runtime_config,
78            runtime_governor,
79            shutdown_tx,
80        })
81    }
82
83    async fn prepare_ca_material(&self) -> Result<(), MitmError> {
84        let ca = self.ca.clone();
85        let cert_path = self.config.tls.ca_cert_path.clone();
86        let key_path = self.config.tls.ca_key_path.clone();
87
88        if ca.is_none() {
89            return Ok(());
90        }
91        let ca = ca.unwrap();
92
93        tokio::task::spawn_blocking(move || -> Result<(), MitmError> {
94            if let Some(parent) = cert_path.parent() {
95                if !parent.as_os_str().is_empty() {
96                    fs::create_dir_all(parent)?;
97                }
98            }
99            if let Some(parent) = key_path.parent() {
100                if !parent.as_os_str().is_empty() {
101                    fs::create_dir_all(parent)?;
102                }
103            }
104            fs::write(&cert_path, &ca.cert_pem)?;
105            write_private_key_file(&key_path, &ca.key_pem)?;
106            Ok(())
107        })
108        .await
109        .map_err(MitmError::Join)?
110    }
111}
112
113/// Handle to a running proxy, providing shutdown, config reload, and metrics access.
114pub struct MitmProxyHandle {
115    join_handle: Arc<Mutex<Option<JoinHandle<Result<(), MitmError>>>>>,
116    metrics_store: Arc<ProxyMetricsStore>,
117    runtime_config: RuntimeConfigHandle,
118    runtime_governor: Arc<RuntimeGovernor>,
119    shutdown_tx: tokio::sync::watch::Sender<bool>,
120}
121
122impl MitmProxyHandle {
123    /// Hot-reloads the proxy configuration. Currently only interception scope
124    /// changes are supported; other field changes will return an error.
125    pub async fn reload(&self, next_config: MitmConfig) -> Result<(), MitmError> {
126        self.runtime_config.apply_reload(&next_config)?;
127        Ok(())
128    }
129
130    /// Returns a snapshot of the currently active configuration.
131    pub async fn current_config(&self) -> MitmConfig {
132        self.runtime_config.current_config()
133    }
134
135    /// Gracefully shuts down the proxy, draining active flows before the deadline.
136    pub async fn shutdown(self, timeout: Duration) -> Result<(), MitmError> {
137        let mut guard = self.join_handle.lock().await;
138        let Some(handle) = guard.take() else {
139            return Ok(());
140        };
141        drop(guard);
142        let mut handle = handle;
143
144        let _ = self.shutdown_tx.send(true);
145        let deadline = tokio::time::Instant::now() + timeout;
146        let drained =
147            wait_for_active_flows_to_drain(Arc::clone(&self.runtime_governor), deadline).await;
148        if !drained {
149            handle.abort();
150            let _ = tokio::time::timeout(Duration::from_millis(100), &mut handle).await;
151            return Err(shutdown_timeout_error());
152        }
153
154        let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
155        match tokio::time::timeout(remaining, &mut handle).await {
156            Ok(join_result) => match join_result {
157                Ok(result) => result,
158                Err(error) if error.is_cancelled() => Ok(()),
159                Err(error) => Err(MitmError::Join(error)),
160            },
161            Err(_) => {
162                handle.abort();
163                let _ = tokio::time::timeout(Duration::from_millis(100), &mut handle).await;
164                Err(shutdown_timeout_error())
165            }
166        }
167    }
168
169    pub fn metrics(&self) -> ProxyMetrics {
170        self.metrics_store.snapshot()
171    }
172}
173
174fn shutdown_timeout_error() -> MitmError {
175    MitmError::Io(std::io::Error::new(
176        std::io::ErrorKind::TimedOut,
177        "timed out waiting for proxy shutdown",
178    ))
179}
180
181async fn wait_for_active_flows_to_drain(
182    runtime_governor: Arc<RuntimeGovernor>,
183    deadline: tokio::time::Instant,
184) -> bool {
185    loop {
186        if runtime_governor.snapshot().active_flows == 0 {
187            return true;
188        }
189
190        let now = tokio::time::Instant::now();
191        if now >= deadline {
192            return false;
193        }
194
195        let sleep_for = (deadline - now).min(Duration::from_millis(25));
196        tokio::time::sleep(sleep_for).await;
197    }
198}
199
200fn write_private_key_file(path: &Path, key_pem: &[u8]) -> std::io::Result<()> {
201    #[cfg(unix)]
202    {
203        use std::os::unix::fs::{OpenOptionsExt, PermissionsExt};
204
205        let mut file = fs::OpenOptions::new()
206            .create(true)
207            .truncate(true)
208            .write(true)
209            .mode(0o600)
210            .open(path)?;
211        file.write_all(key_pem)?;
212        file.flush()?;
213        fs::set_permissions(path, fs::Permissions::from_mode(0o600))?;
214        Ok(())
215    }
216
217    #[cfg(not(unix))]
218    {
219        fs::write(path, key_pem)?;
220        Ok(())
221    }
222}
223
224#[cfg(test)]
225mod tests {
226    use std::fs;
227    #[cfg(unix)]
228    use std::os::unix::fs::PermissionsExt;
229    use std::sync::Arc;
230    use std::time::Duration;
231
232    use crate::server::{RuntimeBudgetConfig, RuntimeGovernor};
233    use tokio::sync::Mutex;
234
235    use super::{write_private_key_file, MitmProxyHandle};
236    use crate::config::MitmConfig;
237    use crate::errors::MitmError;
238    use crate::metrics::ProxyMetricsStore;
239    use crate::runtime::RuntimeConfigHandle;
240
241    fn build_handle(
242        runtime_governor: Arc<RuntimeGovernor>,
243        shutdown_tx: tokio::sync::watch::Sender<bool>,
244        join_handle: Option<tokio::task::JoinHandle<Result<(), MitmError>>>,
245    ) -> MitmProxyHandle {
246        let mut config = MitmConfig::default();
247        config
248            .interception
249            .destinations
250            .push("api.example.com:443".to_string());
251        MitmProxyHandle {
252            join_handle: Arc::new(Mutex::new(join_handle)),
253            metrics_store: Arc::new(ProxyMetricsStore::default()),
254            runtime_config: RuntimeConfigHandle::from_config(&config)
255                .expect("runtime config handle must build"),
256            runtime_governor,
257            shutdown_tx,
258        }
259    }
260
261    #[tokio::test]
262    async fn shutdown_noop_when_handle_already_consumed() {
263        let runtime_governor = Arc::new(RuntimeGovernor::new(RuntimeBudgetConfig::default()));
264        let (shutdown_tx, _) = tokio::sync::watch::channel(false);
265        let handle = build_handle(runtime_governor, shutdown_tx, None);
266        handle
267            .shutdown(Duration::from_millis(10))
268            .await
269            .expect("shutdown should be a no-op when handle is empty");
270    }
271
272    #[tokio::test]
273    async fn shutdown_drains_active_flows_before_joining_runtime() {
274        let runtime_governor = Arc::new(RuntimeGovernor::new(RuntimeBudgetConfig::default()));
275        let permit = runtime_governor
276            .clone()
277            .try_acquire_flow_permit()
278            .expect("flow permit");
279        let flow_guard = runtime_governor.begin_flow(permit);
280        let (shutdown_tx, mut shutdown_rx) = tokio::sync::watch::channel(false);
281        let join_handle = tokio::spawn(async move {
282            let _ = shutdown_rx.changed().await;
283            Ok(())
284        });
285        let guard_task = tokio::spawn(async move {
286            tokio::time::sleep(Duration::from_millis(60)).await;
287            drop(flow_guard);
288        });
289
290        let handle = build_handle(
291            Arc::clone(&runtime_governor),
292            shutdown_tx,
293            Some(join_handle),
294        );
295        let started = std::time::Instant::now();
296        handle
297            .shutdown(Duration::from_millis(250))
298            .await
299            .expect("shutdown should wait for active flow to drain");
300        assert!(
301            started.elapsed() >= Duration::from_millis(55),
302            "shutdown must wait for in-flight flow drain window"
303        );
304        guard_task.await.expect("guard task should complete");
305    }
306
307    #[tokio::test]
308    async fn shutdown_returns_timeout_when_active_flows_do_not_drain() {
309        let runtime_governor = Arc::new(RuntimeGovernor::new(RuntimeBudgetConfig::default()));
310        let permit = runtime_governor
311            .clone()
312            .try_acquire_flow_permit()
313            .expect("flow permit");
314        let _flow_guard = runtime_governor.begin_flow(permit);
315        let (shutdown_tx, mut shutdown_rx) = tokio::sync::watch::channel(false);
316        let join_handle = tokio::spawn(async move {
317            let _ = shutdown_rx.changed().await;
318            Ok(())
319        });
320        let handle = build_handle(runtime_governor, shutdown_tx, Some(join_handle));
321        let error = handle
322            .shutdown(Duration::from_millis(5))
323            .await
324            .expect_err("active flow not draining must force timeout");
325        match error {
326            MitmError::Io(io_error) => {
327                assert_eq!(io_error.kind(), std::io::ErrorKind::TimedOut);
328                assert!(io_error
329                    .to_string()
330                    .contains("timed out waiting for proxy shutdown"));
331            }
332            other => panic!("expected timeout IO error, got {other}"),
333        }
334    }
335
336    #[cfg(unix)]
337    #[test]
338    fn private_key_permissions_are_owner_only_on_unix() {
339        let temp_dir =
340            std::env::temp_dir().join(format!("soth-mitm-key-perm-{}", uuid::Uuid::new_v4()));
341        fs::create_dir_all(&temp_dir).expect("temp dir");
342        let key_path = temp_dir.join("ca-key.pem");
343
344        write_private_key_file(&key_path, b"key-material").expect("write private key");
345        let mode = fs::metadata(&key_path)
346            .expect("metadata")
347            .permissions()
348            .mode()
349            & 0o777;
350        assert_eq!(mode, 0o600, "private key file must be owner-readable only");
351
352        let _ = fs::remove_file(&key_path);
353        let _ = fs::remove_dir(&temp_dir);
354    }
355}