small_http/
request.rs

1/*
2 * Copyright (c) 2023-2025 Bastiaan van der Plaat
3 *
4 * SPDX-License-Identifier: MIT
5 */
6
7use std::collections::HashMap;
8use std::error::Error;
9use std::fmt::{self, Display, Formatter};
10use std::io::{BufRead, BufReader, Read, Write};
11use std::net::{Ipv4Addr, SocketAddr, TcpStream};
12use std::str::{self, FromStr};
13
14use url::Url;
15
16use crate::enums::{Method, Version};
17use crate::Response;
18
19/// Mark: Headers
20#[derive(Default, Clone)]
21pub struct HeaderMap(Vec<(String, String)>);
22
23impl HeaderMap {
24    /// Create new HeaderMap
25    pub fn new() -> Self {
26        Self::default()
27    }
28
29    /// Get header value
30    pub fn get(&self, name: &str) -> Option<&String> {
31        self.0.iter().find(|(n, _)| n == name).map(|(_, v)| v)
32    }
33
34    /// Iterate over headers
35    pub fn iter(&self) -> impl Iterator<Item = (&String, &String)> {
36        self.0.iter().map(|(n, v)| (n, v))
37    }
38
39    /// Insert header
40    pub fn insert(&mut self, name: String, value: String) {
41        self.0.push((name, value));
42    }
43}
44
45/// HTTP request
46#[derive(Clone)]
47pub struct Request {
48    /// HTTP version
49    pub(crate) version: Version,
50    /// URL
51    pub url: Url,
52    /// Method
53    pub method: Method,
54    /// Headers
55    pub headers: HeaderMap,
56    /// Parameters
57    pub params: HashMap<String, String>,
58    /// Body
59    pub body: Option<Vec<u8>>,
60    /// Client address
61    pub client_addr: SocketAddr,
62}
63
64impl Default for Request {
65    fn default() -> Self {
66        Self {
67            version: Version::Http1_1,
68            url: Url::from_str("http://localhost").expect("Should parse"),
69            method: Method::Get,
70            headers: HeaderMap::new(),
71            params: HashMap::new(),
72            body: None,
73            client_addr: (Ipv4Addr::LOCALHOST, 0).into(),
74        }
75    }
76}
77
78impl Request {
79    /// Create new request
80    pub fn new() -> Self {
81        Self::default()
82    }
83
84    /// Create new request with URL
85    pub fn with_url(url: impl AsRef<str>) -> Self {
86        Self {
87            url: url.as_ref().parse().expect("Invalid url"),
88            ..Self::default()
89        }
90    }
91
92    /// Set URL
93    pub fn url(mut self, url: Url) -> Self {
94        self.url = url;
95        self
96    }
97
98    /// Set method
99    pub fn method(mut self, method: Method) -> Self {
100        self.method = method;
101        self
102    }
103
104    /// Set header
105    pub fn header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
106        self.headers.insert(name.into(), value.into());
107        self
108    }
109
110    /// Set body
111    pub fn body(mut self, body: impl Into<Vec<u8>>) -> Self {
112        self.body = Some(body.into());
113        self
114    }
115
116    pub(crate) fn read_from_stream(
117        stream: &mut dyn Read,
118        client_addr: SocketAddr,
119    ) -> Result<Request, InvalidRequestError> {
120        let mut reader = BufReader::new(stream);
121
122        // Read first line
123        let (method, path, version) = {
124            let mut line = String::new();
125            reader
126                .read_line(&mut line)
127                .map_err(|_| InvalidRequestError)?;
128            let mut parts = line.split(' ');
129            (
130                parts
131                    .next()
132                    .ok_or(InvalidRequestError)?
133                    .trim()
134                    .parse()
135                    .map_err(|_| InvalidRequestError)?,
136                parts.next().ok_or(InvalidRequestError)?.trim().to_string(),
137                parts
138                    .next()
139                    .ok_or(InvalidRequestError)?
140                    .trim()
141                    .to_string()
142                    .parse()
143                    .map_err(|_| InvalidRequestError)?,
144            )
145        };
146
147        // Read headers
148        let mut headers = HeaderMap::new();
149        loop {
150            let mut line = String::new();
151            reader
152                .read_line(&mut line)
153                .map_err(|_| InvalidRequestError)?;
154            if line == "\r\n" {
155                break;
156            }
157            let split = line.find(':').ok_or(InvalidRequestError)?;
158            headers.insert(
159                line[0..split].trim().to_string(),
160                line[split + 1..].trim().to_string(),
161            );
162        }
163
164        // Read body
165        let mut body = None;
166        if let Some(content_length) = headers.get("Content-Length") {
167            let content_length = content_length.parse().map_err(|_| InvalidRequestError)?;
168            if content_length > 0 {
169                let mut buffer = vec![0; content_length];
170                reader.read(&mut buffer).map_err(|_| InvalidRequestError)?;
171                body = Some(buffer);
172            }
173        }
174
175        // Parse URL
176        let url = Url::from_str(&if version == Version::Http1_1 {
177            format!(
178                "http://{}{}",
179                headers.get("Host").ok_or(InvalidRequestError)?,
180                path
181            )
182        } else {
183            format!("http://localhost{}", path)
184        })
185        .map_err(|_| InvalidRequestError)?;
186
187        Ok(Request {
188            version,
189            url,
190            method,
191            headers,
192            params: HashMap::new(),
193            body,
194            client_addr,
195        })
196    }
197
198    pub(crate) fn write_to_stream(mut self, stream: &mut dyn Write) {
199        // Finish headers
200        let host = self.url.host().expect("No host in URL");
201        self.headers.insert(
202            "Host".to_string(),
203            if let Some(port) = self.url.port() {
204                format!("{}:{}", &host, port)
205            } else {
206                host.to_string()
207            },
208        );
209        self.headers.insert(
210            "Content-Length".to_string(),
211            if let Some(body) = &self.body {
212                body.len()
213            } else {
214                0
215            }
216            .to_string(),
217        );
218        if self.version == Version::Http1_1 {
219            self.headers
220                .insert("Connection".to_string(), "close".to_string());
221        }
222
223        // Write request
224        let path = self.url.path();
225        let path = if let Some(query) = self.url.query() {
226            format!("{}?{}", &path, query)
227        } else {
228            path.to_string()
229        };
230        _ = write!(stream, "{} {} HTTP/1.1\r\n", self.method, path);
231        for (name, value) in self.headers.iter() {
232            _ = write!(stream, "{}: {}\r\n", name, value);
233        }
234        _ = write!(stream, "\r\n");
235        if let Some(body) = &self.body {
236            _ = stream.write_all(body);
237        }
238    }
239
240    /// Fetch request with http client
241    pub fn fetch(self) -> Result<Response, FetchError> {
242        let mut stream = TcpStream::connect(format!(
243            "{}:{}",
244            self.url.host().expect("No host in URL"),
245            self.url.port().unwrap_or(80)
246        ))
247        .map_err(|_| FetchError)?;
248        self.write_to_stream(&mut stream);
249        Response::read_from_stream(&mut stream).map_err(|_| FetchError)
250    }
251}
252
253// MARK: InvalidRequestError
254#[derive(Debug)]
255pub(crate) struct InvalidRequestError;
256
257impl Display for InvalidRequestError {
258    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
259        write!(f, "Invalid request")
260    }
261}
262
263impl Error for InvalidRequestError {}
264
265// MARK: FetchError
266#[derive(Debug)]
267pub struct FetchError;
268
269impl Display for FetchError {
270    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
271        write!(f, "Fetch error")
272    }
273}
274
275impl Error for FetchError {}
276
277// MARK: Tests
278#[cfg(test)]
279mod test {
280    use std::io::Write;
281    use std::net::{Ipv4Addr, TcpListener};
282    use std::thread;
283
284    use super::*;
285    use crate::enums::Status;
286
287    #[test]
288    fn test_read_from_stream() {
289        let raw_request = b"GET / HTTP/1.1\r\nHost: localhost\r\n\r\n";
290        let mut stream = &raw_request[..];
291        let request =
292            Request::read_from_stream(&mut stream, (Ipv4Addr::LOCALHOST, 12345).into()).unwrap();
293        assert_eq!(request.method, Method::Get);
294        assert_eq!(request.url.to_string(), "http://localhost/");
295        assert_eq!(request.version, Version::Http1_1);
296        assert_eq!(request.headers.get("Host").unwrap(), "localhost");
297    }
298
299    #[test]
300    fn test_read_from_stream_with_body() {
301        let raw_request =
302            b"POST / HTTP/1.1\r\nHost: localhost\r\nContent-Length: 13\r\n\r\nHello, world!";
303        let mut stream = &raw_request[..];
304        let request =
305            Request::read_from_stream(&mut stream, (Ipv4Addr::LOCALHOST, 12345).into()).unwrap();
306        assert_eq!(request.method, Method::Post);
307        assert_eq!(request.url.to_string(), "http://localhost/");
308        assert_eq!(request.version, Version::Http1_1);
309        assert_eq!(request.headers.get("Host").unwrap(), "localhost");
310        assert_eq!(request.body.unwrap(), b"Hello, world!");
311    }
312
313    #[test]
314    fn test_invalid_request_error() {
315        let raw_request = b"INVALID REQUEST";
316        let mut stream = &raw_request[..];
317        let result = Request::read_from_stream(&mut stream, (Ipv4Addr::LOCALHOST, 12345).into());
318        assert!(result.is_err());
319    }
320
321    #[test]
322    fn test_write_to_stream() {
323        let request = Request::new()
324            .method(Method::Get)
325            .url(Url::from_str("http://localhost/").unwrap())
326            .header("Host", "localhost");
327
328        let mut buffer = Vec::new();
329        request.write_to_stream(&mut buffer);
330        assert!(buffer.starts_with(b"GET / HTTP/1.1\r\n"));
331    }
332
333    #[test]
334    fn test_write_to_stream_with_body() {
335        let request = Request::new()
336            .method(Method::Post)
337            .url(Url::from_str("http://localhost/").unwrap())
338            .header("Host", "localhost")
339            .body("Hello, world!");
340
341        let mut buffer = Vec::new();
342        request.write_to_stream(&mut buffer);
343        assert!(buffer.starts_with(b"POST / HTTP/1.1\r\n"));
344    }
345
346    #[test]
347    fn test_fetch_http1_0() {
348        let listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).unwrap();
349        let server_addr = listener.local_addr().unwrap();
350        thread::spawn(move || {
351            let (mut stream, _) = listener.accept().unwrap();
352            stream
353                .write_all(b"HTTP/1.0 200 OK\r\nContent-Length: 4\r\n\r\ntest")
354                .unwrap();
355        });
356
357        let res = Request::with_url(format!("http://{}/", server_addr))
358            .fetch()
359            .unwrap();
360        assert_eq!(res.status, Status::Ok);
361        assert_eq!(res.body, "test".as_bytes());
362    }
363
364    #[test]
365    fn test_fetch_http1_1() {
366        let listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).unwrap();
367        let server_addr = listener.local_addr().unwrap();
368        thread::spawn(move || {
369            let (mut stream, _) = listener.accept().unwrap();
370            stream
371                .write_all(
372                    b"HTTP/1.1 200 OK\r\nContent-Length: 4\r\nConnection: closed\r\n\r\ntest",
373                )
374                .unwrap();
375        });
376
377        let res = Request::with_url(format!("http://{}/", server_addr))
378            .fetch()
379            .unwrap();
380        assert_eq!(res.status, Status::Ok);
381        assert_eq!(res.body, "test".as_bytes());
382    }
383}