spark_rust/wallet/internal_handlers/implementations/
create_tree.rs

1use crate::common_types::types::Address;
2use crate::common_types::types::Encodable;
3use crate::common_types::types::Secp256k1;
4use crate::common_types::types::Transaction;
5use crate::common_types::types::TxIn;
6use crate::common_types::types::TxOut;
7use crate::common_types::types::Txid;
8use crate::error::{validation::ValidationError, SparkSdkError};
9use crate::signer::traits::SparkSigner;
10use crate::wallet::client::SparkSdk;
11use crate::wallet::internal_handlers::traits::create_tree::BuildCreationNodesFromTreeSdkResponse;
12use crate::wallet::internal_handlers::traits::create_tree::CreateAddressRequestNodeFromTreeNodesSdkResponse;
13use crate::wallet::internal_handlers::traits::create_tree::CreateDepositAddressBinaryTreeSdkResponse;
14use crate::wallet::internal_handlers::traits::create_tree::CreateTreeInternalHandlers;
15use crate::wallet::internal_handlers::traits::create_tree::DepositAddressTree;
16use crate::wallet::internal_handlers::traits::create_tree::FinalizeTreeCreationSdkResponse;
17use crate::wallet::internal_handlers::traits::create_tree::GenerateDepositAddressForTreeSdkResponse;
18use crate::wallet::leaf_manager::SparkLeaf;
19use crate::wallet::utils::frost::frost_commitment_to_proto_commitment;
20use bitcoin::secp256k1::PublicKey;
21use parking_lot::RwLock;
22use spark_cryptography::derivation_path::SparkKeyType;
23use spark_protos::spark::prepare_tree_address_request::Source as SourceProto;
24use spark_protos::spark::AddressNode as AddressNodeProto;
25use spark_protos::spark::AddressRequestNode as AddressRequestNodeProto;
26use spark_protos::spark::CreateTreeRequest;
27use spark_protos::spark::FinalizeNodeSignaturesRequest;
28use spark_protos::spark::NodeOutput as NodeOutputProto;
29use spark_protos::spark::PrepareTreeAddressRequest;
30use spark_protos::spark::TreeNode as TreeNodeProto;
31use spark_protos::spark::Utxo as UtxoProto;
32use std::collections::VecDeque;
33use std::str::FromStr;
34use std::sync::Arc;
35use tonic::async_trait;
36use uuid::Uuid;
37
38#[async_trait]
39impl<S: SparkSigner + Send + Sync + Clone + 'static> CreateTreeInternalHandlers<S> for SparkSdk<S> {
40    /// Creates a binary tree of deposit addresses.
41    ///
42    /// The tree is created by recursively splitting the target signing private key into two halves.
43    /// Each node in the tree represents a deposit address, and the children of each node are the
44    /// next level of the tree.
45    ///
46    /// # Arguments
47    ///
48    /// * `split_level` - The level of the tree to create.
49    /// * `target_pubkey` - The public key to split into the tree.
50    #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
51    fn create_deposit_address_binary_tree(
52        &self,
53        split_level: u32,
54        target_pubkey: &Vec<u8>,
55    ) -> Result<CreateDepositAddressBinaryTreeSdkResponse, SparkSdkError> {
56        if split_level == 0 {
57            return Ok(CreateDepositAddressBinaryTreeSdkResponse { tree: vec![] });
58        }
59
60        // generate left pubkey
61        let left_pubkey_uuid = Uuid::now_v7().to_string();
62        let left_keypair = self.signer.derive_spark_key(
63            left_pubkey_uuid,
64            0,
65            SparkKeyType::Deposit,
66            self.config.spark_config.network.to_bitcoin_network(),
67        )?;
68        let left_pk = left_keypair.public_key();
69        let left_pubkey = left_pk.serialize().to_vec();
70
71        // left node
72        let mut left_node = DepositAddressTree {
73            address: None,
74            verification_key: None,
75            signing_public_key: left_pubkey.clone(),
76            children: vec![],
77        };
78
79        // create left children recursively
80        let left_children =
81            self.create_deposit_address_binary_tree(split_level - 1, &left_pubkey)?;
82        left_node.children = left_children.tree;
83
84        // calculate right pubkey
85        let right_pubkey = self.signer.subtract_secret_keys_given_pubkeys(
86            &PublicKey::from_slice(target_pubkey)?,
87            &PublicKey::from_slice(&left_pubkey)?,
88            true,
89        )?;
90
91        // right node
92        let mut right_node = DepositAddressTree {
93            address: None,
94            verification_key: None,
95            signing_public_key: right_pubkey.serialize().to_vec(),
96            children: vec![],
97        };
98
99        // create right children recursively
100        let right_children = self.create_deposit_address_binary_tree(
101            split_level - 1,
102            &right_pubkey.serialize().to_vec(),
103        )?;
104        right_node.children = right_children.tree;
105
106        Ok(CreateDepositAddressBinaryTreeSdkResponse {
107            tree: vec![
108                Arc::new(RwLock::new(left_node)),
109                Arc::new(RwLock::new(right_node)),
110            ],
111        })
112    }
113
114    #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
115    fn create_address_request_node_from_tree_nodes(
116        &self,
117        tree_nodes: &Vec<Arc<RwLock<DepositAddressTree>>>,
118    ) -> Result<CreateAddressRequestNodeFromTreeNodesSdkResponse, SparkSdkError> {
119        let mut results = Vec::<AddressRequestNodeProto>::new();
120
121        for node in tree_nodes {
122            let node = node.read();
123
124            let address_request_node =
125                self.create_address_request_node_from_tree_nodes(&node.children)?;
126            let address_request_node = AddressRequestNodeProto {
127                user_public_key: node.signing_public_key.clone(),
128                children: address_request_node.address_request_nodes,
129            };
130            results.push(address_request_node);
131        }
132
133        Ok(CreateAddressRequestNodeFromTreeNodesSdkResponse {
134            address_request_nodes: results,
135        })
136    }
137
138    #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
139    fn apply_address_nodes_to_tree(
140        &self,
141        tree: &mut Vec<Arc<RwLock<DepositAddressTree>>>,
142        address_nodes: Vec<AddressNodeProto>,
143    ) -> Result<(), SparkSdkError> {
144        for (i, node) in tree.iter_mut().enumerate() {
145            let mut node = node.write();
146            let node_address_data = address_nodes[i].address.clone().unwrap();
147            node.address = Some(node_address_data.address);
148            node.verification_key = Some(node_address_data.verifying_key);
149
150            if !node.children.is_empty() {
151                self.apply_address_nodes_to_tree(
152                    &mut node.children,
153                    address_nodes[i].children.clone(),
154                )?;
155            }
156        }
157
158        Ok(())
159    }
160
161    #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
162    async fn generate_deposit_address_for_tree(
163        &self,
164        parent_tx: Option<Transaction>,
165        parent_node: Option<Arc<RwLock<TreeNodeProto>>>,
166        vout: u32,
167        parent_public_key: Vec<u8>,
168        split_level: u32,
169    ) -> Result<GenerateDepositAddressForTreeSdkResponse, SparkSdkError> {
170        let network_proto = self.config.spark_config.network.marshal_proto();
171
172        // 1. Create the binary tree given the user request
173        let time_start = std::time::Instant::now();
174        let deposit_address_tree =
175            self.create_deposit_address_binary_tree(split_level, &parent_public_key)?;
176        let duration = time_start.elapsed();
177        #[cfg(feature = "telemetry")]
178        tracing::debug!(duration = ?duration, "create_deposit_address_binary_tree");
179
180        // If split_level = 0, len = 1. if split_level = 1, len = 1 + 2 = 3. If split_level = 2, len = 1 + 2 + 4 = 7.
181        // This is because the tree is a binary tree, and the number of nodes is 2^split_level.
182        // assert!(tree.len() == 2u32.pow(split_level) as usize);
183
184        // 2. Create the address request nodes (in proto format) from the tree nodes
185        let time_start = std::time::Instant::now();
186        let address_nodes =
187            self.create_address_request_node_from_tree_nodes(&deposit_address_tree.tree)?;
188        let duration = time_start.elapsed();
189        #[cfg(feature = "telemetry")]
190        tracing::debug!(duration = ?duration, "create_address_request_node_from_tree_nodes");
191
192        // 3. Send PrepareTreeAddressRequest to Spark. This is the first step of tree creation.
193        let request_source = match parent_node {
194            Some(parent_node) => {
195                let node_output = NodeOutputProto {
196                    node_id: parent_node.read().id.clone(),
197                    vout,
198                };
199                SourceProto::ParentNodeOutput(node_output)
200            }
201            None => {
202                let mut raw_tx = Vec::new();
203                parent_tx
204                    .as_ref()
205                    .unwrap()
206                    .consensus_encode(&mut raw_tx)
207                    .map_err(|e| {
208                        SparkSdkError::from(ValidationError::InvalidArgument {
209                            argument: e.to_string(),
210                        })
211                    })?;
212                let utxo = UtxoProto {
213                    vout,
214                    raw_tx,
215                    network: network_proto,
216                };
217                SourceProto::OnChainUtxo(utxo)
218            }
219        };
220        let source = Some(request_source);
221        let request_node = AddressRequestNodeProto {
222            user_public_key: parent_public_key.clone(),
223            children: address_nodes.address_request_nodes,
224        };
225        let node = Some(request_node);
226
227        let request_data = PrepareTreeAddressRequest {
228            user_identity_public_key: self.get_spark_address()?.serialize().to_vec(),
229            node,
230            source,
231        };
232
233        let time_start = std::time::Instant::now();
234        let spark_tree_response = self
235            .config
236            .spark_config
237            .call_with_retry(
238                request_data,
239                |mut client, req| Box::pin(async move { client.prepare_tree_address(req).await }),
240                None,
241            )
242            .await?;
243
244        let duration = time_start.elapsed();
245        #[cfg(feature = "telemetry")]
246        tracing::debug!(duration = ?duration, "prepare_tree_address");
247
248        // 4. Create the root node
249        let response_address_node = spark_tree_response.node.unwrap();
250        let root = DepositAddressTree {
251            address: None,
252            verification_key: None,
253            signing_public_key: parent_public_key.clone(),
254            children: deposit_address_tree.tree.clone(),
255        };
256        let root = Arc::new(RwLock::new(root));
257
258        // 5. Apply the address nodes to the tree
259        let mut root_in_vec = vec![root];
260        let time_start = std::time::Instant::now();
261        self.apply_address_nodes_to_tree(&mut root_in_vec, vec![response_address_node])?;
262        let duration = time_start.elapsed();
263        #[cfg(feature = "telemetry")]
264        tracing::debug!(duration = ?duration, "apply_address_nodes_to_tree");
265
266        // 6. Extract the root back from the vector and return it
267        let root = root_in_vec.remove(0);
268
269        Ok(GenerateDepositAddressForTreeSdkResponse { tree: root })
270    }
271
272    #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
273    fn build_creation_nodes_from_tree(
274        &self,
275        parent_txid: Txid,
276        txout_: &TxOut,
277        vout: u32,
278        root: Arc<RwLock<DepositAddressTree>>,
279    ) -> Result<BuildCreationNodesFromTreeSdkResponse, SparkSdkError> {
280        struct TreeNode {
281            parent_txid: Txid,
282            txout: TxOut,
283            vout: u32,
284            node: Arc<RwLock<DepositAddressTree>>,
285        }
286
287        let mut creation_node = spark_protos::spark::CreationNode::default();
288        let mut queue = VecDeque::<(TreeNode, *mut spark_protos::spark::CreationNode)>::new();
289
290        queue.push_back((
291            TreeNode {
292                parent_txid,
293                txout: txout_.clone(),
294                vout,
295                node: root,
296            },
297            &mut creation_node,
298        ));
299
300        let network = self.config.spark_config.network.to_bitcoin_network();
301        let secp = Secp256k1::new();
302
303        while let Some((current, creation_ptr)) = queue.pop_front() {
304            let creation_ref = unsafe { &mut *creation_ptr };
305            let local_node = current.node.read();
306
307            // let node_signing_key = local_node.verification_key.clone();
308            // let user_signing_key = SecretKey::from_slice(&node_signing_key).unwrap();
309            let user_verifying_key_ = local_node.signing_public_key.clone();
310            let user_verifying_key =
311                bitcoin::secp256k1::PublicKey::from_slice(&user_verifying_key_).unwrap();
312
313            if !local_node.children.is_empty() {
314                let tx = {
315                    let tx_input = bitcoin::TxIn {
316                        previous_output: bitcoin::OutPoint {
317                            txid: current.parent_txid,
318                            vout: current.vout,
319                        },
320                        script_sig: bitcoin::ScriptBuf::default(),
321                        sequence: bitcoin::Sequence::ZERO,
322                        witness: bitcoin::Witness::default(),
323                    };
324
325                    let mut output = Vec::new();
326                    for child in local_node.children.iter() {
327                        let child_address = child.read().address.clone().unwrap();
328                        let child_address = Address::from_str(&child_address).unwrap();
329                        let child_address = child_address.require_network(network).unwrap();
330                        output.push(TxOut {
331                            value: bitcoin::Amount::from_sat(
332                                current.txout.value.to_sat() / local_node.children.len() as u64,
333                            ),
334                            script_pubkey: child_address.script_pubkey(),
335                        });
336                    }
337
338                    bitcoin::Transaction {
339                        version: bitcoin::transaction::Version::TWO,
340                        lock_time: bitcoin::absolute::LockTime::ZERO,
341                        input: vec![tx_input],
342                        output,
343                    }
344                };
345
346                let mut tx_buf = vec![];
347                tx.consensus_encode(&mut tx_buf).map_err(|e| {
348                    SparkSdkError::from(ValidationError::InvalidArgument {
349                        argument: e.to_string(),
350                    })
351                })?;
352
353                let commitment = self.signer.generate_frost_signing_commitments()?;
354
355                let signing_job = spark_protos::spark::SigningJob {
356                    signing_public_key: user_verifying_key.serialize().to_vec(),
357                    raw_tx: tx_buf,
358                    signing_nonce_commitment: Some(frost_commitment_to_proto_commitment(
359                        &commitment,
360                    )?),
361                };
362
363                creation_ref.node_tx_signing_job = Some(signing_job);
364                creation_ref.children = vec![Default::default(); local_node.children.len()];
365
366                let txid = tx.compute_txid();
367                for (i, child) in local_node.children.iter().enumerate() {
368                    queue.push_back((
369                        TreeNode {
370                            parent_txid: txid,
371                            txout: tx.output[i].clone(),
372                            vout: i as u32,
373                            node: child.clone(),
374                        },
375                        &mut creation_ref.children[i] as *mut _,
376                    ));
377                }
378            } else {
379                let aggregated_address = local_node.address.clone().unwrap();
380                let aggregated_address = Address::from_str(&aggregated_address).unwrap();
381                let aggregated_address = aggregated_address.require_network(network).unwrap();
382
383                let node_tx = {
384                    let input = bitcoin::TxIn {
385                        previous_output: bitcoin::OutPoint {
386                            txid: current.parent_txid,
387                            vout: current.vout,
388                        },
389                        script_sig: bitcoin::ScriptBuf::default(),
390                        // sequence: bitcoin::Sequence::ZERO,
391                        sequence: bitcoin::Sequence::ZERO,
392                        witness: bitcoin::Witness::default(),
393                    };
394
395                    bitcoin::Transaction {
396                        version: bitcoin::transaction::Version::TWO,
397                        lock_time: bitcoin::absolute::LockTime::ZERO,
398                        input: vec![input],
399                        output: vec![TxOut {
400                            value: current.txout.value,
401                            script_pubkey: aggregated_address.script_pubkey(),
402                        }],
403                    }
404                };
405
406                let mut node_tx_buf = vec![];
407                node_tx.consensus_encode(&mut node_tx_buf).map_err(|e| {
408                    SparkSdkError::from(ValidationError::InvalidArgument {
409                        argument: e.to_string(),
410                    })
411                })?;
412
413                let refund_tx = {
414                    let user_self_xonly = user_verifying_key.x_only_public_key().0;
415                    let user_self_address = Address::p2tr(&secp, user_self_xonly, None, network);
416
417                    bitcoin::Transaction {
418                        version: bitcoin::transaction::Version::TWO,
419                        lock_time: bitcoin::absolute::LockTime::ZERO,
420                        input: vec![TxIn {
421                            previous_output: bitcoin::OutPoint {
422                                txid: node_tx.compute_txid(),
423                                vout: 0,
424                            },
425                            script_sig: bitcoin::ScriptBuf::default(),
426                            // TODO: this must be the default sequence. For tree creation here, we can set it to MAX, since this is an SSP feature and is unlikely to be used here. Yet, this will affect unilateral exits.
427                            sequence: bitcoin::Sequence::MAX,
428                            witness: bitcoin::Witness::default(),
429                        }],
430                        output: vec![TxOut {
431                            value: current.txout.value,
432                            script_pubkey: user_self_address.script_pubkey(),
433                        }],
434                    }
435                };
436
437                let mut refund_tx_buf = vec![];
438                refund_tx
439                    .consensus_encode(&mut refund_tx_buf)
440                    .map_err(|e| {
441                        SparkSdkError::from(ValidationError::InvalidArgument {
442                            argument: e.to_string(),
443                        })
444                    })?;
445
446                let node_commitment = self.signer.generate_frost_signing_commitments()?;
447                let refund_commitment = self.signer.generate_frost_signing_commitments()?;
448
449                creation_ref.node_tx_signing_job = Some(spark_protos::spark::SigningJob {
450                    signing_public_key: user_verifying_key.serialize().to_vec(),
451                    raw_tx: node_tx_buf,
452                    signing_nonce_commitment: Some(frost_commitment_to_proto_commitment(
453                        &node_commitment,
454                    )?),
455                });
456
457                creation_ref.refund_tx_signing_job = Some(spark_protos::spark::SigningJob {
458                    signing_public_key: user_verifying_key.serialize().to_vec(),
459                    raw_tx: refund_tx_buf,
460                    signing_nonce_commitment: Some(frost_commitment_to_proto_commitment(
461                        &refund_commitment,
462                    )?),
463                });
464            }
465        }
466
467        Ok(BuildCreationNodesFromTreeSdkResponse {
468            creation_nodes: creation_node,
469        })
470    }
471
472    #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
473    async fn finalize_tree_creation(
474        &self,
475        parent_tx: Option<Transaction>,
476        parent_node: Option<Arc<RwLock<TreeNodeProto>>>,
477        vout: u32,
478        root: Arc<RwLock<DepositAddressTree>>,
479    ) -> Result<FinalizeTreeCreationSdkResponse, SparkSdkError> {
480        let mut request = CreateTreeRequest {
481            user_identity_public_key: self.get_spark_address()?.serialize().to_vec(),
482            ..Default::default()
483        };
484
485        let final_parent_tx = if let Some(ptx) = parent_tx {
486            let mut raw_tx = Vec::new();
487            ptx.consensus_encode(&mut raw_tx).map_err(|e| {
488                SparkSdkError::from(ValidationError::InvalidArgument {
489                    argument: e.to_string(),
490                })
491            })?;
492
493            request.source = Some(
494                spark_protos::spark::create_tree_request::Source::OnChainUtxo(
495                    spark_protos::spark::Utxo {
496                        // txid,
497                        vout,
498                        raw_tx,
499                        network: self.config.spark_config.network.marshal_proto(),
500                    },
501                ),
502            );
503            ptx
504        } else if let Some(parent_node) = parent_node {
505            let tx_buf = parent_node.read().node_tx.clone();
506            let ptx: Transaction = bitcoin::consensus::deserialize(&tx_buf).map_err(|_| {
507                SparkSdkError::from(ValidationError::InvalidArgument {
508                    argument: "Failed to parse parent node_tx".to_string(),
509                })
510            })?;
511            let node_id = parent_node.read().id.clone();
512            request.source = Some(
513                spark_protos::spark::create_tree_request::Source::ParentNodeOutput(
514                    spark_protos::spark::NodeOutput { node_id, vout },
515                ),
516            );
517            ptx
518        } else {
519            return Err(SparkSdkError::from(ValidationError::InvalidArgument {
520                argument: "No parent_tx or parent_node provided to create_tree".to_string(),
521            }));
522        };
523
524        let parent_txid = final_parent_tx.compute_txid();
525        let time_start = std::time::Instant::now();
526        let creation_node_response = self.build_creation_nodes_from_tree(
527            parent_txid,
528            &final_parent_tx.output[vout as usize],
529            vout,
530            root.clone(),
531        )?;
532        let duration = time_start.elapsed();
533        #[cfg(feature = "telemetry")]
534        tracing::debug!(duration = ?duration, "build_creation_nodes_from_tree");
535
536        request.node = Some(creation_node_response.creation_nodes.clone());
537
538        let spark_tree_response = self
539            .config
540            .spark_config
541            .call_with_retry(
542                request,
543                |mut client, req| Box::pin(async move { client.create_tree(req).await }),
544                None,
545            )
546            .await?;
547
548        let duration = time_start.elapsed();
549        #[cfg(feature = "telemetry")]
550        tracing::debug!(duration = ?duration, "create_tree");
551
552        let tree_node = spark_tree_response.node.ok_or_else(|| {
553            SparkSdkError::from(ValidationError::InvalidArgument {
554                argument: "Coordinator returned no creation node".to_string(),
555            })
556        })?;
557
558        let time_start = std::time::Instant::now();
559        let (node_signatures, signing_public_keys) = self.signer.sign_created_tree_in_bfs_order(
560            final_parent_tx,
561            vout,
562            root,
563            creation_node_response.creation_nodes,
564            tree_node,
565        )?;
566        let duration = time_start.elapsed();
567        #[cfg(feature = "telemetry")]
568        tracing::debug!(duration = ?duration, "sign_created_tree_in_bfs_order");
569        let spark_signatures_request = FinalizeNodeSignaturesRequest {
570            node_signatures: node_signatures.clone(),
571            ..Default::default()
572        };
573
574        let time_start = std::time::Instant::now();
575
576        let spark_signatures_response = self
577            .config
578            .spark_config
579            .call_with_retry(
580                spark_signatures_request,
581                |mut client, req| {
582                    Box::pin(async move { client.finalize_node_signatures(req).await })
583                },
584                None,
585            )
586            .await?;
587
588        let duration = time_start.elapsed();
589        #[cfg(feature = "telemetry")]
590        tracing::debug!(duration = ?duration, "finalize_node_signatures");
591
592        // Slice the array to get the second half (including the center)
593        let starting_index = spark_signatures_response.nodes.len() / 2;
594
595        let leaf_nodes = spark_signatures_response
596            .nodes
597            .iter()
598            .skip(spark_signatures_response.nodes.len() / 2)
599            .collect::<Vec<_>>();
600
601        let mut leaf_nodes_to_insert = vec![];
602        for (i, node) in leaf_nodes.iter().enumerate() {
603            let node_id = node.id.clone();
604            let i = i + starting_index;
605            if node_id != node_signatures[i].clone().node_id {
606                return Err(SparkSdkError::from(ValidationError::InvalidArgument {
607                    argument: "Node ID mismatch".to_string(),
608                }));
609            }
610
611            // TODO: this is an error-prone approach. For roots, this approach is correct.
612            // For splits, this approach is incorrect because now it should be strictly the right half, excluding the center.
613
614            // TODO: For the parents, add aggregated.
615            leaf_nodes_to_insert.push(SparkLeaf::Bitcoin((*node).clone()));
616        }
617
618        self.leaf_manager
619            .insert_leaves(leaf_nodes_to_insert, false)?;
620
621        Ok(FinalizeTreeCreationSdkResponse {
622            finalize_tree_response: spark_signatures_response,
623            signing_public_keys,
624        })
625    }
626}