spark_rust/wallet/internal_handlers/implementations/
leaves.rs

1use spark_protos::spark::{query_nodes_request::Source, QueryNodesRequest};
2use tonic::{async_trait, Request};
3
4use crate::{
5    error::{network::NetworkError, SparkSdkError, WalletError},
6    signer::traits::SparkSigner,
7    wallet::{
8        internal_handlers::traits::leaves::{LeafSelectionResponse, LeavesInternalHandlers},
9        leaf_manager::{SparkLeaf, SparkNodeStatus},
10    },
11    SparkSdk,
12};
13
14#[async_trait]
15impl<S: SparkSigner + Send + Sync + Clone + 'static> LeavesInternalHandlers<S> for SparkSdk<S> {
16    #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
17    async fn sync_leaves(&self) -> Result<(), SparkSdkError> {
18        let mut spark_client = self.config.spark_config.get_spark_connection(None).await?;
19
20        let identity_pubkey = self.get_spark_address()?;
21        let network = self.config.spark_config.network;
22        let mut request = Request::new(QueryNodesRequest {
23            include_parents: true,
24            source: Some(Source::OwnerIdentityPubkey(
25                identity_pubkey.serialize().to_vec(),
26            )),
27            network: network.marshal_proto(),
28        });
29
30        self.add_authorization_header_to_request(&mut request, None);
31
32        let response = spark_client
33            .query_nodes(request)
34            .await
35            .map_err(|status| SparkSdkError::from(NetworkError::Status(status)))?;
36        let queried_nodes = response.into_inner().nodes;
37
38        // select the available leaves only
39        let available_leaves_with_ids = queried_nodes
40            .into_iter()
41            .filter(|node| node.1.status.to_uppercase() == "AVAILABLE")
42            .collect::<Vec<_>>();
43
44        let available_leaves = available_leaves_with_ids
45            .into_iter()
46            .map(|(_, node)| SparkLeaf::Bitcoin(node))
47            .collect();
48
49        // refresh the leaf manager with the new leaves
50        self.leaf_manager.insert_leaves(available_leaves, true)?;
51
52        Ok(())
53    }
54
55    #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
56    async fn prepare_leaves_for_amount(
57        &self,
58        target_amount: u64,
59    ) -> Result<LeafSelectionResponse, SparkSdkError> {
60        if self.leaf_manager.get_available_bitcoin_value(None) < target_amount {
61            return Err(SparkSdkError::from(
62                WalletError::LeafSelectionInsufficientFunds {},
63            ));
64        }
65
66        // Attempt leaf selection with the target amount. If the exact amount is found, return the response.
67        let leaf_selection_response =
68            self.leaf_manager
69                .select_leaves(target_amount, None, SparkNodeStatus::Transfer)?;
70        if leaf_selection_response.exact_amount {
71            return Ok(leaf_selection_response);
72        }
73
74        // We don't have the exact amount available, so we'll need to swap leaves.
75        let leaf_ids = leaf_selection_response
76            .leaves
77            .iter()
78            .map(|l| l.get_id().clone())
79            .collect();
80
81        // Unlock the leaves before attempting to swap them.
82        self.leaf_manager.unlock_leaves(
83            leaf_selection_response.unlocking_id.unwrap(),
84            &leaf_ids,
85            false,
86        )?;
87
88        self.request_leaves_swap(target_amount).await?;
89
90        // Retry leaf selection now since we should have the exact amount available.
91        self.leaf_manager
92            .select_leaves(target_amount, None, SparkNodeStatus::Transfer)
93    }
94
95    #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
96    async fn optimize_leaves(&self) -> Result<(), SparkSdkError> {
97        // TODO leaves: add status to track if there is any optimization.
98        //              we already keep status of the leaves, but this seems to be for all the leaves.
99        //              this requires further thinking as a part of the next PR.
100
101        // TODO leaves: or if optimization is in progress (if we are optimizing for all leaves)
102        if !self.are_leaves_inefficient()? {
103            return Ok(());
104        }
105
106        // TODO leaves: an appropriate status must be set here, and ideally passed into `are_leaves_inefficient` with the unlocking_id
107        let leaves = self
108            .leaf_manager
109            .get_available_bitcoin_leaves(None, SparkNodeStatus::Available);
110
111        let target_amount = leaves
112            .iter()
113            .map(|leaf| leaf.get_tree_node().unwrap().value)
114            .sum();
115
116        // TODO leaves: again, we have many unlocked gaps between operations in `are_leaves_efficient` and until below. We must think about this.
117        if leaves.len() > 1 {
118            self.request_leaves_swap(target_amount).await?;
119        }
120
121        self.sync_leaves().await?;
122
123        Ok(())
124    }
125
126    #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
127    fn are_leaves_inefficient(&self) -> Result<bool, SparkSdkError> {
128        // Get all available leaves
129        let leaves = self
130            .leaf_manager
131            .get_available_bitcoin_leaves(None, SparkNodeStatus::Available); // TODO: we must use a function that does *not* request a status
132
133        if leaves.len() <= 1 {
134            return Ok(false);
135        }
136
137        // Calculate total amount across all Bitcoin leaves
138        let total_btc_amount = leaves
139            .iter()
140            .fold(0, |acc, leaf| acc + leaf.get_tree_node().unwrap().value); // safe unwrap
141
142        let next_lower_power_of_two = total_btc_amount.ilog2();
143
144        let mut remaining_amount = total_btc_amount;
145        let mut optimal_leaves_length = 0;
146
147        // Iterate from highest power to lowest (e.g., from 2^63 down to 2^0)
148        for i in (0..=next_lower_power_of_two).rev() {
149            let denomination = 1u64.checked_shl(i).unwrap_or(0);
150            if denomination == 0 {
151                continue; // Skip if shift would overflow
152            }
153
154            while remaining_amount >= denomination {
155                remaining_amount -= denomination;
156                optimal_leaves_length += 1;
157            }
158        }
159
160        let is_inefficient = leaves.len() > optimal_leaves_length * 5;
161
162        #[cfg(feature = "telemetry")]
163        tracing::debug!("Leaves are inefficient: {}", is_inefficient);
164
165        Ok(is_inefficient)
166    }
167}
168
169#[allow(dead_code)]
170fn next_lower_power_of_two(total_amount: u64) -> u64 {
171    if total_amount == 0 {
172        0
173    } else {
174        // This safely gets the highest power of 2 that's less than or equal to total_amount
175        1u64.checked_shl(total_amount.ilog2()).unwrap_or(1u64 << 63)
176    }
177}
178
179#[cfg(test)]
180mod tests {
181    use super::*;
182
183    #[test]
184    fn test_next_lower_power_of_two() {
185        assert_eq!(next_lower_power_of_two(0), 0);
186
187        assert_eq!(next_lower_power_of_two(1), 1);
188        assert_eq!(next_lower_power_of_two(2), 2);
189        assert_eq!(next_lower_power_of_two(4), 4);
190        assert_eq!(next_lower_power_of_two(8), 8);
191        assert_eq!(next_lower_power_of_two(16), 16);
192        assert_eq!(next_lower_power_of_two(256), 256);
193        assert_eq!(next_lower_power_of_two(65536), 65536);
194
195        assert_eq!(next_lower_power_of_two(3), 2);
196        assert_eq!(next_lower_power_of_two(5), 4);
197        assert_eq!(next_lower_power_of_two(7), 4);
198        assert_eq!(next_lower_power_of_two(9), 8);
199        assert_eq!(next_lower_power_of_two(15), 8);
200        assert_eq!(next_lower_power_of_two(17), 16);
201        assert_eq!(next_lower_power_of_two(31), 16);
202        assert_eq!(next_lower_power_of_two(100), 64);
203        assert_eq!(next_lower_power_of_two(1000), 512);
204
205        assert_eq!(next_lower_power_of_two(u64::MAX), 1u64 << 63);
206        assert_eq!(next_lower_power_of_two((1u64 << 63) + 1), 1u64 << 63);
207        assert_eq!(next_lower_power_of_two((1u64 << 62) + 1), 1u64 << 62);
208    }
209}