rusty_web/headers/
mod.rs

1use std::collections::HashMap;
2use std::io::Read;
3use std::net::TcpStream;
4use regex::Regex;
5use crate::parser::parse_url_encoded;
6
7pub type Headers = HashMap<String, Vec<String>>;
8
9
10#[derive(Debug)]
11pub enum RequestHeaderError {
12    /// Occurs if header size is larger than the given limit
13    MaxSizeExceed,
14    /// Occurs if client is disconnected
15    ClientDisconnected,
16}
17
18
19/// It will try to read headers from the tcp stream.
20/// Returns type `RequestHeaderError` if failed to extract headers.
21pub fn extract_headers(stream: &mut TcpStream, start_header: &mut String,
22                       partial_body_bytes: &mut Vec<u8>, max_size: usize) -> Result<Headers, RequestHeaderError> {
23    let mut header_bytes = Vec::new();
24
25    let mut read_all_headers = false;
26
27    while !read_all_headers {
28        if header_bytes.len() > max_size {
29            return Err(RequestHeaderError::MaxSizeExceed);
30        }
31
32        let mut buffer = [0u8; 1024];
33        let read_result = stream.read(&mut buffer);
34
35        let read_size;
36
37        match read_result {
38            Ok(bytes_read) => {
39                if bytes_read == 0 {
40                    return Err(RequestHeaderError::ClientDisconnected);
41                }
42                read_size = bytes_read;
43            }
44
45            Err(_) => {
46                return Err(RequestHeaderError::ClientDisconnected);
47            }
48        }
49
50        // There will be index if the header is ended. However, contains_full_header don't take
51        // complete request header.
52        if let Some(header_end_index) = contains_full_headers(&buffer) {
53            header_bytes.extend(&buffer[..header_end_index]);
54
55            // Body starts from header_end_index + "\r\n\r\n"
56            partial_body_bytes.extend(&buffer[header_end_index + 4..read_size]);
57            read_all_headers = true;
58        } else {
59            header_bytes.extend(&buffer[..read_size]);
60        }
61    }
62
63    let raw_request_headers = String::from_utf8(header_bytes)
64        .expect("Unsupported header encoding.");
65    let header_lines: Vec<&str> = raw_request_headers.split("\r\n").collect();
66
67    let mut headers: Headers = HashMap::new();
68    for (index, header_line) in header_lines.iter().enumerate() {
69        if index == 0 {
70            *start_header = header_line.to_string();
71        }
72
73        let key_value = parse_header(header_line);
74
75        if let Some((key, value)) = key_value {
76            if headers.contains_key(&key) {
77                let values = headers.get_mut(&key).unwrap();
78                values.push(value);
79            } else {
80                let header_value: Vec<String> = vec![value];
81                headers.insert(key, header_value);
82            }
83        }
84    };
85
86    return Ok(headers);
87}
88
89
90/// Returns content length from the `Header` if available
91pub fn content_length(headers: &Headers) -> Option<usize> {
92    if let Some(values) = headers.get("Content-Length") {
93        if values.len() > 0 {
94            let value = values.get(0).unwrap();
95            let content_length_value = value.parse::<usize>().expect("Invalid content length");
96            return Some(content_length_value);
97        }
98    }
99
100    return None;
101}
102
103
104/// Returns the value of `Connection` header if available
105pub fn connection_type(headers: &Headers) -> Option<String> {
106    if let Some(values) = headers.get("Connection") {
107        if values.len() > 0 {
108            let value = values.get(0).unwrap();
109            return Some(value.to_owned());
110        }
111    }
112
113    return None;
114}
115
116/// Returns `Host` value from the Header if available.
117pub fn host(headers: &Headers) -> Option<String> {
118    let host = headers.get("Host");
119    if let Some(host) = host {
120        if host.len() > 0 {
121            let value = host.get(0).unwrap();
122            return Some(value.to_string());
123        }
124    }
125
126    return None;
127}
128
129
130/// Returns `Content-Type` value from the header if available
131pub fn extract_content_type(headers: &Headers) -> Option<String> {
132    if let Some(values) = headers.get("Content-Type") {
133        let value = values.get(0).expect("Content-Type implementation error");
134        return Some(value.to_owned());
135    }
136
137    return None;
138}
139
140/// Returns size of header end position if header ends with "\r\n\r\n"
141pub fn contains_full_headers(buffer: &[u8]) -> Option<usize> {
142    let end_header_bytes = b"\r\n\r\n";
143    buffer.windows(end_header_bytes.len()).position(|window| window == end_header_bytes)
144}
145
146
147/// Returns the request method and raw path from the header line if matched
148/// ```markdown
149/// GET / HTTP/1.1
150/// ```
151pub fn parse_request_method_header(line: &str) -> Option<(String, String)> {
152    let pattern = Regex::new(r"(?<method>.+) (?<path>.+) (.+)").unwrap();
153
154    if let Some(groups) = pattern.captures(line) {
155        let request_method = &groups["method"];
156        let path = &groups["path"];
157        return Some((request_method.to_string(), path.to_string()));
158    }
159
160    return None;
161}
162
163/// Returns key value pair from the header line
164///
165/// Input example:
166/// ```markdown
167/// Content-Length: 10
168/// ```
169pub fn parse_header(line: &str) -> Option<(String, String)> {
170    let header_line: Vec<&str> = line.splitn(2, ":").collect();
171    if header_line.len() >= 2 {
172        let name = header_line.get(0).unwrap().trim().to_string();
173        let value = header_line.get(1).unwrap().trim().to_string();
174        return Some((name, value));
175    }
176    return None;
177}
178
179
180/// Returns map of url encoded key values
181/// Example: `/search?name=John&age=22`
182pub fn query_params_from_raw(raw_path: &String) -> HashMap<String, Vec<String>> {
183    let query_params: HashMap<String, Vec<String>> = HashMap::new();
184    let match_result = raw_path.find("?");
185
186    if !match_result.is_some() {
187        return query_params;
188    }
189
190    let index = match_result.unwrap();
191    if index == raw_path.len() - 1 {
192        return query_params;
193    }
194
195    let slice = &raw_path[index + 1..raw_path.len()];
196    return parse_url_encoded(slice);
197}