1use crate::error::Result;
4use bytes::Bytes;
5use slinger::{Body, Request, Response};
6use std::fmt;
7use std::net::SocketAddr;
8use std::sync::Arc;
9use std::time::{SystemTime, UNIX_EPOCH};
10use tokio::time::timeout;
11use uuid::Uuid;
12
13fn generate_session_id() -> u128 {
15 Uuid::new_v4().as_u128()
16}
17
18#[derive(Clone)]
21pub struct MitmRequest {
22 session_id: u128,
24 pub source: Option<SocketAddr>,
26 pub destination: String,
28 pub timestamp: u64,
30 is_http: bool,
32 pub request: Request,
34}
35
36impl MitmRequest {
37 pub fn new(destination: impl Into<String>, request: Request) -> Self {
39 Self {
40 session_id: generate_session_id(),
41 source: None,
42 destination: destination.into(),
43 timestamp: SystemTime::now()
44 .duration_since(UNIX_EPOCH)
45 .map(|d| d.as_millis() as u64)
46 .unwrap_or(0),
47 is_http: true,
48 request,
49 }
50 }
51
52 pub fn with_source(source: SocketAddr, destination: impl Into<String>, request: Request) -> Self {
54 Self {
55 session_id: generate_session_id(),
56 source: Some(source),
57 destination: destination.into(),
58 timestamp: SystemTime::now()
59 .duration_since(UNIX_EPOCH)
60 .map(|d| d.as_millis() as u64)
61 .unwrap_or(0),
62 is_http: true,
63 request,
64 }
65 }
66
67 pub fn raw_tcp(destination: impl Into<String>, body: impl Into<Bytes>) -> Self {
69 let request = Request {
70 body: Some(Body::from(body.into())),
71 ..Default::default()
72 };
73 Self {
74 session_id: generate_session_id(),
75 source: None,
76 destination: destination.into(),
77 timestamp: SystemTime::now()
78 .duration_since(UNIX_EPOCH)
79 .map(|d| d.as_millis() as u64)
80 .unwrap_or(0),
81 is_http: false,
82 request,
83 }
84 }
85
86 pub fn raw_tcp_with_source(
88 source: SocketAddr,
89 destination: impl Into<String>,
90 body: impl Into<Bytes>,
91 ) -> Self {
92 let request = Request {
93 body: Some(Body::from(body.into())),
94 ..Default::default()
95 };
96 Self {
97 session_id: generate_session_id(),
98 source: Some(source),
99 destination: destination.into(),
100 timestamp: SystemTime::now()
101 .duration_since(UNIX_EPOCH)
102 .map(|d| d.as_millis() as u64)
103 .unwrap_or(0),
104 is_http: false,
105 request,
106 }
107 }
108
109 pub fn session_id(&self) -> u128 {
111 self.session_id
112 }
113
114 pub fn set_session_id(&mut self, session_id: u128) {
116 self.session_id = session_id;
117 }
118
119 pub fn source(&self) -> Option<SocketAddr> {
121 self.source
122 }
123
124 pub fn destination(&self) -> &str {
126 &self.destination
127 }
128
129 pub fn timestamp(&self) -> u64 {
131 self.timestamp
132 }
133
134 pub fn request(&self) -> &Request {
136 &self.request
137 }
138
139 pub fn request_mut(&mut self) -> &mut Request {
141 &mut self.request
142 }
143
144 pub fn body(&self) -> Option<&Body> {
146 self.request.body.as_ref()
147 }
148
149 pub fn set_body(&mut self, body: impl Into<Bytes>) {
151 self.request.body = Some(Body::from(body.into()));
152 }
153
154 pub fn is_http(&self) -> bool {
156 self.is_http
157 }
158}
159
160impl fmt::Debug for MitmRequest {
161 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
162 f.debug_struct("MitmRequest")
163 .field("session_id", &self.session_id)
164 .field("source", &self.source)
165 .field("destination", &self.destination)
166 .field("timestamp", &self.timestamp)
167 .field("is_http", &self.is_http())
168 .field("request", &self.request)
169 .finish()
170 }
171}
172
173#[derive(Clone)]
176pub struct MitmResponse {
177 session_id: u128,
179 pub source: String,
181 pub destination: Option<SocketAddr>,
183 pub timestamp: u64,
185 is_http: bool,
187 pub response: Response,
189}
190
191impl MitmResponse {
192 pub fn new(session_id: u128, source: impl Into<String>, response: Response) -> Self {
195 Self {
196 session_id,
197 source: source.into(),
198 destination: None,
199 timestamp: SystemTime::now()
200 .duration_since(UNIX_EPOCH)
201 .map(|d| d.as_millis() as u64)
202 .unwrap_or(0),
203 is_http: true,
204 response,
205 }
206 }
207
208 pub fn with_destination(
211 session_id: u128,
212 source: impl Into<String>,
213 destination: SocketAddr,
214 response: Response,
215 ) -> Self {
216 Self {
217 session_id,
218 source: source.into(),
219 destination: Some(destination),
220 timestamp: SystemTime::now()
221 .duration_since(UNIX_EPOCH)
222 .map(|d| d.as_millis() as u64)
223 .unwrap_or(0),
224 is_http: true,
225 response,
226 }
227 }
228
229 pub fn raw_tcp(session_id: u128, source: impl Into<String>, body: impl Into<Bytes>) -> Self {
232 let response = Response {
233 body: Some(Body::from(body.into())),
234 ..Default::default()
235 };
236 Self {
237 session_id,
238 source: source.into(),
239 destination: None,
240 timestamp: SystemTime::now()
241 .duration_since(UNIX_EPOCH)
242 .map(|d| d.as_millis() as u64)
243 .unwrap_or(0),
244 is_http: false,
245 response,
246 }
247 }
248
249 pub fn raw_tcp_with_destination(
252 session_id: u128,
253 source: impl Into<String>,
254 destination: SocketAddr,
255 body: impl Into<Bytes>,
256 ) -> Self {
257 let response = Response {
258 body: Some(Body::from(body.into())),
259 ..Default::default()
260 };
261 Self {
262 session_id,
263 source: source.into(),
264 destination: Some(destination),
265 timestamp: SystemTime::now()
266 .duration_since(UNIX_EPOCH)
267 .map(|d| d.as_millis() as u64)
268 .unwrap_or(0),
269 is_http: false,
270 response,
271 }
272 }
273
274 pub fn session_id(&self) -> u128 {
276 self.session_id
277 }
278
279 pub fn source(&self) -> &str {
281 &self.source
282 }
283
284 pub fn destination(&self) -> Option<SocketAddr> {
286 self.destination
287 }
288
289 pub fn timestamp(&self) -> u64 {
291 self.timestamp
292 }
293
294 pub fn response(&self) -> &Response {
296 &self.response
297 }
298
299 pub fn response_mut(&mut self) -> &mut Response {
301 &mut self.response
302 }
303
304 pub fn body(&self) -> Option<&Body> {
306 self.response.body.as_ref()
307 }
308
309 pub fn set_body(&mut self, body: impl Into<Bytes>) {
311 self.response.body = Some(Body::from(body.into()));
312 }
313
314 pub fn is_http(&self) -> bool {
316 self.is_http
317 }
318}
319
320impl fmt::Debug for MitmResponse {
321 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
322 f.debug_struct("MitmResponse")
323 .field("session_id", &self.session_id)
324 .field("source", &self.source)
325 .field("destination", &self.destination)
326 .field("timestamp", &self.timestamp)
327 .field("is_http", &self.is_http())
328 .field("response", &self.response)
329 .finish()
330 }
331}
332#[async_trait::async_trait]
336pub trait Interceptor: Send + Sync {
337 async fn intercept_request(&self, request: MitmRequest) -> Result<Option<MitmRequest>> {
341 Ok(Some(request))
343 }
344
345 async fn intercept_response(&self, response: MitmResponse) -> Result<Option<MitmResponse>> {
350 Ok(Some(response))
352 }
353}
354
355pub struct InterceptorHandler {
358 interceptors: Vec<Arc<dyn Interceptor>>,
359 timeout_secs: u64,
361}
362
363impl InterceptorHandler {
364 pub fn new() -> Self {
366 Self {
367 interceptors: Vec::new(),
368 timeout_secs: 60,
369 }
370 }
371
372 pub fn with_timeout(mut self, timeout_secs: u64) -> Self {
374 self.timeout_secs = timeout_secs;
375 self
376 }
377
378 pub fn add_interceptor(&mut self, interceptor: Arc<dyn Interceptor>) {
382 self.interceptors.push(interceptor);
383 }
384
385 pub async fn process_request(&self, mut request: MitmRequest) -> Result<Option<MitmRequest>> {
387 for interceptor in &self.interceptors {
389 let request_clone = request.clone();
392 match timeout(
393 std::time::Duration::from_secs(self.timeout_secs),
394 interceptor.intercept_request(request_clone),
395 )
396 .await
397 {
398 Ok(Ok(Some(modified))) => request = modified,
400 Ok(Ok(None)) => return Ok(None), Ok(Err(e)) => return Err(e), Err(_) => {
404 tracing::warn!(
405 "Interceptor timed out after {}s; skipping",
406 self.timeout_secs
407 );
408 continue;
409 }
410 }
411 }
412 Ok(Some(request))
413 }
414
415 pub async fn process_response(&self, mut response: MitmResponse) -> Result<Option<MitmResponse>> {
417 for interceptor in &self.interceptors {
419 let response_clone = response.clone();
421 match timeout(
422 std::time::Duration::from_secs(self.timeout_secs),
423 interceptor.intercept_response(response_clone),
424 )
425 .await
426 {
427 Ok(Ok(Some(modified))) => response = modified,
428 Ok(Ok(None)) => return Ok(None), Ok(Err(e)) => return Err(e), Err(_) => {
431 tracing::warn!(
432 "Interceptor timed out after {}s; skipping",
433 self.timeout_secs
434 );
435 continue;
436 }
437 }
438 }
439 Ok(Some(response))
440 }
441
442 pub fn has_interceptors(&self) -> bool {
444 !self.interceptors.is_empty()
445 }
446}
447
448impl Default for InterceptorHandler {
449 fn default() -> Self {
450 Self::new()
451 }
452}
453
454pub struct InterceptorFactory;
456
457impl InterceptorFactory {
458 pub fn logging() -> LoggingInterceptor {
460 LoggingInterceptor
461 }
462}
463
464pub struct LoggingInterceptor;
466
467#[async_trait::async_trait]
469impl Interceptor for LoggingInterceptor {
470 async fn intercept_request(&self, request: MitmRequest) -> Result<Option<MitmRequest>> {
471 if request.is_http() {
472 tracing::info!(
473 "[MITM] HTTP Request (session_id={}): {} {}",
474 request.session_id(),
475 request.request().method(),
476 request.request().uri()
477 );
478 for (name, value) in request.request().headers() {
479 tracing::info!(" {}: {:?}", name, value);
480 }
481 } else {
482 tracing::info!(
483 "[MITM] TCP Request (session_id={}) to {}: {} bytes",
484 request.session_id(),
485 request.destination(),
486 request.body().map(|b| b.len()).unwrap_or(0)
487 );
488 }
489 if let Some(source) = request.source() {
490 tracing::info!(" From: {}", source);
491 }
492 tracing::info!(" Timestamp: {}", request.timestamp());
493 Ok(Some(request))
494 }
495
496 async fn intercept_response(&self, response: MitmResponse) -> Result<Option<MitmResponse>> {
497 if response.is_http() {
498 tracing::info!(
499 "[MITM] HTTP Response (session_id={}): {}",
500 response.session_id(),
501 response.response().status_code()
502 );
503 for (name, value) in response.response().headers() {
504 tracing::info!(" {}: {:?}", name, value);
505 }
506 } else {
507 tracing::info!(
508 "[MITM] TCP Response (session_id={}) from {}: {} bytes",
509 response.session_id(),
510 response.source(),
511 response.body().map(|b| b.len()).unwrap_or(0)
512 );
513 }
514 if let Some(destination) = response.destination() {
515 tracing::info!(" To: {}", destination);
516 }
517 tracing::info!(" Timestamp: {}", response.timestamp());
518 Ok(Some(response))
519 }
520}