sa_token_storage_memory/
lib.rs

1// Author: 金书记
2//
3//! # sa-token-storage-memory
4//! 
5//! 内存存储实现
6//! 
7//! 适用于:
8//! - 开发测试环境
9//! - 单机部署
10//! - 不需要持久化的场景
11
12use std::collections::HashMap;
13use std::sync::Arc;
14use std::time::Duration;
15use async_trait::async_trait;
16use tokio::sync::RwLock;
17use chrono::{DateTime, Utc};
18use sa_token_adapter::storage::{SaStorage, StorageResult, StorageError};
19
20/// 内存存储项
21#[derive(Debug, Clone)]
22struct StorageItem {
23    value: String,
24    expire_at: Option<DateTime<Utc>>,
25}
26
27impl StorageItem {
28    fn new(value: String, ttl: Option<Duration>) -> Self {
29        let expire_at = ttl.map(|d| Utc::now() + chrono::Duration::from_std(d).unwrap());
30        Self { value, expire_at }
31    }
32    
33    fn is_expired(&self) -> bool {
34        if let Some(expire_at) = self.expire_at {
35            Utc::now() > expire_at
36        } else {
37            false
38        }
39    }
40}
41
42/// 内存存储实现
43#[derive(Debug, Clone)]
44pub struct MemoryStorage {
45    data: Arc<RwLock<HashMap<String, StorageItem>>>,
46}
47
48impl MemoryStorage {
49    /// 创建新的内存存储
50    pub fn new() -> Self {
51        Self {
52            data: Arc::new(RwLock::new(HashMap::new())),
53        }
54    }
55    
56    /// 清理过期的数据
57    pub async fn cleanup_expired(&self) {
58        let mut data = self.data.write().await;
59        data.retain(|_, item| !item.is_expired());
60    }
61}
62
63impl Default for MemoryStorage {
64    fn default() -> Self {
65        Self::new()
66    }
67}
68
69#[async_trait]
70impl SaStorage for MemoryStorage {
71    async fn get(&self, key: &str) -> StorageResult<Option<String>> {
72        let data = self.data.read().await;
73        
74        if let Some(item) = data.get(key) {
75            if item.is_expired() {
76                // 数据已过期
77                drop(data);
78                self.delete(key).await?;
79                Ok(None)
80            } else {
81                Ok(Some(item.value.clone()))
82            }
83        } else {
84            Ok(None)
85        }
86    }
87    
88    async fn set(&self, key: &str, value: &str, ttl: Option<Duration>) -> StorageResult<()> {
89        let mut data = self.data.write().await;
90        let item = StorageItem::new(value.to_string(), ttl);
91        data.insert(key.to_string(), item);
92        Ok(())
93    }
94    
95    async fn delete(&self, key: &str) -> StorageResult<()> {
96        let mut data = self.data.write().await;
97        data.remove(key);
98        Ok(())
99    }
100    
101    async fn exists(&self, key: &str) -> StorageResult<bool> {
102        let data = self.data.read().await;
103        if let Some(item) = data.get(key) {
104            Ok(!item.is_expired())
105        } else {
106            Ok(false)
107        }
108    }
109    
110    async fn expire(&self, key: &str, ttl: Duration) -> StorageResult<()> {
111        let mut data = self.data.write().await;
112        if let Some(item) = data.get_mut(key) {
113            item.expire_at = Some(Utc::now() + chrono::Duration::from_std(ttl).unwrap());
114        }
115        Ok(())
116    }
117    
118    async fn ttl(&self, key: &str) -> StorageResult<Option<Duration>> {
119        let data = self.data.read().await;
120        if let Some(item) = data.get(key) {
121            if let Some(expire_at) = item.expire_at {
122                let now = Utc::now();
123                if expire_at > now {
124                    let duration = (expire_at - now).to_std()
125                        .map_err(|e| StorageError::InternalError(e.to_string()))?;
126                    Ok(Some(duration))
127                } else {
128                    Ok(Some(Duration::from_secs(0)))
129                }
130            } else {
131                Ok(None) // 永不过期
132            }
133        } else {
134            Ok(None) // 键不存在
135        }
136    }
137    
138    async fn clear(&self) -> StorageResult<()> {
139        let mut data = self.data.write().await;
140        data.clear();
141        Ok(())
142    }
143    
144    async fn keys(&self, pattern: &str) -> StorageResult<Vec<String>> {
145        let data = self.data.read().await;
146        let mut result = Vec::new();
147        
148        // 将模式转换为正则表达式
149        let pattern = pattern.replace("*", ".*");
150        let regex = match regex::Regex::new(&pattern) {
151            Ok(r) => r,
152            Err(e) => return Err(StorageError::OperationFailed(format!("Invalid pattern: {}", e))),
153        };
154        
155        // 筛选匹配的键
156        for (key, item) in data.iter() {
157            if !item.is_expired() && regex.is_match(key) {
158                result.push(key.clone());
159            }
160        }
161        
162        Ok(result)
163    }
164}
165
166#[cfg(test)]
167mod tests {
168    use super::*;
169    
170    #[tokio::test]
171    async fn test_memory_storage() {
172        let storage = MemoryStorage::new();
173        
174        // 测试设置和获取
175        storage.set("key1", "value1", None).await.unwrap();
176        let value = storage.get("key1").await.unwrap();
177        assert_eq!(value, Some("value1".to_string()));
178        
179        // 测试删除
180        storage.delete("key1").await.unwrap();
181        let value = storage.get("key1").await.unwrap();
182        assert_eq!(value, None);
183        
184        // 测试存在性
185        storage.set("key2", "value2", None).await.unwrap();
186        assert!(storage.exists("key2").await.unwrap());
187        assert!(!storage.exists("key3").await.unwrap());
188    }
189    
190    #[tokio::test]
191    async fn test_ttl() {
192        let storage = MemoryStorage::new();
193        
194        // 设置带过期时间的键
195        storage.set("key1", "value1", Some(Duration::from_secs(1))).await.unwrap();
196        
197        // 立即获取应该成功
198        let value = storage.get("key1").await.unwrap();
199        assert_eq!(value, Some("value1".to_string()));
200        
201        // 等待过期
202        tokio::time::sleep(Duration::from_secs(2)).await;
203        
204        // 过期后应该返回 None
205        let value = storage.get("key1").await.unwrap();
206        assert_eq!(value, None);
207    }
208}