1use anyhow::Result;
2#[cfg(feature = "ttl")]
3use anyhow::anyhow;
4pub use bincode::{Decode, Encode};
5use serde::de::DeserializeOwned;
6use serde::{Deserialize, Serialize};
7#[cfg(feature = "ttl")]
8use sled::Event;
9#[cfg(feature = "ttl")]
10use sled::Transactional;
11#[cfg(feature = "ttl")]
12use sled::transaction::ConflictableTransactionError;
13use sled::{Config, Db};
14
15#[cfg(feature = "ttl")]
16use std::sync::Arc;
17use std::time::{Duration, SystemTime, UNIX_EPOCH};
18fn _now() -> u64 {
19 SystemTime::now()
20 .duration_since(UNIX_EPOCH)
21 .unwrap()
22 .as_secs()
23}
24
25fn expired_time(ttl: Duration) -> u64 {
26 SystemTime::now()
27 .checked_add(ttl)
28 .unwrap()
29 .duration_since(UNIX_EPOCH)
30 .unwrap()
31 .as_secs()
32}
33
34pub trait ISledExt {
35 fn expire<K>(&self, key: K, ttl: Duration) -> Result<bool>
36 where
37 K: AsRef<[u8]> + Sync + Send;
38}
39
40impl ISledExt for Db {
41 fn expire<K>(&self, key: K, ttl: Duration) -> Result<bool>
42 where
43 K: AsRef<[u8]> + Sync + Send,
44 {
45 let expire_at = expired_time(ttl).to_be_bytes();
46 self.insert(key, expire_at.as_slice())?;
47 Ok(true)
48 }
49}
50
51#[derive(Serialize, Deserialize)]
52pub struct KvDbConfig {
53 pub path: String,
54 pub cache_capacity: u64,
55 pub flush_every_ms: u64,
56}
57
58const KV_TREE: &[u8] = b"__kv_tree@";
59const _TTL_TREE: &[u8] = b"__tll_tree@";
60
61pub struct KvDb {
62 pub db: Db,
63 pub kv_tree: sled::Tree,
64 #[cfg(feature = "ttl")]
65 pub ttl_tree: sled::Tree,
66}
67
68#[cfg(feature = "ttl")]
69pub fn def_ttl_cleanup(db: Arc<KvDb>, interval: Option<Duration>, limit: Option<usize>) {
70 let t = match interval {
71 Some(d) => d,
72 None => Duration::from_secs(3),
73 };
74 let limit = match limit {
75 Some(l) => l,
76 None => 200,
77 };
78 tokio::spawn(async move {
79 loop {
80 tokio::time::sleep(t).await;
81 loop {
82 let now = std::time::Instant::now();
83 let count = db.cleanup(limit);
84 if count > 0 {
85 log::debug!("cleanup count: {}, cost time: {:?}", count, now.elapsed());
86 }
87 if count < limit {
88 break;
89 }
90 tokio::time::sleep(std::time::Duration::from_millis(300)).await;
91 }
92 }
93 });
94}
95
96#[cfg(feature = "ttl")]
97pub fn set_expire_event<F>(db: Arc<KvDb>, _evt: F)
98where
99 F: Fn(String) + Send + Sync + 'static,
100{
101 tokio::spawn(async move {
102 for event in db.ttl_tree.watch_prefix(vec![]) {
103 match event {
104 Event::Remove { key } => {
105 let key = String::from_utf8_lossy(&key).into_owned();
106 _evt(key);
107 }
108 _ => {}
109 }
110 }
111 });
112}
113
114impl KvDb {
115 pub fn new(cfg: KvDbConfig) -> Result<Self> {
116 let c = Config::default()
117 .path(cfg.path)
118 .cache_capacity(cfg.cache_capacity)
119 .flush_every_ms(Some(cfg.flush_every_ms))
120 .mode(sled::Mode::LowSpace);
121 let db = c.open()?;
122 let kv_tree = db.open_tree(KV_TREE)?;
123 #[cfg(feature = "ttl")]
124 let ttl_tree = db.open_tree(_TTL_TREE)?;
125
126 Ok(KvDb {
127 db,
128 kv_tree,
129 #[cfg(feature = "ttl")]
130 ttl_tree,
131 })
132 }
133
134 #[cfg(feature = "ttl")]
135 fn cleanup(&self, limit: usize) -> usize {
136 let mut count = 0;
137
138 for item in self.ttl_tree.iter() {
139 if count > limit {
140 break;
141 }
142 let (key, expire_at_iv) = match item {
143 Ok(item) => item,
144 Err(e) => {
145 log::error!("cleanup err: {:?}", e);
146 break;
147 }
148 };
149
150 let expire_at = match expire_at_iv.as_ref().try_into() {
151 Ok(at) => u64::from_be_bytes(at),
152 Err(e) => {
153 log::error!("cleanup err: {:?}", e);
154 break;
155 }
156 };
157
158 if expire_at > _now() {
159 break;
160 }
161
162 if let Err(e) = (&self.kv_tree, &self.ttl_tree).transaction(|(kv, exp)| {
163 kv.remove(key.clone())?;
164 exp.remove(key.clone())?;
165 Ok::<_, ConflictableTransactionError<()>>(())
166 }) {
167 log::error!("cleanup err: {:?}", e);
168 } else {
169 count += 1;
170 }
171 }
172 count
173 }
174
175 #[cfg(feature = "ttl")]
176 pub fn get_ttl_at<K>(&self, key: K) -> Option<u64>
177 where
178 K: AsRef<[u8]> + Sync + Send,
179 {
180 let expire_at_iv = match self.ttl_tree.get(key.as_ref()) {
181 Ok(Some(at_bytes)) => at_bytes,
182 Ok(None) => return None,
183 Err(e) => {
184 log::error!("get_ttl_at err: {:?}", e);
185 return None;
186 }
187 };
188
189 let expire_at = match expire_at_iv.as_ref().try_into() {
190 Ok(at) => u64::from_be_bytes(at),
191 Err(e) => {
192 log::error!("get_ttl_at err: {:?}", e);
193 return None;
194 }
195 };
196
197 Some(expire_at)
198 }
199
200 #[cfg(feature = "ttl")]
201 pub fn is_expired<K>(&self, key: K) -> Option<bool>
202 where
203 K: AsRef<[u8]> + Sync + Send,
204 {
205 let expire_at = self.get_ttl_at(key);
206
207 let Some(expire_at) = expire_at else {
208 return None;
209 };
210
211 if _now() > expire_at {
212 return Some(true);
213 }
214
215 Some(false)
216 }
217
218 #[cfg(feature = "ttl")]
219 pub fn insert_ttl<K, V>(&self, key: K, value: V, ttl: Duration) -> Result<()>
220 where
221 K: AsRef<[u8]>,
222 V: Serialize + Encode + Sync + Send,
223 {
224 let v = bincode::encode_to_vec(value, bincode::config::standard())?;
225 let expire_at = expired_time(ttl).to_be_bytes();
226
227 if let Err(e) = (&self.kv_tree, &self.ttl_tree).transaction(|(kv, ttl)| {
228 kv.insert(key.as_ref(), v.clone())?;
229 ttl.insert(key.as_ref(), expire_at.as_slice())?;
230 Ok::<_, ConflictableTransactionError<()>>(())
231 }) {
232 return Err(anyhow!("insert_ttl err: {:?}", e));
233 }
234 Ok(())
235 }
236
237 pub fn insert_or_update<K, V>(&self, key: K, value: V) -> Result<()>
238 where
239 K: AsRef<[u8]>,
240 V: Serialize + Encode + Sync + Send,
241 {
242 let v = bincode::encode_to_vec(value, bincode::config::standard())?;
243 self.kv_tree.insert(key, v)?;
244 Ok(())
245 }
246
247 pub fn contains_key<K>(&self, key: K) -> bool
248 where
249 K: AsRef<[u8]> + Sync + Send,
250 {
251 #[cfg(feature = "ttl")]
252 {
253 let exp_v = self.is_expired(&key);
254
255 if let Some(v) = exp_v
257 && v
258 {
259 return false;
260 }
261 }
262
263 self.kv_tree.contains_key(key).ok().unwrap_or(false)
264 }
265
266 pub fn get<K, V>(&self, key: K) -> Option<V>
267 where
268 K: AsRef<[u8]>,
269 V: DeserializeOwned + Decode<()> + Sync + Send,
270 {
271 let val = match self.kv_tree.get(key) {
272 Ok(v) => v,
273 Err(e) => {
274 log::error!("kvdb get err: {}", e);
275 return None;
276 }
277 };
278
279 if let Some(v) = val {
280 let b = bincode::decode_from_slice::<V, _>(v.as_ref(), bincode::config::standard());
281 if let Ok((v, _)) = b {
282 return Some(v);
283 }
284 if let Err(e) = b {
285 log::error!("kvdb deserialize error: {}", e.to_string());
286 }
287 return None;
288 }
289
290 None
291 }
292
293 pub fn remove<K>(&self, key: K) -> Result<()>
294 where
295 K: AsRef<[u8]>,
296 {
297 let key_ref = key.as_ref();
298 if let Err(e) = (&self.kv_tree, &self.ttl_tree).transaction(|(kv, ttl)| {
299 kv.remove(key_ref)?;
300 ttl.remove(key_ref)?;
301 Ok::<_, ConflictableTransactionError<()>>(())
302 }) {
303 return Err(anyhow!("remove key err: {:?}", e));
304 }
305 Ok(())
306 }
307
308 pub fn clean(&self) -> Result<()> {
309 self.db.clear()?;
310 self.kv_tree.clear()?;
311 self.ttl_tree.clear()?;
312 Ok(())
313 }
314}