Skip to main content

solid_pod_rs/security/
cors.rs

1//! CORS policy primitive (Sprint 7 §6.2, ADR-057).
2//!
3//! Transport-agnostic CORS rules. Consumers (actix-web, axum) call
4//! [`CorsPolicy::preflight_headers`] from their `OPTIONS` handler and
5//! [`CorsPolicy::response_headers`] from the normal-response path; this
6//! crate never mounts routes itself.
7//!
8//! ## Semantics
9//!
10//! - **Allowed origins.** Either [`AllowedOrigins::Wildcard`] (any
11//!   origin) or [`AllowedOrigins::Exact`] (explicit allowlist). An
12//!   unlisted origin yields `None` from the preflight path — the caller
13//!   MUST emit a no-CORS response (typically 403 or the un-augmented
14//!   200).
15//! - **Credentials + wildcard.** Per the Fetch spec, `Access-Control-
16//!   Allow-Origin: *` is invalid when credentials are included. When
17//!   both are configured, the policy degrades to echoing the concrete
18//!   request origin and emits `Vary: Origin` so caches do not leak.
19//! - **Exposed headers.** Default set targets Solid interop (WAC-Allow,
20//!   Link, ETag, Accept-Patch, Accept-Post, Updates-Via). Operators
21//!   override via [`CorsPolicy::with_expose_headers`].
22//! - **Preflight advertising.** `Access-Control-Allow-Headers` echoes
23//!   the `Access-Control-Request-Headers` value verbatim (after
24//!   whitespace normalisation), matching JSS behaviour — consumers need
25//!   not maintain an allowlist of request headers.
26
27use std::collections::BTreeSet;
28use std::time::Duration;
29
30/// Environment variable: comma-separated list of allowed origins, or
31/// `*` for wildcard.
32pub const ENV_CORS_ALLOWED_ORIGINS: &str = "CORS_ALLOWED_ORIGINS";
33
34/// Environment variable: `true`/`1` to enable credentials.
35pub const ENV_CORS_ALLOW_CREDENTIALS: &str = "CORS_ALLOW_CREDENTIALS";
36
37/// Environment variable: preflight max-age in seconds.
38pub const ENV_CORS_MAX_AGE: &str = "CORS_MAX_AGE";
39
40/// Default Max-Age for preflight caching.
41pub const DEFAULT_MAX_AGE_SECS: u64 = 3_600;
42
43/// Default headers exposed to the browser — tuned for Solid / LDP
44/// interoperability.
45pub const DEFAULT_EXPOSE_HEADERS: &[&str] = &[
46    "WAC-Allow",
47    "Link",
48    "ETag",
49    "Accept-Patch",
50    "Accept-Post",
51    "Updates-Via",
52];
53
54/// Origin-matching strategy.
55#[derive(Debug, Clone)]
56pub enum AllowedOrigins {
57    /// Any origin is permitted. Combined with credentials, the policy
58    /// degrades to echo-concrete-origin mode (see module docs).
59    Wildcard,
60    /// Only origins present in the set are permitted. Comparison is
61    /// case-sensitive (RFC 6454 origins are ASCII).
62    Exact(BTreeSet<String>),
63}
64
65/// CORS policy aggregate root. Immutable after construction.
66#[derive(Debug, Clone)]
67pub struct CorsPolicy {
68    allowed_origins: AllowedOrigins,
69    allow_credentials: bool,
70    expose_headers: Vec<String>,
71    max_age: Duration,
72}
73
74impl CorsPolicy {
75    /// Maximally permissive default: wildcard origins, no credentials,
76    /// default expose headers, 3600 s preflight cache.
77    pub fn new() -> Self {
78        Self {
79            allowed_origins: AllowedOrigins::Wildcard,
80            allow_credentials: false,
81            expose_headers: DEFAULT_EXPOSE_HEADERS
82                .iter()
83                .map(|s| (*s).to_string())
84                .collect(),
85            max_age: Duration::from_secs(DEFAULT_MAX_AGE_SECS),
86        }
87    }
88
89    /// Load from env. Missing variables fall back to defaults; present
90    /// but unparseable values also fall back (ignored).
91    ///
92    /// - `CORS_ALLOWED_ORIGINS` — comma-separated list, or `*`.
93    /// - `CORS_ALLOW_CREDENTIALS` — `true`/`1`/`yes`/`on` enables.
94    /// - `CORS_MAX_AGE` — decimal seconds.
95    pub fn from_env() -> Self {
96        let allowed_origins = match std::env::var(ENV_CORS_ALLOWED_ORIGINS) {
97            Ok(raw) => parse_origins(&raw),
98            Err(_) => AllowedOrigins::Wildcard,
99        };
100        let allow_credentials = std::env::var(ENV_CORS_ALLOW_CREDENTIALS)
101            .ok()
102            .map(|v| {
103                let v = v.trim().to_ascii_lowercase();
104                matches!(v.as_str(), "1" | "true" | "yes" | "on")
105            })
106            .unwrap_or(false);
107        let max_age = std::env::var(ENV_CORS_MAX_AGE)
108            .ok()
109            .and_then(|v| v.trim().parse::<u64>().ok())
110            .map(Duration::from_secs)
111            .unwrap_or_else(|| Duration::from_secs(DEFAULT_MAX_AGE_SECS));
112
113        Self {
114            allowed_origins,
115            allow_credentials,
116            expose_headers: DEFAULT_EXPOSE_HEADERS
117                .iter()
118                .map(|s| (*s).to_string())
119                .collect(),
120            max_age,
121        }
122    }
123
124    /// Replace the origin strategy.
125    pub fn with_allowed_origins(mut self, origins: AllowedOrigins) -> Self {
126        self.allowed_origins = origins;
127        self
128    }
129
130    /// Enable (or disable) credentialed requests.
131    pub fn with_allow_credentials(mut self, allow: bool) -> Self {
132        self.allow_credentials = allow;
133        self
134    }
135
136    /// Override the exposed-headers list.
137    pub fn with_expose_headers(mut self, headers: Vec<String>) -> Self {
138        self.expose_headers = headers;
139        self
140    }
141
142    /// Override the preflight cache duration.
143    pub fn with_max_age(mut self, duration: Duration) -> Self {
144        self.max_age = duration;
145        self
146    }
147
148    /// Current preflight cache duration.
149    pub fn max_age(&self) -> Duration {
150        self.max_age
151    }
152
153    /// Build the header set for a preflight (`OPTIONS`) response.
154    ///
155    /// Returns `None` when the request origin is not permitted; the
156    /// caller MUST emit a no-CORS response (typically plain 403 or an
157    /// un-augmented 200).
158    ///
159    /// `req_method` is the value of `Access-Control-Request-Method`.
160    /// `req_headers` is the verbatim `Access-Control-Request-Headers`
161    /// value (comma-separated); passing an empty string is valid and
162    /// yields an empty `Access-Control-Allow-Headers`.
163    pub fn preflight_headers(
164        &self,
165        origin: Option<&str>,
166        req_method: &str,
167        req_headers: &str,
168    ) -> Option<Vec<(&'static str, String)>> {
169        let echoed_origin = self.echo_origin(origin)?;
170
171        let mut out: Vec<(&'static str, String)> = Vec::with_capacity(8);
172        out.push(("Access-Control-Allow-Origin", echoed_origin.clone()));
173
174        // Vary: Origin is mandatory when echoing; harmless when
175        // emitting `*` (caches already key on it).
176        out.push(("Vary", "Origin".to_string()));
177
178        if self.allow_credentials {
179            out.push(("Access-Control-Allow-Credentials", "true".to_string()));
180        }
181
182        // Methods — echo the single requested method. Servers MAY
183        // advertise the full method list here; we keep it minimal to
184        // match JSS + Fetch spec §4.9.
185        let methods = if req_method.trim().is_empty() {
186            default_methods()
187        } else {
188            req_method.trim().to_ascii_uppercase()
189        };
190        out.push(("Access-Control-Allow-Methods", methods));
191
192        // Headers — echo the request header list verbatim (trimmed).
193        // This is the JSS approach and sidesteps maintaining an
194        // allow-list of request headers on the server.
195        let normalised = normalise_header_list(req_headers);
196        out.push(("Access-Control-Allow-Headers", normalised));
197
198        // Max-Age for preflight cache.
199        out.push((
200            "Access-Control-Max-Age",
201            self.max_age.as_secs().to_string(),
202        ));
203
204        Some(out)
205    }
206
207    /// Build the header set for a normal (non-preflight) response.
208    ///
209    /// Always emits `Access-Control-Expose-Headers` plus — when the
210    /// origin is permitted — `Access-Control-Allow-Origin` and `Vary:
211    /// Origin`.
212    pub fn response_headers(&self, origin: Option<&str>) -> Vec<(&'static str, String)> {
213        let mut out: Vec<(&'static str, String)> = Vec::with_capacity(4);
214
215        if let Some(echoed) = self.echo_origin(origin) {
216            out.push(("Access-Control-Allow-Origin", echoed));
217            out.push(("Vary", "Origin".to_string()));
218            if self.allow_credentials {
219                out.push(("Access-Control-Allow-Credentials", "true".to_string()));
220            }
221        }
222
223        if !self.expose_headers.is_empty() {
224            out.push((
225                "Access-Control-Expose-Headers",
226                self.expose_headers.join(", "),
227            ));
228        }
229
230        out
231    }
232
233    /// Compute the value to emit in `Access-Control-Allow-Origin`.
234    ///
235    /// Returns `None` when the origin is not permitted. For wildcard +
236    /// credentials, echoes the concrete request origin; for wildcard
237    /// without credentials, returns `*`; for `Exact`, returns the
238    /// matched origin verbatim.
239    fn echo_origin(&self, origin: Option<&str>) -> Option<String> {
240        match &self.allowed_origins {
241            AllowedOrigins::Wildcard => {
242                if self.allow_credentials {
243                    // RFC: `*` is invalid with credentials; must echo
244                    // the concrete origin. If the caller did not send
245                    // an Origin header, we cannot safely emit `*`, so
246                    // return None and let the caller drop CORS headers.
247                    origin.map(|o| o.to_string())
248                } else {
249                    Some(origin.map(|o| o.to_string()).unwrap_or_else(|| "*".into()))
250                }
251            }
252            AllowedOrigins::Exact(set) => {
253                let o = origin?;
254                if set.contains(o) {
255                    Some(o.to_string())
256                } else {
257                    None
258                }
259            }
260        }
261    }
262}
263
264impl Default for CorsPolicy {
265    fn default() -> Self {
266        Self::new()
267    }
268}
269
270// --- helpers -------------------------------------------------------------
271
272fn parse_origins(raw: &str) -> AllowedOrigins {
273    let trimmed = raw.trim();
274    if trimmed == "*" {
275        return AllowedOrigins::Wildcard;
276    }
277    let set: BTreeSet<String> = trimmed
278        .split(',')
279        .map(|s| s.trim())
280        .filter(|s| !s.is_empty())
281        .map(|s| s.to_string())
282        .collect();
283    if set.is_empty() {
284        AllowedOrigins::Wildcard
285    } else {
286        AllowedOrigins::Exact(set)
287    }
288}
289
290fn default_methods() -> String {
291    "GET, HEAD, POST, PUT, PATCH, DELETE, OPTIONS".to_string()
292}
293
294fn normalise_header_list(raw: &str) -> String {
295    raw.split(',')
296        .map(|s| s.trim())
297        .filter(|s| !s.is_empty())
298        .collect::<Vec<_>>()
299        .join(", ")
300}
301
302// --- unit tests ----------------------------------------------------------
303
304#[cfg(test)]
305mod tests {
306    use super::*;
307
308    #[test]
309    fn default_wildcard_no_credentials_emits_star() {
310        let policy = CorsPolicy::new();
311        let echoed = policy.echo_origin(Some("https://x.example")).unwrap();
312        assert_eq!(echoed, "https://x.example");
313
314        let without = policy.echo_origin(None).unwrap();
315        assert_eq!(without, "*");
316    }
317
318    #[test]
319    fn wildcard_with_credentials_falls_back_to_origin() {
320        let policy = CorsPolicy::new().with_allow_credentials(true);
321        assert_eq!(
322            policy.echo_origin(Some("https://x.example")).unwrap(),
323            "https://x.example"
324        );
325        assert!(policy.echo_origin(None).is_none());
326    }
327
328    #[test]
329    fn exact_rejects_unlisted_origin() {
330        let mut s = BTreeSet::new();
331        s.insert("https://good.example".to_string());
332        let policy = CorsPolicy::new().with_allowed_origins(AllowedOrigins::Exact(s));
333        assert!(policy.echo_origin(Some("https://bad.example")).is_none());
334        assert_eq!(
335            policy.echo_origin(Some("https://good.example")).unwrap(),
336            "https://good.example"
337        );
338    }
339
340    #[test]
341    fn normalise_header_list_collapses_whitespace() {
342        assert_eq!(
343            normalise_header_list("  authorization ,dpop,   content-type "),
344            "authorization, dpop, content-type"
345        );
346    }
347
348    #[test]
349    fn parse_origins_wildcard_and_list() {
350        match parse_origins("*") {
351            AllowedOrigins::Wildcard => {}
352            _ => panic!("expected wildcard"),
353        }
354        match parse_origins("https://a,https://b") {
355            AllowedOrigins::Exact(set) => {
356                assert!(set.contains("https://a"));
357                assert!(set.contains("https://b"));
358            }
359            _ => panic!("expected exact"),
360        }
361    }
362}