sa_token_plugin_axum/
middleware.rs1use std::task::{Context, Poll};
10use tower::{Layer, Service};
11use http::{Request, Response, StatusCode};
12use http_body;
13use serde_json::json;
14use sa_token_core::error::messages;
15
16pub use crate::layer::SaTokenMiddleware;
17
18#[derive(Clone)]
20pub struct SaCheckLoginLayer;
21
22impl SaCheckLoginLayer {
23 pub fn new() -> Self {
24 Self
25 }
26}
27
28impl<S> Layer<S> for SaCheckLoginLayer {
29 type Service = SaCheckLoginMiddleware<S>;
30
31 fn layer(&self, inner: S) -> Self::Service {
32 SaCheckLoginMiddleware { inner }
33 }
34}
35
36#[derive(Clone)]
40pub struct SaCheckLoginMiddleware<S> {
41 inner: S,
42}
43
44#[derive(Clone)]
46pub struct SaCheckPermissionLayer {
47 permission: String,
48}
49
50impl SaCheckPermissionLayer {
51 pub fn new(permission: impl Into<String>) -> Self {
52 Self {
53 permission: permission.into(),
54 }
55 }
56}
57
58impl<S> Layer<S> for SaCheckPermissionLayer {
59 type Service = SaCheckPermissionMiddleware<S>;
60
61 fn layer(&self, inner: S) -> Self::Service {
62 SaCheckPermissionMiddleware {
63 inner,
64 permission: self.permission.clone(),
65 }
66 }
67}
68
69#[derive(Clone)]
73pub struct SaCheckPermissionMiddleware<S> {
74 inner: S,
75 permission: String,
76}
77
78impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for SaCheckLoginMiddleware<S>
79where
80 S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
81 S::Future: Send + 'static,
82 ReqBody: Send + 'static,
83 ResBody: http_body::Body + Default + Send + 'static,
84{
85 type Response = S::Response;
86 type Error = S::Error;
87 type Future = std::pin::Pin<Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>>;
88
89 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
90 self.inner.poll_ready(cx)
91 }
92
93 fn call(&mut self, request: Request<ReqBody>) -> Self::Future {
94 let mut inner = self.inner.clone();
95
96 Box::pin(async move {
97 if request.extensions().get::<String>().is_none() {
99 let mut response = Response::builder()
103 .status(StatusCode::UNAUTHORIZED)
104 .body(ResBody::default())
105 .expect("Unable to create response");
106
107 let error_json = serde_json::to_string(&json!({
109 "code": 401,
110 "message": messages::AUTH_ERROR
111 })).unwrap_or_default();
112
113 if let Ok(header_value) = http::header::HeaderValue::from_str(&error_json) {
115 response.headers_mut().insert("X-Sa-Token-Error", header_value);
116 }
117
118 return Ok(response);
119 }
120
121 inner.call(request).await
123 })
124 }
125}
126
127impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for SaCheckPermissionMiddleware<S>
128where
129 S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
130 S::Future: Send + 'static,
131 ReqBody: Send + 'static,
132 ResBody: http_body::Body + Default + Send + 'static,
133{
134 type Response = S::Response;
135 type Error = S::Error;
136 type Future = std::pin::Pin<Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>>;
137
138 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
139 self.inner.poll_ready(cx)
140 }
141
142 fn call(&mut self, request: Request<ReqBody>) -> Self::Future {
143 let mut inner = self.inner.clone();
144 let permission = self.permission.clone();
145
146 Box::pin(async move {
147 if let Some(login_id) = request.extensions().get::<String>() {
149 if sa_token_core::StpUtil::has_permission(login_id, &permission).await {
151 return inner.call(request).await;
153 }
154 }
155
156 let mut response = Response::builder()
158 .status(StatusCode::FORBIDDEN)
159 .body(ResBody::default())
160 .expect("Unable to create response");
161
162 let error_json = serde_json::to_string(&json!({
164 "code": 403,
165 "message": messages::PERMISSION_REQUIRED
166 })).unwrap_or_default();
167
168 if let Ok(header_value) = http::header::HeaderValue::from_str(&error_json) {
170 response.headers_mut().insert("X-Sa-Token-Error", header_value);
171 }
172
173 Ok(response)
174 })
175 }
176}