1use std::fs;
2#[cfg(unix)]
3use std::io::Write;
4use std::path::Path;
5use std::sync::Arc;
6use std::time::Duration;
7
8use crate::server::RuntimeGovernor;
9use tokio::sync::Mutex;
10use tokio::task::JoinHandle;
11
12use crate::ca::CertificateAuthority;
13use crate::config::MitmConfig;
14use crate::errors::MitmError;
15use crate::handler::InterceptHandler;
16use crate::metrics::{ProxyMetrics, ProxyMetricsStore};
17use crate::runtime::{build_runtime_server, RuntimeConfigHandle};
18
19pub struct MitmProxy<H: InterceptHandler> {
23 config: MitmConfig,
24 handler: Arc<H>,
25 ca: Option<CertificateAuthority>,
26 metrics_store: Arc<ProxyMetricsStore>,
27}
28
29impl<H: InterceptHandler> MitmProxy<H> {
30 pub(crate) fn new(
31 config: MitmConfig,
32 handler: Arc<H>,
33 ca: Option<CertificateAuthority>,
34 metrics_store: Arc<ProxyMetricsStore>,
35 ) -> Self {
36 Self {
37 config,
38 handler,
39 ca,
40 metrics_store,
41 }
42 }
43
44 pub async fn run(self) -> Result<(), MitmError> {
46 self.prepare_ca_material().await?;
47 let runtime_bundle = build_runtime_server(
48 &self.config,
49 Arc::clone(&self.handler),
50 Arc::clone(&self.metrics_store),
51 )?;
52 runtime_bundle.server.run().await.map_err(MitmError::from)
53 }
54
55 pub async fn start(self) -> Result<MitmProxyHandle, MitmError> {
57 self.prepare_ca_material().await?;
58 let runtime_bundle = build_runtime_server(
59 &self.config,
60 Arc::clone(&self.handler),
61 Arc::clone(&self.metrics_store),
62 )?;
63 let runtime_config = runtime_bundle.config_handle.clone();
64 let runtime_governor = runtime_bundle.server.runtime_observability_handle();
65 let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
66
67 let join_handle = tokio::spawn(async move {
68 runtime_bundle
69 .server
70 .run_until_shutdown(shutdown_rx)
71 .await
72 .map_err(MitmError::from)
73 });
74 Ok(MitmProxyHandle {
75 join_handle: Arc::new(Mutex::new(Some(join_handle))),
76 metrics_store: Arc::clone(&self.metrics_store),
77 runtime_config,
78 runtime_governor,
79 shutdown_tx,
80 })
81 }
82
83 pub async fn start_with_listener(
86 self,
87 listener: tokio::net::TcpListener,
88 ) -> Result<MitmProxyHandle, MitmError> {
89 self.prepare_ca_material().await?;
90 let runtime_bundle = build_runtime_server(
91 &self.config,
92 Arc::clone(&self.handler),
93 Arc::clone(&self.metrics_store),
94 )?;
95 let runtime_config = runtime_bundle.config_handle.clone();
96 let runtime_governor = runtime_bundle.server.runtime_observability_handle();
97 let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
98
99 let join_handle = tokio::spawn(async move {
100 runtime_bundle
101 .server
102 .run_until_shutdown_with_listener(listener, shutdown_rx)
103 .await
104 .map_err(MitmError::from)
105 });
106 Ok(MitmProxyHandle {
107 join_handle: Arc::new(Mutex::new(Some(join_handle))),
108 metrics_store: Arc::clone(&self.metrics_store),
109 runtime_config,
110 runtime_governor,
111 shutdown_tx,
112 })
113 }
114
115 async fn prepare_ca_material(&self) -> Result<(), MitmError> {
116 let ca = self.ca.clone();
117 let cert_path = self.config.tls.ca_cert_path.clone();
118 let key_path = self.config.tls.ca_key_path.clone();
119
120 if ca.is_none() {
121 return Ok(());
122 }
123 let ca = ca.unwrap();
124
125 tokio::task::spawn_blocking(move || -> Result<(), MitmError> {
126 if let Some(parent) = cert_path.parent() {
127 if !parent.as_os_str().is_empty() {
128 fs::create_dir_all(parent)?;
129 }
130 }
131 if let Some(parent) = key_path.parent() {
132 if !parent.as_os_str().is_empty() {
133 fs::create_dir_all(parent)?;
134 }
135 }
136 fs::write(&cert_path, &ca.cert_pem)?;
137 write_private_key_file(&key_path, &ca.key_pem)?;
138 Ok(())
139 })
140 .await
141 .map_err(MitmError::Join)?
142 }
143}
144
145pub struct MitmProxyHandle {
147 join_handle: Arc<Mutex<Option<JoinHandle<Result<(), MitmError>>>>>,
148 metrics_store: Arc<ProxyMetricsStore>,
149 runtime_config: RuntimeConfigHandle,
150 runtime_governor: Arc<RuntimeGovernor>,
151 shutdown_tx: tokio::sync::watch::Sender<bool>,
152}
153
154impl MitmProxyHandle {
155 pub async fn reload(&self, next_config: MitmConfig) -> Result<(), MitmError> {
158 self.runtime_config.apply_reload(&next_config)?;
159 Ok(())
160 }
161
162 pub async fn current_config(&self) -> MitmConfig {
164 self.runtime_config.current_config()
165 }
166
167 pub async fn shutdown(self, timeout: Duration) -> Result<(), MitmError> {
169 let mut guard = self.join_handle.lock().await;
170 let Some(handle) = guard.take() else {
171 return Ok(());
172 };
173 drop(guard);
174 let mut handle = handle;
175
176 let _ = self.shutdown_tx.send(true);
177 let deadline = tokio::time::Instant::now() + timeout;
178 let drained =
179 wait_for_active_flows_to_drain(Arc::clone(&self.runtime_governor), deadline).await;
180 if !drained {
181 handle.abort();
182 let _ = tokio::time::timeout(Duration::from_millis(100), &mut handle).await;
183 return Err(shutdown_timeout_error());
184 }
185
186 let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
187 match tokio::time::timeout(remaining, &mut handle).await {
188 Ok(join_result) => match join_result {
189 Ok(result) => result,
190 Err(error) if error.is_cancelled() => Ok(()),
191 Err(error) => Err(MitmError::Join(error)),
192 },
193 Err(_) => {
194 handle.abort();
195 let _ = tokio::time::timeout(Duration::from_millis(100), &mut handle).await;
196 Err(shutdown_timeout_error())
197 }
198 }
199 }
200
201 pub fn metrics(&self) -> ProxyMetrics {
202 self.metrics_store.snapshot()
203 }
204}
205
206fn shutdown_timeout_error() -> MitmError {
207 MitmError::Io(std::io::Error::new(
208 std::io::ErrorKind::TimedOut,
209 "timed out waiting for proxy shutdown",
210 ))
211}
212
213async fn wait_for_active_flows_to_drain(
214 runtime_governor: Arc<RuntimeGovernor>,
215 deadline: tokio::time::Instant,
216) -> bool {
217 loop {
218 if runtime_governor.snapshot().active_flows == 0 {
219 return true;
220 }
221
222 let now = tokio::time::Instant::now();
223 if now >= deadline {
224 return false;
225 }
226
227 let sleep_for = (deadline - now).min(Duration::from_millis(25));
228 tokio::time::sleep(sleep_for).await;
229 }
230}
231
232fn write_private_key_file(path: &Path, key_pem: &[u8]) -> std::io::Result<()> {
233 #[cfg(unix)]
234 {
235 use std::os::unix::fs::{OpenOptionsExt, PermissionsExt};
236
237 let mut file = fs::OpenOptions::new()
238 .create(true)
239 .truncate(true)
240 .write(true)
241 .mode(0o600)
242 .open(path)?;
243 file.write_all(key_pem)?;
244 file.flush()?;
245 fs::set_permissions(path, fs::Permissions::from_mode(0o600))?;
246 Ok(())
247 }
248
249 #[cfg(not(unix))]
250 {
251 fs::write(path, key_pem)?;
252 Ok(())
253 }
254}
255
256#[cfg(test)]
257mod tests {
258 use std::fs;
259 #[cfg(unix)]
260 use std::os::unix::fs::PermissionsExt;
261 use std::sync::Arc;
262 use std::time::Duration;
263
264 use crate::server::{RuntimeBudgetConfig, RuntimeGovernor};
265 use tokio::sync::Mutex;
266
267 use super::{write_private_key_file, MitmProxyHandle};
268 use crate::config::MitmConfig;
269 use crate::errors::MitmError;
270 use crate::metrics::ProxyMetricsStore;
271 use crate::runtime::RuntimeConfigHandle;
272
273 fn build_handle(
274 runtime_governor: Arc<RuntimeGovernor>,
275 shutdown_tx: tokio::sync::watch::Sender<bool>,
276 join_handle: Option<tokio::task::JoinHandle<Result<(), MitmError>>>,
277 ) -> MitmProxyHandle {
278 let mut config = MitmConfig::default();
279 config
280 .interception
281 .destinations
282 .push("api.example.com:443".to_string());
283 MitmProxyHandle {
284 join_handle: Arc::new(Mutex::new(join_handle)),
285 metrics_store: Arc::new(ProxyMetricsStore::default()),
286 runtime_config: RuntimeConfigHandle::from_config(&config)
287 .expect("runtime config handle must build"),
288 runtime_governor,
289 shutdown_tx,
290 }
291 }
292
293 #[tokio::test]
294 async fn shutdown_noop_when_handle_already_consumed() {
295 let runtime_governor = Arc::new(RuntimeGovernor::new(RuntimeBudgetConfig::default()));
296 let (shutdown_tx, _) = tokio::sync::watch::channel(false);
297 let handle = build_handle(runtime_governor, shutdown_tx, None);
298 handle
299 .shutdown(Duration::from_millis(10))
300 .await
301 .expect("shutdown should be a no-op when handle is empty");
302 }
303
304 #[tokio::test]
305 async fn shutdown_drains_active_flows_before_joining_runtime() {
306 let runtime_governor = Arc::new(RuntimeGovernor::new(RuntimeBudgetConfig::default()));
307 let permit = runtime_governor
308 .clone()
309 .try_acquire_flow_permit()
310 .expect("flow permit");
311 let flow_guard = runtime_governor.begin_flow(permit);
312 let (shutdown_tx, mut shutdown_rx) = tokio::sync::watch::channel(false);
313 let join_handle = tokio::spawn(async move {
314 let _ = shutdown_rx.changed().await;
315 Ok(())
316 });
317 let guard_task = tokio::spawn(async move {
318 tokio::time::sleep(Duration::from_millis(60)).await;
319 drop(flow_guard);
320 });
321
322 let handle = build_handle(
323 Arc::clone(&runtime_governor),
324 shutdown_tx,
325 Some(join_handle),
326 );
327 let started = std::time::Instant::now();
328 handle
329 .shutdown(Duration::from_millis(250))
330 .await
331 .expect("shutdown should wait for active flow to drain");
332 assert!(
333 started.elapsed() >= Duration::from_millis(55),
334 "shutdown must wait for in-flight flow drain window"
335 );
336 guard_task.await.expect("guard task should complete");
337 }
338
339 #[tokio::test]
340 async fn shutdown_returns_timeout_when_active_flows_do_not_drain() {
341 let runtime_governor = Arc::new(RuntimeGovernor::new(RuntimeBudgetConfig::default()));
342 let permit = runtime_governor
343 .clone()
344 .try_acquire_flow_permit()
345 .expect("flow permit");
346 let _flow_guard = runtime_governor.begin_flow(permit);
347 let (shutdown_tx, mut shutdown_rx) = tokio::sync::watch::channel(false);
348 let join_handle = tokio::spawn(async move {
349 let _ = shutdown_rx.changed().await;
350 Ok(())
351 });
352 let handle = build_handle(runtime_governor, shutdown_tx, Some(join_handle));
353 let error = handle
354 .shutdown(Duration::from_millis(5))
355 .await
356 .expect_err("active flow not draining must force timeout");
357 match error {
358 MitmError::Io(io_error) => {
359 assert_eq!(io_error.kind(), std::io::ErrorKind::TimedOut);
360 assert!(io_error
361 .to_string()
362 .contains("timed out waiting for proxy shutdown"));
363 }
364 other => panic!("expected timeout IO error, got {other}"),
365 }
366 }
367
368 #[cfg(unix)]
369 #[test]
370 fn private_key_permissions_are_owner_only_on_unix() {
371 let temp_dir =
372 std::env::temp_dir().join(format!("soth-mitm-key-perm-{}", uuid::Uuid::new_v4()));
373 fs::create_dir_all(&temp_dir).expect("temp dir");
374 let key_path = temp_dir.join("ca-key.pem");
375
376 write_private_key_file(&key_path, b"key-material").expect("write private key");
377 let mode = fs::metadata(&key_path)
378 .expect("metadata")
379 .permissions()
380 .mode()
381 & 0o777;
382 assert_eq!(mode, 0o600, "private key file must be owner-readable only");
383
384 let _ = fs::remove_file(&key_path);
385 let _ = fs::remove_dir(&temp_dir);
386 }
387}