spark_rust/wallet/internal_handlers/implementations/
create_tree.rs

1use crate::common_types::types::Address;
2use crate::common_types::types::Encodable;
3use crate::common_types::types::Secp256k1;
4use crate::common_types::types::Transaction;
5use crate::common_types::types::TxIn;
6use crate::common_types::types::TxOut;
7use crate::common_types::types::Txid;
8use crate::error::{network::NetworkError, validation::ValidationError, SparkSdkError};
9use crate::signer::traits::derivation_path::SparkKeyType;
10use crate::signer::traits::SparkSigner;
11use crate::wallet::client::SparkSdk;
12use crate::wallet::internal_handlers::traits::create_tree::BuildCreationNodesFromTreeSdkResponse;
13use crate::wallet::internal_handlers::traits::create_tree::CreateAddressRequestNodeFromTreeNodesSdkResponse;
14use crate::wallet::internal_handlers::traits::create_tree::CreateDepositAddressBinaryTreeSdkResponse;
15use crate::wallet::internal_handlers::traits::create_tree::CreateTreeInternalHandlers;
16use crate::wallet::internal_handlers::traits::create_tree::DepositAddressTree;
17use crate::wallet::internal_handlers::traits::create_tree::FinalizeTreeCreationSdkResponse;
18use crate::wallet::internal_handlers::traits::create_tree::GenerateDepositAddressForTreeSdkResponse;
19use crate::wallet::leaf_manager::SparkLeaf;
20use crate::wallet::utils::frost::frost_commitment_to_proto_commitment;
21use bitcoin::secp256k1::PublicKey;
22use parking_lot::RwLock;
23use spark_protos::spark::prepare_tree_address_request::Source as SourceProto;
24use spark_protos::spark::AddressNode as AddressNodeProto;
25use spark_protos::spark::AddressRequestNode as AddressRequestNodeProto;
26use spark_protos::spark::CreateTreeRequest;
27use spark_protos::spark::FinalizeNodeSignaturesRequest;
28use spark_protos::spark::NodeOutput as NodeOutputProto;
29use spark_protos::spark::PrepareTreeAddressRequest;
30use spark_protos::spark::TreeNode as TreeNodeProto;
31use spark_protos::spark::Utxo as UtxoProto;
32use std::collections::VecDeque;
33use std::str::FromStr;
34use std::sync::Arc;
35use tonic::async_trait;
36use uuid::Uuid;
37
38#[async_trait]
39impl<S: SparkSigner + Send + Sync + Clone + 'static> CreateTreeInternalHandlers<S> for SparkSdk<S> {
40    /// Creates a binary tree of deposit addresses.
41    ///
42    /// The tree is created by recursively splitting the target signing private key into two halves.
43    /// Each node in the tree represents a deposit address, and the children of each node are the
44    /// next level of the tree.
45    ///
46    /// # Arguments
47    ///
48    /// * `split_level` - The level of the tree to create.
49    /// * `target_pubkey` - The public key to split into the tree.
50    #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
51    fn create_deposit_address_binary_tree(
52        &self,
53        split_level: u32,
54        target_pubkey: &Vec<u8>,
55    ) -> Result<CreateDepositAddressBinaryTreeSdkResponse, SparkSdkError> {
56        if split_level == 0 {
57            return Ok(CreateDepositAddressBinaryTreeSdkResponse { tree: vec![] });
58        }
59
60        // generate left pubkey
61        let left_pubkey_uuid = Uuid::now_v7().to_string();
62        let left_pubkey = self.signer.new_secp256k1_keypair(
63            left_pubkey_uuid,
64            SparkKeyType::TemporarySigning,
65            0,
66            self.config.spark_config.network.to_bitcoin_network(),
67        )?;
68        let left_pubkey = left_pubkey.serialize().to_vec();
69
70        // left node
71        let mut left_node = DepositAddressTree {
72            address: None,
73            verification_key: None,
74            signing_public_key: left_pubkey.clone(),
75            children: vec![],
76        };
77
78        // create left children recursively
79        let left_children =
80            self.create_deposit_address_binary_tree(split_level - 1, &left_pubkey)?;
81        left_node.children = left_children.tree;
82
83        // calculate right pubkey
84        let right_pubkey = self.signer.subtract_secret_keys_given_pubkeys(
85            &PublicKey::from_slice(target_pubkey)?,
86            &PublicKey::from_slice(&left_pubkey)?,
87            true,
88        )?;
89
90        // right node
91        let mut right_node = DepositAddressTree {
92            address: None,
93            verification_key: None,
94            signing_public_key: right_pubkey.serialize().to_vec(),
95            children: vec![],
96        };
97
98        // create right children recursively
99        let right_children = self.create_deposit_address_binary_tree(
100            split_level - 1,
101            &right_pubkey.serialize().to_vec(),
102        )?;
103        right_node.children = right_children.tree;
104
105        Ok(CreateDepositAddressBinaryTreeSdkResponse {
106            tree: vec![
107                Arc::new(RwLock::new(left_node)),
108                Arc::new(RwLock::new(right_node)),
109            ],
110        })
111    }
112
113    #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
114    fn create_address_request_node_from_tree_nodes(
115        &self,
116        tree_nodes: &Vec<Arc<RwLock<DepositAddressTree>>>,
117    ) -> Result<CreateAddressRequestNodeFromTreeNodesSdkResponse, SparkSdkError> {
118        let mut results = Vec::<AddressRequestNodeProto>::new();
119
120        for node in tree_nodes {
121            let node = node.read();
122
123            let address_request_node =
124                self.create_address_request_node_from_tree_nodes(&node.children)?;
125            let address_request_node = AddressRequestNodeProto {
126                user_public_key: node.signing_public_key.clone(),
127                children: address_request_node.address_request_nodes,
128            };
129            results.push(address_request_node);
130        }
131
132        Ok(CreateAddressRequestNodeFromTreeNodesSdkResponse {
133            address_request_nodes: results,
134        })
135    }
136
137    #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
138    fn apply_address_nodes_to_tree(
139        &self,
140        tree: &mut Vec<Arc<RwLock<DepositAddressTree>>>,
141        address_nodes: Vec<AddressNodeProto>,
142    ) -> Result<(), SparkSdkError> {
143        for (i, node) in tree.iter_mut().enumerate() {
144            let mut node = node.write();
145            let node_address_data = address_nodes[i].address.clone().unwrap();
146            node.address = Some(node_address_data.address);
147            node.verification_key = Some(node_address_data.verifying_key);
148
149            if !node.children.is_empty() {
150                self.apply_address_nodes_to_tree(
151                    &mut node.children,
152                    address_nodes[i].children.clone(),
153                )?;
154            }
155        }
156
157        Ok(())
158    }
159
160    #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
161    async fn generate_deposit_address_for_tree(
162        &self,
163        parent_tx: Option<Transaction>,
164        parent_node: Option<Arc<RwLock<TreeNodeProto>>>,
165        vout: u32,
166        parent_public_key: Vec<u8>,
167        split_level: u32,
168    ) -> Result<GenerateDepositAddressForTreeSdkResponse, SparkSdkError> {
169        let mut spark_client = self.config.spark_config.get_spark_connection(None).await?;
170        let network_proto = self.config.spark_config.network.marshal_proto();
171
172        // 1. Create the binary tree given the user request
173        let time_start = std::time::Instant::now();
174        let deposit_address_tree =
175            self.create_deposit_address_binary_tree(split_level, &parent_public_key)?;
176        let duration = time_start.elapsed();
177        #[cfg(feature = "telemetry")]
178        tracing::debug!(duration = ?duration, "create_deposit_address_binary_tree");
179
180        // If split_level = 0, len = 1. if split_level = 1, len = 1 + 2 = 3. If split_level = 2, len = 1 + 2 + 4 = 7.
181        // This is because the tree is a binary tree, and the number of nodes is 2^split_level.
182        // assert!(tree.len() == 2u32.pow(split_level) as usize);
183
184        // 2. Create the address request nodes (in proto format) from the tree nodes
185        let time_start = std::time::Instant::now();
186        let address_nodes =
187            self.create_address_request_node_from_tree_nodes(&deposit_address_tree.tree)?;
188        let duration = time_start.elapsed();
189        #[cfg(feature = "telemetry")]
190        tracing::debug!(duration = ?duration, "create_address_request_node_from_tree_nodes");
191
192        // 3. Send PrepareTreeAddressRequest to Spark. This is the first step of tree creation.
193        let request_source = match parent_node {
194            Some(parent_node) => {
195                let node_output = NodeOutputProto {
196                    node_id: parent_node.read().id.clone(),
197                    vout,
198                };
199                SourceProto::ParentNodeOutput(node_output)
200            }
201            None => {
202                let mut raw_tx = Vec::new();
203                parent_tx
204                    .as_ref()
205                    .unwrap()
206                    .consensus_encode(&mut raw_tx)
207                    .map_err(|e| {
208                        SparkSdkError::from(ValidationError::InvalidArgument {
209                            argument: e.to_string(),
210                        })
211                    })?;
212                let utxo = UtxoProto {
213                    vout,
214                    raw_tx,
215                    network: network_proto,
216                };
217                SourceProto::OnChainUtxo(utxo)
218            }
219        };
220        let source = Some(request_source);
221        let request_node = AddressRequestNodeProto {
222            user_public_key: parent_public_key.clone(),
223            children: address_nodes.address_request_nodes,
224        };
225        let node = Some(request_node);
226
227        let mut request = tonic::Request::new(PrepareTreeAddressRequest {
228            user_identity_public_key: self.get_spark_address()?.serialize().to_vec(),
229            node,
230            source,
231        });
232        self.add_authorization_header_to_request(&mut request, None);
233
234        let time_start = std::time::Instant::now();
235
236        let spark_tree_response_ = spark_client
237            .prepare_tree_address(request)
238            .await
239            .map_err(|status| SparkSdkError::from(NetworkError::Status(status)))?;
240        let spark_tree_response = spark_tree_response_.into_inner();
241        let duration = time_start.elapsed();
242        #[cfg(feature = "telemetry")]
243        tracing::debug!(duration = ?duration, "prepare_tree_address");
244
245        // 4. Create the root node
246        let response_address_node = spark_tree_response.node.unwrap();
247        let root = DepositAddressTree {
248            address: None,
249            verification_key: None,
250            signing_public_key: parent_public_key.clone(),
251            children: deposit_address_tree.tree.clone(),
252        };
253        let root = Arc::new(RwLock::new(root));
254
255        // 5. Apply the address nodes to the tree
256        let mut root_in_vec = vec![root];
257        let time_start = std::time::Instant::now();
258        self.apply_address_nodes_to_tree(&mut root_in_vec, vec![response_address_node])?;
259        let duration = time_start.elapsed();
260        #[cfg(feature = "telemetry")]
261        tracing::debug!(duration = ?duration, "apply_address_nodes_to_tree");
262
263        // 6. Extract the root back from the vector and return it
264        let root = root_in_vec.remove(0);
265
266        Ok(GenerateDepositAddressForTreeSdkResponse { tree: root })
267    }
268
269    #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
270    fn build_creation_nodes_from_tree(
271        &self,
272        parent_txid: Txid,
273        txout_: &TxOut,
274        vout: u32,
275        root: Arc<RwLock<DepositAddressTree>>,
276    ) -> Result<BuildCreationNodesFromTreeSdkResponse, SparkSdkError> {
277        struct TreeNode {
278            parent_txid: Txid,
279            txout: TxOut,
280            vout: u32,
281            node: Arc<RwLock<DepositAddressTree>>,
282        }
283
284        let mut creation_node = spark_protos::spark::CreationNode::default();
285        let mut queue = VecDeque::<(TreeNode, *mut spark_protos::spark::CreationNode)>::new();
286
287        queue.push_back((
288            TreeNode {
289                parent_txid,
290                txout: txout_.clone(),
291                vout,
292                node: root,
293            },
294            &mut creation_node,
295        ));
296
297        let network = self.config.spark_config.network.to_bitcoin_network();
298        let secp = Secp256k1::new();
299
300        while let Some((current, creation_ptr)) = queue.pop_front() {
301            let creation_ref = unsafe { &mut *creation_ptr };
302            let local_node = current.node.read();
303
304            // let node_signing_key = local_node.verification_key.clone();
305            // let user_signing_key = SecretKey::from_slice(&node_signing_key).unwrap();
306            let user_verifying_key_ = local_node.signing_public_key.clone();
307            let user_verifying_key =
308                bitcoin::secp256k1::PublicKey::from_slice(&user_verifying_key_).unwrap();
309
310            if !local_node.children.is_empty() {
311                let tx = {
312                    let tx_input = bitcoin::TxIn {
313                        previous_output: bitcoin::OutPoint {
314                            txid: current.parent_txid,
315                            vout: current.vout,
316                        },
317                        script_sig: bitcoin::ScriptBuf::default(),
318                        sequence: bitcoin::Sequence::ZERO,
319                        witness: bitcoin::Witness::default(),
320                    };
321
322                    let mut output = Vec::new();
323                    for child in local_node.children.iter() {
324                        let child_address = child.read().address.clone().unwrap();
325                        let child_address = Address::from_str(&child_address).unwrap();
326                        let child_address = child_address.require_network(network).unwrap();
327                        output.push(TxOut {
328                            value: bitcoin::Amount::from_sat(
329                                current.txout.value.to_sat() / local_node.children.len() as u64,
330                            ),
331                            script_pubkey: child_address.script_pubkey(),
332                        });
333                    }
334
335                    bitcoin::Transaction {
336                        version: bitcoin::transaction::Version::TWO,
337                        lock_time: bitcoin::absolute::LockTime::ZERO,
338                        input: vec![tx_input],
339                        output,
340                    }
341                };
342
343                let mut tx_buf = vec![];
344                tx.consensus_encode(&mut tx_buf).map_err(|e| {
345                    SparkSdkError::from(ValidationError::InvalidArgument {
346                        argument: e.to_string(),
347                    })
348                })?;
349
350                let commitment = self.signer.new_frost_signing_noncepair()?;
351
352                let signing_job = spark_protos::spark::SigningJob {
353                    signing_public_key: user_verifying_key.serialize().to_vec(),
354                    raw_tx: tx_buf,
355                    signing_nonce_commitment: Some(frost_commitment_to_proto_commitment(
356                        &commitment,
357                    )?),
358                };
359
360                creation_ref.node_tx_signing_job = Some(signing_job);
361                creation_ref.children = vec![Default::default(); local_node.children.len()];
362
363                let txid = tx.compute_txid();
364                for (i, child) in local_node.children.iter().enumerate() {
365                    queue.push_back((
366                        TreeNode {
367                            parent_txid: txid,
368                            txout: tx.output[i].clone(),
369                            vout: i as u32,
370                            node: child.clone(),
371                        },
372                        &mut creation_ref.children[i] as *mut _,
373                    ));
374                }
375            } else {
376                let aggregated_address = local_node.address.clone().unwrap();
377                let aggregated_address = Address::from_str(&aggregated_address).unwrap();
378                let aggregated_address = aggregated_address.require_network(network).unwrap();
379
380                let node_tx = {
381                    let input = bitcoin::TxIn {
382                        previous_output: bitcoin::OutPoint {
383                            txid: current.parent_txid,
384                            vout: current.vout,
385                        },
386                        script_sig: bitcoin::ScriptBuf::default(),
387                        // sequence: bitcoin::Sequence::ZERO,
388                        sequence: bitcoin::Sequence::ZERO,
389                        witness: bitcoin::Witness::default(),
390                    };
391
392                    bitcoin::Transaction {
393                        version: bitcoin::transaction::Version::TWO,
394                        lock_time: bitcoin::absolute::LockTime::ZERO,
395                        input: vec![input],
396                        output: vec![TxOut {
397                            value: current.txout.value,
398                            script_pubkey: aggregated_address.script_pubkey(),
399                        }],
400                    }
401                };
402
403                let mut node_tx_buf = vec![];
404                node_tx.consensus_encode(&mut node_tx_buf).map_err(|e| {
405                    SparkSdkError::from(ValidationError::InvalidArgument {
406                        argument: e.to_string(),
407                    })
408                })?;
409
410                let refund_tx = {
411                    let user_self_xonly = user_verifying_key.x_only_public_key().0;
412                    let user_self_address = Address::p2tr(&secp, user_self_xonly, None, network);
413
414                    bitcoin::Transaction {
415                        version: bitcoin::transaction::Version::TWO,
416                        lock_time: bitcoin::absolute::LockTime::ZERO,
417                        input: vec![TxIn {
418                            previous_output: bitcoin::OutPoint {
419                                txid: node_tx.compute_txid(),
420                                vout: 0,
421                            },
422                            script_sig: bitcoin::ScriptBuf::default(),
423                            // TODO: this must be the default sequence. For tree creation here, we can set it to MAX, since this is an SSP feature and is unlikely to be used here. Yet, this will affect unilateral exits.
424                            sequence: bitcoin::Sequence::MAX,
425                            witness: bitcoin::Witness::default(),
426                        }],
427                        output: vec![TxOut {
428                            value: current.txout.value,
429                            script_pubkey: user_self_address.script_pubkey(),
430                        }],
431                    }
432                };
433
434                let mut refund_tx_buf = vec![];
435                refund_tx
436                    .consensus_encode(&mut refund_tx_buf)
437                    .map_err(|e| {
438                        SparkSdkError::from(ValidationError::InvalidArgument {
439                            argument: e.to_string(),
440                        })
441                    })?;
442
443                let node_commitment = self.signer.new_frost_signing_noncepair()?;
444                let refund_commitment = self.signer.new_frost_signing_noncepair()?;
445
446                creation_ref.node_tx_signing_job = Some(spark_protos::spark::SigningJob {
447                    signing_public_key: user_verifying_key.serialize().to_vec(),
448                    raw_tx: node_tx_buf,
449                    signing_nonce_commitment: Some(frost_commitment_to_proto_commitment(
450                        &node_commitment,
451                    )?),
452                });
453
454                creation_ref.refund_tx_signing_job = Some(spark_protos::spark::SigningJob {
455                    signing_public_key: user_verifying_key.serialize().to_vec(),
456                    raw_tx: refund_tx_buf,
457                    signing_nonce_commitment: Some(frost_commitment_to_proto_commitment(
458                        &refund_commitment,
459                    )?),
460                });
461            }
462        }
463
464        Ok(BuildCreationNodesFromTreeSdkResponse {
465            creation_nodes: creation_node,
466        })
467    }
468
469    #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
470    async fn finalize_tree_creation(
471        &self,
472        parent_tx: Option<Transaction>,
473        parent_node: Option<Arc<RwLock<TreeNodeProto>>>,
474        vout: u32,
475        root: Arc<RwLock<DepositAddressTree>>,
476    ) -> Result<FinalizeTreeCreationSdkResponse, SparkSdkError> {
477        let mut request = CreateTreeRequest {
478            user_identity_public_key: self.get_spark_address()?.serialize().to_vec(),
479            ..Default::default()
480        };
481
482        let final_parent_tx = if let Some(ptx) = parent_tx {
483            let mut raw_tx = Vec::new();
484            ptx.consensus_encode(&mut raw_tx).map_err(|e| {
485                SparkSdkError::from(ValidationError::InvalidArgument {
486                    argument: e.to_string(),
487                })
488            })?;
489
490            request.source = Some(
491                spark_protos::spark::create_tree_request::Source::OnChainUtxo(
492                    spark_protos::spark::Utxo {
493                        // txid,
494                        vout,
495                        raw_tx,
496                        network: self.config.spark_config.network.marshal_proto(),
497                    },
498                ),
499            );
500            ptx
501        } else if let Some(parent_node) = parent_node {
502            let tx_buf = parent_node.read().node_tx.clone();
503            let ptx: Transaction = bitcoin::consensus::deserialize(&tx_buf).map_err(|_| {
504                SparkSdkError::from(ValidationError::InvalidArgument {
505                    argument: "Failed to parse parent node_tx".to_string(),
506                })
507            })?;
508            let node_id = parent_node.read().id.clone();
509            request.source = Some(
510                spark_protos::spark::create_tree_request::Source::ParentNodeOutput(
511                    spark_protos::spark::NodeOutput { node_id, vout },
512                ),
513            );
514            ptx
515        } else {
516            return Err(SparkSdkError::from(ValidationError::InvalidArgument {
517                argument: "No parent_tx or parent_node provided to create_tree".to_string(),
518            }));
519        };
520
521        let parent_txid = final_parent_tx.compute_txid();
522        let time_start = std::time::Instant::now();
523        let creation_node_response = self.build_creation_nodes_from_tree(
524            parent_txid,
525            &final_parent_tx.output[vout as usize],
526            vout,
527            root.clone(),
528        )?;
529        let duration = time_start.elapsed();
530        #[cfg(feature = "telemetry")]
531        tracing::debug!(duration = ?duration, "build_creation_nodes_from_tree");
532
533        request.node = Some(creation_node_response.creation_nodes.clone());
534        let mut tonic_request = tonic::Request::new(request);
535        self.add_authorization_header_to_request(&mut tonic_request, None);
536
537        let mut spark_client = self.config.spark_config.get_spark_connection(None).await?;
538        let time_start = std::time::Instant::now();
539        let resp = spark_client
540            .create_tree(tonic_request)
541            .await
542            .map_err(|status| SparkSdkError::from(NetworkError::Status(status)))?;
543        let duration = time_start.elapsed();
544        #[cfg(feature = "telemetry")]
545        tracing::debug!(duration = ?duration, "create_tree");
546
547        let create_tree_result = resp.into_inner().node.ok_or_else(|| {
548            SparkSdkError::from(ValidationError::InvalidArgument {
549                argument: "Coordinator returned no creation node".to_string(),
550            })
551        })?;
552
553        let time_start = std::time::Instant::now();
554        let (node_signatures, signing_public_keys) = self.signer.sign_created_tree_in_bfs_order(
555            final_parent_tx,
556            vout,
557            root,
558            creation_node_response.creation_nodes,
559            create_tree_result,
560        )?;
561        let duration = time_start.elapsed();
562        #[cfg(feature = "telemetry")]
563        tracing::debug!(duration = ?duration, "sign_created_tree_in_bfs_order");
564
565        let mut spark_client = self.config.spark_config.get_spark_connection(None).await?;
566        let mut spark_signatures_request = tonic::Request::new(FinalizeNodeSignaturesRequest {
567            node_signatures: node_signatures.clone(),
568            ..Default::default()
569        });
570        self.add_authorization_header_to_request(&mut spark_signatures_request, None);
571
572        let time_start = std::time::Instant::now();
573        let spark_signatures_response_ = spark_client
574            .finalize_node_signatures(spark_signatures_request)
575            .await
576            .map_err(|status| SparkSdkError::from(NetworkError::Status(status)))?;
577        let spark_signatures_response = spark_signatures_response_.into_inner();
578        let duration = time_start.elapsed();
579        #[cfg(feature = "telemetry")]
580        tracing::debug!(duration = ?duration, "finalize_node_signatures");
581
582        // Slice the array to get the second half (including the center)
583        let starting_index = spark_signatures_response.nodes.len() / 2;
584
585        let leaf_nodes = spark_signatures_response
586            .nodes
587            .iter()
588            .skip(spark_signatures_response.nodes.len() / 2)
589            .collect::<Vec<_>>();
590
591        let mut leaf_nodes_to_insert = vec![];
592        for (i, node) in leaf_nodes.iter().enumerate() {
593            let node_id = node.id.clone();
594            let i = i + starting_index;
595            if node_id != node_signatures[i].clone().node_id {
596                return Err(SparkSdkError::from(ValidationError::InvalidArgument {
597                    argument: "Node ID mismatch".to_string(),
598                }));
599            }
600
601            // TODO: this is an error-prone approach. For roots, this approach is correct.
602            // For splits, this approach is incorrect because now it should be strictly the right half, excluding the center.
603
604            // TODO: For the parents, add aggregated.
605            leaf_nodes_to_insert.push(SparkLeaf::Bitcoin((*node).clone()));
606        }
607
608        self.leaf_manager
609            .insert_leaves(leaf_nodes_to_insert, false)?;
610
611        Ok(FinalizeTreeCreationSdkResponse {
612            finalize_tree_response: spark_signatures_response,
613            signing_public_keys,
614        })
615    }
616}