1use std::collections::BTreeSet;
40use std::fmt;
41use std::os::raw::c_void;
42
43use incrementalmerkletree::{Address, Level};
44use shardtree::{
45 store::{Checkpoint, ShardStore},
46 LocatedPrunableTree, LocatedTree, PrunableTree, Tree,
47};
48
49use crate::hash::{MerkleHashVote, SHARD_HEIGHT};
50use crate::serde::{read_checkpoint, read_shard_vote, write_checkpoint, write_shard_vote};
51
52#[derive(Debug, Clone, PartialEq, Eq)]
67pub enum KvError {
68 IoError,
70 Deserialization,
72 Serialization,
74}
75
76impl fmt::Display for KvError {
77 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
78 match self {
79 KvError::IoError => write!(f, "KV callback returned an error"),
80 KvError::Deserialization => write!(f, "failed to deserialize KV data"),
81 KvError::Serialization => write!(f, "failed to serialize data for KV"),
82 }
83 }
84}
85
86impl std::error::Error for KvError {}
87
88const SHARD_PREFIX: u8 = 0x0F;
93const CAP_KEY: u8 = 0x10;
94const CHECKPOINT_PREFIX: u8 = 0x11;
95
96fn shard_key(index: u64) -> [u8; 9] {
97 let mut k = [0u8; 9];
98 k[0] = SHARD_PREFIX;
99 k[1..].copy_from_slice(&index.to_be_bytes());
100 k
101}
102
103fn cap_key() -> [u8; 1] {
104 [CAP_KEY]
105}
106
107fn checkpoint_key(id: u32) -> [u8; 5] {
108 let mut k = [0u8; 5];
109 k[0] = CHECKPOINT_PREFIX;
110 k[1..].copy_from_slice(&id.to_be_bytes());
111 k
112}
113
114pub type KvGetFn = unsafe extern "C" fn(
125 ctx: *mut c_void,
126 key: *const u8,
127 key_len: usize,
128 out_val: *mut *mut u8,
129 out_val_len: *mut usize,
130) -> i32;
131
132pub type KvSetFn = unsafe extern "C" fn(
134 ctx: *mut c_void,
135 key: *const u8,
136 key_len: usize,
137 val: *const u8,
138 val_len: usize,
139) -> i32;
140
141pub type KvDeleteFn = unsafe extern "C" fn(ctx: *mut c_void, key: *const u8, key_len: usize) -> i32;
143
144pub type KvIterCreateFn = unsafe extern "C" fn(
149 ctx: *mut c_void,
150 prefix: *const u8,
151 prefix_len: usize,
152 reverse: u8,
153) -> *mut c_void;
154
155pub type KvIterNextFn = unsafe extern "C" fn(
160 iter: *mut c_void,
161 out_key: *mut *mut u8,
162 out_key_len: *mut usize,
163 out_val: *mut *mut u8,
164 out_val_len: *mut usize,
165) -> i32;
166
167pub type KvIterFreeFn = unsafe extern "C" fn(iter: *mut c_void);
169
170pub type KvFreeBufFn = unsafe extern "C" fn(ptr: *mut u8, len: usize);
172
173#[derive(Clone, Copy)]
184pub struct KvCallbacks {
185 pub ctx: *mut c_void,
186 pub get: KvGetFn,
187 pub set: KvSetFn,
188 pub delete: KvDeleteFn,
189 pub iter_create: KvIterCreateFn,
190 pub iter_next: KvIterNextFn,
191 pub iter_free: KvIterFreeFn,
192 pub free_buf: KvFreeBufFn,
193}
194
195unsafe impl Send for KvCallbacks {}
198unsafe impl Sync for KvCallbacks {}
199
200impl KvCallbacks {
205 pub fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>, KvError> {
210 let mut out_ptr: *mut u8 = std::ptr::null_mut();
211 let mut out_len: usize = 0;
212 let rc = unsafe {
213 (self.get)(
214 self.ctx,
215 key.as_ptr(),
216 key.len(),
217 &mut out_ptr,
218 &mut out_len,
219 )
220 };
221 match rc {
222 0 => {
223 let val = unsafe { std::slice::from_raw_parts(out_ptr, out_len).to_vec() };
224 unsafe { (self.free_buf)(out_ptr, out_len) };
225 Ok(Some(val))
226 }
227 1 => Ok(None), _ => Err(KvError::IoError), }
230 }
231
232 pub fn set(&self, key: &[u8], val: &[u8]) -> Result<(), KvError> {
235 let rc = unsafe { (self.set)(self.ctx, key.as_ptr(), key.len(), val.as_ptr(), val.len()) };
236 if rc != 0 {
237 Err(KvError::IoError)
238 } else {
239 Ok(())
240 }
241 }
242
243 pub fn delete(&self, key: &[u8]) -> Result<(), KvError> {
245 let rc = unsafe { (self.delete)(self.ctx, key.as_ptr(), key.len()) };
246 if rc != 0 {
247 Err(KvError::IoError)
248 } else {
249 Ok(())
250 }
251 }
252
253 fn iter(&self, prefix: &[u8], reverse: bool) -> KvIter<'_> {
255 let handle =
256 unsafe { (self.iter_create)(self.ctx, prefix.as_ptr(), prefix.len(), reverse as u8) };
257 KvIter { handle, cb: self }
258 }
259}
260
261struct KvIter<'a> {
262 handle: *mut c_void,
263 cb: &'a KvCallbacks,
264}
265
266impl<'a> KvIter<'a> {
267 fn next(&mut self) -> Option<(Vec<u8>, Vec<u8>)> {
269 if self.handle.is_null() {
270 return None;
271 }
272 let mut key_ptr: *mut u8 = std::ptr::null_mut();
273 let mut key_len: usize = 0;
274 let mut val_ptr: *mut u8 = std::ptr::null_mut();
275 let mut val_len: usize = 0;
276 let rc = unsafe {
277 (self.cb.iter_next)(
278 self.handle,
279 &mut key_ptr,
280 &mut key_len,
281 &mut val_ptr,
282 &mut val_len,
283 )
284 };
285 if rc != 0 {
286 return None;
287 }
288 let key = unsafe { std::slice::from_raw_parts(key_ptr, key_len).to_vec() };
289 unsafe { (self.cb.free_buf)(key_ptr, key_len) };
290 let val = unsafe { std::slice::from_raw_parts(val_ptr, val_len).to_vec() };
291 unsafe { (self.cb.free_buf)(val_ptr, val_len) };
292 Some((key, val))
293 }
294}
295
296impl<'a> Drop for KvIter<'a> {
297 fn drop(&mut self) {
298 if !self.handle.is_null() {
299 unsafe { (self.cb.iter_free)(self.handle) };
300 }
301 }
302}
303
304pub struct KvShardStore {
312 pub(crate) cb: KvCallbacks,
313}
314
315impl KvShardStore {
316 pub fn new(cb: KvCallbacks) -> Self {
317 Self { cb }
318 }
319}
320
321impl ShardStore for KvShardStore {
326 type H = MerkleHashVote;
327 type CheckpointId = u32;
328 type Error = KvError;
329
330 fn get_shard(
331 &self,
332 shard_root: Address,
333 ) -> Result<Option<LocatedPrunableTree<MerkleHashVote>>, KvError> {
334 let idx = shard_root.index();
335 let key = shard_key(idx);
336 let Some(blob) = self.cb.get(&key)? else {
337 return Ok(None);
338 };
339 match read_shard_vote(&blob) {
340 Ok(tree) => Ok(LocatedTree::from_parts(shard_root, tree).ok()),
341 Err(_) => Err(KvError::Deserialization),
342 }
343 }
344
345 fn last_shard(&self) -> Result<Option<LocatedPrunableTree<MerkleHashVote>>, KvError> {
346 let prefix = [SHARD_PREFIX];
347 let mut iter = self.cb.iter(&prefix, true );
348 let Some((key, val)) = iter.next() else {
349 return Ok(None);
350 };
351 if key.len() < 9 {
352 return Ok(None);
353 }
354 let idx = u64::from_be_bytes(key[1..9].try_into().unwrap());
355 let level = Level::from(SHARD_HEIGHT);
356 let addr = Address::from_parts(level, idx);
357 match read_shard_vote(&val) {
358 Ok(tree) => Ok(LocatedTree::from_parts(addr, tree).ok()),
359 Err(_) => Err(KvError::Deserialization),
360 }
361 }
362
363 fn put_shard(&mut self, subtree: LocatedPrunableTree<MerkleHashVote>) -> Result<(), KvError> {
364 let idx = subtree.root_addr().index();
365 let key = shard_key(idx);
366 let blob = write_shard_vote(subtree.root()).map_err(|_| KvError::Serialization)?;
367 self.cb.set(&key, &blob)
368 }
369
370 fn get_shard_roots(&self) -> Result<Vec<Address>, KvError> {
371 let prefix = [SHARD_PREFIX];
372 let mut iter = self.cb.iter(&prefix, false);
373 let level = Level::from(SHARD_HEIGHT);
374 let mut roots = Vec::new();
375 while let Some((key, _)) = iter.next() {
376 if key.len() < 9 {
377 continue;
378 }
379 let idx = u64::from_be_bytes(key[1..9].try_into().unwrap());
380 roots.push(Address::from_parts(level, idx));
381 }
382 Ok(roots)
383 }
384
385 fn truncate_shards(&mut self, shard_index: u64) -> Result<(), KvError> {
386 let prefix = [SHARD_PREFIX];
387 let mut iter = self.cb.iter(&prefix, false);
388 let mut to_delete = Vec::new();
389 while let Some((key, _)) = iter.next() {
390 if key.len() < 9 {
391 continue;
392 }
393 let idx = u64::from_be_bytes(key[1..9].try_into().unwrap());
394 if idx >= shard_index {
395 to_delete.push(key);
396 }
397 }
398 drop(iter);
399 for key in to_delete {
400 self.cb.delete(&key)?;
401 }
402 Ok(())
403 }
404
405 fn get_cap(&self) -> Result<PrunableTree<MerkleHashVote>, KvError> {
406 let key = cap_key();
407 let Some(blob) = self.cb.get(&key)? else {
408 return Ok(Tree::empty());
409 };
410 read_shard_vote(&blob).map_err(|_| KvError::Deserialization)
411 }
412
413 fn put_cap(&mut self, cap: PrunableTree<MerkleHashVote>) -> Result<(), KvError> {
414 let key = cap_key();
415 let blob = write_shard_vote(&cap).map_err(|_| KvError::Serialization)?;
416 self.cb.set(&key, &blob)
417 }
418
419 fn min_checkpoint_id(&self) -> Result<Option<u32>, KvError> {
420 let prefix = [CHECKPOINT_PREFIX];
421 let mut iter = self.cb.iter(&prefix, false);
422 Ok(iter.next().and_then(|(k, _)| {
423 if k.len() >= 5 {
424 Some(u32::from_be_bytes(k[1..5].try_into().unwrap()))
425 } else {
426 None
427 }
428 }))
429 }
430
431 fn max_checkpoint_id(&self) -> Result<Option<u32>, KvError> {
432 let prefix = [CHECKPOINT_PREFIX];
433 let mut iter = self.cb.iter(&prefix, true );
434 Ok(iter.next().and_then(|(k, _)| {
435 if k.len() >= 5 {
436 Some(u32::from_be_bytes(k[1..5].try_into().unwrap()))
437 } else {
438 None
439 }
440 }))
441 }
442
443 fn add_checkpoint(
444 &mut self,
445 checkpoint_id: u32,
446 checkpoint: Checkpoint,
447 ) -> Result<(), KvError> {
448 let key = checkpoint_key(checkpoint_id);
449 let blob = write_checkpoint(&checkpoint);
450 self.cb.set(&key, &blob)
451 }
452
453 fn checkpoint_count(&self) -> Result<usize, KvError> {
454 let prefix = [CHECKPOINT_PREFIX];
455 let mut iter = self.cb.iter(&prefix, false);
456 let mut count = 0usize;
457 while iter.next().is_some() {
458 count += 1;
459 }
460 Ok(count)
461 }
462
463 fn get_checkpoint_at_depth(
464 &self,
465 checkpoint_depth: usize,
466 ) -> Result<Option<(u32, Checkpoint)>, KvError> {
467 let prefix = [CHECKPOINT_PREFIX];
468 let mut iter = self.cb.iter(&prefix, true );
469 let mut seen = 0usize;
470 while let Some((key, val)) = iter.next() {
471 if seen == checkpoint_depth {
472 if key.len() < 5 {
473 return Ok(None);
474 }
475 let id = u32::from_be_bytes(key[1..5].try_into().unwrap());
476 return Ok(read_checkpoint(&val).ok().map(|cp| (id, cp)));
477 }
478 seen += 1;
479 }
480 Ok(None)
481 }
482
483 fn get_checkpoint(&self, checkpoint_id: &u32) -> Result<Option<Checkpoint>, KvError> {
484 let key = checkpoint_key(*checkpoint_id);
485 let Some(blob) = self.cb.get(&key)? else {
486 return Ok(None);
487 };
488 Ok(read_checkpoint(&blob).ok())
489 }
490
491 fn with_checkpoints<F>(&mut self, limit: usize, mut callback: F) -> Result<(), KvError>
492 where
493 F: FnMut(&u32, &Checkpoint) -> Result<(), KvError>,
494 {
495 let prefix = [CHECKPOINT_PREFIX];
496 let mut iter = self.cb.iter(&prefix, false);
497 let mut count = 0usize;
498 while count < limit {
499 let Some((key, val)) = iter.next() else {
500 break;
501 };
502 if key.len() < 5 {
503 continue;
504 }
505 let id = u32::from_be_bytes(key[1..5].try_into().unwrap());
506 if let Ok(cp) = read_checkpoint(&val) {
507 callback(&id, &cp)?;
508 }
509 count += 1;
510 }
511 Ok(())
512 }
513
514 fn for_each_checkpoint<F>(&self, limit: usize, mut callback: F) -> Result<(), KvError>
515 where
516 F: FnMut(&u32, &Checkpoint) -> Result<(), KvError>,
517 {
518 let prefix = [CHECKPOINT_PREFIX];
519 let mut iter = self.cb.iter(&prefix, false);
520 let mut count = 0usize;
521 while count < limit {
522 let Some((key, val)) = iter.next() else {
523 break;
524 };
525 if key.len() < 5 {
526 continue;
527 }
528 let id = u32::from_be_bytes(key[1..5].try_into().unwrap());
529 if let Ok(cp) = read_checkpoint(&val) {
530 callback(&id, &cp)?;
531 }
532 count += 1;
533 }
534 Ok(())
535 }
536
537 fn update_checkpoint_with<F>(&mut self, checkpoint_id: &u32, update: F) -> Result<bool, KvError>
538 where
539 F: Fn(&mut Checkpoint) -> Result<(), KvError>,
540 {
541 let key = checkpoint_key(*checkpoint_id);
542 let Some(blob) = self.cb.get(&key)? else {
543 return Ok(false);
544 };
545 let Ok(mut cp) = read_checkpoint(&blob) else {
546 return Ok(false);
547 };
548 update(&mut cp)?;
549 let new_blob = write_checkpoint(&cp);
550 self.cb.set(&key, &new_blob)?;
551 Ok(true)
552 }
553
554 fn remove_checkpoint(&mut self, checkpoint_id: &u32) -> Result<(), KvError> {
555 let key = checkpoint_key(*checkpoint_id);
556 self.cb.delete(&key)
557 }
558
559 fn truncate_checkpoints_retaining(&mut self, checkpoint_id: &u32) -> Result<(), KvError> {
560 let prefix = [CHECKPOINT_PREFIX];
563 let mut iter = self.cb.iter(&prefix, false);
564 let mut to_delete = Vec::new();
565 while let Some((key, _)) = iter.next() {
566 if key.len() < 5 {
567 continue;
568 }
569 let id = u32::from_be_bytes(key[1..5].try_into().unwrap());
570 if id < *checkpoint_id {
571 to_delete.push(key);
572 } else {
573 break;
574 }
575 }
576 drop(iter);
577 for key in to_delete {
578 self.cb.delete(&key)?;
579 }
580 let retain_key = checkpoint_key(*checkpoint_id);
582 if let Some(blob) = self.cb.get(&retain_key)? {
583 if let Ok(cp) = read_checkpoint(&blob) {
584 let cleared = Checkpoint::from_parts(cp.tree_state(), BTreeSet::new());
585 self.cb.set(&retain_key, &write_checkpoint(&cleared))?;
586 }
587 }
588 Ok(())
589 }
590}