small_http/
response.rs

1/*
2 * Copyright (c) 2023-2025 Bastiaan van der Plaat
3 *
4 * SPDX-License-Identifier: MIT
5 */
6
7use std::error::Error;
8use std::fmt::{self, Display, Formatter};
9use std::io::{BufRead, BufReader, Read, Write};
10use std::net::TcpStream;
11
12use crate::enums::{Status, Version};
13use crate::header_map::HeaderMap;
14use crate::request::Request;
15use crate::KEEP_ALIVE_TIMEOUT;
16
17// MARK: Response
18/// HTTP response
19#[derive(Default)]
20pub struct Response {
21    /// Status
22    pub status: Status,
23    /// Headers
24    pub headers: HeaderMap,
25    /// Body
26    pub body: Vec<u8>,
27    pub(crate) takeover: Option<Box<dyn FnOnce(TcpStream) + Send + 'static>>,
28}
29
30impl Response {
31    /// Create new response
32    pub fn new() -> Self {
33        Self::default()
34    }
35
36    /// Create new response with status
37    pub fn with_status(status: Status) -> Self {
38        Self {
39            status,
40            ..Default::default()
41        }
42    }
43
44    /// Set status
45    pub fn status(mut self, status: Status) -> Self {
46        self.status = status;
47        self
48    }
49
50    /// Create new response with header
51    pub fn with_header(name: impl Into<String>, value: impl Into<String>) -> Self {
52        Self::default().header(name.into(), value.into())
53    }
54
55    /// Set header
56    pub fn header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
57        self.headers.insert(name.into(), value.into());
58        self
59    }
60
61    /// Create new response with body
62    pub fn with_body(body: impl Into<Vec<u8>>) -> Self {
63        Self {
64            body: body.into(),
65            ..Default::default()
66        }
67    }
68
69    /// Set body
70    pub fn body(mut self, body: impl Into<Vec<u8>>) -> Self {
71        self.body = body.into();
72        self
73    }
74
75    /// Create new response with json body
76    #[cfg(feature = "json")]
77    pub fn with_json(value: impl serde::Serialize) -> Self {
78        Self::default().json(value)
79    }
80
81    /// Set json body
82    #[cfg(feature = "json")]
83    pub fn json(mut self, value: impl serde::Serialize) -> Self {
84        self.headers
85            .insert("Content-Type".to_string(), "application/json".to_string());
86        self.body = serde_json::to_string(&value)
87            .expect("Can't serialize json")
88            .into();
89        self
90    }
91
92    /// Create new response with redirect header
93    pub fn with_redirect(location: impl Into<String>) -> Self {
94        Self::default().redirect(location.into())
95    }
96
97    /// Set redirect header
98    pub fn redirect(mut self, location: impl Into<String>) -> Self {
99        self.status = Status::TemporaryRedirect;
100        self.headers.insert("Location".to_string(), location.into());
101        self
102    }
103
104    /// Set takeover function
105    pub fn takeover(mut self, f: impl FnOnce(TcpStream) + Send + 'static) -> Self {
106        self.takeover = Some(Box::new(f));
107        self
108    }
109
110    /// Parse json out of body
111    #[cfg(feature = "json")]
112    pub fn into_json<T: serde::de::DeserializeOwned>(self) -> Result<T, serde_json::Error> {
113        serde_json::from_slice(&self.body)
114    }
115
116    /// Read response from stream
117    pub fn read_from_stream(stream: &mut dyn Read) -> Result<Self, InvalidResponseError> {
118        let mut reader = BufReader::new(stream);
119
120        // Read first line
121        let mut res = {
122            let mut line = String::new();
123            reader
124                .read_line(&mut line)
125                .map_err(|_| InvalidResponseError)?;
126            let mut parts = line.splitn(3, ' ');
127            let _http_version = parts.next().ok_or(InvalidResponseError)?;
128            let status_code = parts
129                .next()
130                .ok_or(InvalidResponseError)?
131                .parse::<i32>()
132                .map_err(|_| InvalidResponseError)?;
133            Response::default()
134                .status(Status::try_from(status_code).map_err(|_| InvalidResponseError)?)
135        };
136
137        // Read headers
138        loop {
139            let mut line = String::new();
140            reader
141                .read_line(&mut line)
142                .map_err(|_| InvalidResponseError)?;
143            if line == "\r\n" {
144                break;
145            }
146            let split = line.find(':').ok_or(InvalidResponseError)?;
147            res.headers.insert(
148                line[0..split].trim().to_string(),
149                line[split + 1..].trim().to_string(),
150            );
151        }
152
153        // Read body
154        if let Some(transfer_encoding) = res.headers.get("Transfer-Encoding") {
155            if transfer_encoding == "chunked" {
156                let mut body = Vec::new();
157                loop {
158                    // Read chunk size
159                    let mut size_line = String::new();
160                    reader
161                        .read_line(&mut size_line)
162                        .map_err(|_| InvalidResponseError)?;
163                    let size = usize::from_str_radix(size_line.trim(), 16)
164                        .map_err(|_| InvalidResponseError)?;
165                    if size == 0 {
166                        break;
167                    }
168
169                    // Read chunk
170                    let mut chunk = vec![0; size];
171                    reader
172                        .read_exact(&mut chunk)
173                        .map_err(|_| InvalidResponseError)?;
174                    body.extend_from_slice(&chunk);
175
176                    // Read the trailing \r\n after each chunk
177                    let mut crlf = [0; 2];
178                    reader
179                        .read_exact(&mut crlf)
180                        .map_err(|_| InvalidResponseError)?;
181                }
182                res.body = body;
183                return Ok(res);
184            }
185        }
186        if let Some(content_length) = res.headers.get("Content-Length") {
187            let content_length = content_length.parse().map_err(|_| InvalidResponseError)?;
188            if content_length > 0 {
189                res.body = vec![0; content_length];
190                reader
191                    .read_exact(&mut res.body)
192                    .map_err(|_| InvalidResponseError)?;
193            }
194        }
195        Ok(res)
196    }
197
198    pub(crate) fn write_to_stream(
199        &mut self,
200        stream: &mut dyn Write,
201        req: &Request,
202        keep_alive: bool,
203    ) {
204        self.finish_headers(req, keep_alive);
205
206        _ = write!(stream, "{} {}\r\n", req.version, self.status);
207        for (name, value) in self.headers.iter() {
208            _ = write!(stream, "{name}: {value}\r\n");
209        }
210        _ = write!(stream, "\r\n");
211        _ = stream.write_all(&self.body);
212    }
213
214    fn finish_headers(&mut self, req: &Request, keep_alive: bool) {
215        #[cfg(feature = "date")]
216        self.headers
217            .insert("Date".to_string(), chrono::Utc::now().to_rfc2822());
218        self.headers
219            .insert("Content-Length".to_string(), self.body.len().to_string());
220        if req.version == Version::Http1_1 {
221            if keep_alive && req.headers.get("Connection") != Some("close") {
222                if self.headers.get("Connection").is_none() {
223                    self.headers
224                        .insert("Connection".to_string(), "keep-alive".to_string());
225                    self.headers.insert(
226                        "Keep-Alive".to_string(),
227                        format!("timeout={}", KEEP_ALIVE_TIMEOUT.as_secs()),
228                    );
229                }
230            } else {
231                self.headers
232                    .insert("Connection".to_string(), "close".to_string());
233            }
234        }
235    }
236}
237
238// MARK: InvalidResponseError
239/// Invalid response error
240#[derive(Debug)]
241pub struct InvalidResponseError;
242
243impl Display for InvalidResponseError {
244    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
245        write!(f, "Invalid response")
246    }
247}
248
249impl Error for InvalidResponseError {}
250
251// MARK: Tests
252#[cfg(test)]
253mod test {
254    use super::*;
255
256    #[test]
257    fn test_parse_response() {
258        let response_text = "HTTP/1.1 200 OK\r\nContent-Length: 13\r\n\r\nHello, world!";
259        let mut response_stream = response_text.as_bytes();
260        let response = Response::read_from_stream(&mut response_stream).unwrap();
261
262        assert_eq!(response.status, Status::Ok);
263        assert_eq!(response.headers.get("Content-Length").unwrap(), "13");
264        assert_eq!(response.body, b"Hello, world!");
265    }
266
267    #[test]
268    fn test_parse_response_with_headers() {
269        let response_text =
270            "HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\nX-Custom-Header: Value\r\n\r\n";
271        let mut response_stream = response_text.as_bytes();
272        let response = Response::read_from_stream(&mut response_stream).unwrap();
273
274        assert_eq!(response.status, Status::NotFound);
275        assert_eq!(response.headers.get("Content-Length").unwrap(), "0");
276        assert_eq!(response.headers.get("X-Custom-Header").unwrap(), "Value");
277        assert!(response.body.is_empty());
278    }
279
280    #[test]
281    fn test_parse_response_invalid() {
282        let response_text = "INVALID RESPONSE";
283        let mut response_stream = response_text.as_bytes();
284        let result = Response::read_from_stream(&mut response_stream);
285
286        assert!(result.is_err());
287    }
288
289    #[test]
290    fn test_parse_response_chunked_encoding() {
291        let response_text = "HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n4\r\nBast\r\n4\r\niaan\r\n0\r\n\r\n";
292        let mut response_stream = response_text.as_bytes();
293        let response = Response::read_from_stream(&mut response_stream).unwrap();
294
295        assert_eq!(response.status, Status::Ok);
296        assert_eq!(
297            response.headers.get("Transfer-Encoding").unwrap(),
298            "chunked"
299        );
300        assert_eq!(response.body, b"Bastiaan");
301    }
302
303    #[test]
304    #[cfg(feature = "json")]
305    fn test_parse_response_with_json() {
306        let response_text = "HTTP/1.1 200 OK\r\nContent-Length: 15\r\nContent-Type: application/json\r\n\r\n{\"key\":\"value\"}";
307        let mut response_stream = response_text.as_bytes();
308        let response = Response::read_from_stream(&mut response_stream).unwrap();
309
310        assert_eq!(response.status, Status::Ok);
311        assert_eq!(
312            response.headers.get("Content-Type").unwrap(),
313            "application/json"
314        );
315        assert_eq!(response.body, b"{\"key\":\"value\"}");
316
317        let json_value: serde_json::Value = response.into_json().unwrap();
318        assert_eq!(json_value["key"], "value");
319    }
320
321    #[test]
322    fn test_write_response() {
323        let mut response = Response::with_status(Status::Ok)
324            .header("Content-Length", "13")
325            .body("Hello, world!");
326        let mut response_stream = Vec::new();
327        let request = Request {
328            version: Version::Http1_1,
329            ..Default::default()
330        };
331        response.write_to_stream(&mut response_stream, &request, true);
332
333        let response_text = String::from_utf8(response_stream).unwrap();
334        assert!(response_text.contains("HTTP/1.1 200 OK"));
335        assert!(response_text.contains("Content-Length: 13"));
336        assert!(response_text.contains("\r\n\r\nHello, world!"));
337    }
338
339    #[test]
340    fn test_write_response_with_headers() {
341        let mut response = Response::with_status(Status::NotFound)
342            .header("Content-Length", "0")
343            .header("X-Custom-Header", "Value");
344        let mut response_stream = Vec::new();
345        let request = Request {
346            version: Version::Http1_1,
347            ..Default::default()
348        };
349        response.write_to_stream(&mut response_stream, &request, true);
350
351        let response_text = String::from_utf8(response_stream).unwrap();
352        assert!(response_text.contains("HTTP/1.1 404 Not Found"));
353        assert!(response_text.contains("Content-Length: 0"));
354        assert!(response_text.contains("X-Custom-Header: Value"));
355        assert!(response_text.contains("\r\n\r\n"));
356    }
357
358    #[test]
359    #[cfg(feature = "json")]
360    fn test_write_response_with_json() {
361        let mut response = Response::with_json(serde_json::json!({"key": "value"}));
362        let mut response_stream = Vec::new();
363        let request = Request {
364            version: Version::Http1_1,
365            ..Default::default()
366        };
367        response.write_to_stream(&mut response_stream, &request, true);
368
369        let response_text = String::from_utf8(response_stream).unwrap();
370        assert!(response_text.contains("HTTP/1.1 200 OK"));
371        assert!(response_text.contains("Content-Type: application/json"));
372        assert!(response_text.contains("\r\n\r\n{\"key\":\"value\"}"));
373    }
374}