1use std::error::Error;
8use std::fmt::{self, Display, Formatter};
9use std::io::{BufRead, BufReader, Read, Write};
10use std::net::TcpStream;
11
12use crate::enums::{Status, Version};
13use crate::header_map::HeaderMap;
14use crate::request::Request;
15use crate::KEEP_ALIVE_TIMEOUT;
16
17#[derive(Default)]
20pub struct Response {
21 pub status: Status,
23 pub headers: HeaderMap,
25 pub body: Vec<u8>,
27 pub(crate) takeover: Option<Box<dyn FnOnce(TcpStream) + Send + 'static>>,
28}
29
30impl Response {
31 pub fn new() -> Self {
33 Self::default()
34 }
35
36 pub fn with_status(status: Status) -> Self {
38 Self {
39 status,
40 ..Default::default()
41 }
42 }
43
44 pub fn status(mut self, status: Status) -> Self {
46 self.status = status;
47 self
48 }
49
50 pub fn with_header(name: impl Into<String>, value: impl Into<String>) -> Self {
52 Self::default().header(name.into(), value.into())
53 }
54
55 pub fn header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
57 self.headers.insert(name.into(), value.into());
58 self
59 }
60
61 pub fn with_body(body: impl Into<Vec<u8>>) -> Self {
63 Self {
64 body: body.into(),
65 ..Default::default()
66 }
67 }
68
69 pub fn body(mut self, body: impl Into<Vec<u8>>) -> Self {
71 self.body = body.into();
72 self
73 }
74
75 #[cfg(feature = "json")]
77 pub fn with_json(value: impl serde::Serialize) -> Self {
78 Self::default().json(value)
79 }
80
81 #[cfg(feature = "json")]
83 pub fn json(mut self, value: impl serde::Serialize) -> Self {
84 self.headers
85 .insert("Content-Type".to_string(), "application/json".to_string());
86 self.body = serde_json::to_string(&value)
87 .expect("Can't serialize json")
88 .into();
89 self
90 }
91
92 pub fn with_redirect(location: impl Into<String>) -> Self {
94 Self::default().redirect(location.into())
95 }
96
97 pub fn redirect(mut self, location: impl Into<String>) -> Self {
99 self.status = Status::TemporaryRedirect;
100 self.headers.insert("Location".to_string(), location.into());
101 self
102 }
103
104 pub fn takeover(mut self, f: impl FnOnce(TcpStream) + Send + 'static) -> Self {
106 self.takeover = Some(Box::new(f));
107 self
108 }
109
110 #[cfg(feature = "json")]
112 pub fn into_json<T: serde::de::DeserializeOwned>(self) -> Result<T, serde_json::Error> {
113 serde_json::from_slice(&self.body)
114 }
115
116 pub fn read_from_stream(stream: &mut dyn Read) -> Result<Self, InvalidResponseError> {
118 let mut reader = BufReader::new(stream);
119
120 let mut res = {
122 let mut line = String::new();
123 reader
124 .read_line(&mut line)
125 .map_err(|_| InvalidResponseError)?;
126 let mut parts = line.splitn(3, ' ');
127 let _http_version = parts.next().ok_or(InvalidResponseError)?;
128 let status_code = parts
129 .next()
130 .ok_or(InvalidResponseError)?
131 .parse::<i32>()
132 .map_err(|_| InvalidResponseError)?;
133 Response::default()
134 .status(Status::try_from(status_code).map_err(|_| InvalidResponseError)?)
135 };
136
137 loop {
139 let mut line = String::new();
140 reader
141 .read_line(&mut line)
142 .map_err(|_| InvalidResponseError)?;
143 if line == "\r\n" {
144 break;
145 }
146 let split = line.find(':').ok_or(InvalidResponseError)?;
147 res.headers.insert(
148 line[0..split].trim().to_string(),
149 line[split + 1..].trim().to_string(),
150 );
151 }
152
153 if let Some(transfer_encoding) = res.headers.get("Transfer-Encoding") {
155 if transfer_encoding == "chunked" {
156 let mut body = Vec::new();
157 loop {
158 let mut size_line = String::new();
160 reader
161 .read_line(&mut size_line)
162 .map_err(|_| InvalidResponseError)?;
163 let size = usize::from_str_radix(size_line.trim(), 16)
164 .map_err(|_| InvalidResponseError)?;
165 if size == 0 {
166 break;
167 }
168
169 let mut chunk = vec![0; size];
171 reader
172 .read_exact(&mut chunk)
173 .map_err(|_| InvalidResponseError)?;
174 body.extend_from_slice(&chunk);
175
176 let mut crlf = [0; 2];
178 reader
179 .read_exact(&mut crlf)
180 .map_err(|_| InvalidResponseError)?;
181 }
182 res.body = body;
183 return Ok(res);
184 }
185 }
186 if let Some(content_length) = res.headers.get("Content-Length") {
187 let content_length = content_length.parse().map_err(|_| InvalidResponseError)?;
188 if content_length > 0 {
189 res.body = vec![0; content_length];
190 reader
191 .read_exact(&mut res.body)
192 .map_err(|_| InvalidResponseError)?;
193 }
194 }
195 Ok(res)
196 }
197
198 pub(crate) fn write_to_stream(
199 &mut self,
200 stream: &mut dyn Write,
201 req: &Request,
202 keep_alive: bool,
203 ) {
204 self.finish_headers(req, keep_alive);
205
206 _ = write!(stream, "{} {}\r\n", req.version, self.status);
207 for (name, value) in self.headers.iter() {
208 _ = write!(stream, "{name}: {value}\r\n");
209 }
210 _ = write!(stream, "\r\n");
211 _ = stream.write_all(&self.body);
212 }
213
214 fn finish_headers(&mut self, req: &Request, keep_alive: bool) {
215 #[cfg(feature = "date")]
216 self.headers
217 .insert("Date".to_string(), chrono::Utc::now().to_rfc2822());
218 self.headers
219 .insert("Content-Length".to_string(), self.body.len().to_string());
220 if req.version == Version::Http1_1 {
221 if keep_alive && req.headers.get("Connection") != Some("close") {
222 if self.headers.get("Connection").is_none() {
223 self.headers
224 .insert("Connection".to_string(), "keep-alive".to_string());
225 self.headers.insert(
226 "Keep-Alive".to_string(),
227 format!("timeout={}", KEEP_ALIVE_TIMEOUT.as_secs()),
228 );
229 }
230 } else {
231 self.headers
232 .insert("Connection".to_string(), "close".to_string());
233 }
234 }
235 }
236}
237
238#[derive(Debug)]
241pub struct InvalidResponseError;
242
243impl Display for InvalidResponseError {
244 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
245 write!(f, "Invalid response")
246 }
247}
248
249impl Error for InvalidResponseError {}
250
251#[cfg(test)]
253mod test {
254 use super::*;
255
256 #[test]
257 fn test_parse_response() {
258 let response_text = "HTTP/1.1 200 OK\r\nContent-Length: 13\r\n\r\nHello, world!";
259 let mut response_stream = response_text.as_bytes();
260 let response = Response::read_from_stream(&mut response_stream).unwrap();
261
262 assert_eq!(response.status, Status::Ok);
263 assert_eq!(response.headers.get("Content-Length").unwrap(), "13");
264 assert_eq!(response.body, b"Hello, world!");
265 }
266
267 #[test]
268 fn test_parse_response_with_headers() {
269 let response_text =
270 "HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\nX-Custom-Header: Value\r\n\r\n";
271 let mut response_stream = response_text.as_bytes();
272 let response = Response::read_from_stream(&mut response_stream).unwrap();
273
274 assert_eq!(response.status, Status::NotFound);
275 assert_eq!(response.headers.get("Content-Length").unwrap(), "0");
276 assert_eq!(response.headers.get("X-Custom-Header").unwrap(), "Value");
277 assert!(response.body.is_empty());
278 }
279
280 #[test]
281 fn test_parse_response_invalid() {
282 let response_text = "INVALID RESPONSE";
283 let mut response_stream = response_text.as_bytes();
284 let result = Response::read_from_stream(&mut response_stream);
285
286 assert!(result.is_err());
287 }
288
289 #[test]
290 fn test_parse_response_chunked_encoding() {
291 let response_text = "HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n4\r\nBast\r\n4\r\niaan\r\n0\r\n\r\n";
292 let mut response_stream = response_text.as_bytes();
293 let response = Response::read_from_stream(&mut response_stream).unwrap();
294
295 assert_eq!(response.status, Status::Ok);
296 assert_eq!(
297 response.headers.get("Transfer-Encoding").unwrap(),
298 "chunked"
299 );
300 assert_eq!(response.body, b"Bastiaan");
301 }
302
303 #[test]
304 #[cfg(feature = "json")]
305 fn test_parse_response_with_json() {
306 let response_text = "HTTP/1.1 200 OK\r\nContent-Length: 15\r\nContent-Type: application/json\r\n\r\n{\"key\":\"value\"}";
307 let mut response_stream = response_text.as_bytes();
308 let response = Response::read_from_stream(&mut response_stream).unwrap();
309
310 assert_eq!(response.status, Status::Ok);
311 assert_eq!(
312 response.headers.get("Content-Type").unwrap(),
313 "application/json"
314 );
315 assert_eq!(response.body, b"{\"key\":\"value\"}");
316
317 let json_value: serde_json::Value = response.into_json().unwrap();
318 assert_eq!(json_value["key"], "value");
319 }
320
321 #[test]
322 fn test_write_response() {
323 let mut response = Response::with_status(Status::Ok)
324 .header("Content-Length", "13")
325 .body("Hello, world!");
326 let mut response_stream = Vec::new();
327 let request = Request {
328 version: Version::Http1_1,
329 ..Default::default()
330 };
331 response.write_to_stream(&mut response_stream, &request, true);
332
333 let response_text = String::from_utf8(response_stream).unwrap();
334 assert!(response_text.contains("HTTP/1.1 200 OK"));
335 assert!(response_text.contains("Content-Length: 13"));
336 assert!(response_text.contains("\r\n\r\nHello, world!"));
337 }
338
339 #[test]
340 fn test_write_response_with_headers() {
341 let mut response = Response::with_status(Status::NotFound)
342 .header("Content-Length", "0")
343 .header("X-Custom-Header", "Value");
344 let mut response_stream = Vec::new();
345 let request = Request {
346 version: Version::Http1_1,
347 ..Default::default()
348 };
349 response.write_to_stream(&mut response_stream, &request, true);
350
351 let response_text = String::from_utf8(response_stream).unwrap();
352 assert!(response_text.contains("HTTP/1.1 404 Not Found"));
353 assert!(response_text.contains("Content-Length: 0"));
354 assert!(response_text.contains("X-Custom-Header: Value"));
355 assert!(response_text.contains("\r\n\r\n"));
356 }
357
358 #[test]
359 #[cfg(feature = "json")]
360 fn test_write_response_with_json() {
361 let mut response = Response::with_json(serde_json::json!({"key": "value"}));
362 let mut response_stream = Vec::new();
363 let request = Request {
364 version: Version::Http1_1,
365 ..Default::default()
366 };
367 response.write_to_stream(&mut response_stream, &request, true);
368
369 let response_text = String::from_utf8(response_stream).unwrap();
370 assert!(response_text.contains("HTTP/1.1 200 OK"));
371 assert!(response_text.contains("Content-Type: application/json"));
372 assert!(response_text.contains("\r\n\r\n{\"key\":\"value\"}"));
373 }
374}