Skip to main content

soth_mitm/
types.rs

1use std::net::IpAddr;
2use std::net::{SocketAddrV4, SocketAddrV6};
3use std::path::PathBuf;
4use std::sync::Arc;
5use std::time::SystemTime;
6
7use bytes::Bytes;
8use http::HeaderMap;
9use uuid::Uuid;
10
11/// Newtype wrapping a `u64` flow identifier for type-safe flow tracking.
12///
13/// # Examples
14///
15/// ```
16/// use soth_mitm::FlowId;
17///
18/// let id = FlowId(42);
19/// assert_eq!(id.as_u64(), 42);
20/// assert_eq!(format!("{id}"), "42");
21/// ```
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, serde::Serialize)]
23pub struct FlowId(pub u64);
24
25impl FlowId {
26    /// Returns the inner `u64` value.
27    pub fn as_u64(self) -> u64 {
28        self.0
29    }
30}
31
32impl std::fmt::Display for FlowId {
33    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34        write!(f, "{}", self.0)
35    }
36}
37
38/// TLS protocol version.
39#[derive(Debug, Clone, Copy, PartialEq, Eq)]
40pub enum TlsVersion {
41    Tls12,
42    Tls13,
43}
44
45impl TlsVersion {
46    pub fn as_str(&self) -> &'static str {
47        match self {
48            Self::Tls12 => "tls1.2",
49            Self::Tls13 => "tls1.3",
50        }
51    }
52}
53
54impl std::fmt::Display for TlsVersion {
55    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56        f.write_str(self.as_str())
57    }
58}
59
60/// An intercepted HTTP request passed to the handler.
61#[derive(Debug, Clone, PartialEq, Eq)]
62pub struct RawRequest {
63    pub method: String,
64    pub path: String,
65    pub headers: HeaderMap,
66    pub body: Bytes,
67    pub connection_meta: Arc<ConnectionMeta>,
68}
69
70/// An intercepted HTTP response passed to the handler.
71#[derive(Debug, Clone, PartialEq, Eq)]
72pub struct RawResponse {
73    pub status: u16,
74    pub headers: HeaderMap,
75    pub body: Bytes,
76    pub connection_meta: Arc<ConnectionMeta>,
77}
78
79/// Discriminant for streaming frame types delivered via [`StreamChunk`].
80#[derive(Debug, Clone, Copy, PartialEq, Eq)]
81pub enum FrameKind {
82    SseData,
83    NdjsonLine,
84    GrpcMessage,
85    WebSocketText,
86    WebSocketBinary,
87    WebSocketClose,
88}
89
90impl FrameKind {
91    pub fn as_str(&self) -> &'static str {
92        match self {
93            Self::SseData => "sse_data",
94            Self::NdjsonLine => "ndjson_line",
95            Self::GrpcMessage => "grpc_message",
96            Self::WebSocketText => "websocket_text",
97            Self::WebSocketBinary => "websocket_binary",
98            Self::WebSocketClose => "websocket_close",
99        }
100    }
101}
102
103impl std::fmt::Display for FrameKind {
104    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105        f.write_str(self.as_str())
106    }
107}
108
109/// A streaming data frame (SSE, NDJSON, gRPC, or WebSocket) delivered to the handler.
110#[derive(Debug, Clone, PartialEq, Eq)]
111pub struct StreamChunk {
112    pub connection_id: Uuid,
113    pub payload: Bytes,
114    pub sequence: u64,
115    pub frame_kind: FrameKind,
116}
117
118/// TLS metadata for the downstream connection.
119#[derive(Debug, Clone, PartialEq, Eq)]
120pub struct TlsInfo {
121    pub sni: Option<String>,
122    pub negotiated_proto: Option<String>,
123}
124
125/// Metadata about the downstream connection (socket, TLS, process attribution).
126#[derive(Debug, Clone, PartialEq, Eq)]
127pub struct ConnectionMeta {
128    pub connection_id: Uuid,
129    pub socket_family: SocketFamily,
130    pub process_info: Option<ProcessInfo>,
131    pub tls_info: Option<TlsInfo>,
132}
133
134/// Socket address family for the downstream connection.
135#[derive(Debug, Clone, PartialEq, Eq)]
136pub enum SocketFamily {
137    TcpV4 {
138        local: SocketAddrV4,
139        remote: SocketAddrV4,
140    },
141    TcpV6 {
142        local: SocketAddrV6,
143        remote: SocketAddrV6,
144    },
145    UnixDomain {
146        path: Option<PathBuf>,
147    },
148}
149
150impl SocketFamily {
151    pub fn as_str(&self) -> &'static str {
152        match self {
153            Self::TcpV4 { .. } => "tcp_v4",
154            Self::TcpV6 { .. } => "tcp_v6",
155            Self::UnixDomain { .. } => "unix_domain",
156        }
157    }
158}
159
160impl std::fmt::Display for SocketFamily {
161    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
162        f.write_str(self.as_str())
163    }
164}
165
166#[derive(Debug, Clone, PartialEq, Eq)]
167pub struct ConnectionInfo {
168    pub connection_id: Uuid,
169    pub source_ip: IpAddr,
170    pub source_port: u16,
171    pub destination_host: String,
172    pub destination_port: u16,
173    pub socket_family: SocketFamily,
174    pub tls_fingerprint: Option<TlsClientFingerprint>,
175    pub alpn_protocol: Option<String>,
176    pub is_http2: bool,
177    pub process_info: Option<ProcessInfo>,
178    pub connected_at: SystemTime,
179    pub request_count: u32,
180}
181
182#[derive(Debug, Clone, PartialEq, Eq)]
183pub struct TlsClientFingerprint {
184    pub ja4: String,
185    pub ja3: String,
186    pub tls_version: TlsVersion,
187    pub cipher_suites: Vec<u16>,
188    pub extensions: Vec<u16>,
189    pub elliptic_curves: Vec<u16>,
190}
191
192/// Information about the local process that owns the downstream socket.
193#[derive(Debug, Clone, PartialEq, Eq)]
194pub struct ProcessInfo {
195    pub pid: u32,
196    pub bundle_id: Option<String>,
197    pub exe_name: Option<String>,
198    pub exe_path: Option<PathBuf>,
199    pub parent_pid: Option<u32>,
200    pub parent_process_name: Option<String>,
201}