1use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6use tokio::sync::RwLock;
7use crate::{Request, Response, middleware::Middleware};
8
9#[cfg(feature = "cache")]
10use redis::{Client, Commands};
11
12#[cfg(feature = "json")]
13use serde::{Serialize, Deserialize};
14
15#[cfg(feature = "json")]
17#[derive(Serialize, Deserialize)]
18struct CachedResponse {
19 status_code: u16,
20 headers: HashMap<String, String>,
21 body: String,
22}
23
24#[derive(Debug, Clone)]
26struct CacheEntry {
27 value: String,
28 expires_at: Option<Instant>,
29}
30
31impl CacheEntry {
32 fn new(value: String, ttl: Option<Duration>) -> Self {
33 Self {
34 value,
35 expires_at: ttl.map(|duration| Instant::now() + duration),
36 }
37 }
38
39 fn is_expired(&self) -> bool {
40 self.expires_at.map_or(false, |expires_at| Instant::now() > expires_at)
41 }
42}
43
44pub struct MemoryCache {
46 store: Arc<RwLock<HashMap<String, CacheEntry>>>,
47 default_ttl: Option<Duration>,
48}
49
50impl MemoryCache {
51 pub fn new(default_ttl: Option<Duration>) -> Self {
52 Self {
53 store: Arc::new(RwLock::new(HashMap::new())),
54 default_ttl,
55 }
56 }
57
58 pub async fn get(&self, key: &str) -> Option<String> {
59 let store = self.store.read().await;
60 if let Some(entry) = store.get(key) {
61 if !entry.is_expired() {
62 return Some(entry.value.clone());
63 }
64 }
65 None
66 }
67
68 pub async fn set(&self, key: &str, value: &str, ttl: Option<Duration>) -> Result<(), Box<dyn std::error::Error>> {
69 let mut store = self.store.write().await;
70 let ttl = ttl.or(self.default_ttl);
71 store.insert(key.to_string(), CacheEntry::new(value.to_string(), ttl));
72 Ok(())
73 }
74
75 pub async fn delete(&self, key: &str) -> Result<bool, Box<dyn std::error::Error>> {
76 let mut store = self.store.write().await;
77 Ok(store.remove(key).is_some())
78 }
79
80 pub async fn clear(&self) -> Result<(), Box<dyn std::error::Error>> {
81 let mut store = self.store.write().await;
82 store.clear();
83 Ok(())
84 }
85
86 pub async fn cleanup_expired(&self) -> Result<usize, Box<dyn std::error::Error>> {
87 let mut store = self.store.write().await;
88 let initial_size = store.len();
89 store.retain(|_, entry| !entry.is_expired());
90 Ok(initial_size - store.len())
91 }
92
93 pub async fn size(&self) -> usize {
94 self.store.read().await.len()
95 }
96}
97
98pub struct RedisCache {
100 #[cfg(feature = "cache")]
101 client: Client,
102 #[allow(dead_code)]
103 default_ttl: Option<Duration>,
104 #[cfg(not(feature = "cache"))]
105 _phantom: std::marker::PhantomData<()>,
106}
107
108impl RedisCache {
109 #[cfg(feature = "cache")]
110 pub fn new(redis_url: &str, default_ttl: Option<Duration>) -> Result<Self, redis::RedisError> {
111 let client = Client::open(redis_url)?;
112 Ok(Self {
113 client,
114 default_ttl,
115 })
116 }
117
118 #[cfg(not(feature = "cache"))]
119 pub fn new(_redis_url: &str, default_ttl: Option<Duration>) -> Result<Self, Box<dyn std::error::Error>> {
120 Ok(Self {
121 default_ttl,
122 _phantom: std::marker::PhantomData,
123 })
124 }
125
126 #[cfg(feature = "cache")]
127 pub async fn get(&self, key: &str) -> Result<Option<String>, redis::RedisError> {
128 let mut conn = self.client.get_connection()?;
129 let result: Option<String> = conn.get(key)?;
130 Ok(result)
131 }
132
133 #[cfg(not(feature = "cache"))]
134 pub async fn get(&self, _key: &str) -> Result<Option<String>, Box<dyn std::error::Error>> {
135 Err("Redis cache feature not enabled".into())
136 }
137
138 #[cfg(feature = "cache")]
139 pub async fn set(&self, key: &str, value: &str, ttl: Option<Duration>) -> Result<(), redis::RedisError> {
140 let mut conn = self.client.get_connection()?;
141 if let Some(ttl) = ttl.or(self.default_ttl) {
142 conn.set_ex::<_, _, ()>(key, value, ttl.as_secs())?;
143 } else {
144 conn.set::<_, _, ()>(key, value)?;
145 }
146 Ok(())
147 }
148
149 #[cfg(not(feature = "cache"))]
150 pub async fn set(&self, _key: &str, _value: &str, _ttl: Option<Duration>) -> Result<(), Box<dyn std::error::Error>> {
151 Err("Redis cache feature not enabled".into())
152 }
153
154 #[cfg(feature = "cache")]
155 pub async fn delete(&self, key: &str) -> Result<bool, redis::RedisError> {
156 let mut conn = self.client.get_connection()?;
157 let result: i32 = conn.del(key)?;
158 Ok(result > 0)
159 }
160
161 #[cfg(not(feature = "cache"))]
162 pub async fn delete(&self, _key: &str) -> Result<bool, Box<dyn std::error::Error>> {
163 Err("Redis cache feature not enabled".into())
164 }
165}
166
167pub trait Cache: Send + Sync {
169 fn get(&self, key: &str) -> std::pin::Pin<Box<dyn std::future::Future<Output = Option<String>> + Send + '_>>;
170 fn set(&self, key: &str, value: &str, ttl: Option<Duration>) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<(), Box<dyn std::error::Error>>> + Send + '_>>;
171 fn delete(&self, key: &str) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<bool, Box<dyn std::error::Error>>> + Send + '_>>;
172}
173
174impl Cache for MemoryCache {
175 fn get(&self, key: &str) -> std::pin::Pin<Box<dyn std::future::Future<Output = Option<String>> + Send + '_>> {
176 let key = key.to_string();
177 Box::pin(async move { self.get(&key).await })
178 }
179
180 fn set(&self, key: &str, value: &str, ttl: Option<Duration>) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<(), Box<dyn std::error::Error>>> + Send + '_>> {
181 let key = key.to_string();
182 let value = value.to_string();
183 Box::pin(async move { self.set(&key, &value, ttl).await })
184 }
185
186 fn delete(&self, key: &str) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<bool, Box<dyn std::error::Error>>> + Send + '_>> {
187 let key = key.to_string();
188 Box::pin(async move { self.delete(&key).await })
189 }
190}
191
192pub struct CacheMiddleware {
194 cache: Arc<dyn Cache>,
195 cache_duration: Duration,
196 cache_key_prefix: String,
197}
198
199impl CacheMiddleware {
200 pub fn new(cache: Arc<dyn Cache>, cache_duration: Duration) -> Self {
201 Self {
202 cache,
203 cache_duration,
204 cache_key_prefix: "torch_cache:".to_string(),
205 }
206 }
207
208 pub fn with_prefix(mut self, prefix: &str) -> Self {
209 self.cache_key_prefix = prefix.to_string();
210 self
211 }
212
213 fn generate_cache_key(&self, req: &Request) -> String {
214 format!("{}{}:{}", self.cache_key_prefix, req.method(), req.path())
215 }
216}
217
218impl Middleware for CacheMiddleware {
219 fn call(
220 &self,
221 req: Request,
222 next: Box<dyn Fn(Request) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> + Send + Sync>,
223 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> {
224 let cache = self.cache.clone();
225 let cache_duration = self.cache_duration;
226 let cache_key = self.generate_cache_key(&req);
227
228 Box::pin(async move {
229 let is_get_request = req.method() == &http::Method::GET;
230
231 if is_get_request {
233 if let Some(cached_data) = cache.get(&cache_key).await {
235 #[cfg(feature = "json")]
236 {
237 if let Ok(cached_response) = serde_json::from_str::<CachedResponse>(&cached_data) {
239 let mut response = Response::with_status(
240 http::StatusCode::from_u16(cached_response.status_code).unwrap_or(http::StatusCode::OK)
241 ).body(cached_response.body);
242
243 for (name, value) in cached_response.headers {
245 response = response.header(&name, &value);
246 }
247
248 return response.header("X-Cache", "HIT");
249 }
250 }
251
252 #[cfg(not(feature = "json"))]
253 {
254 return Response::ok()
256 .header("X-Cache", "HIT")
257 .body(cached_data);
258 }
259 }
260 }
261
262 let response = next(req).await;
264
265 if is_get_request && response.status_code().is_success() {
267 #[cfg(feature = "json")]
268 {
269 let cached_response = CachedResponse {
270 status_code: response.status_code().as_u16(),
271 headers: response.headers().iter()
272 .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
273 .collect(),
274 body: String::from_utf8_lossy(response.body_data()).to_string(),
275 };
276
277 if let Ok(serialized) = serde_json::to_string(&cached_response) {
278 if let Err(e) = cache.set(&cache_key, &serialized, Some(cache_duration)).await {
279 eprintln!("Failed to cache response: {}", e);
280 }
281 }
282 }
283
284 #[cfg(not(feature = "json"))]
285 {
286 let response_body = String::from_utf8_lossy(response.body_data());
288 if let Err(e) = cache.set(&cache_key, &response_body, Some(cache_duration)).await {
289 eprintln!("Failed to cache response: {}", e);
290 }
291 }
292 }
293
294 response.header("X-Cache", "MISS")
295 })
296 }
297}
298
299pub struct CacheWarmer {
301 cache: Arc<dyn Cache>,
302}
303
304impl CacheWarmer {
305 pub fn new(cache: Arc<dyn Cache>) -> Self {
306 Self { cache }
307 }
308
309 pub async fn warm_cache(&self, data: HashMap<String, String>) -> Result<usize, Box<dyn std::error::Error>> {
311 let mut warmed_count = 0;
312
313 for (key, value) in data {
314 if let Err(e) = self.cache.set(&key, &value, None).await {
315 eprintln!("Failed to warm cache for key {}: {}", key, e);
316 } else {
317 warmed_count += 1;
318 }
319 }
320
321 Ok(warmed_count)
322 }
323
324 pub async fn preload_from_source<F, Fut>(&self, loader: F) -> Result<usize, Box<dyn std::error::Error>>
326 where
327 F: Fn() -> Fut,
328 Fut: std::future::Future<Output = Result<HashMap<String, String>, Box<dyn std::error::Error>>>,
329 {
330 let data = loader().await?;
331 self.warm_cache(data).await
332 }
333}
334
335#[derive(Debug, Clone)]
337pub struct CacheStats {
338 pub hits: u64,
339 pub misses: u64,
340 pub sets: u64,
341 pub deletes: u64,
342 pub errors: u64,
343}
344
345impl CacheStats {
346 pub fn new() -> Self {
347 Self {
348 hits: 0,
349 misses: 0,
350 sets: 0,
351 deletes: 0,
352 errors: 0,
353 }
354 }
355
356 pub fn hit_rate(&self) -> f64 {
357 let total = self.hits + self.misses;
358 if total == 0 {
359 0.0
360 } else {
361 self.hits as f64 / total as f64
362 }
363 }
364}
365
366#[cfg(test)]
367mod tests {
368 use super::*;
369
370 #[tokio::test]
371 async fn test_memory_cache() {
372 let cache = MemoryCache::new(Some(Duration::from_secs(60)));
373
374 cache.set("key1", "value1", None).await.unwrap();
376 assert_eq!(cache.get("key1").await, Some("value1".to_string()));
377
378 assert_eq!(cache.get("nonexistent").await, None);
380
381 assert!(cache.delete("key1").await.unwrap());
383 assert_eq!(cache.get("key1").await, None);
384 }
385
386 #[tokio::test]
387 async fn test_cache_expiration() {
388 let cache = MemoryCache::new(None);
389
390 cache.set("key1", "value1", Some(Duration::from_millis(10))).await.unwrap();
392 assert_eq!(cache.get("key1").await, Some("value1".to_string()));
393
394 tokio::time::sleep(Duration::from_millis(20)).await;
396 assert_eq!(cache.get("key1").await, None);
397 }
398
399 #[tokio::test]
400 async fn test_cache_cleanup() {
401 let cache = MemoryCache::new(None);
402
403 cache.set("key1", "value1", Some(Duration::from_millis(1))).await.unwrap();
405 cache.set("key2", "value2", Some(Duration::from_millis(1))).await.unwrap();
406 cache.set("key3", "value3", None).await.unwrap(); tokio::time::sleep(Duration::from_millis(10)).await;
409
410 let cleaned = cache.cleanup_expired().await.unwrap();
411 assert_eq!(cleaned, 2);
412 assert_eq!(cache.size().await, 1);
413 }
414
415 #[test]
416 fn test_cache_stats() {
417 let mut stats = CacheStats::new();
418 stats.hits = 80;
419 stats.misses = 20;
420
421 assert_eq!(stats.hit_rate(), 0.8);
422 }
423}