sa_token_plugin_rocket/
middleware.rs1use rocket::{Request, Data, Response};
6use rocket::fairing::{Fairing, Info, Kind};
7use rocket::http::Status;
8use crate::SaTokenState;
9use sa_token_core::token::TokenValue;
10
11pub struct SaTokenFairing {
13 state: SaTokenState,
14}
15
16impl SaTokenFairing {
17 pub fn new(state: SaTokenState) -> Self {
18 Self { state }
19 }
20}
21
22#[rocket::async_trait]
23impl Fairing for SaTokenFairing {
24 fn info(&self) -> Info {
25 Info {
26 name: "SaToken Authentication",
27 kind: Kind::Request,
28 }
29 }
30
31 async fn on_request(&self, request: &mut Request<'_>, _data: &mut Data<'_>) {
32 let token_str = {
34 let token_name = &self.state.manager.config.token_name;
35
36 if let Some(header_val) = request.headers().get_one(token_name) {
38 Some(extract_bearer_token(header_val))
39 }
40 else if let Some(cookie) = request.cookies().get(token_name) {
42 Some(cookie.value().to_string())
43 }
44 else if let Some(query) = request.uri().query() {
46 parse_query_param(query.as_str(), token_name)
47 } else {
48 None
49 }
50 };
51
52 if let Some(token_str) = token_str {
53 let token = TokenValue::new(token_str);
54
55 if self.state.manager.is_valid(&token).await {
57 request.local_cache(|| Some(token.clone()));
59
60 if let Ok(token_info) = self.state.manager.get_token_info(&token).await {
62 request.local_cache(|| Some(token_info.login_id.clone()));
63 }
64 }
65 }
66 }
67}
68
69pub struct SaCheckLoginFairing {
71 state: SaTokenState,
72}
73
74impl SaCheckLoginFairing {
75 pub fn new(state: SaTokenState) -> Self {
76 Self { state }
77 }
78}
79
80#[rocket::async_trait]
81impl Fairing for SaCheckLoginFairing {
82 fn info(&self) -> Info {
83 Info {
84 name: "SaToken Check Login",
85 kind: Kind::Request | Kind::Response,
86 }
87 }
88
89 async fn on_request(&self, request: &mut Request<'_>, _data: &mut Data<'_>) {
90 let token_str = {
92 let token_name = &self.state.manager.config.token_name;
93
94 if let Some(header_val) = request.headers().get_one(token_name) {
96 Some(extract_bearer_token(header_val))
97 }
98 else if let Some(cookie) = request.cookies().get(token_name) {
100 Some(cookie.value().to_string())
101 }
102 else if let Some(query) = request.uri().query() {
104 parse_query_param(query.as_str(), token_name)
105 } else {
106 None
107 }
108 };
109
110 if let Some(token_str) = token_str {
111 let token = TokenValue::new(token_str);
112
113 if self.state.manager.is_valid(&token).await {
115 request.local_cache(|| Some(token.clone()));
117
118 if let Ok(token_info) = self.state.manager.get_token_info(&token).await {
120 request.local_cache(|| Some(token_info.login_id.clone()));
121 }
122 return;
123 }
124 }
125
126 request.local_cache(|| Some("unauthorized"));
128 }
129
130 async fn on_response<'r>(&self, request: &'r Request<'_>, response: &mut Response<'r>) {
131 if let Some(_) = request.local_cache(|| None::<&str>) {
133 if *request.local_cache(|| None::<&str>) == Some("unauthorized") {
134 response.set_status(Status::Unauthorized);
135 response.set_sized_body(None, std::io::Cursor::new(
136 serde_json::json!({
137 "code": 401,
138 "message": "未登录"
139 }).to_string()
140 ));
141 }
142 }
143 }
144}
145
146fn extract_bearer_token(token: &str) -> String {
148 if token.starts_with("Bearer ") {
149 token[7..].to_string()
150 } else {
151 token.to_string()
152 }
153}
154
155fn parse_query_param(query: &str, name: &str) -> Option<String> {
157 for pair in query.split('&') {
158 if let Some((key, value)) = pair.split_once('=') {
159 if key == name {
160 return urlencoding::decode(value).ok().map(|s| s.to_string());
161 }
162 }
163 }
164 None
165}