1use crate::handler::{
2 extract_body_bytes, extract_response_body_bytes, put_body_back, put_response_body_back,
3 BoxBody, Buffered, Dropped, RequestHandler,
4};
5use bytes::Bytes;
6use hyper::header::HeaderMap;
7use hyper::{Method, Request, Response, StatusCode, Uri, Version};
8use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
9use std::sync::mpsc;
10use std::sync::Arc;
11use tracing::info;
12
13pub type InterceptId = u64;
14
15pub enum InterceptedItem {
17 Request {
18 id: InterceptId,
19 method: Method,
20 uri: Uri,
21 version: Version,
22 headers: HeaderMap,
23 body: Bytes,
24 reply: mpsc::Sender<Verdict>,
25 },
26 Response {
27 id: InterceptId,
28 status: StatusCode,
29 version: Version,
30 headers: HeaderMap,
31 body: Bytes,
32 reply: mpsc::Sender<Verdict>,
33 },
34}
35
36pub enum Verdict {
38 Forward {
39 headers: Box<HeaderMap>,
40 body: Bytes,
41 method: Option<Method>,
42 uri: Option<Uri>,
43 status: Option<StatusCode>,
44 },
45 Drop,
46}
47
48pub struct InterceptHandler {
50 tx: mpsc::SyncSender<InterceptedItem>,
51 active: Arc<AtomicBool>,
52 next_id: AtomicU64,
53}
54
55impl InterceptHandler {
56 pub fn new(tx: mpsc::SyncSender<InterceptedItem>, active: Arc<AtomicBool>) -> Self {
57 Self {
58 tx,
59 active,
60 next_id: AtomicU64::new(1),
61 }
62 }
63}
64
65impl RequestHandler for InterceptHandler {
66 fn handle_request(&self, req: &mut Request<BoxBody>) {
67 let path = req.uri().path();
68 let display_uri = if req.uri().query().is_some() {
69 format!("{path}?...")
70 } else {
71 path.to_string()
72 };
73 info!(">> {} {} {:?}", req.method(), display_uri, req.version());
74
75 if !self.active.load(Ordering::Relaxed)
77 || req.extensions().get::<Buffered>().is_none()
78 {
79 return;
80 }
81
82 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
83 let body_bytes = extract_body_bytes(req);
84 let (reply_tx, reply_rx) = mpsc::channel();
85
86 let item = InterceptedItem::Request {
87 id,
88 method: req.method().clone(),
89 uri: req.uri().clone(),
90 version: req.version(),
91 headers: req.headers().clone(),
92 body: body_bytes.clone(),
93 reply: reply_tx,
94 };
95
96 let send_result = tokio::task::block_in_place(|| self.tx.send(item).map_err(Box::new));
98 match send_result {
99 Ok(()) => {}
100 Err(_) => {
101 tracing::warn!("TUI disconnected, disabling interception");
103 self.active.store(false, Ordering::Relaxed);
104 put_body_back(req, body_bytes);
105 return;
106 }
107 }
108
109 match tokio::task::block_in_place(|| reply_rx.recv()) {
111 Ok(Verdict::Forward {
112 headers,
113 body,
114 ..
115 }) => {
116 *req.headers_mut() = *headers;
121 let changed = body != body_bytes;
122 put_body_back(req, body.clone());
123 fix_headers_after_edit(req.headers_mut(), body.len(), changed);
124 }
125 Ok(Verdict::Drop) => {
126 req.extensions_mut().insert(Dropped);
127 put_body_back(req, Bytes::new());
128 fix_headers_after_edit(req.headers_mut(), 0, true);
129 }
130 Err(_) => {
131 put_body_back(req, body_bytes);
132 }
133 }
134 }
135
136 fn handle_response(&self, res: &mut Response<BoxBody>) {
137 info!("<< {}", res.status());
138
139 if !self.active.load(Ordering::Relaxed)
141 || res.extensions().get::<Buffered>().is_none()
142 {
143 return;
144 }
145
146 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
147 let body_bytes = extract_response_body_bytes(res);
148 let (reply_tx, reply_rx) = mpsc::channel();
149
150 let item = InterceptedItem::Response {
151 id,
152 status: res.status(),
153 version: res.version(),
154 headers: res.headers().clone(),
155 body: body_bytes.clone(),
156 reply: reply_tx,
157 };
158
159 let send_result = tokio::task::block_in_place(|| self.tx.send(item).map_err(Box::new));
160 match send_result {
161 Ok(()) => {}
162 Err(_) => {
163 tracing::warn!("TUI disconnected, disabling interception");
164 self.active.store(false, Ordering::Relaxed);
165 put_response_body_back(res, body_bytes);
166 return;
167 }
168 }
169
170 match tokio::task::block_in_place(|| reply_rx.recv()) {
171 Ok(Verdict::Forward {
172 headers,
173 body,
174 status,
175 ..
176 }) => {
177 *res.headers_mut() = *headers;
178 if let Some(s) = status {
179 *res.status_mut() = s;
180 }
181 let changed = body != body_bytes;
182 put_response_body_back(res, body.clone());
183 fix_headers_after_edit(res.headers_mut(), body.len(), changed);
184 }
185 Ok(Verdict::Drop) => {
186 res.extensions_mut().insert(Dropped);
187 put_response_body_back(res, Bytes::new());
188 fix_headers_after_edit(res.headers_mut(), 0, true);
189 }
190 Err(_) => {
191 put_response_body_back(res, body_bytes);
192 }
193 }
194 }
195}
196
197fn fix_headers_after_edit(headers: &mut HeaderMap, body_len: usize, body_changed: bool) {
203 for name in &[
205 hyper::header::CONNECTION,
206 hyper::header::PROXY_AUTHORIZATION,
207 hyper::header::PROXY_AUTHENTICATE,
208 hyper::header::TE,
209 hyper::header::TRAILER,
210 hyper::header::UPGRADE,
211 ] {
212 headers.remove(name);
213 }
214 headers.remove("keep-alive");
216
217 if !body_changed {
218 return; }
220 headers.remove(hyper::header::TRANSFER_ENCODING);
221 headers.remove(hyper::header::CONTENT_ENCODING);
222 if body_len > 0 {
223 headers.insert(
224 hyper::header::CONTENT_LENGTH,
225 hyper::header::HeaderValue::from(body_len),
226 );
227 } else {
228 headers.remove(hyper::header::CONTENT_LENGTH);
229 }
230}
231
232pub fn is_text_body(body: &Bytes) -> bool {
234 body.is_empty() || std::str::from_utf8(body).is_ok()
235}
236
237pub fn serialize_request(
239 method: &Method,
240 uri: &Uri,
241 version: Version,
242 headers: &HeaderMap,
243 body: &Bytes,
244) -> String {
245 let mut s = format!("{method} {uri} {version:?}\r\n");
246 for (name, value) in headers.iter() {
247 s.push_str(&format!(
248 "{}: {}\r\n",
249 name,
250 value.to_str().unwrap_or("<binary>")
251 ));
252 }
253 s.push_str("\r\n");
254 if !body.is_empty() {
255 match std::str::from_utf8(body) {
256 Ok(text) => s.push_str(text),
257 Err(_) => s.push_str(&format!("<binary {} bytes>", body.len())),
258 }
259 }
260 s
261}
262
263pub fn serialize_response(
265 status: StatusCode,
266 version: Version,
267 headers: &HeaderMap,
268 body: &Bytes,
269) -> String {
270 let mut s = format!("{version:?} {status}\r\n");
271 for (name, value) in headers.iter() {
272 s.push_str(&format!(
273 "{}: {}\r\n",
274 name,
275 value.to_str().unwrap_or("<binary>")
276 ));
277 }
278 s.push_str("\r\n");
279 if !body.is_empty() {
280 match std::str::from_utf8(body) {
281 Ok(text) => s.push_str(text),
282 Err(_) => s.push_str(&format!("<binary {} bytes>", body.len())),
283 }
284 }
285 s
286}
287
288pub fn parse_request_text(text: &str) -> Option<(Method, Uri, HeaderMap, Bytes)> {
290 let (head, body) = text.split_once("\r\n\r\n").unwrap_or((text, ""));
291
292 let mut lines = head.lines();
293 let request_line = lines.next()?;
294 let mut parts = request_line.splitn(3, ' ');
295 let method: Method = parts.next()?.parse().ok()?;
296 let uri: Uri = parts.next()?.parse().ok()?;
297
298 let mut headers = HeaderMap::new();
299 for line in lines {
300 if let Some((name, value)) = line.split_once(": ") {
301 if let (Ok(n), Ok(v)) = (
302 name.parse::<hyper::header::HeaderName>(),
303 value.parse::<hyper::header::HeaderValue>(),
304 ) {
305 headers.append(n, v);
306 }
307 }
308 }
309
310 Some((method, uri, headers, Bytes::from(body.to_string())))
311}
312
313pub fn parse_response_text(text: &str) -> Option<(StatusCode, HeaderMap, Bytes)> {
315 let (head, body) = text.split_once("\r\n\r\n").unwrap_or((text, ""));
316
317 let mut lines = head.lines();
318 let status_line = lines.next()?;
319 let status_str = status_line.split_once(' ')?.1;
320 let status_code: u16 = status_str.split_whitespace().next()?.parse().ok()?;
321 let status = StatusCode::from_u16(status_code).ok()?;
322
323 let mut headers = HeaderMap::new();
324 for line in lines {
325 if let Some((name, value)) = line.split_once(": ") {
326 if let (Ok(n), Ok(v)) = (
327 name.parse::<hyper::header::HeaderName>(),
328 value.parse::<hyper::header::HeaderValue>(),
329 ) {
330 headers.append(n, v);
331 }
332 }
333 }
334
335 Some((status, headers, Bytes::from(body.to_string())))
336}