solid_pod_rs/security/
cors.rs1use std::collections::BTreeSet;
28use std::time::Duration;
29
30pub const ENV_CORS_ALLOWED_ORIGINS: &str = "CORS_ALLOWED_ORIGINS";
33
34pub const ENV_CORS_ALLOW_CREDENTIALS: &str = "CORS_ALLOW_CREDENTIALS";
36
37pub const ENV_CORS_MAX_AGE: &str = "CORS_MAX_AGE";
39
40pub const DEFAULT_MAX_AGE_SECS: u64 = 3_600;
42
43pub const DEFAULT_EXPOSE_HEADERS: &[&str] = &[
46 "WAC-Allow",
47 "Link",
48 "ETag",
49 "Accept-Patch",
50 "Accept-Post",
51 "Updates-Via",
52];
53
54#[derive(Debug, Clone)]
56pub enum AllowedOrigins {
57 Wildcard,
60 Exact(BTreeSet<String>),
63}
64
65#[derive(Debug, Clone)]
67pub struct CorsPolicy {
68 allowed_origins: AllowedOrigins,
69 allow_credentials: bool,
70 expose_headers: Vec<String>,
71 max_age: Duration,
72}
73
74impl CorsPolicy {
75 pub fn new() -> Self {
78 Self {
79 allowed_origins: AllowedOrigins::Wildcard,
80 allow_credentials: false,
81 expose_headers: DEFAULT_EXPOSE_HEADERS
82 .iter()
83 .map(|s| (*s).to_string())
84 .collect(),
85 max_age: Duration::from_secs(DEFAULT_MAX_AGE_SECS),
86 }
87 }
88
89 pub fn from_env() -> Self {
96 let allowed_origins = match std::env::var(ENV_CORS_ALLOWED_ORIGINS) {
97 Ok(raw) => parse_origins(&raw),
98 Err(_) => AllowedOrigins::Wildcard,
99 };
100 let allow_credentials = std::env::var(ENV_CORS_ALLOW_CREDENTIALS)
101 .ok()
102 .map(|v| {
103 let v = v.trim().to_ascii_lowercase();
104 matches!(v.as_str(), "1" | "true" | "yes" | "on")
105 })
106 .unwrap_or(false);
107 let max_age = std::env::var(ENV_CORS_MAX_AGE)
108 .ok()
109 .and_then(|v| v.trim().parse::<u64>().ok())
110 .map(Duration::from_secs)
111 .unwrap_or_else(|| Duration::from_secs(DEFAULT_MAX_AGE_SECS));
112
113 Self {
114 allowed_origins,
115 allow_credentials,
116 expose_headers: DEFAULT_EXPOSE_HEADERS
117 .iter()
118 .map(|s| (*s).to_string())
119 .collect(),
120 max_age,
121 }
122 }
123
124 pub fn with_allowed_origins(mut self, origins: AllowedOrigins) -> Self {
126 self.allowed_origins = origins;
127 self
128 }
129
130 pub fn with_allow_credentials(mut self, allow: bool) -> Self {
132 self.allow_credentials = allow;
133 self
134 }
135
136 pub fn with_expose_headers(mut self, headers: Vec<String>) -> Self {
138 self.expose_headers = headers;
139 self
140 }
141
142 pub fn with_max_age(mut self, duration: Duration) -> Self {
144 self.max_age = duration;
145 self
146 }
147
148 pub fn max_age(&self) -> Duration {
150 self.max_age
151 }
152
153 pub fn preflight_headers(
164 &self,
165 origin: Option<&str>,
166 req_method: &str,
167 req_headers: &str,
168 ) -> Option<Vec<(&'static str, String)>> {
169 let echoed_origin = self.echo_origin(origin)?;
170
171 let mut out: Vec<(&'static str, String)> = Vec::with_capacity(8);
172 out.push(("Access-Control-Allow-Origin", echoed_origin.clone()));
173
174 out.push(("Vary", "Origin".to_string()));
177
178 if self.allow_credentials {
179 out.push(("Access-Control-Allow-Credentials", "true".to_string()));
180 }
181
182 let methods = if req_method.trim().is_empty() {
186 default_methods()
187 } else {
188 req_method.trim().to_ascii_uppercase()
189 };
190 out.push(("Access-Control-Allow-Methods", methods));
191
192 let normalised = normalise_header_list(req_headers);
196 out.push(("Access-Control-Allow-Headers", normalised));
197
198 out.push((
200 "Access-Control-Max-Age",
201 self.max_age.as_secs().to_string(),
202 ));
203
204 Some(out)
205 }
206
207 pub fn response_headers(&self, origin: Option<&str>) -> Vec<(&'static str, String)> {
213 let mut out: Vec<(&'static str, String)> = Vec::with_capacity(4);
214
215 if let Some(echoed) = self.echo_origin(origin) {
216 out.push(("Access-Control-Allow-Origin", echoed));
217 out.push(("Vary", "Origin".to_string()));
218 if self.allow_credentials {
219 out.push(("Access-Control-Allow-Credentials", "true".to_string()));
220 }
221 }
222
223 if !self.expose_headers.is_empty() {
224 out.push((
225 "Access-Control-Expose-Headers",
226 self.expose_headers.join(", "),
227 ));
228 }
229
230 out
231 }
232
233 fn echo_origin(&self, origin: Option<&str>) -> Option<String> {
240 match &self.allowed_origins {
241 AllowedOrigins::Wildcard => {
242 if self.allow_credentials {
243 origin.map(|o| o.to_string())
248 } else {
249 Some(origin.map(|o| o.to_string()).unwrap_or_else(|| "*".into()))
250 }
251 }
252 AllowedOrigins::Exact(set) => {
253 let o = origin?;
254 if set.contains(o) {
255 Some(o.to_string())
256 } else {
257 None
258 }
259 }
260 }
261 }
262}
263
264impl Default for CorsPolicy {
265 fn default() -> Self {
266 Self::new()
267 }
268}
269
270fn parse_origins(raw: &str) -> AllowedOrigins {
273 let trimmed = raw.trim();
274 if trimmed == "*" {
275 return AllowedOrigins::Wildcard;
276 }
277 let set: BTreeSet<String> = trimmed
278 .split(',')
279 .map(|s| s.trim())
280 .filter(|s| !s.is_empty())
281 .map(|s| s.to_string())
282 .collect();
283 if set.is_empty() {
284 AllowedOrigins::Wildcard
285 } else {
286 AllowedOrigins::Exact(set)
287 }
288}
289
290fn default_methods() -> String {
291 "GET, HEAD, POST, PUT, PATCH, DELETE, OPTIONS".to_string()
292}
293
294fn normalise_header_list(raw: &str) -> String {
295 raw.split(',')
296 .map(|s| s.trim())
297 .filter(|s| !s.is_empty())
298 .collect::<Vec<_>>()
299 .join(", ")
300}
301
302#[cfg(test)]
305mod tests {
306 use super::*;
307
308 #[test]
309 fn default_wildcard_no_credentials_emits_star() {
310 let policy = CorsPolicy::new();
311 let echoed = policy.echo_origin(Some("https://x.example")).unwrap();
312 assert_eq!(echoed, "https://x.example");
313
314 let without = policy.echo_origin(None).unwrap();
315 assert_eq!(without, "*");
316 }
317
318 #[test]
319 fn wildcard_with_credentials_falls_back_to_origin() {
320 let policy = CorsPolicy::new().with_allow_credentials(true);
321 assert_eq!(
322 policy.echo_origin(Some("https://x.example")).unwrap(),
323 "https://x.example"
324 );
325 assert!(policy.echo_origin(None).is_none());
326 }
327
328 #[test]
329 fn exact_rejects_unlisted_origin() {
330 let mut s = BTreeSet::new();
331 s.insert("https://good.example".to_string());
332 let policy = CorsPolicy::new().with_allowed_origins(AllowedOrigins::Exact(s));
333 assert!(policy.echo_origin(Some("https://bad.example")).is_none());
334 assert_eq!(
335 policy.echo_origin(Some("https://good.example")).unwrap(),
336 "https://good.example"
337 );
338 }
339
340 #[test]
341 fn normalise_header_list_collapses_whitespace() {
342 assert_eq!(
343 normalise_header_list(" authorization ,dpop, content-type "),
344 "authorization, dpop, content-type"
345 );
346 }
347
348 #[test]
349 fn parse_origins_wildcard_and_list() {
350 match parse_origins("*") {
351 AllowedOrigins::Wildcard => {}
352 _ => panic!("expected wildcard"),
353 }
354 match parse_origins("https://a,https://b") {
355 AllowedOrigins::Exact(set) => {
356 assert!(set.contains("https://a"));
357 assert!(set.contains("https://b"));
358 }
359 _ => panic!("expected exact"),
360 }
361 }
362}