rust_logic_graph/distributed/
cache.rs1use crate::distributed::context::DistributedContext;
6use crate::distributed::store::ContextStore;
7use anyhow::Result;
8use std::sync::Arc;
9use std::time::Duration;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum CacheStrategy {
14 WriteThrough,
16
17 WriteBehind,
19
20 ReadThrough,
22
23 CacheAside,
25}
26
27pub struct DistributedCache {
29 store: Arc<dyn ContextStore>,
31
32 strategy: CacheStrategy,
34
35 default_ttl: Option<Duration>,
37}
38
39impl DistributedCache {
40 pub fn new(store: Arc<dyn ContextStore>) -> Self {
42 Self {
43 store,
44 strategy: CacheStrategy::WriteThrough,
45 default_ttl: Some(Duration::from_secs(3600)), }
47 }
48
49 pub fn with_config(
51 store: Arc<dyn ContextStore>,
52 strategy: CacheStrategy,
53 default_ttl: Option<Duration>,
54 ) -> Self {
55 Self {
56 store,
57 strategy,
58 default_ttl,
59 }
60 }
61
62 pub async fn get(&self, session_id: &str) -> Result<Option<DistributedContext>> {
64 match self.strategy {
65 CacheStrategy::ReadThrough | CacheStrategy::WriteThrough => {
66 self.store.load(session_id).await
67 }
68 CacheStrategy::CacheAside | CacheStrategy::WriteBehind => {
69 self.store.load(session_id).await
70 }
71 }
72 }
73
74 pub async fn put(&self, context: &DistributedContext) -> Result<()> {
76 self.put_with_ttl(context, self.default_ttl).await
77 }
78
79 pub async fn put_with_ttl(
81 &self,
82 context: &DistributedContext,
83 ttl: Option<Duration>,
84 ) -> Result<()> {
85 match self.strategy {
86 CacheStrategy::WriteThrough => {
87 self.store.save(context, ttl).await
89 }
90 CacheStrategy::WriteBehind => {
91 let store = self.store.clone();
93 let ctx = context.clone();
94 tokio::spawn(async move {
95 let _ = store.save(&ctx, ttl).await;
96 });
97 Ok(())
98 }
99 CacheStrategy::ReadThrough | CacheStrategy::CacheAside => {
100 self.store.save(context, ttl).await
101 }
102 }
103 }
104
105 pub async fn delete(&self, session_id: &str) -> Result<()> {
107 self.store.delete(session_id).await
108 }
109
110 pub async fn exists(&self, session_id: &str) -> Result<bool> {
112 self.store.exists(session_id).await
113 }
114
115 pub async fn invalidate(&self, session_id: &str) -> Result<()> {
117 self.delete(session_id).await
118 }
119
120 pub async fn get_many(&self, session_ids: &[String]) -> Result<Vec<Option<DistributedContext>>> {
122 let mut results = Vec::new();
123
124 for session_id in session_ids {
125 let context = self.get(session_id).await?;
126 results.push(context);
127 }
128
129 Ok(results)
130 }
131
132 pub async fn put_many(&self, contexts: &[DistributedContext]) -> Result<()> {
134 for context in contexts {
135 self.put(context).await?;
136 }
137 Ok(())
138 }
139
140 pub async fn stats(&self) -> CacheStats {
142 CacheStats {
143 total_contexts: self.store.list_sessions().await.unwrap_or_default().len(),
144 strategy: self.strategy,
145 default_ttl: self.default_ttl,
146 }
147 }
148}
149
150#[derive(Debug, Clone)]
152pub struct CacheStats {
153 pub total_contexts: usize,
154 pub strategy: CacheStrategy,
155 pub default_ttl: Option<Duration>,
156}
157
158pub struct CacheWarmer {
160 cache: Arc<DistributedCache>,
161}
162
163impl CacheWarmer {
164 pub fn new(cache: Arc<DistributedCache>) -> Self {
166 Self { cache }
167 }
168
169 pub async fn warm(&self, contexts: Vec<DistributedContext>) -> Result<()> {
171 self.cache.put_many(&contexts).await
172 }
173
174 pub async fn warm_from_source(
176 &self,
177 session_ids: Vec<String>,
178 source: Arc<dyn ContextStore>,
179 ) -> Result<()> {
180 for session_id in session_ids {
181 if let Some(context) = source.load(&session_id).await? {
182 self.cache.put(&context).await?;
183 }
184 }
185 Ok(())
186 }
187}
188
189#[cfg(test)]
190mod tests {
191 use super::*;
192 use crate::distributed::store::InMemoryStore;
193 use serde_json::json;
194
195 #[tokio::test]
196 async fn test_cache_put_and_get() {
197 let store = Arc::new(InMemoryStore::new());
198 let cache = DistributedCache::new(store);
199
200 let mut ctx = DistributedContext::new("test-session");
201 ctx.set("key1", json!("value1"));
202
203 cache.put(&ctx).await.unwrap();
204
205 let loaded = cache.get("test-session").await.unwrap();
206 assert!(loaded.is_some());
207 assert_eq!(loaded.unwrap().get("key1"), Some(&json!("value1")));
208 }
209
210 #[tokio::test]
211 async fn test_cache_delete() {
212 let store = Arc::new(InMemoryStore::new());
213 let cache = DistributedCache::new(store);
214
215 let ctx = DistributedContext::new("test-session");
216 cache.put(&ctx).await.unwrap();
217
218 assert!(cache.exists("test-session").await.unwrap());
219
220 cache.delete("test-session").await.unwrap();
221
222 assert!(!cache.exists("test-session").await.unwrap());
223 }
224
225 #[tokio::test]
226 async fn test_batch_operations() {
227 let store = Arc::new(InMemoryStore::new());
228 let cache = DistributedCache::new(store);
229
230 let ctx1 = DistributedContext::new("session-1");
231 let ctx2 = DistributedContext::new("session-2");
232
233 cache.put_many(&[ctx1, ctx2]).await.unwrap();
234
235 let results = cache.get_many(&[
236 "session-1".to_string(),
237 "session-2".to_string(),
238 ]).await.unwrap();
239
240 assert_eq!(results.len(), 2);
241 assert!(results[0].is_some());
242 assert!(results[1].is_some());
243 }
244
245 #[tokio::test]
246 async fn test_cache_warmer() {
247 let store = Arc::new(InMemoryStore::new());
248 let cache = Arc::new(DistributedCache::new(store));
249 let warmer = CacheWarmer::new(cache.clone());
250
251 let ctx1 = DistributedContext::new("session-1");
252 let ctx2 = DistributedContext::new("session-2");
253
254 warmer.warm(vec![ctx1, ctx2]).await.unwrap();
255
256 assert!(cache.exists("session-1").await.unwrap());
257 assert!(cache.exists("session-2").await.unwrap());
258 }
259
260 #[tokio::test]
261 async fn test_cache_stats() {
262 let store = Arc::new(InMemoryStore::new());
263 let cache = DistributedCache::with_config(
264 store,
265 CacheStrategy::WriteThrough,
266 Some(Duration::from_secs(300)),
267 );
268
269 let ctx1 = DistributedContext::new("session-1");
270 let ctx2 = DistributedContext::new("session-2");
271
272 cache.put(&ctx1).await.unwrap();
273 cache.put(&ctx2).await.unwrap();
274
275 let stats = cache.stats().await;
276 assert_eq!(stats.total_contexts, 2);
277 assert_eq!(stats.strategy, CacheStrategy::WriteThrough);
278 }
279}