sled_ext/
lib.rs

1use anyhow::Result;
2#[cfg(feature = "ttl")]
3use anyhow::anyhow;
4pub use bincode::{Decode, Encode};
5use serde::de::DeserializeOwned;
6use serde::{Deserialize, Serialize};
7#[cfg(feature = "ttl")]
8use sled::Event;
9#[cfg(feature = "ttl")]
10use sled::Transactional;
11#[cfg(feature = "ttl")]
12use sled::transaction::ConflictableTransactionError;
13use sled::{Config, Db};
14
15#[cfg(feature = "ttl")]
16use std::sync::Arc;
17use std::time::{Duration, SystemTime, UNIX_EPOCH};
18fn _now() -> u64 {
19    SystemTime::now()
20        .duration_since(UNIX_EPOCH)
21        .unwrap()
22        .as_secs()
23}
24
25fn expired_time(ttl: Duration) -> u64 {
26    SystemTime::now()
27        .checked_add(ttl)
28        .unwrap()
29        .duration_since(UNIX_EPOCH)
30        .unwrap()
31        .as_secs()
32}
33
34pub trait ISledExt {
35    fn expire<K>(&self, key: K, ttl: Duration) -> Result<bool>
36    where
37        K: AsRef<[u8]> + Sync + Send;
38}
39
40impl ISledExt for Db {
41    fn expire<K>(&self, key: K, ttl: Duration) -> Result<bool>
42    where
43        K: AsRef<[u8]> + Sync + Send,
44    {
45        let expire_at = expired_time(ttl).to_be_bytes();
46        self.insert(key, expire_at.as_slice())?;
47        Ok(true)
48    }
49}
50
51#[derive(Serialize, Deserialize)]
52pub struct KvDbConfig {
53    pub path: String,
54    pub cache_capacity: u64,
55    pub flush_every_ms: u64,
56}
57
58const KV_TREE: &[u8] = b"__kv_tree@";
59const _TTL_TREE: &[u8] = b"__tll_tree@";
60
61pub struct KvDb {
62    pub db: Db,
63    pub kv_tree: sled::Tree,
64    #[cfg(feature = "ttl")]
65    pub ttl_tree: sled::Tree,
66}
67
68#[cfg(feature = "ttl")]
69pub fn def_ttl_cleanup(db: Arc<KvDb>, interval: Option<Duration>, limit: Option<usize>) {
70    let t = match interval {
71        Some(d) => d,
72        None => Duration::from_secs(3),
73    };
74    let limit = match limit {
75        Some(l) => l,
76        None => 200,
77    };
78    tokio::spawn(async move {
79        loop {
80            tokio::time::sleep(t).await;
81            loop {
82                let now = std::time::Instant::now();
83                let count = db.cleanup(limit);
84                if count > 0 {
85                    log::debug!("cleanup count: {}, cost time: {:?}", count, now.elapsed());
86                }
87                if count < limit {
88                    break;
89                }
90                tokio::time::sleep(std::time::Duration::from_millis(300)).await;
91            }
92        }
93    });
94}
95
96#[cfg(feature = "ttl")]
97pub fn set_expire_event<F>(db: Arc<KvDb>, _evt: F)
98where
99    F: Fn(String) + Send + Sync + 'static,
100{
101    tokio::spawn(async move {
102        for event in db.ttl_tree.watch_prefix(vec![]) {
103            match event {
104                Event::Remove { key } => {
105                    let key = String::from_utf8_lossy(&key).into_owned();
106                    _evt(key);
107                }
108                _ => {}
109            }
110        }
111    });
112}
113
114impl KvDb {
115    pub fn new(cfg: KvDbConfig) -> Result<Self> {
116        let c = Config::default()
117            .path(cfg.path)
118            .cache_capacity(cfg.cache_capacity)
119            .flush_every_ms(Some(cfg.flush_every_ms))
120            .mode(sled::Mode::LowSpace);
121        let db = c.open()?;
122        let kv_tree = db.open_tree(KV_TREE)?;
123        #[cfg(feature = "ttl")]
124        let ttl_tree = db.open_tree(_TTL_TREE)?;
125
126        Ok(KvDb {
127            db,
128            kv_tree,
129            #[cfg(feature = "ttl")]
130            ttl_tree,
131        })
132    }
133
134    #[cfg(feature = "ttl")]
135    fn cleanup(&self, limit: usize) -> usize {
136        let mut count = 0;
137
138        for item in self.ttl_tree.iter() {
139            if count > limit {
140                break;
141            }
142            let (key, expire_at_iv) = match item {
143                Ok(item) => item,
144                Err(e) => {
145                    log::error!("cleanup err: {:?}", e);
146                    break;
147                }
148            };
149
150            let expire_at = match expire_at_iv.as_ref().try_into() {
151                Ok(at) => u64::from_be_bytes(at),
152                Err(e) => {
153                    log::error!("cleanup err: {:?}", e);
154                    break;
155                }
156            };
157
158            if expire_at > _now() {
159                break;
160            }
161
162            if let Err(e) = (&self.kv_tree, &self.ttl_tree).transaction(|(kv, exp)| {
163                kv.remove(key.clone())?;
164                exp.remove(key.clone())?;
165                Ok::<_, ConflictableTransactionError<()>>(())
166            }) {
167                log::error!("cleanup err: {:?}", e);
168            } else {
169                count += 1;
170            }
171        }
172        count
173    }
174
175    #[cfg(feature = "ttl")]
176    pub fn get_ttl_at<K>(&self, key: K) -> Option<u64>
177    where
178        K: AsRef<[u8]> + Sync + Send,
179    {
180        let expire_at_iv = match self.ttl_tree.get(key.as_ref()) {
181            Ok(Some(at_bytes)) => at_bytes,
182            Ok(None) => return None,
183            Err(e) => {
184                log::error!("get_ttl_at err: {:?}", e);
185                return None;
186            }
187        };
188
189        let expire_at = match expire_at_iv.as_ref().try_into() {
190            Ok(at) => u64::from_be_bytes(at),
191            Err(e) => {
192                log::error!("get_ttl_at err: {:?}", e);
193                return None;
194            }
195        };
196
197        Some(expire_at)
198    }
199
200    #[cfg(feature = "ttl")]
201    pub fn is_expired<K>(&self, key: K) -> Option<bool>
202    where
203        K: AsRef<[u8]> + Sync + Send,
204    {
205        let expire_at = self.get_ttl_at(key);
206
207        let Some(expire_at) = expire_at else {
208            return None;
209        };
210
211        if _now() > expire_at {
212            return Some(true);
213        }
214
215        Some(false)
216    }
217
218    #[cfg(feature = "ttl")]
219    pub fn insert_ttl<K, V>(&self, key: K, value: V, ttl: Duration) -> Result<()>
220    where
221        K: AsRef<[u8]>,
222        V: Serialize + Encode + Sync + Send,
223    {
224        let v = bincode::encode_to_vec(value, bincode::config::standard())?;
225        let expire_at = expired_time(ttl).to_be_bytes();
226
227        if let Err(e) = (&self.kv_tree, &self.ttl_tree).transaction(|(kv, ttl)| {
228            kv.insert(key.as_ref(), v.clone())?;
229            ttl.insert(key.as_ref(), expire_at.as_slice())?;
230            Ok::<_, ConflictableTransactionError<()>>(())
231        }) {
232            return Err(anyhow!("insert_ttl err: {:?}", e));
233        }
234        Ok(())
235    }
236
237    pub fn insert_or_update<K, V>(&self, key: K, value: V) -> Result<()>
238    where
239        K: AsRef<[u8]>,
240        V: Serialize + Encode + Sync + Send,
241    {
242        let v = bincode::encode_to_vec(value, bincode::config::standard())?;
243        self.kv_tree.insert(key, v)?;
244        Ok(())
245    }
246
247    pub fn contains_key<K>(&self, key: K) -> bool
248    where
249        K: AsRef<[u8]> + Sync + Send,
250    {
251        #[cfg(feature = "ttl")]
252        {
253            let exp_v = self.is_expired(&key);
254
255            //如果ttl 存在,并已过期 则返回false
256            if let Some(v) = exp_v
257                && v
258            {
259                return false;
260            }
261        }
262
263        self.kv_tree.contains_key(key).ok().unwrap_or(false)
264    }
265
266    pub fn get<K, V>(&self, key: K) -> Option<V>
267    where
268        K: AsRef<[u8]>,
269        V: DeserializeOwned + Decode<()> + Sync + Send,
270    {
271        let val = match self.kv_tree.get(key) {
272            Ok(v) => v,
273            Err(e) => {
274                log::error!("kvdb get err: {}", e);
275                return None;
276            }
277        };
278
279        if let Some(v) = val {
280            let b = bincode::decode_from_slice::<V, _>(v.as_ref(), bincode::config::standard());
281            if let Ok((v, _)) = b {
282                return Some(v);
283            }
284            if let Err(e) = b {
285                log::error!("kvdb deserialize error: {}", e.to_string());
286            }
287            return None;
288        }
289
290        None
291    }
292
293    pub fn remove<K>(&self, key: K) -> Result<()>
294    where
295        K: AsRef<[u8]>,
296    {
297        let key_ref = key.as_ref();
298        if let Err(e) = (&self.kv_tree, &self.ttl_tree).transaction(|(kv, ttl)| {
299            kv.remove(key_ref)?;
300            ttl.remove(key_ref)?;
301            Ok::<_, ConflictableTransactionError<()>>(())
302        }) {
303            return Err(anyhow!("remove key err: {:?}", e));
304        }
305        Ok(())
306    }
307
308    pub fn clean(&self) -> Result<()> {
309        self.db.clear()?;
310        self.kv_tree.clear()?;
311        self.ttl_tree.clear()?;
312        Ok(())
313    }
314}