spark_rust/wallet/leaf_manager/
mod.rs1use crate::error::{validation::ValidationError, wallet::WalletError, SparkSdkError};
2use hashbrown::HashMap;
3use parking_lot::RwLock;
4use serde::{Deserialize, Serialize};
5use spark_protos::spark::TreeNode;
6use std::sync::Arc;
7use uuid::Uuid;
8
9use super::internal_handlers::traits::leaves::LeafSelectionResponse;
10
11struct SparkLeafEntry {
12 leaf: SparkLeaf,
13 status: SparkNodeStatus,
14 unlocking_id: Option<String>,
15}
16
17type LeafMap = Arc<RwLock<HashMap<String, SparkLeafEntry>>>;
18
19pub(crate) struct LeafManager {
20 leaves: LeafMap,
22}
23
24#[derive(Debug, Clone)]
25pub(crate) enum SparkLeaf {
26 Bitcoin(TreeNode),
27}
28
29impl SparkLeaf {
30 pub(crate) fn get_id(&self) -> &String {
31 match self {
32 SparkLeaf::Bitcoin(leaf) => &leaf.id,
33 }
34 }
35
36 pub(crate) fn get_value(&self) -> u64 {
37 match self {
38 SparkLeaf::Bitcoin(leaf) => leaf.value,
39 }
40 }
41
42 pub(crate) fn is_bitcoin(&self) -> bool {
43 matches!(self, SparkLeaf::Bitcoin(_))
44 }
45
46 pub(crate) fn get_tree_node(&self) -> Result<TreeNode, SparkSdkError> {
47 match self {
48 SparkLeaf::Bitcoin(leaf) => Ok(leaf.clone()),
49 }
50 }
51}
52
53#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
55pub enum SparkNodeStatus {
56 Available,
58
59 AggregatableParent,
61
62 Transfer,
64
65 Split,
67
68 Swap,
70
71 FeeQuery,
73
74 CooperativeExit,
76
77 RefreshTimelock,
79}
80
81impl SparkNodeStatus {
82 fn generate_unlocking_id(&self) -> Option<String> {
83 match self {
84 SparkNodeStatus::Transfer | SparkNodeStatus::Split | SparkNodeStatus::Swap => {
85 Some(Uuid::now_v7().to_string())
86 }
87 _ => None,
88 }
89 }
90}
91
92type LeafFilterCallback = Box<dyn Fn(&SparkLeaf) -> bool>;
93
94#[derive(Debug)]
95pub(crate) struct LockLeavesResponse {
96 pub(crate) unlocking_id: Option<String>,
97 pub(crate) leaves: Vec<SparkLeaf>,
98}
99
100pub(crate) type LeafFilterFunction = fn(&SparkLeaf) -> bool;
101
102impl LeafManager {
103 pub(crate) fn new() -> Self {
104 Self {
105 leaves: Arc::new(parking_lot::RwLock::new(HashMap::new())),
106 }
107 }
108
109 #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
110 pub(crate) fn filter_nodes(&self, cb: Option<LeafFilterFunction>) -> Vec<SparkLeaf> {
111 let mut nodes = Vec::new();
112 let guard = self.leaves.read();
113 for node in guard.values() {
114 if cb.as_ref().is_some_and(|f| f(&node.leaf)) {
115 nodes.push(node.leaf.clone());
116 }
117 }
118
119 drop(guard);
120 nodes
121 }
122
123 #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
124 pub(crate) fn filter_nodes_by_ids(&self, leaf_ids: &Vec<String>) -> Vec<SparkLeaf> {
125 let guard = self.leaves.read();
126 let mut nodes = Vec::new();
127 for leaf_id in leaf_ids {
128 if let Some(entry) = guard.get(leaf_id) {
129 nodes.push(entry.leaf.clone());
130 }
131 }
132
133 drop(guard);
134 nodes
135 }
136
137 #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
138 pub(crate) fn filter_and_lock_nodes(
139 &self,
140 cb: Option<LeafFilterFunction>,
141 new_status: Option<SparkNodeStatus>,
142 ) -> Vec<SparkLeaf> {
143 let mut nodes = Vec::new();
145 let unlocking_id = match &new_status {
146 Some(status) => status.generate_unlocking_id(),
147 None => None,
148 };
149
150 let mut guard = self.leaves.write();
151 for node in guard.values_mut() {
152 if cb.as_ref().is_some_and(|f| f(&node.leaf)) {
153 if new_status.is_some() {
154 node.unlocking_id = unlocking_id.clone();
155 node.status = new_status.clone().unwrap();
156 }
157 nodes.push(node.leaf.clone());
158 }
159 }
160
161 drop(guard);
162 nodes
163 }
164
165 #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
166 pub(crate) fn delete_and_insert_leaves_atomically(
167 &self,
168 delete_ids: &Vec<String>,
169 insert_leaves: &Vec<SparkLeaf>,
170 ) -> Result<(), SparkSdkError> {
171 let mut guard = self.leaves.write();
172 for id in delete_ids {
173 let removed = guard.remove(id);
174 if removed.is_none() {
175 #[cfg(feature = "telemetry")]
176 tracing::warn!(
177 "Leaf not found in wallet when deleting and inserting leaves atomically: {}",
178 id
179 );
180 }
181 }
182
183 for leaf in insert_leaves {
184 guard.insert(
185 leaf.get_id().clone(),
186 SparkLeafEntry {
187 leaf: leaf.clone(),
188 status: SparkNodeStatus::Available,
189 unlocking_id: None,
190 },
191 );
192 }
193
194 Ok(())
195 }
196
197 #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
198 pub(crate) fn get_all_available_leaves(
199 &self,
200 lock_callback: Option<LeafFilterCallback>,
201 new_status: Option<SparkNodeStatus>,
202 ) -> (Vec<SparkLeaf>, Option<String>) {
203 let mut nodes = Vec::new();
204 let mut guard = self.leaves.write();
205 let unlocking_id = match &new_status {
206 Some(status) => status.generate_unlocking_id(),
207 None => None,
208 };
209
210 for node in guard.values_mut() {
211 if node.status == SparkNodeStatus::Available
212 && lock_callback.as_ref().is_some_and(|f| f(&node.leaf))
213 {
214 if new_status.is_some() {
215 node.status = new_status.clone().unwrap().clone();
216 node.unlocking_id = unlocking_id.clone();
217 }
218 nodes.push(node.leaf.clone());
219 }
220 }
221
222 drop(guard);
223 (nodes, unlocking_id)
224 }
225
226 #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
227 pub(crate) fn lock_leaf_ids(
228 &self,
229 leaf_ids: &Vec<String>,
230 new_status: SparkNodeStatus,
231 ) -> Result<LeafSelectionResponse, SparkSdkError> {
232 let mut leaves = Vec::new();
234 let mut guard = self.leaves.write();
235 for leaf_id in leaf_ids {
236 let get_leaf = guard.get(leaf_id);
237 if get_leaf.is_none() {
238 drop(guard);
239 return Err(SparkSdkError::from(WalletError::LeafNotFoundInWallet {
240 leaf_id: leaf_id.clone(),
241 }));
242 }
243 let leaf = get_leaf.unwrap();
244 if leaf.status != SparkNodeStatus::Available {
245 drop(guard);
246 return Err(SparkSdkError::from(WalletError::LeafNotAvailableForUse {
247 leaf_id: leaf_id.clone(),
248 }));
249 }
250 leaves.push(leaf.leaf.clone());
251 }
252
253 let unlocking_id = new_status.generate_unlocking_id();
254 for leaf_id in leaf_ids {
255 let leaf = guard.get_mut(leaf_id).unwrap();
256 leaf.status = new_status.clone();
257 leaf.unlocking_id = unlocking_id.clone();
258 }
259
260 let total_value = leaves.iter().map(|l| l.get_value()).sum();
262
263 Ok(LeafSelectionResponse {
264 leaves,
265 total_value,
266 unlocking_id,
267 exact_amount: true,
268 })
269 }
270
271 #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
272 pub(crate) fn remove_all_leaves(&self) -> Result<(), SparkSdkError> {
273 let mut guard = self.leaves.write();
274 guard.clear();
275 Ok(())
276 }
277
278 #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
279 pub(crate) fn insert_leaves(
280 &self,
281 new_leaves: Vec<SparkLeaf>,
282 delete_leaves_first: bool,
283 ) -> Result<(), SparkSdkError> {
284 let mut guard = self.leaves.write();
285
286 if delete_leaves_first {
287 guard.clear();
288 }
289
290 for leaf in &new_leaves {
292 let id = leaf.get_id();
293 if guard.contains_key(id.as_str()) {
294 return Err(SparkSdkError::from(
295 WalletError::LeafAlreadyExistsInWallet {
296 leaf_id: id.clone(),
297 },
298 ));
299 }
300 }
301
302 for leaf in new_leaves {
304 let id = leaf.get_id();
305 let leaf_entry = SparkLeafEntry {
306 leaf: leaf.clone(),
307 status: SparkNodeStatus::Available,
308 unlocking_id: None,
309 };
310 guard.insert(id.clone(), leaf_entry);
311 }
312
313 Ok(())
314 }
315
316 #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
317 pub(crate) fn get_available_bitcoin_value(&self, filter_cb: Option<LeafFilterFunction>) -> u64 {
318 let default_filter = |ln: &SparkLeaf| ln.is_bitcoin();
320
321 let filter: Box<dyn Fn(&SparkLeaf) -> bool> = if let Some(custom_filter) = filter_cb {
323 let combined_filter =
324 move |node: &SparkLeaf| default_filter(node) && custom_filter(node);
325 Box::new(combined_filter)
326 } else {
327 Box::new(move |node: &SparkLeaf| default_filter(node))
328 };
329
330 let guard = self.leaves.read();
331 let mut available_btc_sum = 0;
332 for node in guard.values() {
333 if node.status != SparkNodeStatus::Available {
334 continue;
335 }
336
337 if filter(&node.leaf) {
338 available_btc_sum += node.leaf.get_value();
339 }
340 }
341
342 drop(guard);
343 available_btc_sum
344 }
345
346 #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
347 pub(crate) fn get_available_bitcoin_leaves(
348 &self,
349 filter_cb: Option<LeafFilterFunction>,
350 new_status: SparkNodeStatus,
351 ) -> Vec<SparkLeaf> {
352 let default_filter = |ln: &SparkLeaf| ln.is_bitcoin();
354
355 let filter: Box<dyn Fn(&SparkLeaf) -> bool> = if let Some(custom_filter) = filter_cb {
357 let combined_filter =
358 move |node: &SparkLeaf| default_filter(node) && custom_filter(node);
359 Box::new(combined_filter)
360 } else {
361 Box::new(move |node: &SparkLeaf| default_filter(node))
362 };
363
364 let unlocking_id = Uuid::now_v7().to_string();
365
366 let mut guard = self.leaves.write();
367 let mut available_btc_leaves = Vec::new();
368 for node in guard.values_mut() {
369 if node.status != SparkNodeStatus::Available {
370 continue;
371 }
372
373 node.status = new_status.clone();
374 node.unlocking_id = Some(unlocking_id.clone());
375
376 if filter(&node.leaf) {
377 available_btc_leaves.push(node.leaf.clone());
378 }
379 }
380
381 drop(guard);
382 available_btc_leaves
383 }
384
385 #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
387 pub(crate) fn lock_available_bitcoin_leaves(
388 &self,
389 new_status: SparkNodeStatus,
390 ) -> LockLeavesResponse {
391 let unlocking_id = Uuid::now_v7().to_string();
393 let default_filter: Box<dyn Fn(&SparkLeaf) -> bool> = Box::new(|ln| ln.is_bitcoin());
394
395 let mut guard = self.leaves.write();
396 let mut available_btc_leaves = Vec::new();
397 for node in guard.values_mut() {
398 if node.status != SparkNodeStatus::Available {
399 continue;
400 }
401
402 if default_filter(&node.leaf) {
403 node.status = new_status.clone();
404 node.unlocking_id = Some(unlocking_id.clone());
405 available_btc_leaves.push(node.leaf.clone());
406 }
407 }
408
409 drop(guard);
410 LockLeavesResponse {
411 unlocking_id: Some(unlocking_id),
412 leaves: available_btc_leaves,
413 }
414 }
415
416 #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
417 pub(crate) fn select_leaves(
418 &self,
419 target_amount: u64,
420 new_status: SparkNodeStatus,
421 ) -> Result<LeafSelectionResponse, SparkSdkError> {
422 if target_amount == 0 {
424 return Err(SparkSdkError::from(ValidationError::InvalidInput {
425 field: "Target amount cannot be 0".to_string(),
426 }));
427 }
428
429 let filter = Box::new(|_node: &SparkLeaf| true);
431
432 let mut guard = self.leaves.write();
434 let mut filtered_leaves = Vec::new();
435 for node in guard.values() {
436 if filter(&node.leaf) {
437 filtered_leaves.push(node.leaf.clone());
438 }
439 }
440
441 filtered_leaves.sort_by_key(|b| std::cmp::Reverse(b.get_value()));
443
444 let unlocking_id = new_status.generate_unlocking_id();
445 let mut total_value = 0;
446 let mut leaves = Vec::new();
447 for leaf in filtered_leaves {
448 if target_amount - total_value >= leaf.get_value() {
449 let leaf_id = leaf.get_id().clone();
450
451 total_value += leaf.get_value();
452 leaves.push(leaf);
453
454 let leaf_entry = guard.get_mut(&leaf_id).unwrap();
456 leaf_entry.status = new_status.clone();
457 leaf_entry.unlocking_id = unlocking_id.clone();
458 }
459 }
460
461 drop(guard);
462
463 Ok(LeafSelectionResponse {
464 leaves,
465 total_value,
466 unlocking_id: unlocking_id.clone(),
467 exact_amount: total_value == target_amount,
468 })
469 }
470
471 #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
472 pub(crate) fn unlock_leaves(
473 &self,
474 unlocking_id: String,
475 leaf_ids: &Vec<String>,
476 delete: bool,
477 ) -> Result<(), SparkSdkError> {
478 let mut leaves = self.leaves.write();
479
480 for leaf_id in leaf_ids {
482 let leaf = leaves.get(leaf_id).ok_or_else(|| {
483 SparkSdkError::from(WalletError::LeafNotFoundInWallet {
484 leaf_id: leaf_id.clone(),
485 })
486 })?;
487
488 if leaf.unlocking_id != Some(unlocking_id.clone()) {
489 return Err(SparkSdkError::from(WalletError::LeafNotUsingExpectedLock {
490 expected: unlocking_id.clone(),
491 actual: leaf.unlocking_id.clone().unwrap_or_default(),
492 }));
493 }
494 }
495
496 if delete {
497 for leaf_id in leaf_ids {
498 leaves.remove(leaf_id);
499 }
500
501 drop(leaves);
502 return Ok(());
503 }
504
505 for leaf_id in leaf_ids {
507 let leaf = leaves.get_mut(leaf_id).unwrap();
508 leaf.status = SparkNodeStatus::Available;
509 leaf.unlocking_id = None;
510 }
511
512 drop(leaves);
513 Ok(())
514 }
515}