rust_logic_graph/distributed/
store.rs

1//! Context Store Implementations
2//!
3//! Provides storage backends for distributed contexts including Redis and Memcached.
4
5use crate::distributed::context::DistributedContext;
6use anyhow::{Context as AnyhowContext, Result};
7use async_trait::async_trait;
8use std::time::Duration;
9
10/// Trait for context storage backends
11#[async_trait]
12pub trait ContextStore: Send + Sync {
13    /// Save a context to the store
14    async fn save(&self, context: &DistributedContext, ttl: Option<Duration>) -> Result<()>;
15
16    /// Load a context from the store
17    async fn load(&self, session_id: &str) -> Result<Option<DistributedContext>>;
18
19    /// Delete a context from the store
20    async fn delete(&self, session_id: &str) -> Result<()>;
21
22    /// Check if a context exists
23    async fn exists(&self, session_id: &str) -> Result<bool>;
24
25    /// List all session IDs (for debugging)
26    async fn list_sessions(&self) -> Result<Vec<String>>;
27}
28
29/// Redis-based context store
30#[cfg(feature = "redis")]
31pub struct RedisStore {
32    client: redis::Client,
33    prefix: String,
34}
35
36#[cfg(feature = "redis")]
37impl RedisStore {
38    /// Create a new Redis store
39    ///
40    /// # Example
41    ///
42    /// ```no_run
43    /// use rust_logic_graph::distributed::RedisStore;
44    ///
45    /// # async fn example() -> anyhow::Result<()> {
46    /// let store = RedisStore::new("redis://localhost:6379", "ctx").await?;
47    /// # Ok(())
48    /// # }
49    /// ```
50    pub async fn new(url: &str, prefix: impl Into<String>) -> Result<Self> {
51        let client = redis::Client::open(url).context("Failed to create Redis client")?;
52
53        // Test connection
54        let mut conn = client
55            .get_multiplexed_async_connection()
56            .await
57            .context("Failed to connect to Redis")?;
58
59        redis::cmd("PING")
60            .query_async::<_, String>(&mut conn)
61            .await
62            .context("Redis connection test failed")?;
63
64        Ok(Self {
65            client,
66            prefix: prefix.into(),
67        })
68    }
69
70    fn make_key(&self, session_id: &str) -> String {
71        format!("{}:{}", self.prefix, session_id)
72    }
73}
74
75#[cfg(feature = "redis")]
76#[async_trait]
77impl ContextStore for RedisStore {
78    async fn save(&self, context: &DistributedContext, ttl: Option<Duration>) -> Result<()> {
79        use redis::AsyncCommands;
80
81        let key = self.make_key(&context.session_id);
82        let data = context.serialize()?;
83
84        let mut conn = self
85            .client
86            .get_multiplexed_async_connection()
87            .await
88            .context("Failed to get Redis connection")?;
89
90        if let Some(ttl) = ttl {
91            conn.set_ex(&key, data, ttl.as_secs() as usize)
92                .await
93                .context("Failed to save context to Redis with TTL")?;
94        } else {
95            conn.set(&key, data)
96                .await
97                .context("Failed to save context to Redis")?;
98        }
99
100        Ok(())
101    }
102
103    async fn load(&self, session_id: &str) -> Result<Option<DistributedContext>> {
104        use redis::AsyncCommands;
105
106        let key = self.make_key(session_id);
107        let mut conn = self
108            .client
109            .get_multiplexed_async_connection()
110            .await
111            .context("Failed to get Redis connection")?;
112
113        let data: Option<Vec<u8>> = conn
114            .get(&key)
115            .await
116            .context("Failed to load context from Redis")?;
117
118        match data {
119            Some(bytes) => {
120                let context = DistributedContext::deserialize(&bytes)?;
121                Ok(Some(context))
122            }
123            None => Ok(None),
124        }
125    }
126
127    async fn delete(&self, session_id: &str) -> Result<()> {
128        use redis::AsyncCommands;
129
130        let key = self.make_key(session_id);
131        let mut conn = self
132            .client
133            .get_multiplexed_async_connection()
134            .await
135            .context("Failed to get Redis connection")?;
136
137        conn.del(&key)
138            .await
139            .context("Failed to delete context from Redis")?;
140
141        Ok(())
142    }
143
144    async fn exists(&self, session_id: &str) -> Result<bool> {
145        use redis::AsyncCommands;
146
147        let key = self.make_key(session_id);
148        let mut conn = self
149            .client
150            .get_multiplexed_async_connection()
151            .await
152            .context("Failed to get Redis connection")?;
153
154        let exists: bool = conn
155            .exists(&key)
156            .await
157            .context("Failed to check existence in Redis")?;
158
159        Ok(exists)
160    }
161
162    async fn list_sessions(&self) -> Result<Vec<String>> {
163        use redis::AsyncCommands;
164
165        let pattern = format!("{}:*", self.prefix);
166        let mut conn = self
167            .client
168            .get_multiplexed_async_connection()
169            .await
170            .context("Failed to get Redis connection")?;
171
172        let keys: Vec<String> = conn
173            .keys(&pattern)
174            .await
175            .context("Failed to list keys from Redis")?;
176
177        // Remove prefix from keys
178        let sessions = keys
179            .into_iter()
180            .filter_map(|k| {
181                k.strip_prefix(&format!("{}:", self.prefix))
182                    .map(|s| s.to_string())
183            })
184            .collect();
185
186        Ok(sessions)
187    }
188}
189
190/// Memcached-based context store
191pub struct MemcachedStore {
192    // Placeholder for memcached client
193    servers: Vec<String>,
194    prefix: String,
195}
196
197impl MemcachedStore {
198    /// Create a new Memcached store
199    pub fn new(servers: Vec<String>, prefix: impl Into<String>) -> Self {
200        Self {
201            servers,
202            prefix: prefix.into(),
203        }
204    }
205}
206
207#[async_trait]
208impl ContextStore for MemcachedStore {
209    async fn save(&self, _context: &DistributedContext, _ttl: Option<Duration>) -> Result<()> {
210        // TODO: Implement memcached support
211        anyhow::bail!("Memcached store not yet implemented")
212    }
213
214    async fn load(&self, _session_id: &str) -> Result<Option<DistributedContext>> {
215        anyhow::bail!("Memcached store not yet implemented")
216    }
217
218    async fn delete(&self, _session_id: &str) -> Result<()> {
219        anyhow::bail!("Memcached store not yet implemented")
220    }
221
222    async fn exists(&self, _session_id: &str) -> Result<bool> {
223        anyhow::bail!("Memcached store not yet implemented")
224    }
225
226    async fn list_sessions(&self) -> Result<Vec<String>> {
227        anyhow::bail!("Memcached store not yet implemented")
228    }
229}
230
231/// In-memory store for testing
232pub struct InMemoryStore {
233    data: std::sync::Arc<tokio::sync::RwLock<std::collections::HashMap<String, Vec<u8>>>>,
234}
235
236impl InMemoryStore {
237    /// Create a new in-memory store
238    pub fn new() -> Self {
239        Self {
240            data: std::sync::Arc::new(tokio::sync::RwLock::new(std::collections::HashMap::new())),
241        }
242    }
243}
244
245impl Default for InMemoryStore {
246    fn default() -> Self {
247        Self::new()
248    }
249}
250
251#[async_trait]
252impl ContextStore for InMemoryStore {
253    async fn save(&self, context: &DistributedContext, _ttl: Option<Duration>) -> Result<()> {
254        let data = context.serialize()?;
255        let mut store = self.data.write().await;
256        store.insert(context.session_id.clone(), data);
257        Ok(())
258    }
259
260    async fn load(&self, session_id: &str) -> Result<Option<DistributedContext>> {
261        let store = self.data.read().await;
262        match store.get(session_id) {
263            Some(bytes) => {
264                let context = DistributedContext::deserialize(bytes)?;
265                Ok(Some(context))
266            }
267            None => Ok(None),
268        }
269    }
270
271    async fn delete(&self, session_id: &str) -> Result<()> {
272        let mut store = self.data.write().await;
273        store.remove(session_id);
274        Ok(())
275    }
276
277    async fn exists(&self, session_id: &str) -> Result<bool> {
278        let store = self.data.read().await;
279        Ok(store.contains_key(session_id))
280    }
281
282    async fn list_sessions(&self) -> Result<Vec<String>> {
283        let store = self.data.read().await;
284        Ok(store.keys().cloned().collect())
285    }
286}
287
288#[cfg(test)]
289mod tests {
290    use super::*;
291    use serde_json::json;
292
293    #[tokio::test]
294    async fn test_in_memory_store() {
295        let store = InMemoryStore::new();
296        let mut ctx = DistributedContext::new("test-session");
297        ctx.set("key1", json!("value1"));
298
299        // Save
300        store.save(&ctx, None).await.unwrap();
301
302        // Load
303        let loaded = store.load("test-session").await.unwrap();
304        assert!(loaded.is_some());
305        let loaded = loaded.unwrap();
306        assert_eq!(loaded.get("key1"), Some(&json!("value1")));
307
308        // Exists
309        assert!(store.exists("test-session").await.unwrap());
310
311        // Delete
312        store.delete("test-session").await.unwrap();
313        assert!(!store.exists("test-session").await.unwrap());
314    }
315
316    #[tokio::test]
317    async fn test_list_sessions() {
318        let store = InMemoryStore::new();
319
320        let ctx1 = DistributedContext::new("session-1");
321        let ctx2 = DistributedContext::new("session-2");
322
323        store.save(&ctx1, None).await.unwrap();
324        store.save(&ctx2, None).await.unwrap();
325
326        let sessions = store.list_sessions().await.unwrap();
327        assert_eq!(sessions.len(), 2);
328        assert!(sessions.contains(&"session-1".to_string()));
329        assert!(sessions.contains(&"session-2".to_string()));
330    }
331}