rust_logic_graph/distributed/
store.rs1use crate::distributed::context::DistributedContext;
6use anyhow::{Context as AnyhowContext, Result};
7use async_trait::async_trait;
8use std::time::Duration;
9
10#[async_trait]
12pub trait ContextStore: Send + Sync {
13 async fn save(&self, context: &DistributedContext, ttl: Option<Duration>) -> Result<()>;
15
16 async fn load(&self, session_id: &str) -> Result<Option<DistributedContext>>;
18
19 async fn delete(&self, session_id: &str) -> Result<()>;
21
22 async fn exists(&self, session_id: &str) -> Result<bool>;
24
25 async fn list_sessions(&self) -> Result<Vec<String>>;
27}
28
29#[cfg(feature = "redis")]
31pub struct RedisStore {
32 client: redis::Client,
33 prefix: String,
34}
35
36#[cfg(feature = "redis")]
37impl RedisStore {
38 pub async fn new(url: &str, prefix: impl Into<String>) -> Result<Self> {
51 let client = redis::Client::open(url).context("Failed to create Redis client")?;
52
53 let mut conn = client
55 .get_multiplexed_async_connection()
56 .await
57 .context("Failed to connect to Redis")?;
58
59 redis::cmd("PING")
60 .query_async::<_, String>(&mut conn)
61 .await
62 .context("Redis connection test failed")?;
63
64 Ok(Self {
65 client,
66 prefix: prefix.into(),
67 })
68 }
69
70 fn make_key(&self, session_id: &str) -> String {
71 format!("{}:{}", self.prefix, session_id)
72 }
73}
74
75#[cfg(feature = "redis")]
76#[async_trait]
77impl ContextStore for RedisStore {
78 async fn save(&self, context: &DistributedContext, ttl: Option<Duration>) -> Result<()> {
79 use redis::AsyncCommands;
80
81 let key = self.make_key(&context.session_id);
82 let data = context.serialize()?;
83
84 let mut conn = self
85 .client
86 .get_multiplexed_async_connection()
87 .await
88 .context("Failed to get Redis connection")?;
89
90 if let Some(ttl) = ttl {
91 conn.set_ex(&key, data, ttl.as_secs() as usize)
92 .await
93 .context("Failed to save context to Redis with TTL")?;
94 } else {
95 conn.set(&key, data)
96 .await
97 .context("Failed to save context to Redis")?;
98 }
99
100 Ok(())
101 }
102
103 async fn load(&self, session_id: &str) -> Result<Option<DistributedContext>> {
104 use redis::AsyncCommands;
105
106 let key = self.make_key(session_id);
107 let mut conn = self
108 .client
109 .get_multiplexed_async_connection()
110 .await
111 .context("Failed to get Redis connection")?;
112
113 let data: Option<Vec<u8>> = conn
114 .get(&key)
115 .await
116 .context("Failed to load context from Redis")?;
117
118 match data {
119 Some(bytes) => {
120 let context = DistributedContext::deserialize(&bytes)?;
121 Ok(Some(context))
122 }
123 None => Ok(None),
124 }
125 }
126
127 async fn delete(&self, session_id: &str) -> Result<()> {
128 use redis::AsyncCommands;
129
130 let key = self.make_key(session_id);
131 let mut conn = self
132 .client
133 .get_multiplexed_async_connection()
134 .await
135 .context("Failed to get Redis connection")?;
136
137 conn.del(&key)
138 .await
139 .context("Failed to delete context from Redis")?;
140
141 Ok(())
142 }
143
144 async fn exists(&self, session_id: &str) -> Result<bool> {
145 use redis::AsyncCommands;
146
147 let key = self.make_key(session_id);
148 let mut conn = self
149 .client
150 .get_multiplexed_async_connection()
151 .await
152 .context("Failed to get Redis connection")?;
153
154 let exists: bool = conn
155 .exists(&key)
156 .await
157 .context("Failed to check existence in Redis")?;
158
159 Ok(exists)
160 }
161
162 async fn list_sessions(&self) -> Result<Vec<String>> {
163 use redis::AsyncCommands;
164
165 let pattern = format!("{}:*", self.prefix);
166 let mut conn = self
167 .client
168 .get_multiplexed_async_connection()
169 .await
170 .context("Failed to get Redis connection")?;
171
172 let keys: Vec<String> = conn
173 .keys(&pattern)
174 .await
175 .context("Failed to list keys from Redis")?;
176
177 let sessions = keys
179 .into_iter()
180 .filter_map(|k| {
181 k.strip_prefix(&format!("{}:", self.prefix))
182 .map(|s| s.to_string())
183 })
184 .collect();
185
186 Ok(sessions)
187 }
188}
189
190pub struct MemcachedStore {
192 servers: Vec<String>,
194 prefix: String,
195}
196
197impl MemcachedStore {
198 pub fn new(servers: Vec<String>, prefix: impl Into<String>) -> Self {
200 Self {
201 servers,
202 prefix: prefix.into(),
203 }
204 }
205}
206
207#[async_trait]
208impl ContextStore for MemcachedStore {
209 async fn save(&self, _context: &DistributedContext, _ttl: Option<Duration>) -> Result<()> {
210 anyhow::bail!("Memcached store not yet implemented")
212 }
213
214 async fn load(&self, _session_id: &str) -> Result<Option<DistributedContext>> {
215 anyhow::bail!("Memcached store not yet implemented")
216 }
217
218 async fn delete(&self, _session_id: &str) -> Result<()> {
219 anyhow::bail!("Memcached store not yet implemented")
220 }
221
222 async fn exists(&self, _session_id: &str) -> Result<bool> {
223 anyhow::bail!("Memcached store not yet implemented")
224 }
225
226 async fn list_sessions(&self) -> Result<Vec<String>> {
227 anyhow::bail!("Memcached store not yet implemented")
228 }
229}
230
231pub struct InMemoryStore {
233 data: std::sync::Arc<tokio::sync::RwLock<std::collections::HashMap<String, Vec<u8>>>>,
234}
235
236impl InMemoryStore {
237 pub fn new() -> Self {
239 Self {
240 data: std::sync::Arc::new(tokio::sync::RwLock::new(std::collections::HashMap::new())),
241 }
242 }
243}
244
245impl Default for InMemoryStore {
246 fn default() -> Self {
247 Self::new()
248 }
249}
250
251#[async_trait]
252impl ContextStore for InMemoryStore {
253 async fn save(&self, context: &DistributedContext, _ttl: Option<Duration>) -> Result<()> {
254 let data = context.serialize()?;
255 let mut store = self.data.write().await;
256 store.insert(context.session_id.clone(), data);
257 Ok(())
258 }
259
260 async fn load(&self, session_id: &str) -> Result<Option<DistributedContext>> {
261 let store = self.data.read().await;
262 match store.get(session_id) {
263 Some(bytes) => {
264 let context = DistributedContext::deserialize(bytes)?;
265 Ok(Some(context))
266 }
267 None => Ok(None),
268 }
269 }
270
271 async fn delete(&self, session_id: &str) -> Result<()> {
272 let mut store = self.data.write().await;
273 store.remove(session_id);
274 Ok(())
275 }
276
277 async fn exists(&self, session_id: &str) -> Result<bool> {
278 let store = self.data.read().await;
279 Ok(store.contains_key(session_id))
280 }
281
282 async fn list_sessions(&self) -> Result<Vec<String>> {
283 let store = self.data.read().await;
284 Ok(store.keys().cloned().collect())
285 }
286}
287
288#[cfg(test)]
289mod tests {
290 use super::*;
291 use serde_json::json;
292
293 #[tokio::test]
294 async fn test_in_memory_store() {
295 let store = InMemoryStore::new();
296 let mut ctx = DistributedContext::new("test-session");
297 ctx.set("key1", json!("value1"));
298
299 store.save(&ctx, None).await.unwrap();
301
302 let loaded = store.load("test-session").await.unwrap();
304 assert!(loaded.is_some());
305 let loaded = loaded.unwrap();
306 assert_eq!(loaded.get("key1"), Some(&json!("value1")));
307
308 assert!(store.exists("test-session").await.unwrap());
310
311 store.delete("test-session").await.unwrap();
313 assert!(!store.exists("test-session").await.unwrap());
314 }
315
316 #[tokio::test]
317 async fn test_list_sessions() {
318 let store = InMemoryStore::new();
319
320 let ctx1 = DistributedContext::new("session-1");
321 let ctx2 = DistributedContext::new("session-2");
322
323 store.save(&ctx1, None).await.unwrap();
324 store.save(&ctx2, None).await.unwrap();
325
326 let sessions = store.list_sessions().await.unwrap();
327 assert_eq!(sessions.len(), 2);
328 assert!(sessions.contains(&"session-1".to_string()));
329 assert!(sessions.contains(&"session-2".to_string()));
330 }
331}