Skip to main content

vantus/core/
http.rs

1use std::collections::HashMap;
2use std::fmt;
3use std::fmt::Write as _;
4use std::net::{IpAddr, SocketAddr};
5use std::path::Path;
6
7use bytes::Bytes;
8use http::{HeaderMap, HeaderName, HeaderValue};
9use serde::Serialize;
10
11use crate::{LogLevel, emit_default_log};
12
13const MAX_HEADERS: usize = 100;
14const MAX_QUERY_PARAMS: usize = 128;
15const MAX_QUERY_VALUE_LEN: usize = 8_192;
16
17#[derive(Clone, Debug, Eq, Hash, PartialEq)]
18pub enum Method {
19    Get,
20    Post,
21    Put,
22    Delete,
23    Patch,
24    Head,
25    Options,
26    Other(String),
27}
28
29impl Method {
30    pub fn from_http_str(value: &str) -> Self {
31        match value {
32            "GET" => Self::Get,
33            "POST" => Self::Post,
34            "PUT" => Self::Put,
35            "DELETE" => Self::Delete,
36            "PATCH" => Self::Patch,
37            "HEAD" => Self::Head,
38            "OPTIONS" => Self::Options,
39            other => Self::Other(other.to_string()),
40        }
41    }
42}
43
44impl fmt::Display for Method {
45    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
46        match self {
47            Method::Get => write!(f, "GET"),
48            Method::Post => write!(f, "POST"),
49            Method::Put => write!(f, "PUT"),
50            Method::Delete => write!(f, "DELETE"),
51            Method::Patch => write!(f, "PATCH"),
52            Method::Head => write!(f, "HEAD"),
53            Method::Options => write!(f, "OPTIONS"),
54            Method::Other(value) => write!(f, "{value}"),
55        }
56    }
57}
58
59#[derive(Clone, Debug)]
60pub struct Request {
61    pub method: Method,
62    pub path: String,
63    pub version: String,
64    pub headers: HeaderMap,
65    pub body: Bytes,
66    pub query_params: HashMap<String, Vec<String>>,
67    pub remote_addr: Option<SocketAddr>,
68}
69
70impl Request {
71    pub fn body_str(&self) -> Option<&str> {
72        std::str::from_utf8(self.body.as_ref()).ok()
73    }
74
75    pub fn body_as_string(&self) -> String {
76        String::from_utf8_lossy(self.body.as_ref()).into_owned()
77    }
78
79    pub fn from_bytes(bytes: &[u8]) -> Result<Request, ParseError> {
80        let (head, body) = split_head_body(bytes);
81        let head = std::str::from_utf8(head).map_err(|_| ParseError::InvalidUtf8)?;
82
83        let mut lines = head.split("\r\n");
84        let request_line = lines.next().ok_or(ParseError::MissingRequestLine)?;
85        if request_line.trim().is_empty() {
86            return Err(ParseError::MissingRequestLine);
87        }
88        let (method, raw_path, version) = parse_request_line(request_line)?;
89
90        if !matches!(version.as_str(), "HTTP/1.0" | "HTTP/1.1") {
91            return Err(ParseError::InvalidHttpVersion);
92        }
93        let mut headers = HeaderMap::new();
94        let mut header_count = 0usize;
95        for line in lines {
96            if line.is_empty() {
97                continue;
98            }
99            header_count += 1;
100            if header_count > MAX_HEADERS {
101                return Err(ParseError::TooManyHeaders);
102            }
103            let (key, value) = line.split_once(':').ok_or(ParseError::InvalidHeaderLine)?;
104            let key = key.trim();
105            let value = value.trim();
106            if key.is_empty() {
107                return Err(ParseError::InvalidHeaderLine);
108            }
109            let key = HeaderName::try_from(key).map_err(|_| ParseError::InvalidHeaderLine)?;
110            let value = HeaderValue::from_str(value).map_err(|_| ParseError::InvalidHeaderLine)?;
111            headers.append(key, value);
112        }
113
114        let (raw_path, query_params) = if let Some((path, query)) = raw_path.split_once('?') {
115            (path.to_string(), parse_query(query)?)
116        } else {
117            (raw_path, HashMap::new())
118        };
119        let path = normalize_request_path(&raw_path)?;
120
121        let body = if let Some(content_length) = header_value(&headers, "content-length") {
122            let expected = content_length
123                .parse::<usize>()
124                .map_err(|_| ParseError::InvalidContentLength)?;
125            if body.len() < expected {
126                return Err(ParseError::BodyTooShort {
127                    expected,
128                    actual: body.len(),
129                });
130            }
131            Bytes::copy_from_slice(&body[..expected])
132        } else {
133            Bytes::copy_from_slice(body)
134        };
135
136        Ok(Request {
137            method,
138            path,
139            version,
140            headers,
141            body,
142            query_params,
143            remote_addr: None,
144        })
145    }
146
147    pub(crate) fn from_normalized_parts(
148        method: Method,
149        path: String,
150        version: String,
151        headers: HeaderMap,
152        body: Bytes,
153        query_params: HashMap<String, Vec<String>>,
154        remote_addr: Option<SocketAddr>,
155    ) -> Result<Request, ParseError> {
156        let path = normalize_request_path(&path)?;
157        Ok(Request {
158            method,
159            path,
160            version,
161            headers,
162            body,
163            query_params,
164            remote_addr,
165        })
166    }
167
168    pub(crate) fn parse_query(query: &str) -> Result<HashMap<String, Vec<String>>, ParseError> {
169        parse_query(query)
170    }
171
172    pub fn client_ip(&self, trusted_proxies: &[IpAddr]) -> Option<IpAddr> {
173        let remote_addr = self.remote_addr?;
174        if !trusted_proxies.contains(&remote_addr.ip()) {
175            return Some(remote_addr.ip());
176        }
177
178        let forwarded = header_values(&self.headers, "x-forwarded-for")
179            .flat_map(|value| value.split(','))
180            .filter_map(|item| item.trim().parse::<IpAddr>().ok())
181            .collect::<Vec<_>>();
182
183        for candidate in forwarded.into_iter().rev() {
184            if !trusted_proxies.contains(&candidate) {
185                return Some(candidate);
186            }
187        }
188
189        Some(remote_addr.ip())
190    }
191
192    pub fn header(&self, key: &str) -> Option<&str> {
193        self.headers.get(key).and_then(|value| value.to_str().ok())
194    }
195
196    pub fn header_values<'a>(&'a self, key: &'a str) -> impl Iterator<Item = &'a str> + 'a {
197        header_values(&self.headers, key)
198    }
199}
200
201fn split_head_body(bytes: &[u8]) -> (&[u8], &[u8]) {
202    bytes
203        .windows(4)
204        .position(|window| window == b"\r\n\r\n")
205        .map(|index| (&bytes[..index], &bytes[index + 4..]))
206        .unwrap_or((bytes, &[]))
207}
208
209fn header_value<'a>(headers: &'a HeaderMap, key: &str) -> Option<&'a str> {
210    headers.get(key).and_then(|value| value.to_str().ok())
211}
212
213fn header_values<'a>(headers: &'a HeaderMap, key: &'a str) -> impl Iterator<Item = &'a str> + 'a {
214    headers
215        .get_all(key)
216        .iter()
217        .filter_map(|value| value.to_str().ok())
218}
219
220fn parse_request_line(request_line: &str) -> Result<(Method, String, String), ParseError> {
221    if request_line.contains('\t') {
222        return Err(ParseError::InvalidRequestLine);
223    }
224
225    let mut parts = request_line.split(' ');
226    let method = parts.next().ok_or(ParseError::InvalidRequestLine)?;
227    let path = parts.next().ok_or(ParseError::InvalidRequestLine)?;
228    let version = parts.next().ok_or(ParseError::InvalidRequestLine)?;
229
230    if method.is_empty()
231        || path.is_empty()
232        || version.is_empty()
233        || parts.next().is_some()
234        || request_line.contains("  ")
235    {
236        return Err(ParseError::InvalidRequestLine);
237    }
238
239    Ok((
240        Method::from_http_str(method),
241        path.to_string(),
242        version.to_string(),
243    ))
244}
245
246fn parse_query(query: &str) -> Result<HashMap<String, Vec<String>>, ParseError> {
247    let mut params = HashMap::new();
248    let mut pair_count = 0usize;
249    for pair in query.split('&') {
250        if pair.is_empty() {
251            continue;
252        }
253        pair_count += 1;
254        if pair_count > MAX_QUERY_PARAMS {
255            return Err(ParseError::TooManyQueryParams);
256        }
257
258        let (raw_key, raw_value) = if let Some((key, value)) = pair.split_once('=') {
259            (key, value)
260        } else {
261            (pair, "")
262        };
263
264        let key = percent_decode(raw_key)?;
265        let value = percent_decode(raw_value)?;
266        if key.len() > MAX_QUERY_VALUE_LEN || value.len() > MAX_QUERY_VALUE_LEN {
267            return Err(ParseError::QueryValueTooLong);
268        }
269        params.entry(key).or_insert_with(Vec::new).push(value);
270    }
271    Ok(params)
272}
273
274fn normalize_request_path(path: &str) -> Result<String, ParseError> {
275    if !path.starts_with('/') || path.contains('\0') || path.contains('\\') {
276        return Err(ParseError::InvalidPath);
277    }
278
279    let mut normalized_segments = Vec::new();
280    for segment in path.split('/') {
281        if segment.is_empty() {
282            continue;
283        }
284        if segment == "." || segment == ".." {
285            return Err(ParseError::PathTraversal);
286        }
287        normalized_segments.push(segment);
288    }
289
290    if normalized_segments.is_empty() {
291        Ok("/".to_string())
292    } else {
293        Ok(format!("/{}", normalized_segments.join("/")))
294    }
295}
296
297fn percent_decode(value: &str) -> Result<String, ParseError> {
298    let bytes = value.as_bytes();
299    let mut decoded = Vec::with_capacity(bytes.len());
300    let mut idx = 0;
301
302    while idx < bytes.len() {
303        match bytes[idx] {
304            b'+' => {
305                decoded.push(b' ');
306                idx += 1;
307            }
308            b'%' => {
309                if idx + 2 >= bytes.len() {
310                    return Err(ParseError::InvalidPercentEncoding);
311                }
312
313                let high = decode_hex(bytes[idx + 1])?;
314                let low = decode_hex(bytes[idx + 2])?;
315                decoded.push((high << 4) | low);
316                idx += 3;
317            }
318            byte => {
319                decoded.push(byte);
320                idx += 1;
321            }
322        }
323    }
324
325    String::from_utf8(decoded).map_err(|_| ParseError::InvalidPercentEncoding)
326}
327
328fn decode_hex(byte: u8) -> Result<u8, ParseError> {
329    match byte {
330        b'0'..=b'9' => Ok(byte - b'0'),
331        b'a'..=b'f' => Ok(byte - b'a' + 10),
332        b'A'..=b'F' => Ok(byte - b'A' + 10),
333        _ => Err(ParseError::InvalidPercentEncoding),
334    }
335}
336
337#[derive(Clone, Debug)]
338pub struct Response {
339    pub status_code: u16,
340    pub status_text: String,
341    pub headers: Vec<(String, String)>,
342    pub body: Vec<u8>,
343}
344
345impl Response {
346    pub fn new(status_code: u16, status_text: impl Into<String>, body: impl Into<Vec<u8>>) -> Self {
347        Self {
348            status_code,
349            status_text: status_text.into(),
350            headers: Vec::new(),
351            body: body.into(),
352        }
353    }
354
355    pub fn ok(body: impl Into<Vec<u8>>) -> Self {
356        Self::new(200, "OK", body)
357    }
358
359    pub fn not_found() -> Self {
360        Self::from_error(404, "Not Found", "404 Not Found")
361    }
362
363    pub fn bad_request(message: impl Into<Vec<u8>>) -> Self {
364        let message = message.into();
365        Self::from_error(
366            400,
367            "Bad Request",
368            String::from_utf8_lossy(&message).into_owned(),
369        )
370    }
371
372    pub fn internal_server_error() -> Self {
373        Self::from_error(500, "Internal Server Error", "500 Internal Server Error")
374    }
375
376    pub fn from_error(
377        status_code: u16,
378        status_text: impl Into<String>,
379        body: impl Into<String>,
380    ) -> Self {
381        Self::new(status_code, status_text, body.into().into_bytes())
382            .with_header("Content-Type", "text/plain; charset=utf-8")
383    }
384
385    pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
386        let key = key.into();
387        let value = value.into();
388
389        match (
390            HeaderName::from_bytes(key.as_bytes()),
391            HeaderValue::from_str(&value),
392        ) {
393            (Ok(_), Ok(valid_value)) => {
394                self.headers.push((key, value_from_header(valid_value)));
395            }
396            _ => emit_default_log(
397                LogLevel::Warn,
398                "vantus.http",
399                &format!("ignored invalid response header: {}", key),
400            ),
401        }
402        self
403    }
404
405    pub fn text(body: impl Into<String>) -> Self {
406        Self::ok(body.into().into_bytes()).with_header("Content-Type", "text/plain; charset=utf-8")
407    }
408
409    pub fn html(body: impl Into<String>) -> Self {
410        Self::ok(body.into().into_bytes()).with_header("Content-Type", "text/html; charset=utf-8")
411    }
412
413    pub fn json(body: impl Into<String>) -> Self {
414        Self::ok(body.into().into_bytes())
415            .with_header("Content-Type", "application/json; charset=utf-8")
416    }
417
418    pub fn json_value(value: serde_json::Value) -> Self {
419        match serde_json::to_vec(&value) {
420            Ok(body) => {
421                Self::ok(body).with_header("Content-Type", "application/json; charset=utf-8")
422            }
423            Err(_) => Self::internal_server_error(),
424        }
425    }
426
427    pub fn json_serialized<T: Serialize>(value: &T) -> Result<Self, serde_json::Error> {
428        serde_json::to_vec(value).map(|body| {
429            Self::ok(body).with_header("Content-Type", "application/json; charset=utf-8")
430        })
431    }
432
433    pub fn to_http_bytes(&self) -> Vec<u8> {
434        let mut response = String::with_capacity(64 + self.headers.len() * 32 + self.body.len());
435        let _ = write!(
436            response,
437            "HTTP/1.1 {} {}\r\n",
438            self.status_code, self.status_text
439        );
440        let mut has_content_length = false;
441        let mut has_connection = false;
442
443        for (key, value) in &self.headers {
444            if HeaderName::from_bytes(key.as_bytes()).is_err()
445                || HeaderValue::from_str(value).is_err()
446            {
447                emit_default_log(
448                    LogLevel::Warn,
449                    "vantus.http",
450                    &format!(
451                        "ignored invalid response header during serialization: {}",
452                        key
453                    ),
454                );
455                continue;
456            }
457            if key.eq_ignore_ascii_case("content-length") {
458                has_content_length = true;
459            }
460            if key.eq_ignore_ascii_case("connection") {
461                has_connection = true;
462            }
463            let _ = write!(response, "{key}: {value}\r\n");
464        }
465
466        if !has_content_length {
467            let _ = write!(response, "Content-Length: {}\r\n", self.body.len());
468        }
469        if !has_connection {
470            response.push_str("Connection: close\r\n");
471        }
472
473        response.push_str("\r\n");
474        let mut bytes = response.into_bytes();
475        bytes.extend_from_slice(&self.body);
476        bytes
477    }
478
479    pub async fn file_async(path: impl AsRef<Path>) -> Self {
480        let path = path.as_ref();
481        match tokio::fs::read(path).await {
482            Ok(content) => {
483                let mut res = Self::ok(content);
484
485                if let Some(ext) = path.extension().and_then(|s| s.to_str()) {
486                    res = res.with_header("Content-Type", mime_for_ext(ext));
487                }
488                res
489            }
490            Err(_) => {
491                emit_default_log(
492                    LogLevel::Warn,
493                    "vantus.http",
494                    &format!("file not found at {:?}", path),
495                );
496                Self::not_found()
497            }
498        }
499    }
500
501    #[deprecated(note = "use Response::file_async instead")]
502    pub fn file(path: impl AsRef<Path>) -> Self {
503        let path = path.as_ref().to_path_buf();
504
505        match tokio::runtime::Handle::try_current() {
506            Ok(handle) => {
507                emit_default_log(
508                    LogLevel::Warn,
509                    "vantus.http",
510                    "Response::file is deprecated inside async runtimes; use Response::file_async",
511                );
512                match read_file_bytes_compat(&path, handle) {
513                    Ok(content) => response_from_file_bytes(&path, content),
514                    Err(_) => {
515                        emit_default_log(
516                            LogLevel::Warn,
517                            "vantus.http",
518                            &format!("file not found at {:?}", path),
519                        );
520                        Self::not_found()
521                    }
522                }
523            }
524            Err(_) => match tokio::runtime::Builder::new_current_thread()
525                .enable_all()
526                .build()
527            {
528                Ok(runtime) => runtime.block_on(Self::file_async(path)),
529                Err(_) => Self::internal_server_error(),
530            },
531        }
532    }
533}
534
535fn read_file_bytes_compat(path: &Path, handle: tokio::runtime::Handle) -> std::io::Result<Vec<u8>> {
536    tokio::task::block_in_place(|| {
537        let path = path.to_path_buf();
538        handle.block_on(async move {
539            tokio::task::spawn_blocking(move || std::fs::read(path))
540                .await
541                .map_err(|error| std::io::Error::other(error.to_string()))?
542        })
543    })
544}
545
546fn response_from_file_bytes(path: &Path, content: Vec<u8>) -> Response {
547    let mut res = Response::ok(content);
548    if let Some(ext) = path.extension().and_then(|s| s.to_str()) {
549        res = res.with_header("Content-Type", mime_for_ext(ext));
550    }
551    res
552}
553
554fn value_from_header(value: HeaderValue) -> String {
555    value.to_str().map(str::to_string).unwrap_or_default()
556}
557
558fn mime_for_ext(ext: &str) -> &'static str {
559    match ext {
560        "png" => "image/png",
561        "jpg" | "jpeg" => "image/jpeg",
562        "gif" => "image/gif",
563        "svg" => "image/svg+xml",
564        "webp" => "image/webp",
565        "css" => "text/css; charset=utf-8",
566        "js" | "mjs" => "application/javascript; charset=utf-8",
567        "html" | "htm" => "text/html; charset=utf-8",
568        "json" => "application/json; charset=utf-8",
569        "txt" => "text/plain; charset=utf-8",
570        _ => "application/octet-stream",
571    }
572}
573
574#[derive(Debug)]
575pub enum ParseError {
576    MissingRequestLine,
577    InvalidRequestLine,
578    InvalidHttpVersion,
579    InvalidPath,
580    PathTraversal,
581    InvalidUtf8,
582    InvalidHeaderLine,
583    InvalidContentLength,
584    InvalidPercentEncoding,
585    TooManyHeaders,
586    TooManyQueryParams,
587    QueryValueTooLong,
588    RequestTooLarge { limit: usize },
589    BodyTooShort { expected: usize, actual: usize },
590}
591
592impl fmt::Display for ParseError {
593    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
594        match self {
595            ParseError::MissingRequestLine => write!(f, "request line is missing"),
596            ParseError::InvalidRequestLine => write!(f, "request line is invalid"),
597            ParseError::InvalidHttpVersion => write!(f, "http version is invalid"),
598            ParseError::InvalidPath => write!(f, "request path is invalid"),
599            ParseError::PathTraversal => write!(f, "request path contains traversal sequences"),
600            ParseError::InvalidUtf8 => write!(f, "request headers are not valid utf-8"),
601            ParseError::InvalidHeaderLine => write!(f, "request header line is invalid"),
602            ParseError::InvalidContentLength => write!(f, "content-length header is invalid"),
603            ParseError::InvalidPercentEncoding => {
604                write!(f, "request query percent-encoding is invalid")
605            }
606            ParseError::TooManyHeaders => {
607                write!(
608                    f,
609                    "request contains too many headers (limit: {MAX_HEADERS})"
610                )
611            }
612            ParseError::TooManyQueryParams => write!(
613                f,
614                "request contains too many query parameters (limit: {MAX_QUERY_PARAMS})"
615            ),
616            ParseError::QueryValueTooLong => write!(
617                f,
618                "query key or value exceeds maximum length ({MAX_QUERY_VALUE_LEN} bytes)"
619            ),
620            ParseError::RequestTooLarge { limit } => {
621                write!(f, "request exceeds maximum allowed size ({limit} bytes)")
622            }
623            ParseError::BodyTooShort { expected, actual } => write!(
624                f,
625                "request body is shorter than content-length (expected {expected}, got {actual})"
626            ),
627        }
628    }
629}
630
631impl std::error::Error for ParseError {}