1use crate::{
2 LightningError, LightningNode, NodeInfo, PaymentOutcome, PaymentResult, SimulationError,
3};
4use async_trait::async_trait;
5use bitcoin::constants::ChainHash;
6use bitcoin::hashes::{sha256::Hash as Sha256, Hash};
7use bitcoin::secp256k1::PublicKey;
8use bitcoin::{Network, ScriptBuf, TxOut};
9use lightning::ln::chan_utils::make_funding_redeemscript;
10use std::collections::{hash_map::Entry, HashMap};
11use std::sync::Arc;
12use std::time::{SystemTime, UNIX_EPOCH};
13
14use lightning::ln::features::{ChannelFeatures, NodeFeatures};
15use lightning::ln::msgs::{
16 LightningError as LdkError, UnsignedChannelAnnouncement, UnsignedChannelUpdate,
17};
18use lightning::ln::{PaymentHash, PaymentPreimage};
19use lightning::routing::gossip::{NetworkGraph, NodeId};
20use lightning::routing::router::{find_route, Path, PaymentParameters, Route, RouteParameters};
21use lightning::routing::scoring::ProbabilisticScorer;
22use lightning::routing::utxo::{UtxoLookup, UtxoResult};
23use lightning::util::logger::{Level, Logger, Record};
24use thiserror::Error;
25use tokio::select;
26use tokio::sync::oneshot::{channel, Receiver, Sender};
27use tokio::sync::Mutex;
28use tokio::task::JoinSet;
29use triggered::{Listener, Trigger};
30
31use crate::ShortChannelID;
32
33#[derive(Debug, Error)]
37pub enum ForwardingError {
38 #[error("ZeroAmountHtlc")]
40 ZeroAmountHtlc,
41 #[error("ChannelNotFound: {0}")]
43 ChannelNotFound(ShortChannelID),
44 #[error("NodeNotFound: {0:?}")]
46 NodeNotFound(PublicKey),
47 #[error("PaymentHashExists: {0:?}")]
50 PaymentHashExists(PaymentHash),
51 #[error("PaymentHashNotFound: {0:?}")]
53 PaymentHashNotFound(PaymentHash),
54 #[error("InsufficientBalance: amount: {0} > balance: {1}")]
56 InsufficientBalance(u64, u64),
57 #[error("LessThanMinimum: amount: {0} < minimum: {1}")]
59 LessThanMinimum(u64, u64),
60 #[error("MoreThanMaximum: amount: {0} > maximum: {1}")]
62 MoreThanMaximum(u64, u64),
63 #[error("ExceedsInFlightCount: total in flight: {0} > maximum count: {1}")]
65 ExceedsInFlightCount(u64, u64),
66 #[error("ExceedsInFlightTotal: total in flight amount: {0} > maximum amount: {0}")]
69 ExceedsInFlightTotal(u64, u64),
70 #[error("ExpiryInSeconds: cltv expressed in seconds: {0}")]
72 ExpiryInSeconds(u32, u32),
73 #[error("InsufficientCltvDelta: cltv delta: {0} < required: {1}")]
75 InsufficientCltvDelta(u32, u32),
76 #[error("InsufficientFee: offered fee: {0} (base: {1}, prop: {2}) < expected: {3}")]
78 InsufficientFee(u64, u64, u64, u64),
79 #[error("FeeOverflow: htlc amount: {0} (base: {1}, prop: {2})")]
81 FeeOverflow(u64, u64, u64),
82 #[error("SanityCheckFailed: node balance: {0} != capacity: {1}")]
84 SanityCheckFailed(u64, u64),
85}
86
87impl ForwardingError {
88 fn is_critical(&self) -> bool {
90 matches!(
91 self,
92 ForwardingError::ZeroAmountHtlc
93 | ForwardingError::ChannelNotFound(_)
94 | ForwardingError::NodeNotFound(_)
95 | ForwardingError::PaymentHashExists(_)
96 | ForwardingError::PaymentHashNotFound(_)
97 | ForwardingError::SanityCheckFailed(_, _)
98 | ForwardingError::FeeOverflow(_, _, _)
99 )
100 }
101}
102
103#[derive(Copy, Clone)]
105struct Htlc {
106 amount_msat: u64,
107 cltv_expiry: u32,
108}
109
110#[derive(Clone)]
114pub struct ChannelPolicy {
115 pub pubkey: PublicKey,
116 pub max_htlc_count: u64,
117 pub max_in_flight_msat: u64,
118 pub min_htlc_size_msat: u64,
119 pub max_htlc_size_msat: u64,
120 pub cltv_expiry_delta: u32,
121 pub base_fee: u64,
122 pub fee_rate_prop: u64,
123}
124
125impl ChannelPolicy {
126 fn validate(&self, capacity_msat: u64) -> Result<(), SimulationError> {
128 if self.max_in_flight_msat > capacity_msat {
129 return Err(SimulationError::SimulatedNetworkError(format!(
130 "max_in_flight_msat {} > capacity {}",
131 self.max_in_flight_msat, capacity_msat
132 )));
133 }
134 if self.max_htlc_size_msat > capacity_msat {
135 return Err(SimulationError::SimulatedNetworkError(format!(
136 "max_htlc_size_msat {} > capacity {}",
137 self.max_htlc_size_msat, capacity_msat
138 )));
139 }
140 Ok(())
141 }
142}
143
144macro_rules! fail_forwarding_inequality {
146 ($value_1:expr, $op:tt, $value_2:expr, $error_variant:ident $(, $opt:expr)*) => {
147 if $value_1 $op $value_2 {
148 return Err(ForwardingError::$error_variant(
149 $value_1,
150 $value_2
151 $(
152 , $opt
153 )*
154 ));
155 }
156 };
157}
158
159#[derive(Clone)]
163struct ChannelState {
164 local_balance_msat: u64,
165 in_flight: HashMap<PaymentHash, Htlc>,
166 policy: ChannelPolicy,
167}
168
169impl ChannelState {
170 fn new(policy: ChannelPolicy, local_balance_msat: u64) -> Self {
174 ChannelState {
175 local_balance_msat,
176 in_flight: HashMap::new(),
177 policy,
178 }
179 }
180
181 fn in_flight_total(&self) -> u64 {
183 self.in_flight.values().map(|h| h.amount_msat).sum()
184 }
185
186 fn check_htlc_forward(
189 &self,
190 cltv_delta: u32,
191 amt: u64,
192 fee: u64,
193 ) -> Result<(), ForwardingError> {
194 fail_forwarding_inequality!(cltv_delta, <, self.policy.cltv_expiry_delta, InsufficientCltvDelta);
195
196 let expected_fee = amt
197 .checked_mul(self.policy.fee_rate_prop)
198 .and_then(|prop_fee| (prop_fee / 1000000).checked_add(self.policy.base_fee))
199 .ok_or(ForwardingError::FeeOverflow(
200 amt,
201 self.policy.base_fee,
202 self.policy.fee_rate_prop,
203 ))?;
204
205 fail_forwarding_inequality!(
206 fee, <, expected_fee, InsufficientFee, self.policy.base_fee, self.policy.fee_rate_prop
207 );
208
209 Ok(())
210 }
211
212 fn check_outgoing_addition(&self, htlc: &Htlc) -> Result<(), ForwardingError> {
217 fail_forwarding_inequality!(htlc.amount_msat, >, self.policy.max_htlc_size_msat, MoreThanMaximum);
218 fail_forwarding_inequality!(htlc.amount_msat, <, self.policy.min_htlc_size_msat, LessThanMinimum);
219 fail_forwarding_inequality!(
220 self.in_flight.len() as u64 + 1, >, self.policy.max_htlc_count, ExceedsInFlightCount
221 );
222 fail_forwarding_inequality!(
223 self.in_flight_total() + htlc.amount_msat, >, self.policy.max_in_flight_msat, ExceedsInFlightTotal
224 );
225 fail_forwarding_inequality!(htlc.amount_msat, >, self.local_balance_msat, InsufficientBalance);
226 fail_forwarding_inequality!(htlc.cltv_expiry, >, 500000000, ExpiryInSeconds);
227
228 Ok(())
229 }
230
231 fn add_outgoing_htlc(&mut self, hash: PaymentHash, htlc: Htlc) -> Result<(), ForwardingError> {
238 self.check_outgoing_addition(&htlc)?;
239 if self.in_flight.get(&hash).is_some() {
240 return Err(ForwardingError::PaymentHashExists(hash));
241 }
242 self.local_balance_msat -= htlc.amount_msat;
243 self.in_flight.insert(hash, htlc);
244 Ok(())
245 }
246
247 fn remove_outgoing_htlc(&mut self, hash: &PaymentHash) -> Result<Htlc, ForwardingError> {
249 self.in_flight
250 .remove(hash)
251 .ok_or(ForwardingError::PaymentHashNotFound(*hash))
252 }
253
254 fn settle_outgoing_htlc(&mut self, amt: u64, success: bool) {
258 if !success {
259 self.local_balance_msat += amt
260 }
261 }
262
263 fn settle_incoming_htlc(&mut self, amt: u64, success: bool) {
267 if success {
268 self.local_balance_msat += amt
269 }
270 }
271}
272
273#[derive(Clone)]
289pub struct SimulatedChannel {
290 capacity_msat: u64,
291 short_channel_id: ShortChannelID,
292 node_1: ChannelState,
293 node_2: ChannelState,
294}
295
296impl SimulatedChannel {
297 pub fn new(
300 capacity_msat: u64,
301 short_channel_id: ShortChannelID,
302 node_1: ChannelPolicy,
303 node_2: ChannelPolicy,
304 ) -> Self {
305 SimulatedChannel {
306 capacity_msat,
307 short_channel_id,
308 node_1: ChannelState::new(node_1, capacity_msat / 2),
309 node_2: ChannelState::new(node_2, capacity_msat / 2),
310 }
311 }
312
313 fn validate(&self) -> Result<(), SimulationError> {
315 if self.node_1.policy.pubkey == self.node_2.policy.pubkey {
316 return Err(SimulationError::SimulatedNetworkError(format!(
317 "Channel should have distinct node pubkeys, got: {} for both nodes.",
318 self.node_1.policy.pubkey
319 )));
320 }
321
322 self.node_1.policy.validate(self.capacity_msat)?;
323 self.node_2.policy.validate(self.capacity_msat)?;
324
325 Ok(())
326 }
327
328 fn get_node_mut(&mut self, pubkey: &PublicKey) -> Result<&mut ChannelState, ForwardingError> {
329 if pubkey == &self.node_1.policy.pubkey {
330 Ok(&mut self.node_1)
331 } else if pubkey == &self.node_2.policy.pubkey {
332 Ok(&mut self.node_2)
333 } else {
334 Err(ForwardingError::NodeNotFound(*pubkey))
335 }
336 }
337
338 fn get_node(&self, pubkey: &PublicKey) -> Result<&ChannelState, ForwardingError> {
339 if pubkey == &self.node_1.policy.pubkey {
340 Ok(&self.node_1)
341 } else if pubkey == &self.node_2.policy.pubkey {
342 Ok(&self.node_2)
343 } else {
344 Err(ForwardingError::NodeNotFound(*pubkey))
345 }
346 }
347
348 fn add_htlc(
352 &mut self,
353 sending_node: &PublicKey,
354 hash: PaymentHash,
355 htlc: Htlc,
356 ) -> Result<(), ForwardingError> {
357 if htlc.amount_msat == 0 {
358 return Err(ForwardingError::ZeroAmountHtlc);
359 }
360
361 self.get_node_mut(sending_node)?
362 .add_outgoing_htlc(hash, htlc)?;
363 self.sanity_check()
364 }
365
366 fn sanity_check(&self) -> Result<(), ForwardingError> {
369 let node_1_total = self.node_1.local_balance_msat + self.node_1.in_flight_total();
370 let node_2_total = self.node_2.local_balance_msat + self.node_2.in_flight_total();
371
372 fail_forwarding_inequality!(node_1_total + node_2_total, !=, self.capacity_msat, SanityCheckFailed);
373
374 Ok(())
375 }
376
377 fn remove_htlc(
382 &mut self,
383 sending_node: &PublicKey,
384 hash: &PaymentHash,
385 success: bool,
386 ) -> Result<(), ForwardingError> {
387 let htlc = self
388 .get_node_mut(sending_node)?
389 .remove_outgoing_htlc(hash)?;
390 self.settle_htlc(sending_node, htlc.amount_msat, success)?;
391 self.sanity_check()
392 }
393
394 fn settle_htlc(
398 &mut self,
399 sending_node: &PublicKey,
400 amount_msat: u64,
401 success: bool,
402 ) -> Result<(), ForwardingError> {
403 if sending_node == &self.node_1.policy.pubkey {
404 self.node_1.settle_outgoing_htlc(amount_msat, success);
405 self.node_2.settle_incoming_htlc(amount_msat, success);
406 Ok(())
407 } else if sending_node == &self.node_2.policy.pubkey {
408 self.node_2.settle_outgoing_htlc(amount_msat, success);
409 self.node_1.settle_incoming_htlc(amount_msat, success);
410 Ok(())
411 } else {
412 Err(ForwardingError::NodeNotFound(*sending_node))
413 }
414 }
415
416 fn check_htlc_forward(
418 &self,
419 forwarding_node: &PublicKey,
420 cltv_delta: u32,
421 amount_msat: u64,
422 fee_msat: u64,
423 ) -> Result<(), ForwardingError> {
424 self.get_node(forwarding_node)?
425 .check_htlc_forward(cltv_delta, amount_msat, fee_msat)
426 }
427}
428
429#[async_trait]
432trait SimNetwork: Send + Sync {
433 fn dispatch_payment(
436 &mut self,
437 source: PublicKey,
438 route: Route,
439 payment_hash: PaymentHash,
440 sender: Sender<Result<PaymentResult, LightningError>>,
441 );
442
443 async fn lookup_node(&self, node: &PublicKey) -> Result<(NodeInfo, Vec<u64>), LightningError>;
445}
446
447struct SimNode<'a, T: SimNetwork> {
452 info: NodeInfo,
453 network: Arc<Mutex<T>>,
455 in_flight: HashMap<PaymentHash, Receiver<Result<PaymentResult, LightningError>>>,
457 pathfinding_graph: Arc<NetworkGraph<&'a WrappedLog>>,
459}
460
461impl<'a, T: SimNetwork> SimNode<'a, T> {
462 pub fn new(
465 pubkey: PublicKey,
466 payment_network: Arc<Mutex<T>>,
467 pathfinding_graph: Arc<NetworkGraph<&'a WrappedLog>>,
468 ) -> Self {
469 SimNode {
470 info: node_info(pubkey),
471 network: payment_network,
472 in_flight: HashMap::new(),
473 pathfinding_graph,
474 }
475 }
476}
477
478fn node_info(pubkey: PublicKey) -> NodeInfo {
480 let mut features = NodeFeatures::empty();
482 features.set_keysend_optional();
483
484 NodeInfo {
485 pubkey,
486 alias: "".to_string(), features,
488 }
489}
490
491fn find_payment_route(
494 source: &PublicKey,
495 dest: PublicKey,
496 amount_msat: u64,
497 pathfinding_graph: &NetworkGraph<&WrappedLog>,
498) -> Result<Route, SimulationError> {
499 let scorer = ProbabilisticScorer::new(Default::default(), pathfinding_graph, &WrappedLog {});
500
501 find_route(
502 source,
503 &RouteParameters {
504 payment_params: PaymentParameters::from_node_id(dest, 0)
505 .with_max_total_cltv_expiry_delta(u32::MAX)
506 .with_max_path_count(1)
508 .with_max_channel_saturation_power_of_half(1),
510 final_value_msat: amount_msat,
511 max_total_routing_fee_msat: None,
512 },
513 pathfinding_graph,
514 None,
515 &WrappedLog {},
516 &scorer,
517 &Default::default(),
518 &[0; 32],
519 )
520 .map_err(|e| SimulationError::SimulatedNetworkError(e.err))
521}
522
523#[async_trait]
524impl<T: SimNetwork> LightningNode for SimNode<'_, T> {
525 fn get_info(&self) -> &NodeInfo {
526 &self.info
527 }
528
529 async fn get_network(&mut self) -> Result<Network, LightningError> {
530 Ok(Network::Regtest)
531 }
532
533 async fn send_payment(
536 &mut self,
537 dest: PublicKey,
538 amount_msat: u64,
539 ) -> Result<PaymentHash, LightningError> {
540 let (sender, receiver) = channel();
543 let preimage = PaymentPreimage(rand::random());
544 let payment_hash = PaymentHash(Sha256::hash(&preimage.0).to_byte_array());
545
546 match self.in_flight.entry(payment_hash) {
548 Entry::Occupied(_) => {
549 return Err(LightningError::SendPaymentError(
550 "payment hash exists".to_string(),
551 ));
552 },
553 Entry::Vacant(vacant) => {
554 vacant.insert(receiver);
555 },
556 }
557
558 let route = match find_payment_route(
559 &self.info.pubkey,
560 dest,
561 amount_msat,
562 &self.pathfinding_graph,
563 ) {
564 Ok(path) => path,
565 Err(e) => {
568 log::trace!("Could not find path for payment: {:?}.", e);
569
570 if let Err(e) = sender.send(Ok(PaymentResult {
571 htlc_count: 0,
572 payment_outcome: PaymentOutcome::RouteNotFound,
573 })) {
574 log::error!("Could not send payment result: {:?}.", e);
575 }
576
577 return Ok(payment_hash);
578 },
579 };
580
581 self.network
583 .lock()
584 .await
585 .dispatch_payment(self.info.pubkey, route, payment_hash, sender);
586
587 Ok(payment_hash)
588 }
589
590 async fn track_payment(
593 &mut self,
594 hash: &PaymentHash,
595 listener: Listener,
596 ) -> Result<PaymentResult, LightningError> {
597 match self.in_flight.remove(hash) {
598 Some(receiver) => {
599 select! {
600 biased;
601 _ = listener => Err(
602 LightningError::TrackPaymentError("shutdown during payment tracking".to_string()),
603 ),
604
605 res = receiver => {
607 res.map_err(|e| LightningError::TrackPaymentError(format!("channel receive err: {}", e)))?
608 },
609 }
610 },
611 None => Err(LightningError::TrackPaymentError(format!(
612 "payment hash {} not found",
613 hex::encode(hash.0),
614 ))),
615 }
616 }
617
618 async fn get_node_info(&mut self, node_id: &PublicKey) -> Result<NodeInfo, LightningError> {
619 Ok(self.network.lock().await.lookup_node(node_id).await?.0)
620 }
621
622 async fn list_channels(&mut self) -> Result<Vec<u64>, LightningError> {
623 Ok(self
624 .network
625 .lock()
626 .await
627 .lookup_node(&self.info.pubkey)
628 .await?
629 .1)
630 }
631}
632
633pub struct SimGraph {
635 nodes: HashMap<PublicKey, Vec<u64>>,
638
639 channels: Arc<Mutex<HashMap<ShortChannelID, SimulatedChannel>>>,
641
642 tasks: JoinSet<()>,
644
645 shutdown_trigger: Trigger,
647}
648
649impl SimGraph {
650 pub fn new(
652 graph_channels: Vec<SimulatedChannel>,
653 shutdown_trigger: Trigger,
654 ) -> Result<Self, SimulationError> {
655 let mut nodes: HashMap<PublicKey, Vec<u64>> = HashMap::new();
656 let mut channels = HashMap::new();
657
658 for channel in graph_channels.iter() {
659 channel.validate()?;
662 match channels.entry(channel.short_channel_id) {
663 Entry::Occupied(_) => {
664 return Err(SimulationError::SimulatedNetworkError(format!(
665 "Simulated short channel ID should be unique: {} duplicated",
666 channel.short_channel_id
667 )))
668 },
669 Entry::Vacant(v) => v.insert(channel.clone()),
670 };
671
672 for pubkey in [channel.node_1.policy.pubkey, channel.node_2.policy.pubkey] {
674 match nodes.entry(pubkey) {
675 Entry::Occupied(o) => o.into_mut().push(channel.capacity_msat),
676 Entry::Vacant(v) => {
677 v.insert(vec![channel.capacity_msat]);
678 },
679 }
680 }
681 }
682
683 Ok(SimGraph {
684 nodes,
685 channels: Arc::new(Mutex::new(channels)),
686 tasks: JoinSet::new(),
687 shutdown_trigger,
688 })
689 }
690
691 pub async fn wait_for_shutdown(&mut self) {
694 log::debug!("Waiting for simulated graph to shutdown.");
695
696 while let Some(res) = self.tasks.join_next().await {
697 if let Err(e) = res {
698 log::error!("Graph task exited with error: {e}");
699 }
700 }
701
702 log::debug!("Simulated graph shutdown.");
703 }
704}
705
706pub async fn ln_node_from_graph<'a>(
708 graph: Arc<Mutex<SimGraph>>,
709 routing_graph: Arc<NetworkGraph<&'_ WrappedLog>>,
710) -> HashMap<PublicKey, Arc<Mutex<dyn LightningNode + '_>>> {
711 let mut nodes: HashMap<PublicKey, Arc<Mutex<dyn LightningNode>>> = HashMap::new();
712
713 for pk in graph.lock().await.nodes.keys() {
714 nodes.insert(
715 *pk,
716 Arc::new(Mutex::new(SimNode::new(
717 *pk,
718 graph.clone(),
719 routing_graph.clone(),
720 ))),
721 );
722 }
723
724 nodes
725}
726
727pub fn populate_network_graph<'a>(
732 channels: Vec<SimulatedChannel>,
733) -> Result<NetworkGraph<&'a WrappedLog>, LdkError> {
734 let graph = NetworkGraph::new(Network::Regtest, &WrappedLog {});
735
736 let chain_hash = ChainHash::using_genesis_block(Network::Regtest);
737
738 for channel in channels {
739 let announcement = UnsignedChannelAnnouncement {
740 features: ChannelFeatures::empty(),
742 chain_hash,
743 short_channel_id: channel.short_channel_id.into(),
744 node_id_1: NodeId::from_pubkey(&channel.node_1.policy.pubkey),
745 node_id_2: NodeId::from_pubkey(&channel.node_2.policy.pubkey),
746 bitcoin_key_1: NodeId::from_pubkey(&channel.node_1.policy.pubkey),
749 bitcoin_key_2: NodeId::from_pubkey(&channel.node_2.policy.pubkey),
750 excess_data: Vec::new(),
752 };
753
754 let utxo_validator = UtxoValidator {
755 amount_sat: channel.capacity_msat / 1000,
756 script: make_funding_redeemscript(
757 &channel.node_1.policy.pubkey,
758 &channel.node_2.policy.pubkey,
759 )
760 .to_v0_p2wsh(),
761 };
762
763 graph.update_channel_from_unsigned_announcement(&announcement, &Some(&utxo_validator))?;
764
765 for (i, node) in [channel.node_1, channel.node_2].iter().enumerate() {
768 let update = UnsignedChannelUpdate {
769 chain_hash,
770 short_channel_id: channel.short_channel_id.into(),
771 timestamp: SystemTime::now()
772 .duration_since(UNIX_EPOCH)
773 .unwrap()
774 .as_secs() as u32,
775 flags: i as u8,
776 cltv_expiry_delta: node.policy.cltv_expiry_delta as u16,
777 htlc_minimum_msat: node.policy.min_htlc_size_msat,
778 htlc_maximum_msat: node.policy.max_htlc_size_msat,
779 fee_base_msat: node.policy.base_fee as u32,
780 fee_proportional_millionths: node.policy.fee_rate_prop as u32,
781 excess_data: Vec::new(),
782 };
783 graph.update_channel_unsigned(&update)?;
784 }
785 }
786
787 Ok(graph)
788}
789
790#[async_trait]
791impl SimNetwork for SimGraph {
792 fn dispatch_payment(
797 &mut self,
798 source: PublicKey,
799 route: Route,
800 payment_hash: PaymentHash,
801 sender: Sender<Result<PaymentResult, LightningError>>,
802 ) {
803 let path = match route.paths.first() {
805 Some(p) => p,
806 None => {
807 log::warn!("Find route did not return expected number of paths.");
808
809 if let Err(e) = sender.send(Ok(PaymentResult {
810 htlc_count: 0,
811 payment_outcome: PaymentOutcome::RouteNotFound,
812 })) {
813 log::error!("Could not send payment result: {:?}.", e);
814 }
815
816 return;
817 },
818 };
819
820 self.tasks.spawn(propagate_payment(
821 self.channels.clone(),
822 source,
823 path.clone(),
824 payment_hash,
825 sender,
826 self.shutdown_trigger.clone(),
827 ));
828 }
829
830 async fn lookup_node(&self, node: &PublicKey) -> Result<(NodeInfo, Vec<u64>), LightningError> {
832 self.nodes
833 .get(node)
834 .map(|channels| (node_info(*node), channels.clone()))
835 .ok_or(LightningError::GetNodeInfoError(
836 "Node not found".to_string(),
837 ))
838 }
839}
840
841async fn add_htlcs(
859 nodes: Arc<Mutex<HashMap<ShortChannelID, SimulatedChannel>>>,
860 source: PublicKey,
861 route: Path,
862 payment_hash: PaymentHash,
863) -> Result<(), (Option<usize>, ForwardingError)> {
864 let mut outgoing_node = source;
865 let mut outgoing_amount = route.fee_msat() + route.final_value_msat();
866 let mut outgoing_cltv = route.hops.iter().map(|hop| hop.cltv_expiry_delta).sum();
867
868 let mut fail_idx = None;
874
875 for (i, hop) in route.hops.iter().enumerate() {
876 let mut node_lock = nodes.lock().await;
880 let scid = ShortChannelID::from(hop.short_channel_id);
881
882 if let Some(channel) = node_lock.get_mut(&scid) {
883 channel
884 .add_htlc(
885 &outgoing_node,
886 payment_hash,
887 Htlc {
888 amount_msat: outgoing_amount,
889 cltv_expiry: outgoing_cltv,
890 },
891 )
892 .map_err(|e| (fail_idx, e))?;
895
896 fail_idx = Some(i);
899
900 if i != route.hops.len() - 1 {
907 if let Some(channel) =
908 node_lock.get(&ShortChannelID::from(route.hops[i + 1].short_channel_id))
909 {
910 channel
911 .check_htlc_forward(
912 &hop.pubkey,
913 hop.cltv_expiry_delta,
914 outgoing_amount - hop.fee_msat,
915 hop.fee_msat,
916 )
917 .map_err(|e| (fail_idx, e))?;
920 }
921 }
922 } else {
923 return Err((fail_idx, ForwardingError::ChannelNotFound(scid)));
924 }
925
926 outgoing_node = hop.pubkey;
928 outgoing_amount -= hop.fee_msat;
929 outgoing_cltv -= hop.cltv_expiry_delta;
930
931 }
933
934 Ok(())
935}
936
937async fn remove_htlcs(
947 nodes: Arc<Mutex<HashMap<ShortChannelID, SimulatedChannel>>>,
948 resolution_idx: usize,
949 source: PublicKey,
950 route: Path,
951 payment_hash: PaymentHash,
952 success: bool,
953) -> Result<(), ForwardingError> {
954 for (i, hop) in route.hops[0..=resolution_idx].iter().enumerate().rev() {
955 let incoming_node = if i == 0 {
959 source
960 } else {
961 route.hops[i - 1].pubkey
962 };
963
964 match nodes
967 .lock()
968 .await
969 .get_mut(&ShortChannelID::from(hop.short_channel_id))
970 {
971 Some(channel) => channel.remove_htlc(&incoming_node, &payment_hash, success)?,
972 None => {
973 return Err(ForwardingError::ChannelNotFound(ShortChannelID::from(
974 hop.short_channel_id,
975 )))
976 },
977 }
978 }
979
980 Ok(())
981}
982
983async fn propagate_payment(
988 nodes: Arc<Mutex<HashMap<ShortChannelID, SimulatedChannel>>>,
989 source: PublicKey,
990 route: Path,
991 payment_hash: PaymentHash,
992 sender: Sender<Result<PaymentResult, LightningError>>,
993 shutdown: Trigger,
994) {
995 let notify_result = if let Err((fail_idx, err)) =
998 add_htlcs(nodes.clone(), source, route.clone(), payment_hash).await
999 {
1000 if err.is_critical() {
1001 shutdown.trigger();
1002 }
1003
1004 if let Some(resolution_idx) = fail_idx {
1005 if let Err(e) =
1006 remove_htlcs(nodes, resolution_idx, source, route, payment_hash, false).await
1007 {
1008 if e.is_critical() {
1009 shutdown.trigger();
1010 }
1011 }
1012 }
1013
1014 log::debug!(
1017 "Forwarding failure for simulated payment {}: {err}",
1018 hex::encode(payment_hash.0)
1019 );
1020
1021 PaymentResult {
1022 htlc_count: 0,
1023 payment_outcome: PaymentOutcome::Unknown,
1024 }
1025 } else {
1026 if let Err(e) = remove_htlcs(
1028 nodes,
1029 route.hops.len() - 1,
1030 source,
1031 route,
1032 payment_hash,
1033 true,
1034 )
1035 .await
1036 {
1037 if e.is_critical() {
1038 shutdown.trigger();
1039 }
1040
1041 log::error!("Could not remove htlcs from channel: {e}.");
1042 }
1043
1044 PaymentResult {
1045 htlc_count: 1,
1046 payment_outcome: PaymentOutcome::Success,
1047 }
1048 };
1049
1050 if let Err(e) = sender.send(Ok(notify_result)) {
1051 log::error!("Could not notify payment result: {:?}.", e);
1052 }
1053}
1054
1055pub struct WrappedLog {}
1058
1059impl Logger for WrappedLog {
1060 fn log(&self, record: Record) {
1061 match record.level {
1062 Level::Gossip => log::trace!("{}", record.args),
1063 Level::Trace => log::trace!("{}", record.args),
1064 Level::Debug => log::debug!("{}", record.args),
1065 Level::Info => log::debug!("{}", record.args),
1067 Level::Warn => log::warn!("{}", record.args),
1068 Level::Error => log::error!("{}", record.args),
1069 }
1070 }
1071}
1072
1073struct UtxoValidator {
1075 amount_sat: u64,
1076 script: ScriptBuf,
1077}
1078
1079impl UtxoLookup for UtxoValidator {
1080 fn get_utxo(&self, _genesis_hash: &ChainHash, _short_channel_id: u64) -> UtxoResult {
1081 UtxoResult::Sync(Ok(TxOut {
1082 value: self.amount_sat,
1083 script_pubkey: self.script.clone(),
1084 }))
1085 }
1086}
1087
1088#[cfg(test)]
1089mod tests {
1090 use super::*;
1091 use crate::test_utils::get_random_keypair;
1092 use bitcoin::secp256k1::PublicKey;
1093 use lightning::routing::router::Route;
1094 use mockall::mock;
1095 use std::time::Duration;
1096 use tokio::sync::oneshot;
1097 use tokio::time::timeout;
1098
1099 fn create_test_policy(max_in_flight_msat: u64) -> ChannelPolicy {
1102 let (_, pk) = get_random_keypair();
1103 ChannelPolicy {
1104 pubkey: pk,
1105 max_htlc_count: 10,
1106 max_in_flight_msat,
1107 min_htlc_size_msat: 2,
1108 max_htlc_size_msat: max_in_flight_msat / 2,
1109 cltv_expiry_delta: 10,
1110 base_fee: 1000,
1111 fee_rate_prop: 5000,
1112 }
1113 }
1114
1115 fn create_simulated_channels(n: u64, capacity_msat: u64) -> Vec<SimulatedChannel> {
1121 let mut channels: Vec<SimulatedChannel> = vec![];
1122 let (_, first_node) = get_random_keypair();
1123
1124 let mut node_1 = first_node;
1126 for i in 0..n {
1127 let (_, node_2) = get_random_keypair();
1129
1130 let node_1_to_2 = ChannelPolicy {
1131 pubkey: node_1,
1132 max_htlc_count: 483,
1133 max_in_flight_msat: capacity_msat / 2,
1134 min_htlc_size_msat: 1,
1135 max_htlc_size_msat: capacity_msat / 2,
1136 cltv_expiry_delta: 40,
1137 base_fee: 1000 * i,
1138 fee_rate_prop: 1500 * i,
1139 };
1140
1141 let node_2_to_1 = ChannelPolicy {
1142 pubkey: node_2,
1143 max_htlc_count: 483,
1144 max_in_flight_msat: capacity_msat / 2,
1145 min_htlc_size_msat: 1,
1146 max_htlc_size_msat: capacity_msat / 2,
1147 cltv_expiry_delta: 40 + 10 * i as u32,
1148 base_fee: 2000 * i,
1149 fee_rate_prop: i,
1150 };
1151
1152 channels.push(SimulatedChannel {
1153 capacity_msat,
1154 short_channel_id: ShortChannelID::from(i),
1156 node_1: ChannelState::new(node_1_to_2, capacity_msat),
1157 node_2: ChannelState::new(node_2_to_1, 0),
1158 });
1159
1160 node_1 = node_2;
1162 }
1163
1164 channels
1165 }
1166
1167 macro_rules! assert_channel_balances {
1168 ($channel_state:expr, $local_balance:expr, $in_flight_len:expr, $in_flight_total:expr) => {
1169 assert_eq!($channel_state.local_balance_msat, $local_balance);
1170 assert_eq!($channel_state.in_flight.len(), $in_flight_len);
1171 assert_eq!($channel_state.in_flight_total(), $in_flight_total);
1172 };
1173 }
1174
1175 #[test]
1177 fn test_channel_state_transitions() {
1178 let local_balance = 100_000_000;
1179 let mut channel_state =
1180 ChannelState::new(create_test_policy(local_balance / 2), local_balance);
1181
1182 assert_channel_balances!(channel_state, local_balance, 0, 0);
1184
1185 let hash_1 = PaymentHash([1; 32]);
1189 let htlc_1 = Htlc {
1190 amount_msat: 1000,
1191 cltv_expiry: 40,
1192 };
1193
1194 assert!(channel_state.add_outgoing_htlc(hash_1, htlc_1).is_ok());
1195 assert_channel_balances!(
1196 channel_state,
1197 local_balance - htlc_1.amount_msat,
1198 1,
1199 htlc_1.amount_msat
1200 );
1201
1202 assert!(matches!(
1205 channel_state.add_outgoing_htlc(hash_1, htlc_1),
1206 Err(ForwardingError::PaymentHashExists(_))
1207 ));
1208
1209 let hash_2 = PaymentHash([2; 32]);
1211 let htlc_2 = Htlc {
1212 amount_msat: 1000,
1213 cltv_expiry: 40,
1214 };
1215
1216 assert!(channel_state.add_outgoing_htlc(hash_2, htlc_2).is_ok());
1217 assert_channel_balances!(
1218 channel_state,
1219 local_balance - htlc_1.amount_msat - htlc_2.amount_msat,
1220 2,
1221 htlc_1.amount_msat + htlc_2.amount_msat
1222 );
1223
1224 assert!(channel_state.remove_outgoing_htlc(&hash_2).is_ok());
1226 channel_state.settle_outgoing_htlc(htlc_2.amount_msat, false);
1227 assert_channel_balances!(
1228 channel_state,
1229 local_balance - htlc_1.amount_msat,
1230 1,
1231 htlc_1.amount_msat
1232 );
1233
1234 assert!(matches!(
1236 channel_state.remove_outgoing_htlc(&hash_2),
1237 Err(ForwardingError::PaymentHashNotFound(_))
1238 ));
1239
1240 assert!(channel_state.remove_outgoing_htlc(&hash_1).is_ok());
1242 channel_state.settle_outgoing_htlc(htlc_1.amount_msat, true);
1243 assert_channel_balances!(channel_state, local_balance - htlc_1.amount_msat, 0, 0);
1244 }
1245
1246 #[test]
1248 fn test_htlc_forward() {
1249 let local_balance = 140_000;
1250 let channel_state = ChannelState::new(create_test_policy(local_balance / 2), local_balance);
1251
1252 assert!(matches!(
1254 channel_state.check_htlc_forward(channel_state.policy.cltv_expiry_delta - 1, 0, 0),
1255 Err(ForwardingError::InsufficientCltvDelta(_, _))
1256 ));
1257
1258 let htlc_amount = 1000;
1260 let htlc_fee = channel_state.policy.base_fee
1261 + (channel_state.policy.fee_rate_prop * htlc_amount) / 1e6 as u64;
1262
1263 assert!(matches!(
1264 channel_state.check_htlc_forward(
1265 channel_state.policy.cltv_expiry_delta,
1266 htlc_amount,
1267 htlc_fee - 1
1268 ),
1269 Err(ForwardingError::InsufficientFee(_, _, _, _))
1270 ));
1271
1272 assert!(channel_state
1274 .check_htlc_forward(
1275 channel_state.policy.cltv_expiry_delta,
1276 htlc_amount,
1277 htlc_fee,
1278 )
1279 .is_ok());
1280
1281 assert!(channel_state
1282 .check_htlc_forward(
1283 channel_state.policy.cltv_expiry_delta * 2,
1284 htlc_amount,
1285 htlc_fee * 3
1286 )
1287 .is_ok());
1288 }
1289
1290 #[test]
1292 fn test_check_outgoing_addition() {
1293 let local_balance = 100_000;
1295 let mut channel_state =
1296 ChannelState::new(create_test_policy(local_balance / 2), local_balance);
1297
1298 let mut htlc = Htlc {
1299 amount_msat: channel_state.policy.max_htlc_size_msat + 1,
1300 cltv_expiry: channel_state.policy.cltv_expiry_delta,
1301 };
1302 assert!(matches!(
1304 channel_state.check_outgoing_addition(&htlc),
1305 Err(ForwardingError::MoreThanMaximum(_, _))
1306 ));
1307
1308 htlc.amount_msat = channel_state.policy.min_htlc_size_msat - 1;
1310 assert!(matches!(
1311 channel_state.check_outgoing_addition(&htlc),
1312 Err(ForwardingError::LessThanMinimum(_, _))
1313 ));
1314
1315 let hash_1 = PaymentHash([1; 32]);
1317 let htlc_1 = Htlc {
1318 amount_msat: channel_state.policy.max_in_flight_msat / 2,
1319 cltv_expiry: channel_state.policy.cltv_expiry_delta,
1320 };
1321
1322 assert!(channel_state.check_outgoing_addition(&htlc_1).is_ok());
1323 assert!(channel_state.add_outgoing_htlc(hash_1, htlc_1).is_ok());
1324
1325 let hash_2 = PaymentHash([2; 32]);
1326 let htlc_2 = Htlc {
1327 amount_msat: channel_state.policy.max_in_flight_msat / 2,
1328 cltv_expiry: channel_state.policy.cltv_expiry_delta,
1329 };
1330
1331 assert!(channel_state.check_outgoing_addition(&htlc_2).is_ok());
1332 assert!(channel_state.add_outgoing_htlc(hash_2, htlc_2).is_ok());
1333
1334 htlc.amount_msat = channel_state.policy.min_htlc_size_msat;
1336 assert!(matches!(
1337 channel_state.check_outgoing_addition(&htlc),
1338 Err(ForwardingError::ExceedsInFlightTotal(_, _))
1339 ));
1340
1341 assert!(channel_state.remove_outgoing_htlc(&hash_1).is_ok());
1343 channel_state.settle_outgoing_htlc(htlc_1.amount_msat, true);
1344
1345 assert!(channel_state.remove_outgoing_htlc(&hash_2).is_ok());
1346 channel_state.settle_outgoing_htlc(htlc_2.amount_msat, true);
1347
1348 for i in 0..channel_state.policy.max_htlc_count {
1350 let hash = PaymentHash([i.try_into().unwrap(); 32]);
1351 assert!(channel_state.check_outgoing_addition(&htlc).is_ok());
1352 assert!(channel_state.add_outgoing_htlc(hash, htlc).is_ok());
1353 }
1354
1355 let htlc_3 = Htlc {
1357 amount_msat: channel_state.policy.min_htlc_size_msat,
1358 cltv_expiry: channel_state.policy.cltv_expiry_delta,
1359 };
1360
1361 assert!(matches!(
1362 channel_state.check_outgoing_addition(&htlc_3),
1363 Err(ForwardingError::ExceedsInFlightCount(_, _))
1364 ));
1365
1366 for i in 0..channel_state.policy.max_htlc_count {
1368 let hash = PaymentHash([i.try_into().unwrap(); 32]);
1369 assert!(channel_state.remove_outgoing_htlc(&hash).is_ok());
1370 channel_state.settle_outgoing_htlc(htlc.amount_msat, true)
1371 }
1372
1373 let hash_4 = PaymentHash([1; 32]);
1375 let htlc_4 = Htlc {
1376 amount_msat: channel_state.policy.max_htlc_size_msat,
1377 cltv_expiry: channel_state.policy.cltv_expiry_delta,
1378 };
1379 assert!(channel_state.check_outgoing_addition(&htlc_4).is_ok());
1380 assert!(channel_state.add_outgoing_htlc(hash_4, htlc_4).is_ok());
1381 assert!(channel_state.remove_outgoing_htlc(&hash_4).is_ok());
1382 channel_state.settle_outgoing_htlc(htlc_4.amount_msat, true);
1383
1384 assert!(channel_state.local_balance_msat < channel_state.policy.max_htlc_size_msat);
1388 assert!(matches!(
1389 channel_state.check_outgoing_addition(&htlc_4),
1390 Err(ForwardingError::InsufficientBalance(_, _))
1391 ));
1392 }
1393
1394 #[test]
1397 fn test_simulated_channel() {
1398 let capacity_msat = 500_000_000;
1401 let node_1 = ChannelState::new(create_test_policy(capacity_msat / 2), capacity_msat);
1402 let node_2 = ChannelState::new(create_test_policy(capacity_msat / 2), 0);
1403
1404 let mut simulated_channel = SimulatedChannel {
1405 capacity_msat,
1406 short_channel_id: ShortChannelID::from(123),
1407 node_1: node_1.clone(),
1408 node_2: node_2.clone(),
1409 };
1410
1411 let hash_1 = PaymentHash([1; 32]);
1413 let htlc_1 = Htlc {
1414 amount_msat: node_2.policy.min_htlc_size_msat,
1415 cltv_expiry: node_1.policy.cltv_expiry_delta,
1416 };
1417
1418 assert!(matches!(
1419 simulated_channel.add_htlc(&node_2.policy.pubkey, hash_1, htlc_1),
1420 Err(ForwardingError::InsufficientBalance(_, _))
1421 ));
1422
1423 let hash_2 = PaymentHash([1; 32]);
1425 let htlc_2 = Htlc {
1426 amount_msat: node_1.policy.max_htlc_size_msat,
1427 cltv_expiry: node_2.policy.cltv_expiry_delta,
1428 };
1429 assert!(simulated_channel
1430 .add_htlc(&node_1.policy.pubkey, hash_2, htlc_2)
1431 .is_ok());
1432
1433 assert!(simulated_channel
1436 .remove_htlc(&node_1.policy.pubkey, &hash_2, true)
1437 .is_ok());
1438
1439 assert!(simulated_channel
1440 .add_htlc(&node_2.policy.pubkey, hash_2, htlc_2)
1441 .is_ok());
1442
1443 let (_, pk) = get_random_keypair();
1446 assert!(matches!(
1447 simulated_channel.add_htlc(&pk, hash_2, htlc_2),
1448 Err(ForwardingError::NodeNotFound(_))
1449 ));
1450
1451 assert!(matches!(
1452 simulated_channel.remove_htlc(&pk, &hash_2, true),
1453 Err(ForwardingError::NodeNotFound(_))
1454 ));
1455 }
1456
1457 mock! {
1458 Network{}
1459
1460 #[async_trait]
1461 impl SimNetwork for Network{
1462 fn dispatch_payment(
1463 &mut self,
1464 source: PublicKey,
1465 route: Route,
1466 payment_hash: PaymentHash,
1467 sender: Sender<Result<PaymentResult, LightningError>>,
1468 );
1469
1470 async fn lookup_node(&self, node: &PublicKey) -> Result<(NodeInfo, Vec<u64>), LightningError>;
1471 }
1472 }
1473
1474 #[tokio::test]
1477 async fn test_simulated_node() {
1478 let mock = MockNetwork::new();
1480 let sim_network = Arc::new(Mutex::new(mock));
1481 let channels = create_simulated_channels(5, 300000000);
1482 let graph = populate_network_graph(channels.clone()).unwrap();
1483
1484 let pk = channels[0].node_1.policy.pubkey;
1486 let mut node = SimNode::new(pk, sim_network.clone(), Arc::new(graph));
1487
1488 let lookup_pk = channels[3].node_1.policy.pubkey;
1490 sim_network
1491 .lock()
1492 .await
1493 .expect_lookup_node()
1494 .returning(move |_| Ok((node_info(lookup_pk), vec![1, 2, 3])));
1495
1496 let node_info = node.get_node_info(&lookup_pk).await.unwrap();
1498 assert_eq!(lookup_pk, node_info.pubkey);
1499 assert_eq!(node.list_channels().await.unwrap().len(), 3);
1500
1501 let dest_1 = channels[2].node_1.policy.pubkey;
1504 let dest_2 = channels[4].node_1.policy.pubkey;
1505
1506 sim_network
1507 .lock()
1508 .await
1509 .expect_dispatch_payment()
1510 .returning(
1511 move |_, route: Route, _, sender: Sender<Result<PaymentResult, LightningError>>| {
1512 let receiver = route.paths[0].hops.last().unwrap().pubkey;
1515 let result = if receiver == dest_1 {
1516 PaymentResult {
1517 htlc_count: 2,
1518 payment_outcome: PaymentOutcome::Success,
1519 }
1520 } else if receiver == dest_2 {
1521 PaymentResult {
1522 htlc_count: 0,
1523 payment_outcome: PaymentOutcome::InsufficientBalance,
1524 }
1525 } else {
1526 panic!("unknown mocked receiver");
1527 };
1528
1529 sender.send(Ok(result)).unwrap();
1530 },
1531 );
1532
1533 let hash_1 = node.send_payment(dest_1, 10_000).await.unwrap();
1535 let hash_2 = node.send_payment(dest_2, 15_000).await.unwrap();
1536
1537 let (_, shutdown_listener) = triggered::trigger();
1538
1539 let result_1 = node
1540 .track_payment(&hash_1, shutdown_listener.clone())
1541 .await
1542 .unwrap();
1543 assert!(matches!(result_1.payment_outcome, PaymentOutcome::Success));
1544
1545 let result_2 = node
1546 .track_payment(&hash_2, shutdown_listener.clone())
1547 .await
1548 .unwrap();
1549 assert!(matches!(
1550 result_2.payment_outcome,
1551 PaymentOutcome::InsufficientBalance
1552 ));
1553 }
1554
1555 struct DispatchPaymentTestKit<'a> {
1557 graph: SimGraph,
1558 nodes: Vec<PublicKey>,
1559 routing_graph: NetworkGraph<&'a WrappedLog>,
1560 shutdown: triggered::Trigger,
1561 }
1562
1563 impl<'a> DispatchPaymentTestKit<'a> {
1564 async fn new(capacity: u64) -> Self {
1571 let (shutdown, _listener) = triggered::trigger();
1572 let channels = create_simulated_channels(3, capacity);
1573
1574 let mut nodes = channels
1577 .iter()
1578 .map(|c| c.node_1.policy.pubkey)
1579 .collect::<Vec<PublicKey>>();
1580 nodes.push(channels.last().unwrap().node_2.policy.pubkey);
1581
1582 let kit = DispatchPaymentTestKit {
1583 graph: SimGraph::new(channels.clone(), shutdown.clone())
1584 .expect("could not create test graph"),
1585 nodes,
1586 routing_graph: populate_network_graph(channels).unwrap(),
1587 shutdown,
1588 };
1589
1590 assert_eq!(
1592 kit.channel_balances().await,
1593 vec![(capacity, 0), (capacity, 0), (capacity, 0)]
1594 );
1595
1596 kit
1597 }
1598
1599 async fn channel_balances(&self) -> Vec<(u64, u64)> {
1601 let mut balances = vec![];
1602
1603 let chan_count = self.graph.channels.lock().await.len();
1606
1607 for i in 0..chan_count {
1608 let chan_lock = self.graph.channels.lock().await;
1609 let channel = chan_lock.get(&ShortChannelID::from(i as u64)).unwrap();
1610
1611 balances.push((
1614 channel.node_1.local_balance_msat,
1615 channel.node_2.local_balance_msat,
1616 ));
1617 }
1618
1619 balances
1620 }
1621
1622 async fn send_test_payemnt(
1625 &mut self,
1626 source: PublicKey,
1627 dest: PublicKey,
1628 amt: u64,
1629 ) -> Route {
1630 let route = find_payment_route(&source, dest, amt, &self.routing_graph).unwrap();
1631
1632 let (sender, receiver) = oneshot::channel();
1633 self.graph
1634 .dispatch_payment(source, route.clone(), PaymentHash([1; 32]), sender);
1635
1636 assert!(timeout(Duration::from_millis(10), receiver).await.is_ok());
1638
1639 route
1640 }
1641
1642 async fn set_channel_balance(&mut self, scid: &ShortChannelID, balance: (u64, u64)) {
1644 let mut channels_lock = self.graph.channels.lock().await;
1645 let channel = channels_lock.get_mut(scid).unwrap();
1646
1647 channel.node_1.local_balance_msat = balance.0;
1648 channel.node_2.local_balance_msat = balance.1;
1649
1650 assert!(channel.sanity_check().is_ok());
1651 }
1652 }
1653
1654 #[tokio::test]
1657 async fn test_successful_dispatch() {
1658 let chan_capacity = 500_000_000;
1659 let mut test_kit = DispatchPaymentTestKit::new(chan_capacity).await;
1660
1661 let mut amt = 20_000;
1663 let route = test_kit
1664 .send_test_payemnt(test_kit.nodes[0], test_kit.nodes[3], amt)
1665 .await;
1666
1667 let route_total = amt + route.get_total_fees();
1668 let hop_1_amt = amt + route.paths[0].hops[1].fee_msat;
1669
1670 let alice_to_bob = (chan_capacity - route_total, route_total);
1672 let mut bob_to_carol = (chan_capacity - hop_1_amt, hop_1_amt);
1674 let carol_to_dave = (chan_capacity - amt, amt);
1676
1677 let mut expected_balances = vec![alice_to_bob, bob_to_carol, carol_to_dave];
1678 assert_eq!(test_kit.channel_balances().await, expected_balances);
1679
1680 let _ = test_kit
1684 .send_test_payemnt(test_kit.nodes[3], test_kit.nodes[1], amt * 2)
1685 .await;
1686 assert_eq!(test_kit.channel_balances().await, expected_balances);
1687
1688 amt = bob_to_carol.0 / 2;
1692 let _ = test_kit
1693 .send_test_payemnt(test_kit.nodes[1], test_kit.nodes[2], amt)
1694 .await;
1695
1696 bob_to_carol = (bob_to_carol.0 / 2, bob_to_carol.1 + amt);
1697 expected_balances = vec![alice_to_bob, bob_to_carol, carol_to_dave];
1698 assert_eq!(test_kit.channel_balances().await, expected_balances);
1699
1700 let _ = test_kit
1702 .send_test_payemnt(test_kit.nodes[1], test_kit.nodes[2], amt)
1703 .await;
1704 bob_to_carol = (0, chan_capacity);
1705 expected_balances = vec![alice_to_bob, bob_to_carol, carol_to_dave];
1706 assert_eq!(test_kit.channel_balances().await, expected_balances);
1707
1708 let _ = test_kit
1711 .send_test_payemnt(test_kit.nodes[0], test_kit.nodes[3], 20_000)
1712 .await;
1713 assert_eq!(test_kit.channel_balances().await, expected_balances);
1714
1715 test_kit.shutdown.trigger();
1716 test_kit.graph.wait_for_shutdown().await;
1717 }
1718
1719 #[tokio::test]
1721 async fn test_successful_multi_hop() {
1722 let chan_capacity = 500_000_000;
1723 let mut test_kit = DispatchPaymentTestKit::new(chan_capacity).await;
1724
1725 let amt = 20_000;
1727 let route = test_kit
1728 .send_test_payemnt(test_kit.nodes[0], test_kit.nodes[3], amt)
1729 .await;
1730
1731 let route_total = amt + route.get_total_fees();
1732 let hop_1_amt = amt + route.paths[0].hops[1].fee_msat;
1733
1734 let expected_balances = vec![
1735 (chan_capacity - route_total, route_total),
1737 (chan_capacity - hop_1_amt, hop_1_amt),
1739 (chan_capacity - amt, amt),
1741 ];
1742 assert_eq!(test_kit.channel_balances().await, expected_balances);
1743
1744 test_kit.shutdown.trigger();
1745 test_kit.graph.wait_for_shutdown().await;
1746 }
1747
1748 #[tokio::test]
1750 async fn test_single_hop_payments() {
1751 let chan_capacity = 500_000_000;
1752 let mut test_kit = DispatchPaymentTestKit::new(chan_capacity).await;
1753
1754 let amt = 150_000;
1756 let _ = test_kit
1757 .send_test_payemnt(test_kit.nodes[0], test_kit.nodes[1], amt)
1758 .await;
1759
1760 let expected_balances = vec![
1761 (chan_capacity - amt, amt),
1762 (chan_capacity, 0),
1763 (chan_capacity, 0),
1764 ];
1765 assert_eq!(test_kit.channel_balances().await, expected_balances);
1766
1767 let _ = test_kit
1770 .send_test_payemnt(test_kit.nodes[3], test_kit.nodes[2], amt)
1771 .await;
1772
1773 assert_eq!(test_kit.channel_balances().await, expected_balances);
1774
1775 test_kit.shutdown.trigger();
1776 test_kit.graph.wait_for_shutdown().await;
1777 }
1778
1779 #[tokio::test]
1781 async fn test_multi_hop_faiulre() {
1782 let chan_capacity = 500_000_000;
1783 let mut test_kit = DispatchPaymentTestKit::new(chan_capacity).await;
1784
1785 test_kit
1787 .set_channel_balance(&ShortChannelID::from(1), (0, chan_capacity))
1788 .await;
1789
1790 let mut expected_balances =
1791 vec![(chan_capacity, 0), (0, chan_capacity), (chan_capacity, 0)];
1792 assert_eq!(test_kit.channel_balances().await, expected_balances);
1793
1794 let amt = 150_000;
1796 let _ = test_kit
1797 .send_test_payemnt(test_kit.nodes[0], test_kit.nodes[3], amt)
1798 .await;
1799
1800 assert_eq!(test_kit.channel_balances().await, expected_balances);
1801
1802 expected_balances[2] = (0, chan_capacity);
1805 test_kit
1806 .set_channel_balance(&ShortChannelID::from(2), (0, chan_capacity))
1807 .await;
1808
1809 let _ = test_kit
1810 .send_test_payemnt(test_kit.nodes[3], test_kit.nodes[0], amt)
1811 .await;
1812
1813 assert_eq!(test_kit.channel_balances().await, expected_balances);
1814
1815 test_kit.shutdown.trigger();
1816 test_kit.graph.wait_for_shutdown().await;
1817 }
1818}