Skip to main content

rustgate/
intercept.rs

1use crate::handler::{
2    extract_body_bytes, extract_response_body_bytes, put_body_back, put_response_body_back,
3    BoxBody, Buffered, Dropped, RequestHandler,
4};
5use bytes::Bytes;
6use hyper::header::HeaderMap;
7use hyper::{Method, Request, Response, StatusCode, Uri, Version};
8use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
9use std::sync::mpsc;
10use std::sync::Arc;
11use tracing::info;
12
13pub type InterceptId = u64;
14
15/// Sent from handler to TUI.
16pub enum InterceptedItem {
17    Request {
18        id: InterceptId,
19        method: Method,
20        uri: Uri,
21        version: Version,
22        headers: HeaderMap,
23        body: Bytes,
24        reply: mpsc::Sender<Verdict>,
25    },
26    Response {
27        id: InterceptId,
28        status: StatusCode,
29        version: Version,
30        headers: HeaderMap,
31        body: Bytes,
32        reply: mpsc::Sender<Verdict>,
33    },
34}
35
36/// Sent from TUI back to handler.
37pub enum Verdict {
38    Forward {
39        headers: Box<HeaderMap>,
40        body: Bytes,
41        method: Option<Method>,
42        uri: Option<Uri>,
43        status: Option<StatusCode>,
44    },
45    Drop,
46}
47
48/// RequestHandler that intercepts requests/responses and sends them to the TUI.
49pub struct InterceptHandler {
50    tx: mpsc::SyncSender<InterceptedItem>,
51    active: Arc<AtomicBool>,
52    next_id: AtomicU64,
53}
54
55impl InterceptHandler {
56    pub fn new(tx: mpsc::SyncSender<InterceptedItem>, active: Arc<AtomicBool>) -> Self {
57        Self {
58            tx,
59            active,
60            next_id: AtomicU64::new(1),
61        }
62    }
63}
64
65impl RequestHandler for InterceptHandler {
66    fn handle_request(&self, req: &mut Request<BoxBody>) {
67        let path = req.uri().path();
68        let display_uri = if req.uri().query().is_some() {
69            format!("{path}?...")
70        } else {
71            path.to_string()
72        };
73        info!(">> {} {} {:?}", req.method(), display_uri, req.version());
74
75        // Only intercept if body was pre-buffered and interception is active
76        if !self.active.load(Ordering::Relaxed)
77            || req.extensions().get::<Buffered>().is_none()
78        {
79            return;
80        }
81
82        let id = self.next_id.fetch_add(1, Ordering::Relaxed);
83        let body_bytes = extract_body_bytes(req);
84        let (reply_tx, reply_rx) = mpsc::channel();
85
86        let item = InterceptedItem::Request {
87            id,
88            method: req.method().clone(),
89            uri: req.uri().clone(),
90            version: req.version(),
91            headers: req.headers().clone(),
92            body: body_bytes.clone(),
93            reply: reply_tx,
94        };
95
96        // Backpressure: block until queue has space (bounded channel)
97        let send_result = tokio::task::block_in_place(|| self.tx.send(item).map_err(Box::new));
98        match send_result {
99            Ok(()) => {}
100            Err(_) => {
101                // Disconnected — TUI exited
102                tracing::warn!("TUI disconnected, disabling interception");
103                self.active.store(false, Ordering::Relaxed);
104                put_body_back(req, body_bytes);
105                return;
106            }
107        }
108
109        // Use block_in_place so the tokio worker thread is released while waiting
110        match tokio::task::block_in_place(|| reply_rx.recv()) {
111            Ok(Verdict::Forward {
112                headers,
113                body,
114                ..
115            }) => {
116                // Apply header and body edits only.
117                // Method/URI changes are ignored because the upstream connection
118                // is already resolved from the original request — changing the URI
119                // in the TUI would not retarget the connection.
120                *req.headers_mut() = *headers;
121                let changed = body != body_bytes;
122                put_body_back(req, body.clone());
123                fix_headers_after_edit(req.headers_mut(), body.len(), changed);
124            }
125            Ok(Verdict::Drop) => {
126                req.extensions_mut().insert(Dropped);
127                put_body_back(req, Bytes::new());
128                fix_headers_after_edit(req.headers_mut(), 0, true);
129            }
130            Err(_) => {
131                put_body_back(req, body_bytes);
132            }
133        }
134    }
135
136    fn handle_response(&self, res: &mut Response<BoxBody>) {
137        info!("<< {}", res.status());
138
139        // Only intercept if body was pre-buffered and interception is active
140        if !self.active.load(Ordering::Relaxed)
141            || res.extensions().get::<Buffered>().is_none()
142        {
143            return;
144        }
145
146        let id = self.next_id.fetch_add(1, Ordering::Relaxed);
147        let body_bytes = extract_response_body_bytes(res);
148        let (reply_tx, reply_rx) = mpsc::channel();
149
150        let item = InterceptedItem::Response {
151            id,
152            status: res.status(),
153            version: res.version(),
154            headers: res.headers().clone(),
155            body: body_bytes.clone(),
156            reply: reply_tx,
157        };
158
159        let send_result = tokio::task::block_in_place(|| self.tx.send(item).map_err(Box::new));
160        match send_result {
161            Ok(()) => {}
162            Err(_) => {
163                tracing::warn!("TUI disconnected, disabling interception");
164                self.active.store(false, Ordering::Relaxed);
165                put_response_body_back(res, body_bytes);
166                return;
167            }
168        }
169
170        match tokio::task::block_in_place(|| reply_rx.recv()) {
171            Ok(Verdict::Forward {
172                headers,
173                body,
174                status,
175                ..
176            }) => {
177                *res.headers_mut() = *headers;
178                if let Some(s) = status {
179                    *res.status_mut() = s;
180                }
181                let changed = body != body_bytes;
182                put_response_body_back(res, body.clone());
183                fix_headers_after_edit(res.headers_mut(), body.len(), changed);
184            }
185            Ok(Verdict::Drop) => {
186                res.extensions_mut().insert(Dropped);
187                put_response_body_back(res, Bytes::new());
188                fix_headers_after_edit(res.headers_mut(), 0, true);
189            }
190            Err(_) => {
191                put_response_body_back(res, body_bytes);
192            }
193        }
194    }
195}
196
197/// Recompute framing headers after body mutation to prevent corrupt HTTP.
198/// If `body_changed` is false, preserve all original headers (no-op).
199/// Sanitize headers after interception edit.
200/// Always strips hop-by-hop headers (edit may have reintroduced them).
201/// Recomputes framing headers only if body was actually changed.
202fn fix_headers_after_edit(headers: &mut HeaderMap, body_len: usize, body_changed: bool) {
203    // Always strip hop-by-hop headers that should not be forwarded
204    for name in &[
205        hyper::header::CONNECTION,
206        hyper::header::PROXY_AUTHORIZATION,
207        hyper::header::PROXY_AUTHENTICATE,
208        hyper::header::TE,
209        hyper::header::TRAILER,
210        hyper::header::UPGRADE,
211    ] {
212        headers.remove(name);
213    }
214    // Also strip Keep-Alive (not in hyper constants)
215    headers.remove("keep-alive");
216
217    if !body_changed {
218        return; // Preserve original framing headers
219    }
220    headers.remove(hyper::header::TRANSFER_ENCODING);
221    headers.remove(hyper::header::CONTENT_ENCODING);
222    if body_len > 0 {
223        headers.insert(
224            hyper::header::CONTENT_LENGTH,
225            hyper::header::HeaderValue::from(body_len),
226        );
227    } else {
228        headers.remove(hyper::header::CONTENT_LENGTH);
229    }
230}
231
232/// Check if body is valid UTF-8 (safe for text editing).
233pub fn is_text_body(body: &Bytes) -> bool {
234    body.is_empty() || std::str::from_utf8(body).is_ok()
235}
236
237/// Serialize an HTTP request to raw text for display/editing.
238pub fn serialize_request(
239    method: &Method,
240    uri: &Uri,
241    version: Version,
242    headers: &HeaderMap,
243    body: &Bytes,
244) -> String {
245    let mut s = format!("{method} {uri} {version:?}\r\n");
246    for (name, value) in headers.iter() {
247        s.push_str(&format!(
248            "{}: {}\r\n",
249            name,
250            value.to_str().unwrap_or("<binary>")
251        ));
252    }
253    s.push_str("\r\n");
254    if !body.is_empty() {
255        match std::str::from_utf8(body) {
256            Ok(text) => s.push_str(text),
257            Err(_) => s.push_str(&format!("<binary {} bytes>", body.len())),
258        }
259    }
260    s
261}
262
263/// Serialize an HTTP response to raw text for display/editing.
264pub fn serialize_response(
265    status: StatusCode,
266    version: Version,
267    headers: &HeaderMap,
268    body: &Bytes,
269) -> String {
270    let mut s = format!("{version:?} {status}\r\n");
271    for (name, value) in headers.iter() {
272        s.push_str(&format!(
273            "{}: {}\r\n",
274            name,
275            value.to_str().unwrap_or("<binary>")
276        ));
277    }
278    s.push_str("\r\n");
279    if !body.is_empty() {
280        match std::str::from_utf8(body) {
281            Ok(text) => s.push_str(text),
282            Err(_) => s.push_str(&format!("<binary {} bytes>", body.len())),
283        }
284    }
285    s
286}
287
288/// Parse raw HTTP request text back into parts.
289pub fn parse_request_text(text: &str) -> Option<(Method, Uri, HeaderMap, Bytes)> {
290    let (head, body) = text.split_once("\r\n\r\n").unwrap_or((text, ""));
291
292    let mut lines = head.lines();
293    let request_line = lines.next()?;
294    let mut parts = request_line.splitn(3, ' ');
295    let method: Method = parts.next()?.parse().ok()?;
296    let uri: Uri = parts.next()?.parse().ok()?;
297
298    let mut headers = HeaderMap::new();
299    for line in lines {
300        if let Some((name, value)) = line.split_once(": ") {
301            if let (Ok(n), Ok(v)) = (
302                name.parse::<hyper::header::HeaderName>(),
303                value.parse::<hyper::header::HeaderValue>(),
304            ) {
305                headers.append(n, v);
306            }
307        }
308    }
309
310    Some((method, uri, headers, Bytes::from(body.to_string())))
311}
312
313/// Parse raw HTTP response text back into parts.
314pub fn parse_response_text(text: &str) -> Option<(StatusCode, HeaderMap, Bytes)> {
315    let (head, body) = text.split_once("\r\n\r\n").unwrap_or((text, ""));
316
317    let mut lines = head.lines();
318    let status_line = lines.next()?;
319    let status_str = status_line.split_once(' ')?.1;
320    let status_code: u16 = status_str.split_whitespace().next()?.parse().ok()?;
321    let status = StatusCode::from_u16(status_code).ok()?;
322
323    let mut headers = HeaderMap::new();
324    for line in lines {
325        if let Some((name, value)) = line.split_once(": ") {
326            if let (Ok(n), Ok(v)) = (
327                name.parse::<hyper::header::HeaderName>(),
328                value.parse::<hyper::header::HeaderValue>(),
329            ) {
330                headers.append(n, v);
331            }
332        }
333    }
334
335    Some((status, headers, Bytes::from(body.to_string())))
336}