spark_rust/wallet/internal_handlers/implementations/
create_tree.rs

1use crate::common_types::types::Address;
2use crate::common_types::types::Encodable;
3use crate::common_types::types::Secp256k1;
4use crate::common_types::types::Transaction;
5use crate::common_types::types::TxIn;
6use crate::common_types::types::TxOut;
7use crate::common_types::types::Txid;
8use crate::error::{validation::ValidationError, SparkSdkError};
9use crate::signer::traits::derivation_path::SparkKeyType;
10use crate::signer::traits::SparkSigner;
11use crate::wallet::client::SparkSdk;
12use crate::wallet::internal_handlers::traits::create_tree::BuildCreationNodesFromTreeSdkResponse;
13use crate::wallet::internal_handlers::traits::create_tree::CreateAddressRequestNodeFromTreeNodesSdkResponse;
14use crate::wallet::internal_handlers::traits::create_tree::CreateDepositAddressBinaryTreeSdkResponse;
15use crate::wallet::internal_handlers::traits::create_tree::CreateTreeInternalHandlers;
16use crate::wallet::internal_handlers::traits::create_tree::DepositAddressTree;
17use crate::wallet::internal_handlers::traits::create_tree::FinalizeTreeCreationSdkResponse;
18use crate::wallet::internal_handlers::traits::create_tree::GenerateDepositAddressForTreeSdkResponse;
19use crate::wallet::leaf_manager::SparkLeaf;
20use crate::wallet::utils::frost::frost_commitment_to_proto_commitment;
21use bitcoin::secp256k1::PublicKey;
22use parking_lot::RwLock;
23use spark_protos::spark::prepare_tree_address_request::Source as SourceProto;
24use spark_protos::spark::AddressNode as AddressNodeProto;
25use spark_protos::spark::AddressRequestNode as AddressRequestNodeProto;
26use spark_protos::spark::CreateTreeRequest;
27use spark_protos::spark::FinalizeNodeSignaturesRequest;
28use spark_protos::spark::NodeOutput as NodeOutputProto;
29use spark_protos::spark::PrepareTreeAddressRequest;
30use spark_protos::spark::TreeNode as TreeNodeProto;
31use spark_protos::spark::Utxo as UtxoProto;
32use std::collections::VecDeque;
33use std::str::FromStr;
34use std::sync::Arc;
35use tonic::async_trait;
36use uuid::Uuid;
37
38#[async_trait]
39impl<S: SparkSigner + Send + Sync + Clone + 'static> CreateTreeInternalHandlers<S> for SparkSdk<S> {
40    /// Creates a binary tree of deposit addresses.
41    ///
42    /// The tree is created by recursively splitting the target signing private key into two halves.
43    /// Each node in the tree represents a deposit address, and the children of each node are the
44    /// next level of the tree.
45    ///
46    /// # Arguments
47    ///
48    /// * `split_level` - The level of the tree to create.
49    /// * `target_pubkey` - The public key to split into the tree.
50    #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
51    fn create_deposit_address_binary_tree(
52        &self,
53        split_level: u32,
54        target_pubkey: &Vec<u8>,
55    ) -> Result<CreateDepositAddressBinaryTreeSdkResponse, SparkSdkError> {
56        if split_level == 0 {
57            return Ok(CreateDepositAddressBinaryTreeSdkResponse { tree: vec![] });
58        }
59
60        // generate left pubkey
61        let left_pubkey_uuid = Uuid::now_v7().to_string();
62        let left_pubkey = self.signer.new_secp256k1_keypair(
63            left_pubkey_uuid,
64            SparkKeyType::TemporarySigning,
65            0,
66            self.config.spark_config.network.to_bitcoin_network(),
67        )?;
68        let left_pubkey = left_pubkey.serialize().to_vec();
69
70        // left node
71        let mut left_node = DepositAddressTree {
72            address: None,
73            verification_key: None,
74            signing_public_key: left_pubkey.clone(),
75            children: vec![],
76        };
77
78        // create left children recursively
79        let left_children =
80            self.create_deposit_address_binary_tree(split_level - 1, &left_pubkey)?;
81        left_node.children = left_children.tree;
82
83        // calculate right pubkey
84        let right_pubkey = self.signer.subtract_secret_keys_given_pubkeys(
85            &PublicKey::from_slice(target_pubkey)?,
86            &PublicKey::from_slice(&left_pubkey)?,
87            true,
88        )?;
89
90        // right node
91        let mut right_node = DepositAddressTree {
92            address: None,
93            verification_key: None,
94            signing_public_key: right_pubkey.serialize().to_vec(),
95            children: vec![],
96        };
97
98        // create right children recursively
99        let right_children = self.create_deposit_address_binary_tree(
100            split_level - 1,
101            &right_pubkey.serialize().to_vec(),
102        )?;
103        right_node.children = right_children.tree;
104
105        Ok(CreateDepositAddressBinaryTreeSdkResponse {
106            tree: vec![
107                Arc::new(RwLock::new(left_node)),
108                Arc::new(RwLock::new(right_node)),
109            ],
110        })
111    }
112
113    #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
114    fn create_address_request_node_from_tree_nodes(
115        &self,
116        tree_nodes: &Vec<Arc<RwLock<DepositAddressTree>>>,
117    ) -> Result<CreateAddressRequestNodeFromTreeNodesSdkResponse, SparkSdkError> {
118        let mut results = Vec::<AddressRequestNodeProto>::new();
119
120        for node in tree_nodes {
121            let node = node.read();
122
123            let address_request_node =
124                self.create_address_request_node_from_tree_nodes(&node.children)?;
125            let address_request_node = AddressRequestNodeProto {
126                user_public_key: node.signing_public_key.clone(),
127                children: address_request_node.address_request_nodes,
128            };
129            results.push(address_request_node);
130        }
131
132        Ok(CreateAddressRequestNodeFromTreeNodesSdkResponse {
133            address_request_nodes: results,
134        })
135    }
136
137    #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
138    fn apply_address_nodes_to_tree(
139        &self,
140        tree: &mut Vec<Arc<RwLock<DepositAddressTree>>>,
141        address_nodes: Vec<AddressNodeProto>,
142    ) -> Result<(), SparkSdkError> {
143        for (i, node) in tree.iter_mut().enumerate() {
144            let mut node = node.write();
145            let node_address_data = address_nodes[i].address.clone().unwrap();
146            node.address = Some(node_address_data.address);
147            node.verification_key = Some(node_address_data.verifying_key);
148
149            if !node.children.is_empty() {
150                self.apply_address_nodes_to_tree(
151                    &mut node.children,
152                    address_nodes[i].children.clone(),
153                )?;
154            }
155        }
156
157        Ok(())
158    }
159
160    #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
161    async fn generate_deposit_address_for_tree(
162        &self,
163        parent_tx: Option<Transaction>,
164        parent_node: Option<Arc<RwLock<TreeNodeProto>>>,
165        vout: u32,
166        parent_public_key: Vec<u8>,
167        split_level: u32,
168    ) -> Result<GenerateDepositAddressForTreeSdkResponse, SparkSdkError> {
169        let network_proto = self.config.spark_config.network.marshal_proto();
170
171        // 1. Create the binary tree given the user request
172        let time_start = std::time::Instant::now();
173        let deposit_address_tree =
174            self.create_deposit_address_binary_tree(split_level, &parent_public_key)?;
175        let duration = time_start.elapsed();
176        #[cfg(feature = "telemetry")]
177        tracing::debug!(duration = ?duration, "create_deposit_address_binary_tree");
178
179        // If split_level = 0, len = 1. if split_level = 1, len = 1 + 2 = 3. If split_level = 2, len = 1 + 2 + 4 = 7.
180        // This is because the tree is a binary tree, and the number of nodes is 2^split_level.
181        // assert!(tree.len() == 2u32.pow(split_level) as usize);
182
183        // 2. Create the address request nodes (in proto format) from the tree nodes
184        let time_start = std::time::Instant::now();
185        let address_nodes =
186            self.create_address_request_node_from_tree_nodes(&deposit_address_tree.tree)?;
187        let duration = time_start.elapsed();
188        #[cfg(feature = "telemetry")]
189        tracing::debug!(duration = ?duration, "create_address_request_node_from_tree_nodes");
190
191        // 3. Send PrepareTreeAddressRequest to Spark. This is the first step of tree creation.
192        let request_source = match parent_node {
193            Some(parent_node) => {
194                let node_output = NodeOutputProto {
195                    node_id: parent_node.read().id.clone(),
196                    vout,
197                };
198                SourceProto::ParentNodeOutput(node_output)
199            }
200            None => {
201                let mut raw_tx = Vec::new();
202                parent_tx
203                    .as_ref()
204                    .unwrap()
205                    .consensus_encode(&mut raw_tx)
206                    .map_err(|e| {
207                        SparkSdkError::from(ValidationError::InvalidArgument {
208                            argument: e.to_string(),
209                        })
210                    })?;
211                let utxo = UtxoProto {
212                    vout,
213                    raw_tx,
214                    network: network_proto,
215                };
216                SourceProto::OnChainUtxo(utxo)
217            }
218        };
219        let source = Some(request_source);
220        let request_node = AddressRequestNodeProto {
221            user_public_key: parent_public_key.clone(),
222            children: address_nodes.address_request_nodes,
223        };
224        let node = Some(request_node);
225
226        let request_data = PrepareTreeAddressRequest {
227            user_identity_public_key: self.get_spark_address()?.serialize().to_vec(),
228            node,
229            source,
230        };
231
232        let time_start = std::time::Instant::now();
233        let spark_tree_response = self
234            .config
235            .spark_config
236            .call_with_retry(
237                request_data,
238                |mut client, req| Box::pin(async move { client.prepare_tree_address(req).await }),
239                None,
240            )
241            .await?;
242
243        let duration = time_start.elapsed();
244        #[cfg(feature = "telemetry")]
245        tracing::debug!(duration = ?duration, "prepare_tree_address");
246
247        // 4. Create the root node
248        let response_address_node = spark_tree_response.node.unwrap();
249        let root = DepositAddressTree {
250            address: None,
251            verification_key: None,
252            signing_public_key: parent_public_key.clone(),
253            children: deposit_address_tree.tree.clone(),
254        };
255        let root = Arc::new(RwLock::new(root));
256
257        // 5. Apply the address nodes to the tree
258        let mut root_in_vec = vec![root];
259        let time_start = std::time::Instant::now();
260        self.apply_address_nodes_to_tree(&mut root_in_vec, vec![response_address_node])?;
261        let duration = time_start.elapsed();
262        #[cfg(feature = "telemetry")]
263        tracing::debug!(duration = ?duration, "apply_address_nodes_to_tree");
264
265        // 6. Extract the root back from the vector and return it
266        let root = root_in_vec.remove(0);
267
268        Ok(GenerateDepositAddressForTreeSdkResponse { tree: root })
269    }
270
271    #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
272    fn build_creation_nodes_from_tree(
273        &self,
274        parent_txid: Txid,
275        txout_: &TxOut,
276        vout: u32,
277        root: Arc<RwLock<DepositAddressTree>>,
278    ) -> Result<BuildCreationNodesFromTreeSdkResponse, SparkSdkError> {
279        struct TreeNode {
280            parent_txid: Txid,
281            txout: TxOut,
282            vout: u32,
283            node: Arc<RwLock<DepositAddressTree>>,
284        }
285
286        let mut creation_node = spark_protos::spark::CreationNode::default();
287        let mut queue = VecDeque::<(TreeNode, *mut spark_protos::spark::CreationNode)>::new();
288
289        queue.push_back((
290            TreeNode {
291                parent_txid,
292                txout: txout_.clone(),
293                vout,
294                node: root,
295            },
296            &mut creation_node,
297        ));
298
299        let network = self.config.spark_config.network.to_bitcoin_network();
300        let secp = Secp256k1::new();
301
302        while let Some((current, creation_ptr)) = queue.pop_front() {
303            let creation_ref = unsafe { &mut *creation_ptr };
304            let local_node = current.node.read();
305
306            // let node_signing_key = local_node.verification_key.clone();
307            // let user_signing_key = SecretKey::from_slice(&node_signing_key).unwrap();
308            let user_verifying_key_ = local_node.signing_public_key.clone();
309            let user_verifying_key =
310                bitcoin::secp256k1::PublicKey::from_slice(&user_verifying_key_).unwrap();
311
312            if !local_node.children.is_empty() {
313                let tx = {
314                    let tx_input = bitcoin::TxIn {
315                        previous_output: bitcoin::OutPoint {
316                            txid: current.parent_txid,
317                            vout: current.vout,
318                        },
319                        script_sig: bitcoin::ScriptBuf::default(),
320                        sequence: bitcoin::Sequence::ZERO,
321                        witness: bitcoin::Witness::default(),
322                    };
323
324                    let mut output = Vec::new();
325                    for child in local_node.children.iter() {
326                        let child_address = child.read().address.clone().unwrap();
327                        let child_address = Address::from_str(&child_address).unwrap();
328                        let child_address = child_address.require_network(network).unwrap();
329                        output.push(TxOut {
330                            value: bitcoin::Amount::from_sat(
331                                current.txout.value.to_sat() / local_node.children.len() as u64,
332                            ),
333                            script_pubkey: child_address.script_pubkey(),
334                        });
335                    }
336
337                    bitcoin::Transaction {
338                        version: bitcoin::transaction::Version::TWO,
339                        lock_time: bitcoin::absolute::LockTime::ZERO,
340                        input: vec![tx_input],
341                        output,
342                    }
343                };
344
345                let mut tx_buf = vec![];
346                tx.consensus_encode(&mut tx_buf).map_err(|e| {
347                    SparkSdkError::from(ValidationError::InvalidArgument {
348                        argument: e.to_string(),
349                    })
350                })?;
351
352                let commitment = self.signer.new_frost_signing_noncepair()?;
353
354                let signing_job = spark_protos::spark::SigningJob {
355                    signing_public_key: user_verifying_key.serialize().to_vec(),
356                    raw_tx: tx_buf,
357                    signing_nonce_commitment: Some(frost_commitment_to_proto_commitment(
358                        &commitment,
359                    )?),
360                };
361
362                creation_ref.node_tx_signing_job = Some(signing_job);
363                creation_ref.children = vec![Default::default(); local_node.children.len()];
364
365                let txid = tx.compute_txid();
366                for (i, child) in local_node.children.iter().enumerate() {
367                    queue.push_back((
368                        TreeNode {
369                            parent_txid: txid,
370                            txout: tx.output[i].clone(),
371                            vout: i as u32,
372                            node: child.clone(),
373                        },
374                        &mut creation_ref.children[i] as *mut _,
375                    ));
376                }
377            } else {
378                let aggregated_address = local_node.address.clone().unwrap();
379                let aggregated_address = Address::from_str(&aggregated_address).unwrap();
380                let aggregated_address = aggregated_address.require_network(network).unwrap();
381
382                let node_tx = {
383                    let input = bitcoin::TxIn {
384                        previous_output: bitcoin::OutPoint {
385                            txid: current.parent_txid,
386                            vout: current.vout,
387                        },
388                        script_sig: bitcoin::ScriptBuf::default(),
389                        // sequence: bitcoin::Sequence::ZERO,
390                        sequence: bitcoin::Sequence::ZERO,
391                        witness: bitcoin::Witness::default(),
392                    };
393
394                    bitcoin::Transaction {
395                        version: bitcoin::transaction::Version::TWO,
396                        lock_time: bitcoin::absolute::LockTime::ZERO,
397                        input: vec![input],
398                        output: vec![TxOut {
399                            value: current.txout.value,
400                            script_pubkey: aggregated_address.script_pubkey(),
401                        }],
402                    }
403                };
404
405                let mut node_tx_buf = vec![];
406                node_tx.consensus_encode(&mut node_tx_buf).map_err(|e| {
407                    SparkSdkError::from(ValidationError::InvalidArgument {
408                        argument: e.to_string(),
409                    })
410                })?;
411
412                let refund_tx = {
413                    let user_self_xonly = user_verifying_key.x_only_public_key().0;
414                    let user_self_address = Address::p2tr(&secp, user_self_xonly, None, network);
415
416                    bitcoin::Transaction {
417                        version: bitcoin::transaction::Version::TWO,
418                        lock_time: bitcoin::absolute::LockTime::ZERO,
419                        input: vec![TxIn {
420                            previous_output: bitcoin::OutPoint {
421                                txid: node_tx.compute_txid(),
422                                vout: 0,
423                            },
424                            script_sig: bitcoin::ScriptBuf::default(),
425                            // TODO: this must be the default sequence. For tree creation here, we can set it to MAX, since this is an SSP feature and is unlikely to be used here. Yet, this will affect unilateral exits.
426                            sequence: bitcoin::Sequence::MAX,
427                            witness: bitcoin::Witness::default(),
428                        }],
429                        output: vec![TxOut {
430                            value: current.txout.value,
431                            script_pubkey: user_self_address.script_pubkey(),
432                        }],
433                    }
434                };
435
436                let mut refund_tx_buf = vec![];
437                refund_tx
438                    .consensus_encode(&mut refund_tx_buf)
439                    .map_err(|e| {
440                        SparkSdkError::from(ValidationError::InvalidArgument {
441                            argument: e.to_string(),
442                        })
443                    })?;
444
445                let node_commitment = self.signer.new_frost_signing_noncepair()?;
446                let refund_commitment = self.signer.new_frost_signing_noncepair()?;
447
448                creation_ref.node_tx_signing_job = Some(spark_protos::spark::SigningJob {
449                    signing_public_key: user_verifying_key.serialize().to_vec(),
450                    raw_tx: node_tx_buf,
451                    signing_nonce_commitment: Some(frost_commitment_to_proto_commitment(
452                        &node_commitment,
453                    )?),
454                });
455
456                creation_ref.refund_tx_signing_job = Some(spark_protos::spark::SigningJob {
457                    signing_public_key: user_verifying_key.serialize().to_vec(),
458                    raw_tx: refund_tx_buf,
459                    signing_nonce_commitment: Some(frost_commitment_to_proto_commitment(
460                        &refund_commitment,
461                    )?),
462                });
463            }
464        }
465
466        Ok(BuildCreationNodesFromTreeSdkResponse {
467            creation_nodes: creation_node,
468        })
469    }
470
471    #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
472    async fn finalize_tree_creation(
473        &self,
474        parent_tx: Option<Transaction>,
475        parent_node: Option<Arc<RwLock<TreeNodeProto>>>,
476        vout: u32,
477        root: Arc<RwLock<DepositAddressTree>>,
478    ) -> Result<FinalizeTreeCreationSdkResponse, SparkSdkError> {
479        let mut request = CreateTreeRequest {
480            user_identity_public_key: self.get_spark_address()?.serialize().to_vec(),
481            ..Default::default()
482        };
483
484        let final_parent_tx = if let Some(ptx) = parent_tx {
485            let mut raw_tx = Vec::new();
486            ptx.consensus_encode(&mut raw_tx).map_err(|e| {
487                SparkSdkError::from(ValidationError::InvalidArgument {
488                    argument: e.to_string(),
489                })
490            })?;
491
492            request.source = Some(
493                spark_protos::spark::create_tree_request::Source::OnChainUtxo(
494                    spark_protos::spark::Utxo {
495                        // txid,
496                        vout,
497                        raw_tx,
498                        network: self.config.spark_config.network.marshal_proto(),
499                    },
500                ),
501            );
502            ptx
503        } else if let Some(parent_node) = parent_node {
504            let tx_buf = parent_node.read().node_tx.clone();
505            let ptx: Transaction = bitcoin::consensus::deserialize(&tx_buf).map_err(|_| {
506                SparkSdkError::from(ValidationError::InvalidArgument {
507                    argument: "Failed to parse parent node_tx".to_string(),
508                })
509            })?;
510            let node_id = parent_node.read().id.clone();
511            request.source = Some(
512                spark_protos::spark::create_tree_request::Source::ParentNodeOutput(
513                    spark_protos::spark::NodeOutput { node_id, vout },
514                ),
515            );
516            ptx
517        } else {
518            return Err(SparkSdkError::from(ValidationError::InvalidArgument {
519                argument: "No parent_tx or parent_node provided to create_tree".to_string(),
520            }));
521        };
522
523        let parent_txid = final_parent_tx.compute_txid();
524        let time_start = std::time::Instant::now();
525        let creation_node_response = self.build_creation_nodes_from_tree(
526            parent_txid,
527            &final_parent_tx.output[vout as usize],
528            vout,
529            root.clone(),
530        )?;
531        let duration = time_start.elapsed();
532        #[cfg(feature = "telemetry")]
533        tracing::debug!(duration = ?duration, "build_creation_nodes_from_tree");
534
535        request.node = Some(creation_node_response.creation_nodes.clone());
536
537        let spark_tree_response = self
538            .config
539            .spark_config
540            .call_with_retry(
541                request,
542                |mut client, req| Box::pin(async move { client.create_tree(req).await }),
543                None,
544            )
545            .await?;
546
547        let duration = time_start.elapsed();
548        #[cfg(feature = "telemetry")]
549        tracing::debug!(duration = ?duration, "create_tree");
550
551        let tree_node = spark_tree_response.node.ok_or_else(|| {
552            SparkSdkError::from(ValidationError::InvalidArgument {
553                argument: "Coordinator returned no creation node".to_string(),
554            })
555        })?;
556
557        let time_start = std::time::Instant::now();
558        let (node_signatures, signing_public_keys) = self.signer.sign_created_tree_in_bfs_order(
559            final_parent_tx,
560            vout,
561            root,
562            creation_node_response.creation_nodes,
563            tree_node,
564        )?;
565        let duration = time_start.elapsed();
566        #[cfg(feature = "telemetry")]
567        tracing::debug!(duration = ?duration, "sign_created_tree_in_bfs_order");
568        let spark_signatures_request = FinalizeNodeSignaturesRequest {
569            node_signatures: node_signatures.clone(),
570            ..Default::default()
571        };
572
573        let time_start = std::time::Instant::now();
574
575        let spark_signatures_response = self
576            .config
577            .spark_config
578            .call_with_retry(
579                spark_signatures_request,
580                |mut client, req| {
581                    Box::pin(async move { client.finalize_node_signatures(req).await })
582                },
583                None,
584            )
585            .await?;
586
587        let duration = time_start.elapsed();
588        #[cfg(feature = "telemetry")]
589        tracing::debug!(duration = ?duration, "finalize_node_signatures");
590
591        // Slice the array to get the second half (including the center)
592        let starting_index = spark_signatures_response.nodes.len() / 2;
593
594        let leaf_nodes = spark_signatures_response
595            .nodes
596            .iter()
597            .skip(spark_signatures_response.nodes.len() / 2)
598            .collect::<Vec<_>>();
599
600        let mut leaf_nodes_to_insert = vec![];
601        for (i, node) in leaf_nodes.iter().enumerate() {
602            let node_id = node.id.clone();
603            let i = i + starting_index;
604            if node_id != node_signatures[i].clone().node_id {
605                return Err(SparkSdkError::from(ValidationError::InvalidArgument {
606                    argument: "Node ID mismatch".to_string(),
607                }));
608            }
609
610            // TODO: this is an error-prone approach. For roots, this approach is correct.
611            // For splits, this approach is incorrect because now it should be strictly the right half, excluding the center.
612
613            // TODO: For the parents, add aggregated.
614            leaf_nodes_to_insert.push(SparkLeaf::Bitcoin((*node).clone()));
615        }
616
617        self.leaf_manager
618            .insert_leaves(leaf_nodes_to_insert, false)?;
619
620        Ok(FinalizeTreeCreationSdkResponse {
621            finalize_tree_response: spark_signatures_response,
622            signing_public_keys,
623        })
624    }
625}