Skip to main content

vote_commitment_tree/
serde.rs

1//! Binary serialization for shard trees and checkpoints.
2//!
3//! Shared between [`crate::kv_shard_store`] and any future persistence layer.
4//!
5//! # Shard format
6//!
7//! ```text
8//! [version: u8 = 1]
9//! [tree: node]
10//!
11//! node :=
12//!   0x00                                -- Nil
13//!   0x01 [hash: 32 bytes] [flags: u8]   -- Leaf
14//!   0x02 [has_ann: u8] [hash: 32 bytes if has_ann=1] [left: node] [right: node]  -- Parent
15//! ```
16//!
17//! # Checkpoint format
18//!
19//! ```text
20//! [has_position: u8]  [position: u64 LE if has_position=1]
21//! [marks_count: u32 LE]  [mark_position: u64 LE] × marks_count
22//! ```
23
24use std::collections::BTreeSet;
25use std::io::{self, Cursor, Read, Write};
26use std::ops::Deref;
27use std::sync::Arc;
28
29use incrementalmerkletree::Position;
30use shardtree::{
31    store::{Checkpoint, TreeState},
32    Node, PrunableTree, RetentionFlags, Tree,
33};
34
35use crate::hash::MerkleHashVote;
36
37const SHARD_SER_VERSION: u8 = 1;
38const NODE_NIL: u8 = 0;
39const NODE_LEAF: u8 = 1;
40const NODE_PARENT: u8 = 2;
41
42fn write_hash<W: Write>(w: &mut W, h: &MerkleHashVote) -> io::Result<()> {
43    w.write_all(&h.to_bytes())
44}
45
46fn write_node<W: Write>(w: &mut W, tree: &PrunableTree<MerkleHashVote>) -> io::Result<()> {
47    match tree.deref() {
48        Node::Parent { ann, left, right } => {
49            w.write_all(&[NODE_PARENT])?;
50            match ann.as_ref() {
51                None => w.write_all(&[0u8])?,
52                Some(h) => {
53                    w.write_all(&[1u8])?;
54                    write_hash(w, h)?;
55                }
56            }
57            write_node(w, left)?;
58            write_node(w, right)?;
59            Ok(())
60        }
61        Node::Leaf { value } => {
62            w.write_all(&[NODE_LEAF])?;
63            write_hash(w, &value.0)?;
64            w.write_all(&[value.1.bits()])?;
65            Ok(())
66        }
67        Node::Nil => {
68            w.write_all(&[NODE_NIL])?;
69            Ok(())
70        }
71    }
72}
73
74/// Serialize a `PrunableTree<MerkleHashVote>` to a versioned blob.
75pub fn write_shard_vote(tree: &PrunableTree<MerkleHashVote>) -> io::Result<Vec<u8>> {
76    let mut buf = Vec::new();
77    buf.push(SHARD_SER_VERSION);
78    write_node(&mut buf, tree)?;
79    Ok(buf)
80}
81
82fn read_hash<R: Read>(r: &mut R) -> io::Result<MerkleHashVote> {
83    let mut bytes = [0u8; 32];
84    r.read_exact(&mut bytes)?;
85    MerkleHashVote::from_bytes(&bytes)
86        .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "invalid MerkleHashVote"))
87}
88
89fn read_node<R: Read>(r: &mut R) -> io::Result<PrunableTree<MerkleHashVote>> {
90    let mut tag = [0u8; 1];
91    r.read_exact(&mut tag)?;
92    match tag[0] {
93        NODE_NIL => Ok(Tree::empty()),
94        NODE_LEAF => {
95            let hash = read_hash(r)?;
96            let mut flag = [0u8; 1];
97            r.read_exact(&mut flag)?;
98            let flags = RetentionFlags::from_bits_truncate(flag[0]);
99            Ok(Tree::leaf((hash, flags)))
100        }
101        NODE_PARENT => {
102            let mut ann_flag = [0u8; 1];
103            r.read_exact(&mut ann_flag)?;
104            let ann = if ann_flag[0] == 1 {
105                Some(Arc::new(read_hash(r)?))
106            } else {
107                None
108            };
109            let left = read_node(r)?;
110            let right = read_node(r)?;
111            Ok(Tree::parent(ann, left, right))
112        }
113        t => Err(io::Error::new(
114            io::ErrorKind::InvalidData,
115            format!("unknown node tag: {t}"),
116        )),
117    }
118}
119
120/// Deserialize a shard blob produced by [`write_shard_vote`].
121pub fn read_shard_vote(data: &[u8]) -> io::Result<PrunableTree<MerkleHashVote>> {
122    let mut cur = Cursor::new(data);
123    let mut version = [0u8; 1];
124    cur.read_exact(&mut version)?;
125    if version[0] != SHARD_SER_VERSION {
126        return Err(io::Error::new(
127            io::ErrorKind::InvalidData,
128            format!("unknown shard version: {}", version[0]),
129        ));
130    }
131    read_node(&mut cur)
132}
133
134/// Serialize a [`Checkpoint`] to bytes.
135pub fn write_checkpoint(cp: &Checkpoint) -> Vec<u8> {
136    let mut buf = Vec::new();
137    match cp.position() {
138        None => buf.push(0u8),
139        Some(pos) => {
140            buf.push(1u8);
141            buf.extend_from_slice(&u64::from(pos).to_le_bytes());
142        }
143    }
144    let marks: Vec<u64> = cp.marks_removed().iter().map(|p| u64::from(*p)).collect();
145    let count = marks.len() as u32;
146    buf.extend_from_slice(&count.to_le_bytes());
147    for m in marks {
148        buf.extend_from_slice(&m.to_le_bytes());
149    }
150    buf
151}
152
153/// Deserialize a checkpoint blob produced by [`write_checkpoint`].
154pub fn read_checkpoint(data: &[u8]) -> io::Result<Checkpoint> {
155    let mut cur = Cursor::new(data);
156
157    let mut flag = [0u8; 1];
158    cur.read_exact(&mut flag)?;
159    let tree_state = if flag[0] == 0 {
160        TreeState::Empty
161    } else {
162        let mut pos_bytes = [0u8; 8];
163        cur.read_exact(&mut pos_bytes)?;
164        TreeState::AtPosition(Position::from(u64::from_le_bytes(pos_bytes)))
165    };
166
167    let mut count_bytes = [0u8; 4];
168    cur.read_exact(&mut count_bytes)?;
169    let count = u32::from_le_bytes(count_bytes) as usize;
170
171    let mut marks = BTreeSet::new();
172    for _ in 0..count {
173        let mut pos_bytes = [0u8; 8];
174        cur.read_exact(&mut pos_bytes)?;
175        marks.insert(Position::from(u64::from_le_bytes(pos_bytes)));
176    }
177    Ok(Checkpoint::from_parts(tree_state, marks))
178}