Skip to main content

tnnl/
proxy.rs

1use anyhow::{Result, bail};
2use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
3
4const MAX_HEADER_SIZE: usize = 8192;
5
6const HTTP_502: &[u8] = b"HTTP/1.1 502 Bad Gateway\r\nContent-Type: text/plain\r\nContent-Length: 39\r\nConnection: close\r\n\r\nFailed to connect to upstream service.\n";
7
8const HTTP_404: &[u8] = b"HTTP/1.1 404 Not Found\r\nContent-Type: text/plain\r\nContent-Length: 26\r\nConnection: close\r\n\r\nNo tunnel for this domain\n";
9
10const HTTP_401: &[u8] = b"HTTP/1.1 401 Unauthorized\r\nWWW-Authenticate: Basic realm=\"tnnl\"\r\nContent-Length: 0\r\nConnection: close\r\n\r\n";
11
12pub fn extract_host(buf: &[u8]) -> Option<String> {
13    let header_str = std::str::from_utf8(buf).ok()?;
14    for line in header_str.split("\r\n").skip(1) {
15        if line.is_empty() {
16            break;
17        }
18        if line.len() > 5 && line[..5].eq_ignore_ascii_case("host:") {
19            let host = line[5..].trim();
20            return Some(host.split(':').next().unwrap_or(host).to_lowercase());
21        }
22    }
23    None
24}
25
26/// Read from a TCP stream until we have the full HTTP header block.
27/// Returns the buffered bytes (headers + any body bytes already received).
28pub async fn read_http_head<R: AsyncRead + Unpin>(stream: &mut R) -> Result<Vec<u8>> {
29    let mut buf = Vec::with_capacity(4096);
30    let mut tmp = [0u8; 4096];
31    loop {
32        let n = stream.read(&mut tmp).await?;
33        if n == 0 {
34            bail!("connection closed before headers complete");
35        }
36        buf.extend_from_slice(&tmp[..n]);
37        if buf.len() > MAX_HEADER_SIZE {
38            bail!("headers exceed max size");
39        }
40        if buf.windows(4).any(|w| w == b"\r\n\r\n") {
41            return Ok(buf);
42        }
43    }
44}
45
46pub async fn write_502<W: AsyncWrite + Unpin>(stream: &mut W) -> Result<()> {
47    stream.write_all(HTTP_502).await?;
48    stream.flush().await?;
49    Ok(())
50}
51
52pub async fn write_401<W: AsyncWrite + Unpin>(stream: &mut W) -> Result<()> {
53    stream.write_all(HTTP_401).await?;
54    stream.flush().await?;
55    Ok(())
56}
57
58/// Returns the value of the Authorization header, or None.
59pub fn extract_authorization(buf: &[u8]) -> Option<String> {
60    let text = std::str::from_utf8(buf).ok()?;
61    for line in text.split("\r\n").skip(1) {
62        if line.is_empty() {
63            break;
64        }
65        if line.len() > 14 && line[..14].eq_ignore_ascii_case("authorization:") {
66            return Some(line[14..].trim().to_string());
67        }
68    }
69    None
70}
71
72pub async fn write_404<W: AsyncWrite + Unpin>(stream: &mut W) -> Result<()> {
73    stream.write_all(HTTP_404).await?;
74    stream.flush().await?;
75    Ok(())
76}
77
78pub fn is_chunked(headers: &[u8]) -> bool {
79    let text = std::str::from_utf8(headers).unwrap_or("");
80    for line in text.split("\r\n") {
81        if let Some(colon) = line.find(':')
82            && line[..colon].eq_ignore_ascii_case("transfer-encoding")
83        {
84            return line[colon + 1..].trim().eq_ignore_ascii_case("chunked");
85        }
86    }
87    false
88}
89
90/// Index of the first byte after the header block (i.e. after \r\n\r\n).
91pub fn headers_end(buf: &[u8]) -> Option<usize> {
92    buf.windows(4).position(|w| w == b"\r\n\r\n").map(|i| i + 4)
93}
94
95pub fn parse_content_length(buf: &[u8]) -> usize {
96    let text = std::str::from_utf8(buf).unwrap_or("");
97    for line in text.split("\r\n") {
98        if let Some(colon) = line.find(':')
99            && line[..colon].eq_ignore_ascii_case("content-length")
100        {
101            return line[colon + 1..].trim().parse().unwrap_or(0);
102        }
103    }
104    0
105}
106
107/// Extract the status code from the first line of an HTTP response.
108pub fn parse_response_status(buf: &[u8]) -> u16 {
109    // "HTTP/1.1 200 OK\r\n..."
110    std::str::from_utf8(buf)
111        .unwrap_or("")
112        .split_ascii_whitespace()
113        .nth(1)
114        .and_then(|s| s.parse().ok())
115        .unwrap_or(502)
116}
117
118/// Try to extract method and path from the first line of an HTTP request.
119pub fn parse_request_line(buf: &[u8]) -> (String, String) {
120    let s = match std::str::from_utf8(buf) {
121        Ok(s) => s,
122        Err(_) => return ("?".into(), "?".into()),
123    };
124    let first_line = s.split("\r\n").next().unwrap_or("");
125    let parts: Vec<&str> = first_line.splitn(3, ' ').collect();
126    if parts.len() >= 2 {
127        (parts[0].to_string(), parts[1].to_string())
128    } else {
129        ("?".into(), "?".into())
130    }
131}
132
133#[cfg(test)]
134mod tests {
135    use super::*;
136
137    #[test]
138    fn extract_host_basic() {
139        let req = b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n";
140        assert_eq!(extract_host(req), Some("example.com".into()));
141    }
142
143    #[test]
144    fn extract_host_strips_port() {
145        let req = b"GET / HTTP/1.1\r\nHost: example.com:8080\r\n\r\n";
146        assert_eq!(extract_host(req), Some("example.com".into()));
147    }
148
149    #[test]
150    fn extract_host_case_insensitive() {
151        let req = b"GET / HTTP/1.1\r\nHOST: Example.COM\r\n\r\n";
152        assert_eq!(extract_host(req), Some("example.com".into()));
153    }
154
155    #[test]
156    fn extract_host_missing() {
157        let req = b"GET / HTTP/1.1\r\nContent-Type: text/plain\r\n\r\n";
158        assert_eq!(extract_host(req), None);
159    }
160
161    #[test]
162    fn parse_request_line_get() {
163        let req = b"GET /api/users HTTP/1.1\r\nHost: x\r\n\r\n";
164        assert_eq!(parse_request_line(req), ("GET".into(), "/api/users".into()));
165    }
166
167    #[test]
168    fn parse_request_line_post_with_query() {
169        let req = b"POST /webhook?secret=abc HTTP/1.1\r\n\r\n";
170        assert_eq!(
171            parse_request_line(req),
172            ("POST".into(), "/webhook?secret=abc".into())
173        );
174    }
175
176    #[test]
177    fn parse_request_line_garbage() {
178        assert_eq!(parse_request_line(b"nothttp"), ("?".into(), "?".into()));
179    }
180
181    #[test]
182    fn parse_response_status_200() {
183        assert_eq!(parse_response_status(b"HTTP/1.1 200 OK\r\n"), 200);
184    }
185
186    #[test]
187    fn parse_response_status_404() {
188        assert_eq!(parse_response_status(b"HTTP/1.1 404 Not Found\r\n"), 404);
189    }
190
191    #[test]
192    fn parse_response_status_empty_returns_502() {
193        assert_eq!(parse_response_status(b""), 502);
194    }
195
196    #[test]
197    fn parse_content_length_present() {
198        let headers = b"HTTP/1.1 200 OK\r\nContent-Length: 42\r\n\r\n";
199        assert_eq!(parse_content_length(headers), 42);
200    }
201
202    #[test]
203    fn parse_content_length_missing() {
204        let headers = b"HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\n\r\n";
205        assert_eq!(parse_content_length(headers), 0);
206    }
207
208    #[test]
209    fn parse_content_length_case_insensitive() {
210        let headers = b"HTTP/1.1 200 OK\r\ncontent-length: 99\r\n\r\n";
211        assert_eq!(parse_content_length(headers), 99);
212    }
213
214    #[test]
215    fn is_chunked_true() {
216        let headers = b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n";
217        assert!(is_chunked(headers));
218    }
219
220    #[test]
221    fn is_chunked_false_when_missing() {
222        let headers = b"HTTP/1.1 200 OK\r\nContent-Length: 10\r\n\r\n";
223        assert!(!is_chunked(headers));
224    }
225
226    #[test]
227    fn is_chunked_case_insensitive() {
228        let headers = b"HTTP/1.1 200 OK\r\ntransfer-encoding: CHUNKED\r\n\r\n";
229        assert!(is_chunked(headers));
230    }
231
232    #[test]
233    fn headers_end_found() {
234        let buf = b"GET / HTTP/1.1\r\nHost: x\r\n\r\nbody";
235        let end = headers_end(buf).unwrap();
236        assert_eq!(&buf[end..], b"body");
237    }
238
239    #[test]
240    fn headers_end_not_found() {
241        assert_eq!(headers_end(b"GET / HTTP/1.1\r\nHost: x"), None);
242    }
243}