sqlx_cache/
cache_manager.rs

1use std::collections::{BinaryHeap, HashMap};
2use std::sync::{Arc, mpsc};
3use std::thread;
4use std::time::Duration;
5
6use crate::cache_manager_config::CacheManagerConfig;
7use crate::cache_task::CacheTask;
8use crate::db_cache::CacheInvalidator;
9
10pub struct CacheManager {
11    invalidators: HashMap<&'static str, Arc<dyn CacheInvalidator>>,
12    priority_heap: BinaryHeap<CacheTask>,
13    config: CacheManagerConfig,
14    rx: mpsc::Receiver<CacheTask>,
15    tx: mpsc::Sender<CacheTask>,
16}
17
18impl CacheManager {
19    pub fn new(config: CacheManagerConfig) -> Self {
20        let (tx, rx) = mpsc::channel::<CacheTask>();
21
22        Self { invalidators: HashMap::default(), priority_heap: Default::default(), config, rx, tx }
23    }
24    pub fn register<T>(&mut self, invalidator: Arc<T>)
25    where
26        T: CacheInvalidator + 'static,
27    {
28        if self.invalidators.contains_key(invalidator.cache_id()) {
29            panic!("#{} cache currently registered!", invalidator.cache_id());
30        }
31        self.invalidators.insert(invalidator.cache_id(), invalidator);
32    }
33
34    pub fn sender(&self) -> mpsc::Sender<CacheTask> {
35        self.tx.clone()
36    }
37
38    pub fn start(mut self) {
39        thread::spawn(move || {
40            let max_pending_ms_await = Duration::from_millis(self.config.max_pending_ms_await());
41            let max_pending_bulk_ms_await = Duration::from_millis(self.config.max_pending_bulk_ms_await());
42            let mut tasks_pushed = 0;
43
44            loop {
45                tasks_pushed = 0;
46                if let Ok(task) = self.rx.recv_timeout(max_pending_ms_await) {
47                    self.priority_heap.push(task);
48                    tasks_pushed += 1;
49
50                    while tasks_pushed < self.config.max_task_drain_size() {
51                        match self.rx.recv_timeout(max_pending_bulk_ms_await) {
52                            Ok(task) => {
53                                self.priority_heap.push(task);
54                                tasks_pushed += 1;
55                            }
56                            Err(_) => break,
57                        }
58                    }
59                }
60
61                loop {
62                    match self.priority_heap.peek() {
63                        Some(val) if !val.is_expired() => break,
64                        Some(_) => {
65                            if let Some(CacheTask::INVALIDATION { cache_id, key, .. }) = self.priority_heap.pop() {
66                                self.invalidators.get(cache_id)
67                                    .expect("Invalidator found")
68                                    .invalidate(key);
69                            }
70                        }
71                        None => break,
72                    }
73                }
74            }
75        });
76    }
77}
78
79
80
81
82
83