1use std::collections::HashMap;
8use std::error::Error;
9use std::fmt::{self, Display, Formatter};
10use std::io::{BufRead, BufReader, Read, Write};
11use std::net::{Ipv4Addr, SocketAddr, TcpStream};
12use std::str::{self, FromStr};
13
14use url::Url;
15
16use crate::enums::{Method, Version};
17use crate::Response;
18
19#[derive(Default, Clone)]
21pub struct HeaderMap(Vec<(String, String)>);
22
23impl HeaderMap {
24 pub fn new() -> Self {
26 Self::default()
27 }
28
29 pub fn get(&self, name: &str) -> Option<&String> {
31 self.0.iter().find(|(n, _)| n == name).map(|(_, v)| v)
32 }
33
34 pub fn iter(&self) -> impl Iterator<Item = (&String, &String)> {
36 self.0.iter().map(|(n, v)| (n, v))
37 }
38
39 pub fn insert(&mut self, name: String, value: String) {
41 self.0.push((name, value));
42 }
43}
44
45#[derive(Clone)]
47pub struct Request {
48 pub(crate) version: Version,
50 pub url: Url,
52 pub method: Method,
54 pub headers: HeaderMap,
56 pub params: HashMap<String, String>,
58 pub body: Option<Vec<u8>>,
60 pub client_addr: SocketAddr,
62}
63
64impl Default for Request {
65 fn default() -> Self {
66 Self {
67 version: Version::Http1_1,
68 url: Url::from_str("http://localhost").expect("Should parse"),
69 method: Method::Get,
70 headers: HeaderMap::new(),
71 params: HashMap::new(),
72 body: None,
73 client_addr: (Ipv4Addr::LOCALHOST, 0).into(),
74 }
75 }
76}
77
78impl Request {
79 pub fn new() -> Self {
81 Self::default()
82 }
83
84 pub fn with_url(url: impl AsRef<str>) -> Self {
86 Self {
87 url: url.as_ref().parse().expect("Invalid url"),
88 ..Self::default()
89 }
90 }
91
92 pub fn url(mut self, url: Url) -> Self {
94 self.url = url;
95 self
96 }
97
98 pub fn method(mut self, method: Method) -> Self {
100 self.method = method;
101 self
102 }
103
104 pub fn header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
106 self.headers.insert(name.into(), value.into());
107 self
108 }
109
110 pub fn body(mut self, body: impl Into<Vec<u8>>) -> Self {
112 self.body = Some(body.into());
113 self
114 }
115
116 pub(crate) fn read_from_stream(
117 stream: &mut dyn Read,
118 client_addr: SocketAddr,
119 ) -> Result<Request, InvalidRequestError> {
120 let mut reader = BufReader::new(stream);
121
122 let (method, path, version) = {
124 let mut line = String::new();
125 reader
126 .read_line(&mut line)
127 .map_err(|_| InvalidRequestError)?;
128 let mut parts = line.split(' ');
129 (
130 parts
131 .next()
132 .ok_or(InvalidRequestError)?
133 .trim()
134 .parse()
135 .map_err(|_| InvalidRequestError)?,
136 parts.next().ok_or(InvalidRequestError)?.trim().to_string(),
137 parts
138 .next()
139 .ok_or(InvalidRequestError)?
140 .trim()
141 .to_string()
142 .parse()
143 .map_err(|_| InvalidRequestError)?,
144 )
145 };
146
147 let mut headers = HeaderMap::new();
149 loop {
150 let mut line = String::new();
151 reader
152 .read_line(&mut line)
153 .map_err(|_| InvalidRequestError)?;
154 if line == "\r\n" {
155 break;
156 }
157 let split = line.find(':').ok_or(InvalidRequestError)?;
158 headers.insert(
159 line[0..split].trim().to_string(),
160 line[split + 1..].trim().to_string(),
161 );
162 }
163
164 let mut body = None;
166 if let Some(content_length) = headers.get("Content-Length") {
167 let content_length = content_length.parse().map_err(|_| InvalidRequestError)?;
168 if content_length > 0 {
169 let mut buffer = vec![0; content_length];
170 reader.read(&mut buffer).map_err(|_| InvalidRequestError)?;
171 body = Some(buffer);
172 }
173 }
174
175 let url = Url::from_str(&if version == Version::Http1_1 {
177 format!(
178 "http://{}{}",
179 headers.get("Host").ok_or(InvalidRequestError)?,
180 path
181 )
182 } else {
183 format!("http://localhost{}", path)
184 })
185 .map_err(|_| InvalidRequestError)?;
186
187 Ok(Request {
188 version,
189 url,
190 method,
191 headers,
192 params: HashMap::new(),
193 body,
194 client_addr,
195 })
196 }
197
198 pub(crate) fn write_to_stream(mut self, stream: &mut dyn Write) {
199 let host = self.url.host().expect("No host in URL");
201 self.headers.insert(
202 "Host".to_string(),
203 if let Some(port) = self.url.port() {
204 format!("{}:{}", &host, port)
205 } else {
206 host.to_string()
207 },
208 );
209 self.headers.insert(
210 "Content-Length".to_string(),
211 if let Some(body) = &self.body {
212 body.len()
213 } else {
214 0
215 }
216 .to_string(),
217 );
218 if self.version == Version::Http1_1 {
219 self.headers
220 .insert("Connection".to_string(), "close".to_string());
221 }
222
223 let path = self.url.path();
225 let path = if let Some(query) = self.url.query() {
226 format!("{}?{}", &path, query)
227 } else {
228 path.to_string()
229 };
230 _ = write!(stream, "{} {} HTTP/1.1\r\n", self.method, path);
231 for (name, value) in self.headers.iter() {
232 _ = write!(stream, "{}: {}\r\n", name, value);
233 }
234 _ = write!(stream, "\r\n");
235 if let Some(body) = &self.body {
236 _ = stream.write_all(body);
237 }
238 }
239
240 pub fn fetch(self) -> Result<Response, FetchError> {
242 let mut stream = TcpStream::connect(format!(
243 "{}:{}",
244 self.url.host().expect("No host in URL"),
245 self.url.port().unwrap_or(80)
246 ))
247 .map_err(|_| FetchError)?;
248 self.write_to_stream(&mut stream);
249 Response::read_from_stream(&mut stream).map_err(|_| FetchError)
250 }
251}
252
253#[derive(Debug)]
255pub(crate) struct InvalidRequestError;
256
257impl Display for InvalidRequestError {
258 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
259 write!(f, "Invalid request")
260 }
261}
262
263impl Error for InvalidRequestError {}
264
265#[derive(Debug)]
267pub struct FetchError;
268
269impl Display for FetchError {
270 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
271 write!(f, "Fetch error")
272 }
273}
274
275impl Error for FetchError {}
276
277#[cfg(test)]
279mod test {
280 use std::io::Write;
281 use std::net::{Ipv4Addr, TcpListener};
282 use std::thread;
283
284 use super::*;
285 use crate::enums::Status;
286
287 #[test]
288 fn test_read_from_stream() {
289 let raw_request = b"GET / HTTP/1.1\r\nHost: localhost\r\n\r\n";
290 let mut stream = &raw_request[..];
291 let request =
292 Request::read_from_stream(&mut stream, (Ipv4Addr::LOCALHOST, 12345).into()).unwrap();
293 assert_eq!(request.method, Method::Get);
294 assert_eq!(request.url.to_string(), "http://localhost/");
295 assert_eq!(request.version, Version::Http1_1);
296 assert_eq!(request.headers.get("Host").unwrap(), "localhost");
297 }
298
299 #[test]
300 fn test_read_from_stream_with_body() {
301 let raw_request =
302 b"POST / HTTP/1.1\r\nHost: localhost\r\nContent-Length: 13\r\n\r\nHello, world!";
303 let mut stream = &raw_request[..];
304 let request =
305 Request::read_from_stream(&mut stream, (Ipv4Addr::LOCALHOST, 12345).into()).unwrap();
306 assert_eq!(request.method, Method::Post);
307 assert_eq!(request.url.to_string(), "http://localhost/");
308 assert_eq!(request.version, Version::Http1_1);
309 assert_eq!(request.headers.get("Host").unwrap(), "localhost");
310 assert_eq!(request.body.unwrap(), b"Hello, world!");
311 }
312
313 #[test]
314 fn test_invalid_request_error() {
315 let raw_request = b"INVALID REQUEST";
316 let mut stream = &raw_request[..];
317 let result = Request::read_from_stream(&mut stream, (Ipv4Addr::LOCALHOST, 12345).into());
318 assert!(result.is_err());
319 }
320
321 #[test]
322 fn test_write_to_stream() {
323 let request = Request::new()
324 .method(Method::Get)
325 .url(Url::from_str("http://localhost/").unwrap())
326 .header("Host", "localhost");
327
328 let mut buffer = Vec::new();
329 request.write_to_stream(&mut buffer);
330 assert!(buffer.starts_with(b"GET / HTTP/1.1\r\n"));
331 }
332
333 #[test]
334 fn test_write_to_stream_with_body() {
335 let request = Request::new()
336 .method(Method::Post)
337 .url(Url::from_str("http://localhost/").unwrap())
338 .header("Host", "localhost")
339 .body("Hello, world!");
340
341 let mut buffer = Vec::new();
342 request.write_to_stream(&mut buffer);
343 assert!(buffer.starts_with(b"POST / HTTP/1.1\r\n"));
344 }
345
346 #[test]
347 fn test_fetch_http1_0() {
348 let listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).unwrap();
349 let server_addr = listener.local_addr().unwrap();
350 thread::spawn(move || {
351 let (mut stream, _) = listener.accept().unwrap();
352 stream
353 .write_all(b"HTTP/1.0 200 OK\r\nContent-Length: 4\r\n\r\ntest")
354 .unwrap();
355 });
356
357 let res = Request::with_url(format!("http://{}/", server_addr))
358 .fetch()
359 .unwrap();
360 assert_eq!(res.status, Status::Ok);
361 assert_eq!(res.body, "test".as_bytes());
362 }
363
364 #[test]
365 fn test_fetch_http1_1() {
366 let listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).unwrap();
367 let server_addr = listener.local_addr().unwrap();
368 thread::spawn(move || {
369 let (mut stream, _) = listener.accept().unwrap();
370 stream
371 .write_all(
372 b"HTTP/1.1 200 OK\r\nContent-Length: 4\r\nConnection: closed\r\n\r\ntest",
373 )
374 .unwrap();
375 });
376
377 let res = Request::with_url(format!("http://{}/", server_addr))
378 .fetch()
379 .unwrap();
380 assert_eq!(res.status, Status::Ok);
381 assert_eq!(res.body, "test".as_bytes());
382 }
383}