1use std::collections::HashMap;
2use std::io::Read;
3use std::net::TcpStream;
4use regex::Regex;
5use crate::parser::parse_url_encoded;
6
7pub type Headers = HashMap<String, Vec<String>>;
8
9
10#[derive(Debug)]
11pub enum RequestHeaderError {
12 MaxSizeExceed,
14 ClientDisconnected,
16}
17
18
19pub fn extract_headers(stream: &mut TcpStream, start_header: &mut String,
22 partial_body_bytes: &mut Vec<u8>, max_size: usize) -> Result<Headers, RequestHeaderError> {
23 let mut header_bytes = Vec::new();
24
25 let mut read_all_headers = false;
26
27 while !read_all_headers {
28 if header_bytes.len() > max_size {
29 return Err(RequestHeaderError::MaxSizeExceed);
30 }
31
32 let mut buffer = [0u8; 1024];
33 let read_result = stream.read(&mut buffer);
34
35 let read_size;
36
37 match read_result {
38 Ok(bytes_read) => {
39 if bytes_read == 0 {
40 return Err(RequestHeaderError::ClientDisconnected);
41 }
42 read_size = bytes_read;
43 }
44
45 Err(_) => {
46 return Err(RequestHeaderError::ClientDisconnected);
47 }
48 }
49
50 if let Some(header_end_index) = contains_full_headers(&buffer) {
53 header_bytes.extend(&buffer[..header_end_index]);
54
55 partial_body_bytes.extend(&buffer[header_end_index + 4..read_size]);
57 read_all_headers = true;
58 } else {
59 header_bytes.extend(&buffer[..read_size]);
60 }
61 }
62
63 let raw_request_headers = String::from_utf8(header_bytes)
64 .expect("Unsupported header encoding.");
65 let header_lines: Vec<&str> = raw_request_headers.split("\r\n").collect();
66
67 let mut headers: Headers = HashMap::new();
68 for (index, header_line) in header_lines.iter().enumerate() {
69 if index == 0 {
70 *start_header = header_line.to_string();
71 }
72
73 let key_value = parse_header(header_line);
74
75 if let Some((key, value)) = key_value {
76 if headers.contains_key(&key) {
77 let values = headers.get_mut(&key).unwrap();
78 values.push(value);
79 } else {
80 let header_value: Vec<String> = vec![value];
81 headers.insert(key, header_value);
82 }
83 }
84 };
85
86 return Ok(headers);
87}
88
89
90pub fn content_length(headers: &Headers) -> Option<usize> {
92 if let Some(values) = headers.get("Content-Length") {
93 if values.len() > 0 {
94 let value = values.get(0).unwrap();
95 let content_length_value = value.parse::<usize>().expect("Invalid content length");
96 return Some(content_length_value);
97 }
98 }
99
100 return None;
101}
102
103
104pub fn connection_type(headers: &Headers) -> Option<String> {
106 if let Some(values) = headers.get("Connection") {
107 if values.len() > 0 {
108 let value = values.get(0).unwrap();
109 return Some(value.to_owned());
110 }
111 }
112
113 return None;
114}
115
116pub fn host(headers: &Headers) -> Option<String> {
118 let host = headers.get("Host");
119 if let Some(host) = host {
120 if host.len() > 0 {
121 let value = host.get(0).unwrap();
122 return Some(value.to_string());
123 }
124 }
125
126 return None;
127}
128
129
130pub fn extract_content_type(headers: &Headers) -> Option<String> {
132 if let Some(values) = headers.get("Content-Type") {
133 let value = values.get(0).expect("Content-Type implementation error");
134 return Some(value.to_owned());
135 }
136
137 return None;
138}
139
140pub fn contains_full_headers(buffer: &[u8]) -> Option<usize> {
142 let end_header_bytes = b"\r\n\r\n";
143 buffer.windows(end_header_bytes.len()).position(|window| window == end_header_bytes)
144}
145
146
147pub fn parse_request_method_header(line: &str) -> Option<(String, String)> {
152 let pattern = Regex::new(r"(?<method>.+) (?<path>.+) (.+)").unwrap();
153
154 if let Some(groups) = pattern.captures(line) {
155 let request_method = &groups["method"];
156 let path = &groups["path"];
157 return Some((request_method.to_string(), path.to_string()));
158 }
159
160 return None;
161}
162
163pub fn parse_header(line: &str) -> Option<(String, String)> {
170 let header_line: Vec<&str> = line.splitn(2, ":").collect();
171 if header_line.len() >= 2 {
172 let name = header_line.get(0).unwrap().trim().to_string();
173 let value = header_line.get(1).unwrap().trim().to_string();
174 return Some((name, value));
175 }
176 return None;
177}
178
179
180pub fn query_params_from_raw(raw_path: &String) -> HashMap<String, Vec<String>> {
183 let query_params: HashMap<String, Vec<String>> = HashMap::new();
184 let match_result = raw_path.find("?");
185
186 if !match_result.is_some() {
187 return query_params;
188 }
189
190 let index = match_result.unwrap();
191 if index == raw_path.len() - 1 {
192 return query_params;
193 }
194
195 let slice = &raw_path[index + 1..raw_path.len()];
196 return parse_url_encoded(slice);
197}