simln_lib/
sim_node.rs

1use crate::{
2    LightningError, LightningNode, NodeInfo, PaymentOutcome, PaymentResult, SimulationError,
3};
4use async_trait::async_trait;
5use bitcoin::constants::ChainHash;
6use bitcoin::hashes::{sha256::Hash as Sha256, Hash};
7use bitcoin::secp256k1::PublicKey;
8use bitcoin::{Network, ScriptBuf, TxOut};
9use lightning::ln::chan_utils::make_funding_redeemscript;
10use std::collections::{hash_map::Entry, HashMap};
11use std::sync::Arc;
12use std::time::{SystemTime, UNIX_EPOCH};
13
14use lightning::ln::features::{ChannelFeatures, NodeFeatures};
15use lightning::ln::msgs::{
16    LightningError as LdkError, UnsignedChannelAnnouncement, UnsignedChannelUpdate,
17};
18use lightning::ln::{PaymentHash, PaymentPreimage};
19use lightning::routing::gossip::{NetworkGraph, NodeId};
20use lightning::routing::router::{find_route, Path, PaymentParameters, Route, RouteParameters};
21use lightning::routing::scoring::ProbabilisticScorer;
22use lightning::routing::utxo::{UtxoLookup, UtxoResult};
23use lightning::util::logger::{Level, Logger, Record};
24use thiserror::Error;
25use tokio::select;
26use tokio::sync::oneshot::{channel, Receiver, Sender};
27use tokio::sync::Mutex;
28use tokio::task::JoinSet;
29use triggered::{Listener, Trigger};
30
31use crate::ShortChannelID;
32
33/// ForwardingError represents the various errors that we can run into when forwarding payments in a simulated network.
34/// Since we're not using real lightning nodes, these errors are not obfuscated and can be propagated to the sending
35/// node and used for analysis.
36#[derive(Debug, Error)]
37pub enum ForwardingError {
38    /// Zero amount htlcs are invalid in the protocol.
39    #[error("ZeroAmountHtlc")]
40    ZeroAmountHtlc,
41    /// The outgoing channel id was not found in the network graph.
42    #[error("ChannelNotFound: {0}")]
43    ChannelNotFound(ShortChannelID),
44    /// The node pubkey provided was not associated with the channel in the network graph.
45    #[error("NodeNotFound: {0:?}")]
46    NodeNotFound(PublicKey),
47    /// The channel has already forwarded an HTLC with the payment hash provided.
48    /// TODO: remove if MPP support is added.
49    #[error("PaymentHashExists: {0:?}")]
50    PaymentHashExists(PaymentHash),
51    /// An htlc with the payment hash provided could not be found to resolve.
52    #[error("PaymentHashNotFound: {0:?}")]
53    PaymentHashNotFound(PaymentHash),
54    /// The forwarding node did not have sufficient outgoing balance to forward the htlc (htlc amount / balance).
55    #[error("InsufficientBalance: amount: {0} > balance: {1}")]
56    InsufficientBalance(u64, u64),
57    /// The htlc forwarded is less than the channel's advertised minimum htlc amount (htlc amount / minimum).
58    #[error("LessThanMinimum: amount: {0} < minimum: {1}")]
59    LessThanMinimum(u64, u64),
60    /// The htlc forwarded is more than the channel's advertised maximum htlc amount (htlc amount / maximum).
61    #[error("MoreThanMaximum: amount: {0} > maximum: {1}")]
62    MoreThanMaximum(u64, u64),
63    /// The channel has reached its maximum allowable number of htlcs in flight (total in flight / maximim).
64    #[error("ExceedsInFlightCount: total in flight: {0} > maximum count: {1}")]
65    ExceedsInFlightCount(u64, u64),
66    /// The forwarded htlc's amount would push the channel over its maximum allowable in flight total
67    /// (total in flight / maximum).
68    #[error("ExceedsInFlightTotal: total in flight amount: {0} > maximum amount: {0}")]
69    ExceedsInFlightTotal(u64, u64),
70    /// The forwarded htlc's cltv expiry exceeds the maximum value used to express block heights in Bitcoin.
71    #[error("ExpiryInSeconds: cltv expressed in seconds: {0}")]
72    ExpiryInSeconds(u32, u32),
73    /// The forwarded htlc has insufficient cltv delta for the channel's minimum delta (cltv delta / minimum).
74    #[error("InsufficientCltvDelta: cltv delta: {0} < required: {1}")]
75    InsufficientCltvDelta(u32, u32),
76    /// The forwarded htlc has insufficient fee for the channel's policy (fee / expected fee / base fee / prop fee).
77    #[error("InsufficientFee: offered fee: {0} (base: {1}, prop: {2}) < expected: {3}")]
78    InsufficientFee(u64, u64, u64, u64),
79    /// The fee policy for a htlc amount would overflow with the given fee policy (htlc amount / base fee / prop fee).
80    #[error("FeeOverflow: htlc amount: {0} (base: {1}, prop: {2})")]
81    FeeOverflow(u64, u64, u64),
82    /// Sanity check on channel balances failed (node balances / channel capacity).
83    #[error("SanityCheckFailed: node balance: {0} != capacity: {1}")]
84    SanityCheckFailed(u64, u64),
85}
86
87impl ForwardingError {
88    /// Returns a boolean indicating whether failure to forward a htlc is a critical error that warrants shutdown.
89    fn is_critical(&self) -> bool {
90        matches!(
91            self,
92            ForwardingError::ZeroAmountHtlc
93                | ForwardingError::ChannelNotFound(_)
94                | ForwardingError::NodeNotFound(_)
95                | ForwardingError::PaymentHashExists(_)
96                | ForwardingError::PaymentHashNotFound(_)
97                | ForwardingError::SanityCheckFailed(_, _)
98                | ForwardingError::FeeOverflow(_, _, _)
99        )
100    }
101}
102
103/// Represents an in-flight htlc that has been forwarded over a channel that is awaiting resolution.
104#[derive(Copy, Clone)]
105struct Htlc {
106    amount_msat: u64,
107    cltv_expiry: u32,
108}
109
110/// Represents one node in the channel's forwarding policy and restrictions. Note that this doesn't directly map to
111/// a single concept in the protocol, a few things have been combined for the sake of simplicity. Used to manage the
112/// lightning "state machine" and check that HTLCs are added in accordance of the advertised policy.
113#[derive(Clone)]
114pub struct ChannelPolicy {
115    pub pubkey: PublicKey,
116    pub max_htlc_count: u64,
117    pub max_in_flight_msat: u64,
118    pub min_htlc_size_msat: u64,
119    pub max_htlc_size_msat: u64,
120    pub cltv_expiry_delta: u32,
121    pub base_fee: u64,
122    pub fee_rate_prop: u64,
123}
124
125impl ChannelPolicy {
126    /// Validates that the channel policy is acceptable for the size of the channel.
127    fn validate(&self, capacity_msat: u64) -> Result<(), SimulationError> {
128        if self.max_in_flight_msat > capacity_msat {
129            return Err(SimulationError::SimulatedNetworkError(format!(
130                "max_in_flight_msat {} > capacity {}",
131                self.max_in_flight_msat, capacity_msat
132            )));
133        }
134        if self.max_htlc_size_msat > capacity_msat {
135            return Err(SimulationError::SimulatedNetworkError(format!(
136                "max_htlc_size_msat {} > capacity {}",
137                self.max_htlc_size_msat, capacity_msat
138            )));
139        }
140        Ok(())
141    }
142}
143
144/// Fails with the forwarding error provided if the value provided fails its inequality check.
145macro_rules! fail_forwarding_inequality {
146    ($value_1:expr, $op:tt, $value_2:expr, $error_variant:ident $(, $opt:expr)*) => {
147        if $value_1 $op $value_2 {
148            return Err(ForwardingError::$error_variant(
149                    $value_1,
150                    $value_2
151                    $(
152                        , $opt
153                    )*
154             ));
155        }
156    };
157}
158
159/// The internal state of one side of a simulated channel, including its forwarding parameters. This struct is
160/// primarily responsible for handling our view of what's currently in-flight on the channel, and how much
161/// liquidity we have.
162#[derive(Clone)]
163struct ChannelState {
164    local_balance_msat: u64,
165    in_flight: HashMap<PaymentHash, Htlc>,
166    policy: ChannelPolicy,
167}
168
169impl ChannelState {
170    /// Creates a new channel with local liquidity as allocated by the caller. The responsibility of ensuring that the
171    /// local balance of each side of the channel equals its total capacity is on the caller, as we are only dealing
172    /// with a one-sided view of the channel's state.
173    fn new(policy: ChannelPolicy, local_balance_msat: u64) -> Self {
174        ChannelState {
175            local_balance_msat,
176            in_flight: HashMap::new(),
177            policy,
178        }
179    }
180
181    /// Returns the sum of all the *in flight outgoing* HTLCs on the channel.
182    fn in_flight_total(&self) -> u64 {
183        self.in_flight.values().map(|h| h.amount_msat).sum()
184    }
185
186    /// Checks whether the proposed HTLC abides by the channel policy advertised for using this channel as the
187    /// *outgoing* link in a forward.
188    fn check_htlc_forward(
189        &self,
190        cltv_delta: u32,
191        amt: u64,
192        fee: u64,
193    ) -> Result<(), ForwardingError> {
194        fail_forwarding_inequality!(cltv_delta, <, self.policy.cltv_expiry_delta, InsufficientCltvDelta);
195
196        let expected_fee = amt
197            .checked_mul(self.policy.fee_rate_prop)
198            .and_then(|prop_fee| (prop_fee / 1000000).checked_add(self.policy.base_fee))
199            .ok_or(ForwardingError::FeeOverflow(
200                amt,
201                self.policy.base_fee,
202                self.policy.fee_rate_prop,
203            ))?;
204
205        fail_forwarding_inequality!(
206            fee, <, expected_fee, InsufficientFee, self.policy.base_fee, self.policy.fee_rate_prop
207        );
208
209        Ok(())
210    }
211
212    /// Checks whether the proposed HTLC can be added to the channel as an outgoing HTLC. This requires that we have
213    /// sufficient liquidity, and that the restrictions on our in flight htlc balance and count are not violated by
214    /// the addition of the HTLC. Specification sanity checks (such as reasonable CLTV) are also included, as this
215    /// is where we'd check it in real life.
216    fn check_outgoing_addition(&self, htlc: &Htlc) -> Result<(), ForwardingError> {
217        fail_forwarding_inequality!(htlc.amount_msat, >, self.policy.max_htlc_size_msat, MoreThanMaximum);
218        fail_forwarding_inequality!(htlc.amount_msat, <, self.policy.min_htlc_size_msat, LessThanMinimum);
219        fail_forwarding_inequality!(
220            self.in_flight.len() as u64 + 1, >, self.policy.max_htlc_count, ExceedsInFlightCount
221        );
222        fail_forwarding_inequality!(
223            self.in_flight_total() + htlc.amount_msat, >, self.policy.max_in_flight_msat, ExceedsInFlightTotal
224        );
225        fail_forwarding_inequality!(htlc.amount_msat, >, self.local_balance_msat, InsufficientBalance);
226        fail_forwarding_inequality!(htlc.cltv_expiry, >, 500000000, ExpiryInSeconds);
227
228        Ok(())
229    }
230
231    /// Adds the HTLC to our set of outgoing in-flight HTLCs. [`check_outgoing_addition`] must be called before
232    /// this to ensure that the restrictions on outgoing HTLCs are not violated. Local balance is decreased by the
233    /// HTLC amount, as this liquidity is no longer available.
234    ///
235    /// Note: MPP payments are not currently supported, so this function will fail if a duplicate payment hash is
236    /// reported.
237    fn add_outgoing_htlc(&mut self, hash: PaymentHash, htlc: Htlc) -> Result<(), ForwardingError> {
238        self.check_outgoing_addition(&htlc)?;
239        if self.in_flight.get(&hash).is_some() {
240            return Err(ForwardingError::PaymentHashExists(hash));
241        }
242        self.local_balance_msat -= htlc.amount_msat;
243        self.in_flight.insert(hash, htlc);
244        Ok(())
245    }
246
247    /// Removes the HTLC from our set of outgoing in-flight HTLCs, failing if the payment hash is not found.
248    fn remove_outgoing_htlc(&mut self, hash: &PaymentHash) -> Result<Htlc, ForwardingError> {
249        self.in_flight
250            .remove(hash)
251            .ok_or(ForwardingError::PaymentHashNotFound(*hash))
252    }
253
254    // Updates channel state to account for the resolution of an outgoing in-flight HTLC. If the HTLC failed, the
255    // balance is failed back to the channel's local balance. If not, the in-flight balance is settled to the other
256    // node, so there is no operation.
257    fn settle_outgoing_htlc(&mut self, amt: u64, success: bool) {
258        if !success {
259            self.local_balance_msat += amt
260        }
261    }
262
263    // Updates channel state to account for the resolution of an incoming in-flight HTLC. If the HTLC succeeded,
264    // the balance is settled to the channel's local balance. If not, the in-flight balance is failed back to the
265    // other node, so there is no operation.
266    fn settle_incoming_htlc(&mut self, amt: u64, success: bool) {
267        if success {
268            self.local_balance_msat += amt
269        }
270    }
271}
272
273/// Represents a simulated channel, and is responsible for managing addition and removal of HTLCs from the channel and
274/// sanity checks. Channel state is tracked *unidirectionally* for each participant in the channel.
275///
276/// Each node represented in the channel tracks only its outgoing HTLCs, and balance is transferred between the two
277/// nodes as they settle or fail. Given some channel: node_1 <----> node_2:
278/// * HTLC sent node_1 -> node_2: added to in-flight outgoing htlcs on node_1.
279/// * HTLC sent node_2 -> node_1: added to in-flight outgoing htlcs on node_2.
280///
281/// Rules for managing balance are as follows:
282/// * When an HTLC is in flight, the channel's local outgoing liquidity decreases (as it's locked up).
283/// * When an HTLC fails, the balance is returned to the local node (the one that it was in-flight / outgoing on).
284/// * When an HTLC succeeds, the balance is sent to the remote node (the one that did not track it as in-flight).
285///
286/// With each state transition, the simulated channel checks that the sum of its local balances and in-flight equal the
287/// total channel capacity. Failure of this sanity check represents a critical failure in the state machine.
288#[derive(Clone)]
289pub struct SimulatedChannel {
290    capacity_msat: u64,
291    short_channel_id: ShortChannelID,
292    node_1: ChannelState,
293    node_2: ChannelState,
294}
295
296impl SimulatedChannel {
297    /// Creates a new channel with the capacity and policies provided. The total capacity of the channel is evenly split
298    /// between the channel participants (this is an arbitrary decision).
299    pub fn new(
300        capacity_msat: u64,
301        short_channel_id: ShortChannelID,
302        node_1: ChannelPolicy,
303        node_2: ChannelPolicy,
304    ) -> Self {
305        SimulatedChannel {
306            capacity_msat,
307            short_channel_id,
308            node_1: ChannelState::new(node_1, capacity_msat / 2),
309            node_2: ChannelState::new(node_2, capacity_msat / 2),
310        }
311    }
312
313    /// Validates that a simulated channel has distinct node pairs and valid routing policies.
314    fn validate(&self) -> Result<(), SimulationError> {
315        if self.node_1.policy.pubkey == self.node_2.policy.pubkey {
316            return Err(SimulationError::SimulatedNetworkError(format!(
317                "Channel should have distinct node pubkeys, got: {} for both nodes.",
318                self.node_1.policy.pubkey
319            )));
320        }
321
322        self.node_1.policy.validate(self.capacity_msat)?;
323        self.node_2.policy.validate(self.capacity_msat)?;
324
325        Ok(())
326    }
327
328    fn get_node_mut(&mut self, pubkey: &PublicKey) -> Result<&mut ChannelState, ForwardingError> {
329        if pubkey == &self.node_1.policy.pubkey {
330            Ok(&mut self.node_1)
331        } else if pubkey == &self.node_2.policy.pubkey {
332            Ok(&mut self.node_2)
333        } else {
334            Err(ForwardingError::NodeNotFound(*pubkey))
335        }
336    }
337
338    fn get_node(&self, pubkey: &PublicKey) -> Result<&ChannelState, ForwardingError> {
339        if pubkey == &self.node_1.policy.pubkey {
340            Ok(&self.node_1)
341        } else if pubkey == &self.node_2.policy.pubkey {
342            Ok(&self.node_2)
343        } else {
344            Err(ForwardingError::NodeNotFound(*pubkey))
345        }
346    }
347
348    /// Adds an htlc to the appropriate side of the simulated channel, checking its policy and balance are okay. The
349    /// public key of the node sending the HTLC (ie, the party that would send update_add_htlc in the protocol)
350    /// must be provided to add the outgoing htlc to its side of the channel.
351    fn add_htlc(
352        &mut self,
353        sending_node: &PublicKey,
354        hash: PaymentHash,
355        htlc: Htlc,
356    ) -> Result<(), ForwardingError> {
357        if htlc.amount_msat == 0 {
358            return Err(ForwardingError::ZeroAmountHtlc);
359        }
360
361        self.get_node_mut(sending_node)?
362            .add_outgoing_htlc(hash, htlc)?;
363        self.sanity_check()
364    }
365
366    /// Performs a sanity check on the total balances in a channel. Note that we do not currently include on-chain
367    /// fees or reserve so these values should exactly match.
368    fn sanity_check(&self) -> Result<(), ForwardingError> {
369        let node_1_total = self.node_1.local_balance_msat + self.node_1.in_flight_total();
370        let node_2_total = self.node_2.local_balance_msat + self.node_2.in_flight_total();
371
372        fail_forwarding_inequality!(node_1_total + node_2_total, !=, self.capacity_msat, SanityCheckFailed);
373
374        Ok(())
375    }
376
377    /// Removes an htlc from the appropriate side of the simulated channel, settling balances across channel sides
378    /// based on the success of the htlc. The public key of the node that originally sent the HTLC (ie, the party
379    /// that would send update_add_htlc in the protocol) must be provided to remove the htlc from its side of the
380    /// channel.
381    fn remove_htlc(
382        &mut self,
383        sending_node: &PublicKey,
384        hash: &PaymentHash,
385        success: bool,
386    ) -> Result<(), ForwardingError> {
387        let htlc = self
388            .get_node_mut(sending_node)?
389            .remove_outgoing_htlc(hash)?;
390        self.settle_htlc(sending_node, htlc.amount_msat, success)?;
391        self.sanity_check()
392    }
393
394    /// Updates the local balance of each node in the channel once a htlc has been resolved, pushing funds to the
395    /// receiving nodes in the case of a successful payment and returning balance to the sender in the case of a
396    /// failure.
397    fn settle_htlc(
398        &mut self,
399        sending_node: &PublicKey,
400        amount_msat: u64,
401        success: bool,
402    ) -> Result<(), ForwardingError> {
403        if sending_node == &self.node_1.policy.pubkey {
404            self.node_1.settle_outgoing_htlc(amount_msat, success);
405            self.node_2.settle_incoming_htlc(amount_msat, success);
406            Ok(())
407        } else if sending_node == &self.node_2.policy.pubkey {
408            self.node_2.settle_outgoing_htlc(amount_msat, success);
409            self.node_1.settle_incoming_htlc(amount_msat, success);
410            Ok(())
411        } else {
412            Err(ForwardingError::NodeNotFound(*sending_node))
413        }
414    }
415
416    /// Checks an htlc forward against the outgoing policy of the node provided.
417    fn check_htlc_forward(
418        &self,
419        forwarding_node: &PublicKey,
420        cltv_delta: u32,
421        amount_msat: u64,
422        fee_msat: u64,
423    ) -> Result<(), ForwardingError> {
424        self.get_node(forwarding_node)?
425            .check_htlc_forward(cltv_delta, amount_msat, fee_msat)
426    }
427}
428
429/// SimNetwork represents a high level network coordinator that is responsible for the task of actually propagating
430/// payments through the simulated network.
431#[async_trait]
432trait SimNetwork: Send + Sync {
433    /// Sends payments over the route provided through the network, reporting the final payment outcome to the sender
434    /// channel provided.
435    fn dispatch_payment(
436        &mut self,
437        source: PublicKey,
438        route: Route,
439        payment_hash: PaymentHash,
440        sender: Sender<Result<PaymentResult, LightningError>>,
441    );
442
443    /// Looks up a node in the simulated network and a list of its channel capacities.
444    async fn lookup_node(&self, node: &PublicKey) -> Result<(NodeInfo, Vec<u64>), LightningError>;
445}
446
447/// A wrapper struct used to implement the LightningNode trait (can be thought of as "the" lightning node). Passes
448/// all functionality through to a coordinating simulation network. This implementation contains both the [`SimNetwork`]
449/// implementation that will allow us to dispatch payments and a read-only NetworkGraph that is used for pathfinding.
450/// While these two could be combined, we re-use the LDK-native struct to allow re-use of their pathfinding logic.
451struct SimNode<'a, T: SimNetwork> {
452    info: NodeInfo,
453    /// The underlying execution network that will be responsible for dispatching payments.
454    network: Arc<Mutex<T>>,
455    /// Tracks the channel that will provide updates for payments by hash.
456    in_flight: HashMap<PaymentHash, Receiver<Result<PaymentResult, LightningError>>>,
457    /// A read-only graph used for pathfinding.
458    pathfinding_graph: Arc<NetworkGraph<&'a WrappedLog>>,
459}
460
461impl<'a, T: SimNetwork> SimNode<'a, T> {
462    /// Creates a new simulation node that refers to the high level network coordinator provided to process payments
463    /// on its behalf. The pathfinding graph is provided separately so that each node can handle its own pathfinding.
464    pub fn new(
465        pubkey: PublicKey,
466        payment_network: Arc<Mutex<T>>,
467        pathfinding_graph: Arc<NetworkGraph<&'a WrappedLog>>,
468    ) -> Self {
469        SimNode {
470            info: node_info(pubkey),
471            network: payment_network,
472            in_flight: HashMap::new(),
473            pathfinding_graph,
474        }
475    }
476}
477
478/// Produces the node info for a mocked node, filling in the features that the simulator requires.
479fn node_info(pubkey: PublicKey) -> NodeInfo {
480    // Set any features that the simulator requires here.
481    let mut features = NodeFeatures::empty();
482    features.set_keysend_optional();
483
484    NodeInfo {
485        pubkey,
486        alias: "".to_string(), // TODO: store alias?
487        features,
488    }
489}
490
491/// Uses LDK's pathfinding algorithm with default parameters to find a path from source to destination, with no
492/// restrictions on fee budget.
493fn find_payment_route(
494    source: &PublicKey,
495    dest: PublicKey,
496    amount_msat: u64,
497    pathfinding_graph: &NetworkGraph<&WrappedLog>,
498) -> Result<Route, SimulationError> {
499    let scorer = ProbabilisticScorer::new(Default::default(), pathfinding_graph, &WrappedLog {});
500
501    find_route(
502        source,
503        &RouteParameters {
504            payment_params: PaymentParameters::from_node_id(dest, 0)
505                .with_max_total_cltv_expiry_delta(u32::MAX)
506                // TODO: set non-zero value to support MPP.
507                .with_max_path_count(1)
508                // Allow sending htlcs up to 50% of the channel's capacity.
509                .with_max_channel_saturation_power_of_half(1),
510            final_value_msat: amount_msat,
511            max_total_routing_fee_msat: None,
512        },
513        pathfinding_graph,
514        None,
515        &WrappedLog {},
516        &scorer,
517        &Default::default(),
518        &[0; 32],
519    )
520    .map_err(|e| SimulationError::SimulatedNetworkError(e.err))
521}
522
523#[async_trait]
524impl<T: SimNetwork> LightningNode for SimNode<'_, T> {
525    fn get_info(&self) -> &NodeInfo {
526        &self.info
527    }
528
529    async fn get_network(&mut self) -> Result<Network, LightningError> {
530        Ok(Network::Regtest)
531    }
532
533    /// send_payment picks a random preimage for a payment, dispatches it in the network and adds a tracking channel
534    /// to our node state to be used for subsequent track_payment calls.
535    async fn send_payment(
536        &mut self,
537        dest: PublicKey,
538        amount_msat: u64,
539    ) -> Result<PaymentHash, LightningError> {
540        // Create a sender and receiver pair that will be used to report the results of the payment and add them to
541        // our internal tracking state along with the chosen payment hash.
542        let (sender, receiver) = channel();
543        let preimage = PaymentPreimage(rand::random());
544        let payment_hash = PaymentHash(Sha256::hash(&preimage.0).to_byte_array());
545
546        // Check for payment hash collision, failing the payment if we happen to repeat one.
547        match self.in_flight.entry(payment_hash) {
548            Entry::Occupied(_) => {
549                return Err(LightningError::SendPaymentError(
550                    "payment hash exists".to_string(),
551                ));
552            },
553            Entry::Vacant(vacant) => {
554                vacant.insert(receiver);
555            },
556        }
557
558        let route = match find_payment_route(
559            &self.info.pubkey,
560            dest,
561            amount_msat,
562            &self.pathfinding_graph,
563        ) {
564            Ok(path) => path,
565            // In the case that we can't find a route for the payment, we still report a successful payment *api call*
566            // and report RouteNotFound to the tracking channel. This mimics the behavior of real nodes.
567            Err(e) => {
568                log::trace!("Could not find path for payment: {:?}.", e);
569
570                if let Err(e) = sender.send(Ok(PaymentResult {
571                    htlc_count: 0,
572                    payment_outcome: PaymentOutcome::RouteNotFound,
573                })) {
574                    log::error!("Could not send payment result: {:?}.", e);
575                }
576
577                return Ok(payment_hash);
578            },
579        };
580
581        // If we did successfully obtain a route, dispatch the payment through the network and then report success.
582        self.network
583            .lock()
584            .await
585            .dispatch_payment(self.info.pubkey, route, payment_hash, sender);
586
587        Ok(payment_hash)
588    }
589
590    /// track_payment blocks until a payment outcome is returned for the payment hash provided, or the shutdown listener
591    /// provided is triggered. This call will fail if the hash provided was not obtained by calling send_payment first.
592    async fn track_payment(
593        &mut self,
594        hash: &PaymentHash,
595        listener: Listener,
596    ) -> Result<PaymentResult, LightningError> {
597        match self.in_flight.remove(hash) {
598            Some(receiver) => {
599                select! {
600                    biased;
601                    _ = listener => Err(
602                        LightningError::TrackPaymentError("shutdown during payment tracking".to_string()),
603                    ),
604
605                    // If we get a payment result back, remove from our in flight set of payments and return the result.
606                    res = receiver => {
607                        res.map_err(|e| LightningError::TrackPaymentError(format!("channel receive err: {}", e)))?
608                    },
609                }
610            },
611            None => Err(LightningError::TrackPaymentError(format!(
612                "payment hash {} not found",
613                hex::encode(hash.0),
614            ))),
615        }
616    }
617
618    async fn get_node_info(&mut self, node_id: &PublicKey) -> Result<NodeInfo, LightningError> {
619        Ok(self.network.lock().await.lookup_node(node_id).await?.0)
620    }
621
622    async fn list_channels(&mut self) -> Result<Vec<u64>, LightningError> {
623        Ok(self
624            .network
625            .lock()
626            .await
627            .lookup_node(&self.info.pubkey)
628            .await?
629            .1)
630    }
631}
632
633/// Graph is the top level struct that is used to coordinate simulation of lightning nodes.
634pub struct SimGraph {
635    /// nodes caches the list of nodes in the network with a vector of their channel capacities, only used for quick
636    /// lookup.
637    nodes: HashMap<PublicKey, Vec<u64>>,
638
639    /// channels maps the scid of a channel to its current simulation state.
640    channels: Arc<Mutex<HashMap<ShortChannelID, SimulatedChannel>>>,
641
642    /// track all tasks spawned to process payments in the graph.
643    tasks: JoinSet<()>,
644
645    /// trigger shutdown if a critical error occurs.
646    shutdown_trigger: Trigger,
647}
648
649impl SimGraph {
650    /// Creates a graph on which to simulate payments.
651    pub fn new(
652        graph_channels: Vec<SimulatedChannel>,
653        shutdown_trigger: Trigger,
654    ) -> Result<Self, SimulationError> {
655        let mut nodes: HashMap<PublicKey, Vec<u64>> = HashMap::new();
656        let mut channels = HashMap::new();
657
658        for channel in graph_channels.iter() {
659            // Assert that the channel is valid and that its short channel ID is unique within the simulation, required
660            // because we use scid to identify the channel.
661            channel.validate()?;
662            match channels.entry(channel.short_channel_id) {
663                Entry::Occupied(_) => {
664                    return Err(SimulationError::SimulatedNetworkError(format!(
665                        "Simulated short channel ID should be unique: {} duplicated",
666                        channel.short_channel_id
667                    )))
668                },
669                Entry::Vacant(v) => v.insert(channel.clone()),
670            };
671
672            // It's okay to have duplicate pubkeys because one node can have many channels.
673            for pubkey in [channel.node_1.policy.pubkey, channel.node_2.policy.pubkey] {
674                match nodes.entry(pubkey) {
675                    Entry::Occupied(o) => o.into_mut().push(channel.capacity_msat),
676                    Entry::Vacant(v) => {
677                        v.insert(vec![channel.capacity_msat]);
678                    },
679                }
680            }
681        }
682
683        Ok(SimGraph {
684            nodes,
685            channels: Arc::new(Mutex::new(channels)),
686            tasks: JoinSet::new(),
687            shutdown_trigger,
688        })
689    }
690
691    /// Blocks until all tasks created by the simulator have shut down. This function does not trigger shutdown,
692    /// because it expects erroring-out tasks to handle their own shutdown triggering.
693    pub async fn wait_for_shutdown(&mut self) {
694        log::debug!("Waiting for simulated graph to shutdown.");
695
696        while let Some(res) = self.tasks.join_next().await {
697            if let Err(e) = res {
698                log::error!("Graph task exited with error: {e}");
699            }
700        }
701
702        log::debug!("Simulated graph shutdown.");
703    }
704}
705
706/// Produces a map of node public key to lightning node implementation to be used for simulations.
707pub async fn ln_node_from_graph<'a>(
708    graph: Arc<Mutex<SimGraph>>,
709    routing_graph: Arc<NetworkGraph<&'_ WrappedLog>>,
710) -> HashMap<PublicKey, Arc<Mutex<dyn LightningNode + '_>>> {
711    let mut nodes: HashMap<PublicKey, Arc<Mutex<dyn LightningNode>>> = HashMap::new();
712
713    for pk in graph.lock().await.nodes.keys() {
714        nodes.insert(
715            *pk,
716            Arc::new(Mutex::new(SimNode::new(
717                *pk,
718                graph.clone(),
719                routing_graph.clone(),
720            ))),
721        );
722    }
723
724    nodes
725}
726
727/// Populates a network graph based on the set of simulated channels provided. This function *only* applies channel
728/// announcements, which has the effect of adding the nodes in each channel to the graph, because LDK does not export
729/// all of the fields required to apply node announcements. This means that we will not have node-level information
730/// (such as features) available in the routing graph.
731pub fn populate_network_graph<'a>(
732    channels: Vec<SimulatedChannel>,
733) -> Result<NetworkGraph<&'a WrappedLog>, LdkError> {
734    let graph = NetworkGraph::new(Network::Regtest, &WrappedLog {});
735
736    let chain_hash = ChainHash::using_genesis_block(Network::Regtest);
737
738    for channel in channels {
739        let announcement = UnsignedChannelAnnouncement {
740            // For our purposes we don't currently need any channel level features.
741            features: ChannelFeatures::empty(),
742            chain_hash,
743            short_channel_id: channel.short_channel_id.into(),
744            node_id_1: NodeId::from_pubkey(&channel.node_1.policy.pubkey),
745            node_id_2: NodeId::from_pubkey(&channel.node_2.policy.pubkey),
746            // Note: we don't need bitcoin keys for our purposes, so we just copy them *but* remember that we do use
747            // this for our fake utxo validation so they do matter for producing the script that we mock validate.
748            bitcoin_key_1: NodeId::from_pubkey(&channel.node_1.policy.pubkey),
749            bitcoin_key_2: NodeId::from_pubkey(&channel.node_2.policy.pubkey),
750            // Internal field used by LDK, we don't need it.
751            excess_data: Vec::new(),
752        };
753
754        let utxo_validator = UtxoValidator {
755            amount_sat: channel.capacity_msat / 1000,
756            script: make_funding_redeemscript(
757                &channel.node_1.policy.pubkey,
758                &channel.node_2.policy.pubkey,
759            )
760            .to_v0_p2wsh(),
761        };
762
763        graph.update_channel_from_unsigned_announcement(&announcement, &Some(&utxo_validator))?;
764
765        // The least significant bit of the channel flag field represents the direction that the channel update
766        // applies to. This value is interpreted as node_1 if it is zero, and node_2 otherwise.
767        for (i, node) in [channel.node_1, channel.node_2].iter().enumerate() {
768            let update = UnsignedChannelUpdate {
769                chain_hash,
770                short_channel_id: channel.short_channel_id.into(),
771                timestamp: SystemTime::now()
772                    .duration_since(UNIX_EPOCH)
773                    .unwrap()
774                    .as_secs() as u32,
775                flags: i as u8,
776                cltv_expiry_delta: node.policy.cltv_expiry_delta as u16,
777                htlc_minimum_msat: node.policy.min_htlc_size_msat,
778                htlc_maximum_msat: node.policy.max_htlc_size_msat,
779                fee_base_msat: node.policy.base_fee as u32,
780                fee_proportional_millionths: node.policy.fee_rate_prop as u32,
781                excess_data: Vec::new(),
782            };
783            graph.update_channel_unsigned(&update)?;
784        }
785    }
786
787    Ok(graph)
788}
789
790#[async_trait]
791impl SimNetwork for SimGraph {
792    /// dispatch_payment asynchronously propagates a payment through the simulated network, returning a tracking
793    /// channel that can be used to obtain the result of the payment. At present, MPP payments are not supported.
794    /// In future, we'll allow multiple paths for a single payment, so we allow the trait to accept a route with
795    /// multiple paths to avoid future refactoring.
796    fn dispatch_payment(
797        &mut self,
798        source: PublicKey,
799        route: Route,
800        payment_hash: PaymentHash,
801        sender: Sender<Result<PaymentResult, LightningError>>,
802    ) {
803        // Expect at least one path (right now), with the intention to support multiple in future.
804        let path = match route.paths.first() {
805            Some(p) => p,
806            None => {
807                log::warn!("Find route did not return expected number of paths.");
808
809                if let Err(e) = sender.send(Ok(PaymentResult {
810                    htlc_count: 0,
811                    payment_outcome: PaymentOutcome::RouteNotFound,
812                })) {
813                    log::error!("Could not send payment result: {:?}.", e);
814                }
815
816                return;
817            },
818        };
819
820        self.tasks.spawn(propagate_payment(
821            self.channels.clone(),
822            source,
823            path.clone(),
824            payment_hash,
825            sender,
826            self.shutdown_trigger.clone(),
827        ));
828    }
829
830    /// lookup_node fetches a node's information and channel capacities.
831    async fn lookup_node(&self, node: &PublicKey) -> Result<(NodeInfo, Vec<u64>), LightningError> {
832        self.nodes
833            .get(node)
834            .map(|channels| (node_info(*node), channels.clone()))
835            .ok_or(LightningError::GetNodeInfoError(
836                "Node not found".to_string(),
837            ))
838    }
839}
840
841/// Adds htlcs to the simulation state along the path provided. Returning the index in the path from which to fail
842/// back htlcs (if any) and a forwarding error if the payment is not successfully added to the entire path.
843///
844/// For each hop in the route, we check both the addition of the HTLC and whether we can forward it. Take an example
845/// route A --> B --> C, we will add this in two hops: A --> B then B -->C. For each hop, using A --> B as an example:
846/// * Check whether A can add the outgoing HTLC (checks liquidity and in-flight restrictions).
847///   * If no, fail the HTLC.
848///   * If yes, add outgoing HTLC to A's channel.
849/// * Check whether B will accept the forward.
850///   * If no, fail the HTLC.
851///   * If yes, continue to the next hop.
852///
853/// If successfully added to A --> B, this check will be repeated for B --> C.
854///
855/// Note that we don't have any special handling for the receiving node, once we've successfully added a outgoing HTLC
856/// for the outgoing channel that is connected to the receiving node we'll return. To add invoice-related handling,
857/// we'd need to include some logic that then decides whether to settle/fail the HTLC at the last hop here.
858async fn add_htlcs(
859    nodes: Arc<Mutex<HashMap<ShortChannelID, SimulatedChannel>>>,
860    source: PublicKey,
861    route: Path,
862    payment_hash: PaymentHash,
863) -> Result<(), (Option<usize>, ForwardingError)> {
864    let mut outgoing_node = source;
865    let mut outgoing_amount = route.fee_msat() + route.final_value_msat();
866    let mut outgoing_cltv = route.hops.iter().map(|hop| hop.cltv_expiry_delta).sum();
867
868    // Tracks the hop index that we need to remove htlcs from on payment completion (both success and failure).
869    // Given a payment from A to C, over the route A -- B -- C, this index has the following meanings:
870    // - None: A could not add the outgoing HTLC to B, no action for payment failure.
871    // - Some(0): A -- B added the HTLC but B could not forward the HTLC to C, so it only needs removing on A -- B.
872    // - Some(1): A -- B and B -- C added the HTLC, so it should be removed from the full route.
873    let mut fail_idx = None;
874
875    for (i, hop) in route.hops.iter().enumerate() {
876        // Lock the node that we want to add the HTLC to next. We choose to lock one hop at a time (rather than for
877        // the whole route) so that we can mimic the behavior of payments in the real network where the HTLCs in a
878        // route don't all get to lock in in a row (they have interactions with other payments).
879        let mut node_lock = nodes.lock().await;
880        let scid = ShortChannelID::from(hop.short_channel_id);
881
882        if let Some(channel) = node_lock.get_mut(&scid) {
883            channel
884                .add_htlc(
885                    &outgoing_node,
886                    payment_hash,
887                    Htlc {
888                        amount_msat: outgoing_amount,
889                        cltv_expiry: outgoing_cltv,
890                    },
891                )
892                // If we couldn't add to this HTLC, we only need to fail back from the preceding hop, so we don't
893                // have to progress our fail_idx.
894                .map_err(|e| (fail_idx, e))?;
895
896            // If the HTLC was successfully added, then we'll need to remove the HTLC from this channel if we fail,
897            // so we progress our failure index to include this node.
898            fail_idx = Some(i);
899
900            // Once we've added the HTLC on this hop's channel, we want to check whether it has sufficient fee
901            // and CLTV delta per the _next_ channel's policy (because fees and CLTV delta in LN are charged on
902            // the outgoing link). We check the policy belonging to the node that we just forwarded to, which
903            // represents the fee in that direction.
904            //
905            // TODO: add invoice-related checks (including final CTLV) if we support non-keysend payments.
906            if i != route.hops.len() - 1 {
907                if let Some(channel) =
908                    node_lock.get(&ShortChannelID::from(route.hops[i + 1].short_channel_id))
909                {
910                    channel
911                        .check_htlc_forward(
912                            &hop.pubkey,
913                            hop.cltv_expiry_delta,
914                            outgoing_amount - hop.fee_msat,
915                            hop.fee_msat,
916                        )
917                        // If we haven't met forwarding conditions for the next channel's policy, then we fail at
918                        // the current index, because we've already added the HTLC as outgoing.
919                        .map_err(|e| (fail_idx, e))?;
920                }
921            }
922        } else {
923            return Err((fail_idx, ForwardingError::ChannelNotFound(scid)));
924        }
925
926        // Once we've taken the "hop" to the destination pubkey, it becomes the source of the next outgoing htlc.
927        outgoing_node = hop.pubkey;
928        outgoing_amount -= hop.fee_msat;
929        outgoing_cltv -= hop.cltv_expiry_delta;
930
931        // TODO: introduce artificial latency between hops?
932    }
933
934    Ok(())
935}
936
937/// Removes htlcs from the simulation state from the index in the path provided (backwards).
938///
939/// Taking the example of a payment over A --> B --> C --> D where the payment was rejected by C because it did not
940/// have enough liquidity to forward it, we will expect a failure index of 1 because the HTLC was successfully added
941/// to A and B's outgoing channels, but not C.
942///
943/// This function will remove the HTLC one hop at a time, working backwards from the failure index, so in this
944/// case B --> C and then B --> A. We lookup the HTLC on the incoming node because it will have tracked it in its
945/// outgoing in-flight HTLCs.
946async fn remove_htlcs(
947    nodes: Arc<Mutex<HashMap<ShortChannelID, SimulatedChannel>>>,
948    resolution_idx: usize,
949    source: PublicKey,
950    route: Path,
951    payment_hash: PaymentHash,
952    success: bool,
953) -> Result<(), ForwardingError> {
954    for (i, hop) in route.hops[0..=resolution_idx].iter().enumerate().rev() {
955        // When we add HTLCs, we do so on the state of the node that sent the htlc along the channel so we need to
956        // look up our incoming node so that we can remove it when we go backwards. For the first htlc, this is just
957        // the sending node, otherwise it's the hop before.
958        let incoming_node = if i == 0 {
959            source
960        } else {
961            route.hops[i - 1].pubkey
962        };
963
964        // As with when we add HTLCs, we remove them one hop at a time (rather than locking for the whole route) to
965        // mimic the behavior of payments in a real network.
966        match nodes
967            .lock()
968            .await
969            .get_mut(&ShortChannelID::from(hop.short_channel_id))
970        {
971            Some(channel) => channel.remove_htlc(&incoming_node, &payment_hash, success)?,
972            None => {
973                return Err(ForwardingError::ChannelNotFound(ShortChannelID::from(
974                    hop.short_channel_id,
975                )))
976            },
977        }
978    }
979
980    Ok(())
981}
982
983/// Finds a payment path from the source to destination nodes provided, and propagates the appropriate htlcs through
984/// the simulated network, notifying the sender channel provided of the payment outcome. If a critical error occurs,
985/// ie a breakdown of our state machine, it will still notify the payment outcome and will use the shutdown trigger
986/// to signal that we should exit.
987async fn propagate_payment(
988    nodes: Arc<Mutex<HashMap<ShortChannelID, SimulatedChannel>>>,
989    source: PublicKey,
990    route: Path,
991    payment_hash: PaymentHash,
992    sender: Sender<Result<PaymentResult, LightningError>>,
993    shutdown: Trigger,
994) {
995    // If we partially added HTLCs along the route, we need to fail them back to the source to clean up our partial
996    // state. It's possible that we failed with the very first add, and then we don't need to clean anything up.
997    let notify_result = if let Err((fail_idx, err)) =
998        add_htlcs(nodes.clone(), source, route.clone(), payment_hash).await
999    {
1000        if err.is_critical() {
1001            shutdown.trigger();
1002        }
1003
1004        if let Some(resolution_idx) = fail_idx {
1005            if let Err(e) =
1006                remove_htlcs(nodes, resolution_idx, source, route, payment_hash, false).await
1007            {
1008                if e.is_critical() {
1009                    shutdown.trigger();
1010                }
1011            }
1012        }
1013
1014        // We have more information about failures because we're in control of the whole route, so we log the
1015        // actual failure reason and then fail back with unknown failure type.
1016        log::debug!(
1017            "Forwarding failure for simulated payment {}: {err}",
1018            hex::encode(payment_hash.0)
1019        );
1020
1021        PaymentResult {
1022            htlc_count: 0,
1023            payment_outcome: PaymentOutcome::Unknown,
1024        }
1025    } else {
1026        // If we successfully added the htlc, go ahead and remove all the htlcs in the route with successful resolution.
1027        if let Err(e) = remove_htlcs(
1028            nodes,
1029            route.hops.len() - 1,
1030            source,
1031            route,
1032            payment_hash,
1033            true,
1034        )
1035        .await
1036        {
1037            if e.is_critical() {
1038                shutdown.trigger();
1039            }
1040
1041            log::error!("Could not remove htlcs from channel: {e}.");
1042        }
1043
1044        PaymentResult {
1045            htlc_count: 1,
1046            payment_outcome: PaymentOutcome::Success,
1047        }
1048    };
1049
1050    if let Err(e) = sender.send(Ok(notify_result)) {
1051        log::error!("Could not notify payment result: {:?}.", e);
1052    }
1053}
1054
1055/// WrappedLog implements LDK's logging trait so that we can provide pathfinding with a logger that uses our existing
1056/// logger.
1057pub struct WrappedLog {}
1058
1059impl Logger for WrappedLog {
1060    fn log(&self, record: Record) {
1061        match record.level {
1062            Level::Gossip => log::trace!("{}", record.args),
1063            Level::Trace => log::trace!("{}", record.args),
1064            Level::Debug => log::debug!("{}", record.args),
1065            // LDK has quite noisy info logging for pathfinding, so we downgrade their info logging to our debug level.
1066            Level::Info => log::debug!("{}", record.args),
1067            Level::Warn => log::warn!("{}", record.args),
1068            Level::Error => log::error!("{}", record.args),
1069        }
1070    }
1071}
1072
1073/// UtxoValidator is a mock utxo validator that just returns a fake output with the desired capacity for a channel.
1074struct UtxoValidator {
1075    amount_sat: u64,
1076    script: ScriptBuf,
1077}
1078
1079impl UtxoLookup for UtxoValidator {
1080    fn get_utxo(&self, _genesis_hash: &ChainHash, _short_channel_id: u64) -> UtxoResult {
1081        UtxoResult::Sync(Ok(TxOut {
1082            value: self.amount_sat,
1083            script_pubkey: self.script.clone(),
1084        }))
1085    }
1086}
1087
1088#[cfg(test)]
1089mod tests {
1090    use super::*;
1091    use crate::test_utils::get_random_keypair;
1092    use bitcoin::secp256k1::PublicKey;
1093    use lightning::routing::router::Route;
1094    use mockall::mock;
1095    use std::time::Duration;
1096    use tokio::sync::oneshot;
1097    use tokio::time::timeout;
1098
1099    /// Creates a test channel policy with its maximum HTLC size set to half of the in flight limit of the channel.
1100    /// The minimum HTLC size is hardcoded to 2 so that we can fall beneath this value with a 1 msat htlc.
1101    fn create_test_policy(max_in_flight_msat: u64) -> ChannelPolicy {
1102        let (_, pk) = get_random_keypair();
1103        ChannelPolicy {
1104            pubkey: pk,
1105            max_htlc_count: 10,
1106            max_in_flight_msat,
1107            min_htlc_size_msat: 2,
1108            max_htlc_size_msat: max_in_flight_msat / 2,
1109            cltv_expiry_delta: 10,
1110            base_fee: 1000,
1111            fee_rate_prop: 5000,
1112        }
1113    }
1114
1115    /// Creates a set of n simulated channels connected in a chain of channels, where the short channel ID of each
1116    /// channel is its index in the chain of channels and all capacity is on the side of the node that opened the
1117    /// channel.
1118    ///
1119    /// For example if n = 3 it will produce: node_1 -- node_2 -- node_3 -- node_4, connected by channels.
1120    fn create_simulated_channels(n: u64, capacity_msat: u64) -> Vec<SimulatedChannel> {
1121        let mut channels: Vec<SimulatedChannel> = vec![];
1122        let (_, first_node) = get_random_keypair();
1123
1124        // Create channels in a ring so that we'll get long payment paths.
1125        let mut node_1 = first_node;
1126        for i in 0..n {
1127            // Generate a new random node pubkey.
1128            let (_, node_2) = get_random_keypair();
1129
1130            let node_1_to_2 = ChannelPolicy {
1131                pubkey: node_1,
1132                max_htlc_count: 483,
1133                max_in_flight_msat: capacity_msat / 2,
1134                min_htlc_size_msat: 1,
1135                max_htlc_size_msat: capacity_msat / 2,
1136                cltv_expiry_delta: 40,
1137                base_fee: 1000 * i,
1138                fee_rate_prop: 1500 * i,
1139            };
1140
1141            let node_2_to_1 = ChannelPolicy {
1142                pubkey: node_2,
1143                max_htlc_count: 483,
1144                max_in_flight_msat: capacity_msat / 2,
1145                min_htlc_size_msat: 1,
1146                max_htlc_size_msat: capacity_msat / 2,
1147                cltv_expiry_delta: 40 + 10 * i as u32,
1148                base_fee: 2000 * i,
1149                fee_rate_prop: i,
1150            };
1151
1152            channels.push(SimulatedChannel {
1153                capacity_msat,
1154                // Unique channel ID per link.
1155                short_channel_id: ShortChannelID::from(i),
1156                node_1: ChannelState::new(node_1_to_2, capacity_msat),
1157                node_2: ChannelState::new(node_2_to_1, 0),
1158            });
1159
1160            // Progress source ID to create a chain of nodes.
1161            node_1 = node_2;
1162        }
1163
1164        channels
1165    }
1166
1167    macro_rules! assert_channel_balances {
1168        ($channel_state:expr, $local_balance:expr, $in_flight_len:expr, $in_flight_total:expr) => {
1169            assert_eq!($channel_state.local_balance_msat, $local_balance);
1170            assert_eq!($channel_state.in_flight.len(), $in_flight_len);
1171            assert_eq!($channel_state.in_flight_total(), $in_flight_total);
1172        };
1173    }
1174
1175    /// Tests state updates related to adding and removing HTLCs to a channel.
1176    #[test]
1177    fn test_channel_state_transitions() {
1178        let local_balance = 100_000_000;
1179        let mut channel_state =
1180            ChannelState::new(create_test_policy(local_balance / 2), local_balance);
1181
1182        // Basic sanity check that we Initialize the channel correctly.
1183        assert_channel_balances!(channel_state, local_balance, 0, 0);
1184
1185        // Add a few HTLCs to our internal state and assert that balances are as expected. We'll test
1186        // `check_outgoing_addition` in more detail in another test, so we just assert that we can add the htlc in
1187        // this test.
1188        let hash_1 = PaymentHash([1; 32]);
1189        let htlc_1 = Htlc {
1190            amount_msat: 1000,
1191            cltv_expiry: 40,
1192        };
1193
1194        assert!(channel_state.add_outgoing_htlc(hash_1, htlc_1).is_ok());
1195        assert_channel_balances!(
1196            channel_state,
1197            local_balance - htlc_1.amount_msat,
1198            1,
1199            htlc_1.amount_msat
1200        );
1201
1202        // Try to add a htlc with the same payment hash and assert that we fail because we enforce one htlc per hash
1203        // at present.
1204        assert!(matches!(
1205            channel_state.add_outgoing_htlc(hash_1, htlc_1),
1206            Err(ForwardingError::PaymentHashExists(_))
1207        ));
1208
1209        // Add a second, distinct htlc to our in-flight state.
1210        let hash_2 = PaymentHash([2; 32]);
1211        let htlc_2 = Htlc {
1212            amount_msat: 1000,
1213            cltv_expiry: 40,
1214        };
1215
1216        assert!(channel_state.add_outgoing_htlc(hash_2, htlc_2).is_ok());
1217        assert_channel_balances!(
1218            channel_state,
1219            local_balance - htlc_1.amount_msat - htlc_2.amount_msat,
1220            2,
1221            htlc_1.amount_msat + htlc_2.amount_msat
1222        );
1223
1224        // Remove our second htlc with a failure so that our in-flight drops and we return the balance.
1225        assert!(channel_state.remove_outgoing_htlc(&hash_2).is_ok());
1226        channel_state.settle_outgoing_htlc(htlc_2.amount_msat, false);
1227        assert_channel_balances!(
1228            channel_state,
1229            local_balance - htlc_1.amount_msat,
1230            1,
1231            htlc_1.amount_msat
1232        );
1233
1234        // Try to remove the same htlc and assert that we fail because the htlc can't be found.
1235        assert!(matches!(
1236            channel_state.remove_outgoing_htlc(&hash_2),
1237            Err(ForwardingError::PaymentHashNotFound(_))
1238        ));
1239
1240        // Finally, remove our original htlc with success and assert that our local balance is accordingly updated.
1241        assert!(channel_state.remove_outgoing_htlc(&hash_1).is_ok());
1242        channel_state.settle_outgoing_htlc(htlc_1.amount_msat, true);
1243        assert_channel_balances!(channel_state, local_balance - htlc_1.amount_msat, 0, 0);
1244    }
1245
1246    /// Tests policy checks applied when forwarding a htlc over a channel.
1247    #[test]
1248    fn test_htlc_forward() {
1249        let local_balance = 140_000;
1250        let channel_state = ChannelState::new(create_test_policy(local_balance / 2), local_balance);
1251
1252        // CLTV delta insufficient (one less than required).
1253        assert!(matches!(
1254            channel_state.check_htlc_forward(channel_state.policy.cltv_expiry_delta - 1, 0, 0),
1255            Err(ForwardingError::InsufficientCltvDelta(_, _))
1256        ));
1257
1258        // Test insufficient fee.
1259        let htlc_amount = 1000;
1260        let htlc_fee = channel_state.policy.base_fee
1261            + (channel_state.policy.fee_rate_prop * htlc_amount) / 1e6 as u64;
1262
1263        assert!(matches!(
1264            channel_state.check_htlc_forward(
1265                channel_state.policy.cltv_expiry_delta,
1266                htlc_amount,
1267                htlc_fee - 1
1268            ),
1269            Err(ForwardingError::InsufficientFee(_, _, _, _))
1270        ));
1271
1272        // Test exact and over-estimation of required policy.
1273        assert!(channel_state
1274            .check_htlc_forward(
1275                channel_state.policy.cltv_expiry_delta,
1276                htlc_amount,
1277                htlc_fee,
1278            )
1279            .is_ok());
1280
1281        assert!(channel_state
1282            .check_htlc_forward(
1283                channel_state.policy.cltv_expiry_delta * 2,
1284                htlc_amount,
1285                htlc_fee * 3
1286            )
1287            .is_ok());
1288    }
1289
1290    /// Test addition of outgoing htlc to local state.
1291    #[test]
1292    fn test_check_outgoing_addition() {
1293        // Create test channel with low local liquidity so that we run into failures.
1294        let local_balance = 100_000;
1295        let mut channel_state =
1296            ChannelState::new(create_test_policy(local_balance / 2), local_balance);
1297
1298        let mut htlc = Htlc {
1299            amount_msat: channel_state.policy.max_htlc_size_msat + 1,
1300            cltv_expiry: channel_state.policy.cltv_expiry_delta,
1301        };
1302        // HTLC maximum size exceeded.
1303        assert!(matches!(
1304            channel_state.check_outgoing_addition(&htlc),
1305            Err(ForwardingError::MoreThanMaximum(_, _))
1306        ));
1307
1308        // Beneath HTLC minimum size.
1309        htlc.amount_msat = channel_state.policy.min_htlc_size_msat - 1;
1310        assert!(matches!(
1311            channel_state.check_outgoing_addition(&htlc),
1312            Err(ForwardingError::LessThanMinimum(_, _))
1313        ));
1314
1315        // Add two large htlcs so that we will start to run into our in-flight total amount limit.
1316        let hash_1 = PaymentHash([1; 32]);
1317        let htlc_1 = Htlc {
1318            amount_msat: channel_state.policy.max_in_flight_msat / 2,
1319            cltv_expiry: channel_state.policy.cltv_expiry_delta,
1320        };
1321
1322        assert!(channel_state.check_outgoing_addition(&htlc_1).is_ok());
1323        assert!(channel_state.add_outgoing_htlc(hash_1, htlc_1).is_ok());
1324
1325        let hash_2 = PaymentHash([2; 32]);
1326        let htlc_2 = Htlc {
1327            amount_msat: channel_state.policy.max_in_flight_msat / 2,
1328            cltv_expiry: channel_state.policy.cltv_expiry_delta,
1329        };
1330
1331        assert!(channel_state.check_outgoing_addition(&htlc_2).is_ok());
1332        assert!(channel_state.add_outgoing_htlc(hash_2, htlc_2).is_ok());
1333
1334        // Now, assert that we can't add even our smallest htlc size, because we're hit our in-flight amount limit.
1335        htlc.amount_msat = channel_state.policy.min_htlc_size_msat;
1336        assert!(matches!(
1337            channel_state.check_outgoing_addition(&htlc),
1338            Err(ForwardingError::ExceedsInFlightTotal(_, _))
1339        ));
1340
1341        // Resolve both of the htlcs successfully so that the local liquidity is no longer available.
1342        assert!(channel_state.remove_outgoing_htlc(&hash_1).is_ok());
1343        channel_state.settle_outgoing_htlc(htlc_1.amount_msat, true);
1344
1345        assert!(channel_state.remove_outgoing_htlc(&hash_2).is_ok());
1346        channel_state.settle_outgoing_htlc(htlc_2.amount_msat, true);
1347
1348        // Now we're going to add many htlcs so that we hit our in-flight count limit (unique payment hash per htlc).
1349        for i in 0..channel_state.policy.max_htlc_count {
1350            let hash = PaymentHash([i.try_into().unwrap(); 32]);
1351            assert!(channel_state.check_outgoing_addition(&htlc).is_ok());
1352            assert!(channel_state.add_outgoing_htlc(hash, htlc).is_ok());
1353        }
1354
1355        // Try to add one more htlc and we should be rejected.
1356        let htlc_3 = Htlc {
1357            amount_msat: channel_state.policy.min_htlc_size_msat,
1358            cltv_expiry: channel_state.policy.cltv_expiry_delta,
1359        };
1360
1361        assert!(matches!(
1362            channel_state.check_outgoing_addition(&htlc_3),
1363            Err(ForwardingError::ExceedsInFlightCount(_, _))
1364        ));
1365
1366        // Resolve all in-flight htlcs.
1367        for i in 0..channel_state.policy.max_htlc_count {
1368            let hash = PaymentHash([i.try_into().unwrap(); 32]);
1369            assert!(channel_state.remove_outgoing_htlc(&hash).is_ok());
1370            channel_state.settle_outgoing_htlc(htlc.amount_msat, true)
1371        }
1372
1373        // Add and settle another htlc to move more liquidity away from our local balance.
1374        let hash_4 = PaymentHash([1; 32]);
1375        let htlc_4 = Htlc {
1376            amount_msat: channel_state.policy.max_htlc_size_msat,
1377            cltv_expiry: channel_state.policy.cltv_expiry_delta,
1378        };
1379        assert!(channel_state.check_outgoing_addition(&htlc_4).is_ok());
1380        assert!(channel_state.add_outgoing_htlc(hash_4, htlc_4).is_ok());
1381        assert!(channel_state.remove_outgoing_htlc(&hash_4).is_ok());
1382        channel_state.settle_outgoing_htlc(htlc_4.amount_msat, true);
1383
1384        // Finally, assert that we don't have enough balance to forward our largest possible htlc (because of all the
1385        // htlcs that we've settled) and assert that we fail to a large htlc. The balance assertion here is just a
1386        // sanity check for the test, which will fail if we change the amounts settled/failed in the test.
1387        assert!(channel_state.local_balance_msat < channel_state.policy.max_htlc_size_msat);
1388        assert!(matches!(
1389            channel_state.check_outgoing_addition(&htlc_4),
1390            Err(ForwardingError::InsufficientBalance(_, _))
1391        ));
1392    }
1393
1394    /// Tests basic functionality of a `SimulatedChannel` but does no endeavor to test the underlying
1395    /// `ChannelState`, as this is covered elsewhere in our tests.
1396    #[test]
1397    fn test_simulated_channel() {
1398        // Create a test channel with all balance available to node 1 as local liquidity, and none for node_2 to begin
1399        // with.
1400        let capacity_msat = 500_000_000;
1401        let node_1 = ChannelState::new(create_test_policy(capacity_msat / 2), capacity_msat);
1402        let node_2 = ChannelState::new(create_test_policy(capacity_msat / 2), 0);
1403
1404        let mut simulated_channel = SimulatedChannel {
1405            capacity_msat,
1406            short_channel_id: ShortChannelID::from(123),
1407            node_1: node_1.clone(),
1408            node_2: node_2.clone(),
1409        };
1410
1411        // Assert that we're not able to send a htlc over node_2 -> node_1 (no liquidity).
1412        let hash_1 = PaymentHash([1; 32]);
1413        let htlc_1 = Htlc {
1414            amount_msat: node_2.policy.min_htlc_size_msat,
1415            cltv_expiry: node_1.policy.cltv_expiry_delta,
1416        };
1417
1418        assert!(matches!(
1419            simulated_channel.add_htlc(&node_2.policy.pubkey, hash_1, htlc_1),
1420            Err(ForwardingError::InsufficientBalance(_, _))
1421        ));
1422
1423        // Assert that we can send a htlc over node_1 -> node_2.
1424        let hash_2 = PaymentHash([1; 32]);
1425        let htlc_2 = Htlc {
1426            amount_msat: node_1.policy.max_htlc_size_msat,
1427            cltv_expiry: node_2.policy.cltv_expiry_delta,
1428        };
1429        assert!(simulated_channel
1430            .add_htlc(&node_1.policy.pubkey, hash_2, htlc_2)
1431            .is_ok());
1432
1433        // Settle the htlc and then assert that we can send from node_2 -> node_2 because the balance has been shifted
1434        // across channels.
1435        assert!(simulated_channel
1436            .remove_htlc(&node_1.policy.pubkey, &hash_2, true)
1437            .is_ok());
1438
1439        assert!(simulated_channel
1440            .add_htlc(&node_2.policy.pubkey, hash_2, htlc_2)
1441            .is_ok());
1442
1443        // Finally, try to add/remove htlcs for a pubkey that is not participating in the channel and assert that we
1444        // fail.
1445        let (_, pk) = get_random_keypair();
1446        assert!(matches!(
1447            simulated_channel.add_htlc(&pk, hash_2, htlc_2),
1448            Err(ForwardingError::NodeNotFound(_))
1449        ));
1450
1451        assert!(matches!(
1452            simulated_channel.remove_htlc(&pk, &hash_2, true),
1453            Err(ForwardingError::NodeNotFound(_))
1454        ));
1455    }
1456
1457    mock! {
1458        Network{}
1459
1460        #[async_trait]
1461        impl SimNetwork for Network{
1462            fn dispatch_payment(
1463                &mut self,
1464                source: PublicKey,
1465                route: Route,
1466                payment_hash: PaymentHash,
1467                sender: Sender<Result<PaymentResult, LightningError>>,
1468            );
1469
1470            async fn lookup_node(&self, node: &PublicKey) -> Result<(NodeInfo, Vec<u64>), LightningError>;
1471        }
1472    }
1473
1474    /// Tests the functionality of a `SimNode`, mocking out the `SimNetwork` that is responsible for payment
1475    /// propagation to isolate testing to just the implementation of `LightningNode`.
1476    #[tokio::test]
1477    async fn test_simulated_node() {
1478        // Mock out our network and create a routing graph with 5 hops.
1479        let mock = MockNetwork::new();
1480        let sim_network = Arc::new(Mutex::new(mock));
1481        let channels = create_simulated_channels(5, 300000000);
1482        let graph = populate_network_graph(channels.clone()).unwrap();
1483
1484        // Create a simulated node for the first channel in our network.
1485        let pk = channels[0].node_1.policy.pubkey;
1486        let mut node = SimNode::new(pk, sim_network.clone(), Arc::new(graph));
1487
1488        // Prime mock to return node info from lookup and assert that we get the pubkey we're expecting.
1489        let lookup_pk = channels[3].node_1.policy.pubkey;
1490        sim_network
1491            .lock()
1492            .await
1493            .expect_lookup_node()
1494            .returning(move |_| Ok((node_info(lookup_pk), vec![1, 2, 3])));
1495
1496        // Assert that we get three channels from the mock.
1497        let node_info = node.get_node_info(&lookup_pk).await.unwrap();
1498        assert_eq!(lookup_pk, node_info.pubkey);
1499        assert_eq!(node.list_channels().await.unwrap().len(), 3);
1500
1501        // Next, we're going to test handling of in-flight payments. To do this, we'll mock out calls to our dispatch
1502        // function to send different results depending on the destination.
1503        let dest_1 = channels[2].node_1.policy.pubkey;
1504        let dest_2 = channels[4].node_1.policy.pubkey;
1505
1506        sim_network
1507            .lock()
1508            .await
1509            .expect_dispatch_payment()
1510            .returning(
1511                move |_, route: Route, _, sender: Sender<Result<PaymentResult, LightningError>>| {
1512                    // If we've reached dispatch, we must have at least one path, grab the last hop to match the
1513                    // receiver.
1514                    let receiver = route.paths[0].hops.last().unwrap().pubkey;
1515                    let result = if receiver == dest_1 {
1516                        PaymentResult {
1517                            htlc_count: 2,
1518                            payment_outcome: PaymentOutcome::Success,
1519                        }
1520                    } else if receiver == dest_2 {
1521                        PaymentResult {
1522                            htlc_count: 0,
1523                            payment_outcome: PaymentOutcome::InsufficientBalance,
1524                        }
1525                    } else {
1526                        panic!("unknown mocked receiver");
1527                    };
1528
1529                    sender.send(Ok(result)).unwrap();
1530                },
1531            );
1532
1533        // Dispatch payments to different destinations and assert that our track payment results are as expected.
1534        let hash_1 = node.send_payment(dest_1, 10_000).await.unwrap();
1535        let hash_2 = node.send_payment(dest_2, 15_000).await.unwrap();
1536
1537        let (_, shutdown_listener) = triggered::trigger();
1538
1539        let result_1 = node
1540            .track_payment(&hash_1, shutdown_listener.clone())
1541            .await
1542            .unwrap();
1543        assert!(matches!(result_1.payment_outcome, PaymentOutcome::Success));
1544
1545        let result_2 = node
1546            .track_payment(&hash_2, shutdown_listener.clone())
1547            .await
1548            .unwrap();
1549        assert!(matches!(
1550            result_2.payment_outcome,
1551            PaymentOutcome::InsufficientBalance
1552        ));
1553    }
1554
1555    /// Contains elements required to test dispatch_payment functionality.
1556    struct DispatchPaymentTestKit<'a> {
1557        graph: SimGraph,
1558        nodes: Vec<PublicKey>,
1559        routing_graph: NetworkGraph<&'a WrappedLog>,
1560        shutdown: triggered::Trigger,
1561    }
1562
1563    impl<'a> DispatchPaymentTestKit<'a> {
1564        /// Creates a test graph with a set of nodes connected by three channels, with all the capacity of the channel
1565        /// on the side of the first node. For example, if called with capacity = 100 it will set up the following
1566        /// network:
1567        /// Alice (100) --- (0) Bob (100) --- (0) Carol (100) --- (0) Dave
1568        ///
1569        /// The nodes pubkeys in this chain of channels are provided in-order for easy access.
1570        async fn new(capacity: u64) -> Self {
1571            let (shutdown, _listener) = triggered::trigger();
1572            let channels = create_simulated_channels(3, capacity);
1573
1574            // Collect pubkeys in-order, pushing the last node on separately because they don't have an outgoing
1575            // channel (they are not node_1 in any channel, only node_2).
1576            let mut nodes = channels
1577                .iter()
1578                .map(|c| c.node_1.policy.pubkey)
1579                .collect::<Vec<PublicKey>>();
1580            nodes.push(channels.last().unwrap().node_2.policy.pubkey);
1581
1582            let kit = DispatchPaymentTestKit {
1583                graph: SimGraph::new(channels.clone(), shutdown.clone())
1584                    .expect("could not create test graph"),
1585                nodes,
1586                routing_graph: populate_network_graph(channels).unwrap(),
1587                shutdown,
1588            };
1589
1590            // Assert that our channel balance is all on the side of the channel opener when we start up.
1591            assert_eq!(
1592                kit.channel_balances().await,
1593                vec![(capacity, 0), (capacity, 0), (capacity, 0)]
1594            );
1595
1596            kit
1597        }
1598
1599        /// Returns a vector of local/remote channel balances for channels in the network.
1600        async fn channel_balances(&self) -> Vec<(u64, u64)> {
1601            let mut balances = vec![];
1602
1603            // We can't iterate through our hashmap of channels in-order, so we take advantage of our short channel id
1604            // being the index in our chain of channels. This allows us to look up channels in-order.
1605            let chan_count = self.graph.channels.lock().await.len();
1606
1607            for i in 0..chan_count {
1608                let chan_lock = self.graph.channels.lock().await;
1609                let channel = chan_lock.get(&ShortChannelID::from(i as u64)).unwrap();
1610
1611                // Take advantage of our test setup, which always makes node_1 the channel initiator to get our
1612                // "in order" balances for the chain of channels.
1613                balances.push((
1614                    channel.node_1.local_balance_msat,
1615                    channel.node_2.local_balance_msat,
1616                ));
1617            }
1618
1619            balances
1620        }
1621
1622        // Sends a test payment from source to destination and waits for the payment to complete, returning the route
1623        // used.
1624        async fn send_test_payemnt(
1625            &mut self,
1626            source: PublicKey,
1627            dest: PublicKey,
1628            amt: u64,
1629        ) -> Route {
1630            let route = find_payment_route(&source, dest, amt, &self.routing_graph).unwrap();
1631
1632            let (sender, receiver) = oneshot::channel();
1633            self.graph
1634                .dispatch_payment(source, route.clone(), PaymentHash([1; 32]), sender);
1635
1636            // Assert that we receive from the channel or fail.
1637            assert!(timeout(Duration::from_millis(10), receiver).await.is_ok());
1638
1639            route
1640        }
1641
1642        // Sets the balance on the channel to the tuple provided, used to arrange liquidity for testing.
1643        async fn set_channel_balance(&mut self, scid: &ShortChannelID, balance: (u64, u64)) {
1644            let mut channels_lock = self.graph.channels.lock().await;
1645            let channel = channels_lock.get_mut(scid).unwrap();
1646
1647            channel.node_1.local_balance_msat = balance.0;
1648            channel.node_2.local_balance_msat = balance.1;
1649
1650            assert!(channel.sanity_check().is_ok());
1651        }
1652    }
1653
1654    /// Tests dispatch of a successfully settled payment across a test network of simulated channels:
1655    /// Alice --- Bob --- Carol --- Dave
1656    #[tokio::test]
1657    async fn test_successful_dispatch() {
1658        let chan_capacity = 500_000_000;
1659        let mut test_kit = DispatchPaymentTestKit::new(chan_capacity).await;
1660
1661        // Send a payment that should succeed from Alice -> Dave.
1662        let mut amt = 20_000;
1663        let route = test_kit
1664            .send_test_payemnt(test_kit.nodes[0], test_kit.nodes[3], amt)
1665            .await;
1666
1667        let route_total = amt + route.get_total_fees();
1668        let hop_1_amt = amt + route.paths[0].hops[1].fee_msat;
1669
1670        // The sending node should have pushed the amount + total fee to the intermediary.
1671        let alice_to_bob = (chan_capacity - route_total, route_total);
1672        // The middle hop should include fees for the outgoing link.
1673        let mut bob_to_carol = (chan_capacity - hop_1_amt, hop_1_amt);
1674        // The receiving node should have the payment amount pushed to them.
1675        let carol_to_dave = (chan_capacity - amt, amt);
1676
1677        let mut expected_balances = vec![alice_to_bob, bob_to_carol, carol_to_dave];
1678        assert_eq!(test_kit.channel_balances().await, expected_balances);
1679
1680        // Next, we'll test the case where a payment fails on the first hop. This is an edge case in our state
1681        // machine, so we want to specifically hit it. To do this, we'll try to send double the amount that we just
1682        // pushed to Dave back to Bob, expecting a failure on Dave's outgoing link due to insufficient liquidity.
1683        let _ = test_kit
1684            .send_test_payemnt(test_kit.nodes[3], test_kit.nodes[1], amt * 2)
1685            .await;
1686        assert_eq!(test_kit.channel_balances().await, expected_balances);
1687
1688        // Now, test a successful single-hop payment from Bob -> Carol. We'll do this twice, so that we can drain all
1689        // the liquidity on Bob's side (to prepare for a multi-hop failure test). Our pathfinding only allows us to
1690        // use 50% of the channel's capacity, so we need to do two payments.
1691        amt = bob_to_carol.0 / 2;
1692        let _ = test_kit
1693            .send_test_payemnt(test_kit.nodes[1], test_kit.nodes[2], amt)
1694            .await;
1695
1696        bob_to_carol = (bob_to_carol.0 / 2, bob_to_carol.1 + amt);
1697        expected_balances = vec![alice_to_bob, bob_to_carol, carol_to_dave];
1698        assert_eq!(test_kit.channel_balances().await, expected_balances);
1699
1700        // When we push this amount a second time, all the liquidity should be moved to Carol's end.
1701        let _ = test_kit
1702            .send_test_payemnt(test_kit.nodes[1], test_kit.nodes[2], amt)
1703            .await;
1704        bob_to_carol = (0, chan_capacity);
1705        expected_balances = vec![alice_to_bob, bob_to_carol, carol_to_dave];
1706        assert_eq!(test_kit.channel_balances().await, expected_balances);
1707
1708        // Finally, we'll test a multi-hop failure by trying to send from Alice -> Dave. Since Bob's liquidity is
1709        // drained, we expect a failure and unchanged balances along the route.
1710        let _ = test_kit
1711            .send_test_payemnt(test_kit.nodes[0], test_kit.nodes[3], 20_000)
1712            .await;
1713        assert_eq!(test_kit.channel_balances().await, expected_balances);
1714
1715        test_kit.shutdown.trigger();
1716        test_kit.graph.wait_for_shutdown().await;
1717    }
1718
1719    /// Tests successful dispatch of a multi-hop payment.
1720    #[tokio::test]
1721    async fn test_successful_multi_hop() {
1722        let chan_capacity = 500_000_000;
1723        let mut test_kit = DispatchPaymentTestKit::new(chan_capacity).await;
1724
1725        // Send a payment that should succeed from Alice -> Dave.
1726        let amt = 20_000;
1727        let route = test_kit
1728            .send_test_payemnt(test_kit.nodes[0], test_kit.nodes[3], amt)
1729            .await;
1730
1731        let route_total = amt + route.get_total_fees();
1732        let hop_1_amt = amt + route.paths[0].hops[1].fee_msat;
1733
1734        let expected_balances = vec![
1735            // The sending node should have pushed the amount + total fee to the intermediary.
1736            (chan_capacity - route_total, route_total),
1737            // The middle hop should include fees for the outgoing link.
1738            (chan_capacity - hop_1_amt, hop_1_amt),
1739            // The receiving node should have the payment amount pushed to them.
1740            (chan_capacity - amt, amt),
1741        ];
1742        assert_eq!(test_kit.channel_balances().await, expected_balances);
1743
1744        test_kit.shutdown.trigger();
1745        test_kit.graph.wait_for_shutdown().await;
1746    }
1747
1748    /// Tests success and failure for single hop payments, which are an edge case in our state machine.
1749    #[tokio::test]
1750    async fn test_single_hop_payments() {
1751        let chan_capacity = 500_000_000;
1752        let mut test_kit = DispatchPaymentTestKit::new(chan_capacity).await;
1753
1754        // Send a single hop payment from Alice -> Bob, it will succeed because Alice has all the liquidity.
1755        let amt = 150_000;
1756        let _ = test_kit
1757            .send_test_payemnt(test_kit.nodes[0], test_kit.nodes[1], amt)
1758            .await;
1759
1760        let expected_balances = vec![
1761            (chan_capacity - amt, amt),
1762            (chan_capacity, 0),
1763            (chan_capacity, 0),
1764        ];
1765        assert_eq!(test_kit.channel_balances().await, expected_balances);
1766
1767        // Send a single hop payment from Dave -> Carol that will fail due to lack of liquidity, balances should be
1768        // unchanged.
1769        let _ = test_kit
1770            .send_test_payemnt(test_kit.nodes[3], test_kit.nodes[2], amt)
1771            .await;
1772
1773        assert_eq!(test_kit.channel_balances().await, expected_balances);
1774
1775        test_kit.shutdown.trigger();
1776        test_kit.graph.wait_for_shutdown().await;
1777    }
1778
1779    /// Tests failing back of multi-hop payments at various failure indexes.
1780    #[tokio::test]
1781    async fn test_multi_hop_faiulre() {
1782        let chan_capacity = 500_000_000;
1783        let mut test_kit = DispatchPaymentTestKit::new(chan_capacity).await;
1784
1785        // Drain liquidity between Bob and Carol to force failures on Bob's outgoing linke.
1786        test_kit
1787            .set_channel_balance(&ShortChannelID::from(1), (0, chan_capacity))
1788            .await;
1789
1790        let mut expected_balances =
1791            vec![(chan_capacity, 0), (0, chan_capacity), (chan_capacity, 0)];
1792        assert_eq!(test_kit.channel_balances().await, expected_balances);
1793
1794        // Send a payment from Alice -> Dave which we expect to fail leaving balances unaffected.
1795        let amt = 150_000;
1796        let _ = test_kit
1797            .send_test_payemnt(test_kit.nodes[0], test_kit.nodes[3], amt)
1798            .await;
1799
1800        assert_eq!(test_kit.channel_balances().await, expected_balances);
1801
1802        // Push liquidity to Dave so that we can send a payment which will fail on Bob's outgoing link, leaving
1803        // balances unaffected.
1804        expected_balances[2] = (0, chan_capacity);
1805        test_kit
1806            .set_channel_balance(&ShortChannelID::from(2), (0, chan_capacity))
1807            .await;
1808
1809        let _ = test_kit
1810            .send_test_payemnt(test_kit.nodes[3], test_kit.nodes[0], amt)
1811            .await;
1812
1813        assert_eq!(test_kit.channel_balances().await, expected_balances);
1814
1815        test_kit.shutdown.trigger();
1816        test_kit.graph.wait_for_shutdown().await;
1817    }
1818}