Skip to main content

wae_https/middleware/
cors.rs

1#![doc = include_str!("readme.md")]
2
3use std::time::Duration;
4
5use http::{HeaderName, HeaderValue, Method, header};
6use tower_http::cors::{AllowHeaders, AllowMethods, AllowOrigin, CorsLayer};
7
8/// CORS 配置结构体
9///
10/// 用于配置跨域资源共享的各项参数。
11pub struct CorsConfig {
12    /// 允许的源列表
13    pub allowed_origins: Vec<HeaderValue>,
14    /// 允许的 HTTP 方法列表
15    pub allowed_methods: Vec<Method>,
16    /// 允许的请求头列表
17    pub allowed_headers: Vec<HeaderName>,
18    /// 是否允许携带凭证
19    pub allow_credentials: bool,
20    /// 预检请求缓存时间(秒)
21    pub max_age: u64,
22}
23
24impl CorsConfig {
25    /// 创建默认的 CORS 配置
26    ///
27    /// 默认配置:
28    /// - 允许所有源
29    /// - 允许所有方法
30    /// - 允许所有请求头
31    /// - 不允许凭证
32    /// - 预检缓存时间为 600 秒
33    pub fn new() -> Self {
34        Self {
35            allowed_origins: Vec::new(),
36            allowed_methods: Vec::new(),
37            allowed_headers: Vec::new(),
38            allow_credentials: false,
39            max_age: 600,
40        }
41    }
42
43    /// 添加允许的源
44    ///
45    /// # 参数
46    ///
47    /// * `origin` - 允许的源地址,如 "https://example.com"
48    pub fn allow_origin(mut self, origin: impl Into<String>) -> Self {
49        if let Ok(value) = HeaderValue::from_str(&origin.into()) {
50            self.allowed_origins.push(value);
51        }
52        self
53    }
54
55    /// 设置允许的源列表
56    ///
57    /// # 参数
58    ///
59    /// * `origins` - 允许的源地址迭代器
60    pub fn allow_origins<I, S>(mut self, origins: I) -> Self
61    where
62        I: IntoIterator<Item = S>,
63        S: Into<String>,
64    {
65        self.allowed_origins = origins.into_iter().filter_map(|s| HeaderValue::from_str(&s.into()).ok()).collect();
66        self
67    }
68
69    /// 添加允许的 HTTP 方法
70    ///
71    /// # 参数
72    ///
73    /// * `method` - 允许的 HTTP 方法,如 "GET"、"POST"
74    pub fn allow_method(mut self, method: impl Into<String>) -> Self {
75        if let Ok(m) = Method::from_bytes(method.into().as_bytes()) {
76            self.allowed_methods.push(m);
77        }
78        self
79    }
80
81    /// 设置允许的 HTTP 方法列表
82    ///
83    /// # 参数
84    ///
85    /// * `methods` - 允许的 HTTP 方法迭代器
86    pub fn allow_methods<I, S>(mut self, methods: I) -> Self
87    where
88        I: IntoIterator<Item = S>,
89        S: Into<String>,
90    {
91        self.allowed_methods = methods.into_iter().filter_map(|m| Method::from_bytes(m.into().as_bytes()).ok()).collect();
92        self
93    }
94
95    /// 添加允许的请求头
96    ///
97    /// # 参数
98    ///
99    /// * `header_name` - 允许的请求头名称,如 "Content-Type"
100    pub fn allow_header(mut self, header_name: impl Into<String>) -> Self {
101        if let Ok(name) = HeaderName::try_from(header_name.into()) {
102            self.allowed_headers.push(name);
103        }
104        self
105    }
106
107    /// 设置允许的请求头列表
108    ///
109    /// # 参数
110    ///
111    /// * `headers` - 允许的请求头名称迭代器
112    pub fn allow_headers<I, S>(mut self, headers: I) -> Self
113    where
114        I: IntoIterator<Item = S>,
115        S: Into<String>,
116    {
117        self.allowed_headers = headers.into_iter().filter_map(|h| HeaderName::try_from(h.into()).ok()).collect();
118        self
119    }
120
121    /// 设置是否允许携带凭证
122    ///
123    /// # 参数
124    ///
125    /// * `allow` - 是否允许凭证
126    ///
127    /// # 注意
128    ///
129    /// 当启用凭证模式时,不能使用通配符源。
130    pub fn allow_credentials(mut self, allow: bool) -> Self {
131        self.allow_credentials = allow;
132        self
133    }
134
135    /// 设置预检请求缓存时间
136    ///
137    /// # 参数
138    ///
139    /// * `seconds` - 缓存时间(秒)
140    pub fn max_age(mut self, seconds: u64) -> Self {
141        self.max_age = seconds;
142        self
143    }
144
145    /// 将配置转换为 CorsLayer
146    ///
147    /// # 返回
148    ///
149    /// 返回配置好的 `CorsLayer` 实例。
150    pub fn into_layer(self) -> CorsLayer {
151        let mut cors = CorsLayer::new();
152
153        cors = if self.allowed_origins.is_empty() {
154            cors.allow_origin(AllowOrigin::any())
155        }
156        else {
157            cors.allow_origin(AllowOrigin::list(self.allowed_origins))
158        };
159
160        cors = if self.allowed_methods.is_empty() {
161            cors.allow_methods(AllowMethods::any())
162        }
163        else {
164            cors.allow_methods(AllowMethods::list(self.allowed_methods))
165        };
166
167        cors = if self.allowed_headers.is_empty() {
168            cors.allow_headers(AllowHeaders::any())
169        }
170        else {
171            cors.allow_headers(AllowHeaders::list(self.allowed_headers))
172        };
173
174        cors = if self.allow_credentials { cors.allow_credentials(true) } else { cors };
175
176        cors.max_age(Duration::from_secs(self.max_age))
177    }
178}
179
180impl Default for CorsConfig {
181    fn default() -> Self {
182        Self::new()
183    }
184}
185
186/// 创建允许所有请求的 CORS 配置
187///
188/// 适用于开发环境,允许所有源、方法和请求头。
189pub fn cors_permissive() -> CorsLayer {
190    CorsLayer::permissive()
191}
192
193/// 创建严格模式的 CORS 配置
194///
195/// 仅允许常见的方法和请求头,不包含凭证支持。
196pub fn cors_strict() -> CorsLayer {
197    CorsConfig::new()
198        .allow_methods(["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"])
199        .allow_headers([header::CONTENT_TYPE.as_str(), header::AUTHORIZATION.as_str(), header::ACCEPT.as_str()])
200        .max_age(3600)
201        .into_layer()
202}