1use std::{
2 future::Future,
3 pin::Pin,
4 task::{Context, Poll},
5};
6
7use http::{Request, Response};
8use shield::{Session, Shield, User};
9use tower_service::Service;
10use tracing::debug;
11
12use crate::session::TowerSessionStorage;
13
14#[derive(Clone)]
15pub struct ShieldService<S, U: User> {
16 inner: S,
17 shield: Shield<U>,
18 session_key: &'static str,
19}
20
21impl<S, U: User> ShieldService<S, U> {
22 pub fn new(inner: S, shield: Shield<U>, session_key: &'static str) -> Self {
23 Self {
24 inner,
25 shield,
26 session_key,
27 }
28 }
29
30 fn internal_server_error<ResBody: Default>() -> Response<ResBody> {
31 let mut response = Response::default();
32 *response.status_mut() = http::StatusCode::INTERNAL_SERVER_ERROR;
33 response
34 }
35}
36
37impl<S, U: User + Clone + 'static, ReqBody, ResBody> Service<Request<ReqBody>>
38 for ShieldService<S, U>
39where
40 S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
41 S::Future: Send + 'static,
42 ReqBody: Send + 'static,
43 ResBody: Default + Send,
44{
45 type Response = S::Response;
46 type Error = S::Error;
47 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
48
49 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
50 self.inner.poll_ready(cx)
51 }
52
53 fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
54 let clone = self.inner.clone();
58 let mut inner = std::mem::replace(&mut self.inner, clone);
59
60 let shield = self.shield.clone();
61 let session_key = self.session_key;
62
63 Box::pin(async move {
64 let session = match req.extensions().get::<tower_sessions::Session>() {
65 Some(session) => session,
66 None => {
67 return Ok(Self::internal_server_error());
68 }
69 };
70
71 let session_storage =
72 match TowerSessionStorage::load(session.clone(), session_key).await {
73 Ok(session_storage) => session_storage,
74 Err(_err) => return Ok(Self::internal_server_error()),
75 };
76 let shield_session = Session::new(session_storage);
77
78 let authenticated = match shield_session.data().lock() {
79 Ok(session) => session.authentication.clone(),
80 Err(_err) => return Ok(Self::internal_server_error()),
81 };
82
83 let user = if let Some(authenticated) = authenticated {
84 match shield.storage().user_by_id(&authenticated.user_id).await {
87 Ok(user) => {
88 if user.is_none() {
89 if let Err(_err) = shield_session.purge().await {
90 return Ok(Self::internal_server_error());
91 }
92 }
93
94 user
95 }
96 Err(_err) => return Ok(Self::internal_server_error()),
97 }
98 } else {
99 None
100 };
101
102 debug!("{:?}", user.as_ref().map(|user| user.id()));
103
104 req.extensions_mut().insert(shield);
105 req.extensions_mut().insert(shield_session);
106 req.extensions_mut().insert(user);
107
108 inner.call(req).await
109 })
110 }
111}