1use std::collections::BTreeSet;
40use std::fmt;
41use std::os::raw::c_void;
42
43use incrementalmerkletree::{Address, Level};
44use shardtree::{
45 store::{Checkpoint, ShardStore},
46 LocatedPrunableTree, LocatedTree, PrunableTree, Tree,
47};
48
49use crate::hash::{MerkleHashVote, SHARD_HEIGHT};
50use crate::serde::{read_checkpoint, read_shard_vote, write_checkpoint, write_shard_vote};
51
52#[derive(Debug, Clone, PartialEq, Eq)]
67pub enum KvError {
68 IoError,
70 Deserialization,
72 Serialization,
74}
75
76impl fmt::Display for KvError {
77 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
78 match self {
79 KvError::IoError => write!(f, "KV callback returned an error"),
80 KvError::Deserialization => write!(f, "failed to deserialize KV data"),
81 KvError::Serialization => write!(f, "failed to serialize data for KV"),
82 }
83 }
84}
85
86impl std::error::Error for KvError {}
87
88const SHARD_PREFIX: u8 = 0x0F;
93const CAP_KEY: u8 = 0x10;
94const CHECKPOINT_PREFIX: u8 = 0x11;
95
96fn shard_key(index: u64) -> [u8; 9] {
97 let mut k = [0u8; 9];
98 k[0] = SHARD_PREFIX;
99 k[1..].copy_from_slice(&index.to_be_bytes());
100 k
101}
102
103fn cap_key() -> [u8; 1] {
104 [CAP_KEY]
105}
106
107fn checkpoint_key(id: u32) -> [u8; 5] {
108 let mut k = [0u8; 5];
109 k[0] = CHECKPOINT_PREFIX;
110 k[1..].copy_from_slice(&id.to_be_bytes());
111 k
112}
113
114pub type KvGetFn = unsafe extern "C" fn(
125 ctx: *mut c_void,
126 key: *const u8,
127 key_len: usize,
128 out_val: *mut *mut u8,
129 out_val_len: *mut usize,
130) -> i32;
131
132pub type KvSetFn = unsafe extern "C" fn(
134 ctx: *mut c_void,
135 key: *const u8,
136 key_len: usize,
137 val: *const u8,
138 val_len: usize,
139) -> i32;
140
141pub type KvDeleteFn = unsafe extern "C" fn(ctx: *mut c_void, key: *const u8, key_len: usize)
143 -> i32;
144
145pub type KvIterCreateFn = unsafe extern "C" fn(
150 ctx: *mut c_void,
151 prefix: *const u8,
152 prefix_len: usize,
153 reverse: u8,
154) -> *mut c_void;
155
156pub type KvIterNextFn = unsafe extern "C" fn(
161 iter: *mut c_void,
162 out_key: *mut *mut u8,
163 out_key_len: *mut usize,
164 out_val: *mut *mut u8,
165 out_val_len: *mut usize,
166) -> i32;
167
168pub type KvIterFreeFn = unsafe extern "C" fn(iter: *mut c_void);
170
171pub type KvFreeBufFn = unsafe extern "C" fn(ptr: *mut u8, len: usize);
173
174#[derive(Clone, Copy)]
185pub struct KvCallbacks {
186 pub ctx: *mut c_void,
187 pub get: KvGetFn,
188 pub set: KvSetFn,
189 pub delete: KvDeleteFn,
190 pub iter_create: KvIterCreateFn,
191 pub iter_next: KvIterNextFn,
192 pub iter_free: KvIterFreeFn,
193 pub free_buf: KvFreeBufFn,
194}
195
196unsafe impl Send for KvCallbacks {}
199unsafe impl Sync for KvCallbacks {}
200
201impl KvCallbacks {
206 pub fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>, KvError> {
211 let mut out_ptr: *mut u8 = std::ptr::null_mut();
212 let mut out_len: usize = 0;
213 let rc = unsafe {
214 (self.get)(self.ctx, key.as_ptr(), key.len(), &mut out_ptr, &mut out_len)
215 };
216 match rc {
217 0 => {
218 let val = unsafe { std::slice::from_raw_parts(out_ptr, out_len).to_vec() };
219 unsafe { (self.free_buf)(out_ptr, out_len) };
220 Ok(Some(val))
221 }
222 1 => Ok(None), _ => Err(KvError::IoError), }
225 }
226
227 pub fn set(&self, key: &[u8], val: &[u8]) -> Result<(), KvError> {
230 let rc = unsafe {
231 (self.set)(self.ctx, key.as_ptr(), key.len(), val.as_ptr(), val.len())
232 };
233 if rc != 0 {
234 Err(KvError::IoError)
235 } else {
236 Ok(())
237 }
238 }
239
240 pub fn delete(&self, key: &[u8]) -> Result<(), KvError> {
242 let rc = unsafe { (self.delete)(self.ctx, key.as_ptr(), key.len()) };
243 if rc != 0 {
244 Err(KvError::IoError)
245 } else {
246 Ok(())
247 }
248 }
249
250 fn iter(&self, prefix: &[u8], reverse: bool) -> KvIter<'_> {
252 let handle = unsafe {
253 (self.iter_create)(self.ctx, prefix.as_ptr(), prefix.len(), reverse as u8)
254 };
255 KvIter {
256 handle,
257 cb: self,
258 }
259 }
260}
261
262struct KvIter<'a> {
263 handle: *mut c_void,
264 cb: &'a KvCallbacks,
265}
266
267impl<'a> KvIter<'a> {
268 fn next(&mut self) -> Option<(Vec<u8>, Vec<u8>)> {
270 if self.handle.is_null() {
271 return None;
272 }
273 let mut key_ptr: *mut u8 = std::ptr::null_mut();
274 let mut key_len: usize = 0;
275 let mut val_ptr: *mut u8 = std::ptr::null_mut();
276 let mut val_len: usize = 0;
277 let rc = unsafe {
278 (self.cb.iter_next)(
279 self.handle,
280 &mut key_ptr,
281 &mut key_len,
282 &mut val_ptr,
283 &mut val_len,
284 )
285 };
286 if rc != 0 {
287 return None;
288 }
289 let key = unsafe { std::slice::from_raw_parts(key_ptr, key_len).to_vec() };
290 unsafe { (self.cb.free_buf)(key_ptr, key_len) };
291 let val = unsafe { std::slice::from_raw_parts(val_ptr, val_len).to_vec() };
292 unsafe { (self.cb.free_buf)(val_ptr, val_len) };
293 Some((key, val))
294 }
295}
296
297impl<'a> Drop for KvIter<'a> {
298 fn drop(&mut self) {
299 if !self.handle.is_null() {
300 unsafe { (self.cb.iter_free)(self.handle) };
301 }
302 }
303}
304
305pub struct KvShardStore {
313 pub(crate) cb: KvCallbacks,
314}
315
316impl KvShardStore {
317 pub fn new(cb: KvCallbacks) -> Self {
318 Self { cb }
319 }
320}
321
322impl ShardStore for KvShardStore {
327 type H = MerkleHashVote;
328 type CheckpointId = u32;
329 type Error = KvError;
330
331 fn get_shard(
332 &self,
333 shard_root: Address,
334 ) -> Result<Option<LocatedPrunableTree<MerkleHashVote>>, KvError> {
335 let idx = shard_root.index();
336 let key = shard_key(idx);
337 let Some(blob) = self.cb.get(&key)? else {
338 return Ok(None);
339 };
340 match read_shard_vote(&blob) {
341 Ok(tree) => Ok(LocatedTree::from_parts(shard_root, tree).ok()),
342 Err(_) => Err(KvError::Deserialization),
343 }
344 }
345
346 fn last_shard(&self) -> Result<Option<LocatedPrunableTree<MerkleHashVote>>, KvError> {
347 let prefix = [SHARD_PREFIX];
348 let mut iter = self.cb.iter(&prefix, true );
349 let Some((key, val)) = iter.next() else {
350 return Ok(None);
351 };
352 if key.len() < 9 {
353 return Ok(None);
354 }
355 let idx = u64::from_be_bytes(key[1..9].try_into().unwrap());
356 let level = Level::from(SHARD_HEIGHT);
357 let addr = Address::from_parts(level, idx);
358 match read_shard_vote(&val) {
359 Ok(tree) => Ok(LocatedTree::from_parts(addr, tree).ok()),
360 Err(_) => Err(KvError::Deserialization),
361 }
362 }
363
364 fn put_shard(
365 &mut self,
366 subtree: LocatedPrunableTree<MerkleHashVote>,
367 ) -> Result<(), KvError> {
368 let idx = subtree.root_addr().index();
369 let key = shard_key(idx);
370 let blob = write_shard_vote(subtree.root()).map_err(|_| KvError::Serialization)?;
371 self.cb.set(&key, &blob)
372 }
373
374 fn get_shard_roots(&self) -> Result<Vec<Address>, KvError> {
375 let prefix = [SHARD_PREFIX];
376 let mut iter = self.cb.iter(&prefix, false);
377 let level = Level::from(SHARD_HEIGHT);
378 let mut roots = Vec::new();
379 while let Some((key, _)) = iter.next() {
380 if key.len() < 9 {
381 continue;
382 }
383 let idx = u64::from_be_bytes(key[1..9].try_into().unwrap());
384 roots.push(Address::from_parts(level, idx));
385 }
386 Ok(roots)
387 }
388
389 fn truncate_shards(&mut self, shard_index: u64) -> Result<(), KvError> {
390 let prefix = [SHARD_PREFIX];
391 let mut iter = self.cb.iter(&prefix, false);
392 let mut to_delete = Vec::new();
393 while let Some((key, _)) = iter.next() {
394 if key.len() < 9 {
395 continue;
396 }
397 let idx = u64::from_be_bytes(key[1..9].try_into().unwrap());
398 if idx >= shard_index {
399 to_delete.push(key);
400 }
401 }
402 drop(iter);
403 for key in to_delete {
404 self.cb.delete(&key)?;
405 }
406 Ok(())
407 }
408
409 fn get_cap(&self) -> Result<PrunableTree<MerkleHashVote>, KvError> {
410 let key = cap_key();
411 let Some(blob) = self.cb.get(&key)? else {
412 return Ok(Tree::empty());
413 };
414 read_shard_vote(&blob).map_err(|_| KvError::Deserialization)
415 }
416
417 fn put_cap(&mut self, cap: PrunableTree<MerkleHashVote>) -> Result<(), KvError> {
418 let key = cap_key();
419 let blob = write_shard_vote(&cap).map_err(|_| KvError::Serialization)?;
420 self.cb.set(&key, &blob)
421 }
422
423 fn min_checkpoint_id(&self) -> Result<Option<u32>, KvError> {
424 let prefix = [CHECKPOINT_PREFIX];
425 let mut iter = self.cb.iter(&prefix, false);
426 Ok(iter.next().and_then(|(k, _)| {
427 if k.len() >= 5 {
428 Some(u32::from_be_bytes(k[1..5].try_into().unwrap()))
429 } else {
430 None
431 }
432 }))
433 }
434
435 fn max_checkpoint_id(&self) -> Result<Option<u32>, KvError> {
436 let prefix = [CHECKPOINT_PREFIX];
437 let mut iter = self.cb.iter(&prefix, true );
438 Ok(iter.next().and_then(|(k, _)| {
439 if k.len() >= 5 {
440 Some(u32::from_be_bytes(k[1..5].try_into().unwrap()))
441 } else {
442 None
443 }
444 }))
445 }
446
447 fn add_checkpoint(
448 &mut self,
449 checkpoint_id: u32,
450 checkpoint: Checkpoint,
451 ) -> Result<(), KvError> {
452 let key = checkpoint_key(checkpoint_id);
453 let blob = write_checkpoint(&checkpoint);
454 self.cb.set(&key, &blob)
455 }
456
457 fn checkpoint_count(&self) -> Result<usize, KvError> {
458 let prefix = [CHECKPOINT_PREFIX];
459 let mut iter = self.cb.iter(&prefix, false);
460 let mut count = 0usize;
461 while iter.next().is_some() {
462 count += 1;
463 }
464 Ok(count)
465 }
466
467 fn get_checkpoint_at_depth(
468 &self,
469 checkpoint_depth: usize,
470 ) -> Result<Option<(u32, Checkpoint)>, KvError> {
471 let prefix = [CHECKPOINT_PREFIX];
472 let mut iter = self.cb.iter(&prefix, true );
473 let mut seen = 0usize;
474 while let Some((key, val)) = iter.next() {
475 if seen == checkpoint_depth {
476 if key.len() < 5 {
477 return Ok(None);
478 }
479 let id = u32::from_be_bytes(key[1..5].try_into().unwrap());
480 return Ok(read_checkpoint(&val).ok().map(|cp| (id, cp)));
481 }
482 seen += 1;
483 }
484 Ok(None)
485 }
486
487 fn get_checkpoint(&self, checkpoint_id: &u32) -> Result<Option<Checkpoint>, KvError> {
488 let key = checkpoint_key(*checkpoint_id);
489 let Some(blob) = self.cb.get(&key)? else {
490 return Ok(None);
491 };
492 Ok(read_checkpoint(&blob).ok())
493 }
494
495 fn with_checkpoints<F>(&mut self, limit: usize, mut callback: F) -> Result<(), KvError>
496 where
497 F: FnMut(&u32, &Checkpoint) -> Result<(), KvError>,
498 {
499 let prefix = [CHECKPOINT_PREFIX];
500 let mut iter = self.cb.iter(&prefix, false);
501 let mut count = 0usize;
502 while count < limit {
503 let Some((key, val)) = iter.next() else {
504 break;
505 };
506 if key.len() < 5 {
507 continue;
508 }
509 let id = u32::from_be_bytes(key[1..5].try_into().unwrap());
510 if let Ok(cp) = read_checkpoint(&val) {
511 callback(&id, &cp)?;
512 }
513 count += 1;
514 }
515 Ok(())
516 }
517
518 fn for_each_checkpoint<F>(&self, limit: usize, mut callback: F) -> Result<(), KvError>
519 where
520 F: FnMut(&u32, &Checkpoint) -> Result<(), KvError>,
521 {
522 let prefix = [CHECKPOINT_PREFIX];
523 let mut iter = self.cb.iter(&prefix, false);
524 let mut count = 0usize;
525 while count < limit {
526 let Some((key, val)) = iter.next() else {
527 break;
528 };
529 if key.len() < 5 {
530 continue;
531 }
532 let id = u32::from_be_bytes(key[1..5].try_into().unwrap());
533 if let Ok(cp) = read_checkpoint(&val) {
534 callback(&id, &cp)?;
535 }
536 count += 1;
537 }
538 Ok(())
539 }
540
541 fn update_checkpoint_with<F>(
542 &mut self,
543 checkpoint_id: &u32,
544 update: F,
545 ) -> Result<bool, KvError>
546 where
547 F: Fn(&mut Checkpoint) -> Result<(), KvError>,
548 {
549 let key = checkpoint_key(*checkpoint_id);
550 let Some(blob) = self.cb.get(&key)? else {
551 return Ok(false);
552 };
553 let Ok(mut cp) = read_checkpoint(&blob) else {
554 return Ok(false);
555 };
556 update(&mut cp)?;
557 let new_blob = write_checkpoint(&cp);
558 self.cb.set(&key, &new_blob)?;
559 Ok(true)
560 }
561
562 fn remove_checkpoint(&mut self, checkpoint_id: &u32) -> Result<(), KvError> {
563 let key = checkpoint_key(*checkpoint_id);
564 self.cb.delete(&key)
565 }
566
567 fn truncate_checkpoints_retaining(
568 &mut self,
569 checkpoint_id: &u32,
570 ) -> Result<(), KvError> {
571 let prefix = [CHECKPOINT_PREFIX];
574 let mut iter = self.cb.iter(&prefix, false);
575 let mut to_delete = Vec::new();
576 while let Some((key, _)) = iter.next() {
577 if key.len() < 5 {
578 continue;
579 }
580 let id = u32::from_be_bytes(key[1..5].try_into().unwrap());
581 if id < *checkpoint_id {
582 to_delete.push(key);
583 } else {
584 break;
585 }
586 }
587 drop(iter);
588 for key in to_delete {
589 self.cb.delete(&key)?;
590 }
591 let retain_key = checkpoint_key(*checkpoint_id);
593 if let Some(blob) = self.cb.get(&retain_key)? {
594 if let Ok(cp) = read_checkpoint(&blob) {
595 let cleared = Checkpoint::from_parts(cp.tree_state(), BTreeSet::new());
596 self.cb.set(&retain_key, &write_checkpoint(&cleared))?;
597 }
598 }
599 Ok(())
600 }
601}