tako/plugins/cors.rs
1#![cfg_attr(docsrs, doc(cfg(feature = "plugins")))]
2//! Cross-Origin Resource Sharing (CORS) plugin for handling cross-origin HTTP requests.
3//!
4//! This module provides comprehensive CORS support for Tako web applications, enabling
5//! secure cross-origin resource sharing between different domains. The plugin handles
6//! preflight OPTIONS requests, validates origins against configured policies, and adds
7//! appropriate CORS headers to responses. It supports configurable origins, methods,
8//! headers, credentials, and cache control for flexible cross-origin access policies.
9//!
10//! The CORS plugin can be applied at both router-level (all routes) and route-level
11//! (specific routes), allowing fine-grained control over CORS policies.
12//!
13//! # Examples
14//!
15//! ```rust
16//! use tako::plugins::cors::{CorsPlugin, CorsBuilder};
17//! use tako::plugins::TakoPlugin;
18//! use tako::router::Router;
19//! use http::Method;
20//!
21//! async fn api_handler(_req: tako::types::Request) -> &'static str {
22//! "API response"
23//! }
24//!
25//! async fn public_handler(_req: tako::types::Request) -> &'static str {
26//! "Public response"
27//! }
28//!
29//! let mut router = Router::new();
30//!
31//! // Router-level: Basic CORS setup allowing all origins (applied to all routes)
32//! let global_cors = CorsBuilder::new().build();
33//! router.plugin(global_cors);
34//!
35//! // Route-level: Restrictive CORS for specific API endpoint
36//! let api_route = router.route(Method::GET, "/api/data", api_handler);
37//! let api_cors = CorsBuilder::new()
38//! .allow_origin("https://app.example.com")
39//! .allow_origin("https://admin.example.com")
40//! .allow_methods(&[Method::GET, Method::POST, Method::PUT])
41//! .allow_credentials(true)
42//! .max_age_secs(86400)
43//! .build();
44//! api_route.plugin(api_cors);
45//!
46//! // Another route without CORS restrictions (uses global if set)
47//! router.route(Method::GET, "/public", public_handler);
48//! ```
49
50use anyhow::Result;
51use http::{
52 HeaderName, HeaderValue, Method, StatusCode,
53 header::{
54 ACCESS_CONTROL_ALLOW_CREDENTIALS, ACCESS_CONTROL_ALLOW_HEADERS, ACCESS_CONTROL_ALLOW_METHODS,
55 ACCESS_CONTROL_ALLOW_ORIGIN, ACCESS_CONTROL_MAX_AGE, ORIGIN,
56 },
57};
58
59use crate::{
60 body::TakoBody,
61 middleware::Next,
62 plugins::TakoPlugin,
63 responder::Responder,
64 router::Router,
65 types::{Request, Response},
66};
67
68/// CORS policy configuration settings for cross-origin request handling.
69///
70/// `Config` defines the Cross-Origin Resource Sharing policy including allowed origins,
71/// HTTP methods, headers, credential handling, and preflight cache duration. The
72/// configuration determines which cross-origin requests are permitted and what headers
73/// are added to responses to enable secure cross-origin communication.
74///
75/// # Examples
76///
77/// ```rust
78/// use tako::plugins::cors::Config;
79/// use http::{Method, HeaderName};
80///
81/// let config = Config {
82/// origins: vec!["https://app.example.com".to_string()],
83/// methods: vec![Method::GET, Method::POST],
84/// headers: vec![HeaderName::from_static("x-api-key")],
85/// allow_credentials: true,
86/// max_age_secs: Some(3600),
87/// };
88/// ```
89#[derive(Clone)]
90pub struct Config {
91 /// List of allowed origin URLs for cross-origin requests.
92 pub origins: Vec<String>,
93 /// List of allowed HTTP methods for cross-origin requests.
94 pub methods: Vec<Method>,
95 /// List of allowed request headers for cross-origin requests.
96 pub headers: Vec<HeaderName>,
97 /// Whether to allow credentials (cookies, authorization headers) in cross-origin requests.
98 pub allow_credentials: bool,
99 /// Maximum age in seconds for preflight request caching by browsers.
100 pub max_age_secs: Option<u32>,
101}
102
103impl Default for Config {
104 /// Provides permissive default CORS configuration suitable for development.
105 fn default() -> Self {
106 Self {
107 origins: Vec::new(),
108 methods: vec![
109 Method::GET,
110 Method::POST,
111 Method::PUT,
112 Method::PATCH,
113 Method::DELETE,
114 Method::OPTIONS,
115 ],
116 headers: Vec::new(),
117 allow_credentials: false,
118 max_age_secs: Some(3600),
119 }
120 }
121}
122
123/// Builder for configuring CORS policies with a fluent API.
124///
125/// `CorsBuilder` provides a convenient way to construct CORS configurations using
126/// method chaining. It starts with sensible defaults and allows selective customization
127/// of origins, methods, headers, and other CORS policy aspects. The builder pattern
128/// ensures all configuration is explicit while maintaining ease of use.
129///
130/// # Examples
131///
132/// ```rust
133/// use tako::plugins::cors::CorsBuilder;
134/// use http::{Method, HeaderName};
135///
136/// // Development setup - permissive CORS
137/// let dev_cors = CorsBuilder::new()
138/// .allow_credentials(false)
139/// .build();
140///
141/// // Production setup - restrictive CORS
142/// let prod_cors = CorsBuilder::new()
143/// .allow_origin("https://app.mysite.com")
144/// .allow_origin("https://admin.mysite.com")
145/// .allow_methods(&[Method::GET, Method::POST])
146/// .allow_headers(&[HeaderName::from_static("authorization")])
147/// .allow_credentials(true)
148/// .max_age_secs(86400)
149/// .build();
150/// ```
151pub struct CorsBuilder(Config);
152
153impl CorsBuilder {
154 /// Creates a new CORS configuration builder with default settings.
155 pub fn new() -> Self {
156 Self(Config::default())
157 }
158
159 /// Adds an allowed origin to the CORS policy.
160 pub fn allow_origin(mut self, o: impl Into<String>) -> Self {
161 self.0.origins.push(o.into());
162 self
163 }
164
165 /// Sets the allowed HTTP methods for cross-origin requests.
166 pub fn allow_methods(mut self, m: &[Method]) -> Self {
167 self.0.methods = m.to_vec();
168 self
169 }
170
171 /// Sets the allowed request headers for cross-origin requests.
172 pub fn allow_headers(mut self, h: &[HeaderName]) -> Self {
173 self.0.headers = h.to_vec();
174 self
175 }
176
177 /// Enables or disables credential sharing in cross-origin requests.
178 pub fn allow_credentials(mut self, allow: bool) -> Self {
179 self.0.allow_credentials = allow;
180 self
181 }
182
183 /// Sets the maximum age for preflight request caching.
184 pub fn max_age_secs(mut self, secs: u32) -> Self {
185 self.0.max_age_secs = Some(secs);
186 self
187 }
188
189 /// Builds the CORS plugin with the configured settings.
190 pub fn build(self) -> CorsPlugin {
191 CorsPlugin { cfg: self.0 }
192 }
193}
194
195/// CORS plugin for handling cross-origin resource sharing in Tako applications.
196///
197/// `CorsPlugin` implements the TakoPlugin trait to provide comprehensive CORS support
198/// including preflight request handling, origin validation, and response header
199/// management. It automatically handles OPTIONS preflight requests and adds appropriate
200/// CORS headers to all responses based on the configured policy.
201///
202/// # Examples
203///
204/// ```rust
205/// use tako::plugins::cors::{CorsPlugin, CorsBuilder};
206/// use tako::plugins::TakoPlugin;
207/// use tako::router::Router;
208/// use http::Method;
209///
210/// // Basic setup with default permissive policy
211/// let cors = CorsPlugin::default();
212/// let mut router = Router::new();
213/// router.plugin(cors);
214///
215/// // Custom restrictive policy for production
216/// let prod_cors = CorsBuilder::new()
217/// .allow_origin("https://myapp.com")
218/// .allow_methods(&[Method::GET, Method::POST])
219/// .allow_credentials(true)
220/// .build();
221/// router.plugin(prod_cors);
222/// ```
223#[derive(Clone)]
224#[doc(alias = "cors")]
225pub struct CorsPlugin {
226 cfg: Config,
227}
228
229impl Default for CorsPlugin {
230 /// Creates a CORS plugin with permissive default configuration.
231 fn default() -> Self {
232 Self {
233 cfg: Config::default(),
234 }
235 }
236}
237
238impl TakoPlugin for CorsPlugin {
239 /// Returns the plugin name for identification and debugging.
240 fn name(&self) -> &'static str {
241 "CorsPlugin"
242 }
243
244 /// Sets up the CORS plugin by registering middleware with the router.
245 fn setup(&self, router: &Router) -> Result<()> {
246 let cfg = self.cfg.clone();
247 router.middleware(move |req, next| {
248 let cfg = cfg.clone();
249 async move { handle_cors(req, next, cfg).await }
250 });
251 Ok(())
252 }
253}
254
255/// Handles CORS processing for incoming requests including preflight and actual requests.
256async fn handle_cors(req: Request, next: Next, cfg: Config) -> impl Responder {
257 let origin = req.headers().get(ORIGIN).cloned();
258
259 if req.method() == Method::OPTIONS {
260 let mut resp = http::Response::builder()
261 .status(StatusCode::NO_CONTENT)
262 .body(TakoBody::empty())
263 .unwrap();
264 add_cors_headers(&cfg, origin, &mut resp);
265 return resp.into_response();
266 }
267
268 let mut resp = next.run(req).await;
269 add_cors_headers(&cfg, origin, &mut resp);
270 resp.into_response()
271}
272
273/// Adds CORS headers to HTTP responses based on configuration and request origin.
274fn add_cors_headers(cfg: &Config, origin: Option<HeaderValue>, resp: &mut Response) {
275 // Origin validation and Access-Control-Allow-Origin header
276 let allow_origin = if cfg.origins.is_empty() {
277 "*".to_string()
278 } else if let Some(o) = &origin {
279 let s = o.to_str().unwrap_or_default();
280 if cfg.origins.iter().any(|p| p == s) {
281 s.to_string()
282 } else {
283 return; // Origin not allowed, don't add CORS headers
284 }
285 } else {
286 return; // No origin header, don't add CORS headers
287 };
288
289 resp.headers_mut().insert(
290 ACCESS_CONTROL_ALLOW_ORIGIN,
291 HeaderValue::from_str(&allow_origin).unwrap(),
292 );
293
294 // Access-Control-Allow-Methods header
295 let methods = if cfg.methods.is_empty() {
296 None
297 } else {
298 Some(
299 cfg
300 .methods
301 .iter()
302 .map(|m| m.as_str())
303 .collect::<Vec<_>>()
304 .join(","),
305 )
306 };
307 if let Some(v) = methods {
308 resp.headers_mut().insert(
309 ACCESS_CONTROL_ALLOW_METHODS,
310 HeaderValue::from_str(&v).unwrap(),
311 );
312 }
313
314 // Access-Control-Allow-Headers header
315 if cfg.headers.is_empty() {
316 // Allow all request headers by default when none are explicitly configured.
317 resp.headers_mut().insert(
318 ACCESS_CONTROL_ALLOW_HEADERS,
319 HeaderValue::from_static("*"),
320 );
321 } else {
322 let h = cfg
323 .headers
324 .iter()
325 .map(|h| h.as_str())
326 .collect::<Vec<_>>()
327 .join(",");
328 resp.headers_mut().insert(
329 ACCESS_CONTROL_ALLOW_HEADERS,
330 HeaderValue::from_str(&h).unwrap(),
331 );
332 }
333
334 // Access-Control-Allow-Credentials header
335 if cfg.allow_credentials {
336 resp.headers_mut().insert(
337 ACCESS_CONTROL_ALLOW_CREDENTIALS,
338 HeaderValue::from_static("true"),
339 );
340 }
341
342 // Access-Control-Max-Age header
343 if let Some(secs) = cfg.max_age_secs {
344 resp.headers_mut().insert(
345 ACCESS_CONTROL_MAX_AGE,
346 HeaderValue::from_str(&secs.to_string()).unwrap(),
347 );
348 }
349}