sled_ext/
lib.rs

1use anyhow::Result;
2#[cfg(feature = "ttl")]
3use anyhow::anyhow;
4pub use bincode::{Decode, Encode};
5use serde::de::DeserializeOwned;
6use serde::{Deserialize, Serialize};
7#[cfg(feature = "ttl")]
8use sled::Event;
9#[cfg(feature = "ttl")]
10use sled::Transactional;
11#[cfg(feature = "ttl")]
12use sled::transaction::ConflictableTransactionError;
13use sled::{Config, Db};
14
15#[cfg(feature = "ttl")]
16use std::sync::Arc;
17use std::time::{Duration, SystemTime, UNIX_EPOCH};
18fn _now() -> u64 {
19    SystemTime::now()
20        .duration_since(UNIX_EPOCH)
21        .unwrap()
22        .as_secs()
23}
24
25fn expired_time(ttl: Duration) -> u64 {
26    SystemTime::now()
27        .checked_add(ttl)
28        .unwrap()
29        .duration_since(UNIX_EPOCH)
30        .unwrap()
31        .as_secs()
32}
33
34pub trait ISledExt {
35    fn expire<K>(&self, key: K, ttl: Duration) -> Result<bool>
36    where
37        K: AsRef<[u8]> + Sync + Send;
38}
39
40impl ISledExt for Db {
41    fn expire<K>(&self, key: K, ttl: Duration) -> Result<bool>
42    where
43        K: AsRef<[u8]> + Sync + Send,
44    {
45        let expire_at = expired_time(ttl).to_be_bytes();
46        self.insert(key, expire_at.as_slice())?;
47        Ok(true)
48    }
49}
50
51#[derive(Serialize, Deserialize)]
52pub struct KvDbConfig {
53    pub path: String,
54    pub cache_capacity: u64,
55    pub flush_every_ms: u64,
56}
57
58const KV_TREE: &[u8] = b"__kv_tree@";
59const _TTL_TREE: &[u8] = b"__tll_tree@";
60
61pub struct KvDb {
62    pub(crate) kv_tree: sled::Tree,
63    #[cfg(feature = "ttl")]
64    pub(crate) ttl_tree: sled::Tree,
65}
66
67#[cfg(feature = "ttl")]
68pub fn def_ttl_cleanup(db: Arc<KvDb>) {
69    //let db = kvdb.clone();
70    tokio::spawn(async move {
71        let limit = 200;
72        loop {
73            tokio::time::sleep(std::time::Duration::from_secs(5)).await;
74            loop {
75                let now = std::time::Instant::now();
76                let count = db.cleanup(limit);
77                if count > 0 {
78                    log::debug!("cleanup count: {}, cost time: {:?}", count, now.elapsed());
79                }
80                if count < limit {
81                    break;
82                }
83                tokio::time::sleep(std::time::Duration::from_millis(500)).await;
84            }
85        }
86    });
87}
88
89#[cfg(feature = "ttl")]
90pub fn set_expire_event<F>(db: Arc<KvDb>, _evt: F)
91where
92    F: Fn(String) + Send + Sync + 'static,
93{
94    tokio::spawn(async move {
95        for event in db.ttl_tree.watch_prefix(vec![]) {
96            match event {
97                Event::Remove { key } => {
98                    let key = String::from_utf8_lossy(&key).into_owned();
99                    _evt(key);
100                }
101                _ => {}
102            }
103        }
104    });
105}
106
107impl KvDb {
108    pub fn new(cfg: KvDbConfig) -> Result<Self> {
109        let c = Config::default()
110            .path(cfg.path)
111            .cache_capacity(cfg.cache_capacity)
112            .flush_every_ms(Some(cfg.flush_every_ms))
113            .mode(sled::Mode::LowSpace);
114        let db = c.open()?;
115        let kv_tree = db.open_tree(KV_TREE)?;
116        #[cfg(feature = "ttl")]
117        let ttl_tree = db.open_tree(_TTL_TREE)?;
118
119        // let db = Arc::new(db);
120        Ok(KvDb {
121            kv_tree,
122            #[cfg(feature = "ttl")]
123            ttl_tree,
124        })
125    }
126
127    #[cfg(feature = "ttl")]
128    fn cleanup(&self, limit: usize) -> usize {
129        let mut count = 0;
130
131        for item in self.ttl_tree.iter() {
132            if count > limit {
133                break;
134            }
135            let (key, expire_at_iv) = match item {
136                Ok(item) => item,
137                Err(e) => {
138                    log::error!("cleanup err: {:?}", e);
139                    break;
140                }
141            };
142
143            let expire_at = match expire_at_iv.as_ref().try_into() {
144                Ok(at) => u64::from_be_bytes(at),
145                Err(e) => {
146                    log::error!("cleanup err: {:?}", e);
147                    break;
148                }
149            };
150
151            if expire_at > _now() {
152                break;
153            }
154
155            if let Err(e) = (&self.kv_tree, &self.ttl_tree).transaction(|(kv, exp)| {
156                kv.remove(key.clone())?;
157                exp.remove(key.clone())?;
158                Ok::<_, ConflictableTransactionError<()>>(())
159            }) {
160                log::error!("cleanup err: {:?}", e);
161            } else {
162                count += 1;
163            }
164        }
165        count
166    }
167
168    #[cfg(feature = "ttl")]
169    pub fn get_ttl_at<K>(&self, key: K) -> Option<u64>
170    where
171        K: AsRef<[u8]> + Sync + Send,
172    {
173        let expire_at_iv = match self.ttl_tree.get(key.as_ref()) {
174            Ok(Some(at_bytes)) => at_bytes,
175            Ok(None) => return None,
176            Err(e) => {
177                log::error!("get_ttl_at err: {:?}", e);
178                return None;
179            }
180        };
181
182        let expire_at = match expire_at_iv.as_ref().try_into() {
183            Ok(at) => u64::from_be_bytes(at),
184            Err(e) => {
185                log::error!("get_ttl_at err: {:?}", e);
186                return None;
187            }
188        };
189
190        Some(expire_at)
191    }
192
193    #[cfg(feature = "ttl")]
194    pub fn is_expired<K>(&self, key: K) -> Option<bool>
195    where
196        K: AsRef<[u8]> + Sync + Send,
197    {
198        let expire_at = self.get_ttl_at(key);
199
200        let Some(expire_at) = expire_at else {
201            return None;
202        };
203
204        if _now() > expire_at {
205            return Some(true);
206        }
207
208        Some(false)
209    }
210
211    #[cfg(feature = "ttl")]
212    pub fn insert_ttl<K, V>(&self, key: K, value: V, ttl: Duration) -> Result<()>
213    where
214        K: AsRef<[u8]>,
215        V: Serialize + Encode + Sync + Send,
216    {
217        let v = bincode::encode_to_vec(value, bincode::config::standard())?;
218        let expire_at = expired_time(ttl).to_be_bytes();
219
220        if let Err(e) = (&self.kv_tree, &self.ttl_tree).transaction(|(kv, ttl)| {
221            kv.insert(key.as_ref(), v.clone())?;
222            ttl.insert(key.as_ref(), expire_at.as_slice())?;
223            Ok::<_, ConflictableTransactionError<()>>(())
224        }) {
225            return Err(anyhow!("insert_ttl err: {:?}", e));
226        }
227        Ok(())
228    }
229
230    pub fn insert<K, V>(&self, key: K, value: V) -> Result<()>
231    where
232        K: AsRef<[u8]>,
233        V: Serialize + Encode + Sync + Send,
234    {
235        let v = bincode::encode_to_vec(value, bincode::config::standard())?;
236        self.kv_tree.insert(key, v)?;
237        Ok(())
238    }
239
240    pub fn contains_key<K>(&self, key: K) -> bool
241    where
242        K: AsRef<[u8]> + Sync + Send,
243    {
244        #[cfg(feature = "ttl")]
245        {
246            let exp_v = self.is_expired(&key);
247
248            //如果ttl 存在,并已过期 则返回false
249            if let Some(v) = exp_v
250                && v
251            {
252                return false;
253            }
254        }
255
256        self.kv_tree.contains_key(key).ok().unwrap_or(false)
257    }
258
259    pub fn get<K, V>(&self, key: K) -> Option<V>
260    where
261        K: AsRef<[u8]>,
262        V: DeserializeOwned + Decode<()> + Sync + Send,
263    {
264        let val = match self.kv_tree.get(key) {
265            Ok(v) => v,
266            Err(e) => {
267                log::error!("kvdb get err: {}", e);
268                return None;
269            }
270        };
271
272        if let Some(v) = val {
273            let b = bincode::decode_from_slice::<V, _>(v.as_ref(), bincode::config::standard());
274            if let Ok((v, _)) = b {
275                return Some(v);
276            }
277            if let Err(e) = b {
278                log::error!("kvdb deserialize error: {}", e.to_string());
279            }
280            return None;
281        }
282
283        None
284    }
285}