sqlx_cache/
cache_manager.rs1use std::collections::{BinaryHeap, HashMap};
2use std::sync::{Arc, mpsc};
3use std::thread;
4use std::time::Duration;
5
6use crate::cache_manager_config::CacheManagerConfig;
7use crate::cache_task::CacheTask;
8use crate::db_cache::CacheInvalidator;
9
10pub struct CacheManager {
11 invalidators: HashMap<&'static str, Arc<dyn CacheInvalidator>>,
12 priority_heap: BinaryHeap<CacheTask>,
13 config: CacheManagerConfig,
14 rx: mpsc::Receiver<CacheTask>,
15 tx: mpsc::Sender<CacheTask>,
16}
17
18impl CacheManager {
19 pub fn new(config: CacheManagerConfig) -> Self {
20 let (tx, rx) = mpsc::channel::<CacheTask>();
21
22 Self { invalidators: HashMap::default(), priority_heap: Default::default(), config, rx, tx }
23 }
24 pub fn register<T>(&mut self, invalidator: Arc<T>)
25 where
26 T: CacheInvalidator + 'static,
27 {
28 if self.invalidators.contains_key(invalidator.cache_id()) {
29 panic!("#{} cache currently registered!", invalidator.cache_id());
30 }
31 self.invalidators.insert(invalidator.cache_id(), invalidator);
32 }
33
34 pub fn sender(&self) -> mpsc::Sender<CacheTask> {
35 self.tx.clone()
36 }
37
38 pub fn start(mut self) {
39 thread::spawn(move || {
40 let max_pending_ms_await = Duration::from_millis(self.config.max_pending_ms_await());
41 let max_pending_bulk_ms_await = Duration::from_millis(self.config.max_pending_bulk_ms_await());
42 let mut tasks_pushed = 0;
43
44 loop {
45 tasks_pushed = 0;
46 if let Ok(task) = self.rx.recv_timeout(max_pending_ms_await) {
47 self.priority_heap.push(task);
48 tasks_pushed += 1;
49
50 while tasks_pushed < self.config.max_task_drain_size() {
51 match self.rx.recv_timeout(max_pending_bulk_ms_await) {
52 Ok(task) => {
53 self.priority_heap.push(task);
54 tasks_pushed += 1;
55 }
56 Err(_) => break,
57 }
58 }
59 }
60
61 loop {
62 match self.priority_heap.peek() {
63 Some(val) if !val.is_expired() => break,
64 Some(_) => {
65 if let Some(CacheTask::INVALIDATION { cache_id, key, .. }) = self.priority_heap.pop() {
66 self.invalidators.get(cache_id)
67 .expect("Invalidator found")
68 .invalidate(key);
69 }
70 }
71 None => break,
72 }
73 }
74 }
75 });
76 }
77}
78
79
80
81
82
83