tako/plugins/
cors.rs

1#![cfg_attr(docsrs, doc(cfg(feature = "plugins")))]
2//! Cross-Origin Resource Sharing (CORS) plugin for handling cross-origin HTTP requests.
3//!
4//! This module provides comprehensive CORS support for Tako web applications, enabling
5//! secure cross-origin resource sharing between different domains. The plugin handles
6//! preflight OPTIONS requests, validates origins against configured policies, and adds
7//! appropriate CORS headers to responses. It supports configurable origins, methods,
8//! headers, credentials, and cache control for flexible cross-origin access policies.
9//!
10//! The CORS plugin can be applied at both router-level (all routes) and route-level
11//! (specific routes), allowing fine-grained control over CORS policies.
12//!
13//! # Examples
14//!
15//! ```rust
16//! use tako::plugins::cors::{CorsPlugin, CorsBuilder};
17//! use tako::plugins::TakoPlugin;
18//! use tako::router::Router;
19//! use http::Method;
20//!
21//! async fn api_handler(_req: tako::types::Request) -> &'static str {
22//!     "API response"
23//! }
24//!
25//! async fn public_handler(_req: tako::types::Request) -> &'static str {
26//!     "Public response"
27//! }
28//!
29//! let mut router = Router::new();
30//!
31//! // Router-level: Basic CORS setup allowing all origins (applied to all routes)
32//! let global_cors = CorsBuilder::new().build();
33//! router.plugin(global_cors);
34//!
35//! // Route-level: Restrictive CORS for specific API endpoint
36//! let api_route = router.route(Method::GET, "/api/data", api_handler);
37//! let api_cors = CorsBuilder::new()
38//!     .allow_origin("https://app.example.com")
39//!     .allow_origin("https://admin.example.com")
40//!     .allow_methods(&[Method::GET, Method::POST, Method::PUT])
41//!     .allow_credentials(true)
42//!     .max_age_secs(86400)
43//!     .build();
44//! api_route.plugin(api_cors);
45//!
46//! // Another route without CORS restrictions (uses global if set)
47//! router.route(Method::GET, "/public", public_handler);
48//! ```
49
50use anyhow::Result;
51use http::{
52  HeaderName, HeaderValue, Method, StatusCode,
53  header::{
54    ACCESS_CONTROL_ALLOW_CREDENTIALS, ACCESS_CONTROL_ALLOW_HEADERS, ACCESS_CONTROL_ALLOW_METHODS,
55    ACCESS_CONTROL_ALLOW_ORIGIN, ACCESS_CONTROL_MAX_AGE, ORIGIN,
56  },
57};
58
59use crate::{
60  body::TakoBody,
61  middleware::Next,
62  plugins::TakoPlugin,
63  responder::Responder,
64  router::Router,
65  types::{Request, Response},
66};
67
68/// CORS policy configuration settings for cross-origin request handling.
69///
70/// `Config` defines the Cross-Origin Resource Sharing policy including allowed origins,
71/// HTTP methods, headers, credential handling, and preflight cache duration. The
72/// configuration determines which cross-origin requests are permitted and what headers
73/// are added to responses to enable secure cross-origin communication.
74///
75/// # Examples
76///
77/// ```rust
78/// use tako::plugins::cors::Config;
79/// use http::{Method, HeaderName};
80///
81/// let config = Config {
82///     origins: vec!["https://app.example.com".to_string()],
83///     methods: vec![Method::GET, Method::POST],
84///     headers: vec![HeaderName::from_static("x-api-key")],
85///     allow_credentials: true,
86///     max_age_secs: Some(3600),
87/// };
88/// ```
89#[derive(Clone)]
90pub struct Config {
91  /// List of allowed origin URLs for cross-origin requests.
92  pub origins: Vec<String>,
93  /// List of allowed HTTP methods for cross-origin requests.
94  pub methods: Vec<Method>,
95  /// List of allowed request headers for cross-origin requests.
96  pub headers: Vec<HeaderName>,
97  /// Whether to allow credentials (cookies, authorization headers) in cross-origin requests.
98  pub allow_credentials: bool,
99  /// Maximum age in seconds for preflight request caching by browsers.
100  pub max_age_secs: Option<u32>,
101}
102
103impl Default for Config {
104  /// Provides permissive default CORS configuration suitable for development.
105  fn default() -> Self {
106    Self {
107      origins: Vec::new(),
108      methods: vec![
109        Method::GET,
110        Method::POST,
111        Method::PUT,
112        Method::PATCH,
113        Method::DELETE,
114        Method::OPTIONS,
115      ],
116      headers: Vec::new(),
117      allow_credentials: false,
118      max_age_secs: Some(3600),
119    }
120  }
121}
122
123/// Builder for configuring CORS policies with a fluent API.
124///
125/// `CorsBuilder` provides a convenient way to construct CORS configurations using
126/// method chaining. It starts with sensible defaults and allows selective customization
127/// of origins, methods, headers, and other CORS policy aspects. The builder pattern
128/// ensures all configuration is explicit while maintaining ease of use.
129///
130/// # Examples
131///
132/// ```rust
133/// use tako::plugins::cors::CorsBuilder;
134/// use http::{Method, HeaderName};
135///
136/// // Development setup - permissive CORS
137/// let dev_cors = CorsBuilder::new()
138///     .allow_credentials(false)
139///     .build();
140///
141/// // Production setup - restrictive CORS
142/// let prod_cors = CorsBuilder::new()
143///     .allow_origin("https://app.mysite.com")
144///     .allow_origin("https://admin.mysite.com")
145///     .allow_methods(&[Method::GET, Method::POST])
146///     .allow_headers(&[HeaderName::from_static("authorization")])
147///     .allow_credentials(true)
148///     .max_age_secs(86400)
149///     .build();
150/// ```
151pub struct CorsBuilder(Config);
152
153impl CorsBuilder {
154  /// Creates a new CORS configuration builder with default settings.
155  pub fn new() -> Self {
156    Self(Config::default())
157  }
158
159  /// Adds an allowed origin to the CORS policy.
160  pub fn allow_origin(mut self, o: impl Into<String>) -> Self {
161    self.0.origins.push(o.into());
162    self
163  }
164
165  /// Sets the allowed HTTP methods for cross-origin requests.
166  pub fn allow_methods(mut self, m: &[Method]) -> Self {
167    self.0.methods = m.to_vec();
168    self
169  }
170
171  /// Sets the allowed request headers for cross-origin requests.
172  pub fn allow_headers(mut self, h: &[HeaderName]) -> Self {
173    self.0.headers = h.to_vec();
174    self
175  }
176
177  /// Enables or disables credential sharing in cross-origin requests.
178  pub fn allow_credentials(mut self, allow: bool) -> Self {
179    self.0.allow_credentials = allow;
180    self
181  }
182
183  /// Sets the maximum age for preflight request caching.
184  pub fn max_age_secs(mut self, secs: u32) -> Self {
185    self.0.max_age_secs = Some(secs);
186    self
187  }
188
189  /// Builds the CORS plugin with the configured settings.
190  pub fn build(self) -> CorsPlugin {
191    CorsPlugin { cfg: self.0 }
192  }
193}
194
195/// CORS plugin for handling cross-origin resource sharing in Tako applications.
196///
197/// `CorsPlugin` implements the TakoPlugin trait to provide comprehensive CORS support
198/// including preflight request handling, origin validation, and response header
199/// management. It automatically handles OPTIONS preflight requests and adds appropriate
200/// CORS headers to all responses based on the configured policy.
201///
202/// # Examples
203///
204/// ```rust
205/// use tako::plugins::cors::{CorsPlugin, CorsBuilder};
206/// use tako::plugins::TakoPlugin;
207/// use tako::router::Router;
208/// use http::Method;
209///
210/// // Basic setup with default permissive policy
211/// let cors = CorsPlugin::default();
212/// let mut router = Router::new();
213/// router.plugin(cors);
214///
215/// // Custom restrictive policy for production
216/// let prod_cors = CorsBuilder::new()
217///     .allow_origin("https://myapp.com")
218///     .allow_methods(&[Method::GET, Method::POST])
219///     .allow_credentials(true)
220///     .build();
221/// router.plugin(prod_cors);
222/// ```
223#[derive(Clone)]
224#[doc(alias = "cors")]
225pub struct CorsPlugin {
226  cfg: Config,
227}
228
229impl Default for CorsPlugin {
230  /// Creates a CORS plugin with permissive default configuration.
231  fn default() -> Self {
232    Self {
233      cfg: Config::default(),
234    }
235  }
236}
237
238impl TakoPlugin for CorsPlugin {
239  /// Returns the plugin name for identification and debugging.
240  fn name(&self) -> &'static str {
241    "CorsPlugin"
242  }
243
244  /// Sets up the CORS plugin by registering middleware with the router.
245  fn setup(&self, router: &Router) -> Result<()> {
246    let cfg = self.cfg.clone();
247    router.middleware(move |req, next| {
248      let cfg = cfg.clone();
249      async move { handle_cors(req, next, cfg).await }
250    });
251    Ok(())
252  }
253}
254
255/// Handles CORS processing for incoming requests including preflight and actual requests.
256async fn handle_cors(req: Request, next: Next, cfg: Config) -> impl Responder {
257  let origin = req.headers().get(ORIGIN).cloned();
258
259  if req.method() == Method::OPTIONS {
260    let mut resp = http::Response::builder()
261      .status(StatusCode::NO_CONTENT)
262      .body(TakoBody::empty())
263      .unwrap();
264    add_cors_headers(&cfg, origin, &mut resp);
265    return resp.into_response();
266  }
267
268  let mut resp = next.run(req).await;
269  add_cors_headers(&cfg, origin, &mut resp);
270  resp.into_response()
271}
272
273/// Adds CORS headers to HTTP responses based on configuration and request origin.
274fn add_cors_headers(cfg: &Config, origin: Option<HeaderValue>, resp: &mut Response) {
275  // Origin validation and Access-Control-Allow-Origin header
276  let allow_origin = if cfg.origins.is_empty() {
277    "*".to_string()
278  } else if let Some(o) = &origin {
279    let s = o.to_str().unwrap_or_default();
280    if cfg.origins.iter().any(|p| p == s) {
281      s.to_string()
282    } else {
283      return; // Origin not allowed, don't add CORS headers
284    }
285  } else {
286    return; // No origin header, don't add CORS headers
287  };
288
289  resp.headers_mut().insert(
290    ACCESS_CONTROL_ALLOW_ORIGIN,
291    HeaderValue::from_str(&allow_origin).unwrap(),
292  );
293
294  // Access-Control-Allow-Methods header
295  let methods = if cfg.methods.is_empty() {
296    None
297  } else {
298    Some(
299      cfg
300        .methods
301        .iter()
302        .map(|m| m.as_str())
303        .collect::<Vec<_>>()
304        .join(","),
305    )
306  };
307  if let Some(v) = methods {
308    resp.headers_mut().insert(
309      ACCESS_CONTROL_ALLOW_METHODS,
310      HeaderValue::from_str(&v).unwrap(),
311    );
312  }
313
314  // Access-Control-Allow-Headers header
315  if cfg.headers.is_empty() {
316    // Allow all request headers by default when none are explicitly configured.
317    resp.headers_mut().insert(
318      ACCESS_CONTROL_ALLOW_HEADERS,
319      HeaderValue::from_static("*"),
320    );
321  } else {
322    let h = cfg
323      .headers
324      .iter()
325      .map(|h| h.as_str())
326      .collect::<Vec<_>>()
327      .join(",");
328    resp.headers_mut().insert(
329      ACCESS_CONTROL_ALLOW_HEADERS,
330      HeaderValue::from_str(&h).unwrap(),
331    );
332  }
333
334  // Access-Control-Allow-Credentials header
335  if cfg.allow_credentials {
336    resp.headers_mut().insert(
337      ACCESS_CONTROL_ALLOW_CREDENTIALS,
338      HeaderValue::from_static("true"),
339    );
340  }
341
342  // Access-Control-Max-Age header
343  if let Some(secs) = cfg.max_age_secs {
344    resp.headers_mut().insert(
345      ACCESS_CONTROL_MAX_AGE,
346      HeaderValue::from_str(&secs.to_string()).unwrap(),
347    );
348  }
349}