Skip to main content

wae_session/
extract.rs

1//! Session 提取器
2//!
3//! 提供 axum 的 Session 提取器实现。
4
5use crate::Session;
6use axum::{extract::FromRequestParts, http::request::Parts};
7use std::sync::Arc;
8
9/// Session 提取器
10///
11/// 用于在 axum 处理函数中提取 Session。
12///
13/// # 示例
14///
15/// ```rust,ignore
16/// use wae_session::SessionExtractor;
17///
18/// async fn handler(session: SessionExtractor) -> impl IntoResponse {
19///     let user_id: Option<String> = session.get_typed("user_id").await;
20///     // ...
21/// }
22/// ```
23#[derive(Debug, Clone)]
24pub struct SessionExtractor {
25    /// Session 引用
26    session: Arc<Session>,
27}
28
29impl SessionExtractor {
30    /// 获取 Session 引用
31    pub fn inner(&self) -> &Session {
32        &self.session
33    }
34
35    /// 获取 Session ID
36    pub fn id(&self) -> &str {
37        self.session.id()
38    }
39
40    /// 检查是否是新创建的 Session
41    pub fn is_new(&self) -> bool {
42        self.session.is_new()
43    }
44
45    /// 获取值
46    pub async fn get(&self, key: &str) -> Option<serde_json::Value> {
47        self.session.get(key).await
48    }
49
50    /// 获取类型化的值
51    pub async fn get_typed<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
52        self.session.get_typed(key).await
53    }
54
55    /// 设置值
56    pub async fn set<T: serde::Serialize>(&self, key: impl Into<String>, value: T) {
57        self.session.set(key, value).await
58    }
59
60    /// 删除值
61    pub async fn remove(&self, key: &str) -> Option<serde_json::Value> {
62        self.session.remove(key).await
63    }
64
65    /// 检查键是否存在
66    pub async fn contains(&self, key: &str) -> bool {
67        self.session.contains(key).await
68    }
69
70    /// 清空所有数据
71    pub async fn clear(&self) {
72        self.session.clear().await
73    }
74
75    /// 获取所有键
76    pub async fn keys(&self) -> Vec<String> {
77        self.session.keys().await
78    }
79
80    /// 获取数据条目数量
81    pub async fn len(&self) -> usize {
82        self.session.len().await
83    }
84
85    /// 检查是否为空
86    pub async fn is_empty(&self) -> bool {
87        self.session.is_empty().await
88    }
89}
90
91impl<S> FromRequestParts<S> for SessionExtractor
92where
93    S: Send + Sync,
94{
95    type Rejection = SessionRejection;
96
97    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
98        parts
99            .extensions
100            .get::<Arc<Session>>()
101            .cloned()
102            .map(|session| SessionExtractor { session })
103            .ok_or(SessionRejection::MissingSession)
104    }
105}
106
107/// Session 提取器拒绝错误
108#[derive(Debug, Clone)]
109pub enum SessionRejection {
110    /// Session 不存在
111    MissingSession,
112}
113
114impl std::fmt::Display for SessionRejection {
115    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
116        match self {
117            SessionRejection::MissingSession => write!(f, "Session not found in request extensions"),
118        }
119    }
120}
121
122impl std::error::Error for SessionRejection {}
123
124impl axum::response::IntoResponse for SessionRejection {
125    fn into_response(self) -> axum::response::Response {
126        let body = axum::body::Body::from(self.to_string());
127        axum::http::Response::builder().status(axum::http::StatusCode::INTERNAL_SERVER_ERROR).body(body).unwrap()
128    }
129}