Skip to main content

soth_mitm/
types.rs

1use std::net::IpAddr;
2use std::net::{SocketAddrV4, SocketAddrV6};
3use std::path::PathBuf;
4use std::sync::Arc;
5use std::time::SystemTime;
6
7use bytes::Bytes;
8use http::HeaderMap;
9use uuid::Uuid;
10
11/// Newtype wrapping a `u64` flow identifier for type-safe flow tracking.
12///
13/// # Examples
14///
15/// ```
16/// use soth_mitm::FlowId;
17///
18/// let id = FlowId(42);
19/// assert_eq!(id.as_u64(), 42);
20/// assert_eq!(format!("{id}"), "42");
21/// ```
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, serde::Serialize)]
23pub struct FlowId(pub u64);
24
25impl FlowId {
26    /// Returns the inner `u64` value.
27    pub fn as_u64(self) -> u64 {
28        self.0
29    }
30}
31
32impl std::fmt::Display for FlowId {
33    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34        write!(f, "{}", self.0)
35    }
36}
37
38/// TLS protocol version.
39#[derive(Debug, Clone, Copy, PartialEq, Eq)]
40pub enum TlsVersion {
41    Tls12,
42    Tls13,
43}
44
45impl TlsVersion {
46    pub fn as_str(&self) -> &'static str {
47        match self {
48            Self::Tls12 => "tls1.2",
49            Self::Tls13 => "tls1.3",
50        }
51    }
52}
53
54impl std::fmt::Display for TlsVersion {
55    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56        f.write_str(self.as_str())
57    }
58}
59
60/// An intercepted HTTP request passed to the handler.
61#[derive(Debug, Clone, PartialEq, Eq)]
62pub struct RawRequest {
63    pub method: String,
64    pub path: String,
65    pub headers: HeaderMap,
66    pub body: Bytes,
67    pub connection_meta: Arc<ConnectionMeta>,
68}
69
70/// An intercepted HTTP response passed to the handler.
71#[derive(Debug, Clone, PartialEq, Eq)]
72pub struct RawResponse {
73    pub status: u16,
74    pub headers: HeaderMap,
75    pub body: Bytes,
76    pub connection_meta: Arc<ConnectionMeta>,
77}
78
79/// Discriminant for streaming frame types delivered via [`StreamChunk`].
80#[derive(Debug, Clone, Copy, PartialEq, Eq)]
81pub enum FrameKind {
82    SseData,
83    NdjsonLine,
84    GrpcMessage,
85    WebSocketText,
86    WebSocketBinary,
87    WebSocketClose,
88}
89
90impl FrameKind {
91    pub fn as_str(&self) -> &'static str {
92        match self {
93            Self::SseData => "sse_data",
94            Self::NdjsonLine => "ndjson_line",
95            Self::GrpcMessage => "grpc_message",
96            Self::WebSocketText => "websocket_text",
97            Self::WebSocketBinary => "websocket_binary",
98            Self::WebSocketClose => "websocket_close",
99        }
100    }
101}
102
103impl std::fmt::Display for FrameKind {
104    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105        f.write_str(self.as_str())
106    }
107}
108
109/// Direction of a WebSocket frame.
110#[derive(Debug, Clone, Copy, PartialEq, Eq)]
111pub enum FrameDirection {
112    /// Client → server (request: model, prompt, tools).
113    ClientToServer,
114    /// Server → client (response: content deltas, usage).
115    ServerToClient,
116}
117
118/// A streaming data frame (SSE, NDJSON, gRPC, or WebSocket) delivered to the handler.
119#[derive(Debug, Clone, PartialEq, Eq)]
120pub struct StreamChunk {
121    pub connection_id: Uuid,
122    pub payload: Bytes,
123    pub sequence: u64,
124    pub frame_kind: FrameKind,
125    /// Direction for WebSocket frames. `None` for SSE/NDJSON/gRPC (always server→client).
126    pub direction: Option<FrameDirection>,
127}
128
129/// TLS metadata for the downstream connection.
130#[derive(Debug, Clone, PartialEq, Eq)]
131pub struct TlsInfo {
132    pub sni: Option<String>,
133    pub negotiated_proto: Option<String>,
134}
135
136/// Metadata about the downstream connection (socket, TLS, process attribution).
137#[derive(Debug, Clone, PartialEq, Eq)]
138pub struct ConnectionMeta {
139    pub connection_id: Uuid,
140    pub socket_family: SocketFamily,
141    pub process_info: Option<ProcessInfo>,
142    pub tls_info: Option<TlsInfo>,
143}
144
145/// Socket address family for the downstream connection.
146#[derive(Debug, Clone, PartialEq, Eq)]
147pub enum SocketFamily {
148    TcpV4 {
149        local: SocketAddrV4,
150        remote: SocketAddrV4,
151    },
152    TcpV6 {
153        local: SocketAddrV6,
154        remote: SocketAddrV6,
155    },
156    UnixDomain {
157        path: Option<PathBuf>,
158    },
159}
160
161impl SocketFamily {
162    pub fn as_str(&self) -> &'static str {
163        match self {
164            Self::TcpV4 { .. } => "tcp_v4",
165            Self::TcpV6 { .. } => "tcp_v6",
166            Self::UnixDomain { .. } => "unix_domain",
167        }
168    }
169}
170
171impl std::fmt::Display for SocketFamily {
172    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
173        f.write_str(self.as_str())
174    }
175}
176
177#[derive(Debug, Clone, PartialEq, Eq)]
178pub struct ConnectionInfo {
179    pub connection_id: Uuid,
180    pub source_ip: IpAddr,
181    pub source_port: u16,
182    pub destination_host: String,
183    pub destination_port: u16,
184    pub socket_family: SocketFamily,
185    pub tls_fingerprint: Option<TlsClientFingerprint>,
186    pub alpn_protocol: Option<String>,
187    pub is_http2: bool,
188    pub process_info: Option<ProcessInfo>,
189    pub connected_at: SystemTime,
190    pub request_count: u32,
191}
192
193#[derive(Debug, Clone, PartialEq, Eq)]
194pub struct TlsClientFingerprint {
195    pub ja4: String,
196    pub ja3: String,
197    pub tls_version: TlsVersion,
198    pub cipher_suites: Vec<u16>,
199    pub extensions: Vec<u16>,
200    pub elliptic_curves: Vec<u16>,
201}
202
203/// Information about the local process that owns the downstream socket.
204#[derive(Debug, Clone, PartialEq, Eq)]
205pub struct ProcessInfo {
206    pub pid: u32,
207    pub bundle_id: Option<String>,
208    pub exe_name: Option<String>,
209    pub exe_path: Option<PathBuf>,
210    pub parent_pid: Option<u32>,
211    pub parent_process_name: Option<String>,
212}