Skip to main content

slinger_mitm/
interceptor.rs

1//! Traffic interception and modification interfaces
2
3use crate::error::Result;
4use bytes::Bytes;
5use slinger::{Body, Request, Response};
6use std::fmt;
7use std::net::SocketAddr;
8use std::sync::Arc;
9use std::time::{SystemTime, UNIX_EPOCH};
10use tokio::time::timeout;
11use uuid::Uuid;
12
13/// Generate a new unique session ID using UUID v4
14fn generate_session_id() -> u128 {
15  Uuid::new_v4().as_u128()
16}
17
18/// MITM Request wrapper that wraps slinger::Request with connection metadata.
19/// Used for both HTTP and non-HTTP (raw TCP) traffic interception.
20#[derive(Clone)]
21pub struct MitmRequest {
22  /// Unique session ID to correlate this request with its response (UUID v4 as u128)
23  session_id: u128,
24  /// Source address and port (client)
25  pub source: Option<SocketAddr>,
26  /// Destination address (host:port)
27  pub destination: String,
28  /// Timestamp when the request was intercepted
29  pub timestamp: u64,
30  /// Whether this is an HTTP request (true) or raw TCP (false)
31  is_http: bool,
32  /// The underlying request (contains body for both HTTP and raw TCP)
33  pub request: Request,
34}
35
36impl MitmRequest {
37  /// Create a new MITM request wrapper for HTTP traffic
38  pub fn new(destination: impl Into<String>, request: Request) -> Self {
39    Self {
40      session_id: generate_session_id(),
41      source: None,
42      destination: destination.into(),
43      timestamp: SystemTime::now()
44        .duration_since(UNIX_EPOCH)
45        .map(|d| d.as_millis() as u64)
46        .unwrap_or(0),
47      is_http: true,
48      request,
49    }
50  }
51
52  /// Create a new MITM request with source address for HTTP traffic
53  pub fn with_source(source: SocketAddr, destination: impl Into<String>, request: Request) -> Self {
54    Self {
55      session_id: generate_session_id(),
56      source: Some(source),
57      destination: destination.into(),
58      timestamp: SystemTime::now()
59        .duration_since(UNIX_EPOCH)
60        .map(|d| d.as_millis() as u64)
61        .unwrap_or(0),
62      is_http: true,
63      request,
64    }
65  }
66
67  /// Create a MITM request for raw TCP data (non-HTTP)
68  pub fn raw_tcp(destination: impl Into<String>, body: impl Into<Bytes>) -> Self {
69    let request = Request {
70      body: Some(Body::from(body.into())),
71      ..Default::default()
72    };
73    Self {
74      session_id: generate_session_id(),
75      source: None,
76      destination: destination.into(),
77      timestamp: SystemTime::now()
78        .duration_since(UNIX_EPOCH)
79        .map(|d| d.as_millis() as u64)
80        .unwrap_or(0),
81      is_http: false,
82      request,
83    }
84  }
85
86  /// Create a MITM request for raw TCP data with source address
87  pub fn raw_tcp_with_source(
88    source: SocketAddr,
89    destination: impl Into<String>,
90    body: impl Into<Bytes>,
91  ) -> Self {
92    let request = Request {
93      body: Some(Body::from(body.into())),
94      ..Default::default()
95    };
96    Self {
97      session_id: generate_session_id(),
98      source: Some(source),
99      destination: destination.into(),
100      timestamp: SystemTime::now()
101        .duration_since(UNIX_EPOCH)
102        .map(|d| d.as_millis() as u64)
103        .unwrap_or(0),
104      is_http: false,
105      request,
106    }
107  }
108
109  /// Get the session ID (used to correlate request with response)
110  pub fn session_id(&self) -> u128 {
111    self.session_id
112  }
113
114  /// Set the session ID (used to override auto-generated session_id for TCP connections)
115  pub fn set_session_id(&mut self, session_id: u128) {
116    self.session_id = session_id;
117  }
118
119  /// Get the source address
120  pub fn source(&self) -> Option<SocketAddr> {
121    self.source
122  }
123
124  /// Get the destination address
125  pub fn destination(&self) -> &str {
126    &self.destination
127  }
128
129  /// Get the timestamp
130  pub fn timestamp(&self) -> u64 {
131    self.timestamp
132  }
133
134  /// Get the underlying request
135  pub fn request(&self) -> &Request {
136    &self.request
137  }
138
139  /// Get a mutable reference to the underlying request
140  pub fn request_mut(&mut self) -> &mut Request {
141    &mut self.request
142  }
143
144  /// Get the body as bytes (for raw TCP traffic)
145  pub fn body(&self) -> Option<&Body> {
146    self.request.body.as_ref()
147  }
148
149  /// Set the body (for raw TCP traffic)
150  pub fn set_body(&mut self, body: impl Into<Bytes>) {
151    self.request.body = Some(Body::from(body.into()));
152  }
153
154  /// Check if this is an HTTP request (true) or raw TCP (false)
155  pub fn is_http(&self) -> bool {
156    self.is_http
157  }
158}
159
160impl fmt::Debug for MitmRequest {
161  fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
162    f.debug_struct("MitmRequest")
163      .field("session_id", &self.session_id)
164      .field("source", &self.source)
165      .field("destination", &self.destination)
166      .field("timestamp", &self.timestamp)
167      .field("is_http", &self.is_http())
168      .field("request", &self.request)
169      .finish()
170  }
171}
172
173/// MITM Response wrapper that wraps slinger::Response with connection metadata.
174/// Used for both HTTP and non-HTTP (raw TCP) traffic interception.
175#[derive(Clone)]
176pub struct MitmResponse {
177  /// Unique session ID to correlate this response with its request (UUID v4 as u128)
178  session_id: u128,
179  /// Source address (where the response came from, host:port)
180  pub source: String,
181  /// Destination address and port (client)
182  pub destination: Option<SocketAddr>,
183  /// Timestamp when the response was intercepted
184  pub timestamp: u64,
185  /// Whether this is an HTTP response (true) or raw TCP (false)
186  is_http: bool,
187  /// The underlying response (contains body for both HTTP and raw TCP)
188  pub response: Response,
189}
190
191impl MitmResponse {
192  /// Create a new MITM response wrapper for HTTP traffic
193  /// The session_id should match the corresponding MitmRequest's session_id
194  pub fn new(session_id: u128, source: impl Into<String>, response: Response) -> Self {
195    Self {
196      session_id,
197      source: source.into(),
198      destination: None,
199      timestamp: SystemTime::now()
200        .duration_since(UNIX_EPOCH)
201        .map(|d| d.as_millis() as u64)
202        .unwrap_or(0),
203      is_http: true,
204      response,
205    }
206  }
207
208  /// Create a new MITM response with destination address for HTTP traffic
209  /// The session_id should match the corresponding MitmRequest's session_id
210  pub fn with_destination(
211    session_id: u128,
212    source: impl Into<String>,
213    destination: SocketAddr,
214    response: Response,
215  ) -> Self {
216    Self {
217      session_id,
218      source: source.into(),
219      destination: Some(destination),
220      timestamp: SystemTime::now()
221        .duration_since(UNIX_EPOCH)
222        .map(|d| d.as_millis() as u64)
223        .unwrap_or(0),
224      is_http: true,
225      response,
226    }
227  }
228
229  /// Create a MITM response for raw TCP data (non-HTTP)
230  /// The session_id should match the corresponding MitmRequest's session_id
231  pub fn raw_tcp(session_id: u128, source: impl Into<String>, body: impl Into<Bytes>) -> Self {
232    let response = Response {
233      body: Some(Body::from(body.into())),
234      ..Default::default()
235    };
236    Self {
237      session_id,
238      source: source.into(),
239      destination: None,
240      timestamp: SystemTime::now()
241        .duration_since(UNIX_EPOCH)
242        .map(|d| d.as_millis() as u64)
243        .unwrap_or(0),
244      is_http: false,
245      response,
246    }
247  }
248
249  /// Create a MITM response for raw TCP data with destination address
250  /// The session_id should match the corresponding MitmRequest's session_id
251  pub fn raw_tcp_with_destination(
252    session_id: u128,
253    source: impl Into<String>,
254    destination: SocketAddr,
255    body: impl Into<Bytes>,
256  ) -> Self {
257    let response = Response {
258      body: Some(Body::from(body.into())),
259      ..Default::default()
260    };
261    Self {
262      session_id,
263      source: source.into(),
264      destination: Some(destination),
265      timestamp: SystemTime::now()
266        .duration_since(UNIX_EPOCH)
267        .map(|d| d.as_millis() as u64)
268        .unwrap_or(0),
269      is_http: false,
270      response,
271    }
272  }
273
274  /// Get the session ID (used to correlate response with request)
275  pub fn session_id(&self) -> u128 {
276    self.session_id
277  }
278
279  /// Get the source address
280  pub fn source(&self) -> &str {
281    &self.source
282  }
283
284  /// Get the destination address
285  pub fn destination(&self) -> Option<SocketAddr> {
286    self.destination
287  }
288
289  /// Get the timestamp
290  pub fn timestamp(&self) -> u64 {
291    self.timestamp
292  }
293
294  /// Get the underlying response
295  pub fn response(&self) -> &Response {
296    &self.response
297  }
298
299  /// Get a mutable reference to the underlying response
300  pub fn response_mut(&mut self) -> &mut Response {
301    &mut self.response
302  }
303
304  /// Get the body as bytes (for raw TCP traffic)
305  pub fn body(&self) -> Option<&Body> {
306    self.response.body.as_ref()
307  }
308
309  /// Set the body (for raw TCP traffic)
310  pub fn set_body(&mut self, body: impl Into<Bytes>) {
311    self.response.body = Some(Body::from(body.into()));
312  }
313
314  /// Check if this is an HTTP response (true) or raw TCP (false)
315  pub fn is_http(&self) -> bool {
316    self.is_http
317  }
318}
319
320impl fmt::Debug for MitmResponse {
321  fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
322    f.debug_struct("MitmResponse")
323      .field("session_id", &self.session_id)
324      .field("source", &self.source)
325      .field("destination", &self.destination)
326      .field("timestamp", &self.timestamp)
327      .field("is_http", &self.is_http())
328      .field("response", &self.response)
329      .finish()
330  }
331}
332/// Unified trait for intercepting both requests and responses with automatic correlation
333/// This trait is recommended over separate RequestInterceptor and ResponseInterceptor
334/// as it provides automatic session correlation between requests and responses.
335#[async_trait::async_trait]
336pub trait Interceptor: Send + Sync {
337  /// Intercept and optionally modify a request
338  ///
339  /// Return `None` to block the request, or return a modified request
340  async fn intercept_request(&self, request: MitmRequest) -> Result<Option<MitmRequest>> {
341    // Default implementation passes through
342    Ok(Some(request))
343  }
344
345  /// Intercept and optionally modify a response
346  /// The response is automatically correlated with its request via session_id
347  ///
348  /// Return `None` to block the response, or return a modified response
349  async fn intercept_response(&self, response: MitmResponse) -> Result<Option<MitmResponse>> {
350    // Default implementation passes through
351    Ok(Some(response))
352  }
353}
354
355/// Combined interceptor handler for both HTTP and TCP traffic
356/// Manages automatic correlation between requests and responses via session IDs
357pub struct InterceptorHandler {
358  interceptors: Vec<Arc<dyn Interceptor>>,
359  /// Per-interceptor timeout in seconds
360  timeout_secs: u64,
361}
362
363impl InterceptorHandler {
364  /// Create a new interceptor handler
365  pub fn new() -> Self {
366    Self {
367      interceptors: Vec::new(),
368      timeout_secs: 60,
369    }
370  }
371
372  /// Create a new interceptor handler with a configurable per-interceptor timeout
373  pub fn with_timeout(mut self, timeout_secs: u64) -> Self {
374    self.timeout_secs = timeout_secs;
375    self
376  }
377
378  /// Add a unified interceptor that handles both requests and responses
379  /// This is the recommended way to add interceptors as it provides automatic
380  /// session correlation between requests and responses
381  pub fn add_interceptor(&mut self, interceptor: Arc<dyn Interceptor>) {
382    self.interceptors.push(interceptor);
383  }
384
385  /// Process a request through all interceptors
386  pub async fn process_request(&self, mut request: MitmRequest) -> Result<Option<MitmRequest>> {
387    // Process through unified interceptors first
388    for interceptor in &self.interceptors {
389      // Clone the request so a timed-out interceptor doesn't consume ownership
390      // of the in-band request. Dashed interceptor results will replace `request`.
391      let request_clone = request.clone();
392      match timeout(
393        std::time::Duration::from_secs(self.timeout_secs),
394        interceptor.intercept_request(request_clone),
395      )
396      .await
397      {
398        // Interceptor completed within timeout
399        Ok(Ok(Some(modified))) => request = modified,
400        Ok(Ok(None)) => return Ok(None), // Request blocked by interceptor
401        Ok(Err(e)) => return Err(e),     // Interceptor returned an error
402        // Timeout -> skip this interceptor but continue processing
403        Err(_) => {
404          tracing::warn!(
405            "Interceptor timed out after {}s; skipping",
406            self.timeout_secs
407          );
408          continue;
409        }
410      }
411    }
412    Ok(Some(request))
413  }
414
415  /// Process a response through all interceptors
416  pub async fn process_response(&self, mut response: MitmResponse) -> Result<Option<MitmResponse>> {
417    // Process through unified interceptors first
418    for interceptor in &self.interceptors {
419      // Clone before invoking so timeouts don't consume the in-band response
420      let response_clone = response.clone();
421      match timeout(
422        std::time::Duration::from_secs(self.timeout_secs),
423        interceptor.intercept_response(response_clone),
424      )
425      .await
426      {
427        Ok(Ok(Some(modified))) => response = modified,
428        Ok(Ok(None)) => return Ok(None), // Response blocked
429        Ok(Err(e)) => return Err(e),     // Interceptor error
430        Err(_) => {
431          tracing::warn!(
432            "Interceptor timed out after {}s; skipping",
433            self.timeout_secs
434          );
435          continue;
436        }
437      }
438    }
439    Ok(Some(response))
440  }
441
442  /// Check if any interceptors are registered
443  pub fn has_interceptors(&self) -> bool {
444    !self.interceptors.is_empty()
445  }
446}
447
448impl Default for InterceptorHandler {
449  fn default() -> Self {
450    Self::new()
451  }
452}
453
454/// Factory for creating pre-built interceptors
455pub struct InterceptorFactory;
456
457impl InterceptorFactory {
458  /// Create a logging interceptor that prints requests/responses
459  pub fn logging() -> LoggingInterceptor {
460    LoggingInterceptor
461  }
462}
463
464/// Logging interceptor implementation that handles both HTTP and TCP traffic
465pub struct LoggingInterceptor;
466
467// Unified Interceptor trait implementation (recommended)
468#[async_trait::async_trait]
469impl Interceptor for LoggingInterceptor {
470  async fn intercept_request(&self, request: MitmRequest) -> Result<Option<MitmRequest>> {
471    if request.is_http() {
472      tracing::info!(
473        "[MITM] HTTP Request (session_id={}): {} {}",
474        request.session_id(),
475        request.request().method(),
476        request.request().uri()
477      );
478      for (name, value) in request.request().headers() {
479        tracing::info!("  {}: {:?}", name, value);
480      }
481    } else {
482      tracing::info!(
483        "[MITM] TCP Request (session_id={}) to {}: {} bytes",
484        request.session_id(),
485        request.destination(),
486        request.body().map(|b| b.len()).unwrap_or(0)
487      );
488    }
489    if let Some(source) = request.source() {
490      tracing::info!("  From: {}", source);
491    }
492    tracing::info!("  Timestamp: {}", request.timestamp());
493    Ok(Some(request))
494  }
495
496  async fn intercept_response(&self, response: MitmResponse) -> Result<Option<MitmResponse>> {
497    if response.is_http() {
498      tracing::info!(
499        "[MITM] HTTP Response (session_id={}): {}",
500        response.session_id(),
501        response.response().status_code()
502      );
503      for (name, value) in response.response().headers() {
504        tracing::info!("  {}: {:?}", name, value);
505      }
506    } else {
507      tracing::info!(
508        "[MITM] TCP Response (session_id={}) from {}: {} bytes",
509        response.session_id(),
510        response.source(),
511        response.body().map(|b| b.len()).unwrap_or(0)
512      );
513    }
514    if let Some(destination) = response.destination() {
515      tracing::info!("  To: {}", destination);
516    }
517    tracing::info!("  Timestamp: {}", response.timestamp());
518    Ok(Some(response))
519  }
520}