1use std::collections::HashMap;
2use std::fmt;
3use std::fmt::Write as _;
4use std::net::{IpAddr, SocketAddr};
5use std::path::Path;
6
7use bytes::Bytes;
8use http::{HeaderMap, HeaderName, HeaderValue};
9use serde::Serialize;
10
11use crate::{LogLevel, emit_default_log};
12
13const MAX_HEADERS: usize = 100;
14const MAX_QUERY_PARAMS: usize = 128;
15const MAX_QUERY_VALUE_LEN: usize = 8_192;
16
17#[derive(Clone, Debug, Eq, Hash, PartialEq)]
18pub enum Method {
19 Get,
20 Post,
21 Put,
22 Delete,
23 Patch,
24 Head,
25 Options,
26 Other(String),
27}
28
29impl Method {
30 pub fn from_http_str(value: &str) -> Self {
31 match value {
32 "GET" => Self::Get,
33 "POST" => Self::Post,
34 "PUT" => Self::Put,
35 "DELETE" => Self::Delete,
36 "PATCH" => Self::Patch,
37 "HEAD" => Self::Head,
38 "OPTIONS" => Self::Options,
39 other => Self::Other(other.to_string()),
40 }
41 }
42}
43
44impl fmt::Display for Method {
45 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
46 match self {
47 Method::Get => write!(f, "GET"),
48 Method::Post => write!(f, "POST"),
49 Method::Put => write!(f, "PUT"),
50 Method::Delete => write!(f, "DELETE"),
51 Method::Patch => write!(f, "PATCH"),
52 Method::Head => write!(f, "HEAD"),
53 Method::Options => write!(f, "OPTIONS"),
54 Method::Other(value) => write!(f, "{value}"),
55 }
56 }
57}
58
59#[derive(Clone, Debug)]
60pub struct Request {
61 pub method: Method,
62 pub path: String,
63 pub version: String,
64 pub headers: HeaderMap,
65 pub body: Bytes,
66 pub query_params: HashMap<String, Vec<String>>,
67 pub remote_addr: Option<SocketAddr>,
68}
69
70impl Request {
71 pub fn body_str(&self) -> Option<&str> {
72 std::str::from_utf8(self.body.as_ref()).ok()
73 }
74
75 pub fn body_as_string(&self) -> String {
76 String::from_utf8_lossy(self.body.as_ref()).into_owned()
77 }
78
79 pub fn from_bytes(bytes: &[u8]) -> Result<Request, ParseError> {
80 let (head, body) = split_head_body(bytes);
81 let head = std::str::from_utf8(head).map_err(|_| ParseError::InvalidUtf8)?;
82
83 let mut lines = head.split("\r\n");
84 let request_line = lines.next().ok_or(ParseError::MissingRequestLine)?;
85 if request_line.trim().is_empty() {
86 return Err(ParseError::MissingRequestLine);
87 }
88 let (method, raw_path, version) = parse_request_line(request_line)?;
89
90 if !matches!(version.as_str(), "HTTP/1.0" | "HTTP/1.1") {
91 return Err(ParseError::InvalidHttpVersion);
92 }
93 let mut headers = HeaderMap::new();
94 let mut header_count = 0usize;
95 for line in lines {
96 if line.is_empty() {
97 continue;
98 }
99 header_count += 1;
100 if header_count > MAX_HEADERS {
101 return Err(ParseError::TooManyHeaders);
102 }
103 let (key, value) = line.split_once(':').ok_or(ParseError::InvalidHeaderLine)?;
104 let key = key.trim();
105 let value = value.trim();
106 if key.is_empty() {
107 return Err(ParseError::InvalidHeaderLine);
108 }
109 let key = HeaderName::try_from(key).map_err(|_| ParseError::InvalidHeaderLine)?;
110 let value = HeaderValue::from_str(value).map_err(|_| ParseError::InvalidHeaderLine)?;
111 headers.append(key, value);
112 }
113
114 let (raw_path, query_params) = if let Some((path, query)) = raw_path.split_once('?') {
115 (path.to_string(), parse_query(query)?)
116 } else {
117 (raw_path, HashMap::new())
118 };
119 let path = normalize_request_path(&raw_path)?;
120
121 let body = if let Some(content_length) = header_value(&headers, "content-length") {
122 let expected = content_length
123 .parse::<usize>()
124 .map_err(|_| ParseError::InvalidContentLength)?;
125 if body.len() < expected {
126 return Err(ParseError::BodyTooShort {
127 expected,
128 actual: body.len(),
129 });
130 }
131 Bytes::copy_from_slice(&body[..expected])
132 } else {
133 Bytes::copy_from_slice(body)
134 };
135
136 Ok(Request {
137 method,
138 path,
139 version,
140 headers,
141 body,
142 query_params,
143 remote_addr: None,
144 })
145 }
146
147 pub(crate) fn from_normalized_parts(
148 method: Method,
149 path: String,
150 version: String,
151 headers: HeaderMap,
152 body: Bytes,
153 query_params: HashMap<String, Vec<String>>,
154 remote_addr: Option<SocketAddr>,
155 ) -> Result<Request, ParseError> {
156 let path = normalize_request_path(&path)?;
157 Ok(Request {
158 method,
159 path,
160 version,
161 headers,
162 body,
163 query_params,
164 remote_addr,
165 })
166 }
167
168 pub(crate) fn parse_query(query: &str) -> Result<HashMap<String, Vec<String>>, ParseError> {
169 parse_query(query)
170 }
171
172 pub fn client_ip(&self, trusted_proxies: &[IpAddr]) -> Option<IpAddr> {
173 let remote_addr = self.remote_addr?;
174 if !trusted_proxies.contains(&remote_addr.ip()) {
175 return Some(remote_addr.ip());
176 }
177
178 let forwarded = header_values(&self.headers, "x-forwarded-for")
179 .flat_map(|value| value.split(','))
180 .filter_map(|item| item.trim().parse::<IpAddr>().ok())
181 .collect::<Vec<_>>();
182
183 for candidate in forwarded.into_iter().rev() {
184 if !trusted_proxies.contains(&candidate) {
185 return Some(candidate);
186 }
187 }
188
189 Some(remote_addr.ip())
190 }
191
192 pub fn header(&self, key: &str) -> Option<&str> {
193 self.headers.get(key).and_then(|value| value.to_str().ok())
194 }
195
196 pub fn header_values<'a>(&'a self, key: &'a str) -> impl Iterator<Item = &'a str> + 'a {
197 header_values(&self.headers, key)
198 }
199}
200
201fn split_head_body(bytes: &[u8]) -> (&[u8], &[u8]) {
202 bytes
203 .windows(4)
204 .position(|window| window == b"\r\n\r\n")
205 .map(|index| (&bytes[..index], &bytes[index + 4..]))
206 .unwrap_or((bytes, &[]))
207}
208
209fn header_value<'a>(headers: &'a HeaderMap, key: &str) -> Option<&'a str> {
210 headers.get(key).and_then(|value| value.to_str().ok())
211}
212
213fn header_values<'a>(headers: &'a HeaderMap, key: &'a str) -> impl Iterator<Item = &'a str> + 'a {
214 headers
215 .get_all(key)
216 .iter()
217 .filter_map(|value| value.to_str().ok())
218}
219
220fn parse_request_line(request_line: &str) -> Result<(Method, String, String), ParseError> {
221 if request_line.contains('\t') {
222 return Err(ParseError::InvalidRequestLine);
223 }
224
225 let mut parts = request_line.split(' ');
226 let method = parts.next().ok_or(ParseError::InvalidRequestLine)?;
227 let path = parts.next().ok_or(ParseError::InvalidRequestLine)?;
228 let version = parts.next().ok_or(ParseError::InvalidRequestLine)?;
229
230 if method.is_empty()
231 || path.is_empty()
232 || version.is_empty()
233 || parts.next().is_some()
234 || request_line.contains(" ")
235 {
236 return Err(ParseError::InvalidRequestLine);
237 }
238
239 Ok((
240 Method::from_http_str(method),
241 path.to_string(),
242 version.to_string(),
243 ))
244}
245
246fn parse_query(query: &str) -> Result<HashMap<String, Vec<String>>, ParseError> {
247 let mut params = HashMap::new();
248 let mut pair_count = 0usize;
249 for pair in query.split('&') {
250 if pair.is_empty() {
251 continue;
252 }
253 pair_count += 1;
254 if pair_count > MAX_QUERY_PARAMS {
255 return Err(ParseError::TooManyQueryParams);
256 }
257
258 let (raw_key, raw_value) = if let Some((key, value)) = pair.split_once('=') {
259 (key, value)
260 } else {
261 (pair, "")
262 };
263
264 let key = percent_decode(raw_key)?;
265 let value = percent_decode(raw_value)?;
266 if key.len() > MAX_QUERY_VALUE_LEN || value.len() > MAX_QUERY_VALUE_LEN {
267 return Err(ParseError::QueryValueTooLong);
268 }
269 params.entry(key).or_insert_with(Vec::new).push(value);
270 }
271 Ok(params)
272}
273
274fn normalize_request_path(path: &str) -> Result<String, ParseError> {
275 if !path.starts_with('/') || path.contains('\0') || path.contains('\\') {
276 return Err(ParseError::InvalidPath);
277 }
278
279 let mut normalized_segments = Vec::new();
280 for segment in path.split('/') {
281 if segment.is_empty() {
282 continue;
283 }
284 if segment == "." || segment == ".." {
285 return Err(ParseError::PathTraversal);
286 }
287 normalized_segments.push(segment);
288 }
289
290 if normalized_segments.is_empty() {
291 Ok("/".to_string())
292 } else {
293 Ok(format!("/{}", normalized_segments.join("/")))
294 }
295}
296
297fn percent_decode(value: &str) -> Result<String, ParseError> {
298 let bytes = value.as_bytes();
299 let mut decoded = Vec::with_capacity(bytes.len());
300 let mut idx = 0;
301
302 while idx < bytes.len() {
303 match bytes[idx] {
304 b'+' => {
305 decoded.push(b' ');
306 idx += 1;
307 }
308 b'%' => {
309 if idx + 2 >= bytes.len() {
310 return Err(ParseError::InvalidPercentEncoding);
311 }
312
313 let high = decode_hex(bytes[idx + 1])?;
314 let low = decode_hex(bytes[idx + 2])?;
315 decoded.push((high << 4) | low);
316 idx += 3;
317 }
318 byte => {
319 decoded.push(byte);
320 idx += 1;
321 }
322 }
323 }
324
325 String::from_utf8(decoded).map_err(|_| ParseError::InvalidPercentEncoding)
326}
327
328fn decode_hex(byte: u8) -> Result<u8, ParseError> {
329 match byte {
330 b'0'..=b'9' => Ok(byte - b'0'),
331 b'a'..=b'f' => Ok(byte - b'a' + 10),
332 b'A'..=b'F' => Ok(byte - b'A' + 10),
333 _ => Err(ParseError::InvalidPercentEncoding),
334 }
335}
336
337#[derive(Clone, Debug)]
338pub struct Response {
339 pub status_code: u16,
340 pub status_text: String,
341 pub headers: Vec<(String, String)>,
342 pub body: Vec<u8>,
343}
344
345impl Response {
346 pub fn new(status_code: u16, status_text: impl Into<String>, body: impl Into<Vec<u8>>) -> Self {
347 Self {
348 status_code,
349 status_text: status_text.into(),
350 headers: Vec::new(),
351 body: body.into(),
352 }
353 }
354
355 pub fn ok(body: impl Into<Vec<u8>>) -> Self {
356 Self::new(200, "OK", body)
357 }
358
359 pub fn not_found() -> Self {
360 Self::from_error(404, "Not Found", "404 Not Found")
361 }
362
363 pub fn bad_request(message: impl Into<Vec<u8>>) -> Self {
364 let message = message.into();
365 Self::from_error(
366 400,
367 "Bad Request",
368 String::from_utf8_lossy(&message).into_owned(),
369 )
370 }
371
372 pub fn internal_server_error() -> Self {
373 Self::from_error(500, "Internal Server Error", "500 Internal Server Error")
374 }
375
376 pub fn from_error(
377 status_code: u16,
378 status_text: impl Into<String>,
379 body: impl Into<String>,
380 ) -> Self {
381 Self::new(status_code, status_text, body.into().into_bytes())
382 .with_header("Content-Type", "text/plain; charset=utf-8")
383 }
384
385 pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
386 let key = key.into();
387 let value = value.into();
388
389 match (
390 HeaderName::from_bytes(key.as_bytes()),
391 HeaderValue::from_str(&value),
392 ) {
393 (Ok(_), Ok(valid_value)) => {
394 self.headers.push((key, value_from_header(valid_value)));
395 }
396 _ => emit_default_log(
397 LogLevel::Warn,
398 "vantus.http",
399 &format!("ignored invalid response header: {}", key),
400 ),
401 }
402 self
403 }
404
405 pub fn text(body: impl Into<String>) -> Self {
406 Self::ok(body.into().into_bytes()).with_header("Content-Type", "text/plain; charset=utf-8")
407 }
408
409 pub fn html(body: impl Into<String>) -> Self {
410 Self::ok(body.into().into_bytes()).with_header("Content-Type", "text/html; charset=utf-8")
411 }
412
413 pub fn json(body: impl Into<String>) -> Self {
414 Self::ok(body.into().into_bytes())
415 .with_header("Content-Type", "application/json; charset=utf-8")
416 }
417
418 pub fn json_value(value: serde_json::Value) -> Self {
419 match serde_json::to_vec(&value) {
420 Ok(body) => {
421 Self::ok(body).with_header("Content-Type", "application/json; charset=utf-8")
422 }
423 Err(_) => Self::internal_server_error(),
424 }
425 }
426
427 pub fn json_serialized<T: Serialize>(value: &T) -> Result<Self, serde_json::Error> {
428 serde_json::to_vec(value).map(|body| {
429 Self::ok(body).with_header("Content-Type", "application/json; charset=utf-8")
430 })
431 }
432
433 pub fn to_http_bytes(&self) -> Vec<u8> {
434 let mut response = String::with_capacity(64 + self.headers.len() * 32 + self.body.len());
435 let _ = write!(
436 response,
437 "HTTP/1.1 {} {}\r\n",
438 self.status_code, self.status_text
439 );
440 let mut has_content_length = false;
441 let mut has_connection = false;
442
443 for (key, value) in &self.headers {
444 if HeaderName::from_bytes(key.as_bytes()).is_err()
445 || HeaderValue::from_str(value).is_err()
446 {
447 emit_default_log(
448 LogLevel::Warn,
449 "vantus.http",
450 &format!(
451 "ignored invalid response header during serialization: {}",
452 key
453 ),
454 );
455 continue;
456 }
457 if key.eq_ignore_ascii_case("content-length") {
458 has_content_length = true;
459 }
460 if key.eq_ignore_ascii_case("connection") {
461 has_connection = true;
462 }
463 let _ = write!(response, "{key}: {value}\r\n");
464 }
465
466 if !has_content_length {
467 let _ = write!(response, "Content-Length: {}\r\n", self.body.len());
468 }
469 if !has_connection {
470 response.push_str("Connection: close\r\n");
471 }
472
473 response.push_str("\r\n");
474 let mut bytes = response.into_bytes();
475 bytes.extend_from_slice(&self.body);
476 bytes
477 }
478
479 pub async fn file_async(path: impl AsRef<Path>) -> Self {
480 let path = path.as_ref();
481 match tokio::fs::read(path).await {
482 Ok(content) => {
483 let mut res = Self::ok(content);
484
485 if let Some(ext) = path.extension().and_then(|s| s.to_str()) {
486 res = res.with_header("Content-Type", mime_for_ext(ext));
487 }
488 res
489 }
490 Err(_) => {
491 emit_default_log(
492 LogLevel::Warn,
493 "vantus.http",
494 &format!("file not found at {:?}", path),
495 );
496 Self::not_found()
497 }
498 }
499 }
500
501 #[deprecated(note = "use Response::file_async instead")]
502 pub fn file(path: impl AsRef<Path>) -> Self {
503 let path = path.as_ref().to_path_buf();
504
505 match tokio::runtime::Handle::try_current() {
506 Ok(handle) => {
507 emit_default_log(
508 LogLevel::Warn,
509 "vantus.http",
510 "Response::file is deprecated inside async runtimes; use Response::file_async",
511 );
512 match read_file_bytes_compat(&path, handle) {
513 Ok(content) => response_from_file_bytes(&path, content),
514 Err(_) => {
515 emit_default_log(
516 LogLevel::Warn,
517 "vantus.http",
518 &format!("file not found at {:?}", path),
519 );
520 Self::not_found()
521 }
522 }
523 }
524 Err(_) => match tokio::runtime::Builder::new_current_thread()
525 .enable_all()
526 .build()
527 {
528 Ok(runtime) => runtime.block_on(Self::file_async(path)),
529 Err(_) => Self::internal_server_error(),
530 },
531 }
532 }
533}
534
535fn read_file_bytes_compat(path: &Path, handle: tokio::runtime::Handle) -> std::io::Result<Vec<u8>> {
536 tokio::task::block_in_place(|| {
537 let path = path.to_path_buf();
538 handle.block_on(async move {
539 tokio::task::spawn_blocking(move || std::fs::read(path))
540 .await
541 .map_err(|error| std::io::Error::other(error.to_string()))?
542 })
543 })
544}
545
546fn response_from_file_bytes(path: &Path, content: Vec<u8>) -> Response {
547 let mut res = Response::ok(content);
548 if let Some(ext) = path.extension().and_then(|s| s.to_str()) {
549 res = res.with_header("Content-Type", mime_for_ext(ext));
550 }
551 res
552}
553
554fn value_from_header(value: HeaderValue) -> String {
555 value.to_str().map(str::to_string).unwrap_or_default()
556}
557
558fn mime_for_ext(ext: &str) -> &'static str {
559 match ext {
560 "png" => "image/png",
561 "jpg" | "jpeg" => "image/jpeg",
562 "gif" => "image/gif",
563 "svg" => "image/svg+xml",
564 "webp" => "image/webp",
565 "css" => "text/css; charset=utf-8",
566 "js" | "mjs" => "application/javascript; charset=utf-8",
567 "html" | "htm" => "text/html; charset=utf-8",
568 "json" => "application/json; charset=utf-8",
569 "txt" => "text/plain; charset=utf-8",
570 _ => "application/octet-stream",
571 }
572}
573
574#[derive(Debug)]
575pub enum ParseError {
576 MissingRequestLine,
577 InvalidRequestLine,
578 InvalidHttpVersion,
579 InvalidPath,
580 PathTraversal,
581 InvalidUtf8,
582 InvalidHeaderLine,
583 InvalidContentLength,
584 InvalidPercentEncoding,
585 TooManyHeaders,
586 TooManyQueryParams,
587 QueryValueTooLong,
588 RequestTooLarge { limit: usize },
589 BodyTooShort { expected: usize, actual: usize },
590}
591
592impl fmt::Display for ParseError {
593 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
594 match self {
595 ParseError::MissingRequestLine => write!(f, "request line is missing"),
596 ParseError::InvalidRequestLine => write!(f, "request line is invalid"),
597 ParseError::InvalidHttpVersion => write!(f, "http version is invalid"),
598 ParseError::InvalidPath => write!(f, "request path is invalid"),
599 ParseError::PathTraversal => write!(f, "request path contains traversal sequences"),
600 ParseError::InvalidUtf8 => write!(f, "request headers are not valid utf-8"),
601 ParseError::InvalidHeaderLine => write!(f, "request header line is invalid"),
602 ParseError::InvalidContentLength => write!(f, "content-length header is invalid"),
603 ParseError::InvalidPercentEncoding => {
604 write!(f, "request query percent-encoding is invalid")
605 }
606 ParseError::TooManyHeaders => {
607 write!(
608 f,
609 "request contains too many headers (limit: {MAX_HEADERS})"
610 )
611 }
612 ParseError::TooManyQueryParams => write!(
613 f,
614 "request contains too many query parameters (limit: {MAX_QUERY_PARAMS})"
615 ),
616 ParseError::QueryValueTooLong => write!(
617 f,
618 "query key or value exceeds maximum length ({MAX_QUERY_VALUE_LEN} bytes)"
619 ),
620 ParseError::RequestTooLarge { limit } => {
621 write!(f, "request exceeds maximum allowed size ({limit} bytes)")
622 }
623 ParseError::BodyTooShort { expected, actual } => write!(
624 f,
625 "request body is shorter than content-length (expected {expected}, got {actual})"
626 ),
627 }
628 }
629}
630
631impl std::error::Error for ParseError {}