sa_token_plugin_rocket/
extractor.rs

1// Author: 金书记
2//
3//! Rocket Request Guards (提取器)
4
5use rocket::request::{FromRequest, Request, Outcome};
6use rocket::http::Status;
7use rocket::http::ContentType;
8use rocket::response::{self, Responder};
9use sa_token_core::{token::TokenValue, error::messages};
10use serde_json::json;
11
12/// 认证错误响应
13#[derive(Debug)]
14pub struct AuthError {
15    json: String,
16}
17
18impl<'r> Responder<'r, 'static> for AuthError {
19    fn respond_to(self, _: &'r Request<'_>) -> response::Result<'static> {
20        let mut response = rocket::Response::new();
21        response.set_header(ContentType::JSON);
22        response.set_status(Status::Unauthorized);
23        response.set_sized_body(self.json.len(), std::io::Cursor::new(self.json));
24        Ok(response)
25    }
26}
27
28/// Token 守卫 - 必须存在,否则返回错误
29pub struct SaTokenGuard(pub TokenValue);
30
31impl SaTokenGuard {
32    pub fn token(&self) -> &TokenValue {
33        &self.0
34    }
35}
36
37#[rocket::async_trait]
38impl<'r> FromRequest<'r> for SaTokenGuard {
39    type Error = AuthError;
40    
41    async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
42        let token = request.local_cache(|| None::<TokenValue>);
43        if let Some(token) = token {
44            return Outcome::Success(SaTokenGuard(token.clone()));
45        }
46        
47        let error = json!({
48            "code": 401,
49            "message": messages::AUTH_ERROR
50        }).to_string();
51        
52        Outcome::Error((Status::Unauthorized, AuthError { json: error }))
53    }
54}
55
56/// 可选 Token 守卫 - 不存在也不报错
57pub struct OptionalSaTokenGuard(pub Option<TokenValue>);
58
59#[rocket::async_trait]
60impl<'r> FromRequest<'r> for OptionalSaTokenGuard {
61    type Error = ();
62    
63    async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
64        let token = request.local_cache(|| None::<TokenValue>).clone();
65        Outcome::Success(OptionalSaTokenGuard(token))
66    }
67}
68
69/// LoginId 守卫 - 直接获取登录用户的 ID
70pub struct LoginIdGuard(pub String);
71
72impl LoginIdGuard {
73    pub fn login_id(&self) -> &str {
74        &self.0
75    }
76}
77
78#[rocket::async_trait]
79impl<'r> FromRequest<'r> for LoginIdGuard {
80    type Error = AuthError;
81    
82    async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
83        let login_id = request.local_cache(|| None::<String>);
84        if let Some(login_id) = login_id {
85            return Outcome::Success(LoginIdGuard(login_id.clone()));
86        }
87        
88        let error = json!({
89            "code": 401,
90            "message": messages::AUTH_ERROR
91        }).to_string();
92        
93        Outcome::Error((Status::Unauthorized, AuthError { json: error }))
94    }
95}