rustywallet_taproot/
taptree.rs1use crate::error::TaprootError;
6use crate::tagged_hash::{TapLeafHash, TapNodeHash};
7
8pub const TAPSCRIPT_LEAF_VERSION: u8 = 0xc0;
10
11#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash)]
13pub struct LeafVersion(pub u8);
14
15impl LeafVersion {
16 pub const TAPSCRIPT: Self = Self(TAPSCRIPT_LEAF_VERSION);
18
19 pub fn new(version: u8) -> Result<Self, TaprootError> {
21 if version & 0x01 != 0 {
23 return Err(TaprootError::InvalidLeafVersion(version));
24 }
25 Ok(Self(version))
26 }
27
28 pub fn to_u8(self) -> u8 {
30 self.0
31 }
32}
33
34impl Default for LeafVersion {
35 fn default() -> Self {
36 Self::TAPSCRIPT
37 }
38}
39
40#[derive(Clone, PartialEq, Eq, Debug)]
42pub struct TapLeaf {
43 pub version: LeafVersion,
45 pub script: Vec<u8>,
47}
48
49impl TapLeaf {
50 pub fn new(script: Vec<u8>) -> Self {
52 Self {
53 version: LeafVersion::TAPSCRIPT,
54 script,
55 }
56 }
57
58 pub fn with_version(version: LeafVersion, script: Vec<u8>) -> Self {
60 Self { version, script }
61 }
62
63 pub fn hash(&self) -> TapLeafHash {
65 TapLeafHash::from_script(self.version.0, &self.script)
66 }
67}
68
69#[derive(Clone, Debug)]
71pub enum TapNode {
72 Leaf(TapLeaf),
74 Branch(Box<TapNode>, Box<TapNode>),
76}
77
78impl TapNode {
79 pub fn hash(&self) -> TapNodeHash {
81 match self {
82 TapNode::Leaf(leaf) => TapNodeHash::from_leaf(leaf.hash()),
83 TapNode::Branch(left, right) => {
84 TapNodeHash::from_children(&left.hash(), &right.hash())
85 }
86 }
87 }
88
89 pub fn is_leaf(&self) -> bool {
91 matches!(self, TapNode::Leaf(_))
92 }
93
94 pub fn as_leaf(&self) -> Option<&TapLeaf> {
96 match self {
97 TapNode::Leaf(leaf) => Some(leaf),
98 TapNode::Branch(_, _) => None,
99 }
100 }
101}
102
103#[derive(Clone, Debug)]
105pub struct TapTree {
106 root: TapNode,
107}
108
109impl TapTree {
110 pub fn from_node(root: TapNode) -> Self {
112 Self { root }
113 }
114
115 pub fn single_leaf(script: Vec<u8>) -> Self {
117 Self {
118 root: TapNode::Leaf(TapLeaf::new(script)),
119 }
120 }
121
122 pub fn root_hash(&self) -> TapNodeHash {
124 self.root.hash()
125 }
126
127 pub fn root(&self) -> &TapNode {
129 &self.root
130 }
131
132 pub fn merkle_path(&self, target_leaf: &TapLeaf) -> Option<Vec<TapNodeHash>> {
134 let target_hash = target_leaf.hash();
135 self.find_path(&self.root, &TapNodeHash::from_leaf(target_hash))
136 }
137
138 fn find_path(&self, node: &TapNode, target: &TapNodeHash) -> Option<Vec<TapNodeHash>> {
139 match node {
140 TapNode::Leaf(leaf) => {
141 if TapNodeHash::from_leaf(leaf.hash()) == *target {
142 Some(Vec::new())
143 } else {
144 None
145 }
146 }
147 TapNode::Branch(left, right) => {
148 if let Some(mut path) = self.find_path(left, target) {
150 path.push(right.hash());
151 return Some(path);
152 }
153 if let Some(mut path) = self.find_path(right, target) {
155 path.push(left.hash());
156 return Some(path);
157 }
158 None
159 }
160 }
161 }
162
163 pub fn leaves(&self) -> Vec<&TapLeaf> {
165 let mut leaves = Vec::new();
166 self.collect_leaves(&self.root, &mut leaves);
167 leaves
168 }
169
170 fn collect_leaves<'a>(&'a self, node: &'a TapNode, leaves: &mut Vec<&'a TapLeaf>) {
171 match node {
172 TapNode::Leaf(leaf) => leaves.push(leaf),
173 TapNode::Branch(left, right) => {
174 self.collect_leaves(left, leaves);
175 self.collect_leaves(right, leaves);
176 }
177 }
178 }
179}
180
181#[derive(Default)]
183pub struct TapTreeBuilder {
184 leaves: Vec<(TapLeaf, u8)>, }
186
187impl TapTreeBuilder {
188 pub fn new() -> Self {
190 Self::default()
191 }
192
193 pub fn add_leaf(mut self, depth: u8, script: Vec<u8>) -> Self {
195 self.leaves.push((TapLeaf::new(script), depth));
196 self
197 }
198
199 pub fn add_leaf_with_version(
201 mut self,
202 depth: u8,
203 version: LeafVersion,
204 script: Vec<u8>,
205 ) -> Self {
206 self.leaves.push((TapLeaf::with_version(version, script), depth));
207 self
208 }
209
210 pub fn build(self) -> Result<TapTree, TaprootError> {
212 if self.leaves.is_empty() {
213 return Err(TaprootError::EmptyTree);
214 }
215
216 if self.leaves.len() == 1 {
217 return Ok(TapTree::single_leaf(self.leaves[0].0.script.clone()));
218 }
219
220 let mut leaves = self.leaves;
222 leaves.sort_by(|a, b| b.1.cmp(&a.1));
223
224 let mut nodes: Vec<(TapNode, u8)> = leaves
226 .into_iter()
227 .map(|(leaf, depth)| (TapNode::Leaf(leaf), depth))
228 .collect();
229
230 while nodes.len() > 1 {
231 let mut i = 0;
233 while i < nodes.len() - 1 {
234 if nodes[i].1 == nodes[i + 1].1 {
235 let (right, _) = nodes.remove(i + 1);
236 let (left, depth) = nodes.remove(i);
237 let branch = TapNode::Branch(Box::new(left), Box::new(right));
238 nodes.insert(i, (branch, depth.saturating_sub(1)));
239 } else {
240 i += 1;
241 }
242 }
243
244 if nodes.len() > 1 && nodes.iter().all(|(_, d)| *d == nodes[0].1) {
247 break;
249 }
250 }
251
252 if nodes.len() != 1 {
253 return Err(TaprootError::TreeError(
254 "Could not build balanced tree".into(),
255 ));
256 }
257
258 Ok(TapTree::from_node(nodes.remove(0).0))
259 }
260}
261
262pub fn two_leaf_tree(script1: Vec<u8>, script2: Vec<u8>) -> TapTree {
264 let left = TapNode::Leaf(TapLeaf::new(script1));
265 let right = TapNode::Leaf(TapLeaf::new(script2));
266 TapTree::from_node(TapNode::Branch(Box::new(left), Box::new(right)))
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272
273 #[test]
274 fn test_leaf_hash() {
275 let leaf = TapLeaf::new(vec![0x51]); let hash = leaf.hash();
277
278 let hash2 = leaf.hash();
280 assert_eq!(hash, hash2);
281 }
282
283 #[test]
284 fn test_single_leaf_tree() {
285 let tree = TapTree::single_leaf(vec![0x51]);
286 let leaves = tree.leaves();
287 assert_eq!(leaves.len(), 1);
288 }
289
290 #[test]
291 fn test_two_leaf_tree() {
292 let tree = two_leaf_tree(vec![0x51], vec![0x52]);
293 let leaves = tree.leaves();
294 assert_eq!(leaves.len(), 2);
295 }
296
297 #[test]
298 fn test_merkle_path() {
299 let script1 = vec![0x51];
300 let script2 = vec![0x52];
301 let tree = two_leaf_tree(script1.clone(), script2.clone());
302
303 let leaf1 = TapLeaf::new(script1);
304 let path = tree.merkle_path(&leaf1).unwrap();
305
306 assert_eq!(path.len(), 1);
308 }
309
310 #[test]
311 fn test_builder_single_leaf() {
312 let tree = TapTreeBuilder::new()
313 .add_leaf(0, vec![0x51])
314 .build()
315 .unwrap();
316
317 assert_eq!(tree.leaves().len(), 1);
318 }
319
320 #[test]
321 fn test_builder_two_leaves() {
322 let tree = TapTreeBuilder::new()
323 .add_leaf(1, vec![0x51])
324 .add_leaf(1, vec![0x52])
325 .build()
326 .unwrap();
327
328 assert_eq!(tree.leaves().len(), 2);
329 }
330
331 #[test]
332 fn test_leaf_version() {
333 assert!(LeafVersion::new(0xc0).is_ok());
334 assert!(LeafVersion::new(0xc2).is_ok());
335 assert!(LeafVersion::new(0xc1).is_err()); }
337
338 #[test]
339 fn test_branch_hash_deterministic() {
340 let tree1 = two_leaf_tree(vec![0x51], vec![0x52]);
341 let tree2 = two_leaf_tree(vec![0x51], vec![0x52]);
342
343 assert_eq!(tree1.root_hash(), tree2.root_hash());
344 }
345}