spark_rust/wallet/leaf_manager/
mod.rs1use crate::error::{validation::ValidationError, wallet::WalletError, SparkSdkError};
2use hashbrown::HashMap;
3use parking_lot::RwLock;
4use serde::{Deserialize, Serialize};
5use spark_protos::spark::TreeNode;
6use std::sync::Arc;
7use uuid::Uuid;
8
9use super::internal_handlers::traits::leaves::LeafSelectionResponse;
10
11struct SparkLeafEntry {
12 leaf: SparkLeaf,
13 status: SparkNodeStatus,
14 unlocking_id: Option<String>,
15}
16
17type LeafMap = Arc<RwLock<HashMap<String, SparkLeafEntry>>>;
18
19pub(crate) struct LeafManager {
20 leaves: LeafMap,
22}
23
24#[derive(Debug, Clone)]
25pub(crate) struct TokenLeaf {
26 pub(crate) id: String,
28
29 pub(crate) _tree_id: String,
31
32 pub(crate) value: u64,
34
35 pub(crate) token_public_key: Vec<u8>,
37
38 pub(crate) _revocation_public_key: Vec<u8>,
40
41 pub(crate) _token_transaction_hash: Vec<u8>,
43}
44
45#[derive(Debug, Clone)]
46pub(crate) enum SparkLeaf {
47 Bitcoin(TreeNode),
48 #[allow(dead_code)]
49 Token(TokenLeaf),
50}
51
52impl SparkLeaf {
53 pub(crate) fn get_id(&self) -> &String {
54 match self {
55 SparkLeaf::Bitcoin(leaf) => &leaf.id,
56 SparkLeaf::Token(leaf) => &leaf.id,
57 }
58 }
59
60 pub(crate) fn get_value(&self) -> u64 {
61 match self {
62 SparkLeaf::Bitcoin(leaf) => leaf.value,
63 SparkLeaf::Token(leaf) => leaf.value,
64 }
65 }
66
67 pub(crate) fn is_bitcoin(&self) -> bool {
68 matches!(self, SparkLeaf::Bitcoin(_))
69 }
70
71 pub(crate) fn get_token_pubkey(&self) -> Option<Vec<u8>> {
72 match self {
73 SparkLeaf::Bitcoin(_) => None,
74 SparkLeaf::Token(leaf) => Some(leaf.token_public_key.clone()),
75 }
76 }
77
78 pub(crate) fn get_tree_node(&self) -> Result<TreeNode, SparkSdkError> {
79 match self {
80 SparkLeaf::Bitcoin(leaf) => Ok(leaf.clone()),
81 SparkLeaf::Token(_) => Err(SparkSdkError::from(WalletError::LeafIsNotBitcoin {
82 leaf_id: self.get_id().clone(),
83 })),
84 }
85 }
86}
87
88#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
90pub enum SparkNodeStatus {
91 Available,
93
94 AggregatableParent,
96
97 Transfer,
99
100 Split,
102
103 Swap,
105
106 FeeQuery,
108
109 CooperativeExit,
111}
112
113impl SparkNodeStatus {
114 fn generate_unlocking_id(&self) -> Option<String> {
115 match self {
116 SparkNodeStatus::Transfer | SparkNodeStatus::Split | SparkNodeStatus::Swap => {
117 Some(Uuid::now_v7().to_string())
118 }
119 _ => None,
120 }
121 }
122}
123
124#[derive(Debug)]
125pub(crate) struct LockLeavesResponse {
126 pub(crate) unlocking_id: Option<String>,
127 pub(crate) leaves: Vec<SparkLeaf>,
128}
129
130pub(crate) type LeafFilterFunction = fn(&SparkLeaf) -> bool;
131
132impl LeafManager {
133 pub(crate) fn new() -> Self {
134 Self {
135 leaves: Arc::new(parking_lot::RwLock::new(HashMap::new())),
136 }
137 }
138
139 #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
140 pub(crate) fn filter_nodes(&self, cb: Option<LeafFilterFunction>) -> Vec<SparkLeaf> {
141 let mut nodes = Vec::new();
142 let guard = self.leaves.read();
143 for node in guard.values() {
144 if cb.as_ref().is_some_and(|f| f(&node.leaf)) {
145 nodes.push(node.leaf.clone());
146 }
147 }
148
149 drop(guard);
150 nodes
151 }
152
153 #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
154 pub(crate) fn filter_nodes_by_ids(&self, leaf_ids: &Vec<String>) -> Vec<SparkLeaf> {
155 let guard = self.leaves.read();
156 let mut nodes = Vec::new();
157 for leaf_id in leaf_ids {
158 if let Some(entry) = guard.get(leaf_id) {
159 nodes.push(entry.leaf.clone());
160 }
161 }
162
163 drop(guard);
164 nodes
165 }
166
167 #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
168 pub(crate) fn filter_and_lock_nodes(
169 &self,
170 cb: Option<LeafFilterFunction>,
171 new_status: Option<SparkNodeStatus>,
172 ) -> Vec<SparkLeaf> {
173 let mut nodes = Vec::new();
175 let unlocking_id = match &new_status {
176 Some(status) => status.generate_unlocking_id(),
177 None => None,
178 };
179
180 let mut guard = self.leaves.write();
181 for node in guard.values_mut() {
182 if cb.as_ref().is_some_and(|f| f(&node.leaf)) {
183 if new_status.is_some() {
184 node.unlocking_id = unlocking_id.clone();
185 node.status = new_status.clone().unwrap();
186 }
187 nodes.push(node.leaf.clone());
188 }
189 }
190
191 drop(guard);
192 nodes
193 }
194
195 #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
196 pub(crate) fn lock_leaf_ids(
197 &self,
198 leaf_ids: &Vec<String>,
199 new_status: SparkNodeStatus,
200 ) -> Result<LeafSelectionResponse, SparkSdkError> {
201 let mut leaves = Vec::new();
203 let mut guard = self.leaves.write();
204 for leaf_id in leaf_ids {
205 let get_leaf = guard.get(leaf_id);
206 if get_leaf.is_none() {
207 drop(guard);
208 return Err(SparkSdkError::from(WalletError::LeafNotFoundInWallet {
209 leaf_id: leaf_id.clone(),
210 }));
211 }
212 let leaf = get_leaf.unwrap();
213 if leaf.status != SparkNodeStatus::Available {
214 drop(guard);
215 return Err(SparkSdkError::from(WalletError::LeafNotAvailableForUse {
216 leaf_id: leaf_id.clone(),
217 }));
218 }
219 leaves.push(leaf.leaf.clone());
220 }
221
222 let unlocking_id = new_status.generate_unlocking_id();
223 for leaf_id in leaf_ids {
224 let leaf = guard.get_mut(leaf_id).unwrap();
225 leaf.status = new_status.clone();
226 leaf.unlocking_id = unlocking_id.clone();
227 }
228
229 let total_value = leaves.iter().map(|l| l.get_value()).sum();
231
232 Ok(LeafSelectionResponse {
233 leaves,
234 total_value,
235 unlocking_id,
236 exact_amount: true,
237 })
238 }
239
240 #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
241 pub(crate) fn remove_all_leaves(&self) -> Result<(), SparkSdkError> {
242 let mut guard = self.leaves.write();
243 guard.clear();
244 Ok(())
245 }
246
247 #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
248 pub(crate) fn insert_leaves(
249 &self,
250 new_leaves: Vec<SparkLeaf>,
251 delete_leaves_first: bool,
252 ) -> Result<(), SparkSdkError> {
253 let mut guard = self.leaves.write();
254
255 if delete_leaves_first {
256 guard.clear();
257 }
258
259 for leaf in &new_leaves {
261 let id = leaf.get_id();
262 if guard.contains_key(id.as_str()) {
263 return Err(SparkSdkError::from(
264 WalletError::LeafAlreadyExistsInWallet {
265 leaf_id: id.clone(),
266 },
267 ));
268 }
269 }
270
271 for leaf in new_leaves {
273 let id = leaf.get_id();
274 let leaf_entry = SparkLeafEntry {
275 leaf: leaf.clone(),
276 status: SparkNodeStatus::Available,
277 unlocking_id: None,
278 };
279 guard.insert(id.clone(), leaf_entry);
280 }
281
282 Ok(())
283 }
284
285 #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
286 pub(crate) fn get_available_bitcoin_value(&self, filter_cb: Option<LeafFilterFunction>) -> u64 {
287 let default_filter = |ln: &SparkLeaf| ln.is_bitcoin();
289
290 let filter: Box<dyn Fn(&SparkLeaf) -> bool> = if let Some(custom_filter) = filter_cb {
292 let combined_filter =
293 move |node: &SparkLeaf| default_filter(node) && custom_filter(node);
294 Box::new(combined_filter)
295 } else {
296 Box::new(move |node: &SparkLeaf| default_filter(node))
297 };
298
299 let guard = self.leaves.read();
300 let mut available_btc_sum = 0;
301 for node in guard.values() {
302 if node.status != SparkNodeStatus::Available {
303 continue;
304 }
305
306 if filter(&node.leaf) {
307 available_btc_sum += node.leaf.get_value();
308 }
309 }
310
311 drop(guard);
312 available_btc_sum
313 }
314
315 #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
316 pub(crate) fn get_available_bitcoin_leaves(
317 &self,
318 filter_cb: Option<LeafFilterFunction>,
319 new_status: SparkNodeStatus,
320 ) -> Vec<SparkLeaf> {
321 let default_filter = |ln: &SparkLeaf| ln.is_bitcoin();
323
324 let filter: Box<dyn Fn(&SparkLeaf) -> bool> = if let Some(custom_filter) = filter_cb {
326 let combined_filter =
327 move |node: &SparkLeaf| default_filter(node) && custom_filter(node);
328 Box::new(combined_filter)
329 } else {
330 Box::new(move |node: &SparkLeaf| default_filter(node))
331 };
332
333 let unlocking_id = Uuid::now_v7().to_string();
334
335 let mut guard = self.leaves.write();
336 let mut available_btc_leaves = Vec::new();
337 for node in guard.values_mut() {
338 if node.status != SparkNodeStatus::Available {
339 continue;
340 }
341
342 node.status = new_status.clone();
343 node.unlocking_id = Some(unlocking_id.clone());
344
345 if filter(&node.leaf) {
346 available_btc_leaves.push(node.leaf.clone());
347 }
348 }
349
350 drop(guard);
351 available_btc_leaves
352 }
353
354 #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
356 pub(crate) fn lock_available_bitcoin_leaves(
357 &self,
358 new_status: SparkNodeStatus,
359 ) -> LockLeavesResponse {
360 let unlocking_id = Uuid::now_v7().to_string();
362 let default_filter: Box<dyn Fn(&SparkLeaf) -> bool> = Box::new(|ln| ln.is_bitcoin());
363
364 let mut guard = self.leaves.write();
365 let mut available_btc_leaves = Vec::new();
366 for node in guard.values_mut() {
367 if node.status != SparkNodeStatus::Available {
368 continue;
369 }
370
371 if default_filter(&node.leaf) {
372 node.status = new_status.clone();
373 node.unlocking_id = Some(unlocking_id.clone());
374 available_btc_leaves.push(node.leaf.clone());
375 }
376 }
377
378 drop(guard);
379 LockLeavesResponse {
380 unlocking_id: Some(unlocking_id),
381 leaves: available_btc_leaves,
382 }
383 }
384
385 #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
386 pub(crate) fn select_leaves(
387 &self,
388 target_amount: u64,
389 token_pubkey: Option<Vec<u8>>,
390 new_status: SparkNodeStatus,
391 ) -> Result<LeafSelectionResponse, SparkSdkError> {
392 if target_amount == 0 {
394 return Err(SparkSdkError::from(ValidationError::InvalidInput {
395 field: "Target amount cannot be 0".to_string(),
396 }));
397 }
398
399 let filter = Box::new(|node: &SparkLeaf| node.get_token_pubkey() == token_pubkey);
401
402 let mut guard = self.leaves.write();
404 let mut filtered_leaves = Vec::new();
405 for node in guard.values() {
406 if filter(&node.leaf) {
407 filtered_leaves.push(node.leaf.clone());
408 }
409 }
410
411 filtered_leaves.sort_by_key(|b| std::cmp::Reverse(b.get_value()));
413
414 let unlocking_id = new_status.generate_unlocking_id();
415 let mut total_value = 0;
416 let mut leaves = Vec::new();
417 let mut target_reached = false;
418
419 for leaf in filtered_leaves {
420 if target_reached {
422 break;
423 }
424
425 let leaf_id = leaf.get_id().clone();
426
427 total_value += leaf.get_value();
429 leaves.push(leaf);
430
431 let leaf_entry = guard.get_mut(&leaf_id).unwrap();
433 leaf_entry.status = new_status.clone();
434 leaf_entry.unlocking_id = unlocking_id.clone();
435
436 if total_value >= target_amount {
439 target_reached = true;
440 }
441 }
442
443 drop(guard);
444
445 Ok(LeafSelectionResponse {
446 leaves,
447 total_value,
448 unlocking_id: unlocking_id.clone(),
449 exact_amount: total_value == target_amount,
450 })
451 }
452
453 #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
454 pub(crate) fn unlock_leaves(
455 &self,
456 unlocking_id: String,
457 leaf_ids: &Vec<String>,
458 delete: bool,
459 ) -> Result<(), SparkSdkError> {
460 let mut leaves = self.leaves.write();
461
462 for leaf_id in leaf_ids {
464 let leaf = leaves.get(leaf_id).ok_or_else(|| {
465 SparkSdkError::from(WalletError::LeafNotFoundInWallet {
466 leaf_id: leaf_id.clone(),
467 })
468 })?;
469
470 if leaf.unlocking_id != Some(unlocking_id.clone()) {
471 return Err(SparkSdkError::from(WalletError::LeafNotUsingExpectedLock {
472 expected: unlocking_id.clone(),
473 actual: leaf.unlocking_id.clone().unwrap_or_default(),
474 }));
475 }
476 }
477
478 if delete {
479 for leaf_id in leaf_ids {
480 leaves.remove(leaf_id);
481 }
482
483 drop(leaves);
484 return Ok(());
485 }
486
487 for leaf_id in leaf_ids {
489 let leaf = leaves.get_mut(leaf_id).unwrap();
490 leaf.status = SparkNodeStatus::Available;
491 leaf.unlocking_id = None;
492 }
493
494 drop(leaves);
495 Ok(())
496 }
497}