rust_web_server/cors/
mod.rs

1#[cfg(test)]
2mod tests;
3
4use std::env;
5use crate::header::Header;
6
7use crate::entry_point::Config;
8use crate::request::{METHOD, Request};
9use crate::response::{Error};
10
11#[derive(PartialEq, Eq, Clone, Debug)]
12pub struct Cors {
13    pub allow_all: bool,
14    pub allow_origins: Vec<String>,
15    pub allow_methods: Vec<String>,
16    pub allow_headers: Vec<String>,
17    pub allow_credentials: bool,
18    pub expose_headers: Vec<String>,
19    pub max_age: String,
20}
21
22impl Cors {
23    pub const MAX_AGE: &'static str = "86400";
24
25    pub fn get_vary_header_value() -> String {
26        Header::_ORIGIN.to_string()
27    }
28
29    pub fn allow_all(request: &Request) -> Result<Vec<Header>, Error> {
30        let mut headers : Vec<Header> = vec![];
31        let origin = request.get_header(Header::_ORIGIN.to_string());
32        if origin.is_some() {
33            let allow_origin = Header {
34                name: Header::_ACCESS_CONTROL_ALLOW_ORIGIN.to_string(),
35                value: origin.unwrap().value.to_string()
36            };
37            headers.push(allow_origin);
38
39            let allow_credentials = Header {
40                name: Header::_ACCESS_CONTROL_ALLOW_CREDENTIALS.to_string(),
41                value: "true".to_string()
42            };
43            headers.push(allow_credentials);
44
45            let is_options = request.method == METHOD.options;
46            if is_options {
47                let method = request.get_header(Header::_ACCESS_CONTROL_REQUEST_METHOD.to_string());
48                if method.is_some() {
49                    let allow_method = Header {
50                        name: Header::_ACCESS_CONTROL_ALLOW_METHODS.to_string(),
51                        value: method.unwrap().value.to_string()
52                    };
53                    headers.push(allow_method);
54                }
55
56                let access_control_request_headers = request.get_header(Header::_ACCESS_CONTROL_REQUEST_HEADERS.to_string());
57                if access_control_request_headers.is_some() {
58                    let request_headers = access_control_request_headers.unwrap();
59                    let allow_headers = Header {
60                        name: Header::_ACCESS_CONTROL_ALLOW_HEADERS.to_string(),
61                        value: request_headers.value.to_lowercase(),
62                    };
63                    headers.push(allow_headers);
64
65                    let expose_headers = Header {
66                        name: Header::_ACCESS_CONTROL_EXPOSE_HEADERS.to_string(),
67                        value: request_headers.value.to_lowercase(),
68                    };
69                    headers.push(expose_headers);
70                }
71
72                let max_age = Header {
73                    name: Header::_ACCESS_CONTROL_MAX_AGE.to_string(),
74                    value: Cors::MAX_AGE.to_string()
75                };
76                headers.push(max_age);
77            }
78
79        }
80
81        Ok(headers)
82    }
83
84    pub fn _process(request: &Request, cors: &Cors) -> Result<Vec<Header>, Error> {
85        let mut headers : Vec<Header> = vec![];
86
87        let allow_origins = cors.allow_origins.join(",");
88        let boxed_origin = request.get_header(Header::_ORIGIN.to_string());
89
90        if boxed_origin.is_none() {
91            return Ok(headers)
92        }
93
94        let origin = boxed_origin.unwrap();
95        let origin_value = format!("{}", origin.value);
96
97        let is_valid_origin = allow_origins.contains(&origin_value);
98        if !is_valid_origin {
99            return Ok(headers)
100        }
101
102        let allow_origin = Header {
103            name: Header::_ACCESS_CONTROL_ALLOW_ORIGIN.to_string(),
104            value: origin_value
105        };
106        headers.push(allow_origin);
107
108        if cors.allow_credentials {
109            let allow_credentials = Header {
110                name: Header::_ACCESS_CONTROL_ALLOW_CREDENTIALS.to_string(),
111                value: cors.allow_credentials.to_string()
112            };
113            headers.push(allow_credentials);
114        }
115
116        let is_options = request.method == METHOD.options;
117        if is_options {
118            let methods = cors.allow_methods.join(",");
119            let allow_methods = Header {
120                name: Header::_ACCESS_CONTROL_ALLOW_METHODS.to_string(),
121                value: methods
122            };
123            headers.push(allow_methods);
124
125            let allow_headers_value = cors.allow_headers.join(",");
126            let allow_headers = Header {
127                name: Header::_ACCESS_CONTROL_ALLOW_HEADERS.to_string(),
128                value: allow_headers_value.to_lowercase()
129            };
130            headers.push(allow_headers);
131
132            let allow_expose_headers  = cors.expose_headers.join(",");
133            let expose_headers = Header {
134                name: Header::_ACCESS_CONTROL_EXPOSE_HEADERS.to_string(),
135                value: allow_expose_headers.to_lowercase()
136            };
137            headers.push(expose_headers);
138
139            let max_age = Header {
140                name: Header::_ACCESS_CONTROL_MAX_AGE.to_string(),
141                value: cors.max_age.to_string()
142            };
143            headers.push(max_age);
144        }
145
146        Ok(headers)
147    }
148
149    pub fn process_using_default_config(request: &Request) -> Result<Vec<Header>, Error> {
150        let mut headers : Vec<Header> = vec![];
151        let boxed_allow_origins = env::var(Config::RWS_CONFIG_CORS_ALLOW_ORIGINS);
152        let mut allow_origins: String = "".to_string();
153        if boxed_allow_origins.is_err() {
154            eprintln!("unable to read {} environment variable", Config::RWS_CONFIG_CORS_ALLOW_ORIGINS);
155        } else {
156            allow_origins = boxed_allow_origins.unwrap();
157        }
158
159        let boxed_origin = request.get_header(Header::_ORIGIN.to_string());
160
161        if boxed_origin.is_none() {
162            return Ok(headers)
163        }
164
165        let origin = boxed_origin.unwrap();
166        let origin_value = format!("{}", origin.value);
167
168        let is_valid_origin = allow_origins.contains(&origin_value);
169        if !is_valid_origin {
170            return Ok(headers)
171        }
172
173        let allow_origin = Header {
174            name: Header::_ACCESS_CONTROL_ALLOW_ORIGIN.to_string(),
175            value: origin_value
176        };
177        headers.push(allow_origin);
178
179        let boxed_is_allow_credentials = env::var(Config::RWS_CONFIG_CORS_ALLOW_CREDENTIALS);
180        if boxed_is_allow_credentials.is_err() {
181            eprintln!("unable to read {} environment variable", Config::RWS_CONFIG_CORS_ALLOW_CREDENTIALS);
182        } else {
183            let boxed_parse = boxed_is_allow_credentials.unwrap().parse::<bool>();
184            if boxed_parse.is_err() {
185                eprintln!("unable to parse as boolean {} environment variable. Possible values are true or false", Config::RWS_CONFIG_CORS_ALLOW_CREDENTIALS);
186            } else {
187                let is_allow_credentials : bool = boxed_parse.unwrap();
188                if is_allow_credentials {
189                    let allow_credentials = Header {
190                        name: Header::_ACCESS_CONTROL_ALLOW_CREDENTIALS.to_string(),
191                        value: is_allow_credentials.to_string()
192                    };
193                    headers.push(allow_credentials);
194                }
195            }
196        }
197
198
199        let is_options = request.method == METHOD.options;
200        if is_options {
201            let boxed_methods = env::var(Config::RWS_CONFIG_CORS_ALLOW_METHODS);
202            if boxed_methods.is_err() {
203                eprintln!("unable to read {} environment variable", Config::RWS_CONFIG_CORS_ALLOW_METHODS);
204            } else {
205                let methods = boxed_methods.unwrap();
206                let allow_methods = Header {
207                    name: Header::_ACCESS_CONTROL_ALLOW_METHODS.to_string(),
208                    value: methods
209                };
210                headers.push(allow_methods);
211            }
212
213
214            let boxed_allow_headers_env_variable = env::var(Config::RWS_CONFIG_CORS_ALLOW_HEADERS);
215            if boxed_allow_headers_env_variable.is_err() {
216                eprintln!("unable to read {} environment variable", Config::RWS_CONFIG_CORS_ALLOW_HEADERS);
217            } else {
218                let allow_headers_env_variable = boxed_allow_headers_env_variable.unwrap();
219                let allow_headers = Header {
220                    name: Header::_ACCESS_CONTROL_ALLOW_HEADERS.to_string(),
221                    value: allow_headers_env_variable.to_lowercase()
222                };
223                headers.push(allow_headers);
224            }
225
226
227            let boxed_allow_expose_headers = env::var(Config::RWS_CONFIG_CORS_EXPOSE_HEADERS);
228            if boxed_allow_expose_headers.is_err() {
229                eprintln!("unable to read {} environment variable", Config::RWS_CONFIG_CORS_EXPOSE_HEADERS);
230            } else {
231                let allow_expose_headers  = boxed_allow_expose_headers.unwrap();
232                let expose_headers = Header {
233                    name: Header::_ACCESS_CONTROL_EXPOSE_HEADERS.to_string(),
234                    value: allow_expose_headers.to_lowercase()
235                };
236                headers.push(expose_headers);
237            }
238
239
240            let boxed_max_age_value = env::var(Config::RWS_CONFIG_CORS_MAX_AGE);
241            if boxed_max_age_value.is_err() {
242                eprintln!("unable to read {} environment variable", Config::RWS_CONFIG_CORS_MAX_AGE);
243            } else {
244                let max_age_value  = boxed_max_age_value.unwrap();
245                let max_age = Header {
246                    name: Header::_ACCESS_CONTROL_MAX_AGE.to_string(),
247                    value: max_age_value
248                };
249                headers.push(max_age);
250            }
251
252        }
253
254
255        Ok(headers)
256    }
257
258    pub fn get_headers(request: &Request) -> Vec<Header> {
259
260        let boxed_rws_config_cors_allow_all = env::var(Config::RWS_CONFIG_CORS_ALLOW_ALL);
261        if boxed_rws_config_cors_allow_all.is_err() {
262            eprintln!("unable to read {} environment variable", Config::RWS_CONFIG_CORS_ALLOW_ALL);
263            let boxed_cors_header_list = Cors::allow_all(&request);
264            if boxed_cors_header_list.is_err() {
265                eprintln!("unable to get Cors::allow_all headers {}", boxed_cors_header_list.err().unwrap().message);
266            } else {
267                return boxed_cors_header_list.unwrap()
268            }
269        } else {
270            let boxed_parse = boxed_rws_config_cors_allow_all.unwrap().parse::<bool>();
271            if boxed_parse.is_err() {
272                eprintln!("unable to parse as boolean {} environment variable. Possible values are true or false", Config::RWS_CONFIG_CORS_ALLOW_ALL);
273            } else {
274                let is_cors_set_to_allow_all_requests = boxed_parse.unwrap();
275                if !is_cors_set_to_allow_all_requests {
276                    let boxed_cors_header_list = Cors::process_using_default_config(&request);
277                    if boxed_cors_header_list.is_err() {
278                        eprintln!("unable to get Cors::process_using_default_config headers {}", boxed_cors_header_list.err().unwrap().message);
279                    } else {
280                        return boxed_cors_header_list.unwrap()
281                    }
282                }
283            }
284        }
285
286
287        let boxed_cors_header_list = Cors::allow_all(&request);
288        if boxed_cors_header_list.is_err() {
289            eprintln!("unable to get Cors::allow_all headers {}", boxed_cors_header_list.err().unwrap().message);
290            vec![]
291        } else {
292            return boxed_cors_header_list.unwrap()
293        }
294    }
295}