spa_rs/session.rs
1//! A tower middleware who can reading and writing session data from Cookie.
2//!
3use crate::filter::Predicate;
4use axum::{extract::Request, http::StatusCode, response::Response};
5use headers::{Cookie, HeaderMapExt};
6use parking_lot::RwLock;
7use std::{cmp::PartialEq, collections::HashMap, sync::Arc};
8
9/// Session object, can access by Extension in RequireSession layer.
10///
11/// See [RequireSession] example for usage
12#[derive(Clone)]
13pub struct Session<T> {
14 /// current session data
15 pub current: T,
16 /// session storage
17 pub all: Arc<SessionStore<T>>,
18}
19
20/// Session storage, can access by Extersion in AddSession layer.
21///
22/// See [AddSession] example for usage
23#[derive(Debug)]
24pub struct SessionStore<T> {
25 key: String,
26 inner: RwLock<HashMap<String, T>>,
27}
28
29impl<T: PartialEq> SessionStore<T> {
30 /// return new SessionStore with specific key
31 pub fn new(key: impl Into<String>) -> Self {
32 SessionStore {
33 key: key.into(),
34 inner: RwLock::new(HashMap::new()),
35 }
36 }
37
38 /// get the key reference
39 pub fn key(&self) -> &str {
40 &self.key
41 }
42
43 /// insert a new session item
44 pub fn insert(&self, k: impl Into<String>, v: T) {
45 self.inner.write().insert(k.into(), v);
46 }
47
48 /// remove the session item
49 pub fn remove(&self, v: T) {
50 self.inner.write().retain(|_, x| *x != v);
51 }
52}
53
54/// Middleware that can access and modify all sessions data. Usually used for **Login** handler
55///
56/// # Example
57///```
58/// # use spa_rs::routing::{post, Router};
59/// # use spa_rs::Extension;
60/// # use spa_rs::session::AddSession;
61/// # use spa_rs::session::SessionStore;
62/// # use axum_help::filter::FilterExLayer;
63/// # use std::sync::Arc;
64/// #
65/// #[derive(PartialEq, Clone)]
66/// struct User;
67///
68/// async fn login(Extension(session): Extension<Arc<SessionStore<User>>>) {
69/// let new_user = User;
70/// session.insert("session_id", new_user);
71/// }
72///
73/// #[tokio::main]
74/// async fn main() {
75/// let session = Arc::new(SessionStore::<User>::new("my_session"));
76/// let app = Router::new()
77/// .route("/login", post(login))
78/// .layer(FilterExLayer::new(AddSession::new(session.clone())));
79/// # axum::Server::bind(&"0.0.0.0:3000".parse().unwrap()).serve(app.into_make_service());
80/// }
81///```
82#[derive(Clone, Debug)]
83pub struct AddSession<T>(Arc<SessionStore<T>>);
84
85impl<T> AddSession<T> {
86 pub fn new(store: Arc<SessionStore<T>>) -> Self {
87 Self(store)
88 }
89}
90
91impl<T> Predicate<Request> for AddSession<T>
92where
93 T: Send + Sync + 'static,
94{
95 type Request = Request;
96 type Response = Response;
97
98 fn check(&self, mut request: Request) -> Result<Self::Request, Self::Response> {
99 request.extensions_mut().insert(self.0.clone());
100 Ok(request)
101 }
102}
103
104/// Middleware that can access and modify all sessions data.
105///
106/// # Example
107///```
108/// # use spa_rs::routing::{post, Router};
109/// # use spa_rs::Extension;
110/// # use spa_rs::session::RequireSession;
111/// # use spa_rs::session::SessionStore;
112/// # use spa_rs::session::Session;
113/// # use axum_help::filter::FilterExLayer;
114/// # use std::sync::Arc;
115/// #
116/// #[derive(PartialEq, Clone, Debug)]
117/// struct User;
118///
119/// async fn action(Extension(session): Extension<Arc<Session<User>>>) {
120/// println!("current user: {:?}", session.current);
121/// }
122///
123/// #[tokio::main]
124/// async fn main() {
125/// let session = Arc::new(SessionStore::<User>::new("my_session"));
126/// let app = Router::new()
127/// .route("/action", post(action))
128/// .layer(FilterExLayer::new(RequireSession::new(session.clone())));
129/// # axum::Server::bind(&"0.0.0.0:3000".parse().unwrap()).serve(app.into_make_service());
130/// }
131///```
132#[derive(Clone, Debug)]
133pub struct RequireSession<T>(Arc<SessionStore<T>>);
134
135impl<T> RequireSession<T> {
136 pub fn new(store: Arc<SessionStore<T>>) -> Self {
137 Self(store)
138 }
139}
140
141impl<T> Predicate<Request> for RequireSession<T>
142where
143 T: Clone + Send + Sync + 'static,
144{
145 type Request = Request;
146 type Response = Response;
147
148 fn check(&self, mut request: Request) -> Result<Self::Request, Self::Response> {
149 if let Some(cookie) = request.headers().typed_get::<Cookie>() {
150 let sessions = self.0.inner.read();
151 for (k, v) in cookie.iter() {
152 if k == self.0.key {
153 if let Some(u) = sessions.get(v) {
154 request.extensions_mut().insert(Session {
155 current: u.clone(),
156 all: self.0.clone(),
157 });
158 return Ok(request);
159 }
160 }
161 }
162 }
163
164 Err({
165 let mut response = Response::default();
166 *response.status_mut() = StatusCode::UNAUTHORIZED;
167 response
168 })
169 }
170}