small_http/
response.rs

1/*
2 * Copyright (c) 2023-2025 Bastiaan van der Plaat
3 *
4 * SPDX-License-Identifier: MIT
5 */
6
7use std::error::Error;
8use std::fmt::{self, Display, Formatter};
9use std::io::{BufRead, BufReader, Read, Write};
10
11use crate::enums::{Status, Version};
12use crate::request::HeaderMap;
13use crate::serve::KEEP_ALIVE_TIMEOUT;
14use crate::Request;
15
16// MARK: Response
17/// HTTP response
18#[derive(Default)]
19pub struct Response {
20    /// Status
21    pub status: Status,
22    /// Headers
23    pub headers: HeaderMap,
24    /// Body
25    pub body: Vec<u8>,
26}
27
28impl Response {
29    /// Create new response
30    pub fn new() -> Self {
31        Self::default()
32    }
33
34    /// Create new response with status
35    pub fn with_status(status: Status) -> Self {
36        Self {
37            status,
38            ..Default::default()
39        }
40    }
41
42    /// Set status
43    pub fn status(mut self, status: Status) -> Self {
44        self.status = status;
45        self
46    }
47
48    /// Create new response with header
49    pub fn with_header(name: impl Into<String>, value: impl Into<String>) -> Self {
50        Self::default().header(name.into(), value.into())
51    }
52
53    /// Set header
54    pub fn header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
55        self.headers.insert(name.into(), value.into());
56        self
57    }
58
59    /// Create new response with body
60    pub fn with_body(body: impl Into<Vec<u8>>) -> Self {
61        Self {
62            body: body.into(),
63            ..Default::default()
64        }
65    }
66
67    /// Set body
68    pub fn body(mut self, body: impl Into<Vec<u8>>) -> Self {
69        self.body = body.into();
70        self
71    }
72
73    /// Create new response with json body
74    #[cfg(feature = "json")]
75    pub fn with_json(value: impl serde::Serialize) -> Self {
76        Self::default().json(value)
77    }
78
79    /// Set json body
80    #[cfg(feature = "json")]
81    pub fn json(mut self, value: impl serde::Serialize) -> Self {
82        self.headers
83            .insert("Content-Type".to_string(), "application/json".to_string());
84        self.body = serde_json::to_string(&value)
85            .expect("Can't serialize json")
86            .into();
87        self
88    }
89
90    /// Create new response with redirect header
91    pub fn with_redirect(location: impl Into<String>) -> Self {
92        Self::default().redirect(location.into())
93    }
94
95    /// Set redirect header
96    pub fn redirect(mut self, location: impl Into<String>) -> Self {
97        self.status = Status::TemporaryRedirect;
98        self.headers.insert("Location".to_string(), location.into());
99        self
100    }
101
102    /// Parse json out of body
103    #[cfg(feature = "json")]
104    pub fn into_json<T: serde::de::DeserializeOwned>(self) -> Result<T, serde_json::Error> {
105        serde_json::from_slice(&self.body)
106    }
107
108    pub(crate) fn read_from_stream(stream: &mut dyn Read) -> Result<Self, InvalidResponseError> {
109        let mut reader = BufReader::new(stream);
110
111        // Read first line
112        let mut res = {
113            let mut line = String::new();
114            reader
115                .read_line(&mut line)
116                .map_err(|_| InvalidResponseError)?;
117            let mut parts = line.splitn(3, ' ');
118            let _http_version = parts.next().ok_or(InvalidResponseError)?;
119            let status_code = parts
120                .next()
121                .ok_or(InvalidResponseError)?
122                .parse::<i32>()
123                .map_err(|_| InvalidResponseError)?;
124            Response::default()
125                .status(Status::try_from(status_code).map_err(|_| InvalidResponseError)?)
126        };
127
128        // Read headers
129        loop {
130            let mut line = String::new();
131            reader
132                .read_line(&mut line)
133                .map_err(|_| InvalidResponseError)?;
134            if line == "\r\n" {
135                break;
136            }
137            let split = line.find(':').ok_or(InvalidResponseError)?;
138            res.headers.insert(
139                line[0..split].trim().to_string(),
140                line[split + 1..].trim().to_string(),
141            );
142        }
143
144        // Read body
145        if let Some(transfer_encoding) = res.headers.get("Transfer-Encoding") {
146            if transfer_encoding == "chunked" {
147                let mut body = Vec::new();
148                loop {
149                    // Read chunk size
150                    let mut size_line = String::new();
151                    reader
152                        .read_line(&mut size_line)
153                        .map_err(|_| InvalidResponseError)?;
154                    let size = usize::from_str_radix(size_line.trim(), 16)
155                        .map_err(|_| InvalidResponseError)?;
156                    if size == 0 {
157                        break;
158                    }
159
160                    // Read chunk
161                    let mut chunk = vec![0; size];
162                    reader
163                        .read_exact(&mut chunk)
164                        .map_err(|_| InvalidResponseError)?;
165                    body.extend_from_slice(&chunk);
166
167                    // Read the trailing \r\n after each chunk
168                    let mut crlf = [0; 2];
169                    reader
170                        .read_exact(&mut crlf)
171                        .map_err(|_| InvalidResponseError)?;
172                }
173                res.body = body;
174                return Ok(res);
175            }
176        }
177        if let Some(content_length) = res.headers.get("Content-Length") {
178            let content_length = content_length.parse().map_err(|_| InvalidResponseError)?;
179            if content_length > 0 {
180                res.body = vec![0; content_length];
181                reader
182                    .read_exact(&mut res.body)
183                    .map_err(|_| InvalidResponseError)?;
184            }
185        }
186        Ok(res)
187    }
188
189    pub(crate) fn write_to_stream(mut self, stream: &mut dyn Write, req: &Request) {
190        // Finish headers
191        #[cfg(feature = "date")]
192        self.headers
193            .insert("Date".to_string(), chrono::Utc::now().to_rfc2822());
194        self.headers
195            .insert("Content-Length".to_string(), self.body.len().to_string());
196        if req.version == Version::Http1_1 {
197            if req.headers.get("Connection").map(|v| v.as_str()) != Some("close") {
198                self.headers
199                    .insert("Connection".to_string(), "keep-alive".to_string());
200                self.headers.insert(
201                    "Keep-Alive".to_string(),
202                    format!("timeout={}", KEEP_ALIVE_TIMEOUT.as_secs()),
203                );
204            } else {
205                self.headers
206                    .insert("Connection".to_string(), "close".to_string());
207            }
208        }
209
210        // Write response
211        _ = write!(stream, "{} {}\r\n", req.version, self.status);
212        for (name, value) in self.headers.iter() {
213            _ = write!(stream, "{}: {}\r\n", name, value);
214        }
215        _ = write!(stream, "\r\n");
216        _ = stream.write_all(&self.body);
217    }
218}
219
220// MARK: InvalidResponseError
221#[derive(Debug)]
222pub(crate) struct InvalidResponseError;
223
224impl Display for InvalidResponseError {
225    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
226        write!(f, "Invalid response")
227    }
228}
229
230impl Error for InvalidResponseError {}
231
232// MARK: Tests
233#[cfg(test)]
234mod test {
235    use super::*;
236
237    #[test]
238    fn test_parse_response() {
239        let response_text = "HTTP/1.1 200 OK\r\nContent-Length: 13\r\n\r\nHello, world!";
240        let mut response_stream = response_text.as_bytes();
241        let response = Response::read_from_stream(&mut response_stream).unwrap();
242
243        assert_eq!(response.status, Status::Ok);
244        assert_eq!(response.headers.get("Content-Length").unwrap(), "13");
245        assert_eq!(response.body, b"Hello, world!");
246    }
247
248    #[test]
249    fn test_parse_response_with_headers() {
250        let response_text =
251            "HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\nX-Custom-Header: Value\r\n\r\n";
252        let mut response_stream = response_text.as_bytes();
253        let response = Response::read_from_stream(&mut response_stream).unwrap();
254
255        assert_eq!(response.status, Status::NotFound);
256        assert_eq!(response.headers.get("Content-Length").unwrap(), "0");
257        assert_eq!(response.headers.get("X-Custom-Header").unwrap(), "Value");
258        assert!(response.body.is_empty());
259    }
260
261    #[test]
262    fn test_parse_response_invalid() {
263        let response_text = "INVALID RESPONSE";
264        let mut response_stream = response_text.as_bytes();
265        let result = Response::read_from_stream(&mut response_stream);
266
267        assert!(result.is_err());
268    }
269
270    #[test]
271    fn test_parse_response_chunked_encoding() {
272        let response_text = "HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n4\r\nBast\r\n4\r\niaan\r\n0\r\n\r\n";
273        let mut response_stream = response_text.as_bytes();
274        let response = Response::read_from_stream(&mut response_stream).unwrap();
275
276        assert_eq!(response.status, Status::Ok);
277        assert_eq!(
278            response.headers.get("Transfer-Encoding").unwrap(),
279            "chunked"
280        );
281        assert_eq!(response.body, b"Bastiaan");
282    }
283
284    #[test]
285    #[cfg(feature = "json")]
286    fn test_parse_response_with_json() {
287        let response_text = "HTTP/1.1 200 OK\r\nContent-Length: 15\r\nContent-Type: application/json\r\n\r\n{\"key\":\"value\"}";
288        let mut response_stream = response_text.as_bytes();
289        let response = Response::read_from_stream(&mut response_stream).unwrap();
290
291        assert_eq!(response.status, Status::Ok);
292        assert_eq!(
293            response.headers.get("Content-Type").unwrap(),
294            "application/json"
295        );
296        assert_eq!(response.body, b"{\"key\":\"value\"}");
297
298        let json_value: serde_json::Value = response.into_json().unwrap();
299        assert_eq!(json_value["key"], "value");
300    }
301
302    #[test]
303    fn test_write_response() {
304        let response = Response::with_status(Status::Ok)
305            .header("Content-Length", "13")
306            .body("Hello, world!");
307        let mut response_stream = Vec::new();
308        let request = Request {
309            version: Version::Http1_1,
310            ..Default::default()
311        };
312        response.write_to_stream(&mut response_stream, &request);
313
314        let response_text = String::from_utf8(response_stream).unwrap();
315        assert!(response_text.contains("HTTP/1.1 200 OK"));
316        assert!(response_text.contains("Content-Length: 13"));
317        assert!(response_text.contains("\r\n\r\nHello, world!"));
318    }
319
320    #[test]
321    fn test_write_response_with_headers() {
322        let response = Response::with_status(Status::NotFound)
323            .header("Content-Length", "0")
324            .header("X-Custom-Header", "Value");
325        let mut response_stream = Vec::new();
326        let request = Request {
327            version: Version::Http1_1,
328            ..Default::default()
329        };
330        response.write_to_stream(&mut response_stream, &request);
331
332        let response_text = String::from_utf8(response_stream).unwrap();
333        assert!(response_text.contains("HTTP/1.1 404 Not Found"));
334        assert!(response_text.contains("Content-Length: 0"));
335        assert!(response_text.contains("X-Custom-Header: Value"));
336        assert!(response_text.contains("\r\n\r\n"));
337    }
338
339    #[test]
340    #[cfg(feature = "json")]
341    fn test_write_response_with_json() {
342        let response = Response::with_json(serde_json::json!({"key": "value"}));
343        let mut response_stream = Vec::new();
344        let request = Request {
345            version: Version::Http1_1,
346            ..Default::default()
347        };
348        response.write_to_stream(&mut response_stream, &request);
349
350        let response_text = String::from_utf8(response_stream).unwrap();
351        assert!(response_text.contains("HTTP/1.1 200 OK"));
352        assert!(response_text.contains("Content-Type: application/json"));
353        assert!(response_text.contains("\r\n\r\n{\"key\":\"value\"}"));
354    }
355}