1use anyhow::Result;
2#[cfg(feature = "ttl")]
3use anyhow::anyhow;
4pub use bincode::{Decode, Encode};
5use serde::de::DeserializeOwned;
6use serde::{Deserialize, Serialize};
7#[cfg(feature = "ttl")]
8use sled::Event;
9#[cfg(feature = "ttl")]
10use sled::Transactional;
11#[cfg(feature = "ttl")]
12use sled::transaction::ConflictableTransactionError;
13use sled::{Config, Db};
14
15#[cfg(feature = "ttl")]
16use std::sync::Arc;
17use std::time::{Duration, SystemTime, UNIX_EPOCH};
18fn _now() -> u64 {
19 SystemTime::now()
20 .duration_since(UNIX_EPOCH)
21 .unwrap()
22 .as_secs()
23}
24
25fn expired_time(ttl: Duration) -> u64 {
26 SystemTime::now()
27 .checked_add(ttl)
28 .unwrap()
29 .duration_since(UNIX_EPOCH)
30 .unwrap()
31 .as_secs()
32}
33
34pub trait ISledExt {
35 fn expire<K>(&self, key: K, ttl: Duration) -> Result<bool>
36 where
37 K: AsRef<[u8]> + Sync + Send;
38}
39
40impl ISledExt for Db {
41 fn expire<K>(&self, key: K, ttl: Duration) -> Result<bool>
42 where
43 K: AsRef<[u8]> + Sync + Send,
44 {
45 let expire_at = expired_time(ttl).to_be_bytes();
46 self.insert(key, expire_at.as_slice())?;
47 Ok(true)
48 }
49}
50
51#[derive(Serialize, Deserialize)]
52pub struct KvDbConfig {
53 pub path: String,
54 pub cache_capacity: u64,
55 pub flush_every_ms: u64,
56}
57
58const KV_TREE: &[u8] = b"__kv_tree@";
59const _TTL_TREE: &[u8] = b"__tll_tree@";
60
61pub struct KvDb {
62 pub(crate) kv_tree: sled::Tree,
63 #[cfg(feature = "ttl")]
64 pub(crate) ttl_tree: sled::Tree,
65}
66
67#[cfg(feature = "ttl")]
68pub fn def_ttl_cleanup(db: Arc<KvDb>) {
69 tokio::spawn(async move {
71 let limit = 200;
72 loop {
73 tokio::time::sleep(std::time::Duration::from_secs(5)).await;
74 loop {
75 let now = std::time::Instant::now();
76 let count = db.cleanup(limit);
77 if count > 0 {
78 log::debug!("cleanup count: {}, cost time: {:?}", count, now.elapsed());
79 }
80 if count < limit {
81 break;
82 }
83 tokio::time::sleep(std::time::Duration::from_millis(500)).await;
84 }
85 }
86 });
87}
88
89#[cfg(feature = "ttl")]
90pub fn set_expire_event<F>(db: Arc<KvDb>, _evt: F)
91where
92 F: Fn(String) + Send + Sync + 'static,
93{
94 tokio::spawn(async move {
95 for event in db.ttl_tree.watch_prefix(vec![]) {
96 match event {
97 Event::Remove { key } => {
98 let key = String::from_utf8_lossy(&key).into_owned();
99 _evt(key);
100 }
101 _ => {}
102 }
103 }
104 });
105}
106
107impl KvDb {
108 pub fn new(cfg: KvDbConfig) -> Result<Self> {
109 let c = Config::default()
110 .path(cfg.path)
111 .cache_capacity(cfg.cache_capacity)
112 .flush_every_ms(Some(cfg.flush_every_ms))
113 .mode(sled::Mode::LowSpace);
114 let db = c.open()?;
115 let kv_tree = db.open_tree(KV_TREE)?;
116 #[cfg(feature = "ttl")]
117 let ttl_tree = db.open_tree(_TTL_TREE)?;
118
119 Ok(KvDb {
121 kv_tree,
122 #[cfg(feature = "ttl")]
123 ttl_tree,
124 })
125 }
126
127 #[cfg(feature = "ttl")]
128 fn cleanup(&self, limit: usize) -> usize {
129 let mut count = 0;
130
131 for item in self.ttl_tree.iter() {
132 if count > limit {
133 break;
134 }
135 let (key, expire_at_iv) = match item {
136 Ok(item) => item,
137 Err(e) => {
138 log::error!("cleanup err: {:?}", e);
139 break;
140 }
141 };
142
143 let expire_at = match expire_at_iv.as_ref().try_into() {
144 Ok(at) => u64::from_be_bytes(at),
145 Err(e) => {
146 log::error!("cleanup err: {:?}", e);
147 break;
148 }
149 };
150
151 if expire_at > _now() {
152 break;
153 }
154
155 if let Err(e) = (&self.kv_tree, &self.ttl_tree).transaction(|(kv, exp)| {
156 kv.remove(key.clone())?;
157 exp.remove(key.clone())?;
158 Ok::<_, ConflictableTransactionError<()>>(())
159 }) {
160 log::error!("cleanup err: {:?}", e);
161 } else {
162 count += 1;
163 }
164 }
165 count
166 }
167
168 #[cfg(feature = "ttl")]
169 pub fn get_ttl_at<K>(&self, key: K) -> Option<u64>
170 where
171 K: AsRef<[u8]> + Sync + Send,
172 {
173 let expire_at_iv = match self.ttl_tree.get(key.as_ref()) {
174 Ok(Some(at_bytes)) => at_bytes,
175 Ok(None) => return None,
176 Err(e) => {
177 log::error!("get_ttl_at err: {:?}", e);
178 return None;
179 }
180 };
181
182 let expire_at = match expire_at_iv.as_ref().try_into() {
183 Ok(at) => u64::from_be_bytes(at),
184 Err(e) => {
185 log::error!("get_ttl_at err: {:?}", e);
186 return None;
187 }
188 };
189
190 Some(expire_at)
191 }
192
193 #[cfg(feature = "ttl")]
194 pub fn is_expired<K>(&self, key: K) -> Option<bool>
195 where
196 K: AsRef<[u8]> + Sync + Send,
197 {
198 let expire_at = self.get_ttl_at(key);
199
200 let Some(expire_at) = expire_at else {
201 return None;
202 };
203
204 if _now() > expire_at {
205 return Some(true);
206 }
207
208 Some(false)
209 }
210
211 #[cfg(feature = "ttl")]
212 pub fn insert_ttl<K, V>(&self, key: K, value: V, ttl: Duration) -> Result<()>
213 where
214 K: AsRef<[u8]>,
215 V: Serialize + Encode + Sync + Send,
216 {
217 let v = bincode::encode_to_vec(value, bincode::config::standard())?;
218 let expire_at = expired_time(ttl).to_be_bytes();
219
220 if let Err(e) = (&self.kv_tree, &self.ttl_tree).transaction(|(kv, ttl)| {
221 kv.insert(key.as_ref(), v.clone())?;
222 ttl.insert(key.as_ref(), expire_at.as_slice())?;
223 Ok::<_, ConflictableTransactionError<()>>(())
224 }) {
225 return Err(anyhow!("insert_ttl err: {:?}", e));
226 }
227 Ok(())
228 }
229
230 pub fn insert<K, V>(&self, key: K, value: V) -> Result<()>
231 where
232 K: AsRef<[u8]>,
233 V: Serialize + Encode + Sync + Send,
234 {
235 let v = bincode::encode_to_vec(value, bincode::config::standard())?;
236 self.kv_tree.insert(key, v)?;
237 Ok(())
238 }
239
240 pub fn contains_key<K>(&self, key: K) -> bool
241 where
242 K: AsRef<[u8]> + Sync + Send,
243 {
244 #[cfg(feature = "ttl")]
245 {
246 let exp_v = self.is_expired(&key);
247
248 if let Some(v) = exp_v
250 && v
251 {
252 return false;
253 }
254 }
255
256 self.kv_tree.contains_key(key).ok().unwrap_or(false)
257 }
258
259 pub fn get<K, V>(&self, key: K) -> Option<V>
260 where
261 K: AsRef<[u8]>,
262 V: DeserializeOwned + Decode<()> + Sync + Send,
263 {
264 let val = match self.kv_tree.get(key) {
265 Ok(v) => v,
266 Err(e) => {
267 log::error!("kvdb get err: {}", e);
268 return None;
269 }
270 };
271
272 if let Some(v) = val {
273 let b = bincode::decode_from_slice::<V, _>(v.as_ref(), bincode::config::standard());
274 if let Ok((v, _)) = b {
275 return Some(v);
276 }
277 if let Err(e) = b {
278 log::error!("kvdb deserialize error: {}", e.to_string());
279 }
280 return None;
281 }
282
283 None
284 }
285}