sa_token_plugin_axum/
extractor.rs

1// Author: 金书记
2//
3//! Axum提取器
4
5use axum::{
6    extract::FromRequestParts,
7    http::{request::Parts, StatusCode},
8    response::{IntoResponse, Response},
9    Json,
10};
11use sa_token_core::{token::TokenValue, error::messages};
12use serde_json::json;
13
14pub struct SaTokenExtractor(pub TokenValue);
15
16impl<S> FromRequestParts<S> for SaTokenExtractor
17where
18    S: Send + Sync,
19{
20    type Rejection = Response;
21    
22    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
23        match parts.extensions.get::<TokenValue>() {
24            Some(token) => Ok(SaTokenExtractor(token.clone())),
25            None => Err((
26                StatusCode::UNAUTHORIZED,
27                Json(json!({
28                    "code": 401,
29                    "message": messages::AUTH_ERROR
30                }))
31            ).into_response()),
32        }
33    }
34}
35
36pub struct OptionalSaTokenExtractor(pub Option<TokenValue>);
37
38impl<S> FromRequestParts<S> for OptionalSaTokenExtractor
39where
40    S: Send + Sync,
41{
42    type Rejection = std::convert::Infallible;
43    
44    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
45        let token = parts.extensions.get::<TokenValue>().cloned();
46        Ok(OptionalSaTokenExtractor(token))
47    }
48}
49
50pub struct LoginIdExtractor(pub String);
51
52impl<S> FromRequestParts<S> for LoginIdExtractor
53where
54    S: Send + Sync,
55{
56    type Rejection = Response;
57    
58    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
59        match parts.extensions.get::<String>() {
60            Some(login_id) => Ok(LoginIdExtractor(login_id.clone())),
61            None => Err((
62                StatusCode::UNAUTHORIZED,
63                Json(json!({
64                    "code": 401,
65                    "message": messages::AUTH_ERROR
66                }))
67            ).into_response()),
68        }
69    }
70}