sqlite_cache/
lib.rs

1#[cfg(test)]
2mod lib_test;
3
4use data_encoding::BASE32_NOPAD;
5use futures::channel::oneshot::{channel, Receiver, Sender};
6pub use rusqlite;
7
8use std::{
9    collections::HashMap,
10    sync::{mpsc, Arc, Mutex, Weak},
11    time::{Duration, SystemTime, UNIX_EPOCH},
12};
13
14use rusqlite::{Connection, OptionalExtension};
15
16#[derive(Clone)]
17pub struct Cache {
18    inner: Arc<CacheImpl>,
19}
20
21#[derive(Clone, Debug)]
22pub struct CacheConfig {
23    pub flush_interval: Duration,
24    pub flush_gc_ratio: u64,
25    pub max_ttl: Option<Duration>,
26}
27
28impl Default for CacheConfig {
29    fn default() -> Self {
30        CacheConfig {
31            flush_interval: Duration::from_secs(10),
32            flush_gc_ratio: 30,
33            max_ttl: None,
34        }
35    }
36}
37
38#[derive(Clone)]
39pub struct Topic {
40    inner: Arc<TopicImpl>,
41}
42
43struct CacheImpl {
44    config: CacheConfig,
45    conn: Mutex<Connection>,
46    lazy_expiry_update: Mutex<HashMap<(Arc<str>, String), u64>>,
47    stop_tx: Mutex<mpsc::Sender<()>>,
48    completion_rx: Mutex<mpsc::Receiver<()>>,
49}
50
51struct TopicImpl {
52    cache: Cache,
53    table_name: Arc<str>,
54    listeners: Mutex<HashMap<String, Vec<Sender<()>>>>,
55}
56
57impl Drop for CacheImpl {
58    fn drop(&mut self) {
59        self.stop_tx.lock().unwrap().send(()).unwrap();
60        self.completion_rx.lock().unwrap().recv().unwrap();
61    }
62}
63
64impl Cache {
65    pub fn new(config: CacheConfig, conn: Connection) -> Result<Self, rusqlite::Error> {
66        assert!(config.flush_gc_ratio > 0);
67        let (stop_tx, stop_rx) = mpsc::channel::<()>();
68        let (completion_tx, completion_rx) = mpsc::channel::<()>();
69        conn.execute_batch("pragma journal_mode = wal;")?;
70        let inner = Arc::new(CacheImpl {
71            conn: Mutex::new(conn),
72            config: config.clone(),
73            lazy_expiry_update: Mutex::new(HashMap::new()),
74            stop_tx: Mutex::new(stop_tx),
75            completion_rx: Mutex::new(completion_rx),
76        });
77        let w = Arc::downgrade(&inner);
78        std::thread::spawn(move || periodic_task(config, stop_rx, completion_tx, w));
79        Ok(Self { inner })
80    }
81
82    fn flush(&self) {
83        let lazy_expiry_update = std::mem::take(&mut *self.inner.lazy_expiry_update.lock().unwrap());
84        for ((table_name, key), expiry) in lazy_expiry_update {
85            let res = self.inner.conn.lock().unwrap().execute(
86                &format!("update {} set expiry = ? where k = ?", table_name),
87                rusqlite::params![expiry, key],
88            );
89            if let Err(e) = res {
90                tracing::error!(table = &*table_name, key = key.as_str(), error = %e, "error updating expiry");
91            }
92        }
93    }
94
95    fn gc(&self) -> Result<(), rusqlite::Error> {
96        let now = SystemTime::now()
97            .duration_since(UNIX_EPOCH)
98            .unwrap()
99            .as_secs();
100        let tables = self
101            .inner
102            .conn
103            .lock()
104            .unwrap()
105            .unchecked_transaction()?
106            .prepare("select name from sqlite_master where type = 'table' and name like 'topic_%'")?
107            .query_map(rusqlite::params![], |x| x.get::<_, String>(0))?
108            .collect::<Result<Vec<String>, rusqlite::Error>>()?;
109        let mut total = 0usize;
110        for table in tables {
111            let count = self.inner.conn.lock().unwrap().execute(
112                &format!("delete from {} where expiry < ?", table),
113                rusqlite::params![now],
114            )?;
115            total += count;
116        }
117        if total != 0 {
118            tracing::info!(total = total, "gc deleted rows");
119        }
120        Ok(())
121    }
122
123    pub fn topic(&self, key: &str) -> Result<Topic, rusqlite::Error> {
124        let table_name = format!("topic_{}", BASE32_NOPAD.encode(key.as_bytes()));
125        self.inner.conn.lock().unwrap().execute_batch(&format!(
126            r#"
127begin transaction;
128create table if not exists {} (
129    k text primary key not null,
130    v blob not null,
131    created_at integer not null default (cast(strftime('%s', 'now') as integer)),
132    expiry integer not null,
133    ttl integer not null
134);
135create index if not exists {}_by_expiry on {} (expiry);
136commit;
137"#,
138            table_name, table_name, table_name,
139        ))?;
140        Ok(Topic {
141            inner: Arc::new(TopicImpl {
142                cache: self.clone(),
143                table_name: Arc::from(table_name),
144                listeners: Mutex::new(HashMap::new()),
145            }),
146        })
147    }
148}
149
150pub struct Value {
151    pub data: Vec<u8>,
152    pub created_at: u64,
153}
154
155impl Topic {
156    pub fn get(&self, key: &str) -> Result<Option<Value>, rusqlite::Error> {
157        let conn = self.inner.cache.inner.conn.lock().unwrap();
158        let mut stmt = conn.prepare_cached(&format!(
159            "select v, created_at, ttl from {} where k = ?",
160            self.inner.table_name,
161        ))?;
162        let rsp: Option<(Vec<u8>, u64, u64)> = stmt
163            .query_row(rusqlite::params![key], |x| {
164                Ok((x.get(0)?, x.get(1)?, x.get(2)?))
165            })
166            .optional()?;
167        if let Some((data, created_at, ttl)) = rsp {
168            self.inner
169                .cache
170                .inner
171                .lazy_expiry_update
172                .lock()
173                .unwrap()
174                .insert(
175                    (self.inner.table_name.clone(), key.to_string()),
176                    SystemTime::now()
177                        .duration_since(UNIX_EPOCH)
178                        .unwrap()
179                        .as_secs()
180                        .saturating_add(ttl)
181                        .min(i64::MAX as u64),
182                );
183            Ok(Some(Value { data, created_at }))
184        } else {
185            Ok(None)
186        }
187    }
188
189    pub async fn get_for_update(
190        &self,
191        key: &str,
192    ) -> Result<(KeyUpdater, Option<Value>), rusqlite::Error> {
193        loop {
194            let receiver: Option<Receiver<()>>;
195            {
196                let mut listeners = self.inner.listeners.lock().unwrap();
197                if let Some(arr) = listeners.get_mut(key) {
198                    let (tx, rx) = channel();
199                    arr.push(tx);
200                    receiver = Some(rx);
201                } else {
202                    receiver = None;
203                    listeners.insert(key.to_string(), vec![]);
204                }
205            }
206
207            if let Some(receiver) = receiver {
208                let _ = receiver.await;
209            } else {
210                break;
211            }
212        }
213
214        let data = self.get(key)?;
215        Ok((
216            KeyUpdater {
217                topic: self.clone(),
218                key: key.to_string(),
219            },
220            data,
221        ))
222    }
223
224    pub fn set(&self, key: &str, value: &[u8], ttl: Duration) -> Result<(), rusqlite::Error> {
225        let conn = self.inner.cache.inner.conn.lock().unwrap();
226        let mut stmt = conn.prepare_cached(&format!(
227            "replace into {} (k, v, expiry, ttl) values(?, ?, ?, ?)",
228            self.inner.table_name
229        ))?;
230        let mut ttl = ttl.as_secs();
231        if let Some(max_ttl) = self.inner.cache.inner.config.max_ttl {
232            let max_ttl = max_ttl.as_secs();
233            ttl = ttl.min(max_ttl);
234        }
235        ttl = ttl.min(i64::MAX as u64);
236        let expiry = SystemTime::now()
237            .duration_since(UNIX_EPOCH)
238            .unwrap()
239            .as_secs()
240            .saturating_add(ttl)
241            .min(i64::MAX as u64);
242        stmt.execute(rusqlite::params![key, value, expiry, ttl])?;
243        self.inner
244            .cache
245            .inner
246            .lazy_expiry_update
247            .lock()
248            .unwrap()
249            .remove(&(self.inner.table_name.clone(), key.to_string()));
250        Ok(())
251    }
252
253    pub fn delete(&self, key: &str) -> Result<(), rusqlite::Error> {
254        let conn = self.inner.cache.inner.conn.lock().unwrap();
255        let mut stmt = conn.prepare_cached(&format!(
256            "delete from {} where k = ?",
257            self.inner.table_name
258        ))?;
259        stmt.execute(rusqlite::params![key])?;
260        Ok(())
261    }
262}
263
264pub struct KeyUpdater {
265    topic: Topic,
266    key: String,
267}
268
269impl Drop for KeyUpdater {
270    fn drop(&mut self) {
271        let mut listeners = self.topic.inner.listeners.lock().unwrap();
272        listeners.remove(self.key.as_str()).unwrap();
273    }
274}
275
276impl KeyUpdater {
277    pub fn write(self, value: &[u8], ttl: Duration) -> Result<(), rusqlite::Error> {
278        self.topic.set(&self.key, value, ttl)?;
279        Ok(())
280    }
281}
282
283fn periodic_task(
284    config: CacheConfig,
285    stop_rx: mpsc::Receiver<()>,
286    completion_tx: mpsc::Sender<()>,
287    w: Weak<CacheImpl>,
288) {
289    let mut gc_ratio_counter = 0u64;
290    loop {
291        let tx = stop_rx.recv_timeout(config.flush_interval);
292        if tx.is_ok() {
293            break;
294        }
295
296        let inner = if let Some(x) = w.upgrade() {
297            x
298        } else {
299            break;
300        };
301        let cache = Cache { inner };
302        cache.flush();
303        gc_ratio_counter += 1;
304        if gc_ratio_counter == config.flush_gc_ratio {
305            gc_ratio_counter = 0;
306            if let Err(e) = cache.gc() {
307                tracing::error!(error = %e, "gc failed");
308            }
309        }
310    }
311    tracing::info!("exiting periodic task");
312    completion_tx.send(()).unwrap();
313}