Skip to main content

wae_session/
session.rs

1//! Session 数据结构
2//!
3//! 提供 Session 的核心数据结构和操作方法。
4
5use serde::{Serialize, de::DeserializeOwned};
6use std::{collections::HashMap, sync::Arc};
7use tokio::sync::RwLock;
8
9/// Session ID 类型别名
10pub type SessionId = String;
11
12/// Session 数据结构
13///
14/// 表示一个 HTTP Session,包含 Session ID 和存储的数据。
15/// 支持类型化的数据存取。
16#[derive(Debug, Clone)]
17pub struct Session {
18    /// Session ID
19    id: SessionId,
20    /// Session 数据
21    data: Arc<RwLock<HashMap<String, serde_json::Value>>>,
22    /// 是否已修改
23    dirty: Arc<RwLock<bool>>,
24    /// 是否是新创建的
25    is_new: bool,
26}
27
28impl Session {
29    /// 创建新的 Session
30    pub fn new(id: impl Into<String>) -> Self {
31        Self { id: id.into(), data: Arc::new(RwLock::new(HashMap::new())), dirty: Arc::new(RwLock::new(false)), is_new: true }
32    }
33
34    /// 从现有数据创建 Session
35    pub async fn from_data(id: impl Into<String>, data: HashMap<String, serde_json::Value>) -> Self {
36        Self { id: id.into(), data: Arc::new(RwLock::new(data)), dirty: Arc::new(RwLock::new(false)), is_new: false }
37    }
38
39    /// 获取 Session ID
40    pub fn id(&self) -> &str {
41        &self.id
42    }
43
44    /// 检查是否是新创建的 Session
45    pub fn is_new(&self) -> bool {
46        self.is_new
47    }
48
49    /// 检查 Session 是否已修改
50    pub async fn is_dirty(&self) -> bool {
51        *self.dirty.read().await
52    }
53
54    /// 标记 Session 为已修改
55    pub async fn mark_dirty(&self) {
56        *self.dirty.write().await = true;
57    }
58
59    /// 获取值
60    pub async fn get(&self, key: &str) -> Option<serde_json::Value> {
61        self.data.read().await.get(key).cloned()
62    }
63
64    /// 获取类型化的值
65    pub async fn get_typed<T: DeserializeOwned>(&self, key: &str) -> Option<T> {
66        self.data.read().await.get(key).and_then(|v| serde_json::from_value(v.clone()).ok())
67    }
68
69    /// 设置值
70    pub async fn set<T: Serialize>(&self, key: impl Into<String>, value: T) {
71        let key = key.into();
72        let json_value = serde_json::to_value(value).unwrap_or(serde_json::Value::Null);
73        self.data.write().await.insert(key, json_value);
74        *self.dirty.write().await = true;
75    }
76
77    /// 删除值
78    pub async fn remove(&self, key: &str) -> Option<serde_json::Value> {
79        let result = self.data.write().await.remove(key);
80        if result.is_some() {
81            *self.dirty.write().await = true;
82        }
83        result
84    }
85
86    /// 检查键是否存在
87    pub async fn contains(&self, key: &str) -> bool {
88        self.data.read().await.contains_key(key)
89    }
90
91    /// 清空所有数据
92    pub async fn clear(&self) {
93        self.data.write().await.clear();
94        *self.dirty.write().await = true;
95    }
96
97    /// 获取所有键
98    pub async fn keys(&self) -> Vec<String> {
99        self.data.read().await.keys().cloned().collect()
100    }
101
102    /// 获取数据条目数量
103    pub async fn len(&self) -> usize {
104        self.data.read().await.len()
105    }
106
107    /// 检查是否为空
108    pub async fn is_empty(&self) -> bool {
109        self.data.read().await.is_empty()
110    }
111
112    /// 序列化为 JSON 字符串
113    pub async fn to_json(&self) -> String {
114        let data = self.data.read().await;
115        serde_json::to_string(&*data).unwrap_or_default()
116    }
117
118    /// 从 JSON 字符串反序列化
119    pub async fn from_json(&self, json: &str) -> bool {
120        if let Ok(data) = serde_json::from_str::<HashMap<String, serde_json::Value>>(json) {
121            *self.data.write().await = data;
122            true
123        }
124        else {
125            false
126        }
127    }
128}