rust_logic_graph/distributed/
store.rs1use crate::distributed::context::DistributedContext;
6use anyhow::{Result, Context as AnyhowContext};
7use async_trait::async_trait;
8use std::time::Duration;
9
10#[async_trait]
12pub trait ContextStore: Send + Sync {
13 async fn save(&self, context: &DistributedContext, ttl: Option<Duration>) -> Result<()>;
15
16 async fn load(&self, session_id: &str) -> Result<Option<DistributedContext>>;
18
19 async fn delete(&self, session_id: &str) -> Result<()>;
21
22 async fn exists(&self, session_id: &str) -> Result<bool>;
24
25 async fn list_sessions(&self) -> Result<Vec<String>>;
27}
28
29#[cfg(feature = "redis")]
31pub struct RedisStore {
32 client: redis::Client,
33 prefix: String,
34}
35
36#[cfg(feature = "redis")]
37impl RedisStore {
38 pub async fn new(url: &str, prefix: impl Into<String>) -> Result<Self> {
51 let client = redis::Client::open(url)
52 .context("Failed to create Redis client")?;
53
54 let mut conn = client.get_multiplexed_async_connection().await
56 .context("Failed to connect to Redis")?;
57
58 redis::cmd("PING")
59 .query_async::<_, String>(&mut conn)
60 .await
61 .context("Redis connection test failed")?;
62
63 Ok(Self {
64 client,
65 prefix: prefix.into(),
66 })
67 }
68
69 fn make_key(&self, session_id: &str) -> String {
70 format!("{}:{}", self.prefix, session_id)
71 }
72}
73
74#[cfg(feature = "redis")]
75#[async_trait]
76impl ContextStore for RedisStore {
77 async fn save(&self, context: &DistributedContext, ttl: Option<Duration>) -> Result<()> {
78 use redis::AsyncCommands;
79
80 let key = self.make_key(&context.session_id);
81 let data = context.serialize()?;
82
83 let mut conn = self.client.get_multiplexed_async_connection().await
84 .context("Failed to get Redis connection")?;
85
86 if let Some(ttl) = ttl {
87 conn.set_ex(&key, data, ttl.as_secs() as usize).await
88 .context("Failed to save context to Redis with TTL")?;
89 } else {
90 conn.set(&key, data).await
91 .context("Failed to save context to Redis")?;
92 }
93
94 Ok(())
95 }
96
97 async fn load(&self, session_id: &str) -> Result<Option<DistributedContext>> {
98 use redis::AsyncCommands;
99
100 let key = self.make_key(session_id);
101 let mut conn = self.client.get_multiplexed_async_connection().await
102 .context("Failed to get Redis connection")?;
103
104 let data: Option<Vec<u8>> = conn.get(&key).await
105 .context("Failed to load context from Redis")?;
106
107 match data {
108 Some(bytes) => {
109 let context = DistributedContext::deserialize(&bytes)?;
110 Ok(Some(context))
111 }
112 None => Ok(None),
113 }
114 }
115
116 async fn delete(&self, session_id: &str) -> Result<()> {
117 use redis::AsyncCommands;
118
119 let key = self.make_key(session_id);
120 let mut conn = self.client.get_multiplexed_async_connection().await
121 .context("Failed to get Redis connection")?;
122
123 conn.del(&key).await
124 .context("Failed to delete context from Redis")?;
125
126 Ok(())
127 }
128
129 async fn exists(&self, session_id: &str) -> Result<bool> {
130 use redis::AsyncCommands;
131
132 let key = self.make_key(session_id);
133 let mut conn = self.client.get_multiplexed_async_connection().await
134 .context("Failed to get Redis connection")?;
135
136 let exists: bool = conn.exists(&key).await
137 .context("Failed to check existence in Redis")?;
138
139 Ok(exists)
140 }
141
142 async fn list_sessions(&self) -> Result<Vec<String>> {
143 use redis::AsyncCommands;
144
145 let pattern = format!("{}:*", self.prefix);
146 let mut conn = self.client.get_multiplexed_async_connection().await
147 .context("Failed to get Redis connection")?;
148
149 let keys: Vec<String> = conn.keys(&pattern).await
150 .context("Failed to list keys from Redis")?;
151
152 let sessions = keys.into_iter()
154 .filter_map(|k| k.strip_prefix(&format!("{}:", self.prefix)).map(|s| s.to_string()))
155 .collect();
156
157 Ok(sessions)
158 }
159}
160
161pub struct MemcachedStore {
163 servers: Vec<String>,
165 prefix: String,
166}
167
168impl MemcachedStore {
169 pub fn new(servers: Vec<String>, prefix: impl Into<String>) -> Self {
171 Self {
172 servers,
173 prefix: prefix.into(),
174 }
175 }
176}
177
178#[async_trait]
179impl ContextStore for MemcachedStore {
180 async fn save(&self, _context: &DistributedContext, _ttl: Option<Duration>) -> Result<()> {
181 anyhow::bail!("Memcached store not yet implemented")
183 }
184
185 async fn load(&self, _session_id: &str) -> Result<Option<DistributedContext>> {
186 anyhow::bail!("Memcached store not yet implemented")
187 }
188
189 async fn delete(&self, _session_id: &str) -> Result<()> {
190 anyhow::bail!("Memcached store not yet implemented")
191 }
192
193 async fn exists(&self, _session_id: &str) -> Result<bool> {
194 anyhow::bail!("Memcached store not yet implemented")
195 }
196
197 async fn list_sessions(&self) -> Result<Vec<String>> {
198 anyhow::bail!("Memcached store not yet implemented")
199 }
200}
201
202pub struct InMemoryStore {
204 data: std::sync::Arc<tokio::sync::RwLock<std::collections::HashMap<String, Vec<u8>>>>,
205}
206
207impl InMemoryStore {
208 pub fn new() -> Self {
210 Self {
211 data: std::sync::Arc::new(tokio::sync::RwLock::new(std::collections::HashMap::new())),
212 }
213 }
214}
215
216impl Default for InMemoryStore {
217 fn default() -> Self {
218 Self::new()
219 }
220}
221
222#[async_trait]
223impl ContextStore for InMemoryStore {
224 async fn save(&self, context: &DistributedContext, _ttl: Option<Duration>) -> Result<()> {
225 let data = context.serialize()?;
226 let mut store = self.data.write().await;
227 store.insert(context.session_id.clone(), data);
228 Ok(())
229 }
230
231 async fn load(&self, session_id: &str) -> Result<Option<DistributedContext>> {
232 let store = self.data.read().await;
233 match store.get(session_id) {
234 Some(bytes) => {
235 let context = DistributedContext::deserialize(bytes)?;
236 Ok(Some(context))
237 }
238 None => Ok(None),
239 }
240 }
241
242 async fn delete(&self, session_id: &str) -> Result<()> {
243 let mut store = self.data.write().await;
244 store.remove(session_id);
245 Ok(())
246 }
247
248 async fn exists(&self, session_id: &str) -> Result<bool> {
249 let store = self.data.read().await;
250 Ok(store.contains_key(session_id))
251 }
252
253 async fn list_sessions(&self) -> Result<Vec<String>> {
254 let store = self.data.read().await;
255 Ok(store.keys().cloned().collect())
256 }
257}
258
259#[cfg(test)]
260mod tests {
261 use super::*;
262 use serde_json::json;
263
264 #[tokio::test]
265 async fn test_in_memory_store() {
266 let store = InMemoryStore::new();
267 let mut ctx = DistributedContext::new("test-session");
268 ctx.set("key1", json!("value1"));
269
270 store.save(&ctx, None).await.unwrap();
272
273 let loaded = store.load("test-session").await.unwrap();
275 assert!(loaded.is_some());
276 let loaded = loaded.unwrap();
277 assert_eq!(loaded.get("key1"), Some(&json!("value1")));
278
279 assert!(store.exists("test-session").await.unwrap());
281
282 store.delete("test-session").await.unwrap();
284 assert!(!store.exists("test-session").await.unwrap());
285 }
286
287 #[tokio::test]
288 async fn test_list_sessions() {
289 let store = InMemoryStore::new();
290
291 let ctx1 = DistributedContext::new("session-1");
292 let ctx2 = DistributedContext::new("session-2");
293
294 store.save(&ctx1, None).await.unwrap();
295 store.save(&ctx2, None).await.unwrap();
296
297 let sessions = store.list_sessions().await.unwrap();
298 assert_eq!(sessions.len(), 2);
299 assert!(sessions.contains(&"session-1".to_string()));
300 assert!(sessions.contains(&"session-2".to_string()));
301 }
302}