sa_token_plugin_rocket/
middleware.rs1use rocket::{Request, Data, Response};
6use rocket::fairing::{Fairing, Info, Kind};
7use rocket::http::{Status, ContentType};
8use crate::SaTokenState;
9use sa_token_core::{token::TokenValue, error::messages};
10use serde_json::json;
11
12pub struct SaTokenFairing {
14 state: SaTokenState,
15}
16
17impl SaTokenFairing {
18 pub fn new(state: SaTokenState) -> Self {
19 Self { state }
20 }
21}
22
23#[rocket::async_trait]
24impl Fairing for SaTokenFairing {
25 fn info(&self) -> Info {
26 Info {
27 name: "SaToken Authentication",
28 kind: Kind::Request,
29 }
30 }
31
32 async fn on_request(&self, request: &mut Request<'_>, _data: &mut Data<'_>) {
33 let token_str = {
35 let token_name = &self.state.manager.config.token_name;
36
37 if let Some(header_val) = request.headers().get_one(token_name) {
39 Some(extract_bearer_token(header_val))
40 }
41 else if let Some(cookie) = request.cookies().get(token_name) {
43 Some(cookie.value().to_string())
44 }
45 else if let Some(query) = request.uri().query() {
47 parse_query_param(query.as_str(), token_name)
48 } else {
49 None
50 }
51 };
52
53 if let Some(token_str) = token_str {
54 let token = TokenValue::new(token_str);
55
56 if self.state.manager.is_valid(&token).await {
58 request.local_cache(|| Some(token.clone()));
60
61 if let Ok(token_info) = self.state.manager.get_token_info(&token).await {
63 request.local_cache(|| Some(token_info.login_id.clone()));
64 }
65 }
66 }
67 }
68}
69
70pub struct SaCheckLoginFairing {
72 state: SaTokenState,
73}
74
75impl SaCheckLoginFairing {
76 pub fn new(state: SaTokenState) -> Self {
77 Self { state }
78 }
79}
80
81#[rocket::async_trait]
82impl Fairing for SaCheckLoginFairing {
83 fn info(&self) -> Info {
84 Info {
85 name: "SaToken Check Login",
86 kind: Kind::Request | Kind::Response,
87 }
88 }
89
90 async fn on_request(&self, request: &mut Request<'_>, _data: &mut Data<'_>) {
91 let token_str = {
93 let token_name = &self.state.manager.config.token_name;
94
95 if let Some(header_val) = request.headers().get_one(token_name) {
97 Some(extract_bearer_token(header_val))
98 }
99 else if let Some(cookie) = request.cookies().get(token_name) {
101 Some(cookie.value().to_string())
102 }
103 else if let Some(query) = request.uri().query() {
105 parse_query_param(query.as_str(), token_name)
106 } else {
107 None
108 }
109 };
110
111 if let Some(token_str) = token_str {
112 let token = TokenValue::new(token_str);
113
114 if self.state.manager.is_valid(&token).await {
116 request.local_cache(|| Some(token.clone()));
118
119 if let Ok(token_info) = self.state.manager.get_token_info(&token).await {
121 request.local_cache(|| Some(token_info.login_id.clone()));
122 }
123 return;
124 }
125 }
126
127 request.local_cache(|| Some("unauthorized"));
129 }
130
131 async fn on_response<'r>(&self, request: &'r Request<'_>, response: &mut Response<'r>) {
132 if let Some(_) = request.local_cache(|| None::<&str>) {
134 if *request.local_cache(|| None::<&str>) == Some("unauthorized") {
135 response.set_status(Status::Unauthorized);
136 response.set_sized_body(None, std::io::Cursor::new(
137 json!({
138 "code": 401,
139 "message": messages::AUTH_ERROR
140 }).to_string()
141 ));
142 }
143 }
144 }
145}
146
147pub struct SaCheckPermissionFairing {
149 #[allow(dead_code)]
150 state: SaTokenState,
151 permission: String,
152}
153
154impl SaCheckPermissionFairing {
155 pub fn new(state: SaTokenState, permission: impl Into<String>) -> Self {
156 Self {
157 state,
158 permission: permission.into(),
159 }
160 }
161}
162
163#[rocket::async_trait]
164impl Fairing for SaCheckPermissionFairing {
165 fn info(&self) -> Info {
166 Info {
167 name: "SaToken Check Permission",
168 kind: Kind::Request | Kind::Response,
169 }
170 }
171
172 async fn on_request(&self, request: &mut Request<'_>, _data: &mut Data<'_>) {
173 if let Some(login_id) = request.local_cache(|| None::<String>).clone() {
175 if sa_token_core::StpUtil::has_permission(&login_id, &self.permission).await {
177 return;
178 }
179 }
180
181 request.local_cache(|| Some("forbidden"));
183 }
184
185 async fn on_response<'r>(&self, request: &'r Request<'_>, response: &mut Response<'r>) {
186 if let Some(_) = request.local_cache(|| None::<&str>) {
188 if *request.local_cache(|| None::<&str>) == Some("forbidden") {
189 response.set_status(Status::Forbidden);
190 response.set_header(ContentType::JSON);
191 response.set_sized_body(None, std::io::Cursor::new(
192 json!({
193 "code": 403,
194 "message": messages::PERMISSION_REQUIRED
195 }).to_string()
196 ));
197 }
198 }
199 }
200}
201
202pub struct SaCheckRoleFairing {
204 #[allow(dead_code)]
205 state: SaTokenState,
206 role: String,
207}
208
209impl SaCheckRoleFairing {
210 pub fn new(state: SaTokenState, role: impl Into<String>) -> Self {
211 Self {
212 state,
213 role: role.into(),
214 }
215 }
216}
217
218#[rocket::async_trait]
219impl Fairing for SaCheckRoleFairing {
220 fn info(&self) -> Info {
221 Info {
222 name: "SaToken Check Role",
223 kind: Kind::Request | Kind::Response,
224 }
225 }
226
227 async fn on_request(&self, request: &mut Request<'_>, _data: &mut Data<'_>) {
228 if let Some(login_id) = request.local_cache(|| None::<String>).clone() {
230 if sa_token_core::StpUtil::has_role(&login_id, &self.role).await {
232 return;
233 }
234 }
235
236 request.local_cache(|| Some("forbidden_role"));
238 }
239
240 async fn on_response<'r>(&self, request: &'r Request<'_>, response: &mut Response<'r>) {
241 if let Some(_) = request.local_cache(|| None::<&str>) {
243 if *request.local_cache(|| None::<&str>) == Some("forbidden_role") {
244 response.set_status(Status::Forbidden);
245 response.set_header(ContentType::JSON);
246 response.set_sized_body(None, std::io::Cursor::new(
247 json!({
248 "code": 403,
249 "message": messages::ROLE_REQUIRED
250 }).to_string()
251 ));
252 }
253 }
254 }
255}
256
257fn extract_bearer_token(token: &str) -> String {
259 if token.starts_with("Bearer ") {
260 token[7..].to_string()
261 } else {
262 token.to_string()
263 }
264}
265
266fn parse_query_param(query: &str, name: &str) -> Option<String> {
268 for pair in query.split('&') {
269 if let Some((key, value)) = pair.split_once('=') {
270 if key == name {
271 return urlencoding::decode(value).ok().map(|s| s.to_string());
272 }
273 }
274 }
275 None
276}