Skip to main content

treeship_core/merkle/
tree.rs

1use sha2::{Sha256, Digest};
2use serde::{Deserialize, Serialize};
3
4/// Direction of a sibling in the Merkle proof path.
5#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
6pub enum Direction {
7    Left,
8    Right,
9}
10
11/// One step in an inclusion proof path.
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct ProofStep {
14    pub direction: Direction,
15    /// Hex-encoded hash of the sibling node.
16    pub hash: String,
17}
18
19/// Merkle tree algorithm identifier for forward/backward compatibility.
20pub const MERKLE_ALGORITHM_V1: &str = "sha256-duplicate-last";
21pub const MERKLE_ALGORITHM_V2: &str = "sha256-rfc9162";
22
23/// An inclusion proof: the sibling path from a leaf to the root.
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct InclusionProof {
26    pub leaf_index: usize,
27    /// Hex-encoded leaf hash.
28    pub leaf_hash: String,
29    pub path: Vec<ProofStep>,
30    /// Algorithm used to build this proof. Missing = v1 (duplicate-last).
31    #[serde(default, skip_serializing_if = "Option::is_none")]
32    pub algorithm: Option<String>,
33}
34
35/// An append-only binary Merkle tree.
36///
37/// Leaves are `sha256(artifact_id)`. Odd leaf counts are handled by
38/// promoting the unpaired node to the next level without hashing,
39/// matching the RFC 9162 (Certificate Transparency) construction.
40pub struct MerkleTree {
41    /// All leaf hashes in insertion order.
42    leaves: Vec<[u8; 32]>,
43}
44
45impl MerkleTree {
46    pub fn new() -> Self {
47        Self { leaves: Vec::new() }
48    }
49
50    /// Append an artifact ID as a new leaf. Returns the leaf index.
51    pub fn append(&mut self, artifact_id: &str) -> usize {
52        let hash = Sha256::digest(artifact_id.as_bytes());
53        self.leaves.push(hash.into());
54        self.leaves.len() - 1
55    }
56
57    /// Number of leaves in the tree.
58    pub fn len(&self) -> usize {
59        self.leaves.len()
60    }
61
62    /// Whether the tree is empty.
63    pub fn is_empty(&self) -> bool {
64        self.leaves.is_empty()
65    }
66
67    /// Compute the root hash. Returns `None` for an empty tree.
68    pub fn root(&self) -> Option<[u8; 32]> {
69        if self.leaves.is_empty() {
70            return None;
71        }
72        Some(self.compute_root(&self.leaves))
73    }
74
75    /// Height of the tree: ceil(log2(n_leaves)).
76    pub fn height(&self) -> usize {
77        if self.leaves.len() <= 1 {
78            return 0;
79        }
80        (self.leaves.len() as f64).log2().ceil() as usize
81    }
82
83    /// Generate an inclusion proof for the leaf at `leaf_index`.
84    pub fn inclusion_proof(&self, leaf_index: usize) -> Option<InclusionProof> {
85        if leaf_index >= self.leaves.len() {
86            return None;
87        }
88
89        let mut path = Vec::new();
90        let mut idx = leaf_index;
91        let mut level: Vec<[u8; 32]> = self.leaves.clone();
92
93        while level.len() > 1 {
94            // RFC 9162: if idx has a sibling, add it to the proof path.
95            // If idx is the unpaired last node, it promotes without a sibling step.
96            if idx + 1 < level.len() && idx % 2 == 0 {
97                // Sibling is to the right
98                path.push(ProofStep {
99                    direction: Direction::Right,
100                    hash: hex::encode(level[idx + 1]),
101                });
102            } else if idx % 2 == 1 {
103                // Sibling is to the left
104                path.push(ProofStep {
105                    direction: Direction::Left,
106                    hash: hex::encode(level[idx - 1]),
107                });
108            }
109            // else: unpaired last node, no sibling step needed
110
111            // Move up: compute parent hashes (RFC 9162 promotion)
112            let mut next_level = Vec::with_capacity((level.len() + 1) / 2);
113            let mut i = 0;
114            while i + 1 < level.len() {
115                let mut h = Sha256::new();
116                h.update(level[i]);
117                h.update(level[i + 1]);
118                next_level.push(h.finalize().into());
119                i += 2;
120            }
121            if i < level.len() {
122                next_level.push(level[i]);
123            }
124            level = next_level;
125
126            idx /= 2;
127        }
128
129        Some(InclusionProof {
130            leaf_index,
131            leaf_hash: hex::encode(self.leaves[leaf_index]),
132            path,
133            algorithm: Some(MERKLE_ALGORITHM_V2.to_string()),
134        })
135    }
136
137    /// Verify an inclusion proof against a hex-encoded root hash.
138    ///
139    /// Recomputes the root from `artifact_id` + the proof path and checks
140    /// that it matches `root_hex`. Fully offline, no tree state needed.
141    ///
142    /// Supports both v1 (sha256-duplicate-last) and v2 (sha256-rfc9162) proofs.
143    /// Rejects unknown algorithm values.
144    pub fn verify_proof(
145        root_hex: &str,
146        artifact_id: &str,
147        proof: &InclusionProof,
148    ) -> bool {
149        // Validate algorithm field. Missing = v1 (legacy), known values accepted.
150        let algo = proof.algorithm.as_deref().unwrap_or(MERKLE_ALGORITHM_V1);
151        if algo != MERKLE_ALGORITHM_V1 && algo != MERKLE_ALGORITHM_V2 {
152            return false; // unknown algorithm -- reject
153        }
154
155        let current: [u8; 32] = Sha256::digest(artifact_id.as_bytes()).into();
156        // Verify leaf hash matches artifact
157        if hex::encode(current) != proof.leaf_hash {
158            return false;
159        }
160
161        let mut current = current;
162        for step in &proof.path {
163            let sibling = match hex::decode(&step.hash) {
164                Ok(b) if b.len() == 32 => {
165                    let mut arr = [0u8; 32];
166                    arr.copy_from_slice(&b);
167                    arr
168                }
169                _ => return false,
170            };
171
172            let mut h = Sha256::new();
173            match step.direction {
174                Direction::Right => {
175                    h.update(current);
176                    h.update(sibling);
177                }
178                Direction::Left => {
179                    h.update(sibling);
180                    h.update(current);
181                }
182            }
183            current = h.finalize().into();
184        }
185
186        hex::encode(current) == root_hex
187    }
188
189    /// Internal: compute root from a slice of leaf hashes.
190    /// RFC 9162 construction: odd nodes are promoted without hashing.
191    fn compute_root(&self, leaves: &[[u8; 32]]) -> [u8; 32] {
192        if leaves.len() == 1 {
193            return leaves[0];
194        }
195        let mut level = leaves.to_vec();
196        while level.len() > 1 {
197            let mut next_level = Vec::with_capacity((level.len() + 1) / 2);
198            let mut i = 0;
199            while i + 1 < level.len() {
200                let mut h = Sha256::new();
201                h.update(level[i]);
202                h.update(level[i + 1]);
203                next_level.push(h.finalize().into());
204                i += 2;
205            }
206            // RFC 9162: promote unpaired node without hashing
207            if i < level.len() {
208                next_level.push(level[i]);
209            }
210            level = next_level;
211        }
212        level[0]
213    }
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219
220    #[test]
221    fn single_leaf_root_is_leaf_hash() {
222        let mut tree = MerkleTree::new();
223        tree.append("art_abc123");
224        let root = tree.root().unwrap();
225        let expected = Sha256::digest(b"art_abc123");
226        assert_eq!(root, expected.as_slice());
227    }
228
229    #[test]
230    fn inclusion_proof_verifies() {
231        let mut tree = MerkleTree::new();
232        let ids = ["art_a", "art_b", "art_c", "art_d"];
233        for id in &ids {
234            tree.append(id);
235        }
236
237        let root = hex::encode(tree.root().unwrap());
238        let proof = tree.inclusion_proof(1).unwrap(); // art_b at index 1
239
240        assert!(MerkleTree::verify_proof(&root, "art_b", &proof));
241    }
242
243    #[test]
244    fn wrong_artifact_fails_verification() {
245        let mut tree = MerkleTree::new();
246        tree.append("art_a");
247        tree.append("art_b");
248
249        let root = hex::encode(tree.root().unwrap());
250        let proof = tree.inclusion_proof(0).unwrap(); // proof for art_a
251
252        // Try to verify art_WRONG against art_a's proof
253        assert!(!MerkleTree::verify_proof(&root, "art_WRONG", &proof));
254    }
255
256    #[test]
257    fn tampered_sibling_fails() {
258        let mut tree = MerkleTree::new();
259        tree.append("art_a");
260        tree.append("art_b");
261
262        let root = hex::encode(tree.root().unwrap());
263        let mut proof = tree.inclusion_proof(0).unwrap();
264
265        // Tamper with a sibling hash
266        proof.path[0].hash = "0".repeat(64);
267
268        assert!(!MerkleTree::verify_proof(&root, "art_a", &proof));
269    }
270
271    #[test]
272    fn odd_number_of_leaves() {
273        // 5 leaves -- last leaf is duplicated for padding
274        let mut tree = MerkleTree::new();
275        for i in 0..5 {
276            tree.append(&format!("art_{}", i));
277        }
278
279        let root = hex::encode(tree.root().unwrap());
280        let proof = tree.inclusion_proof(4).unwrap(); // last leaf
281
282        assert!(MerkleTree::verify_proof(&root, "art_4", &proof));
283    }
284}