sa_token_plugin_rocket_v05/
middleware.rs1use rocket::{Data, Request, Response};
7use rocket::fairing::{Fairing, Info, Kind};
8use rocket::http::{ContentType, Status};
9use sa_token_core::error::messages;
10use sa_token_plugin_rocket_core::run_auth_flow;
11use serde_json::json;
12
13use crate::adapter::RocketCapturedRequest;
14use crate::SaTokenState;
15
16pub struct SaTokenFairing {
18 state: SaTokenState,
19}
20
21impl SaTokenFairing {
22 pub fn new(state: SaTokenState) -> Self {
23 Self { state }
24 }
25}
26
27#[rocket::async_trait]
28impl Fairing for SaTokenFairing {
29 fn info(&self) -> Info {
30 Info {
31 name: "SaToken Authentication",
32 kind: Kind::Request,
33 }
34 }
35
36 async fn on_request(&self, request: &mut Request<'_>, _data: &mut Data<'_>) {
37 let adapter = RocketCapturedRequest::capture(
38 request,
39 self.state.manager.config.token_name.as_str(),
40 );
41 let flow = run_auth_flow(&adapter, &self.state.manager, None).await;
42
43 if let Some(ref t) = flow.token {
44 request.local_cache(|| Some(t.clone()));
45 }
46 if let Some(ref id) = flow.login_id {
47 request.local_cache(|| Some(id.clone()));
48 }
49 }
50}
51
52pub struct SaCheckLoginFairing {
54 state: SaTokenState,
55}
56
57impl SaCheckLoginFairing {
58 pub fn new(state: SaTokenState) -> Self {
59 Self { state }
60 }
61}
62
63#[rocket::async_trait]
64impl Fairing for SaCheckLoginFairing {
65 fn info(&self) -> Info {
66 Info {
67 name: "SaToken Check Login",
68 kind: Kind::Request | Kind::Response,
69 }
70 }
71
72 async fn on_request(&self, request: &mut Request<'_>, _data: &mut Data<'_>) {
73 let adapter = RocketCapturedRequest::capture(
74 request,
75 self.state.manager.config.token_name.as_str(),
76 );
77 let flow = run_auth_flow(&adapter, &self.state.manager, None).await;
78
79 if flow.login_id.is_some() {
80 if let Some(ref t) = flow.token {
81 request.local_cache(|| Some(t.clone()));
82 }
83 if let Some(ref id) = flow.login_id {
84 request.local_cache(|| Some(id.clone()));
85 }
86 return;
87 }
88
89 request.local_cache(|| Some("unauthorized"));
90 }
91
92 async fn on_response<'r>(&self, request: &'r Request<'_>, response: &mut Response<'r>) {
93 if request.local_cache(|| None::<&str>).is_some()
95 && *request.local_cache(|| None::<&str>) == Some("unauthorized") {
96 response.set_status(Status::Unauthorized);
97 response.set_sized_body(
98 None,
99 std::io::Cursor::new(
100 json!({
101 "code": 401,
102 "message": messages::AUTH_ERROR
103 })
104 .to_string(),
105 ),
106 );
107 }
108 }
109}
110
111pub struct SaCheckPermissionFairing {
113 #[allow(dead_code)]
114 state: SaTokenState,
115 permission: String,
116}
117
118impl SaCheckPermissionFairing {
119 pub fn new(state: SaTokenState, permission: impl Into<String>) -> Self {
120 Self {
121 state,
122 permission: permission.into(),
123 }
124 }
125}
126
127#[rocket::async_trait]
128impl Fairing for SaCheckPermissionFairing {
129 fn info(&self) -> Info {
130 Info {
131 name: "SaToken Check Permission",
132 kind: Kind::Request | Kind::Response,
133 }
134 }
135
136 async fn on_request(&self, request: &mut Request<'_>, _data: &mut Data<'_>) {
137 if let Some(login_id) = request.local_cache(|| None::<String>).clone() {
139 if sa_token_core::StpUtil::has_permission(&login_id, &self.permission).await {
141 return;
142 }
143 }
144
145 request.local_cache(|| Some("forbidden"));
147 }
148
149 async fn on_response<'r>(&self, request: &'r Request<'_>, response: &mut Response<'r>) {
150 if request.local_cache(|| None::<&str>).is_some()
152 && *request.local_cache(|| None::<&str>) == Some("forbidden") {
153 response.set_status(Status::Forbidden);
154 response.set_header(ContentType::JSON);
155 response.set_sized_body(
156 None,
157 std::io::Cursor::new(
158 json!({
159 "code": 403,
160 "message": messages::PERMISSION_REQUIRED
161 })
162 .to_string(),
163 ),
164 );
165 }
166 }
167}
168
169pub struct SaCheckRoleFairing {
171 #[allow(dead_code)]
172 state: SaTokenState,
173 role: String,
174}
175
176impl SaCheckRoleFairing {
177 pub fn new(state: SaTokenState, role: impl Into<String>) -> Self {
178 Self {
179 state,
180 role: role.into(),
181 }
182 }
183}
184
185#[rocket::async_trait]
186impl Fairing for SaCheckRoleFairing {
187 fn info(&self) -> Info {
188 Info {
189 name: "SaToken Check Role",
190 kind: Kind::Request | Kind::Response,
191 }
192 }
193
194 async fn on_request(&self, request: &mut Request<'_>, _data: &mut Data<'_>) {
195 if let Some(login_id) = request.local_cache(|| None::<String>).clone() {
197 if sa_token_core::StpUtil::has_role(&login_id, &self.role).await {
199 return;
200 }
201 }
202
203 request.local_cache(|| Some("forbidden_role"));
205 }
206
207 async fn on_response<'r>(&self, request: &'r Request<'_>, response: &mut Response<'r>) {
208 if request.local_cache(|| None::<&str>).is_some()
210 && *request.local_cache(|| None::<&str>) == Some("forbidden_role") {
211 response.set_status(Status::Forbidden);
212 response.set_header(ContentType::JSON);
213 response.set_sized_body(
214 None,
215 std::io::Cursor::new(
216 json!({
217 "code": 403,
218 "message": messages::ROLE_REQUIRED
219 })
220 .to_string(),
221 ),
222 );
223 }
224 }
225}