Skip to main content

vote_commitment_tree/
kv_shard_store.rs

1//! [`KvShardStore`] — a [`ShardStore`] implementation backed by Go's Cosmos KV
2//! store via C function pointer callbacks.
3//!
4//! # Design
5//!
6//! Instead of maintaining an in-process copy of all shard data,
7//! `KvShardStore` forwards every [`ShardStore`] read and write directly to
8//! the Cosmos KV store through a set of C callbacks registered at creation
9//! time. Go registers `//export` functions that dispatch to the current
10//! block's `store.KVStore` through a stable proxy pointer.
11//!
12//! This gives `ShardTree` true lazy loading: on a cold start only the data
13//! that is actually accessed (the frontier shard + cap + checkpoints) is read.
14//! No explicit restore loop, no O(n) blob loading, no shard geometry in Go.
15//!
16//! # KV key schema (matches keys.go)
17//!
18//! | Prefix    | Key                              | Value           |
19//! |-----------|----------------------------------|-----------------|
20//! | `0x0F`    | `0x0F \|\| u64 BE shard_index`   | shard blob      |
21//! | `0x10`    | `0x10`                           | cap blob        |
22//! | `0x11`    | `0x11 \|\| u32 BE checkpoint_id` | checkpoint blob |
23//!
24//! # Buffer ownership
25//!
26//! `get` returns a C-malloc'd buffer that Rust frees with the provided
27//! `free_buf` callback after copying the value. All write callbacks receive
28//! a Rust-owned slice (pointer + length); they must copy the data if they
29//! need it to outlive the call.
30//!
31//! # Iterator protocol
32//!
33//! `iter_create(ctx, prefix, prefix_len, reverse)` returns an opaque handle
34//! (a `cgo.Handle` on the Go side). `iter_next` advances and writes
35//! C-malloc'd key + value; Rust frees each pair with `free_buf` before the
36//! next call. `iter_free` closes and drops the iterator. `iter_next` returns
37//! 0 on a valid entry, 1 when exhausted, -1 on error.
38
39use std::collections::BTreeSet;
40use std::fmt;
41use std::os::raw::c_void;
42
43use incrementalmerkletree::{Address, Level};
44use shardtree::{
45    store::{Checkpoint, ShardStore},
46    LocatedPrunableTree, LocatedTree, PrunableTree, Tree,
47};
48
49use crate::hash::{MerkleHashVote, SHARD_HEIGHT};
50use crate::serde::{read_checkpoint, read_shard_vote, write_checkpoint, write_shard_vote};
51
52// ---------------------------------------------------------------------------
53// KvError
54// ---------------------------------------------------------------------------
55
56/// Error type for [`KvShardStore`] operations.
57///
58/// Replaces `Infallible` so that KV callback failures are visible to callers
59/// rather than being silently swallowed. The three variants cover all
60/// observable failure modes:
61///
62/// - `IoError`: a KV callback returned a non-zero error code (disk full,
63///   store closed, etc.).
64/// - `Deserialization`: a blob retrieved from KV failed to decode.
65/// - `Serialization`: a shard or cap could not be encoded before writing.
66#[derive(Debug, Clone, PartialEq, Eq)]
67pub enum KvError {
68    /// A KV callback returned an error code (set, delete, or iterator failure).
69    IoError,
70    /// Shard or checkpoint data retrieved from KV could not be decoded.
71    Deserialization,
72    /// Shard or cap data could not be serialized before writing.
73    Serialization,
74}
75
76impl fmt::Display for KvError {
77    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
78        match self {
79            KvError::IoError => write!(f, "KV callback returned an error"),
80            KvError::Deserialization => write!(f, "failed to deserialize KV data"),
81            KvError::Serialization => write!(f, "failed to serialize data for KV"),
82        }
83    }
84}
85
86impl std::error::Error for KvError {}
87
88// ---------------------------------------------------------------------------
89// KV key constants (must match keys.go 0x0F / 0x10 / 0x11)
90// ---------------------------------------------------------------------------
91
92const SHARD_PREFIX: u8 = 0x0F;
93const CAP_KEY: u8 = 0x10;
94const CHECKPOINT_PREFIX: u8 = 0x11;
95
96fn shard_key(index: u64) -> [u8; 9] {
97    let mut k = [0u8; 9];
98    k[0] = SHARD_PREFIX;
99    k[1..].copy_from_slice(&index.to_be_bytes());
100    k
101}
102
103fn cap_key() -> [u8; 1] {
104    [CAP_KEY]
105}
106
107fn checkpoint_key(id: u32) -> [u8; 5] {
108    let mut k = [0u8; 5];
109    k[0] = CHECKPOINT_PREFIX;
110    k[1..].copy_from_slice(&id.to_be_bytes());
111    k
112}
113
114// ---------------------------------------------------------------------------
115// Callback function pointer types
116// ---------------------------------------------------------------------------
117
118/// Retrieve a value from the KV store.
119///
120/// On success (key found) writes a C-malloc'd buffer to `*out_val` and its
121/// length to `*out_val_len`, then returns 0.
122/// Returns 1 if the key was not found (out pointers are unchanged).
123/// Returns -1 on error.
124pub type KvGetFn = unsafe extern "C" fn(
125    ctx: *mut c_void,
126    key: *const u8,
127    key_len: usize,
128    out_val: *mut *mut u8,
129    out_val_len: *mut usize,
130) -> i32;
131
132/// Write a key-value pair. Returns 0 on success, -1 on error.
133pub type KvSetFn = unsafe extern "C" fn(
134    ctx: *mut c_void,
135    key: *const u8,
136    key_len: usize,
137    val: *const u8,
138    val_len: usize,
139) -> i32;
140
141/// Delete a key. Returns 0 on success, -1 on error.
142pub type KvDeleteFn = unsafe extern "C" fn(ctx: *mut c_void, key: *const u8, key_len: usize)
143    -> i32;
144
145/// Create an iterator over the given prefix.
146///
147/// `reverse` is 1 for a reverse (descending) iterator, 0 for ascending.
148/// Returns an opaque iterator handle, or null on error.
149pub type KvIterCreateFn = unsafe extern "C" fn(
150    ctx: *mut c_void,
151    prefix: *const u8,
152    prefix_len: usize,
153    reverse: u8,
154) -> *mut c_void;
155
156/// Advance the iterator and return the next key-value pair as C-malloc'd
157/// buffers. Caller frees with `free_buf`.
158///
159/// Returns 0 if a valid entry was written, 1 if exhausted, -1 on error.
160pub type KvIterNextFn = unsafe extern "C" fn(
161    iter: *mut c_void,
162    out_key: *mut *mut u8,
163    out_key_len: *mut usize,
164    out_val: *mut *mut u8,
165    out_val_len: *mut usize,
166) -> i32;
167
168/// Close and free an iterator handle.
169pub type KvIterFreeFn = unsafe extern "C" fn(iter: *mut c_void);
170
171/// Free a C-malloc'd buffer returned by a KV callback.
172pub type KvFreeBufFn = unsafe extern "C" fn(ptr: *mut u8, len: usize);
173
174// ---------------------------------------------------------------------------
175// KvCallbacks
176// ---------------------------------------------------------------------------
177
178/// Bundle of C function pointers + context passed to [`KvShardStore`].
179///
180/// # Safety
181/// All function pointers must remain valid for the lifetime of the
182/// `KvShardStore`. The `ctx` pointer must remain stable; Go achieves this
183/// via a `KvStoreProxy` whose address never changes across blocks.
184#[derive(Clone, Copy)]
185pub struct KvCallbacks {
186    pub ctx: *mut c_void,
187    pub get: KvGetFn,
188    pub set: KvSetFn,
189    pub delete: KvDeleteFn,
190    pub iter_create: KvIterCreateFn,
191    pub iter_next: KvIterNextFn,
192    pub iter_free: KvIterFreeFn,
193    pub free_buf: KvFreeBufFn,
194}
195
196// SAFETY: EndBlocker is single-threaded; all callbacks are called only on
197// the goroutine that owns the KV store.
198unsafe impl Send for KvCallbacks {}
199unsafe impl Sync for KvCallbacks {}
200
201// ---------------------------------------------------------------------------
202// Low-level helpers
203// ---------------------------------------------------------------------------
204
205impl KvCallbacks {
206    /// Fetch a value by key.
207    ///
208    /// Returns `Ok(Some(bytes))` if found, `Ok(None)` if not present, or
209    /// `Err(KvError::IoError)` if the callback signalled a hard error (rc=-1).
210    pub fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>, KvError> {
211        let mut out_ptr: *mut u8 = std::ptr::null_mut();
212        let mut out_len: usize = 0;
213        let rc = unsafe {
214            (self.get)(self.ctx, key.as_ptr(), key.len(), &mut out_ptr, &mut out_len)
215        };
216        match rc {
217            0 => {
218                let val = unsafe { std::slice::from_raw_parts(out_ptr, out_len).to_vec() };
219                unsafe { (self.free_buf)(out_ptr, out_len) };
220                Ok(Some(val))
221            }
222            1 => Ok(None),        // not found
223            _ => Err(KvError::IoError), // rc=-1 or any other error code
224        }
225    }
226
227    /// Write a key-value pair. Returns `Err(KvError::IoError)` if the
228    /// callback returned a non-zero code.
229    pub fn set(&self, key: &[u8], val: &[u8]) -> Result<(), KvError> {
230        let rc = unsafe {
231            (self.set)(self.ctx, key.as_ptr(), key.len(), val.as_ptr(), val.len())
232        };
233        if rc != 0 {
234            Err(KvError::IoError)
235        } else {
236            Ok(())
237        }
238    }
239
240    /// Delete a key. Returns `Err(KvError::IoError)` if the callback failed.
241    pub fn delete(&self, key: &[u8]) -> Result<(), KvError> {
242        let rc = unsafe { (self.delete)(self.ctx, key.as_ptr(), key.len()) };
243        if rc != 0 {
244            Err(KvError::IoError)
245        } else {
246            Ok(())
247        }
248    }
249
250    /// Create a forward or reverse iterator over the given prefix.
251    fn iter(&self, prefix: &[u8], reverse: bool) -> KvIter<'_> {
252        let handle = unsafe {
253            (self.iter_create)(self.ctx, prefix.as_ptr(), prefix.len(), reverse as u8)
254        };
255        KvIter {
256            handle,
257            cb: self,
258        }
259    }
260}
261
262struct KvIter<'a> {
263    handle: *mut c_void,
264    cb: &'a KvCallbacks,
265}
266
267impl<'a> KvIter<'a> {
268    /// Advance and return `Some((key, value))`, or `None` when exhausted.
269    fn next(&mut self) -> Option<(Vec<u8>, Vec<u8>)> {
270        if self.handle.is_null() {
271            return None;
272        }
273        let mut key_ptr: *mut u8 = std::ptr::null_mut();
274        let mut key_len: usize = 0;
275        let mut val_ptr: *mut u8 = std::ptr::null_mut();
276        let mut val_len: usize = 0;
277        let rc = unsafe {
278            (self.cb.iter_next)(
279                self.handle,
280                &mut key_ptr,
281                &mut key_len,
282                &mut val_ptr,
283                &mut val_len,
284            )
285        };
286        if rc != 0 {
287            return None;
288        }
289        let key = unsafe { std::slice::from_raw_parts(key_ptr, key_len).to_vec() };
290        unsafe { (self.cb.free_buf)(key_ptr, key_len) };
291        let val = unsafe { std::slice::from_raw_parts(val_ptr, val_len).to_vec() };
292        unsafe { (self.cb.free_buf)(val_ptr, val_len) };
293        Some((key, val))
294    }
295}
296
297impl<'a> Drop for KvIter<'a> {
298    fn drop(&mut self) {
299        if !self.handle.is_null() {
300            unsafe { (self.cb.iter_free)(self.handle) };
301        }
302    }
303}
304
305// ---------------------------------------------------------------------------
306// KvShardStore
307// ---------------------------------------------------------------------------
308
309/// A [`ShardStore`] that stores all state in the Cosmos KV store via Go
310/// callbacks. Gives `ShardTree` true lazy loading: only the data it actually
311/// accesses is read from KV.
312pub struct KvShardStore {
313    pub(crate) cb: KvCallbacks,
314}
315
316impl KvShardStore {
317    pub fn new(cb: KvCallbacks) -> Self {
318        Self { cb }
319    }
320}
321
322// ---------------------------------------------------------------------------
323// ShardStore implementation
324// ---------------------------------------------------------------------------
325
326impl ShardStore for KvShardStore {
327    type H = MerkleHashVote;
328    type CheckpointId = u32;
329    type Error = KvError;
330
331    fn get_shard(
332        &self,
333        shard_root: Address,
334    ) -> Result<Option<LocatedPrunableTree<MerkleHashVote>>, KvError> {
335        let idx = shard_root.index();
336        let key = shard_key(idx);
337        let Some(blob) = self.cb.get(&key)? else {
338            return Ok(None);
339        };
340        match read_shard_vote(&blob) {
341            Ok(tree) => Ok(LocatedTree::from_parts(shard_root, tree).ok()),
342            Err(_) => Err(KvError::Deserialization),
343        }
344    }
345
346    fn last_shard(&self) -> Result<Option<LocatedPrunableTree<MerkleHashVote>>, KvError> {
347        let prefix = [SHARD_PREFIX];
348        let mut iter = self.cb.iter(&prefix, true /* reverse */);
349        let Some((key, val)) = iter.next() else {
350            return Ok(None);
351        };
352        if key.len() < 9 {
353            return Ok(None);
354        }
355        let idx = u64::from_be_bytes(key[1..9].try_into().unwrap());
356        let level = Level::from(SHARD_HEIGHT);
357        let addr = Address::from_parts(level, idx);
358        match read_shard_vote(&val) {
359            Ok(tree) => Ok(LocatedTree::from_parts(addr, tree).ok()),
360            Err(_) => Err(KvError::Deserialization),
361        }
362    }
363
364    fn put_shard(
365        &mut self,
366        subtree: LocatedPrunableTree<MerkleHashVote>,
367    ) -> Result<(), KvError> {
368        let idx = subtree.root_addr().index();
369        let key = shard_key(idx);
370        let blob = write_shard_vote(subtree.root()).map_err(|_| KvError::Serialization)?;
371        self.cb.set(&key, &blob)
372    }
373
374    fn get_shard_roots(&self) -> Result<Vec<Address>, KvError> {
375        let prefix = [SHARD_PREFIX];
376        let mut iter = self.cb.iter(&prefix, false);
377        let level = Level::from(SHARD_HEIGHT);
378        let mut roots = Vec::new();
379        while let Some((key, _)) = iter.next() {
380            if key.len() < 9 {
381                continue;
382            }
383            let idx = u64::from_be_bytes(key[1..9].try_into().unwrap());
384            roots.push(Address::from_parts(level, idx));
385        }
386        Ok(roots)
387    }
388
389    fn truncate_shards(&mut self, shard_index: u64) -> Result<(), KvError> {
390        let prefix = [SHARD_PREFIX];
391        let mut iter = self.cb.iter(&prefix, false);
392        let mut to_delete = Vec::new();
393        while let Some((key, _)) = iter.next() {
394            if key.len() < 9 {
395                continue;
396            }
397            let idx = u64::from_be_bytes(key[1..9].try_into().unwrap());
398            if idx >= shard_index {
399                to_delete.push(key);
400            }
401        }
402        drop(iter);
403        for key in to_delete {
404            self.cb.delete(&key)?;
405        }
406        Ok(())
407    }
408
409    fn get_cap(&self) -> Result<PrunableTree<MerkleHashVote>, KvError> {
410        let key = cap_key();
411        let Some(blob) = self.cb.get(&key)? else {
412            return Ok(Tree::empty());
413        };
414        read_shard_vote(&blob).map_err(|_| KvError::Deserialization)
415    }
416
417    fn put_cap(&mut self, cap: PrunableTree<MerkleHashVote>) -> Result<(), KvError> {
418        let key = cap_key();
419        let blob = write_shard_vote(&cap).map_err(|_| KvError::Serialization)?;
420        self.cb.set(&key, &blob)
421    }
422
423    fn min_checkpoint_id(&self) -> Result<Option<u32>, KvError> {
424        let prefix = [CHECKPOINT_PREFIX];
425        let mut iter = self.cb.iter(&prefix, false);
426        Ok(iter.next().and_then(|(k, _)| {
427            if k.len() >= 5 {
428                Some(u32::from_be_bytes(k[1..5].try_into().unwrap()))
429            } else {
430                None
431            }
432        }))
433    }
434
435    fn max_checkpoint_id(&self) -> Result<Option<u32>, KvError> {
436        let prefix = [CHECKPOINT_PREFIX];
437        let mut iter = self.cb.iter(&prefix, true /* reverse */);
438        Ok(iter.next().and_then(|(k, _)| {
439            if k.len() >= 5 {
440                Some(u32::from_be_bytes(k[1..5].try_into().unwrap()))
441            } else {
442                None
443            }
444        }))
445    }
446
447    fn add_checkpoint(
448        &mut self,
449        checkpoint_id: u32,
450        checkpoint: Checkpoint,
451    ) -> Result<(), KvError> {
452        let key = checkpoint_key(checkpoint_id);
453        let blob = write_checkpoint(&checkpoint);
454        self.cb.set(&key, &blob)
455    }
456
457    fn checkpoint_count(&self) -> Result<usize, KvError> {
458        let prefix = [CHECKPOINT_PREFIX];
459        let mut iter = self.cb.iter(&prefix, false);
460        let mut count = 0usize;
461        while iter.next().is_some() {
462            count += 1;
463        }
464        Ok(count)
465    }
466
467    fn get_checkpoint_at_depth(
468        &self,
469        checkpoint_depth: usize,
470    ) -> Result<Option<(u32, Checkpoint)>, KvError> {
471        let prefix = [CHECKPOINT_PREFIX];
472        let mut iter = self.cb.iter(&prefix, true /* reverse */);
473        let mut seen = 0usize;
474        while let Some((key, val)) = iter.next() {
475            if seen == checkpoint_depth {
476                if key.len() < 5 {
477                    return Ok(None);
478                }
479                let id = u32::from_be_bytes(key[1..5].try_into().unwrap());
480                return Ok(read_checkpoint(&val).ok().map(|cp| (id, cp)));
481            }
482            seen += 1;
483        }
484        Ok(None)
485    }
486
487    fn get_checkpoint(&self, checkpoint_id: &u32) -> Result<Option<Checkpoint>, KvError> {
488        let key = checkpoint_key(*checkpoint_id);
489        let Some(blob) = self.cb.get(&key)? else {
490            return Ok(None);
491        };
492        Ok(read_checkpoint(&blob).ok())
493    }
494
495    fn with_checkpoints<F>(&mut self, limit: usize, mut callback: F) -> Result<(), KvError>
496    where
497        F: FnMut(&u32, &Checkpoint) -> Result<(), KvError>,
498    {
499        let prefix = [CHECKPOINT_PREFIX];
500        let mut iter = self.cb.iter(&prefix, false);
501        let mut count = 0usize;
502        while count < limit {
503            let Some((key, val)) = iter.next() else {
504                break;
505            };
506            if key.len() < 5 {
507                continue;
508            }
509            let id = u32::from_be_bytes(key[1..5].try_into().unwrap());
510            if let Ok(cp) = read_checkpoint(&val) {
511                callback(&id, &cp)?;
512            }
513            count += 1;
514        }
515        Ok(())
516    }
517
518    fn for_each_checkpoint<F>(&self, limit: usize, mut callback: F) -> Result<(), KvError>
519    where
520        F: FnMut(&u32, &Checkpoint) -> Result<(), KvError>,
521    {
522        let prefix = [CHECKPOINT_PREFIX];
523        let mut iter = self.cb.iter(&prefix, false);
524        let mut count = 0usize;
525        while count < limit {
526            let Some((key, val)) = iter.next() else {
527                break;
528            };
529            if key.len() < 5 {
530                continue;
531            }
532            let id = u32::from_be_bytes(key[1..5].try_into().unwrap());
533            if let Ok(cp) = read_checkpoint(&val) {
534                callback(&id, &cp)?;
535            }
536            count += 1;
537        }
538        Ok(())
539    }
540
541    fn update_checkpoint_with<F>(
542        &mut self,
543        checkpoint_id: &u32,
544        update: F,
545    ) -> Result<bool, KvError>
546    where
547        F: Fn(&mut Checkpoint) -> Result<(), KvError>,
548    {
549        let key = checkpoint_key(*checkpoint_id);
550        let Some(blob) = self.cb.get(&key)? else {
551            return Ok(false);
552        };
553        let Ok(mut cp) = read_checkpoint(&blob) else {
554            return Ok(false);
555        };
556        update(&mut cp)?;
557        let new_blob = write_checkpoint(&cp);
558        self.cb.set(&key, &new_blob)?;
559        Ok(true)
560    }
561
562    fn remove_checkpoint(&mut self, checkpoint_id: &u32) -> Result<(), KvError> {
563        let key = checkpoint_key(*checkpoint_id);
564        self.cb.delete(&key)
565    }
566
567    fn truncate_checkpoints_retaining(
568        &mut self,
569        checkpoint_id: &u32,
570    ) -> Result<(), KvError> {
571        // Delete all checkpoints with id < checkpoint_id; clear marks_removed
572        // on the retained checkpoint itself (matches MemoryShardStore semantics).
573        let prefix = [CHECKPOINT_PREFIX];
574        let mut iter = self.cb.iter(&prefix, false);
575        let mut to_delete = Vec::new();
576        while let Some((key, _)) = iter.next() {
577            if key.len() < 5 {
578                continue;
579            }
580            let id = u32::from_be_bytes(key[1..5].try_into().unwrap());
581            if id < *checkpoint_id {
582                to_delete.push(key);
583            } else {
584                break;
585            }
586        }
587        drop(iter);
588        for key in to_delete {
589            self.cb.delete(&key)?;
590        }
591        // Clear marks_removed on the retaining checkpoint.
592        let retain_key = checkpoint_key(*checkpoint_id);
593        if let Some(blob) = self.cb.get(&retain_key)? {
594            if let Ok(cp) = read_checkpoint(&blob) {
595                let cleared = Checkpoint::from_parts(cp.tree_state(), BTreeSet::new());
596                self.cb.set(&retain_key, &write_checkpoint(&cleared))?;
597            }
598        }
599        Ok(())
600    }
601}