1use crate::Session;
6use axum::{extract::FromRequestParts, http::request::Parts};
7use std::sync::Arc;
8
9#[derive(Debug, Clone)]
24pub struct SessionExtractor {
25 session: Arc<Session>,
27}
28
29impl SessionExtractor {
30 pub fn inner(&self) -> &Session {
32 &self.session
33 }
34
35 pub fn id(&self) -> &str {
37 self.session.id()
38 }
39
40 pub fn is_new(&self) -> bool {
42 self.session.is_new()
43 }
44
45 pub async fn get(&self, key: &str) -> Option<serde_json::Value> {
47 self.session.get(key).await
48 }
49
50 pub async fn get_typed<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
52 self.session.get_typed(key).await
53 }
54
55 pub async fn set<T: serde::Serialize>(&self, key: impl Into<String>, value: T) {
57 self.session.set(key, value).await
58 }
59
60 pub async fn remove(&self, key: &str) -> Option<serde_json::Value> {
62 self.session.remove(key).await
63 }
64
65 pub async fn contains(&self, key: &str) -> bool {
67 self.session.contains(key).await
68 }
69
70 pub async fn clear(&self) {
72 self.session.clear().await
73 }
74
75 pub async fn keys(&self) -> Vec<String> {
77 self.session.keys().await
78 }
79
80 pub async fn len(&self) -> usize {
82 self.session.len().await
83 }
84
85 pub async fn is_empty(&self) -> bool {
87 self.session.is_empty().await
88 }
89}
90
91impl<S> FromRequestParts<S> for SessionExtractor
92where
93 S: Send + Sync,
94{
95 type Rejection = SessionRejection;
96
97 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
98 parts
99 .extensions
100 .get::<Arc<Session>>()
101 .cloned()
102 .map(|session| SessionExtractor { session })
103 .ok_or(SessionRejection::MissingSession)
104 }
105}
106
107#[derive(Debug, Clone)]
109pub enum SessionRejection {
110 MissingSession,
112}
113
114impl std::fmt::Display for SessionRejection {
115 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
116 match self {
117 SessionRejection::MissingSession => write!(f, "Session not found in request extensions"),
118 }
119 }
120}
121
122impl std::error::Error for SessionRejection {}
123
124impl axum::response::IntoResponse for SessionRejection {
125 fn into_response(self) -> axum::response::Response {
126 let body = axum::body::Body::from(self.to_string());
127 axum::http::Response::builder().status(axum::http::StatusCode::INTERNAL_SERVER_ERROR).body(body).unwrap()
128 }
129}