1#[cfg(test)]
2mod lib_test;
3
4use data_encoding::BASE32_NOPAD;
5use futures::channel::oneshot::{channel, Receiver, Sender};
6pub use rusqlite;
7
8use std::{
9 collections::HashMap,
10 sync::{mpsc, Arc, Mutex, Weak},
11 time::{Duration, SystemTime, UNIX_EPOCH},
12};
13
14use rusqlite::{Connection, OptionalExtension};
15
16#[derive(Clone)]
17pub struct Cache {
18 inner: Arc<CacheImpl>,
19}
20
21#[derive(Clone, Debug)]
22pub struct CacheConfig {
23 pub flush_interval: Duration,
24 pub flush_gc_ratio: u64,
25 pub max_ttl: Option<Duration>,
26}
27
28impl Default for CacheConfig {
29 fn default() -> Self {
30 CacheConfig {
31 flush_interval: Duration::from_secs(10),
32 flush_gc_ratio: 30,
33 max_ttl: None,
34 }
35 }
36}
37
38#[derive(Clone)]
39pub struct Topic {
40 inner: Arc<TopicImpl>,
41}
42
43struct CacheImpl {
44 config: CacheConfig,
45 conn: Mutex<Connection>,
46 lazy_expiry_update: Mutex<HashMap<(Arc<str>, String), u64>>,
47 stop_tx: Mutex<mpsc::Sender<()>>,
48 completion_rx: Mutex<mpsc::Receiver<()>>,
49}
50
51struct TopicImpl {
52 cache: Cache,
53 table_name: Arc<str>,
54 listeners: Mutex<HashMap<String, Vec<Sender<()>>>>,
55}
56
57impl Drop for CacheImpl {
58 fn drop(&mut self) {
59 self.stop_tx.lock().unwrap().send(()).unwrap();
60 self.completion_rx.lock().unwrap().recv().unwrap();
61 }
62}
63
64impl Cache {
65 pub fn new(config: CacheConfig, conn: Connection) -> Result<Self, rusqlite::Error> {
66 assert!(config.flush_gc_ratio > 0);
67 let (stop_tx, stop_rx) = mpsc::channel::<()>();
68 let (completion_tx, completion_rx) = mpsc::channel::<()>();
69 conn.execute_batch("pragma journal_mode = wal;")?;
70 let inner = Arc::new(CacheImpl {
71 conn: Mutex::new(conn),
72 config: config.clone(),
73 lazy_expiry_update: Mutex::new(HashMap::new()),
74 stop_tx: Mutex::new(stop_tx),
75 completion_rx: Mutex::new(completion_rx),
76 });
77 let w = Arc::downgrade(&inner);
78 std::thread::spawn(move || periodic_task(config, stop_rx, completion_tx, w));
79 Ok(Self { inner })
80 }
81
82 fn flush(&self) {
83 let lazy_expiry_update = std::mem::take(&mut *self.inner.lazy_expiry_update.lock().unwrap());
84 for ((table_name, key), expiry) in lazy_expiry_update {
85 let res = self.inner.conn.lock().unwrap().execute(
86 &format!("update {} set expiry = ? where k = ?", table_name),
87 rusqlite::params![expiry, key],
88 );
89 if let Err(e) = res {
90 tracing::error!(table = &*table_name, key = key.as_str(), error = %e, "error updating expiry");
91 }
92 }
93 }
94
95 fn gc(&self) -> Result<(), rusqlite::Error> {
96 let now = SystemTime::now()
97 .duration_since(UNIX_EPOCH)
98 .unwrap()
99 .as_secs();
100 let tables = self
101 .inner
102 .conn
103 .lock()
104 .unwrap()
105 .unchecked_transaction()?
106 .prepare("select name from sqlite_master where type = 'table' and name like 'topic_%'")?
107 .query_map(rusqlite::params![], |x| x.get::<_, String>(0))?
108 .collect::<Result<Vec<String>, rusqlite::Error>>()?;
109 let mut total = 0usize;
110 for table in tables {
111 let count = self.inner.conn.lock().unwrap().execute(
112 &format!("delete from {} where expiry < ?", table),
113 rusqlite::params![now],
114 )?;
115 total += count;
116 }
117 if total != 0 {
118 tracing::info!(total = total, "gc deleted rows");
119 }
120 Ok(())
121 }
122
123 pub fn topic(&self, key: &str) -> Result<Topic, rusqlite::Error> {
124 let table_name = format!("topic_{}", BASE32_NOPAD.encode(key.as_bytes()));
125 self.inner.conn.lock().unwrap().execute_batch(&format!(
126 r#"
127begin transaction;
128create table if not exists {} (
129 k text primary key not null,
130 v blob not null,
131 created_at integer not null default (cast(strftime('%s', 'now') as integer)),
132 expiry integer not null,
133 ttl integer not null
134);
135create index if not exists {}_by_expiry on {} (expiry);
136commit;
137"#,
138 table_name, table_name, table_name,
139 ))?;
140 Ok(Topic {
141 inner: Arc::new(TopicImpl {
142 cache: self.clone(),
143 table_name: Arc::from(table_name),
144 listeners: Mutex::new(HashMap::new()),
145 }),
146 })
147 }
148}
149
150pub struct Value {
151 pub data: Vec<u8>,
152 pub created_at: u64,
153}
154
155impl Topic {
156 pub fn get(&self, key: &str) -> Result<Option<Value>, rusqlite::Error> {
157 let conn = self.inner.cache.inner.conn.lock().unwrap();
158 let mut stmt = conn.prepare_cached(&format!(
159 "select v, created_at, ttl from {} where k = ?",
160 self.inner.table_name,
161 ))?;
162 let rsp: Option<(Vec<u8>, u64, u64)> = stmt
163 .query_row(rusqlite::params![key], |x| {
164 Ok((x.get(0)?, x.get(1)?, x.get(2)?))
165 })
166 .optional()?;
167 if let Some((data, created_at, ttl)) = rsp {
168 self.inner
169 .cache
170 .inner
171 .lazy_expiry_update
172 .lock()
173 .unwrap()
174 .insert(
175 (self.inner.table_name.clone(), key.to_string()),
176 SystemTime::now()
177 .duration_since(UNIX_EPOCH)
178 .unwrap()
179 .as_secs()
180 .saturating_add(ttl)
181 .min(i64::MAX as u64),
182 );
183 Ok(Some(Value { data, created_at }))
184 } else {
185 Ok(None)
186 }
187 }
188
189 pub async fn get_for_update(
190 &self,
191 key: &str,
192 ) -> Result<(KeyUpdater, Option<Value>), rusqlite::Error> {
193 loop {
194 let receiver: Option<Receiver<()>>;
195 {
196 let mut listeners = self.inner.listeners.lock().unwrap();
197 if let Some(arr) = listeners.get_mut(key) {
198 let (tx, rx) = channel();
199 arr.push(tx);
200 receiver = Some(rx);
201 } else {
202 receiver = None;
203 listeners.insert(key.to_string(), vec![]);
204 }
205 }
206
207 if let Some(receiver) = receiver {
208 let _ = receiver.await;
209 } else {
210 break;
211 }
212 }
213
214 let data = self.get(key)?;
215 Ok((
216 KeyUpdater {
217 topic: self.clone(),
218 key: key.to_string(),
219 },
220 data,
221 ))
222 }
223
224 pub fn set(&self, key: &str, value: &[u8], ttl: Duration) -> Result<(), rusqlite::Error> {
225 let conn = self.inner.cache.inner.conn.lock().unwrap();
226 let mut stmt = conn.prepare_cached(&format!(
227 "replace into {} (k, v, expiry, ttl) values(?, ?, ?, ?)",
228 self.inner.table_name
229 ))?;
230 let mut ttl = ttl.as_secs();
231 if let Some(max_ttl) = self.inner.cache.inner.config.max_ttl {
232 let max_ttl = max_ttl.as_secs();
233 ttl = ttl.min(max_ttl);
234 }
235 ttl = ttl.min(i64::MAX as u64);
236 let expiry = SystemTime::now()
237 .duration_since(UNIX_EPOCH)
238 .unwrap()
239 .as_secs()
240 .saturating_add(ttl)
241 .min(i64::MAX as u64);
242 stmt.execute(rusqlite::params![key, value, expiry, ttl])?;
243 self.inner
244 .cache
245 .inner
246 .lazy_expiry_update
247 .lock()
248 .unwrap()
249 .remove(&(self.inner.table_name.clone(), key.to_string()));
250 Ok(())
251 }
252
253 pub fn delete(&self, key: &str) -> Result<(), rusqlite::Error> {
254 let conn = self.inner.cache.inner.conn.lock().unwrap();
255 let mut stmt = conn.prepare_cached(&format!(
256 "delete from {} where k = ?",
257 self.inner.table_name
258 ))?;
259 stmt.execute(rusqlite::params![key])?;
260 Ok(())
261 }
262}
263
264pub struct KeyUpdater {
265 topic: Topic,
266 key: String,
267}
268
269impl Drop for KeyUpdater {
270 fn drop(&mut self) {
271 let mut listeners = self.topic.inner.listeners.lock().unwrap();
272 listeners.remove(self.key.as_str()).unwrap();
273 }
274}
275
276impl KeyUpdater {
277 pub fn write(self, value: &[u8], ttl: Duration) -> Result<(), rusqlite::Error> {
278 self.topic.set(&self.key, value, ttl)?;
279 Ok(())
280 }
281}
282
283fn periodic_task(
284 config: CacheConfig,
285 stop_rx: mpsc::Receiver<()>,
286 completion_tx: mpsc::Sender<()>,
287 w: Weak<CacheImpl>,
288) {
289 let mut gc_ratio_counter = 0u64;
290 loop {
291 let tx = stop_rx.recv_timeout(config.flush_interval);
292 if tx.is_ok() {
293 break;
294 }
295
296 let inner = if let Some(x) = w.upgrade() {
297 x
298 } else {
299 break;
300 };
301 let cache = Cache { inner };
302 cache.flush();
303 gc_ratio_counter += 1;
304 if gc_ratio_counter == config.flush_gc_ratio {
305 gc_ratio_counter = 0;
306 if let Err(e) = cache.gc() {
307 tracing::error!(error = %e, "gc failed");
308 }
309 }
310 }
311 tracing::info!("exiting periodic task");
312 completion_tx.send(()).unwrap();
313}