Skip to main content

rust_web_server/cors/
mod.rs

1#[cfg(test)]
2mod tests;
3
4use crate::header::Header;
5use crate::request::{METHOD, Request};
6use crate::response::Error;
7use crate::server_config::ServerConfig;
8
9#[derive(PartialEq, Eq, Clone, Debug)]
10pub struct Cors {
11    pub allow_all: bool,
12    pub allow_origins: Vec<String>,
13    pub allow_methods: Vec<String>,
14    pub allow_headers: Vec<String>,
15    pub allow_credentials: bool,
16    pub expose_headers: Vec<String>,
17    pub max_age: String,
18}
19
20impl Cors {
21    pub const MAX_AGE: &'static str = "86400";
22
23    pub fn get_vary_header_value() -> String {
24        Header::_ORIGIN.to_string()
25    }
26
27    pub fn allow_all(request: &Request) -> Result<Vec<Header>, Error> {
28        let mut headers : Vec<Header> = vec![];
29        let origin = request.get_header(Header::_ORIGIN.to_string());
30        if origin.is_some() {
31            let allow_origin = Header {
32                name: Header::_ACCESS_CONTROL_ALLOW_ORIGIN.to_string(),
33                value: origin.unwrap().value.to_string()
34            };
35            headers.push(allow_origin);
36
37            let allow_credentials = Header {
38                name: Header::_ACCESS_CONTROL_ALLOW_CREDENTIALS.to_string(),
39                value: "true".to_string()
40            };
41            headers.push(allow_credentials);
42
43            let is_options = request.method == METHOD.options;
44            if is_options {
45                let method = request.get_header(Header::_ACCESS_CONTROL_REQUEST_METHOD.to_string());
46                if method.is_some() {
47                    let allow_method = Header {
48                        name: Header::_ACCESS_CONTROL_ALLOW_METHODS.to_string(),
49                        value: method.unwrap().value.to_string()
50                    };
51                    headers.push(allow_method);
52                }
53
54                let access_control_request_headers = request.get_header(Header::_ACCESS_CONTROL_REQUEST_HEADERS.to_string());
55                if access_control_request_headers.is_some() {
56                    let request_headers = access_control_request_headers.unwrap();
57                    let allow_headers = Header {
58                        name: Header::_ACCESS_CONTROL_ALLOW_HEADERS.to_string(),
59                        value: request_headers.value.to_lowercase(),
60                    };
61                    headers.push(allow_headers);
62
63                    let expose_headers = Header {
64                        name: Header::_ACCESS_CONTROL_EXPOSE_HEADERS.to_string(),
65                        value: request_headers.value.to_lowercase(),
66                    };
67                    headers.push(expose_headers);
68                }
69
70                let max_age = Header {
71                    name: Header::_ACCESS_CONTROL_MAX_AGE.to_string(),
72                    value: Cors::MAX_AGE.to_string()
73                };
74                headers.push(max_age);
75            }
76
77        }
78
79        Ok(headers)
80    }
81
82    pub fn _process(request: &Request, cors: &Cors) -> Result<Vec<Header>, Error> {
83        let mut headers : Vec<Header> = vec![];
84
85        let allow_origins = cors.allow_origins.join(",");
86        let boxed_origin = request.get_header(Header::_ORIGIN.to_string());
87
88        if boxed_origin.is_none() {
89            return Ok(headers)
90        }
91
92        let origin = boxed_origin.unwrap();
93        let origin_value = format!("{}", origin.value);
94
95        let is_valid_origin = allow_origins.contains(&origin_value);
96        if !is_valid_origin {
97            return Ok(headers)
98        }
99
100        let allow_origin = Header {
101            name: Header::_ACCESS_CONTROL_ALLOW_ORIGIN.to_string(),
102            value: origin_value
103        };
104        headers.push(allow_origin);
105
106        if cors.allow_credentials {
107            let allow_credentials = Header {
108                name: Header::_ACCESS_CONTROL_ALLOW_CREDENTIALS.to_string(),
109                value: cors.allow_credentials.to_string()
110            };
111            headers.push(allow_credentials);
112        }
113
114        let is_options = request.method == METHOD.options;
115        if is_options {
116            let methods = cors.allow_methods.join(",");
117            let allow_methods = Header {
118                name: Header::_ACCESS_CONTROL_ALLOW_METHODS.to_string(),
119                value: methods
120            };
121            headers.push(allow_methods);
122
123            let allow_headers_value = cors.allow_headers.join(",");
124            let allow_headers = Header {
125                name: Header::_ACCESS_CONTROL_ALLOW_HEADERS.to_string(),
126                value: allow_headers_value.to_lowercase()
127            };
128            headers.push(allow_headers);
129
130            let allow_expose_headers  = cors.expose_headers.join(",");
131            let expose_headers = Header {
132                name: Header::_ACCESS_CONTROL_EXPOSE_HEADERS.to_string(),
133                value: allow_expose_headers.to_lowercase()
134            };
135            headers.push(expose_headers);
136
137            let max_age = Header {
138                name: Header::_ACCESS_CONTROL_MAX_AGE.to_string(),
139                value: cors.max_age.to_string()
140            };
141            headers.push(max_age);
142        }
143
144        Ok(headers)
145    }
146
147    /// Build CORS headers using the provided [`ServerConfig`].
148    ///
149    /// This is the primary implementation. [`get_headers`] is the legacy
150    /// entry point that reads from environment variables; all new call-sites
151    /// should prefer this version.
152    pub fn get_headers_from_config(request: &Request, config: &ServerConfig) -> Vec<Header> {
153        if config.cors_allow_all {
154            return Cors::allow_all(request).unwrap_or_default();
155        }
156        Cors::process_using_config(request, config).unwrap_or_default()
157    }
158
159    fn process_using_config(request: &Request, config: &ServerConfig) -> Result<Vec<Header>, Error> {
160        let mut headers: Vec<Header> = vec![];
161
162        let boxed_origin = request.get_header(Header::_ORIGIN.to_string());
163        if boxed_origin.is_none() {
164            return Ok(headers);
165        }
166        let origin_value = boxed_origin.unwrap().value.clone();
167
168        if !config.cors_allow_origins.contains(&origin_value) {
169            return Ok(headers);
170        }
171
172        headers.push(Header {
173            name: Header::_ACCESS_CONTROL_ALLOW_ORIGIN.to_string(),
174            value: origin_value,
175        });
176
177        let credentials_str = &config.cors_allow_credentials;
178        if credentials_str.eq_ignore_ascii_case("true") {
179            headers.push(Header {
180                name: Header::_ACCESS_CONTROL_ALLOW_CREDENTIALS.to_string(),
181                value: "true".to_string(),
182            });
183        }
184
185        if request.method == METHOD.options {
186            if !config.cors_allow_methods.is_empty() {
187                headers.push(Header {
188                    name: Header::_ACCESS_CONTROL_ALLOW_METHODS.to_string(),
189                    value: config.cors_allow_methods.clone(),
190                });
191            }
192            if !config.cors_allow_headers.is_empty() {
193                headers.push(Header {
194                    name: Header::_ACCESS_CONTROL_ALLOW_HEADERS.to_string(),
195                    value: config.cors_allow_headers.to_lowercase(),
196                });
197            }
198            if !config.cors_expose_headers.is_empty() {
199                headers.push(Header {
200                    name: Header::_ACCESS_CONTROL_EXPOSE_HEADERS.to_string(),
201                    value: config.cors_expose_headers.to_lowercase(),
202                });
203            }
204            if !config.cors_max_age.is_empty() {
205                headers.push(Header {
206                    name: Header::_ACCESS_CONTROL_MAX_AGE.to_string(),
207                    value: config.cors_max_age.clone(),
208                });
209            }
210        }
211
212        Ok(headers)
213    }
214
215    /// Legacy entry point that reads CORS settings from environment variables.
216    ///
217    /// Prefer [`get_headers_from_config`] when a [`ServerConfig`] is available
218    /// (e.g. inside [`App::execute`]). This variant is kept for call-sites that
219    /// do not yet have an `App`-level config reference.
220    pub fn process_using_default_config(request: &Request) -> Result<Vec<Header>, Error> {
221        let config = ServerConfig::from_env();
222        Self::process_using_config(request, &config)
223    }
224
225    /// Legacy entry point that reads CORS settings from environment variables.
226    ///
227    /// Prefer [`get_headers_from_config`] when a [`ServerConfig`] is available.
228    pub fn get_headers(request: &Request) -> Vec<Header> {
229        let config = ServerConfig::from_env();
230        Cors::get_headers_from_config(request, &config)
231    }
232}