small_http/
request.rs

1/*
2 * Copyright (c) 2023-2025 Bastiaan van der Plaat
3 *
4 * SPDX-License-Identifier: MIT
5 */
6
7use std::collections::HashMap;
8use std::error::Error;
9use std::fmt::{self, Display, Formatter};
10use std::io::{BufRead, BufReader, Read, Write};
11use std::net::{Ipv4Addr, SocketAddr, TcpStream};
12use std::str::{self, FromStr};
13
14use url::Url;
15
16use crate::enums::{Method, Version};
17use crate::header_map::HeaderMap;
18use crate::response::Response;
19use crate::KEEP_ALIVE_TIMEOUT;
20
21// MARK: Request
22/// HTTP request
23#[derive(Clone)]
24pub struct Request {
25    /// HTTP version
26    pub(crate) version: Version,
27    /// URL
28    pub url: Url,
29    /// Method
30    pub method: Method,
31    /// Headers
32    pub headers: HeaderMap,
33    /// Parameters (mostly added for small-router)
34    pub params: HashMap<String, String>,
35    /// Body
36    pub body: Option<Vec<u8>>,
37    /// Client address
38    pub client_addr: SocketAddr,
39}
40
41impl Default for Request {
42    fn default() -> Self {
43        Self {
44            version: Version::Http1_1,
45            url: Url::from_str("http://localhost").expect("Should parse"),
46            method: Method::Get,
47            headers: HeaderMap::new(),
48            params: HashMap::new(),
49            body: None,
50            client_addr: (Ipv4Addr::LOCALHOST, 0).into(),
51        }
52    }
53}
54
55impl Request {
56    /// Create new request
57    pub fn new() -> Self {
58        Self::default()
59    }
60
61    /// Create new request with method
62    pub fn with_method(method: Method) -> Self {
63        Self {
64            method,
65            ..Self::default()
66        }
67    }
68
69    /// Create new request with URL
70    pub fn with_url(url: impl AsRef<str>) -> Self {
71        Self {
72            url: url.as_ref().parse().expect("Invalid url"),
73            ..Self::default()
74        }
75    }
76
77    /// Create new request with specific method and URL
78    fn with_method_and_url(method: Method, url: impl AsRef<str>) -> Self {
79        Self {
80            method,
81            url: url.as_ref().parse().expect("Invalid url"),
82            ..Self::default()
83        }
84    }
85
86    /// Create new GET request with URL
87    pub fn get(url: impl AsRef<str>) -> Self {
88        Self::with_method_and_url(Method::Get, url)
89    }
90
91    /// Create new HEAD request with URL
92    pub fn head(url: impl AsRef<str>) -> Self {
93        Self::with_method_and_url(Method::Head, url)
94    }
95
96    /// Create new POST request with URL
97    pub fn post(url: impl AsRef<str>) -> Self {
98        Self::with_method_and_url(Method::Post, url)
99    }
100
101    /// Create new PUT request with URL
102    pub fn put(url: impl AsRef<str>) -> Self {
103        Self::with_method_and_url(Method::Put, url)
104    }
105
106    /// Create new DELETE request with URL
107    pub fn delete(url: impl AsRef<str>) -> Self {
108        Self::with_method_and_url(Method::Delete, url)
109    }
110
111    /// Create new CONNECT request with URL
112    pub fn connect(url: impl AsRef<str>) -> Self {
113        Self::with_method_and_url(Method::Connect, url)
114    }
115
116    /// Create new OPTIONS request with URL
117    pub fn options(url: impl AsRef<str>) -> Self {
118        Self::with_method_and_url(Method::Options, url)
119    }
120
121    /// Create new TRACE request with URL
122    pub fn trace(url: impl AsRef<str>) -> Self {
123        Self::with_method_and_url(Method::Trace, url)
124    }
125
126    /// Create new PATCH request with URL
127    pub fn patch(url: impl AsRef<str>) -> Self {
128        Self::with_method_and_url(Method::Patch, url)
129    }
130
131    /// Set URL
132    pub fn url(mut self, url: Url) -> Self {
133        self.url = url;
134        self
135    }
136
137    /// Set method
138    pub fn method(mut self, method: Method) -> Self {
139        self.method = method;
140        self
141    }
142
143    /// Set header
144    pub fn header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
145        self.headers.insert(name.into(), value.into());
146        self
147    }
148
149    /// Set body
150    pub fn body(mut self, body: impl Into<Vec<u8>>) -> Self {
151        self.body = Some(body.into());
152        self
153    }
154
155    pub(crate) fn read_from_stream(
156        stream: &mut dyn Read,
157        client_addr: SocketAddr,
158    ) -> Result<Request, InvalidRequestError> {
159        let mut reader = BufReader::new(stream);
160
161        // Read first line
162        let (method, path, version) = {
163            let mut line = String::new();
164            reader
165                .read_line(&mut line)
166                .map_err(|_| InvalidRequestError("Can't read first line".to_string()))?;
167            let mut parts = line.split(' ');
168            (
169                parts
170                    .next()
171                    .ok_or(InvalidRequestError(
172                        "Can't read 1st part of first line".to_string(),
173                    ))?
174                    .trim()
175                    .parse()
176                    .map_err(|_| InvalidRequestError("Can't parse method".to_string()))?,
177                parts
178                    .next()
179                    .ok_or(InvalidRequestError(
180                        "Can't read 2st part of first line".to_string(),
181                    ))?
182                    .trim()
183                    .to_string(),
184                parts
185                    .next()
186                    .ok_or(InvalidRequestError(
187                        "Can't read 3st part of first line".to_string(),
188                    ))?
189                    .trim()
190                    .to_string()
191                    .parse()
192                    .map_err(|_| InvalidRequestError("Can't parse HTTP version".to_string()))?,
193            )
194        };
195
196        // Read headers
197        let mut headers = HeaderMap::new();
198        loop {
199            let mut line = String::new();
200            reader
201                .read_line(&mut line)
202                .map_err(|_| InvalidRequestError("Can't read header line".to_string()))?;
203            if line == "\r\n" {
204                break;
205            }
206            let split = line
207                .find(':')
208                .ok_or(InvalidRequestError("Can't parse header line".to_string()))?;
209            headers.insert(
210                line[0..split].trim().to_string(),
211                line[split + 1..].trim().to_string(),
212            );
213        }
214
215        // Read body
216        let mut body = None;
217        if let Some(content_length) = headers.get("Content-Length") {
218            let content_length = content_length
219                .parse()
220                .map_err(|_| InvalidRequestError("Can't parse Content-Length".to_string()))?;
221            if content_length > 0 {
222                let mut buffer = vec![0; content_length];
223                reader.read(&mut buffer).map_err(|_| {
224                    InvalidRequestError(
225                        "Can't read Content-Length amount of bytes from stream".to_string(),
226                    )
227                })?;
228                body = Some(buffer);
229            }
230        }
231
232        // Parse URL
233        let url = Url::from_str(&if version == Version::Http1_1 {
234            format!(
235                "http://{}{}",
236                headers.get("Host").ok_or(InvalidRequestError(
237                    "HTTP version is 1.1 but Host header is not set".to_string()
238                ))?,
239                path
240            )
241        } else {
242            format!("http://localhost{path}")
243        })
244        .map_err(|_| InvalidRequestError("Can't parse request url".to_string()))?;
245
246        Ok(Request {
247            version,
248            url,
249            method,
250            headers,
251            params: HashMap::new(),
252            body,
253            client_addr,
254        })
255    }
256
257    /// Write request to TCP stream
258    pub fn write_to_stream(mut self, stream: &mut dyn Write, keep_alive: bool) {
259        // Finish headers
260        let host = self.url.host().expect("No host in URL");
261        self.headers.insert(
262            "Host".to_string(),
263            if let Some(port) = self.url.port() {
264                format!("{}:{}", &host, port)
265            } else {
266                host.to_string()
267            },
268        );
269        self.headers.insert(
270            "Content-Length".to_string(),
271            if let Some(body) = &self.body {
272                body.len()
273            } else {
274                0
275            }
276            .to_string(),
277        );
278        if self.version == Version::Http1_1 {
279            if keep_alive {
280                self.headers
281                    .insert("Connection".to_string(), "keep-alive".to_string());
282                self.headers.insert(
283                    "Keep-Alive".to_string(),
284                    format!("timeout={}", KEEP_ALIVE_TIMEOUT.as_secs()),
285                );
286            } else {
287                self.headers
288                    .insert("Connection".to_string(), "close".to_string());
289            }
290        }
291
292        // Write request
293        let path = self.url.path();
294        let path = if let Some(query) = self.url.query() {
295            format!("{}?{}", &path, query)
296        } else {
297            path.to_string()
298        };
299        _ = write!(stream, "{} {} HTTP/1.1\r\n", self.method, path);
300        for (name, value) in self.headers.iter() {
301            _ = write!(stream, "{name}: {value}\r\n");
302        }
303        _ = write!(stream, "\r\n");
304        if let Some(body) = &self.body {
305            _ = stream.write_all(body);
306        }
307    }
308
309    /// Fetch request with http client
310    pub fn fetch(self) -> Result<Response, FetchError> {
311        let mut stream = TcpStream::connect(format!(
312            "{}:{}",
313            self.url.host().expect("No host in URL"),
314            self.url.port().unwrap_or(80)
315        ))
316        .map_err(|_| FetchError)?;
317        self.write_to_stream(&mut stream, false);
318        Response::read_from_stream(&mut stream).map_err(|_| FetchError)
319    }
320}
321
322// MARK: InvalidRequestError
323#[derive(Debug)]
324pub(crate) struct InvalidRequestError(String);
325
326impl Display for InvalidRequestError {
327    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
328        write!(f, "Invalid request: {}", self.0)
329    }
330}
331
332impl Error for InvalidRequestError {}
333
334// MARK: FetchError
335#[derive(Debug)]
336pub struct FetchError;
337
338impl Display for FetchError {
339    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
340        write!(f, "Fetch error")
341    }
342}
343
344impl Error for FetchError {}
345
346// MARK: Tests
347#[cfg(test)]
348mod test {
349    use std::io::Write;
350    use std::net::{Ipv4Addr, TcpListener};
351    use std::thread;
352
353    use super::*;
354    use crate::enums::Status;
355
356    #[test]
357    fn test_read_from_stream() {
358        let raw_request = b"GET / HTTP/1.1\r\nHost: localhost\r\n\r\n";
359        let mut stream = &raw_request[..];
360        let request =
361            Request::read_from_stream(&mut stream, (Ipv4Addr::LOCALHOST, 12345).into()).unwrap();
362        assert_eq!(request.method, Method::Get);
363        assert_eq!(request.url.to_string(), "http://localhost/");
364        assert_eq!(request.version, Version::Http1_1);
365        assert_eq!(request.headers.get("Host").unwrap(), "localhost");
366    }
367
368    #[test]
369    fn test_read_from_stream_with_body() {
370        let raw_request =
371            b"POST / HTTP/1.1\r\nHost: localhost\r\nContent-Length: 13\r\n\r\nHello, world!";
372        let mut stream = &raw_request[..];
373        let request =
374            Request::read_from_stream(&mut stream, (Ipv4Addr::LOCALHOST, 12345).into()).unwrap();
375        assert_eq!(request.method, Method::Post);
376        assert_eq!(request.url.to_string(), "http://localhost/");
377        assert_eq!(request.version, Version::Http1_1);
378        assert_eq!(request.headers.get("Host").unwrap(), "localhost");
379        assert_eq!(request.body.unwrap(), b"Hello, world!");
380    }
381
382    #[test]
383    fn test_read_from_stream_with_body_lowercase_headers() {
384        let raw_request =
385            b"POST / HTTP/1.1\r\nhost: localhost\r\ncontent-Length: 13\r\n\r\nHello, world!";
386        let mut stream = &raw_request[..];
387        let request =
388            Request::read_from_stream(&mut stream, (Ipv4Addr::LOCALHOST, 12345).into()).unwrap();
389        assert_eq!(request.method, Method::Post);
390        assert_eq!(request.url.to_string(), "http://localhost/");
391        assert_eq!(request.version, Version::Http1_1);
392        assert_eq!(request.headers.get("Host").unwrap(), "localhost");
393        assert_eq!(request.body.unwrap(), b"Hello, world!");
394    }
395
396    #[test]
397    fn test_invalid_request_error() {
398        let raw_request = b"INVALID REQUEST";
399        let mut stream = &raw_request[..];
400        let result = Request::read_from_stream(&mut stream, (Ipv4Addr::LOCALHOST, 12345).into());
401        assert!(result.is_err());
402    }
403
404    #[test]
405    fn test_write_to_stream() {
406        let request = Request::get("http://localhost/").header("Host", "localhost");
407
408        let mut buffer = Vec::new();
409        request.write_to_stream(&mut buffer, false);
410        assert!(buffer.starts_with(b"GET / HTTP/1.1\r\n"));
411    }
412
413    #[test]
414    fn test_write_to_stream_with_body() {
415        let request = Request::post("http://localhost/")
416            .header("Host", "localhost")
417            .body("Hello, world!");
418
419        let mut buffer = Vec::new();
420        request.write_to_stream(&mut buffer, false);
421        assert!(buffer.starts_with(b"POST / HTTP/1.1\r\n"));
422    }
423
424    #[test]
425    fn test_fetch_http1_0() {
426        let listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).unwrap();
427        let server_addr = listener.local_addr().unwrap();
428        thread::spawn(move || {
429            let (mut stream, _) = listener.accept().unwrap();
430            stream
431                .write_all(b"HTTP/1.0 200 OK\r\nContent-Length: 4\r\n\r\ntest")
432                .unwrap();
433            stream.flush().unwrap();
434        });
435
436        let res = Request::get(format!("http://{server_addr}/"))
437            .fetch()
438            .unwrap();
439        assert_eq!(res.status, Status::Ok);
440        assert_eq!(res.body, "test".as_bytes());
441    }
442
443    #[test]
444    fn test_fetch_http1_1() {
445        let listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).unwrap();
446        let server_addr = listener.local_addr().unwrap();
447        thread::spawn(move || {
448            let (mut stream, _) = listener.accept().unwrap();
449            stream
450                .write_all(
451                    b"HTTP/1.1 200 OK\r\nContent-Length: 4\r\nConnection: closed\r\n\r\ntest",
452                )
453                .unwrap();
454            stream.flush().unwrap();
455        });
456
457        let res = Request::get(format!("http://{server_addr}/"))
458            .fetch()
459            .unwrap();
460        assert_eq!(res.status, Status::Ok);
461        assert_eq!(res.body, "test".as_bytes());
462    }
463}