torch_web/
cache.rs

1//! High-performance caching with Redis and in-memory support
2
3use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6use tokio::sync::RwLock;
7use crate::{Request, Response, middleware::Middleware};
8
9#[cfg(feature = "cache")]
10use redis::{Client, Commands};
11
12#[cfg(feature = "json")]
13use serde::{Serialize, Deserialize};
14
15/// Cached response structure for serialization
16#[cfg(feature = "json")]
17#[derive(Serialize, Deserialize)]
18struct CachedResponse {
19    status_code: u16,
20    headers: HashMap<String, String>,
21    body: String,
22}
23
24/// Cache entry with expiration
25#[derive(Debug, Clone)]
26struct CacheEntry {
27    value: String,
28    expires_at: Option<Instant>,
29}
30
31impl CacheEntry {
32    fn new(value: String, ttl: Option<Duration>) -> Self {
33        Self {
34            value,
35            expires_at: ttl.map(|duration| Instant::now() + duration),
36        }
37    }
38
39    fn is_expired(&self) -> bool {
40        self.expires_at.map_or(false, |expires_at| Instant::now() > expires_at)
41    }
42}
43
44/// In-memory cache implementation
45pub struct MemoryCache {
46    store: Arc<RwLock<HashMap<String, CacheEntry>>>,
47    default_ttl: Option<Duration>,
48}
49
50impl MemoryCache {
51    pub fn new(default_ttl: Option<Duration>) -> Self {
52        Self {
53            store: Arc::new(RwLock::new(HashMap::new())),
54            default_ttl,
55        }
56    }
57
58    pub async fn get(&self, key: &str) -> Option<String> {
59        let store = self.store.read().await;
60        if let Some(entry) = store.get(key) {
61            if !entry.is_expired() {
62                return Some(entry.value.clone());
63            }
64        }
65        None
66    }
67
68    pub async fn set(&self, key: &str, value: &str, ttl: Option<Duration>) -> Result<(), Box<dyn std::error::Error>> {
69        let mut store = self.store.write().await;
70        let ttl = ttl.or(self.default_ttl);
71        store.insert(key.to_string(), CacheEntry::new(value.to_string(), ttl));
72        Ok(())
73    }
74
75    pub async fn delete(&self, key: &str) -> Result<bool, Box<dyn std::error::Error>> {
76        let mut store = self.store.write().await;
77        Ok(store.remove(key).is_some())
78    }
79
80    pub async fn clear(&self) -> Result<(), Box<dyn std::error::Error>> {
81        let mut store = self.store.write().await;
82        store.clear();
83        Ok(())
84    }
85
86    pub async fn cleanup_expired(&self) -> Result<usize, Box<dyn std::error::Error>> {
87        let mut store = self.store.write().await;
88        let initial_size = store.len();
89        store.retain(|_, entry| !entry.is_expired());
90        Ok(initial_size - store.len())
91    }
92
93    pub async fn size(&self) -> usize {
94        self.store.read().await.len()
95    }
96}
97
98/// Redis cache implementation
99pub struct RedisCache {
100    #[cfg(feature = "cache")]
101    client: Client,
102    #[allow(dead_code)]
103    default_ttl: Option<Duration>,
104    #[cfg(not(feature = "cache"))]
105    _phantom: std::marker::PhantomData<()>,
106}
107
108impl RedisCache {
109    #[cfg(feature = "cache")]
110    pub fn new(redis_url: &str, default_ttl: Option<Duration>) -> Result<Self, redis::RedisError> {
111        let client = Client::open(redis_url)?;
112        Ok(Self {
113            client,
114            default_ttl,
115        })
116    }
117
118    #[cfg(not(feature = "cache"))]
119    pub fn new(_redis_url: &str, default_ttl: Option<Duration>) -> Result<Self, Box<dyn std::error::Error>> {
120        Ok(Self {
121            default_ttl,
122            _phantom: std::marker::PhantomData,
123        })
124    }
125
126    #[cfg(feature = "cache")]
127    pub async fn get(&self, key: &str) -> Result<Option<String>, redis::RedisError> {
128        let mut conn = self.client.get_connection()?;
129        let result: Option<String> = conn.get(key)?;
130        Ok(result)
131    }
132
133    #[cfg(not(feature = "cache"))]
134    pub async fn get(&self, _key: &str) -> Result<Option<String>, Box<dyn std::error::Error>> {
135        Err("Redis cache feature not enabled".into())
136    }
137
138    #[cfg(feature = "cache")]
139    pub async fn set(&self, key: &str, value: &str, ttl: Option<Duration>) -> Result<(), redis::RedisError> {
140        let mut conn = self.client.get_connection()?;
141        if let Some(ttl) = ttl.or(self.default_ttl) {
142            conn.set_ex::<_, _, ()>(key, value, ttl.as_secs())?;
143        } else {
144            conn.set::<_, _, ()>(key, value)?;
145        }
146        Ok(())
147    }
148
149    #[cfg(not(feature = "cache"))]
150    pub async fn set(&self, _key: &str, _value: &str, _ttl: Option<Duration>) -> Result<(), Box<dyn std::error::Error>> {
151        Err("Redis cache feature not enabled".into())
152    }
153
154    #[cfg(feature = "cache")]
155    pub async fn delete(&self, key: &str) -> Result<bool, redis::RedisError> {
156        let mut conn = self.client.get_connection()?;
157        let result: i32 = conn.del(key)?;
158        Ok(result > 0)
159    }
160
161    #[cfg(not(feature = "cache"))]
162    pub async fn delete(&self, _key: &str) -> Result<bool, Box<dyn std::error::Error>> {
163        Err("Redis cache feature not enabled".into())
164    }
165}
166
167/// Cache trait for unified interface
168pub trait Cache: Send + Sync {
169    fn get(&self, key: &str) -> std::pin::Pin<Box<dyn std::future::Future<Output = Option<String>> + Send + '_>>;
170    fn set(&self, key: &str, value: &str, ttl: Option<Duration>) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<(), Box<dyn std::error::Error>>> + Send + '_>>;
171    fn delete(&self, key: &str) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<bool, Box<dyn std::error::Error>>> + Send + '_>>;
172}
173
174impl Cache for MemoryCache {
175    fn get(&self, key: &str) -> std::pin::Pin<Box<dyn std::future::Future<Output = Option<String>> + Send + '_>> {
176        let key = key.to_string();
177        Box::pin(async move { self.get(&key).await })
178    }
179
180    fn set(&self, key: &str, value: &str, ttl: Option<Duration>) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<(), Box<dyn std::error::Error>>> + Send + '_>> {
181        let key = key.to_string();
182        let value = value.to_string();
183        Box::pin(async move { self.set(&key, &value, ttl).await })
184    }
185
186    fn delete(&self, key: &str) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<bool, Box<dyn std::error::Error>>> + Send + '_>> {
187        let key = key.to_string();
188        Box::pin(async move { self.delete(&key).await })
189    }
190}
191
192/// Response caching middleware
193pub struct CacheMiddleware {
194    cache: Arc<dyn Cache>,
195    cache_duration: Duration,
196    cache_key_prefix: String,
197}
198
199impl CacheMiddleware {
200    pub fn new(cache: Arc<dyn Cache>, cache_duration: Duration) -> Self {
201        Self {
202            cache,
203            cache_duration,
204            cache_key_prefix: "torch_cache:".to_string(),
205        }
206    }
207
208    pub fn with_prefix(mut self, prefix: &str) -> Self {
209        self.cache_key_prefix = prefix.to_string();
210        self
211    }
212
213    fn generate_cache_key(&self, req: &Request) -> String {
214        format!("{}{}:{}", self.cache_key_prefix, req.method(), req.path())
215    }
216}
217
218impl Middleware for CacheMiddleware {
219    fn call(
220        &self,
221        req: Request,
222        next: Box<dyn Fn(Request) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> + Send + Sync>,
223    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> {
224        let cache = self.cache.clone();
225        let cache_duration = self.cache_duration;
226        let cache_key = self.generate_cache_key(&req);
227
228        Box::pin(async move {
229            let is_get_request = req.method() == &http::Method::GET;
230
231            // Only cache GET requests
232            if is_get_request {
233                // Try to get from cache first
234                if let Some(cached_data) = cache.get(&cache_key).await {
235                    #[cfg(feature = "json")]
236                    {
237                        // Parse cached response data
238                        if let Ok(cached_response) = serde_json::from_str::<CachedResponse>(&cached_data) {
239                            let mut response = Response::with_status(
240                                http::StatusCode::from_u16(cached_response.status_code).unwrap_or(http::StatusCode::OK)
241                            ).body(cached_response.body);
242
243                            // Restore headers
244                            for (name, value) in cached_response.headers {
245                                response = response.header(&name, &value);
246                            }
247
248                            return response.header("X-Cache", "HIT");
249                        }
250                    }
251
252                    #[cfg(not(feature = "json"))]
253                    {
254                        // Simple string caching when JSON feature is not available
255                        return Response::ok()
256                            .header("X-Cache", "HIT")
257                            .body(cached_data);
258                    }
259                }
260            }
261
262            // Execute the request
263            let response = next(req).await;
264
265            // Cache successful GET responses
266            if is_get_request && response.status_code().is_success() {
267                #[cfg(feature = "json")]
268                {
269                    let cached_response = CachedResponse {
270                        status_code: response.status_code().as_u16(),
271                        headers: response.headers().iter()
272                            .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
273                            .collect(),
274                        body: String::from_utf8_lossy(response.body_data()).to_string(),
275                    };
276
277                    if let Ok(serialized) = serde_json::to_string(&cached_response) {
278                        if let Err(e) = cache.set(&cache_key, &serialized, Some(cache_duration)).await {
279                            eprintln!("Failed to cache response: {}", e);
280                        }
281                    }
282                }
283
284                #[cfg(not(feature = "json"))]
285                {
286                    // Simple string caching when JSON feature is not available
287                    let response_body = String::from_utf8_lossy(response.body_data());
288                    if let Err(e) = cache.set(&cache_key, &response_body, Some(cache_duration)).await {
289                        eprintln!("Failed to cache response: {}", e);
290                    }
291                }
292            }
293
294            response.header("X-Cache", "MISS")
295        })
296    }
297}
298
299/// Cache warming utility
300pub struct CacheWarmer {
301    cache: Arc<dyn Cache>,
302}
303
304impl CacheWarmer {
305    pub fn new(cache: Arc<dyn Cache>) -> Self {
306        Self { cache }
307    }
308
309    /// Warm the cache with predefined data
310    pub async fn warm_cache(&self, data: HashMap<String, String>) -> Result<usize, Box<dyn std::error::Error>> {
311        let mut warmed_count = 0;
312        
313        for (key, value) in data {
314            if let Err(e) = self.cache.set(&key, &value, None).await {
315                eprintln!("Failed to warm cache for key {}: {}", key, e);
316            } else {
317                warmed_count += 1;
318            }
319        }
320        
321        Ok(warmed_count)
322    }
323
324    /// Preload cache from database or external source
325    pub async fn preload_from_source<F, Fut>(&self, loader: F) -> Result<usize, Box<dyn std::error::Error>>
326    where
327        F: Fn() -> Fut,
328        Fut: std::future::Future<Output = Result<HashMap<String, String>, Box<dyn std::error::Error>>>,
329    {
330        let data = loader().await?;
331        self.warm_cache(data).await
332    }
333}
334
335/// Cache statistics
336#[derive(Debug, Clone)]
337pub struct CacheStats {
338    pub hits: u64,
339    pub misses: u64,
340    pub sets: u64,
341    pub deletes: u64,
342    pub errors: u64,
343}
344
345impl CacheStats {
346    pub fn new() -> Self {
347        Self {
348            hits: 0,
349            misses: 0,
350            sets: 0,
351            deletes: 0,
352            errors: 0,
353        }
354    }
355
356    pub fn hit_rate(&self) -> f64 {
357        let total = self.hits + self.misses;
358        if total == 0 {
359            0.0
360        } else {
361            self.hits as f64 / total as f64
362        }
363    }
364}
365
366#[cfg(test)]
367mod tests {
368    use super::*;
369
370    #[tokio::test]
371    async fn test_memory_cache() {
372        let cache = MemoryCache::new(Some(Duration::from_secs(60)));
373        
374        // Test set and get
375        cache.set("key1", "value1", None).await.unwrap();
376        assert_eq!(cache.get("key1").await, Some("value1".to_string()));
377        
378        // Test non-existent key
379        assert_eq!(cache.get("nonexistent").await, None);
380        
381        // Test delete
382        assert!(cache.delete("key1").await.unwrap());
383        assert_eq!(cache.get("key1").await, None);
384    }
385
386    #[tokio::test]
387    async fn test_cache_expiration() {
388        let cache = MemoryCache::new(None);
389        
390        // Set with short TTL
391        cache.set("key1", "value1", Some(Duration::from_millis(10))).await.unwrap();
392        assert_eq!(cache.get("key1").await, Some("value1".to_string()));
393        
394        // Wait for expiration
395        tokio::time::sleep(Duration::from_millis(20)).await;
396        assert_eq!(cache.get("key1").await, None);
397    }
398
399    #[tokio::test]
400    async fn test_cache_cleanup() {
401        let cache = MemoryCache::new(None);
402        
403        // Add expired entries
404        cache.set("key1", "value1", Some(Duration::from_millis(1))).await.unwrap();
405        cache.set("key2", "value2", Some(Duration::from_millis(1))).await.unwrap();
406        cache.set("key3", "value3", None).await.unwrap(); // No expiration
407        
408        tokio::time::sleep(Duration::from_millis(10)).await;
409        
410        let cleaned = cache.cleanup_expired().await.unwrap();
411        assert_eq!(cleaned, 2);
412        assert_eq!(cache.size().await, 1);
413    }
414
415    #[test]
416    fn test_cache_stats() {
417        let mut stats = CacheStats::new();
418        stats.hits = 80;
419        stats.misses = 20;
420        
421        assert_eq!(stats.hit_rate(), 0.8);
422    }
423}