Skip to main content

sa_token_plugin_rocket_v05/
middleware.rs

1// Author: 金书记
2//
3//! Rocket Fairings sharing the same **`run_auth_flow`** rules as [`SaTokenLayer`](crate::layer::SaTokenLayer) (see each `on_request`).
4//! 与 [`SaTokenLayer`](crate::layer::SaTokenLayer) 共用 **`run_auth_flow`** 规则的 Fairing(详见各 `on_request`)。
5
6use rocket::{Data, Request, Response};
7use rocket::fairing::{Fairing, Info, Kind};
8use rocket::http::{ContentType, Status};
9use sa_token_core::error::messages;
10use sa_token_plugin_rocket_core::run_auth_flow;
11use serde_json::json;
12
13use crate::adapter::RocketCapturedRequest;
14use crate::SaTokenState;
15
16/// sa-token Fairing - 提取并验证 token
17pub struct SaTokenFairing {
18    state: SaTokenState,
19}
20
21impl SaTokenFairing {
22    pub fn new(state: SaTokenState) -> Self {
23        Self { state }
24    }
25}
26
27#[rocket::async_trait]
28impl Fairing for SaTokenFairing {
29    fn info(&self) -> Info {
30        Info {
31            name: "SaToken Authentication",
32            kind: Kind::Request,
33        }
34    }
35
36    async fn on_request(&self, request: &mut Request<'_>, _data: &mut Data<'_>) {
37        let adapter = RocketCapturedRequest::capture(
38            request,
39            self.state.manager.config.token_name.as_str(),
40        );
41        let flow = run_auth_flow(&adapter, &self.state.manager, None).await;
42
43        if let Some(ref t) = flow.token {
44            request.local_cache(|| Some(t.clone()));
45        }
46        if let Some(ref id) = flow.login_id {
47            request.local_cache(|| Some(id.clone()));
48        }
49    }
50}
51
52/// sa-token 登录检查 Fairing - 强制要求登录
53pub struct SaCheckLoginFairing {
54    state: SaTokenState,
55}
56
57impl SaCheckLoginFairing {
58    pub fn new(state: SaTokenState) -> Self {
59        Self { state }
60    }
61}
62
63#[rocket::async_trait]
64impl Fairing for SaCheckLoginFairing {
65    fn info(&self) -> Info {
66        Info {
67            name: "SaToken Check Login",
68            kind: Kind::Request | Kind::Response,
69        }
70    }
71
72    async fn on_request(&self, request: &mut Request<'_>, _data: &mut Data<'_>) {
73        let adapter = RocketCapturedRequest::capture(
74            request,
75            self.state.manager.config.token_name.as_str(),
76        );
77        let flow = run_auth_flow(&adapter, &self.state.manager, None).await;
78
79        if flow.login_id.is_some() {
80            if let Some(ref t) = flow.token {
81                request.local_cache(|| Some(t.clone()));
82            }
83            if let Some(ref id) = flow.login_id {
84                request.local_cache(|| Some(id.clone()));
85            }
86            return;
87        }
88
89        request.local_cache(|| Some("unauthorized"));
90    }
91
92    async fn on_response<'r>(&self, request: &'r Request<'_>, response: &mut Response<'r>) {
93        // 检查是否标记为未授权
94        if request.local_cache(|| None::<&str>).is_some()
95            && *request.local_cache(|| None::<&str>) == Some("unauthorized") {
96                response.set_status(Status::Unauthorized);
97                response.set_sized_body(
98                    None,
99                    std::io::Cursor::new(
100                        json!({
101                            "code": 401,
102                            "message": messages::AUTH_ERROR
103                        })
104                        .to_string(),
105                    ),
106                );
107            }
108    }
109}
110
111/// sa-token 权限检查 Fairing - 强制要求特定权限
112pub struct SaCheckPermissionFairing {
113    #[allow(dead_code)]
114    state: SaTokenState,
115    permission: String,
116}
117
118impl SaCheckPermissionFairing {
119    pub fn new(state: SaTokenState, permission: impl Into<String>) -> Self {
120        Self {
121            state,
122            permission: permission.into(),
123        }
124    }
125}
126
127#[rocket::async_trait]
128impl Fairing for SaCheckPermissionFairing {
129    fn info(&self) -> Info {
130        Info {
131            name: "SaToken Check Permission",
132            kind: Kind::Request | Kind::Response,
133        }
134    }
135
136    async fn on_request(&self, request: &mut Request<'_>, _data: &mut Data<'_>) {
137        // 检查是否有登录ID
138        if let Some(login_id) = request.local_cache(|| None::<String>).clone() {
139            // 检查权限
140            if sa_token_core::StpUtil::has_permission(&login_id, &self.permission).await {
141                return;
142            }
143        }
144
145        // 无权限,标记为禁止访问
146        request.local_cache(|| Some("forbidden"));
147    }
148
149    async fn on_response<'r>(&self, request: &'r Request<'_>, response: &mut Response<'r>) {
150        // 检查是否标记为禁止访问
151        if request.local_cache(|| None::<&str>).is_some()
152            && *request.local_cache(|| None::<&str>) == Some("forbidden") {
153                response.set_status(Status::Forbidden);
154                response.set_header(ContentType::JSON);
155                response.set_sized_body(
156                    None,
157                    std::io::Cursor::new(
158                        json!({
159                            "code": 403,
160                            "message": messages::PERMISSION_REQUIRED
161                        })
162                        .to_string(),
163                    ),
164                );
165            }
166    }
167}
168
169/// sa-token 角色检查 Fairing - 强制要求特定角色
170pub struct SaCheckRoleFairing {
171    #[allow(dead_code)]
172    state: SaTokenState,
173    role: String,
174}
175
176impl SaCheckRoleFairing {
177    pub fn new(state: SaTokenState, role: impl Into<String>) -> Self {
178        Self {
179            state,
180            role: role.into(),
181        }
182    }
183}
184
185#[rocket::async_trait]
186impl Fairing for SaCheckRoleFairing {
187    fn info(&self) -> Info {
188        Info {
189            name: "SaToken Check Role",
190            kind: Kind::Request | Kind::Response,
191        }
192    }
193
194    async fn on_request(&self, request: &mut Request<'_>, _data: &mut Data<'_>) {
195        // 检查是否有登录ID
196        if let Some(login_id) = request.local_cache(|| None::<String>).clone() {
197            // 检查角色
198            if sa_token_core::StpUtil::has_role(&login_id, &self.role).await {
199                return;
200            }
201        }
202
203        // 无角色,标记为禁止访问
204        request.local_cache(|| Some("forbidden_role"));
205    }
206
207    async fn on_response<'r>(&self, request: &'r Request<'_>, response: &mut Response<'r>) {
208        // 检查是否标记为禁止访问
209        if request.local_cache(|| None::<&str>).is_some()
210            && *request.local_cache(|| None::<&str>) == Some("forbidden_role") {
211                response.set_status(Status::Forbidden);
212                response.set_header(ContentType::JSON);
213                response.set_sized_body(
214                    None,
215                    std::io::Cursor::new(
216                        json!({
217                            "code": 403,
218                            "message": messages::ROLE_REQUIRED
219                        })
220                        .to_string(),
221                    ),
222                );
223            }
224    }
225}