sa_token_storage_memory/
lib.rs

1// Author: 金书记
2//
3//! # sa-token-storage-memory
4//! 
5//! 内存存储实现
6//! 
7//! 适用于:
8//! - 开发测试环境
9//! - 单机部署
10//! - 不需要持久化的场景
11
12use std::collections::HashMap;
13use std::sync::Arc;
14use std::time::Duration;
15use async_trait::async_trait;
16use tokio::sync::RwLock;
17use chrono::{DateTime, Utc};
18use sa_token_adapter::storage::{SaStorage, StorageResult, StorageError};
19
20/// 内存存储项
21#[derive(Debug, Clone)]
22struct StorageItem {
23    value: String,
24    expire_at: Option<DateTime<Utc>>,
25}
26
27impl StorageItem {
28    fn new(value: String, ttl: Option<Duration>) -> Self {
29        let expire_at = ttl.map(|d| Utc::now() + chrono::Duration::from_std(d).unwrap());
30        Self { value, expire_at }
31    }
32    
33    fn is_expired(&self) -> bool {
34        if let Some(expire_at) = self.expire_at {
35            Utc::now() > expire_at
36        } else {
37            false
38        }
39    }
40}
41
42/// 内存存储实现
43#[derive(Debug, Clone)]
44pub struct MemoryStorage {
45    data: Arc<RwLock<HashMap<String, StorageItem>>>,
46}
47
48impl MemoryStorage {
49    /// 创建新的内存存储
50    pub fn new() -> Self {
51        Self {
52            data: Arc::new(RwLock::new(HashMap::new())),
53        }
54    }
55    
56    /// 清理过期的数据
57    pub async fn cleanup_expired(&self) {
58        let mut data = self.data.write().await;
59        data.retain(|_, item| !item.is_expired());
60    }
61}
62
63impl Default for MemoryStorage {
64    fn default() -> Self {
65        Self::new()
66    }
67}
68
69#[async_trait]
70impl SaStorage for MemoryStorage {
71    async fn get(&self, key: &str) -> StorageResult<Option<String>> {
72        let data = self.data.read().await;
73        
74        if let Some(item) = data.get(key) {
75            if item.is_expired() {
76                // 数据已过期
77                drop(data);
78                self.delete(key).await?;
79                Ok(None)
80            } else {
81                Ok(Some(item.value.clone()))
82            }
83        } else {
84            Ok(None)
85        }
86    }
87    
88    async fn set(&self, key: &str, value: &str, ttl: Option<Duration>) -> StorageResult<()> {
89        let mut data = self.data.write().await;
90        let item = StorageItem::new(value.to_string(), ttl);
91        data.insert(key.to_string(), item);
92        Ok(())
93    }
94    
95    async fn delete(&self, key: &str) -> StorageResult<()> {
96        let mut data = self.data.write().await;
97        data.remove(key);
98        Ok(())
99    }
100    
101    async fn exists(&self, key: &str) -> StorageResult<bool> {
102        let data = self.data.read().await;
103        if let Some(item) = data.get(key) {
104            Ok(!item.is_expired())
105        } else {
106            Ok(false)
107        }
108    }
109    
110    async fn expire(&self, key: &str, ttl: Duration) -> StorageResult<()> {
111        let mut data = self.data.write().await;
112        if let Some(item) = data.get_mut(key) {
113            item.expire_at = Some(Utc::now() + chrono::Duration::from_std(ttl).unwrap());
114        }
115        Ok(())
116    }
117    
118    async fn ttl(&self, key: &str) -> StorageResult<Option<Duration>> {
119        let data = self.data.read().await;
120        if let Some(item) = data.get(key) {
121            if let Some(expire_at) = item.expire_at {
122                let now = Utc::now();
123                if expire_at > now {
124                    let duration = (expire_at - now).to_std()
125                        .map_err(|e| StorageError::InternalError(e.to_string()))?;
126                    Ok(Some(duration))
127                } else {
128                    Ok(Some(Duration::from_secs(0)))
129                }
130            } else {
131                Ok(None) // 永不过期
132            }
133        } else {
134            Ok(None) // 键不存在
135        }
136    }
137    
138    async fn clear(&self) -> StorageResult<()> {
139        let mut data = self.data.write().await;
140        data.clear();
141        Ok(())
142    }
143}
144
145#[cfg(test)]
146mod tests {
147    use super::*;
148    
149    #[tokio::test]
150    async fn test_memory_storage() {
151        let storage = MemoryStorage::new();
152        
153        // 测试设置和获取
154        storage.set("key1", "value1", None).await.unwrap();
155        let value = storage.get("key1").await.unwrap();
156        assert_eq!(value, Some("value1".to_string()));
157        
158        // 测试删除
159        storage.delete("key1").await.unwrap();
160        let value = storage.get("key1").await.unwrap();
161        assert_eq!(value, None);
162        
163        // 测试存在性
164        storage.set("key2", "value2", None).await.unwrap();
165        assert!(storage.exists("key2").await.unwrap());
166        assert!(!storage.exists("key3").await.unwrap());
167    }
168    
169    #[tokio::test]
170    async fn test_ttl() {
171        let storage = MemoryStorage::new();
172        
173        // 设置带过期时间的键
174        storage.set("key1", "value1", Some(Duration::from_secs(1))).await.unwrap();
175        
176        // 立即获取应该成功
177        let value = storage.get("key1").await.unwrap();
178        assert_eq!(value, Some("value1".to_string()));
179        
180        // 等待过期
181        tokio::time::sleep(Duration::from_secs(2)).await;
182        
183        // 过期后应该返回 None
184        let value = storage.get("key1").await.unwrap();
185        assert_eq!(value, None);
186    }
187}