sqlx_cache/
db_cache.rs

1use std::any::Any;
2use std::error::Error;
3use std::marker::PhantomData;
4use std::sync::Arc;
5use std::sync::mpsc::{Sender, SendError};
6
7use dashmap::DashMap;
8use sqlx::Pool;
9
10use crate::cache_manager::CacheManager;
11use crate::cache_task::CacheTask;
12use crate::db_cache_config::DbCacheConfig;
13use crate::db_commands::DbCommands;
14use crate::utils::GenericError;
15
16struct CacheEventProcessor<DBC>
17where
18    DBC: DbCommands + 'static,
19{
20    db_cache_config: DbCacheConfig,
21    tx: Sender<CacheTask>,
22    _phantom: PhantomData<DBC>,
23}
24
25impl<DBC> CacheEventProcessor<DBC>
26where
27    DBC: DbCommands + 'static,
28{
29    pub fn new(db_cache_config: DbCacheConfig, tx: Sender<CacheTask>) -> Self {
30        Self { db_cache_config, tx, _phantom: Default::default() }
31    }
32    pub fn invalidate(&self, key: DBC::Key) -> Result<(), SendError<CacheTask>> {
33        let task = CacheTask::invalidation(self.db_cache_config.expires_in(), self.db_cache_config.cache_id(), Box::new(key));
34        self.tx.send(task)
35    }
36}
37
38pub struct DbCache<DBC>
39where
40    DBC: DbCommands + 'static,
41{
42    db_pool: Pool<DBC::Db>,
43    cache_event_processor: CacheEventProcessor<DBC>,
44    db_storage: DashMap<DBC::Key, DBC::Value>,
45    config: DbCacheConfig,
46}
47
48impl<DBC> DbCache<DBC>
49where
50    DBC: DbCommands + 'static,
51{
52    pub fn build(cache_manager: &mut CacheManager, config: DbCacheConfig, db_pool: Pool<DBC::Db>) -> Arc<DbCache<DBC>> {
53        let self_ = Arc::new(Self {
54            db_pool,
55            cache_event_processor: CacheEventProcessor::new(config, cache_manager.sender()),
56            db_storage: DashMap::default(),
57            config,
58        });
59        cache_manager.register(self_.clone());
60        self_
61    }
62    pub async fn get(&self, key: &DBC::Key) -> Option<DBC::Value> {
63        return match self.db_storage.get(key) {
64            None => {
65                println!("cache miss for #{key} key");
66                let val = match DBC::get(&self.db_pool, key).await {
67                    None => {
68                        return None;
69                    }
70                    Some(val) => {
71                        val
72                    }
73                };
74
75
76                self.db_storage.insert(key.clone(), val.clone());
77                if let Err(err) = self.cache_event_processor.invalidate(key.clone()) {
78                    println!("Error sending invalidate cache task caused by: {err}");
79                    self.db_storage.remove(key);
80                }
81
82                Some(val)
83            }
84            Some(val) => {
85                println!("cache hit for #{key} key");
86                Some(val.value().clone())
87            }
88        };
89    }
90
91
92    pub async fn put(&self, key: DBC::Key, value: DBC::Value) -> Result<(), GenericError> {
93        DBC::put(&self.db_pool, key.clone(), value.clone()).await?;
94        self.db_storage.insert(key.clone(), value);
95        if let Err(err) = self.cache_event_processor.invalidate(key.clone()) {
96            println!("Error sending invalidate cache task caused by: {err}");
97            self.db_storage.remove(&key);
98        }
99
100        Ok(())
101    }
102
103    pub fn remove(&self, key: &DBC::Key) {
104        self.db_storage.remove(key);
105    }
106
107
108    pub fn cache(&self) -> &DashMap<DBC::Key, DBC::Value> {
109        &self.db_storage
110    }
111}
112
113
114pub trait CacheInvalidator: Send + Sync {
115    fn invalidate(&self, key: Box<dyn Any + Send>);
116    fn cache_id(&self) -> &'static str;
117}
118
119
120impl<DBC> CacheInvalidator for DbCache<DBC>
121where
122    DBC: DbCommands,
123{
124    fn invalidate(&self, key: Box<dyn Any + Send>) {
125        let val = match key.downcast::<DBC::Key>() {
126            Ok(val) => {
127                val
128            }
129            Err(err) => {
130                println!("Error executing invalidation for #{} cache caused by: {err:?}", self.cache_id());
131                return;
132            }
133        };
134
135        println!("Executing invalidation for #{val} key and #{} cache", self.cache_id());
136        self.remove(&val);
137    }
138
139    fn cache_id(&self) -> &'static str {
140        self.config.cache_id()
141    }
142}
143