Skip to main content

wae_session/
store.rs

1//! Session 存储模块
2//!
3//! 提供 Session 存储的 trait 定义和内存实现。
4
5use async_trait::async_trait;
6use dashmap::DashMap;
7use serde::{Serialize, de::DeserializeOwned};
8use std::{sync::Arc, time::Duration};
9use tokio::time::Instant;
10
11/// Session 存储接口
12///
13/// 定义了 Session 存储的基本操作,支持异步操作。
14/// 实现此 trait 可以创建自定义的 Session 存储后端。
15#[async_trait]
16pub trait SessionStore: Clone + Send + Sync + 'static {
17    /// 获取 Session 数据
18    async fn get(&self, session_id: &str) -> Option<String>;
19
20    /// 设置 Session 数据
21    async fn set(&self, session_id: &str, data: &str, ttl: Duration);
22
23    /// 删除 Session
24    async fn remove(&self, session_id: &str);
25
26    /// 检查 Session 是否存在
27    async fn exists(&self, session_id: &str) -> bool;
28
29    /// 刷新 Session 过期时间
30    async fn refresh(&self, session_id: &str, ttl: Duration);
31
32    /// 获取类型化的 Session 数据
33    async fn get_typed<T: DeserializeOwned + Send + Sync>(&self, session_id: &str) -> Option<T> {
34        let data = self.get(session_id).await?;
35        serde_json::from_str(&data).ok()
36    }
37
38    /// 设置类型化的 Session 数据
39    async fn set_typed<T: Serialize + Send + Sync>(&self, session_id: &str, data: &T, ttl: Duration) -> bool {
40        match serde_json::to_string(data) {
41            Ok(json) => {
42                self.set(session_id, &json, ttl).await;
43                true
44            }
45            Err(_) => false,
46        }
47    }
48}
49
50/// Session 条目
51#[derive(Debug, Clone)]
52struct SessionEntry {
53    /// Session 数据
54    data: String,
55    /// 过期时间
56    expires_at: Instant,
57}
58
59/// 内存 Session 存储
60///
61/// 基于内存的 Session 存储实现,使用 DashMap 提供高并发访问。
62/// 支持自动过期清理。
63#[derive(Debug, Clone)]
64pub struct MemorySessionStore {
65    /// 存储映射
66    storage: Arc<DashMap<String, SessionEntry>>,
67    /// 清理间隔
68    cleanup_interval: Duration,
69    /// 是否启用自动清理
70    #[allow(dead_code)]
71    auto_cleanup: bool,
72}
73
74impl MemorySessionStore {
75    /// 创建新的内存 Session 存储
76    pub fn new() -> Self {
77        Self::with_config(Duration::from_secs(60), true)
78    }
79
80    /// 使用配置创建内存 Session 存储
81    pub fn with_config(cleanup_interval: Duration, auto_cleanup: bool) -> Self {
82        let store = Self { storage: Arc::new(DashMap::new()), cleanup_interval, auto_cleanup };
83
84        if auto_cleanup {
85            store.start_cleanup_task();
86        }
87
88        store
89    }
90
91    /// 启动清理任务
92    fn start_cleanup_task(&self) {
93        let storage = Arc::clone(&self.storage);
94        let interval = self.cleanup_interval;
95
96        tokio::spawn(async move {
97            loop {
98                tokio::time::sleep(interval).await;
99                let now = Instant::now();
100
101                storage.retain(|_, entry| entry.expires_at > now);
102            }
103        });
104    }
105
106    /// 手动清理过期 Session
107    pub fn cleanup_expired(&self) {
108        let now = Instant::now();
109        self.storage.retain(|_, entry| entry.expires_at > now);
110    }
111
112    /// 获取 Session 数量
113    pub fn len(&self) -> usize {
114        self.storage.len()
115    }
116
117    /// 检查是否为空
118    pub fn is_empty(&self) -> bool {
119        self.storage.is_empty()
120    }
121
122    /// 清空所有 Session
123    pub fn clear(&self) {
124        self.storage.clear();
125    }
126}
127
128impl Default for MemorySessionStore {
129    fn default() -> Self {
130        Self::new()
131    }
132}
133
134#[async_trait]
135impl SessionStore for MemorySessionStore {
136    async fn get(&self, session_id: &str) -> Option<String> {
137        self.storage.get(session_id).filter(|entry| entry.expires_at > Instant::now()).map(|entry| entry.data.clone())
138    }
139
140    async fn set(&self, session_id: &str, data: &str, ttl: Duration) {
141        let entry = SessionEntry { data: data.to_string(), expires_at: Instant::now() + ttl };
142        self.storage.insert(session_id.to_string(), entry);
143    }
144
145    async fn remove(&self, session_id: &str) {
146        self.storage.remove(session_id);
147    }
148
149    async fn exists(&self, session_id: &str) -> bool {
150        self.storage.get(session_id).map(|entry| entry.expires_at > Instant::now()).unwrap_or(false)
151    }
152
153    async fn refresh(&self, session_id: &str, ttl: Duration) {
154        if let Some(mut entry) = self.storage.get_mut(session_id) {
155            entry.expires_at = Instant::now() + ttl;
156        }
157    }
158}