Skip to main content

sa_token_plugin_rocket_v05/
extractor.rs

1// Author: 金书记
2//
3//! Rocket Request Guards (提取器)
4
5use rocket::request::{FromRequest, Request, Outcome};
6use rocket::http::Status;
7use rocket::http::ContentType;
8use rocket::response::{self, Responder};
9use sa_token_core::{token::TokenValue, error::messages, SaTokenContext};
10use std::sync::Arc;
11use serde_json::json;
12
13/// 认证错误响应
14#[derive(Debug)]
15pub struct AuthError {
16    json: String,
17}
18
19impl<'r> Responder<'r, 'static> for AuthError {
20    fn respond_to(self, _: &'r Request<'_>) -> response::Result<'static> {
21        let mut response = rocket::Response::new();
22        response.set_header(ContentType::JSON);
23        response.set_status(Status::Unauthorized);
24        response.set_sized_body(self.json.len(), std::io::Cursor::new(self.json));
25        Ok(response)
26    }
27}
28
29/// Token 守卫 - 必须存在,否则返回错误
30pub struct SaTokenGuard(pub TokenValue);
31
32impl SaTokenGuard {
33    pub fn token(&self) -> &TokenValue {
34        &self.0
35    }
36}
37
38#[rocket::async_trait]
39impl<'r> FromRequest<'r> for SaTokenGuard {
40    type Error = AuthError;
41    
42    async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
43        let token = request.local_cache(|| None::<TokenValue>);
44        if let Some(token) = token {
45            return Outcome::Success(SaTokenGuard(token.clone()));
46        }
47        
48        let error = json!({
49            "code": 401,
50            "message": messages::AUTH_ERROR
51        }).to_string();
52        
53        Outcome::Error((Status::Unauthorized, AuthError { json: error }))
54    }
55}
56
57/// 可选 Token 守卫 - 不存在也不报错
58pub struct OptionalSaTokenGuard(pub Option<TokenValue>);
59
60#[rocket::async_trait]
61impl<'r> FromRequest<'r> for OptionalSaTokenGuard {
62    type Error = ();
63    
64    async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
65        let token = request.local_cache(|| None::<TokenValue>).clone();
66        Outcome::Success(OptionalSaTokenGuard(token))
67    }
68}
69
70/// 请求级 [`SaTokenContext`](来自 Fairing 写入的 `local_cache`,跨 `await` 安全)。
71///
72/// 若未挂载 [`crate::SaTokenLayer`],工厂会返回空上下文。
73pub struct SaCtx(pub Arc<SaTokenContext>);
74
75#[rocket::async_trait]
76impl<'r> FromRequest<'r> for SaCtx {
77    type Error = ();
78
79    async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
80        let ctx = req.local_cache(|| Arc::new(SaTokenContext::new()));
81        Outcome::Success(SaCtx(ctx.clone()))
82    }
83}
84
85/// LoginId 守卫 - 直接获取登录用户的 ID
86pub struct LoginIdGuard(pub String);
87
88impl LoginIdGuard {
89    pub fn login_id(&self) -> &str {
90        &self.0
91    }
92}
93
94#[rocket::async_trait]
95impl<'r> FromRequest<'r> for LoginIdGuard {
96    type Error = AuthError;
97    
98    async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
99        let login_id = request.local_cache(|| None::<String>);
100        if let Some(login_id) = login_id {
101            return Outcome::Success(LoginIdGuard(login_id.clone()));
102        }
103        
104        let error = json!({
105            "code": 401,
106            "message": messages::AUTH_ERROR
107        }).to_string();
108        
109        Outcome::Error((Status::Unauthorized, AuthError { json: error }))
110    }
111}