rustywallet_taproot/
taptree.rs

1//! Tap tree (MAST) construction
2//!
3//! Implements Merkle Abstract Syntax Trees for Taproot script paths.
4
5use crate::error::TaprootError;
6use crate::tagged_hash::{TapLeafHash, TapNodeHash};
7
8/// Leaf version for Tapscript
9pub const TAPSCRIPT_LEAF_VERSION: u8 = 0xc0;
10
11/// Leaf version
12#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash)]
13pub struct LeafVersion(pub u8);
14
15impl LeafVersion {
16    /// Tapscript version (0xc0)
17    pub const TAPSCRIPT: Self = Self(TAPSCRIPT_LEAF_VERSION);
18
19    /// Create a new leaf version
20    pub fn new(version: u8) -> Result<Self, TaprootError> {
21        // Valid leaf versions have the lowest bit unset
22        if version & 0x01 != 0 {
23            return Err(TaprootError::InvalidLeafVersion(version));
24        }
25        Ok(Self(version))
26    }
27
28    /// Get the version byte
29    pub fn to_u8(self) -> u8 {
30        self.0
31    }
32}
33
34impl Default for LeafVersion {
35    fn default() -> Self {
36        Self::TAPSCRIPT
37    }
38}
39
40/// A leaf in the tap tree
41#[derive(Clone, PartialEq, Eq, Debug)]
42pub struct TapLeaf {
43    /// Leaf version
44    pub version: LeafVersion,
45    /// The script
46    pub script: Vec<u8>,
47}
48
49impl TapLeaf {
50    /// Create a new Tapscript leaf
51    pub fn new(script: Vec<u8>) -> Self {
52        Self {
53            version: LeafVersion::TAPSCRIPT,
54            script,
55        }
56    }
57
58    /// Create a leaf with custom version
59    pub fn with_version(version: LeafVersion, script: Vec<u8>) -> Self {
60        Self { version, script }
61    }
62
63    /// Compute the leaf hash
64    pub fn hash(&self) -> TapLeafHash {
65        TapLeafHash::from_script(self.version.0, &self.script)
66    }
67}
68
69/// A node in the tap tree
70#[derive(Clone, Debug)]
71pub enum TapNode {
72    /// A leaf script
73    Leaf(TapLeaf),
74    /// A branch with two children
75    Branch(Box<TapNode>, Box<TapNode>),
76}
77
78impl TapNode {
79    /// Compute the hash of this node
80    pub fn hash(&self) -> TapNodeHash {
81        match self {
82            TapNode::Leaf(leaf) => TapNodeHash::from_leaf(leaf.hash()),
83            TapNode::Branch(left, right) => {
84                TapNodeHash::from_children(&left.hash(), &right.hash())
85            }
86        }
87    }
88
89    /// Check if this is a leaf
90    pub fn is_leaf(&self) -> bool {
91        matches!(self, TapNode::Leaf(_))
92    }
93
94    /// Get the leaf if this is a leaf node
95    pub fn as_leaf(&self) -> Option<&TapLeaf> {
96        match self {
97            TapNode::Leaf(leaf) => Some(leaf),
98            TapNode::Branch(_, _) => None,
99        }
100    }
101}
102
103/// Complete tap tree
104#[derive(Clone, Debug)]
105pub struct TapTree {
106    root: TapNode,
107}
108
109impl TapTree {
110    /// Create a tap tree from a root node
111    pub fn from_node(root: TapNode) -> Self {
112        Self { root }
113    }
114
115    /// Create a tap tree with a single leaf
116    pub fn single_leaf(script: Vec<u8>) -> Self {
117        Self {
118            root: TapNode::Leaf(TapLeaf::new(script)),
119        }
120    }
121
122    /// Get the merkle root hash
123    pub fn root_hash(&self) -> TapNodeHash {
124        self.root.hash()
125    }
126
127    /// Get the root node
128    pub fn root(&self) -> &TapNode {
129        &self.root
130    }
131
132    /// Find the merkle path to a leaf
133    pub fn merkle_path(&self, target_leaf: &TapLeaf) -> Option<Vec<TapNodeHash>> {
134        let target_hash = target_leaf.hash();
135        self.find_path(&self.root, &TapNodeHash::from_leaf(target_hash))
136    }
137
138    fn find_path(&self, node: &TapNode, target: &TapNodeHash) -> Option<Vec<TapNodeHash>> {
139        match node {
140            TapNode::Leaf(leaf) => {
141                if TapNodeHash::from_leaf(leaf.hash()) == *target {
142                    Some(Vec::new())
143                } else {
144                    None
145                }
146            }
147            TapNode::Branch(left, right) => {
148                // Try left branch
149                if let Some(mut path) = self.find_path(left, target) {
150                    path.push(right.hash());
151                    return Some(path);
152                }
153                // Try right branch
154                if let Some(mut path) = self.find_path(right, target) {
155                    path.push(left.hash());
156                    return Some(path);
157                }
158                None
159            }
160        }
161    }
162
163    /// Get all leaves in the tree
164    pub fn leaves(&self) -> Vec<&TapLeaf> {
165        let mut leaves = Vec::new();
166        self.collect_leaves(&self.root, &mut leaves);
167        leaves
168    }
169
170    fn collect_leaves<'a>(&'a self, node: &'a TapNode, leaves: &mut Vec<&'a TapLeaf>) {
171        match node {
172            TapNode::Leaf(leaf) => leaves.push(leaf),
173            TapNode::Branch(left, right) => {
174                self.collect_leaves(left, leaves);
175                self.collect_leaves(right, leaves);
176            }
177        }
178    }
179}
180
181/// Builder for constructing tap trees
182#[derive(Default)]
183pub struct TapTreeBuilder {
184    leaves: Vec<(TapLeaf, u8)>, // (leaf, depth)
185}
186
187impl TapTreeBuilder {
188    /// Create a new builder
189    pub fn new() -> Self {
190        Self::default()
191    }
192
193    /// Add a leaf at a specific depth
194    pub fn add_leaf(mut self, depth: u8, script: Vec<u8>) -> Self {
195        self.leaves.push((TapLeaf::new(script), depth));
196        self
197    }
198
199    /// Add a leaf with custom version at a specific depth
200    pub fn add_leaf_with_version(
201        mut self,
202        depth: u8,
203        version: LeafVersion,
204        script: Vec<u8>,
205    ) -> Self {
206        self.leaves.push((TapLeaf::with_version(version, script), depth));
207        self
208    }
209
210    /// Build the tap tree
211    pub fn build(self) -> Result<TapTree, TaprootError> {
212        if self.leaves.is_empty() {
213            return Err(TaprootError::EmptyTree);
214        }
215
216        if self.leaves.len() == 1 {
217            return Ok(TapTree::single_leaf(self.leaves[0].0.script.clone()));
218        }
219
220        // Sort by depth (descending) for proper tree construction
221        let mut leaves = self.leaves;
222        leaves.sort_by(|a, b| b.1.cmp(&a.1));
223
224        // Build tree from leaves
225        let mut nodes: Vec<(TapNode, u8)> = leaves
226            .into_iter()
227            .map(|(leaf, depth)| (TapNode::Leaf(leaf), depth))
228            .collect();
229
230        while nodes.len() > 1 {
231            // Find two nodes at the same depth
232            let mut i = 0;
233            while i < nodes.len() - 1 {
234                if nodes[i].1 == nodes[i + 1].1 {
235                    let (right, _) = nodes.remove(i + 1);
236                    let (left, depth) = nodes.remove(i);
237                    let branch = TapNode::Branch(Box::new(left), Box::new(right));
238                    nodes.insert(i, (branch, depth.saturating_sub(1)));
239                } else {
240                    i += 1;
241                }
242            }
243
244            // If no pairs found at same depth, we have an unbalanced tree
245            // Combine the deepest nodes
246            if nodes.len() > 1 && nodes.iter().all(|(_, d)| *d == nodes[0].1) {
247                // All at same depth but odd number - this shouldn't happen with valid input
248                break;
249            }
250        }
251
252        if nodes.len() != 1 {
253            return Err(TaprootError::TreeError(
254                "Could not build balanced tree".into(),
255            ));
256        }
257
258        Ok(TapTree::from_node(nodes.remove(0).0))
259    }
260}
261
262/// Create a simple 2-leaf tree
263pub fn two_leaf_tree(script1: Vec<u8>, script2: Vec<u8>) -> TapTree {
264    let left = TapNode::Leaf(TapLeaf::new(script1));
265    let right = TapNode::Leaf(TapLeaf::new(script2));
266    TapTree::from_node(TapNode::Branch(Box::new(left), Box::new(right)))
267}
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272
273    #[test]
274    fn test_leaf_hash() {
275        let leaf = TapLeaf::new(vec![0x51]); // OP_1
276        let hash = leaf.hash();
277        
278        // Hash should be deterministic
279        let hash2 = leaf.hash();
280        assert_eq!(hash, hash2);
281    }
282
283    #[test]
284    fn test_single_leaf_tree() {
285        let tree = TapTree::single_leaf(vec![0x51]);
286        let leaves = tree.leaves();
287        assert_eq!(leaves.len(), 1);
288    }
289
290    #[test]
291    fn test_two_leaf_tree() {
292        let tree = two_leaf_tree(vec![0x51], vec![0x52]);
293        let leaves = tree.leaves();
294        assert_eq!(leaves.len(), 2);
295    }
296
297    #[test]
298    fn test_merkle_path() {
299        let script1 = vec![0x51];
300        let script2 = vec![0x52];
301        let tree = two_leaf_tree(script1.clone(), script2.clone());
302        
303        let leaf1 = TapLeaf::new(script1);
304        let path = tree.merkle_path(&leaf1).unwrap();
305        
306        // Path should have one element (the sibling hash)
307        assert_eq!(path.len(), 1);
308    }
309
310    #[test]
311    fn test_builder_single_leaf() {
312        let tree = TapTreeBuilder::new()
313            .add_leaf(0, vec![0x51])
314            .build()
315            .unwrap();
316        
317        assert_eq!(tree.leaves().len(), 1);
318    }
319
320    #[test]
321    fn test_builder_two_leaves() {
322        let tree = TapTreeBuilder::new()
323            .add_leaf(1, vec![0x51])
324            .add_leaf(1, vec![0x52])
325            .build()
326            .unwrap();
327        
328        assert_eq!(tree.leaves().len(), 2);
329    }
330
331    #[test]
332    fn test_leaf_version() {
333        assert!(LeafVersion::new(0xc0).is_ok());
334        assert!(LeafVersion::new(0xc2).is_ok());
335        assert!(LeafVersion::new(0xc1).is_err()); // Odd version invalid
336    }
337
338    #[test]
339    fn test_branch_hash_deterministic() {
340        let tree1 = two_leaf_tree(vec![0x51], vec![0x52]);
341        let tree2 = two_leaf_tree(vec![0x51], vec![0x52]);
342        
343        assert_eq!(tree1.root_hash(), tree2.root_hash());
344    }
345}