spark_rust/wallet/leaf_manager/
mod.rs

1use crate::error::{validation::ValidationError, wallet::WalletError, SparkSdkError};
2use hashbrown::HashMap;
3use parking_lot::RwLock;
4use serde::{Deserialize, Serialize};
5use spark_protos::spark::TreeNode;
6use std::sync::Arc;
7use uuid::Uuid;
8
9use super::internal_handlers::traits::leaves::LeafSelectionResponse;
10
11struct SparkLeafEntry {
12    leaf: SparkLeaf,
13    status: SparkNodeStatus,
14    unlocking_id: Option<String>,
15}
16
17type LeafMap = Arc<RwLock<HashMap<String, SparkLeafEntry>>>;
18
19pub(crate) struct LeafManager {
20    /// The map of leaf nodes
21    leaves: LeafMap,
22}
23
24#[derive(Debug, Clone)]
25pub(crate) enum SparkLeaf {
26    Bitcoin(TreeNode),
27}
28
29impl SparkLeaf {
30    pub(crate) fn get_id(&self) -> &String {
31        match self {
32            SparkLeaf::Bitcoin(leaf) => &leaf.id,
33        }
34    }
35
36    pub(crate) fn get_value(&self) -> u64 {
37        match self {
38            SparkLeaf::Bitcoin(leaf) => leaf.value,
39        }
40    }
41
42    pub(crate) fn is_bitcoin(&self) -> bool {
43        matches!(self, SparkLeaf::Bitcoin(_))
44    }
45
46    pub(crate) fn get_tree_node(&self) -> Result<TreeNode, SparkSdkError> {
47        match self {
48            SparkLeaf::Bitcoin(leaf) => Ok(leaf.clone()),
49        }
50    }
51}
52
53/// Status of a user-owned Spark node. These are mostly leaves, but they can also represent aggregatable branch nodes (parents  ) if the user owns all children of the parent. The status of a transfer (or derivative thereof) may require a swap operation first. In this case, the status won't show the swap operation, since it is a sub-operation of the user-intended flow.
54#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
55pub enum SparkNodeStatus {
56    /// Available for a new operation
57    Available,
58
59    /// Aggregatable parent
60    AggregatableParent,
61
62    /// Currently used for a transfer operation. It is locked and will no longer exist if the transfer operation succeeds.
63    Transfer,
64
65    /// Currently used for a split operation. It is locked and will no longer exist if the split operation succeeds.
66    Split,
67
68    /// Currently used for a swap operation. It is locked and will no longer exist if the swap operation succeeds.
69    Swap,
70
71    /// Currently used for a fee query that will be followed by a transfer (or derivative thereof) operation. It is locked for the operation that pre-fetches the fee. Since fee requests can also be called externally by the user, these locks should have a timeout duration.
72    FeeQuery,
73
74    /// Exiting through the SSP with cooperative exit
75    CooperativeExit,
76
77    /// Refreshing timelock
78    RefreshTimelock,
79}
80
81impl SparkNodeStatus {
82    fn generate_unlocking_id(&self) -> Option<String> {
83        match self {
84            SparkNodeStatus::Transfer | SparkNodeStatus::Split | SparkNodeStatus::Swap => {
85                Some(Uuid::now_v7().to_string())
86            }
87            _ => None,
88        }
89    }
90}
91
92type LeafFilterCallback = Box<dyn Fn(&SparkLeaf) -> bool>;
93
94#[derive(Debug)]
95pub(crate) struct LockLeavesResponse {
96    pub(crate) unlocking_id: Option<String>,
97    pub(crate) leaves: Vec<SparkLeaf>,
98}
99
100pub(crate) type LeafFilterFunction = fn(&SparkLeaf) -> bool;
101
102impl LeafManager {
103    pub(crate) fn new() -> Self {
104        Self {
105            leaves: Arc::new(parking_lot::RwLock::new(HashMap::new())),
106        }
107    }
108
109    #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
110    pub(crate) fn filter_nodes(&self, cb: Option<LeafFilterFunction>) -> Vec<SparkLeaf> {
111        let mut nodes = Vec::new();
112        let guard = self.leaves.read();
113        for node in guard.values() {
114            if cb.as_ref().is_some_and(|f| f(&node.leaf)) {
115                nodes.push(node.leaf.clone());
116            }
117        }
118
119        drop(guard);
120        nodes
121    }
122
123    #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
124    pub(crate) fn filter_nodes_by_ids(&self, leaf_ids: &Vec<String>) -> Vec<SparkLeaf> {
125        let guard = self.leaves.read();
126        let mut nodes = Vec::new();
127        for leaf_id in leaf_ids {
128            if let Some(entry) = guard.get(leaf_id) {
129                nodes.push(entry.leaf.clone());
130            }
131        }
132
133        drop(guard);
134        nodes
135    }
136
137    #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
138    pub(crate) fn filter_and_lock_nodes(
139        &self,
140        cb: Option<LeafFilterFunction>,
141        new_status: Option<SparkNodeStatus>,
142    ) -> Vec<SparkLeaf> {
143        // Check if all nodes are lockable
144        let mut nodes = Vec::new();
145        let unlocking_id = match &new_status {
146            Some(status) => status.generate_unlocking_id(),
147            None => None,
148        };
149
150        let mut guard = self.leaves.write();
151        for node in guard.values_mut() {
152            if cb.as_ref().is_some_and(|f| f(&node.leaf)) {
153                if new_status.is_some() {
154                    node.unlocking_id = unlocking_id.clone();
155                    node.status = new_status.clone().unwrap();
156                }
157                nodes.push(node.leaf.clone());
158            }
159        }
160
161        drop(guard);
162        nodes
163    }
164
165    #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
166    pub(crate) fn delete_and_insert_leaves_atomically(
167        &self,
168        delete_ids: &Vec<String>,
169        insert_leaves: &Vec<SparkLeaf>,
170    ) -> Result<(), SparkSdkError> {
171        let mut guard = self.leaves.write();
172        for id in delete_ids {
173            let removed = guard.remove(id);
174            if removed.is_none() {
175                #[cfg(feature = "telemetry")]
176                tracing::warn!(
177                    "Leaf not found in wallet when deleting and inserting leaves atomically: {}",
178                    id
179                );
180            }
181        }
182
183        for leaf in insert_leaves {
184            guard.insert(
185                leaf.get_id().clone(),
186                SparkLeafEntry {
187                    leaf: leaf.clone(),
188                    status: SparkNodeStatus::Available,
189                    unlocking_id: None,
190                },
191            );
192        }
193
194        Ok(())
195    }
196
197    #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
198    pub(crate) fn get_all_available_leaves(
199        &self,
200        lock_callback: Option<LeafFilterCallback>,
201        new_status: Option<SparkNodeStatus>,
202    ) -> (Vec<SparkLeaf>, Option<String>) {
203        let mut nodes = Vec::new();
204        let mut guard = self.leaves.write();
205        let unlocking_id = match &new_status {
206            Some(status) => status.generate_unlocking_id(),
207            None => None,
208        };
209
210        for node in guard.values_mut() {
211            if node.status == SparkNodeStatus::Available
212                && lock_callback.as_ref().is_some_and(|f| f(&node.leaf))
213            {
214                if new_status.is_some() {
215                    node.status = new_status.clone().unwrap().clone();
216                    node.unlocking_id = unlocking_id.clone();
217                }
218                nodes.push(node.leaf.clone());
219            }
220        }
221
222        drop(guard);
223        (nodes, unlocking_id)
224    }
225
226    #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
227    pub(crate) fn lock_leaf_ids(
228        &self,
229        leaf_ids: &Vec<String>,
230        new_status: SparkNodeStatus,
231    ) -> Result<LeafSelectionResponse, SparkSdkError> {
232        // make sure all leaves exist and are available
233        let mut leaves = Vec::new();
234        let mut guard = self.leaves.write();
235        for leaf_id in leaf_ids {
236            let get_leaf = guard.get(leaf_id);
237            if get_leaf.is_none() {
238                drop(guard);
239                return Err(SparkSdkError::from(WalletError::LeafNotFoundInWallet {
240                    leaf_id: leaf_id.clone(),
241                }));
242            }
243            let leaf = get_leaf.unwrap();
244            if leaf.status != SparkNodeStatus::Available {
245                drop(guard);
246                return Err(SparkSdkError::from(WalletError::LeafNotAvailableForUse {
247                    leaf_id: leaf_id.clone(),
248                }));
249            }
250            leaves.push(leaf.leaf.clone());
251        }
252
253        let unlocking_id = new_status.generate_unlocking_id();
254        for leaf_id in leaf_ids {
255            let leaf = guard.get_mut(leaf_id).unwrap();
256            leaf.status = new_status.clone();
257            leaf.unlocking_id = unlocking_id.clone();
258        }
259
260        // get total value of leaves
261        let total_value = leaves.iter().map(|l| l.get_value()).sum();
262
263        Ok(LeafSelectionResponse {
264            leaves,
265            total_value,
266            unlocking_id,
267            exact_amount: true,
268        })
269    }
270
271    #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
272    pub(crate) fn remove_all_leaves(&self) -> Result<(), SparkSdkError> {
273        let mut guard = self.leaves.write();
274        guard.clear();
275        Ok(())
276    }
277
278    #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
279    pub(crate) fn insert_leaves(
280        &self,
281        new_leaves: Vec<SparkLeaf>,
282        delete_leaves_first: bool,
283    ) -> Result<(), SparkSdkError> {
284        let mut guard = self.leaves.write();
285
286        if delete_leaves_first {
287            guard.clear();
288        }
289
290        // check if any of the leaves already exist before insertion
291        for leaf in &new_leaves {
292            let id = leaf.get_id();
293            if guard.contains_key(id.as_str()) {
294                return Err(SparkSdkError::from(
295                    WalletError::LeafAlreadyExistsInWallet {
296                        leaf_id: id.clone(),
297                    },
298                ));
299            }
300        }
301
302        // insert all leaves since we've verified none exist
303        for leaf in new_leaves {
304            let id = leaf.get_id();
305            let leaf_entry = SparkLeafEntry {
306                leaf: leaf.clone(),
307                status: SparkNodeStatus::Available,
308                unlocking_id: None,
309            };
310            guard.insert(id.clone(), leaf_entry);
311        }
312
313        Ok(())
314    }
315
316    #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
317    pub(crate) fn get_available_bitcoin_value(&self, filter_cb: Option<LeafFilterFunction>) -> u64 {
318        // set the default filter for mapping btc leaves
319        let default_filter = |ln: &SparkLeaf| ln.is_bitcoin();
320
321        // combine the default filter with the custom filter
322        let filter: Box<dyn Fn(&SparkLeaf) -> bool> = if let Some(custom_filter) = filter_cb {
323            let combined_filter =
324                move |node: &SparkLeaf| default_filter(node) && custom_filter(node);
325            Box::new(combined_filter)
326        } else {
327            Box::new(move |node: &SparkLeaf| default_filter(node))
328        };
329
330        let guard = self.leaves.read();
331        let mut available_btc_sum = 0;
332        for node in guard.values() {
333            if node.status != SparkNodeStatus::Available {
334                continue;
335            }
336
337            if filter(&node.leaf) {
338                available_btc_sum += node.leaf.get_value();
339            }
340        }
341
342        drop(guard);
343        available_btc_sum
344    }
345
346    #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
347    pub(crate) fn get_available_bitcoin_leaves(
348        &self,
349        filter_cb: Option<LeafFilterFunction>,
350        new_status: SparkNodeStatus,
351    ) -> Vec<SparkLeaf> {
352        // set the default filter for mapping btc leaves
353        let default_filter = |ln: &SparkLeaf| ln.is_bitcoin();
354
355        // combine the default filter with the custom filter
356        let filter: Box<dyn Fn(&SparkLeaf) -> bool> = if let Some(custom_filter) = filter_cb {
357            let combined_filter =
358                move |node: &SparkLeaf| default_filter(node) && custom_filter(node);
359            Box::new(combined_filter)
360        } else {
361            Box::new(move |node: &SparkLeaf| default_filter(node))
362        };
363
364        let unlocking_id = Uuid::now_v7().to_string();
365
366        let mut guard = self.leaves.write();
367        let mut available_btc_leaves = Vec::new();
368        for node in guard.values_mut() {
369            if node.status != SparkNodeStatus::Available {
370                continue;
371            }
372
373            node.status = new_status.clone();
374            node.unlocking_id = Some(unlocking_id.clone());
375
376            if filter(&node.leaf) {
377                available_btc_leaves.push(node.leaf.clone());
378            }
379        }
380
381        drop(guard);
382        available_btc_leaves
383    }
384
385    /// Lock all available bitcoin leaves.
386    #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
387    pub(crate) fn lock_available_bitcoin_leaves(
388        &self,
389        new_status: SparkNodeStatus,
390    ) -> LockLeavesResponse {
391        // set the default filter for mapping btc leaves
392        let unlocking_id = Uuid::now_v7().to_string();
393        let default_filter: Box<dyn Fn(&SparkLeaf) -> bool> = Box::new(|ln| ln.is_bitcoin());
394
395        let mut guard = self.leaves.write();
396        let mut available_btc_leaves = Vec::new();
397        for node in guard.values_mut() {
398            if node.status != SparkNodeStatus::Available {
399                continue;
400            }
401
402            if default_filter(&node.leaf) {
403                node.status = new_status.clone();
404                node.unlocking_id = Some(unlocking_id.clone());
405                available_btc_leaves.push(node.leaf.clone());
406            }
407        }
408
409        drop(guard);
410        LockLeavesResponse {
411            unlocking_id: Some(unlocking_id),
412            leaves: available_btc_leaves,
413        }
414    }
415
416    #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
417    pub(crate) fn select_leaves(
418        &self,
419        target_amount: u64,
420        new_status: SparkNodeStatus,
421    ) -> Result<LeafSelectionResponse, SparkSdkError> {
422        // check if target amount is 0
423        if target_amount == 0 {
424            return Err(SparkSdkError::from(ValidationError::InvalidInput {
425                field: "Target amount cannot be 0".to_string(),
426            }));
427        }
428
429        // prepare the filter, if token pubkey is None, filter for Bitcoin leaves
430        let filter = Box::new(|_node: &SparkLeaf| true);
431
432        // get all leaves that match the filter
433        let mut guard = self.leaves.write();
434        let mut filtered_leaves = Vec::new();
435        for node in guard.values() {
436            if filter(&node.leaf) {
437                filtered_leaves.push(node.leaf.clone());
438            }
439        }
440
441        // sort leaves in descending order by value
442        filtered_leaves.sort_by_key(|b| std::cmp::Reverse(b.get_value()));
443
444        let unlocking_id = new_status.generate_unlocking_id();
445        let mut total_value = 0;
446        let mut leaves = Vec::new();
447        for leaf in filtered_leaves {
448            if target_amount - total_value >= leaf.get_value() {
449                let leaf_id = leaf.get_id().clone();
450
451                total_value += leaf.get_value();
452                leaves.push(leaf);
453
454                // lock the leaf
455                let leaf_entry = guard.get_mut(&leaf_id).unwrap();
456                leaf_entry.status = new_status.clone();
457                leaf_entry.unlocking_id = unlocking_id.clone();
458            }
459        }
460
461        drop(guard);
462
463        Ok(LeafSelectionResponse {
464            leaves,
465            total_value,
466            unlocking_id: unlocking_id.clone(),
467            exact_amount: total_value == target_amount,
468        })
469    }
470
471    #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
472    pub(crate) fn unlock_leaves(
473        &self,
474        unlocking_id: String,
475        leaf_ids: &Vec<String>,
476        delete: bool,
477    ) -> Result<(), SparkSdkError> {
478        let mut leaves = self.leaves.write();
479
480        // make sure that all leaves are in transfer lock with the same request id
481        for leaf_id in leaf_ids {
482            let leaf = leaves.get(leaf_id).ok_or_else(|| {
483                SparkSdkError::from(WalletError::LeafNotFoundInWallet {
484                    leaf_id: leaf_id.clone(),
485                })
486            })?;
487
488            if leaf.unlocking_id != Some(unlocking_id.clone()) {
489                return Err(SparkSdkError::from(WalletError::LeafNotUsingExpectedLock {
490                    expected: unlocking_id.clone(),
491                    actual: leaf.unlocking_id.clone().unwrap_or_default(),
492                }));
493            }
494        }
495
496        if delete {
497            for leaf_id in leaf_ids {
498                leaves.remove(leaf_id);
499            }
500
501            drop(leaves);
502            return Ok(());
503        }
504
505        // update status to available
506        for leaf_id in leaf_ids {
507            let leaf = leaves.get_mut(leaf_id).unwrap();
508            leaf.status = SparkNodeStatus::Available;
509            leaf.unlocking_id = None;
510        }
511
512        drop(leaves);
513        Ok(())
514    }
515}