spark_rust/wallet/leaf_manager/
mod.rs

1use crate::error::{validation::ValidationError, wallet::WalletError, SparkSdkError};
2use hashbrown::HashMap;
3use parking_lot::RwLock;
4use serde::{Deserialize, Serialize};
5use spark_protos::spark::TreeNode;
6use std::sync::Arc;
7use uuid::Uuid;
8
9use super::internal_handlers::traits::leaves::LeafSelectionResponse;
10
11struct SparkLeafEntry {
12    leaf: SparkLeaf,
13    status: SparkNodeStatus,
14    unlocking_id: Option<String>,
15}
16
17type LeafMap = Arc<RwLock<HashMap<String, SparkLeafEntry>>>;
18
19pub(crate) struct LeafManager {
20    /// The map of leaf nodes
21    leaves: LeafMap,
22}
23
24#[derive(Debug, Clone)]
25pub(crate) struct TokenLeaf {
26    /// The id of the leaf node. This is used to derive the child index of the leaf node as well.
27    pub(crate) id: String,
28
29    /// The tree id of the leaf node
30    pub(crate) _tree_id: String,
31
32    /// The value of the leaf node
33    pub(crate) value: u64,
34
35    /// The token public key
36    pub(crate) token_public_key: Vec<u8>,
37
38    /// Revocation public key (for tokens only)
39    pub(crate) _revocation_public_key: Vec<u8>,
40
41    /// Token transaction hash
42    pub(crate) _token_transaction_hash: Vec<u8>,
43}
44
45#[derive(Debug, Clone)]
46pub(crate) enum SparkLeaf {
47    Bitcoin(TreeNode),
48    #[allow(dead_code)]
49    Token(TokenLeaf),
50}
51
52impl SparkLeaf {
53    pub(crate) fn get_id(&self) -> &String {
54        match self {
55            SparkLeaf::Bitcoin(leaf) => &leaf.id,
56            SparkLeaf::Token(leaf) => &leaf.id,
57        }
58    }
59
60    pub(crate) fn get_value(&self) -> u64 {
61        match self {
62            SparkLeaf::Bitcoin(leaf) => leaf.value,
63            SparkLeaf::Token(leaf) => leaf.value,
64        }
65    }
66
67    pub(crate) fn is_bitcoin(&self) -> bool {
68        matches!(self, SparkLeaf::Bitcoin(_))
69    }
70
71    pub(crate) fn get_token_pubkey(&self) -> Option<Vec<u8>> {
72        match self {
73            SparkLeaf::Bitcoin(_) => None,
74            SparkLeaf::Token(leaf) => Some(leaf.token_public_key.clone()),
75        }
76    }
77
78    pub(crate) fn get_tree_node(&self) -> Result<TreeNode, SparkSdkError> {
79        match self {
80            SparkLeaf::Bitcoin(leaf) => Ok(leaf.clone()),
81            SparkLeaf::Token(_) => Err(SparkSdkError::from(WalletError::LeafIsNotBitcoin {
82                leaf_id: self.get_id().clone(),
83            })),
84        }
85    }
86}
87
88/// Status of a user-owned Spark node. These are mostly leaves, but they can also represent aggregatable branch nodes (parents  ) if the user owns all children of the parent. The status of a transfer (or derivative thereof) may require a swap operation first. In this case, the status won't show the swap operation, since it is a sub-operation of the user-intended flow.
89#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
90pub enum SparkNodeStatus {
91    /// Available for a new operation
92    Available,
93
94    /// Aggregatable parent
95    AggregatableParent,
96
97    /// Currently used for a transfer operation. It is locked and will no longer exist if the transfer operation succeeds.
98    Transfer,
99
100    /// Currently used for a split operation. It is locked and will no longer exist if the split operation succeeds.
101    Split,
102
103    /// Currently used for a swap operation. It is locked and will no longer exist if the swap operation succeeds.
104    Swap,
105
106    /// Currently used for a fee query that will be followed by a transfer (or derivative thereof) operation. It is locked for the operation that pre-fetches the fee. Since fee requests can also be called externally by the user, these locks should have a timeout duration.
107    FeeQuery,
108
109    /// Exiting through the SSP with cooperative exit
110    CooperativeExit,
111}
112
113impl SparkNodeStatus {
114    fn generate_unlocking_id(&self) -> Option<String> {
115        match self {
116            SparkNodeStatus::Transfer | SparkNodeStatus::Split | SparkNodeStatus::Swap => {
117                Some(Uuid::now_v7().to_string())
118            }
119            _ => None,
120        }
121    }
122}
123
124#[derive(Debug)]
125pub(crate) struct LockLeavesResponse {
126    pub(crate) unlocking_id: Option<String>,
127    pub(crate) leaves: Vec<SparkLeaf>,
128}
129
130pub(crate) type LeafFilterFunction = fn(&SparkLeaf) -> bool;
131
132impl LeafManager {
133    pub(crate) fn new() -> Self {
134        Self {
135            leaves: Arc::new(parking_lot::RwLock::new(HashMap::new())),
136        }
137    }
138
139    #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
140    pub(crate) fn filter_nodes(&self, cb: Option<LeafFilterFunction>) -> Vec<SparkLeaf> {
141        let mut nodes = Vec::new();
142        let guard = self.leaves.read();
143        for node in guard.values() {
144            if cb.as_ref().is_some_and(|f| f(&node.leaf)) {
145                nodes.push(node.leaf.clone());
146            }
147        }
148
149        drop(guard);
150        nodes
151    }
152
153    #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
154    pub(crate) fn filter_nodes_by_ids(&self, leaf_ids: &Vec<String>) -> Vec<SparkLeaf> {
155        let guard = self.leaves.read();
156        let mut nodes = Vec::new();
157        for leaf_id in leaf_ids {
158            if let Some(entry) = guard.get(leaf_id) {
159                nodes.push(entry.leaf.clone());
160            }
161        }
162
163        drop(guard);
164        nodes
165    }
166
167    #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
168    pub(crate) fn filter_and_lock_nodes(
169        &self,
170        cb: Option<LeafFilterFunction>,
171        new_status: Option<SparkNodeStatus>,
172    ) -> Vec<SparkLeaf> {
173        // Check if all nodes are lockable
174        let mut nodes = Vec::new();
175        let unlocking_id = match &new_status {
176            Some(status) => status.generate_unlocking_id(),
177            None => None,
178        };
179
180        let mut guard = self.leaves.write();
181        for node in guard.values_mut() {
182            if cb.as_ref().is_some_and(|f| f(&node.leaf)) {
183                if new_status.is_some() {
184                    node.unlocking_id = unlocking_id.clone();
185                    node.status = new_status.clone().unwrap();
186                }
187                nodes.push(node.leaf.clone());
188            }
189        }
190
191        drop(guard);
192        nodes
193    }
194
195    #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
196    pub(crate) fn lock_leaf_ids(
197        &self,
198        leaf_ids: &Vec<String>,
199        new_status: SparkNodeStatus,
200    ) -> Result<LeafSelectionResponse, SparkSdkError> {
201        // make sure all leaves exist and are available
202        let mut leaves = Vec::new();
203        let mut guard = self.leaves.write();
204        for leaf_id in leaf_ids {
205            let get_leaf = guard.get(leaf_id);
206            if get_leaf.is_none() {
207                drop(guard);
208                return Err(SparkSdkError::from(WalletError::LeafNotFoundInWallet {
209                    leaf_id: leaf_id.clone(),
210                }));
211            }
212            let leaf = get_leaf.unwrap();
213            if leaf.status != SparkNodeStatus::Available {
214                drop(guard);
215                return Err(SparkSdkError::from(WalletError::LeafNotAvailableForUse {
216                    leaf_id: leaf_id.clone(),
217                }));
218            }
219            leaves.push(leaf.leaf.clone());
220        }
221
222        let unlocking_id = new_status.generate_unlocking_id();
223        for leaf_id in leaf_ids {
224            let leaf = guard.get_mut(leaf_id).unwrap();
225            leaf.status = new_status.clone();
226            leaf.unlocking_id = unlocking_id.clone();
227        }
228
229        // get total value of leaves
230        let total_value = leaves.iter().map(|l| l.get_value()).sum();
231
232        Ok(LeafSelectionResponse {
233            leaves,
234            total_value,
235            unlocking_id,
236            exact_amount: true,
237        })
238    }
239
240    #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
241    pub(crate) fn remove_all_leaves(&self) -> Result<(), SparkSdkError> {
242        let mut guard = self.leaves.write();
243        guard.clear();
244        Ok(())
245    }
246
247    #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
248    pub(crate) fn insert_leaves(
249        &self,
250        new_leaves: Vec<SparkLeaf>,
251        delete_leaves_first: bool,
252    ) -> Result<(), SparkSdkError> {
253        let mut guard = self.leaves.write();
254
255        if delete_leaves_first {
256            guard.clear();
257        }
258
259        // check if any of the leaves already exist before insertion
260        for leaf in &new_leaves {
261            let id = leaf.get_id();
262            if guard.contains_key(id.as_str()) {
263                return Err(SparkSdkError::from(
264                    WalletError::LeafAlreadyExistsInWallet {
265                        leaf_id: id.clone(),
266                    },
267                ));
268            }
269        }
270
271        // insert all leaves since we've verified none exist
272        for leaf in new_leaves {
273            let id = leaf.get_id();
274            let leaf_entry = SparkLeafEntry {
275                leaf: leaf.clone(),
276                status: SparkNodeStatus::Available,
277                unlocking_id: None,
278            };
279            guard.insert(id.clone(), leaf_entry);
280        }
281
282        Ok(())
283    }
284
285    #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
286    pub(crate) fn get_available_bitcoin_value(&self, filter_cb: Option<LeafFilterFunction>) -> u64 {
287        // set the default filter for mapping btc leaves
288        let default_filter = |ln: &SparkLeaf| ln.is_bitcoin();
289
290        // combine the default filter with the custom filter
291        let filter: Box<dyn Fn(&SparkLeaf) -> bool> = if let Some(custom_filter) = filter_cb {
292            let combined_filter =
293                move |node: &SparkLeaf| default_filter(node) && custom_filter(node);
294            Box::new(combined_filter)
295        } else {
296            Box::new(move |node: &SparkLeaf| default_filter(node))
297        };
298
299        let guard = self.leaves.read();
300        let mut available_btc_sum = 0;
301        for node in guard.values() {
302            if node.status != SparkNodeStatus::Available {
303                continue;
304            }
305
306            if filter(&node.leaf) {
307                available_btc_sum += node.leaf.get_value();
308            }
309        }
310
311        drop(guard);
312        available_btc_sum
313    }
314
315    #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
316    pub(crate) fn get_available_bitcoin_leaves(
317        &self,
318        filter_cb: Option<LeafFilterFunction>,
319        new_status: SparkNodeStatus,
320    ) -> Vec<SparkLeaf> {
321        // set the default filter for mapping btc leaves
322        let default_filter = |ln: &SparkLeaf| ln.is_bitcoin();
323
324        // combine the default filter with the custom filter
325        let filter: Box<dyn Fn(&SparkLeaf) -> bool> = if let Some(custom_filter) = filter_cb {
326            let combined_filter =
327                move |node: &SparkLeaf| default_filter(node) && custom_filter(node);
328            Box::new(combined_filter)
329        } else {
330            Box::new(move |node: &SparkLeaf| default_filter(node))
331        };
332
333        let unlocking_id = Uuid::now_v7().to_string();
334
335        let mut guard = self.leaves.write();
336        let mut available_btc_leaves = Vec::new();
337        for node in guard.values_mut() {
338            if node.status != SparkNodeStatus::Available {
339                continue;
340            }
341
342            node.status = new_status.clone();
343            node.unlocking_id = Some(unlocking_id.clone());
344
345            if filter(&node.leaf) {
346                available_btc_leaves.push(node.leaf.clone());
347            }
348        }
349
350        drop(guard);
351        available_btc_leaves
352    }
353
354    /// Lock all available bitcoin leaves.
355    #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
356    pub(crate) fn lock_available_bitcoin_leaves(
357        &self,
358        new_status: SparkNodeStatus,
359    ) -> LockLeavesResponse {
360        // set the default filter for mapping btc leaves
361        let unlocking_id = Uuid::now_v7().to_string();
362        let default_filter: Box<dyn Fn(&SparkLeaf) -> bool> = Box::new(|ln| ln.is_bitcoin());
363
364        let mut guard = self.leaves.write();
365        let mut available_btc_leaves = Vec::new();
366        for node in guard.values_mut() {
367            if node.status != SparkNodeStatus::Available {
368                continue;
369            }
370
371            if default_filter(&node.leaf) {
372                node.status = new_status.clone();
373                node.unlocking_id = Some(unlocking_id.clone());
374                available_btc_leaves.push(node.leaf.clone());
375            }
376        }
377
378        drop(guard);
379        LockLeavesResponse {
380            unlocking_id: Some(unlocking_id),
381            leaves: available_btc_leaves,
382        }
383    }
384
385    #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
386    pub(crate) fn select_leaves(
387        &self,
388        target_amount: u64,
389        token_pubkey: Option<Vec<u8>>,
390        new_status: SparkNodeStatus,
391    ) -> Result<LeafSelectionResponse, SparkSdkError> {
392        // check if target amount is 0
393        if target_amount == 0 {
394            return Err(SparkSdkError::from(ValidationError::InvalidInput {
395                field: "Target amount cannot be 0".to_string(),
396            }));
397        }
398
399        // prepare the filter, if token pubkey is None, filter for Bitcoin leaves
400        let filter = Box::new(|node: &SparkLeaf| node.get_token_pubkey() == token_pubkey);
401
402        // get all leaves that match the filter
403        let mut guard = self.leaves.write();
404        let mut filtered_leaves = Vec::new();
405        for node in guard.values() {
406            if filter(&node.leaf) {
407                filtered_leaves.push(node.leaf.clone());
408            }
409        }
410
411        // sort leaves in descending order by value
412        filtered_leaves.sort_by_key(|b| std::cmp::Reverse(b.get_value()));
413
414        let unlocking_id = new_status.generate_unlocking_id();
415        let mut total_value = 0;
416        let mut leaves = Vec::new();
417        let mut target_reached = false;
418
419        for leaf in filtered_leaves {
420            // If we already reached the target and added one more leaf, we're done
421            if target_reached {
422                break;
423            }
424
425            let leaf_id = leaf.get_id().clone();
426
427            // Add the leaf and update total
428            total_value += leaf.get_value();
429            leaves.push(leaf);
430
431            // Lock the leaf
432            let leaf_entry = guard.get_mut(&leaf_id).unwrap();
433            leaf_entry.status = new_status.clone();
434            leaf_entry.unlocking_id = unlocking_id.clone();
435
436            // Check if we've reached or exceeded the target amount
437            // If so, mark that we've reached the target, so we'll add one more leaf and then stop
438            if total_value >= target_amount {
439                target_reached = true;
440            }
441        }
442
443        drop(guard);
444
445        Ok(LeafSelectionResponse {
446            leaves,
447            total_value,
448            unlocking_id: unlocking_id.clone(),
449            exact_amount: total_value == target_amount,
450        })
451    }
452
453    #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
454    pub(crate) fn unlock_leaves(
455        &self,
456        unlocking_id: String,
457        leaf_ids: &Vec<String>,
458        delete: bool,
459    ) -> Result<(), SparkSdkError> {
460        let mut leaves = self.leaves.write();
461
462        // make sure that all leaves are in transfer lock with the same request id
463        for leaf_id in leaf_ids {
464            let leaf = leaves.get(leaf_id).ok_or_else(|| {
465                SparkSdkError::from(WalletError::LeafNotFoundInWallet {
466                    leaf_id: leaf_id.clone(),
467                })
468            })?;
469
470            if leaf.unlocking_id != Some(unlocking_id.clone()) {
471                return Err(SparkSdkError::from(WalletError::LeafNotUsingExpectedLock {
472                    expected: unlocking_id.clone(),
473                    actual: leaf.unlocking_id.clone().unwrap_or_default(),
474                }));
475            }
476        }
477
478        if delete {
479            for leaf_id in leaf_ids {
480                leaves.remove(leaf_id);
481            }
482
483            drop(leaves);
484            return Ok(());
485        }
486
487        // update status to available
488        for leaf_id in leaf_ids {
489            let leaf = leaves.get_mut(leaf_id).unwrap();
490            leaf.status = SparkNodeStatus::Available;
491            leaf.unlocking_id = None;
492        }
493
494        drop(leaves);
495        Ok(())
496    }
497}