sa_token_storage_memory/
lib.rs1use std::collections::HashMap;
13use std::sync::Arc;
14use std::time::Duration;
15use async_trait::async_trait;
16use tokio::sync::RwLock;
17use chrono::{DateTime, Utc};
18use sa_token_adapter::storage::{SaStorage, StorageResult, StorageError};
19
20#[derive(Debug, Clone)]
22struct StorageItem {
23 value: String,
24 expire_at: Option<DateTime<Utc>>,
25}
26
27impl StorageItem {
28 fn new(value: String, ttl: Option<Duration>) -> Self {
29 let expire_at = ttl.map(|d| Utc::now() + chrono::Duration::from_std(d).unwrap());
30 Self { value, expire_at }
31 }
32
33 fn is_expired(&self) -> bool {
34 if let Some(expire_at) = self.expire_at {
35 Utc::now() > expire_at
36 } else {
37 false
38 }
39 }
40}
41
42#[derive(Debug, Clone)]
44pub struct MemoryStorage {
45 data: Arc<RwLock<HashMap<String, StorageItem>>>,
46}
47
48impl MemoryStorage {
49 pub fn new() -> Self {
51 Self {
52 data: Arc::new(RwLock::new(HashMap::new())),
53 }
54 }
55
56 pub async fn cleanup_expired(&self) {
58 let mut data = self.data.write().await;
59 data.retain(|_, item| !item.is_expired());
60 }
61}
62
63impl Default for MemoryStorage {
64 fn default() -> Self {
65 Self::new()
66 }
67}
68
69#[async_trait]
70impl SaStorage for MemoryStorage {
71 async fn get(&self, key: &str) -> StorageResult<Option<String>> {
72 let data = self.data.read().await;
73
74 if let Some(item) = data.get(key) {
75 if item.is_expired() {
76 drop(data);
78 self.delete(key).await?;
79 Ok(None)
80 } else {
81 Ok(Some(item.value.clone()))
82 }
83 } else {
84 Ok(None)
85 }
86 }
87
88 async fn set(&self, key: &str, value: &str, ttl: Option<Duration>) -> StorageResult<()> {
89 let mut data = self.data.write().await;
90 let item = StorageItem::new(value.to_string(), ttl);
91 data.insert(key.to_string(), item);
92 Ok(())
93 }
94
95 async fn delete(&self, key: &str) -> StorageResult<()> {
96 let mut data = self.data.write().await;
97 data.remove(key);
98 Ok(())
99 }
100
101 async fn exists(&self, key: &str) -> StorageResult<bool> {
102 let data = self.data.read().await;
103 if let Some(item) = data.get(key) {
104 Ok(!item.is_expired())
105 } else {
106 Ok(false)
107 }
108 }
109
110 async fn expire(&self, key: &str, ttl: Duration) -> StorageResult<()> {
111 let mut data = self.data.write().await;
112 if let Some(item) = data.get_mut(key) {
113 item.expire_at = Some(Utc::now() + chrono::Duration::from_std(ttl).unwrap());
114 }
115 Ok(())
116 }
117
118 async fn ttl(&self, key: &str) -> StorageResult<Option<Duration>> {
119 let data = self.data.read().await;
120 if let Some(item) = data.get(key) {
121 if let Some(expire_at) = item.expire_at {
122 let now = Utc::now();
123 if expire_at > now {
124 let duration = (expire_at - now).to_std()
125 .map_err(|e| StorageError::InternalError(e.to_string()))?;
126 Ok(Some(duration))
127 } else {
128 Ok(Some(Duration::from_secs(0)))
129 }
130 } else {
131 Ok(None) }
133 } else {
134 Ok(None) }
136 }
137
138 async fn clear(&self) -> StorageResult<()> {
139 let mut data = self.data.write().await;
140 data.clear();
141 Ok(())
142 }
143}
144
145#[cfg(test)]
146mod tests {
147 use super::*;
148
149 #[tokio::test]
150 async fn test_memory_storage() {
151 let storage = MemoryStorage::new();
152
153 storage.set("key1", "value1", None).await.unwrap();
155 let value = storage.get("key1").await.unwrap();
156 assert_eq!(value, Some("value1".to_string()));
157
158 storage.delete("key1").await.unwrap();
160 let value = storage.get("key1").await.unwrap();
161 assert_eq!(value, None);
162
163 storage.set("key2", "value2", None).await.unwrap();
165 assert!(storage.exists("key2").await.unwrap());
166 assert!(!storage.exists("key3").await.unwrap());
167 }
168
169 #[tokio::test]
170 async fn test_ttl() {
171 let storage = MemoryStorage::new();
172
173 storage.set("key1", "value1", Some(Duration::from_secs(1))).await.unwrap();
175
176 let value = storage.get("key1").await.unwrap();
178 assert_eq!(value, Some("value1".to_string()));
179
180 tokio::time::sleep(Duration::from_secs(2)).await;
182
183 let value = storage.get("key1").await.unwrap();
185 assert_eq!(value, None);
186 }
187}