1use anyhow::{Result, bail};
2use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
3
4const MAX_HEADER_SIZE: usize = 8192;
5
6const HTTP_502: &[u8] = b"HTTP/1.1 502 Bad Gateway\r\nContent-Type: text/plain\r\nContent-Length: 39\r\nConnection: close\r\n\r\nFailed to connect to upstream service.\n";
7
8const HTTP_404: &[u8] = b"HTTP/1.1 404 Not Found\r\nContent-Type: text/plain\r\nContent-Length: 26\r\nConnection: close\r\n\r\nNo tunnel for this domain\n";
9
10const HTTP_401: &[u8] = b"HTTP/1.1 401 Unauthorized\r\nWWW-Authenticate: Basic realm=\"tnnl\"\r\nContent-Length: 0\r\nConnection: close\r\n\r\n";
11
12pub fn extract_host(buf: &[u8]) -> Option<String> {
13 let header_str = std::str::from_utf8(buf).ok()?;
14 for line in header_str.split("\r\n").skip(1) {
15 if line.is_empty() {
16 break;
17 }
18 if line.len() > 5 && line[..5].eq_ignore_ascii_case("host:") {
19 let host = line[5..].trim();
20 return Some(host.split(':').next().unwrap_or(host).to_lowercase());
21 }
22 }
23 None
24}
25
26pub async fn read_http_head<R: AsyncRead + Unpin>(stream: &mut R) -> Result<Vec<u8>> {
29 let mut buf = Vec::with_capacity(4096);
30 let mut tmp = [0u8; 4096];
31 loop {
32 let n = stream.read(&mut tmp).await?;
33 if n == 0 {
34 bail!("connection closed before headers complete");
35 }
36 buf.extend_from_slice(&tmp[..n]);
37 if buf.len() > MAX_HEADER_SIZE {
38 bail!("headers exceed max size");
39 }
40 if buf.windows(4).any(|w| w == b"\r\n\r\n") {
41 return Ok(buf);
42 }
43 }
44}
45
46pub async fn write_502<W: AsyncWrite + Unpin>(stream: &mut W) -> Result<()> {
47 stream.write_all(HTTP_502).await?;
48 stream.flush().await?;
49 Ok(())
50}
51
52pub async fn write_401<W: AsyncWrite + Unpin>(stream: &mut W) -> Result<()> {
53 stream.write_all(HTTP_401).await?;
54 stream.flush().await?;
55 Ok(())
56}
57
58pub fn extract_authorization(buf: &[u8]) -> Option<String> {
60 let text = std::str::from_utf8(buf).ok()?;
61 for line in text.split("\r\n").skip(1) {
62 if line.is_empty() {
63 break;
64 }
65 if line.len() > 14 && line[..14].eq_ignore_ascii_case("authorization:") {
66 return Some(line[14..].trim().to_string());
67 }
68 }
69 None
70}
71
72pub async fn write_404<W: AsyncWrite + Unpin>(stream: &mut W) -> Result<()> {
73 stream.write_all(HTTP_404).await?;
74 stream.flush().await?;
75 Ok(())
76}
77
78pub fn is_chunked(headers: &[u8]) -> bool {
79 let text = std::str::from_utf8(headers).unwrap_or("");
80 for line in text.split("\r\n") {
81 if let Some(colon) = line.find(':')
82 && line[..colon].eq_ignore_ascii_case("transfer-encoding")
83 {
84 return line[colon + 1..].trim().eq_ignore_ascii_case("chunked");
85 }
86 }
87 false
88}
89
90pub fn headers_end(buf: &[u8]) -> Option<usize> {
92 buf.windows(4).position(|w| w == b"\r\n\r\n").map(|i| i + 4)
93}
94
95pub fn parse_content_length(buf: &[u8]) -> usize {
96 let text = std::str::from_utf8(buf).unwrap_or("");
97 for line in text.split("\r\n") {
98 if let Some(colon) = line.find(':')
99 && line[..colon].eq_ignore_ascii_case("content-length")
100 {
101 return line[colon + 1..].trim().parse().unwrap_or(0);
102 }
103 }
104 0
105}
106
107pub fn parse_response_status(buf: &[u8]) -> u16 {
109 std::str::from_utf8(buf)
111 .unwrap_or("")
112 .split_ascii_whitespace()
113 .nth(1)
114 .and_then(|s| s.parse().ok())
115 .unwrap_or(502)
116}
117
118pub fn parse_request_line(buf: &[u8]) -> (String, String) {
120 let s = match std::str::from_utf8(buf) {
121 Ok(s) => s,
122 Err(_) => return ("?".into(), "?".into()),
123 };
124 let first_line = s.split("\r\n").next().unwrap_or("");
125 let parts: Vec<&str> = first_line.splitn(3, ' ').collect();
126 if parts.len() >= 2 {
127 (parts[0].to_string(), parts[1].to_string())
128 } else {
129 ("?".into(), "?".into())
130 }
131}
132
133#[cfg(test)]
134mod tests {
135 use super::*;
136
137 #[test]
138 fn extract_host_basic() {
139 let req = b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n";
140 assert_eq!(extract_host(req), Some("example.com".into()));
141 }
142
143 #[test]
144 fn extract_host_strips_port() {
145 let req = b"GET / HTTP/1.1\r\nHost: example.com:8080\r\n\r\n";
146 assert_eq!(extract_host(req), Some("example.com".into()));
147 }
148
149 #[test]
150 fn extract_host_case_insensitive() {
151 let req = b"GET / HTTP/1.1\r\nHOST: Example.COM\r\n\r\n";
152 assert_eq!(extract_host(req), Some("example.com".into()));
153 }
154
155 #[test]
156 fn extract_host_missing() {
157 let req = b"GET / HTTP/1.1\r\nContent-Type: text/plain\r\n\r\n";
158 assert_eq!(extract_host(req), None);
159 }
160
161 #[test]
162 fn parse_request_line_get() {
163 let req = b"GET /api/users HTTP/1.1\r\nHost: x\r\n\r\n";
164 assert_eq!(parse_request_line(req), ("GET".into(), "/api/users".into()));
165 }
166
167 #[test]
168 fn parse_request_line_post_with_query() {
169 let req = b"POST /webhook?secret=abc HTTP/1.1\r\n\r\n";
170 assert_eq!(
171 parse_request_line(req),
172 ("POST".into(), "/webhook?secret=abc".into())
173 );
174 }
175
176 #[test]
177 fn parse_request_line_garbage() {
178 assert_eq!(parse_request_line(b"nothttp"), ("?".into(), "?".into()));
179 }
180
181 #[test]
182 fn parse_response_status_200() {
183 assert_eq!(parse_response_status(b"HTTP/1.1 200 OK\r\n"), 200);
184 }
185
186 #[test]
187 fn parse_response_status_404() {
188 assert_eq!(parse_response_status(b"HTTP/1.1 404 Not Found\r\n"), 404);
189 }
190
191 #[test]
192 fn parse_response_status_empty_returns_502() {
193 assert_eq!(parse_response_status(b""), 502);
194 }
195
196 #[test]
197 fn parse_content_length_present() {
198 let headers = b"HTTP/1.1 200 OK\r\nContent-Length: 42\r\n\r\n";
199 assert_eq!(parse_content_length(headers), 42);
200 }
201
202 #[test]
203 fn parse_content_length_missing() {
204 let headers = b"HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\n\r\n";
205 assert_eq!(parse_content_length(headers), 0);
206 }
207
208 #[test]
209 fn parse_content_length_case_insensitive() {
210 let headers = b"HTTP/1.1 200 OK\r\ncontent-length: 99\r\n\r\n";
211 assert_eq!(parse_content_length(headers), 99);
212 }
213
214 #[test]
215 fn is_chunked_true() {
216 let headers = b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n";
217 assert!(is_chunked(headers));
218 }
219
220 #[test]
221 fn is_chunked_false_when_missing() {
222 let headers = b"HTTP/1.1 200 OK\r\nContent-Length: 10\r\n\r\n";
223 assert!(!is_chunked(headers));
224 }
225
226 #[test]
227 fn is_chunked_case_insensitive() {
228 let headers = b"HTTP/1.1 200 OK\r\ntransfer-encoding: CHUNKED\r\n\r\n";
229 assert!(is_chunked(headers));
230 }
231
232 #[test]
233 fn headers_end_found() {
234 let buf = b"GET / HTTP/1.1\r\nHost: x\r\n\r\nbody";
235 let end = headers_end(buf).unwrap();
236 assert_eq!(&buf[end..], b"body");
237 }
238
239 #[test]
240 fn headers_end_not_found() {
241 assert_eq!(headers_end(b"GET / HTTP/1.1\r\nHost: x"), None);
242 }
243}