1use dashmap::DashMap;
22use http::{Method, Request, Uri};
23use std::sync::atomic::{AtomicU64, Ordering};
24use std::sync::Arc;
25use std::time::{Duration, SystemTime, UNIX_EPOCH};
26use tokio::sync::{RwLock, Semaphore};
27use tokio::task::JoinHandle;
28
29#[cfg(feature = "metrics")]
30use metrics::{counter, gauge, histogram};
31
32#[cfg(feature = "tracing")]
33use tracing::{debug, error, info, instrument};
34
35#[derive(Debug, Clone)]
37pub struct AutoRefreshConfig {
38 pub enabled: bool,
40 pub min_hits_per_minute: f64,
42 pub check_interval: Duration,
44 pub max_concurrent_refreshes: usize,
46 pub cleanup_interval: Duration,
48 pub hit_rate_window: Duration,
50}
51
52impl Default for AutoRefreshConfig {
53 fn default() -> Self {
54 Self {
55 enabled: false,
56 min_hits_per_minute: 10.0,
57 check_interval: Duration::from_secs(10),
58 max_concurrent_refreshes: 10,
59 cleanup_interval: Duration::from_secs(60),
60 hit_rate_window: Duration::from_secs(60),
61 }
62 }
63}
64
65impl AutoRefreshConfig {
66 pub fn validate(&self) -> Result<(), String> {
68 if self.min_hits_per_minute < 0.0 {
69 return Err("min_hits_per_minute must be non-negative".to_string());
70 }
71 if self.max_concurrent_refreshes == 0 {
72 return Err("max_concurrent_refreshes must be at least 1".to_string());
73 }
74 if self.check_interval.as_millis() == 0 {
75 return Err("check_interval must be greater than zero".to_string());
76 }
77 if self.cleanup_interval.as_millis() == 0 {
78 return Err("cleanup_interval must be greater than zero".to_string());
79 }
80 if self.hit_rate_window.as_millis() == 0 {
81 return Err("hit_rate_window must be greater than zero".to_string());
82 }
83 Ok(())
84 }
85
86 pub fn enabled(min_hits_per_minute: f64) -> Self {
88 Self {
89 enabled: true,
90 min_hits_per_minute,
91 ..Default::default()
92 }
93 }
94}
95
96#[derive(Debug, Clone)]
98pub struct RefreshMetadata {
99 pub method: Method,
100 pub uri: Uri,
101 pub headers: Vec<(String, Vec<u8>)>,
102}
103
104impl RefreshMetadata {
105 pub fn from_request<B>(req: &Request<B>) -> Self {
107 Self {
108 method: req.method().clone(),
109 uri: req.uri().clone(),
110 headers: Vec::new(),
111 }
112 }
113
114 pub fn from_request_with_headers<B>(req: &Request<B>, header_names: &[String]) -> Self {
116 let headers = req
117 .headers()
118 .iter()
119 .filter(|(name, _)| {
120 let name_str = name.as_str().to_ascii_lowercase();
121 header_names
122 .iter()
123 .any(|h| h.to_ascii_lowercase() == name_str)
124 })
125 .map(|(name, value)| (name.as_str().to_owned(), value.as_bytes().to_vec()))
126 .collect();
127
128 Self {
129 method: req.method().clone(),
130 uri: req.uri().clone(),
131 headers,
132 }
133 }
134
135 pub fn try_into_request(&self) -> Option<Request<()>> {
139 let mut builder = Request::builder()
140 .method(self.method.clone())
141 .uri(self.uri.clone());
142
143 for (name, value) in &self.headers {
144 if let Ok(header_name) = http::header::HeaderName::from_bytes(name.as_bytes()) {
145 if let Ok(header_value) = http::header::HeaderValue::from_bytes(value) {
146 builder = builder.header(header_name, header_value);
147 }
148 }
149 }
150
151 builder.body(()).ok()
152 }
153}
154
155#[derive(Debug)]
157struct AccessStats {
158 hits: AtomicU64,
160 last_access_ms: AtomicU64,
162 window_start_ms: AtomicU64,
164 window_hits: AtomicU64,
166}
167
168impl AccessStats {
169 fn new() -> Self {
170 let now_ms = SystemTime::now()
171 .duration_since(UNIX_EPOCH)
172 .unwrap_or_default()
173 .as_millis() as u64;
174
175 Self {
176 hits: AtomicU64::new(0),
177 last_access_ms: AtomicU64::new(now_ms),
178 window_start_ms: AtomicU64::new(now_ms),
179 window_hits: AtomicU64::new(0),
180 }
181 }
182
183 fn record_hit(&self, window_duration_ms: u64) {
184 let now_ms = SystemTime::now()
185 .duration_since(UNIX_EPOCH)
186 .unwrap_or_default()
187 .as_millis() as u64;
188
189 self.hits.fetch_add(1, Ordering::Relaxed);
190 self.last_access_ms.store(now_ms, Ordering::Relaxed);
191
192 let window_start = self.window_start_ms.load(Ordering::Relaxed);
193 if now_ms.saturating_sub(window_start) > window_duration_ms {
194 self.window_start_ms.store(now_ms, Ordering::Relaxed);
196 self.window_hits.store(1, Ordering::Relaxed);
197 } else {
198 self.window_hits.fetch_add(1, Ordering::Relaxed);
199 }
200 }
201
202 fn hits_per_minute(&self, window_duration_ms: u64) -> f64 {
203 let now_ms = SystemTime::now()
204 .duration_since(UNIX_EPOCH)
205 .unwrap_or_default()
206 .as_millis() as u64;
207
208 let window_start = self.window_start_ms.load(Ordering::Relaxed);
209 let window_hits = self.window_hits.load(Ordering::Relaxed);
210
211 let elapsed_ms = now_ms.saturating_sub(window_start);
212 if elapsed_ms == 0 {
213 return 0.0;
214 }
215
216 if elapsed_ms > window_duration_ms {
218 return 0.0;
219 }
220
221 let elapsed_minutes = elapsed_ms as f64 / 60_000.0;
222 if elapsed_minutes == 0.0 {
223 return 0.0;
224 }
225
226 window_hits as f64 / elapsed_minutes
227 }
228
229 fn last_access(&self) -> SystemTime {
230 let ms = self.last_access_ms.load(Ordering::Relaxed);
231 UNIX_EPOCH + Duration::from_millis(ms)
232 }
233
234 fn total_hits(&self) -> u64 {
235 self.hits.load(Ordering::Relaxed)
236 }
237}
238
239#[derive(Clone)]
241pub struct AccessTracker {
242 stats: Arc<DashMap<String, Arc<AccessStats>>>,
243 config: Arc<AutoRefreshConfig>,
244}
245
246impl AccessTracker {
247 pub fn new(config: AutoRefreshConfig) -> Self {
248 Self {
249 stats: Arc::new(DashMap::new()),
250 config: Arc::new(config),
251 }
252 }
253
254 pub fn record_hit(&self, key: &str) {
256 let window_duration_ms = self.config.hit_rate_window.as_millis() as u64;
257
258 let stats = self
259 .stats
260 .entry(key.to_owned())
261 .or_insert_with(|| Arc::new(AccessStats::new()))
262 .clone();
263
264 stats.record_hit(window_duration_ms);
265 }
266
267 pub fn hits_per_minute(&self, key: &str) -> f64 {
269 let window_duration_ms = self.config.hit_rate_window.as_millis() as u64;
270
271 self.stats
272 .get(key)
273 .map(|stats| stats.hits_per_minute(window_duration_ms))
274 .unwrap_or(0.0)
275 }
276
277 pub fn should_auto_refresh(&self, key: &str) -> bool {
279 let rate = self.hits_per_minute(key);
280 rate >= self.config.min_hits_per_minute
281 }
282
283 pub fn cleanup_stale(&self, max_age: Duration) {
285 let now = SystemTime::now();
286 let keys_to_remove: Vec<String> = self
287 .stats
288 .iter()
289 .filter_map(|entry| {
290 let last_access = entry.value().last_access();
291 if now.duration_since(last_access).ok()? > max_age {
292 Some(entry.key().clone())
293 } else {
294 None
295 }
296 })
297 .collect();
298
299 for key in keys_to_remove {
300 self.stats.remove(&key);
301 }
302
303 #[cfg(feature = "metrics")]
304 gauge!("tower_http_cache.auto_refresh.active_keys").set(self.stats.len() as f64);
305 }
306
307 pub fn tracked_keys(&self) -> usize {
309 self.stats.len()
310 }
311
312 pub fn get_stats(&self, key: &str) -> Option<(u64, f64)> {
314 let window_duration_ms = self.config.hit_rate_window.as_millis() as u64;
315 self.stats.get(key).map(|stats| {
316 (
317 stats.total_hits(),
318 stats.hits_per_minute(window_duration_ms),
319 )
320 })
321 }
322}
323
324pub type RefreshResult = Result<(), Box<dyn std::error::Error + Send + Sync>>;
326
327pub type RefreshFuture = std::pin::Pin<Box<dyn std::future::Future<Output = RefreshResult> + Send>>;
329
330pub trait RefreshCallback: Send + Sync {
334 fn refresh(&self, key: String, metadata: RefreshMetadata) -> RefreshFuture;
336}
337
338pub struct RefreshManager {
340 tracker: AccessTracker,
341 metadata_store: Arc<DashMap<String, RefreshMetadata>>,
342 config: Arc<AutoRefreshConfig>,
343 shutdown_tx: Arc<RwLock<Option<tokio::sync::oneshot::Sender<()>>>>,
344 pub(crate) task_handle: Arc<RwLock<Option<JoinHandle<()>>>>,
345}
346
347impl RefreshManager {
348 pub fn new(config: AutoRefreshConfig) -> Self {
350 Self {
351 tracker: AccessTracker::new(config.clone()),
352 metadata_store: Arc::new(DashMap::new()),
353 config: Arc::new(config),
354 shutdown_tx: Arc::new(RwLock::new(None)),
355 task_handle: Arc::new(RwLock::new(None)),
356 }
357 }
358
359 pub fn tracker(&self) -> &AccessTracker {
361 &self.tracker
362 }
363
364 pub fn store_metadata(&self, key: String, metadata: RefreshMetadata) {
366 self.metadata_store.insert(key, metadata);
367 }
368
369 pub fn get_metadata(&self, key: &str) -> Option<RefreshMetadata> {
371 self.metadata_store
372 .get(key)
373 .map(|entry| entry.value().clone())
374 }
375
376 pub async fn start<C>(&self, callback: Arc<C>) -> Result<(), String>
380 where
381 C: RefreshCallback + 'static,
382 {
383 self.config.validate()?;
385
386 {
388 let task_guard = self.task_handle.read().await;
389 if task_guard.is_some() {
390 return Err("Refresh manager is already running".to_string());
391 }
392 }
393
394 let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
395
396 {
398 let mut tx_guard = self.shutdown_tx.write().await;
399 *tx_guard = Some(shutdown_tx);
400 }
401
402 let config = self.config.clone();
403 let tracker = self.tracker.clone();
404 let metadata_store = self.metadata_store.clone();
405
406 let handle = tokio::spawn(async move {
407 refresh_task(config, tracker, metadata_store, callback, shutdown_rx).await;
408 });
409
410 {
412 let mut handle_guard = self.task_handle.write().await;
413 *handle_guard = Some(handle);
414 }
415
416 #[cfg(feature = "tracing")]
417 info!("Auto-refresh background task started");
418
419 Ok(())
420 }
421
422 pub async fn shutdown(&self) {
424 {
426 let mut tx_guard = self.shutdown_tx.write().await;
427 if let Some(tx) = tx_guard.take() {
428 let _ = tx.send(());
429 }
430 }
431
432 {
434 let mut handle_guard = self.task_handle.write().await;
435 if let Some(handle) = handle_guard.take() {
436 let _ = handle.await;
437 }
438 }
439
440 #[cfg(feature = "tracing")]
441 info!("Auto-refresh background task shutdown complete");
442 }
443}
444
445impl Drop for RefreshManager {
446 fn drop(&mut self) {
447 if let Ok(mut tx_guard) = self.shutdown_tx.try_write() {
451 if let Some(tx) = tx_guard.take() {
452 let _ = tx.send(());
453 }
454 }
455 }
456}
457
458#[cfg_attr(feature = "tracing", instrument(skip_all, name = "auto_refresh_task"))]
460async fn refresh_task<C>(
461 config: Arc<AutoRefreshConfig>,
462 tracker: AccessTracker,
463 metadata_store: Arc<DashMap<String, RefreshMetadata>>,
464 callback: Arc<C>,
465 mut shutdown_rx: tokio::sync::oneshot::Receiver<()>,
466) where
467 C: RefreshCallback + 'static,
468{
469 let mut check_interval = tokio::time::interval(config.check_interval);
470 let mut cleanup_interval = tokio::time::interval(config.cleanup_interval);
471
472 check_interval.tick().await;
474 cleanup_interval.tick().await;
475
476 let semaphore = Arc::new(Semaphore::new(config.max_concurrent_refreshes));
477
478 #[cfg(feature = "tracing")]
479 debug!(
480 max_concurrent = config.max_concurrent_refreshes,
481 check_interval_ms = config.check_interval.as_millis(),
482 "Auto-refresh task loop started"
483 );
484
485 loop {
486 tokio::select! {
487 _ = check_interval.tick() => {
488 let candidates = find_refresh_candidates(&tracker, &metadata_store);
490
491 #[cfg(feature = "tracing")]
492 debug!(candidates = candidates.len(), "Found refresh candidates");
493
494 for (key, metadata) in candidates {
495 let permit = match semaphore.clone().try_acquire_owned() {
496 Ok(permit) => permit,
497 Err(_) => {
498 #[cfg(feature = "metrics")]
499 counter!("tower_http_cache.auto_refresh.skipped").increment(1);
500
501 #[cfg(feature = "tracing")]
502 debug!(key = %key, "Skipped refresh due to concurrency limit");
503 continue;
504 }
505 };
506
507 let callback = callback.clone();
508 let key_clone = key.clone();
509
510 tokio::spawn(async move {
511 let _permit = permit; #[cfg(feature = "metrics")]
514 {
515 counter!("tower_http_cache.auto_refresh.triggered").increment(1);
516 let start = std::time::Instant::now();
517
518 match callback.refresh(key_clone.clone(), metadata).await {
519 Ok(()) => {
520 counter!("tower_http_cache.auto_refresh.success").increment(1);
521 histogram!("tower_http_cache.auto_refresh.latency")
522 .record(start.elapsed().as_secs_f64());
523
524 #[cfg(feature = "tracing")]
525 debug!(key = %key_clone, latency_ms = start.elapsed().as_millis(), "Refresh succeeded");
526 }
527 Err(err) => {
528 counter!("tower_http_cache.auto_refresh.error").increment(1);
529
530 #[cfg(feature = "tracing")]
531 error!(key = %key_clone, error = %err, "Refresh failed");
532 }
533 }
534 }
535
536 #[cfg(not(feature = "metrics"))]
537 {
538 let result = callback.refresh(key_clone.clone(), metadata).await;
539
540 #[cfg(feature = "tracing")]
541 match result {
542 Ok(()) => debug!(key = %key_clone, "Refresh succeeded"),
543 Err(err) => error!(key = %key_clone, error = %err, "Refresh failed"),
544 }
545
546 #[cfg(not(feature = "tracing"))]
547 let _ = result;
548 }
549 });
550 }
551 }
552 _ = cleanup_interval.tick() => {
553 let max_age = config.hit_rate_window * 2;
555 tracker.cleanup_stale(max_age);
556
557 #[cfg(feature = "tracing")]
558 debug!(tracked_keys = tracker.tracked_keys(), "Cleaned up stale tracking data");
559 }
560 _ = &mut shutdown_rx => {
561 #[cfg(feature = "tracing")]
562 info!("Received shutdown signal, stopping auto-refresh task");
563 break;
564 }
565 }
566 }
567}
568
569fn find_refresh_candidates(
571 tracker: &AccessTracker,
572 metadata_store: &DashMap<String, RefreshMetadata>,
573) -> Vec<(String, RefreshMetadata)> {
574 let mut candidates = Vec::new();
575
576 for entry in metadata_store.iter() {
577 let key = entry.key();
578 if tracker.should_auto_refresh(key) {
579 candidates.push((key.clone(), entry.value().clone()));
580 }
581 }
582
583 candidates
584}
585
586#[cfg(test)]
587mod tests {
588 use super::*;
589 use std::sync::atomic::{AtomicUsize, Ordering as AtomicOrdering};
590
591 #[test]
592 fn auto_refresh_config_validation() {
593 let valid = AutoRefreshConfig::default();
594 assert!(valid.validate().is_ok());
595
596 let invalid_hits = AutoRefreshConfig {
597 min_hits_per_minute: -1.0,
598 ..Default::default()
599 };
600 assert!(invalid_hits.validate().is_err());
601
602 let invalid_concurrent = AutoRefreshConfig {
603 max_concurrent_refreshes: 0,
604 ..Default::default()
605 };
606 assert!(invalid_concurrent.validate().is_err());
607 }
608
609 #[test]
610 fn access_stats_tracks_hits() {
611 let stats = AccessStats::new();
612 let window_ms = 100; stats.record_hit(window_ms);
615 stats.record_hit(window_ms);
616 stats.record_hit(window_ms);
617
618 assert_eq!(stats.total_hits(), 3);
619 let rate = stats.hits_per_minute(window_ms);
621 assert!(
622 rate >= 0.0,
623 "Hit rate should be non-negative, got: {}",
624 rate
625 );
626 }
627
628 #[test]
629 fn access_tracker_records_and_queries() {
630 let config = AutoRefreshConfig {
631 min_hits_per_minute: 5.0,
632 hit_rate_window: Duration::from_secs(60),
633 ..Default::default()
634 };
635 let tracker = AccessTracker::new(config);
636
637 tracker.record_hit("key1");
638 tracker.record_hit("key1");
639 tracker.record_hit("key2");
640
641 assert!(tracker.tracked_keys() >= 2);
642
643 let (hits, _rate) = tracker.get_stats("key1").expect("key1 should exist");
644 assert_eq!(hits, 2);
645 }
646
647 #[test]
648 fn refresh_metadata_roundtrip() {
649 let req = Request::builder()
650 .method(Method::GET)
651 .uri("https://example.com/test")
652 .body(())
653 .unwrap();
654
655 let metadata = RefreshMetadata::from_request(&req);
656 let reconstructed = metadata.try_into_request();
657
658 assert!(reconstructed.is_some());
659 let reconstructed = reconstructed.unwrap();
660 assert_eq!(reconstructed.method(), Method::GET);
661 assert_eq!(reconstructed.uri().path(), "/test");
662 }
663
664 #[test]
665 fn refresh_metadata_with_headers() {
666 let req = Request::builder()
667 .method(Method::GET)
668 .uri("https://example.com/test")
669 .header("authorization", "Bearer token")
670 .header("x-custom", "value")
671 .body(())
672 .unwrap();
673
674 let metadata =
675 RefreshMetadata::from_request_with_headers(&req, &["authorization".to_string()]);
676
677 assert_eq!(metadata.headers.len(), 1);
678 assert_eq!(metadata.headers[0].0, "authorization");
679 }
680
681 #[tokio::test]
682 async fn refresh_manager_lifecycle() {
683 struct TestCallback {
684 call_count: Arc<AtomicUsize>,
685 }
686
687 impl RefreshCallback for TestCallback {
688 fn refresh(&self, _key: String, _metadata: RefreshMetadata) -> RefreshFuture {
689 let count = self.call_count.clone();
690 Box::pin(async move {
691 count.fetch_add(1, AtomicOrdering::Relaxed);
692 Ok(())
693 })
694 }
695 }
696
697 let config = AutoRefreshConfig {
698 enabled: true,
699 check_interval: Duration::from_millis(100),
700 cleanup_interval: Duration::from_secs(10),
701 ..Default::default()
702 };
703
704 let manager = RefreshManager::new(config);
705 let callback = Arc::new(TestCallback {
706 call_count: Arc::new(AtomicUsize::new(0)),
707 });
708
709 assert!(manager.start(callback).await.is_ok());
710
711 tokio::time::sleep(Duration::from_millis(50)).await;
713
714 manager.shutdown().await;
715 }
716
717 #[test]
718 fn access_tracker_cleanup_removes_stale() {
719 let config = AutoRefreshConfig {
720 hit_rate_window: Duration::from_secs(60),
721 ..Default::default()
722 };
723 let tracker = AccessTracker::new(config);
724
725 tracker.record_hit("key1");
726 tracker.record_hit("key2");
727
728 assert_eq!(tracker.tracked_keys(), 2);
729
730 tracker.cleanup_stale(Duration::from_secs(3600));
732 assert_eq!(tracker.tracked_keys(), 2);
733 }
734
735 #[test]
736 fn find_refresh_candidates_filters_by_rate() {
737 let config = AutoRefreshConfig {
739 min_hits_per_minute: 0.1,
740 hit_rate_window: Duration::from_millis(100),
741 ..Default::default()
742 };
743 let tracker = AccessTracker::new(config);
744 let metadata_store = DashMap::new();
745
746 let metadata = RefreshMetadata {
747 method: Method::GET,
748 uri: Uri::from_static("http://example.com"),
749 headers: Vec::new(),
750 };
751
752 for _ in 0..10 {
754 tracker.record_hit("key1");
755 }
756 metadata_store.insert("key1".to_string(), metadata.clone());
757
758 metadata_store.insert("key2".to_string(), metadata.clone());
760
761 let candidates = find_refresh_candidates(&tracker, &metadata_store);
762
763 assert!(
766 candidates.len() <= 1,
767 "Expected at most 1 candidate, got: {}",
768 candidates.len()
769 );
770 if !candidates.is_empty() {
771 assert_eq!(candidates[0].0, "key1");
772 }
773 }
774}