1use std::collections::HashMap;
8use std::error::Error;
9use std::fmt::{self, Display, Formatter};
10use std::io::{BufRead, BufReader, Read, Write};
11use std::net::{Ipv4Addr, SocketAddr, TcpStream};
12use std::str::{self, FromStr};
13
14use url::Url;
15
16use crate::enums::{Method, Version};
17use crate::header_map::HeaderMap;
18use crate::response::Response;
19use crate::KEEP_ALIVE_TIMEOUT;
20
21#[derive(Clone)]
24pub struct Request {
25 pub(crate) version: Version,
27 pub url: Url,
29 pub method: Method,
31 pub headers: HeaderMap,
33 pub params: HashMap<String, String>,
35 pub body: Option<Vec<u8>>,
37 pub client_addr: SocketAddr,
39}
40
41impl Default for Request {
42 fn default() -> Self {
43 Self {
44 version: Version::Http1_1,
45 url: Url::from_str("http://localhost").expect("Should parse"),
46 method: Method::Get,
47 headers: HeaderMap::new(),
48 params: HashMap::new(),
49 body: None,
50 client_addr: (Ipv4Addr::LOCALHOST, 0).into(),
51 }
52 }
53}
54
55impl Request {
56 pub fn new() -> Self {
58 Self::default()
59 }
60
61 pub fn with_method(method: Method) -> Self {
63 Self {
64 method,
65 ..Self::default()
66 }
67 }
68
69 pub fn with_url(url: impl AsRef<str>) -> Self {
71 Self {
72 url: url.as_ref().parse().expect("Invalid url"),
73 ..Self::default()
74 }
75 }
76
77 fn with_method_and_url(method: Method, url: impl AsRef<str>) -> Self {
79 Self {
80 method,
81 url: url.as_ref().parse().expect("Invalid url"),
82 ..Self::default()
83 }
84 }
85
86 pub fn get(url: impl AsRef<str>) -> Self {
88 Self::with_method_and_url(Method::Get, url)
89 }
90
91 pub fn head(url: impl AsRef<str>) -> Self {
93 Self::with_method_and_url(Method::Head, url)
94 }
95
96 pub fn post(url: impl AsRef<str>) -> Self {
98 Self::with_method_and_url(Method::Post, url)
99 }
100
101 pub fn put(url: impl AsRef<str>) -> Self {
103 Self::with_method_and_url(Method::Put, url)
104 }
105
106 pub fn delete(url: impl AsRef<str>) -> Self {
108 Self::with_method_and_url(Method::Delete, url)
109 }
110
111 pub fn connect(url: impl AsRef<str>) -> Self {
113 Self::with_method_and_url(Method::Connect, url)
114 }
115
116 pub fn options(url: impl AsRef<str>) -> Self {
118 Self::with_method_and_url(Method::Options, url)
119 }
120
121 pub fn trace(url: impl AsRef<str>) -> Self {
123 Self::with_method_and_url(Method::Trace, url)
124 }
125
126 pub fn patch(url: impl AsRef<str>) -> Self {
128 Self::with_method_and_url(Method::Patch, url)
129 }
130
131 pub fn url(mut self, url: Url) -> Self {
133 self.url = url;
134 self
135 }
136
137 pub fn method(mut self, method: Method) -> Self {
139 self.method = method;
140 self
141 }
142
143 pub fn header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
145 self.headers.insert(name.into(), value.into());
146 self
147 }
148
149 pub fn body(mut self, body: impl Into<Vec<u8>>) -> Self {
151 self.body = Some(body.into());
152 self
153 }
154
155 pub(crate) fn read_from_stream(
156 stream: &mut dyn Read,
157 client_addr: SocketAddr,
158 ) -> Result<Request, InvalidRequestError> {
159 let mut reader = BufReader::new(stream);
160
161 let (method, path, version) = {
163 let mut line = String::new();
164 reader
165 .read_line(&mut line)
166 .map_err(|_| InvalidRequestError("Can't read first line".to_string()))?;
167 let mut parts = line.split(' ');
168 (
169 parts
170 .next()
171 .ok_or(InvalidRequestError(
172 "Can't read 1st part of first line".to_string(),
173 ))?
174 .trim()
175 .parse()
176 .map_err(|_| InvalidRequestError("Can't parse method".to_string()))?,
177 parts
178 .next()
179 .ok_or(InvalidRequestError(
180 "Can't read 2st part of first line".to_string(),
181 ))?
182 .trim()
183 .to_string(),
184 parts
185 .next()
186 .ok_or(InvalidRequestError(
187 "Can't read 3st part of first line".to_string(),
188 ))?
189 .trim()
190 .to_string()
191 .parse()
192 .map_err(|_| InvalidRequestError("Can't parse HTTP version".to_string()))?,
193 )
194 };
195
196 let mut headers = HeaderMap::new();
198 loop {
199 let mut line = String::new();
200 reader
201 .read_line(&mut line)
202 .map_err(|_| InvalidRequestError("Can't read header line".to_string()))?;
203 if line == "\r\n" {
204 break;
205 }
206 let split = line
207 .find(':')
208 .ok_or(InvalidRequestError("Can't parse header line".to_string()))?;
209 headers.insert(
210 line[0..split].trim().to_string(),
211 line[split + 1..].trim().to_string(),
212 );
213 }
214
215 let mut body = None;
217 if let Some(content_length) = headers.get("Content-Length") {
218 let content_length = content_length
219 .parse()
220 .map_err(|_| InvalidRequestError("Can't parse Content-Length".to_string()))?;
221 if content_length > 0 {
222 let mut buffer = vec![0; content_length];
223 reader.read(&mut buffer).map_err(|_| {
224 InvalidRequestError(
225 "Can't read Content-Length amount of bytes from stream".to_string(),
226 )
227 })?;
228 body = Some(buffer);
229 }
230 }
231
232 let url = Url::from_str(&if version == Version::Http1_1 {
234 format!(
235 "http://{}{}",
236 headers.get("Host").ok_or(InvalidRequestError(
237 "HTTP version is 1.1 but Host header is not set".to_string()
238 ))?,
239 path
240 )
241 } else {
242 format!("http://localhost{path}")
243 })
244 .map_err(|_| InvalidRequestError("Can't parse request url".to_string()))?;
245
246 Ok(Request {
247 version,
248 url,
249 method,
250 headers,
251 params: HashMap::new(),
252 body,
253 client_addr,
254 })
255 }
256
257 pub fn write_to_stream(mut self, stream: &mut dyn Write, keep_alive: bool) {
259 let host = self.url.host().expect("No host in URL");
261 self.headers.insert(
262 "Host".to_string(),
263 if let Some(port) = self.url.port() {
264 format!("{}:{}", &host, port)
265 } else {
266 host.to_string()
267 },
268 );
269 self.headers.insert(
270 "Content-Length".to_string(),
271 if let Some(body) = &self.body {
272 body.len()
273 } else {
274 0
275 }
276 .to_string(),
277 );
278 if self.version == Version::Http1_1 {
279 if keep_alive {
280 self.headers
281 .insert("Connection".to_string(), "keep-alive".to_string());
282 self.headers.insert(
283 "Keep-Alive".to_string(),
284 format!("timeout={}", KEEP_ALIVE_TIMEOUT.as_secs()),
285 );
286 } else {
287 self.headers
288 .insert("Connection".to_string(), "close".to_string());
289 }
290 }
291
292 let path = self.url.path();
294 let path = if let Some(query) = self.url.query() {
295 format!("{}?{}", &path, query)
296 } else {
297 path.to_string()
298 };
299 _ = write!(stream, "{} {} HTTP/1.1\r\n", self.method, path);
300 for (name, value) in self.headers.iter() {
301 _ = write!(stream, "{name}: {value}\r\n");
302 }
303 _ = write!(stream, "\r\n");
304 if let Some(body) = &self.body {
305 _ = stream.write_all(body);
306 }
307 }
308
309 pub fn fetch(self) -> Result<Response, FetchError> {
311 let mut stream = TcpStream::connect(format!(
312 "{}:{}",
313 self.url.host().expect("No host in URL"),
314 self.url.port().unwrap_or(80)
315 ))
316 .map_err(|_| FetchError)?;
317 self.write_to_stream(&mut stream, false);
318 Response::read_from_stream(&mut stream).map_err(|_| FetchError)
319 }
320}
321
322#[derive(Debug)]
324pub(crate) struct InvalidRequestError(String);
325
326impl Display for InvalidRequestError {
327 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
328 write!(f, "Invalid request: {}", self.0)
329 }
330}
331
332impl Error for InvalidRequestError {}
333
334#[derive(Debug)]
336pub struct FetchError;
337
338impl Display for FetchError {
339 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
340 write!(f, "Fetch error")
341 }
342}
343
344impl Error for FetchError {}
345
346#[cfg(test)]
348mod test {
349 use std::io::Write;
350 use std::net::{Ipv4Addr, TcpListener};
351 use std::thread;
352
353 use super::*;
354 use crate::enums::Status;
355
356 #[test]
357 fn test_read_from_stream() {
358 let raw_request = b"GET / HTTP/1.1\r\nHost: localhost\r\n\r\n";
359 let mut stream = &raw_request[..];
360 let request =
361 Request::read_from_stream(&mut stream, (Ipv4Addr::LOCALHOST, 12345).into()).unwrap();
362 assert_eq!(request.method, Method::Get);
363 assert_eq!(request.url.to_string(), "http://localhost/");
364 assert_eq!(request.version, Version::Http1_1);
365 assert_eq!(request.headers.get("Host").unwrap(), "localhost");
366 }
367
368 #[test]
369 fn test_read_from_stream_with_body() {
370 let raw_request =
371 b"POST / HTTP/1.1\r\nHost: localhost\r\nContent-Length: 13\r\n\r\nHello, world!";
372 let mut stream = &raw_request[..];
373 let request =
374 Request::read_from_stream(&mut stream, (Ipv4Addr::LOCALHOST, 12345).into()).unwrap();
375 assert_eq!(request.method, Method::Post);
376 assert_eq!(request.url.to_string(), "http://localhost/");
377 assert_eq!(request.version, Version::Http1_1);
378 assert_eq!(request.headers.get("Host").unwrap(), "localhost");
379 assert_eq!(request.body.unwrap(), b"Hello, world!");
380 }
381
382 #[test]
383 fn test_read_from_stream_with_body_lowercase_headers() {
384 let raw_request =
385 b"POST / HTTP/1.1\r\nhost: localhost\r\ncontent-Length: 13\r\n\r\nHello, world!";
386 let mut stream = &raw_request[..];
387 let request =
388 Request::read_from_stream(&mut stream, (Ipv4Addr::LOCALHOST, 12345).into()).unwrap();
389 assert_eq!(request.method, Method::Post);
390 assert_eq!(request.url.to_string(), "http://localhost/");
391 assert_eq!(request.version, Version::Http1_1);
392 assert_eq!(request.headers.get("Host").unwrap(), "localhost");
393 assert_eq!(request.body.unwrap(), b"Hello, world!");
394 }
395
396 #[test]
397 fn test_invalid_request_error() {
398 let raw_request = b"INVALID REQUEST";
399 let mut stream = &raw_request[..];
400 let result = Request::read_from_stream(&mut stream, (Ipv4Addr::LOCALHOST, 12345).into());
401 assert!(result.is_err());
402 }
403
404 #[test]
405 fn test_write_to_stream() {
406 let request = Request::get("http://localhost/").header("Host", "localhost");
407
408 let mut buffer = Vec::new();
409 request.write_to_stream(&mut buffer, false);
410 assert!(buffer.starts_with(b"GET / HTTP/1.1\r\n"));
411 }
412
413 #[test]
414 fn test_write_to_stream_with_body() {
415 let request = Request::post("http://localhost/")
416 .header("Host", "localhost")
417 .body("Hello, world!");
418
419 let mut buffer = Vec::new();
420 request.write_to_stream(&mut buffer, false);
421 assert!(buffer.starts_with(b"POST / HTTP/1.1\r\n"));
422 }
423
424 #[test]
425 fn test_fetch_http1_0() {
426 let listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).unwrap();
427 let server_addr = listener.local_addr().unwrap();
428 thread::spawn(move || {
429 let (mut stream, _) = listener.accept().unwrap();
430 stream
431 .write_all(b"HTTP/1.0 200 OK\r\nContent-Length: 4\r\n\r\ntest")
432 .unwrap();
433 stream.flush().unwrap();
434 });
435
436 let res = Request::get(format!("http://{server_addr}/"))
437 .fetch()
438 .unwrap();
439 assert_eq!(res.status, Status::Ok);
440 assert_eq!(res.body, "test".as_bytes());
441 }
442
443 #[test]
444 fn test_fetch_http1_1() {
445 let listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).unwrap();
446 let server_addr = listener.local_addr().unwrap();
447 thread::spawn(move || {
448 let (mut stream, _) = listener.accept().unwrap();
449 stream
450 .write_all(
451 b"HTTP/1.1 200 OK\r\nContent-Length: 4\r\nConnection: closed\r\n\r\ntest",
452 )
453 .unwrap();
454 stream.flush().unwrap();
455 });
456
457 let res = Request::get(format!("http://{server_addr}/"))
458 .fetch()
459 .unwrap();
460 assert_eq!(res.status, Status::Ok);
461 assert_eq!(res.body, "test".as_bytes());
462 }
463}