Skip to main content

vote_commitment_tree/
kv_shard_store.rs

1//! [`KvShardStore`] — a [`ShardStore`] implementation backed by Go's Cosmos KV
2//! store via C function pointer callbacks.
3//!
4//! # Design
5//!
6//! Instead of maintaining an in-process copy of all shard data,
7//! `KvShardStore` forwards every [`ShardStore`] read and write directly to
8//! the Cosmos KV store through a set of C callbacks registered at creation
9//! time. Go registers `//export` functions that dispatch to the current
10//! block's `store.KVStore` through a stable proxy pointer.
11//!
12//! This gives `ShardTree` true lazy loading: on a cold start only the data
13//! that is actually accessed (the frontier shard + cap + checkpoints) is read.
14//! No explicit restore loop, no O(n) blob loading, no shard geometry in Go.
15//!
16//! # KV key schema (matches keys.go)
17//!
18//! | Prefix    | Key                              | Value           |
19//! |-----------|----------------------------------|-----------------|
20//! | `0x0F`    | `0x0F \|\| u64 BE shard_index`   | shard blob      |
21//! | `0x10`    | `0x10`                           | cap blob        |
22//! | `0x11`    | `0x11 \|\| u32 BE checkpoint_id` | checkpoint blob |
23//!
24//! # Buffer ownership
25//!
26//! `get` returns a C-malloc'd buffer that Rust frees with the provided
27//! `free_buf` callback after copying the value. All write callbacks receive
28//! a Rust-owned slice (pointer + length); they must copy the data if they
29//! need it to outlive the call.
30//!
31//! # Iterator protocol
32//!
33//! `iter_create(ctx, prefix, prefix_len, reverse)` returns an opaque handle
34//! (a `cgo.Handle` on the Go side). `iter_next` advances and writes
35//! C-malloc'd key + value; Rust frees each pair with `free_buf` before the
36//! next call. `iter_free` closes and drops the iterator. `iter_next` returns
37//! 0 on a valid entry, 1 when exhausted, -1 on error.
38
39use std::collections::BTreeSet;
40use std::fmt;
41use std::os::raw::c_void;
42
43use incrementalmerkletree::{Address, Level};
44use shardtree::{
45    store::{Checkpoint, ShardStore},
46    LocatedPrunableTree, LocatedTree, PrunableTree, Tree,
47};
48
49use crate::hash::{MerkleHashVote, SHARD_HEIGHT};
50use crate::serde::{read_checkpoint, read_shard_vote, write_checkpoint, write_shard_vote};
51
52// ---------------------------------------------------------------------------
53// KvError
54// ---------------------------------------------------------------------------
55
56/// Error type for [`KvShardStore`] operations.
57///
58/// Replaces `Infallible` so that KV callback failures are visible to callers
59/// rather than being silently swallowed. The three variants cover all
60/// observable failure modes:
61///
62/// - `IoError`: a KV callback returned a non-zero error code (disk full,
63///   store closed, etc.).
64/// - `Deserialization`: a blob retrieved from KV failed to decode.
65/// - `Serialization`: a shard or cap could not be encoded before writing.
66#[derive(Debug, Clone, PartialEq, Eq)]
67pub enum KvError {
68    /// A KV callback returned an error code (set, delete, or iterator failure).
69    IoError,
70    /// Shard or checkpoint data retrieved from KV could not be decoded.
71    Deserialization,
72    /// Shard or cap data could not be serialized before writing.
73    Serialization,
74}
75
76impl fmt::Display for KvError {
77    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
78        match self {
79            KvError::IoError => write!(f, "KV callback returned an error"),
80            KvError::Deserialization => write!(f, "failed to deserialize KV data"),
81            KvError::Serialization => write!(f, "failed to serialize data for KV"),
82        }
83    }
84}
85
86impl std::error::Error for KvError {}
87
88// ---------------------------------------------------------------------------
89// KV key constants (must match keys.go 0x0F / 0x10 / 0x11)
90// ---------------------------------------------------------------------------
91
92const SHARD_PREFIX: u8 = 0x0F;
93const CAP_KEY: u8 = 0x10;
94const CHECKPOINT_PREFIX: u8 = 0x11;
95
96fn shard_key(index: u64) -> [u8; 9] {
97    let mut k = [0u8; 9];
98    k[0] = SHARD_PREFIX;
99    k[1..].copy_from_slice(&index.to_be_bytes());
100    k
101}
102
103fn cap_key() -> [u8; 1] {
104    [CAP_KEY]
105}
106
107fn checkpoint_key(id: u32) -> [u8; 5] {
108    let mut k = [0u8; 5];
109    k[0] = CHECKPOINT_PREFIX;
110    k[1..].copy_from_slice(&id.to_be_bytes());
111    k
112}
113
114// ---------------------------------------------------------------------------
115// Callback function pointer types
116// ---------------------------------------------------------------------------
117
118/// Retrieve a value from the KV store.
119///
120/// On success (key found) writes a C-malloc'd buffer to `*out_val` and its
121/// length to `*out_val_len`, then returns 0.
122/// Returns 1 if the key was not found (out pointers are unchanged).
123/// Returns -1 on error.
124pub type KvGetFn = unsafe extern "C" fn(
125    ctx: *mut c_void,
126    key: *const u8,
127    key_len: usize,
128    out_val: *mut *mut u8,
129    out_val_len: *mut usize,
130) -> i32;
131
132/// Write a key-value pair. Returns 0 on success, -1 on error.
133pub type KvSetFn = unsafe extern "C" fn(
134    ctx: *mut c_void,
135    key: *const u8,
136    key_len: usize,
137    val: *const u8,
138    val_len: usize,
139) -> i32;
140
141/// Delete a key. Returns 0 on success, -1 on error.
142pub type KvDeleteFn = unsafe extern "C" fn(ctx: *mut c_void, key: *const u8, key_len: usize) -> i32;
143
144/// Create an iterator over the given prefix.
145///
146/// `reverse` is 1 for a reverse (descending) iterator, 0 for ascending.
147/// Returns an opaque iterator handle, or null on error.
148pub type KvIterCreateFn = unsafe extern "C" fn(
149    ctx: *mut c_void,
150    prefix: *const u8,
151    prefix_len: usize,
152    reverse: u8,
153) -> *mut c_void;
154
155/// Advance the iterator and return the next key-value pair as C-malloc'd
156/// buffers. Caller frees with `free_buf`.
157///
158/// Returns 0 if a valid entry was written, 1 if exhausted, -1 on error.
159pub type KvIterNextFn = unsafe extern "C" fn(
160    iter: *mut c_void,
161    out_key: *mut *mut u8,
162    out_key_len: *mut usize,
163    out_val: *mut *mut u8,
164    out_val_len: *mut usize,
165) -> i32;
166
167/// Close and free an iterator handle.
168pub type KvIterFreeFn = unsafe extern "C" fn(iter: *mut c_void);
169
170/// Free a C-malloc'd buffer returned by a KV callback.
171pub type KvFreeBufFn = unsafe extern "C" fn(ptr: *mut u8, len: usize);
172
173// ---------------------------------------------------------------------------
174// KvCallbacks
175// ---------------------------------------------------------------------------
176
177/// Bundle of C function pointers + context passed to [`KvShardStore`].
178///
179/// # Safety
180/// All function pointers must remain valid for the lifetime of the
181/// `KvShardStore`. The `ctx` pointer must remain stable; Go achieves this
182/// via a `KvStoreProxy` whose address never changes across blocks.
183#[derive(Clone, Copy)]
184pub struct KvCallbacks {
185    pub ctx: *mut c_void,
186    pub get: KvGetFn,
187    pub set: KvSetFn,
188    pub delete: KvDeleteFn,
189    pub iter_create: KvIterCreateFn,
190    pub iter_next: KvIterNextFn,
191    pub iter_free: KvIterFreeFn,
192    pub free_buf: KvFreeBufFn,
193}
194
195// SAFETY: EndBlocker is single-threaded; all callbacks are called only on
196// the goroutine that owns the KV store.
197unsafe impl Send for KvCallbacks {}
198unsafe impl Sync for KvCallbacks {}
199
200// ---------------------------------------------------------------------------
201// Low-level helpers
202// ---------------------------------------------------------------------------
203
204impl KvCallbacks {
205    /// Fetch a value by key.
206    ///
207    /// Returns `Ok(Some(bytes))` if found, `Ok(None)` if not present, or
208    /// `Err(KvError::IoError)` if the callback signalled a hard error (rc=-1).
209    pub fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>, KvError> {
210        let mut out_ptr: *mut u8 = std::ptr::null_mut();
211        let mut out_len: usize = 0;
212        let rc = unsafe {
213            (self.get)(
214                self.ctx,
215                key.as_ptr(),
216                key.len(),
217                &mut out_ptr,
218                &mut out_len,
219            )
220        };
221        match rc {
222            0 => {
223                let val = unsafe { std::slice::from_raw_parts(out_ptr, out_len).to_vec() };
224                unsafe { (self.free_buf)(out_ptr, out_len) };
225                Ok(Some(val))
226            }
227            1 => Ok(None),              // not found
228            _ => Err(KvError::IoError), // rc=-1 or any other error code
229        }
230    }
231
232    /// Write a key-value pair. Returns `Err(KvError::IoError)` if the
233    /// callback returned a non-zero code.
234    pub fn set(&self, key: &[u8], val: &[u8]) -> Result<(), KvError> {
235        let rc = unsafe { (self.set)(self.ctx, key.as_ptr(), key.len(), val.as_ptr(), val.len()) };
236        if rc != 0 {
237            Err(KvError::IoError)
238        } else {
239            Ok(())
240        }
241    }
242
243    /// Delete a key. Returns `Err(KvError::IoError)` if the callback failed.
244    pub fn delete(&self, key: &[u8]) -> Result<(), KvError> {
245        let rc = unsafe { (self.delete)(self.ctx, key.as_ptr(), key.len()) };
246        if rc != 0 {
247            Err(KvError::IoError)
248        } else {
249            Ok(())
250        }
251    }
252
253    /// Create a forward or reverse iterator over the given prefix.
254    fn iter(&self, prefix: &[u8], reverse: bool) -> KvIter<'_> {
255        let handle =
256            unsafe { (self.iter_create)(self.ctx, prefix.as_ptr(), prefix.len(), reverse as u8) };
257        KvIter { handle, cb: self }
258    }
259}
260
261struct KvIter<'a> {
262    handle: *mut c_void,
263    cb: &'a KvCallbacks,
264}
265
266impl<'a> KvIter<'a> {
267    /// Advance and return `Some((key, value))`, or `None` when exhausted.
268    fn next(&mut self) -> Option<(Vec<u8>, Vec<u8>)> {
269        if self.handle.is_null() {
270            return None;
271        }
272        let mut key_ptr: *mut u8 = std::ptr::null_mut();
273        let mut key_len: usize = 0;
274        let mut val_ptr: *mut u8 = std::ptr::null_mut();
275        let mut val_len: usize = 0;
276        let rc = unsafe {
277            (self.cb.iter_next)(
278                self.handle,
279                &mut key_ptr,
280                &mut key_len,
281                &mut val_ptr,
282                &mut val_len,
283            )
284        };
285        if rc != 0 {
286            return None;
287        }
288        let key = unsafe { std::slice::from_raw_parts(key_ptr, key_len).to_vec() };
289        unsafe { (self.cb.free_buf)(key_ptr, key_len) };
290        let val = unsafe { std::slice::from_raw_parts(val_ptr, val_len).to_vec() };
291        unsafe { (self.cb.free_buf)(val_ptr, val_len) };
292        Some((key, val))
293    }
294}
295
296impl<'a> Drop for KvIter<'a> {
297    fn drop(&mut self) {
298        if !self.handle.is_null() {
299            unsafe { (self.cb.iter_free)(self.handle) };
300        }
301    }
302}
303
304// ---------------------------------------------------------------------------
305// KvShardStore
306// ---------------------------------------------------------------------------
307
308/// A [`ShardStore`] that stores all state in the Cosmos KV store via Go
309/// callbacks. Gives `ShardTree` true lazy loading: only the data it actually
310/// accesses is read from KV.
311pub struct KvShardStore {
312    pub(crate) cb: KvCallbacks,
313}
314
315impl KvShardStore {
316    pub fn new(cb: KvCallbacks) -> Self {
317        Self { cb }
318    }
319}
320
321// ---------------------------------------------------------------------------
322// ShardStore implementation
323// ---------------------------------------------------------------------------
324
325impl ShardStore for KvShardStore {
326    type H = MerkleHashVote;
327    type CheckpointId = u32;
328    type Error = KvError;
329
330    fn get_shard(
331        &self,
332        shard_root: Address,
333    ) -> Result<Option<LocatedPrunableTree<MerkleHashVote>>, KvError> {
334        let idx = shard_root.index();
335        let key = shard_key(idx);
336        let Some(blob) = self.cb.get(&key)? else {
337            return Ok(None);
338        };
339        match read_shard_vote(&blob) {
340            Ok(tree) => Ok(LocatedTree::from_parts(shard_root, tree).ok()),
341            Err(_) => Err(KvError::Deserialization),
342        }
343    }
344
345    fn last_shard(&self) -> Result<Option<LocatedPrunableTree<MerkleHashVote>>, KvError> {
346        let prefix = [SHARD_PREFIX];
347        let mut iter = self.cb.iter(&prefix, true /* reverse */);
348        let Some((key, val)) = iter.next() else {
349            return Ok(None);
350        };
351        if key.len() < 9 {
352            return Ok(None);
353        }
354        let idx = u64::from_be_bytes(key[1..9].try_into().unwrap());
355        let level = Level::from(SHARD_HEIGHT);
356        let addr = Address::from_parts(level, idx);
357        match read_shard_vote(&val) {
358            Ok(tree) => Ok(LocatedTree::from_parts(addr, tree).ok()),
359            Err(_) => Err(KvError::Deserialization),
360        }
361    }
362
363    fn put_shard(&mut self, subtree: LocatedPrunableTree<MerkleHashVote>) -> Result<(), KvError> {
364        let idx = subtree.root_addr().index();
365        let key = shard_key(idx);
366        let blob = write_shard_vote(subtree.root()).map_err(|_| KvError::Serialization)?;
367        self.cb.set(&key, &blob)
368    }
369
370    fn get_shard_roots(&self) -> Result<Vec<Address>, KvError> {
371        let prefix = [SHARD_PREFIX];
372        let mut iter = self.cb.iter(&prefix, false);
373        let level = Level::from(SHARD_HEIGHT);
374        let mut roots = Vec::new();
375        while let Some((key, _)) = iter.next() {
376            if key.len() < 9 {
377                continue;
378            }
379            let idx = u64::from_be_bytes(key[1..9].try_into().unwrap());
380            roots.push(Address::from_parts(level, idx));
381        }
382        Ok(roots)
383    }
384
385    fn truncate_shards(&mut self, shard_index: u64) -> Result<(), KvError> {
386        let prefix = [SHARD_PREFIX];
387        let mut iter = self.cb.iter(&prefix, false);
388        let mut to_delete = Vec::new();
389        while let Some((key, _)) = iter.next() {
390            if key.len() < 9 {
391                continue;
392            }
393            let idx = u64::from_be_bytes(key[1..9].try_into().unwrap());
394            if idx >= shard_index {
395                to_delete.push(key);
396            }
397        }
398        drop(iter);
399        for key in to_delete {
400            self.cb.delete(&key)?;
401        }
402        Ok(())
403    }
404
405    fn get_cap(&self) -> Result<PrunableTree<MerkleHashVote>, KvError> {
406        let key = cap_key();
407        let Some(blob) = self.cb.get(&key)? else {
408            return Ok(Tree::empty());
409        };
410        read_shard_vote(&blob).map_err(|_| KvError::Deserialization)
411    }
412
413    fn put_cap(&mut self, cap: PrunableTree<MerkleHashVote>) -> Result<(), KvError> {
414        let key = cap_key();
415        let blob = write_shard_vote(&cap).map_err(|_| KvError::Serialization)?;
416        self.cb.set(&key, &blob)
417    }
418
419    fn min_checkpoint_id(&self) -> Result<Option<u32>, KvError> {
420        let prefix = [CHECKPOINT_PREFIX];
421        let mut iter = self.cb.iter(&prefix, false);
422        Ok(iter.next().and_then(|(k, _)| {
423            if k.len() >= 5 {
424                Some(u32::from_be_bytes(k[1..5].try_into().unwrap()))
425            } else {
426                None
427            }
428        }))
429    }
430
431    fn max_checkpoint_id(&self) -> Result<Option<u32>, KvError> {
432        let prefix = [CHECKPOINT_PREFIX];
433        let mut iter = self.cb.iter(&prefix, true /* reverse */);
434        Ok(iter.next().and_then(|(k, _)| {
435            if k.len() >= 5 {
436                Some(u32::from_be_bytes(k[1..5].try_into().unwrap()))
437            } else {
438                None
439            }
440        }))
441    }
442
443    fn add_checkpoint(
444        &mut self,
445        checkpoint_id: u32,
446        checkpoint: Checkpoint,
447    ) -> Result<(), KvError> {
448        let key = checkpoint_key(checkpoint_id);
449        let blob = write_checkpoint(&checkpoint);
450        self.cb.set(&key, &blob)
451    }
452
453    fn checkpoint_count(&self) -> Result<usize, KvError> {
454        let prefix = [CHECKPOINT_PREFIX];
455        let mut iter = self.cb.iter(&prefix, false);
456        let mut count = 0usize;
457        while iter.next().is_some() {
458            count += 1;
459        }
460        Ok(count)
461    }
462
463    fn get_checkpoint_at_depth(
464        &self,
465        checkpoint_depth: usize,
466    ) -> Result<Option<(u32, Checkpoint)>, KvError> {
467        let prefix = [CHECKPOINT_PREFIX];
468        let mut iter = self.cb.iter(&prefix, true /* reverse */);
469        let mut seen = 0usize;
470        while let Some((key, val)) = iter.next() {
471            if seen == checkpoint_depth {
472                if key.len() < 5 {
473                    return Ok(None);
474                }
475                let id = u32::from_be_bytes(key[1..5].try_into().unwrap());
476                return Ok(read_checkpoint(&val).ok().map(|cp| (id, cp)));
477            }
478            seen += 1;
479        }
480        Ok(None)
481    }
482
483    fn get_checkpoint(&self, checkpoint_id: &u32) -> Result<Option<Checkpoint>, KvError> {
484        let key = checkpoint_key(*checkpoint_id);
485        let Some(blob) = self.cb.get(&key)? else {
486            return Ok(None);
487        };
488        Ok(read_checkpoint(&blob).ok())
489    }
490
491    fn with_checkpoints<F>(&mut self, limit: usize, mut callback: F) -> Result<(), KvError>
492    where
493        F: FnMut(&u32, &Checkpoint) -> Result<(), KvError>,
494    {
495        let prefix = [CHECKPOINT_PREFIX];
496        let mut iter = self.cb.iter(&prefix, false);
497        let mut count = 0usize;
498        while count < limit {
499            let Some((key, val)) = iter.next() else {
500                break;
501            };
502            if key.len() < 5 {
503                continue;
504            }
505            let id = u32::from_be_bytes(key[1..5].try_into().unwrap());
506            if let Ok(cp) = read_checkpoint(&val) {
507                callback(&id, &cp)?;
508            }
509            count += 1;
510        }
511        Ok(())
512    }
513
514    fn for_each_checkpoint<F>(&self, limit: usize, mut callback: F) -> Result<(), KvError>
515    where
516        F: FnMut(&u32, &Checkpoint) -> Result<(), KvError>,
517    {
518        let prefix = [CHECKPOINT_PREFIX];
519        let mut iter = self.cb.iter(&prefix, false);
520        let mut count = 0usize;
521        while count < limit {
522            let Some((key, val)) = iter.next() else {
523                break;
524            };
525            if key.len() < 5 {
526                continue;
527            }
528            let id = u32::from_be_bytes(key[1..5].try_into().unwrap());
529            if let Ok(cp) = read_checkpoint(&val) {
530                callback(&id, &cp)?;
531            }
532            count += 1;
533        }
534        Ok(())
535    }
536
537    fn update_checkpoint_with<F>(&mut self, checkpoint_id: &u32, update: F) -> Result<bool, KvError>
538    where
539        F: Fn(&mut Checkpoint) -> Result<(), KvError>,
540    {
541        let key = checkpoint_key(*checkpoint_id);
542        let Some(blob) = self.cb.get(&key)? else {
543            return Ok(false);
544        };
545        let Ok(mut cp) = read_checkpoint(&blob) else {
546            return Ok(false);
547        };
548        update(&mut cp)?;
549        let new_blob = write_checkpoint(&cp);
550        self.cb.set(&key, &new_blob)?;
551        Ok(true)
552    }
553
554    fn remove_checkpoint(&mut self, checkpoint_id: &u32) -> Result<(), KvError> {
555        let key = checkpoint_key(*checkpoint_id);
556        self.cb.delete(&key)
557    }
558
559    fn truncate_checkpoints_retaining(&mut self, checkpoint_id: &u32) -> Result<(), KvError> {
560        // Delete all checkpoints with id < checkpoint_id; clear marks_removed
561        // on the retained checkpoint itself (matches MemoryShardStore semantics).
562        let prefix = [CHECKPOINT_PREFIX];
563        let mut iter = self.cb.iter(&prefix, false);
564        let mut to_delete = Vec::new();
565        while let Some((key, _)) = iter.next() {
566            if key.len() < 5 {
567                continue;
568            }
569            let id = u32::from_be_bytes(key[1..5].try_into().unwrap());
570            if id < *checkpoint_id {
571                to_delete.push(key);
572            } else {
573                break;
574            }
575        }
576        drop(iter);
577        for key in to_delete {
578            self.cb.delete(&key)?;
579        }
580        // Clear marks_removed on the retaining checkpoint.
581        let retain_key = checkpoint_key(*checkpoint_id);
582        if let Some(blob) = self.cb.get(&retain_key)? {
583            if let Ok(cp) = read_checkpoint(&blob) {
584                let cleared = Checkpoint::from_parts(cp.tree_state(), BTreeSet::new());
585                self.cb.set(&retain_key, &write_checkpoint(&cleared))?;
586            }
587        }
588        Ok(())
589    }
590}