sa_token_plugin_rocket/
middleware.rs

1// Author: 金书记
2//
3//! Rocket Fairing (中间件)
4
5use rocket::{Request, Data, Response};
6use rocket::fairing::{Fairing, Info, Kind};
7use rocket::http::{Status, ContentType};
8use crate::SaTokenState;
9use sa_token_core::{token::TokenValue, error::messages};
10use serde_json::json;
11
12/// sa-token Fairing - 提取并验证 token
13pub struct SaTokenFairing {
14    state: SaTokenState,
15}
16
17impl SaTokenFairing {
18    pub fn new(state: SaTokenState) -> Self {
19        Self { state }
20    }
21}
22
23#[rocket::async_trait]
24impl Fairing for SaTokenFairing {
25    fn info(&self) -> Info {
26        Info {
27            name: "SaToken Authentication",
28            kind: Kind::Request,
29        }
30    }
31    
32    async fn on_request(&self, request: &mut Request<'_>, _data: &mut Data<'_>) {
33        // 提取 token
34        let token_str = {
35            let token_name = &self.state.manager.config.token_name;
36            
37            // 1. 从 Header 获取
38            if let Some(header_val) = request.headers().get_one(token_name) {
39                Some(extract_bearer_token(header_val))
40            }
41            // 2. 从 Cookie 获取
42            else if let Some(cookie) = request.cookies().get(token_name) {
43                Some(cookie.value().to_string())
44            }
45            // 3. 从 Query 参数获取
46            else if let Some(query) = request.uri().query() {
47                parse_query_param(query.as_str(), token_name)
48            } else {
49                None
50            }
51        };
52        
53        if let Some(token_str) = token_str {
54            let token = TokenValue::new(token_str);
55            
56            // 验证 token
57            if self.state.manager.is_valid(&token).await {
58                // 存储 token 到本地缓存
59                request.local_cache(|| Some(token.clone()));
60                
61                // 获取并存储 login_id
62                if let Ok(token_info) = self.state.manager.get_token_info(&token).await {
63                    request.local_cache(|| Some(token_info.login_id.clone()));
64                }
65            }
66        }
67    }
68}
69
70/// sa-token 登录检查 Fairing - 强制要求登录
71pub struct SaCheckLoginFairing {
72    state: SaTokenState,
73}
74
75impl SaCheckLoginFairing {
76    pub fn new(state: SaTokenState) -> Self {
77        Self { state }
78    }
79}
80
81#[rocket::async_trait]
82impl Fairing for SaCheckLoginFairing {
83    fn info(&self) -> Info {
84        Info {
85            name: "SaToken Check Login",
86            kind: Kind::Request | Kind::Response,
87        }
88    }
89    
90    async fn on_request(&self, request: &mut Request<'_>, _data: &mut Data<'_>) {
91        // 提取 token
92        let token_str = {
93            let token_name = &self.state.manager.config.token_name;
94            
95            // 1. 从 Header 获取
96            if let Some(header_val) = request.headers().get_one(token_name) {
97                Some(extract_bearer_token(header_val))
98            }
99            // 2. 从 Cookie 获取
100            else if let Some(cookie) = request.cookies().get(token_name) {
101                Some(cookie.value().to_string())
102            }
103            // 3. 从 Query 参数获取
104            else if let Some(query) = request.uri().query() {
105                parse_query_param(query.as_str(), token_name)
106            } else {
107                None
108            }
109        };
110        
111        if let Some(token_str) = token_str {
112            let token = TokenValue::new(token_str);
113            
114            // 验证 token
115            if self.state.manager.is_valid(&token).await {
116                // 存储 token
117                request.local_cache(|| Some(token.clone()));
118                
119                // 获取并存储 login_id
120                if let Ok(token_info) = self.state.manager.get_token_info(&token).await {
121                    request.local_cache(|| Some(token_info.login_id.clone()));
122                }
123                return;
124            }
125        }
126        
127        // 未登录,标记为未授权
128        request.local_cache(|| Some("unauthorized"));
129    }
130    
131    async fn on_response<'r>(&self, request: &'r Request<'_>, response: &mut Response<'r>) {
132        // 检查是否标记为未授权
133        if let Some(_) = request.local_cache(|| None::<&str>) {
134            if *request.local_cache(|| None::<&str>) == Some("unauthorized") {
135                response.set_status(Status::Unauthorized);
136                response.set_sized_body(None, std::io::Cursor::new(
137                    json!({
138                        "code": 401,
139                        "message": messages::AUTH_ERROR
140                    }).to_string()
141                ));
142            }
143        }
144    }
145}
146
147/// sa-token 权限检查 Fairing - 强制要求特定权限
148pub struct SaCheckPermissionFairing {
149    #[allow(dead_code)]
150    state: SaTokenState,
151    permission: String,
152}
153
154impl SaCheckPermissionFairing {
155    pub fn new(state: SaTokenState, permission: impl Into<String>) -> Self {
156        Self {
157            state,
158            permission: permission.into(),
159        }
160    }
161}
162
163#[rocket::async_trait]
164impl Fairing for SaCheckPermissionFairing {
165    fn info(&self) -> Info {
166        Info {
167            name: "SaToken Check Permission",
168            kind: Kind::Request | Kind::Response,
169        }
170    }
171    
172    async fn on_request(&self, request: &mut Request<'_>, _data: &mut Data<'_>) {
173        // 检查是否有登录ID
174        if let Some(login_id) = request.local_cache(|| None::<String>).clone() {
175            // 检查权限
176            if sa_token_core::StpUtil::has_permission(&login_id, &self.permission).await {
177                return;
178            }
179        }
180        
181        // 无权限,标记为禁止访问
182        request.local_cache(|| Some("forbidden"));
183    }
184    
185    async fn on_response<'r>(&self, request: &'r Request<'_>, response: &mut Response<'r>) {
186        // 检查是否标记为禁止访问
187        if let Some(_) = request.local_cache(|| None::<&str>) {
188            if *request.local_cache(|| None::<&str>) == Some("forbidden") {
189                response.set_status(Status::Forbidden);
190                response.set_header(ContentType::JSON);
191                response.set_sized_body(None, std::io::Cursor::new(
192                    json!({
193                        "code": 403,
194                        "message": messages::PERMISSION_REQUIRED
195                    }).to_string()
196                ));
197            }
198        }
199    }
200}
201
202/// sa-token 角色检查 Fairing - 强制要求特定角色
203pub struct SaCheckRoleFairing {
204    #[allow(dead_code)]
205    state: SaTokenState,
206    role: String,
207}
208
209impl SaCheckRoleFairing {
210    pub fn new(state: SaTokenState, role: impl Into<String>) -> Self {
211        Self {
212            state,
213            role: role.into(),
214        }
215    }
216}
217
218#[rocket::async_trait]
219impl Fairing for SaCheckRoleFairing {
220    fn info(&self) -> Info {
221        Info {
222            name: "SaToken Check Role",
223            kind: Kind::Request | Kind::Response,
224        }
225    }
226    
227    async fn on_request(&self, request: &mut Request<'_>, _data: &mut Data<'_>) {
228        // 检查是否有登录ID
229        if let Some(login_id) = request.local_cache(|| None::<String>).clone() {
230            // 检查角色
231            if sa_token_core::StpUtil::has_role(&login_id, &self.role).await {
232                return;
233            }
234        }
235        
236        // 无角色,标记为禁止访问
237        request.local_cache(|| Some("forbidden_role"));
238    }
239    
240    async fn on_response<'r>(&self, request: &'r Request<'_>, response: &mut Response<'r>) {
241        // 检查是否标记为禁止访问
242        if let Some(_) = request.local_cache(|| None::<&str>) {
243            if *request.local_cache(|| None::<&str>) == Some("forbidden_role") {
244                response.set_status(Status::Forbidden);
245                response.set_header(ContentType::JSON);
246                response.set_sized_body(None, std::io::Cursor::new(
247                    json!({
248                        "code": 403,
249                        "message": messages::ROLE_REQUIRED
250                    }).to_string()
251                ));
252            }
253        }
254    }
255}
256
257/// 提取 Bearer token
258fn extract_bearer_token(token: &str) -> String {
259    if token.starts_with("Bearer ") {
260        token[7..].to_string()
261    } else {
262        token.to_string()
263    }
264}
265
266/// 解析查询参数
267fn parse_query_param(query: &str, name: &str) -> Option<String> {
268    for pair in query.split('&') {
269        if let Some((key, value)) = pair.split_once('=') {
270            if key == name {
271                return urlencoding::decode(value).ok().map(|s| s.to_string());
272            }
273        }
274    }
275    None
276}