1use std::fs;
2#[cfg(unix)]
3use std::io::Write;
4use std::path::Path;
5use std::sync::Arc;
6use std::time::Duration;
7
8use crate::server::RuntimeGovernor;
9use tokio::sync::Mutex;
10use tokio::task::JoinHandle;
11
12use crate::ca::CertificateAuthority;
13use crate::config::MitmConfig;
14use crate::errors::MitmError;
15use crate::handler::InterceptHandler;
16use crate::metrics::{ProxyMetrics, ProxyMetricsStore};
17use crate::runtime::{build_runtime_server, RuntimeConfigHandle};
18
19pub struct MitmProxy<H: InterceptHandler> {
23 config: MitmConfig,
24 handler: Arc<H>,
25 ca: Option<CertificateAuthority>,
26 metrics_store: Arc<ProxyMetricsStore>,
27}
28
29impl<H: InterceptHandler> MitmProxy<H> {
30 pub(crate) fn new(
31 config: MitmConfig,
32 handler: Arc<H>,
33 ca: Option<CertificateAuthority>,
34 metrics_store: Arc<ProxyMetricsStore>,
35 ) -> Self {
36 Self {
37 config,
38 handler,
39 ca,
40 metrics_store,
41 }
42 }
43
44 pub async fn run(self) -> Result<(), MitmError> {
46 self.prepare_ca_material().await?;
47 let runtime_bundle = build_runtime_server(
48 &self.config,
49 Arc::clone(&self.handler),
50 Arc::clone(&self.metrics_store),
51 )?;
52 runtime_bundle.server.run().await.map_err(MitmError::from)
53 }
54
55 pub async fn start(self) -> Result<MitmProxyHandle, MitmError> {
57 self.prepare_ca_material().await?;
58 let runtime_bundle = build_runtime_server(
59 &self.config,
60 Arc::clone(&self.handler),
61 Arc::clone(&self.metrics_store),
62 )?;
63 let runtime_config = runtime_bundle.config_handle.clone();
64 let runtime_governor = runtime_bundle.server.runtime_observability_handle();
65 let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
66
67 let join_handle = tokio::spawn(async move {
68 runtime_bundle
69 .server
70 .run_until_shutdown(shutdown_rx)
71 .await
72 .map_err(MitmError::from)
73 });
74 Ok(MitmProxyHandle {
75 join_handle: Arc::new(Mutex::new(Some(join_handle))),
76 metrics_store: Arc::clone(&self.metrics_store),
77 runtime_config,
78 runtime_governor,
79 shutdown_tx,
80 })
81 }
82
83 async fn prepare_ca_material(&self) -> Result<(), MitmError> {
84 let ca = self.ca.clone();
85 let cert_path = self.config.tls.ca_cert_path.clone();
86 let key_path = self.config.tls.ca_key_path.clone();
87
88 if ca.is_none() {
89 return Ok(());
90 }
91 let ca = ca.unwrap();
92
93 tokio::task::spawn_blocking(move || -> Result<(), MitmError> {
94 if let Some(parent) = cert_path.parent() {
95 if !parent.as_os_str().is_empty() {
96 fs::create_dir_all(parent)?;
97 }
98 }
99 if let Some(parent) = key_path.parent() {
100 if !parent.as_os_str().is_empty() {
101 fs::create_dir_all(parent)?;
102 }
103 }
104 fs::write(&cert_path, &ca.cert_pem)?;
105 write_private_key_file(&key_path, &ca.key_pem)?;
106 Ok(())
107 })
108 .await
109 .map_err(MitmError::Join)?
110 }
111}
112
113pub struct MitmProxyHandle {
115 join_handle: Arc<Mutex<Option<JoinHandle<Result<(), MitmError>>>>>,
116 metrics_store: Arc<ProxyMetricsStore>,
117 runtime_config: RuntimeConfigHandle,
118 runtime_governor: Arc<RuntimeGovernor>,
119 shutdown_tx: tokio::sync::watch::Sender<bool>,
120}
121
122impl MitmProxyHandle {
123 pub async fn reload(&self, next_config: MitmConfig) -> Result<(), MitmError> {
126 self.runtime_config.apply_reload(&next_config)?;
127 Ok(())
128 }
129
130 pub async fn current_config(&self) -> MitmConfig {
132 self.runtime_config.current_config()
133 }
134
135 pub async fn shutdown(self, timeout: Duration) -> Result<(), MitmError> {
137 let mut guard = self.join_handle.lock().await;
138 let Some(handle) = guard.take() else {
139 return Ok(());
140 };
141 drop(guard);
142 let mut handle = handle;
143
144 let _ = self.shutdown_tx.send(true);
145 let deadline = tokio::time::Instant::now() + timeout;
146 let drained =
147 wait_for_active_flows_to_drain(Arc::clone(&self.runtime_governor), deadline).await;
148 if !drained {
149 handle.abort();
150 let _ = tokio::time::timeout(Duration::from_millis(100), &mut handle).await;
151 return Err(shutdown_timeout_error());
152 }
153
154 let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
155 match tokio::time::timeout(remaining, &mut handle).await {
156 Ok(join_result) => match join_result {
157 Ok(result) => result,
158 Err(error) if error.is_cancelled() => Ok(()),
159 Err(error) => Err(MitmError::Join(error)),
160 },
161 Err(_) => {
162 handle.abort();
163 let _ = tokio::time::timeout(Duration::from_millis(100), &mut handle).await;
164 Err(shutdown_timeout_error())
165 }
166 }
167 }
168
169 pub fn metrics(&self) -> ProxyMetrics {
170 self.metrics_store.snapshot()
171 }
172}
173
174fn shutdown_timeout_error() -> MitmError {
175 MitmError::Io(std::io::Error::new(
176 std::io::ErrorKind::TimedOut,
177 "timed out waiting for proxy shutdown",
178 ))
179}
180
181async fn wait_for_active_flows_to_drain(
182 runtime_governor: Arc<RuntimeGovernor>,
183 deadline: tokio::time::Instant,
184) -> bool {
185 loop {
186 if runtime_governor.snapshot().active_flows == 0 {
187 return true;
188 }
189
190 let now = tokio::time::Instant::now();
191 if now >= deadline {
192 return false;
193 }
194
195 let sleep_for = (deadline - now).min(Duration::from_millis(25));
196 tokio::time::sleep(sleep_for).await;
197 }
198}
199
200fn write_private_key_file(path: &Path, key_pem: &[u8]) -> std::io::Result<()> {
201 #[cfg(unix)]
202 {
203 use std::os::unix::fs::{OpenOptionsExt, PermissionsExt};
204
205 let mut file = fs::OpenOptions::new()
206 .create(true)
207 .truncate(true)
208 .write(true)
209 .mode(0o600)
210 .open(path)?;
211 file.write_all(key_pem)?;
212 file.flush()?;
213 fs::set_permissions(path, fs::Permissions::from_mode(0o600))?;
214 Ok(())
215 }
216
217 #[cfg(not(unix))]
218 {
219 fs::write(path, key_pem)?;
220 Ok(())
221 }
222}
223
224#[cfg(test)]
225mod tests {
226 use std::fs;
227 #[cfg(unix)]
228 use std::os::unix::fs::PermissionsExt;
229 use std::sync::Arc;
230 use std::time::Duration;
231
232 use crate::server::{RuntimeBudgetConfig, RuntimeGovernor};
233 use tokio::sync::Mutex;
234
235 use super::{write_private_key_file, MitmProxyHandle};
236 use crate::config::MitmConfig;
237 use crate::errors::MitmError;
238 use crate::metrics::ProxyMetricsStore;
239 use crate::runtime::RuntimeConfigHandle;
240
241 fn build_handle(
242 runtime_governor: Arc<RuntimeGovernor>,
243 shutdown_tx: tokio::sync::watch::Sender<bool>,
244 join_handle: Option<tokio::task::JoinHandle<Result<(), MitmError>>>,
245 ) -> MitmProxyHandle {
246 let mut config = MitmConfig::default();
247 config
248 .interception
249 .destinations
250 .push("api.example.com:443".to_string());
251 MitmProxyHandle {
252 join_handle: Arc::new(Mutex::new(join_handle)),
253 metrics_store: Arc::new(ProxyMetricsStore::default()),
254 runtime_config: RuntimeConfigHandle::from_config(&config)
255 .expect("runtime config handle must build"),
256 runtime_governor,
257 shutdown_tx,
258 }
259 }
260
261 #[tokio::test]
262 async fn shutdown_noop_when_handle_already_consumed() {
263 let runtime_governor = Arc::new(RuntimeGovernor::new(RuntimeBudgetConfig::default()));
264 let (shutdown_tx, _) = tokio::sync::watch::channel(false);
265 let handle = build_handle(runtime_governor, shutdown_tx, None);
266 handle
267 .shutdown(Duration::from_millis(10))
268 .await
269 .expect("shutdown should be a no-op when handle is empty");
270 }
271
272 #[tokio::test]
273 async fn shutdown_drains_active_flows_before_joining_runtime() {
274 let runtime_governor = Arc::new(RuntimeGovernor::new(RuntimeBudgetConfig::default()));
275 let permit = runtime_governor
276 .clone()
277 .try_acquire_flow_permit()
278 .expect("flow permit");
279 let flow_guard = runtime_governor.begin_flow(permit);
280 let (shutdown_tx, mut shutdown_rx) = tokio::sync::watch::channel(false);
281 let join_handle = tokio::spawn(async move {
282 let _ = shutdown_rx.changed().await;
283 Ok(())
284 });
285 let guard_task = tokio::spawn(async move {
286 tokio::time::sleep(Duration::from_millis(60)).await;
287 drop(flow_guard);
288 });
289
290 let handle = build_handle(
291 Arc::clone(&runtime_governor),
292 shutdown_tx,
293 Some(join_handle),
294 );
295 let started = std::time::Instant::now();
296 handle
297 .shutdown(Duration::from_millis(250))
298 .await
299 .expect("shutdown should wait for active flow to drain");
300 assert!(
301 started.elapsed() >= Duration::from_millis(55),
302 "shutdown must wait for in-flight flow drain window"
303 );
304 guard_task.await.expect("guard task should complete");
305 }
306
307 #[tokio::test]
308 async fn shutdown_returns_timeout_when_active_flows_do_not_drain() {
309 let runtime_governor = Arc::new(RuntimeGovernor::new(RuntimeBudgetConfig::default()));
310 let permit = runtime_governor
311 .clone()
312 .try_acquire_flow_permit()
313 .expect("flow permit");
314 let _flow_guard = runtime_governor.begin_flow(permit);
315 let (shutdown_tx, mut shutdown_rx) = tokio::sync::watch::channel(false);
316 let join_handle = tokio::spawn(async move {
317 let _ = shutdown_rx.changed().await;
318 Ok(())
319 });
320 let handle = build_handle(runtime_governor, shutdown_tx, Some(join_handle));
321 let error = handle
322 .shutdown(Duration::from_millis(5))
323 .await
324 .expect_err("active flow not draining must force timeout");
325 match error {
326 MitmError::Io(io_error) => {
327 assert_eq!(io_error.kind(), std::io::ErrorKind::TimedOut);
328 assert!(io_error
329 .to_string()
330 .contains("timed out waiting for proxy shutdown"));
331 }
332 other => panic!("expected timeout IO error, got {other}"),
333 }
334 }
335
336 #[cfg(unix)]
337 #[test]
338 fn private_key_permissions_are_owner_only_on_unix() {
339 let temp_dir =
340 std::env::temp_dir().join(format!("soth-mitm-key-perm-{}", uuid::Uuid::new_v4()));
341 fs::create_dir_all(&temp_dir).expect("temp dir");
342 let key_path = temp_dir.join("ca-key.pem");
343
344 write_private_key_file(&key_path, b"key-material").expect("write private key");
345 let mode = fs::metadata(&key_path)
346 .expect("metadata")
347 .permissions()
348 .mode()
349 & 0o777;
350 assert_eq!(mode, 0o600, "private key file must be owner-readable only");
351
352 let _ = fs::remove_file(&key_path);
353 let _ = fs::remove_dir(&temp_dir);
354 }
355}