1use std::collections::{hash_map, VecDeque};
2use std::mem::ManuallyDrop;
3use std::sync::atomic::{AtomicUsize, Ordering};
4use std::sync::{Arc, Weak};
5use std::time::{Duration, Instant};
6
7use anyhow::Result;
8use arc_swap::{ArcSwap, AsRaw};
9use tokio::sync::{broadcast, mpsc, oneshot};
10use tokio::task::{AbortHandle, JoinSet};
11use tokio_util::time::{delay_queue, DelayQueue};
12use tycho_util::{FastDashMap, FastHashMap};
13
14use crate::network::config::NetworkConfig;
15use crate::network::connection::Connection;
16use crate::network::endpoint::{Connecting, Endpoint, Into0RttResult};
17use crate::network::request_handler::InboundRequestHandler;
18use crate::network::wire::handshake;
19use crate::types::{
20 Address, BoxCloneService, Direction, DisconnectReason, PeerAffinity, PeerEvent, PeerId,
21 PeerInfo, Response, ServiceRequest,
22};
23
24const METRIC_CONNECTION_OUT_TIME: &str = "tycho_net_conn_out_time";
26const METRIC_CONNECTION_IN_TIME: &str = "tycho_net_conn_in_time";
27
28const METRIC_CONNECTIONS_OUT_TOTAL: &str = "tycho_net_conn_out_total";
30const METRIC_CONNECTIONS_IN_TOTAL: &str = "tycho_net_conn_in_total";
31const METRIC_CONNECTIONS_OUT_FAIL_TOTAL: &str = "tycho_net_conn_out_fail_total";
32const METRIC_CONNECTIONS_IN_FAIL_TOTAL: &str = "tycho_net_conn_in_fail_total";
33
34const METRIC_CONNECTIONS_ACTIVE: &str = "tycho_net_conn_active";
36const METRIC_CONNECTIONS_PENDING: &str = "tycho_net_conn_pending";
37const METRIC_CONNECTIONS_PARTIAL: &str = "tycho_net_conn_partial";
38const METRIC_CONNECTIONS_PENDING_DIALS: &str = "tycho_net_conn_pending_dials";
39
40const METRIC_ACTIVE_PEERS: &str = "tycho_net_active_peers";
41const METRIC_KNOWN_PEERS: &str = "tycho_net_known_peers";
42
43#[derive(Debug)]
44pub(crate) enum ConnectionManagerRequest {
45 Connect(Address, PeerId, CallbackTx),
46 Shutdown(oneshot::Sender<()>),
47}
48
49pub(crate) struct ConnectionManager {
50 config: Arc<NetworkConfig>,
51 endpoint: Arc<Endpoint>,
52
53 mailbox: mpsc::Receiver<ConnectionManagerRequest>,
54
55 pending_connection_callbacks: FastHashMap<Address, PendingConnectionCallbacks>,
56 pending_partial_connections: JoinSet<Option<PartialConnection>>,
57 pending_connections: JoinSet<ConnectingOutput>,
58 connection_handlers: JoinSet<()>,
59 delayed_callbacks: DelayedCallbacksQueue,
60
61 pending_dials: FastHashMap<PeerId, CallbackRx>,
62 dial_backoff_states: FastHashMap<PeerId, DialBackoffState>,
63
64 active_peers: ActivePeers,
65 known_peers: KnownPeers,
66
67 service: BoxCloneService<ServiceRequest, Response>,
68}
69
70type CallbackTx = oneshot::Sender<Result<PeerId, Arc<anyhow::Error>>>;
71type CallbackRx = oneshot::Receiver<Result<PeerId, Arc<anyhow::Error>>>;
72
73impl Drop for ConnectionManager {
74 fn drop(&mut self) {
75 tracing::trace!("dropping connection manager");
76 self.endpoint.close();
77 }
78}
79
80impl ConnectionManager {
81 pub fn new(
82 config: Arc<NetworkConfig>,
83 endpoint: Arc<Endpoint>,
84 active_peers: ActivePeers,
85 known_peers: KnownPeers,
86 service: BoxCloneService<ServiceRequest, Response>,
87 ) -> (Self, mpsc::Sender<ConnectionManagerRequest>) {
88 let (mailbox_tx, mailbox) = mpsc::channel(config.connection_manager_channel_capacity);
89 let connection_manager = Self {
90 config,
91 endpoint,
92 mailbox,
93 pending_connection_callbacks: Default::default(),
94 pending_partial_connections: Default::default(),
95 pending_connections: Default::default(),
96 connection_handlers: Default::default(),
97 delayed_callbacks: Default::default(),
98 pending_dials: Default::default(),
99 dial_backoff_states: Default::default(),
100 active_peers,
101 known_peers,
102 service,
103 };
104 (connection_manager, mailbox_tx)
105 }
106
107 pub async fn start(mut self) {
108 tracing::info!("connection manager started");
109
110 let jitter = Duration::from_millis(1000).mul_f64(rand::random());
111 let mut interval = tokio::time::interval(self.config.connectivity_check_interval + jitter);
112
113 let mut shutdown_notifier = None;
114
115 loop {
116 tokio::select! {
117 now = interval.tick() => {
118 self.handle_connectivity_check(now.into_std());
119 }
120 maybe_request = self.mailbox.recv() => {
121 let Some(request) = maybe_request else {
122 break;
123 };
124
125 match request {
126 ConnectionManagerRequest::Connect(address, peer_id, callback) => {
127 self.handle_connect_request(address, &peer_id, callback);
128 }
129 ConnectionManagerRequest::Shutdown(oneshot) => {
130 shutdown_notifier = Some(oneshot);
131 break;
132 }
133 }
134 }
135 connecting = self.endpoint.accept() => {
136 if let Some(connecting) = connecting {
137 self.handle_incoming(connecting);
138 }
139 }
140 Some(connecting_output) = self.pending_connections.join_next() => {
141 metrics::gauge!(METRIC_CONNECTIONS_PENDING).decrement(1);
142 match connecting_output {
143 Ok(connecting) => self.handle_connecting_result(connecting),
144 Err(e) => {
145 if e.is_panic() {
146 std::panic::resume_unwind(e.into_panic());
147 }
148 }
149 }
150 }
151 Some(partial_connection) = self.pending_partial_connections.join_next() => {
152 metrics::gauge!(METRIC_CONNECTIONS_PARTIAL).decrement(1);
153 if let Some(PartialConnection { connection, timeout_at }) = partial_connection.unwrap() {
155 self.handle_incoming_impl(connection, None, timeout_at);
156 }
157 }
158 Some(connection_handler_output) = self.connection_handlers.join_next() => {
159 metrics::gauge!(METRIC_CONNECTIONS_ACTIVE).decrement(1);
160
161 if let Err(e) = connection_handler_output {
163 if e.is_panic() {
164 std::panic::resume_unwind(e.into_panic());
165 }
166 }
167 }
168 Some(peer_id) = self.delayed_callbacks.wait_for_next_expired() => {
169 self.delayed_callbacks.execute_expired(&peer_id);
170 }
171 }
172 }
173
174 self.shutdown().await;
175
176 if let Some(tx) = shutdown_notifier {
177 _ = tx.send(());
178 }
179
180 tracing::info!("connection manager stopped");
181 }
182
183 async fn shutdown(mut self) {
184 tracing::trace!("shutting down connection manager");
185
186 self.endpoint.close();
187
188 self.pending_partial_connections.shutdown().await;
189 metrics::gauge!(METRIC_CONNECTIONS_PARTIAL).set(0);
190
191 self.pending_connections.shutdown().await;
192 metrics::gauge!(METRIC_CONNECTIONS_PENDING).set(0);
193
194 while self.connection_handlers.join_next().await.is_some() {
195 metrics::gauge!(METRIC_CONNECTIONS_ACTIVE).decrement(1);
196 }
197 assert!(self.active_peers.is_empty());
198
199 self.endpoint
200 .wait_idle(self.config.shutdown_idle_timeout)
201 .await;
202 }
203
204 fn handle_connectivity_check(&mut self, now: Instant) {
205 use std::collections::hash_map::Entry;
206
207 self.pending_dials
208 .retain(|peer_id, oneshot| match oneshot.try_recv() {
209 Ok(Ok(returned_peer_id)) => {
210 debug_assert_eq!(peer_id, &returned_peer_id);
211 self.dial_backoff_states.remove(peer_id);
212 false
213 }
214 Ok(Err(_)) => {
215 match self.dial_backoff_states.entry(*peer_id) {
216 Entry::Occupied(mut entry) => entry.get_mut().update(
217 now,
218 self.config.connection_backoff,
219 self.config.max_connection_backoff,
220 ),
221 Entry::Vacant(entry) => {
222 entry.insert(DialBackoffState::new(
223 now,
224 self.config.connection_backoff,
225 self.config.max_connection_backoff,
226 ));
227 }
228 }
229 false
230 }
231 Err(oneshot::error::TryRecvError::Closed) => {
232 panic!("BUG: connection manager never finished dialing a peer");
233 }
234 Err(oneshot::error::TryRecvError::Empty) => true,
235 });
236
237 let outstanding_connections_limit = self
238 .config
239 .max_concurrent_outstanding_connections
240 .saturating_sub(self.pending_connections.len());
241
242 let outstanding_connections = self
243 .known_peers
244 .0
245 .iter()
246 .filter_map(|item| {
247 let value = match item.value() {
248 KnownPeerState::Stored(item) => item.upgrade()?,
249 KnownPeerState::Banned => return None,
250 };
251 let peer_info = value.peer_info.load();
252 let affinity = value.compute_affinity();
253
254 (affinity == PeerAffinity::High
255 && peer_info.id != self.endpoint.peer_id()
256 && !self.active_peers.contains(&peer_info.id)
257 && !self.pending_dials.contains_key(&peer_info.id)
258 && self
259 .dial_backoff_states
260 .get(&peer_info.id)
261 .map_or(true, |state| now > state.next_attempt_at))
262 .then(|| arc_swap::Guard::into_inner(peer_info))
263 })
264 .take(outstanding_connections_limit)
265 .collect::<Vec<_>>();
266
267 for peer_info in outstanding_connections {
268 let address = peer_info
270 .iter_addresses()
271 .next()
272 .cloned()
273 .expect("address list must have at least one item");
274
275 let (tx, rx) = oneshot::channel();
276 self.dial_peer(address, &peer_info.id, tx);
277 self.pending_dials.insert(peer_info.id, rx);
278 }
279
280 metrics::gauge!(METRIC_CONNECTIONS_PENDING_DIALS).set(self.pending_dials.len() as f64);
281 }
282
283 fn handle_connect_request(&mut self, address: Address, peer_id: &PeerId, callback: CallbackTx) {
284 self.dial_peer(address, peer_id, callback);
285 }
286
287 fn handle_incoming(&mut self, connecting: Connecting) {
288 let remote_addr = connecting.remote_address();
289 tracing::trace!(
290 local_id = %self.endpoint.peer_id(),
291 %remote_addr,
292 "received an incoming connection",
293 );
294
295 match connecting.into_0rtt() {
297 Into0RttResult::Established(connection, accepted) => {
298 let timeout_at = Instant::now() + self.config.connect_timeout;
299 self.handle_incoming_impl(connection, Some(accepted), timeout_at);
300 }
301 Into0RttResult::WithoutIdentity(partial_connection) => {
302 tracing::trace!("connection identity is not available yet");
303
304 let timeout_at = Instant::now() + self.config.connect_timeout;
305 self.pending_partial_connections.spawn(async move {
306 match tokio::time::timeout_at(timeout_at.into(), partial_connection).await {
307 Ok(Ok(connection)) => Some(PartialConnection {
308 connection,
309 timeout_at,
310 }),
311 Ok(Err(e)) => {
312 tracing::trace!(
313 %remote_addr,
314 "failed to establish an incoming connection: {e}",
315 );
316 None
317 }
318 Err(_) => {
319 tracing::trace!(
320 %remote_addr,
321 "incoming connection timed out",
322 );
323 None
324 }
325 }
326 });
327 metrics::gauge!(METRIC_CONNECTIONS_PARTIAL).increment(1);
328 }
329 Into0RttResult::InvalidConnection(e) => {
330 tracing::warn!(%remote_addr, "invalid incoming connection: {e}");
332 }
333 Into0RttResult::Unavailable(_) => unreachable!(
334 "BUG: For incoming connections, a 0.5-RTT connection must \
335 always be successfully constructed."
336 ),
337 };
338 }
339
340 fn handle_incoming_impl(
341 &mut self,
342 connection: Connection,
343 accepted: Option<quinn::ZeroRttAccepted>,
344 timeout_at: Instant,
345 ) {
346 async fn handle_incoming_task(
347 seqno: u32,
348 connection: ConnectionClosedOnDrop,
349 accepted: Option<quinn::ZeroRttAccepted>,
350 timeout_at: Instant,
351 ) -> ConnectingOutput {
352 let target_peer_id = *connection.peer_id();
353 let target_address = connection.remote_address().into();
354 let fut = async {
355 if let Some(accepted) = accepted {
356 accepted.await;
358 }
359 handshake(&connection).await
360 };
361
362 let started_at = Instant::now();
363
364 let connecting_result = tokio::time::timeout_at(timeout_at.into(), fut)
365 .await
366 .map_err(Into::into)
367 .and_then(std::convert::identity)
368 .map_err(Arc::new)
369 .map(|_| connection.disarm());
370
371 metrics::histogram!(METRIC_CONNECTION_IN_TIME).record(started_at.elapsed());
372
373 ConnectingOutput {
374 seqno,
375 drop_result: true,
376 connecting_result: ManuallyDrop::new(connecting_result),
377 target_address,
378 target_peer_id,
379 origin: Direction::Inbound,
380 }
381 }
382
383 let remote_addr = connection.remote_address();
384
385 match self.known_peers.get_affinity(connection.peer_id()) {
387 Some(PeerAffinity::High | PeerAffinity::Allowed) => {}
388 Some(PeerAffinity::Never) => {
389 tracing::warn!(
391 %remote_addr,
392 peer_id = %connection.peer_id(),
393 "rejecting connection due to PeerAffinity::Never",
394 );
395 connection.close();
396 return;
397 }
398 _ => {
399 if matches!(
400 self.config.max_concurrent_connections,
401 Some(limit) if self.active_peers.len() >= limit
402 ) {
403 tracing::warn!(
405 %remote_addr,
406 peer_id = %connection.peer_id(),
407 "rejecting connection due too many concurrent connections",
408 );
409 connection.close();
410 return;
411 }
412 }
413 }
414
415 let entry = match self.pending_connection_callbacks.entry(remote_addr.into()) {
416 hash_map::Entry::Vacant(entry) => Some(entry.insert(PendingConnectionCallbacks {
417 last_seqno: 0,
418 origin: Direction::Inbound,
419 callbacks: Default::default(),
420 abort_handle: None,
421 })),
422 hash_map::Entry::Occupied(entry) => {
423 let entry = entry.into_mut();
424
425 if simultaneous_dial_tie_breaking(
427 self.endpoint.peer_id(),
428 connection.peer_id(),
429 entry.origin,
430 Direction::Inbound,
431 ) {
432 tracing::debug!(
434 %remote_addr,
435 peer_id = %connection.peer_id(),
436 "cancelling old connection to mitigate simultaneous dial",
437 );
438
439 entry.origin = Direction::Inbound;
440 entry.last_seqno += 1;
441 if let Some(handle) = entry.abort_handle.take() {
442 handle.abort();
443 }
444 Some(entry)
445 } else {
446 tracing::debug!(
448 %remote_addr,
449 peer_id = %connection.peer_id(),
450 "cancelling new connection to mitigate simultaneous dial",
451 );
452
453 connection.close();
454 None
455 }
456 }
457 };
458
459 if let Some(entry) = entry {
460 entry.abort_handle = Some(self.pending_connections.spawn(handle_incoming_task(
461 entry.last_seqno,
462 ConnectionClosedOnDrop::new(connection),
463 accepted,
464 timeout_at,
465 )));
466 metrics::gauge!(METRIC_CONNECTIONS_PENDING).increment(1);
467 }
468 }
469
470 fn handle_connecting_result(&mut self, mut res: ConnectingOutput) {
471 {
473 let Some(entry) = self.pending_connection_callbacks.get(&res.target_address) else {
474 tracing::trace!("connection task reordering detected");
475 return;
476 };
477
478 if entry.last_seqno != res.seqno {
479 tracing::debug!(
480 local_id = %self.endpoint.peer_id(),
481 peer_id = %res.target_peer_id,
482 remote_addr = %res.target_address,
483 "connection result is outdated"
484 );
485 return;
486 }
487 }
488
489 let callbacks = self
490 .pending_connection_callbacks
491 .remove(&res.target_address)
492 .expect("Connection tasks must be tracked")
493 .callbacks;
494
495 res.drop_result = false;
496 match unsafe { ManuallyDrop::take(&mut res.connecting_result) } {
498 Ok(connection) => {
499 let peer_id = *connection.peer_id();
500 tracing::debug!(
501 local_id = %self.endpoint.peer_id(),
502 %peer_id,
503 remote_addr = %res.target_address,
504 "new connection",
505 );
506 self.add_peer(connection);
507
508 self.delayed_callbacks.execute_resolved(&res.target_peer_id);
509
510 for callback in callbacks {
511 _ = callback.send(Ok(peer_id));
512 }
513 }
514 Err(e) => {
515 tracing::debug!(
516 local_id = %self.endpoint.peer_id(),
517 peer_id = %res.target_peer_id,
518 remote_addr = %res.target_address,
519 "connection failed: {e}"
520 );
521
522 metrics::counter!(match res.origin {
523 Direction::Outbound => METRIC_CONNECTIONS_OUT_FAIL_TOTAL,
524 Direction::Inbound => METRIC_CONNECTIONS_IN_FAIL_TOTAL,
525 })
526 .increment(1);
527
528 if let Some(quinn::ConnectionError::ApplicationClosed(closed)) = e.downcast_ref() {
531 if closed.error_code.into_inner() == 0 && !callbacks.is_empty() {
532 self.delayed_callbacks.push(
533 &res.target_peer_id,
534 e,
535 callbacks,
536 &self.config.connection_error_delay,
537 );
538 return;
539 }
540 }
541
542 for callback in callbacks {
543 _ = callback.send(Err(e.clone()));
544 }
545 }
546 }
547 }
548
549 fn add_peer(&mut self, connection: Connection) {
550 if let Some(connection) = self.active_peers.add(self.endpoint.peer_id(), connection) {
551 let origin = connection.origin();
552
553 let handler = InboundRequestHandler::new(
554 self.config.clone(),
555 connection,
556 self.service.clone(),
557 self.active_peers.clone(),
558 );
559
560 metrics::counter!(match origin {
561 Direction::Outbound => METRIC_CONNECTIONS_OUT_TOTAL,
562 Direction::Inbound => METRIC_CONNECTIONS_IN_TOTAL,
563 })
564 .increment(1);
565
566 metrics::gauge!(METRIC_CONNECTIONS_ACTIVE).increment(1);
567 self.connection_handlers.spawn(handler.start());
568 }
569 }
570
571 #[tracing::instrument(
572 level = "trace",
573 skip_all,
574 fields(
575 local_id = %self.endpoint.peer_id(),
576 peer_id = %peer_id,
577 remote_addr = %address,
578 ),
579 )]
580 fn dial_peer(&mut self, address: Address, peer_id: &PeerId, callback: CallbackTx) {
581 async fn dial_peer_task(
582 seqno: u32,
583 endpoint: Arc<Endpoint>,
584 address: Address,
585 peer_id: PeerId,
586 config: Arc<NetworkConfig>,
587 ) -> ConnectingOutput {
588 let fut = async {
589 let address = address.resolve().await?;
590 let connecting = endpoint.connect_with_expected_id(&address, &peer_id)?;
591 let connection = ConnectionClosedOnDrop::new(connecting.await?);
592 handshake(&connection).await?;
593 Ok(connection)
594 };
595
596 let started_at = Instant::now();
597
598 let connecting_result = tokio::time::timeout(config.connect_timeout, fut)
599 .await
600 .map_err(Into::into)
601 .and_then(std::convert::identity)
602 .map_err(Arc::new)
603 .map(ConnectionClosedOnDrop::disarm);
604
605 metrics::histogram!(METRIC_CONNECTION_OUT_TIME).record(started_at.elapsed());
606
607 ConnectingOutput {
608 seqno,
609 drop_result: true,
610 connecting_result: ManuallyDrop::new(connecting_result),
611 target_address: address,
612 target_peer_id: peer_id,
613 origin: Direction::Outbound,
614 }
615 }
616
617 if self.active_peers.contains(peer_id) {
618 tracing::debug!("peer is already connected");
619 _ = callback.send(Ok(*peer_id));
620 return;
621 }
622
623 tracing::trace!("connecting to peer");
624
625 let entry = match self.pending_connection_callbacks.entry(address.clone()) {
626 hash_map::Entry::Vacant(entry) => Some(entry.insert(PendingConnectionCallbacks {
627 last_seqno: 0,
628 origin: Direction::Outbound,
629 callbacks: vec![callback],
630 abort_handle: None,
631 })),
632 hash_map::Entry::Occupied(entry) => {
633 let entry = entry.into_mut();
634
635 entry.callbacks.push(callback);
637
638 let break_tie = simultaneous_dial_tie_breaking(
640 self.endpoint.peer_id(),
641 peer_id,
642 entry.origin,
643 Direction::Outbound,
644 );
645
646 if break_tie && entry.origin != Direction::Outbound {
647 tracing::debug!("cancelling old connection to mitigate simultaneous dial");
649
650 entry.origin = Direction::Outbound;
651 entry.last_seqno += 1;
652 if let Some(handle) = entry.abort_handle.take() {
653 handle.abort();
654 }
655 Some(entry)
656 } else {
657 tracing::trace!("reusing old connection to mitigate simultaneous dial");
659 None
660 }
661 }
662 };
663
664 if let Some(entry) = entry {
665 entry.abort_handle = Some(self.pending_connections.spawn(dial_peer_task(
666 entry.last_seqno,
667 self.endpoint.clone(),
668 address.clone(),
669 *peer_id,
670 self.config.clone(),
671 )));
672 metrics::gauge!(METRIC_CONNECTIONS_PENDING).increment(1);
673 }
674 }
675}
676
677struct PendingConnectionCallbacks {
678 last_seqno: u32,
679 origin: Direction,
680 callbacks: Vec<CallbackTx>,
681 abort_handle: Option<AbortHandle>,
682}
683
684struct PartialConnection {
685 connection: Connection,
686 timeout_at: Instant,
687}
688
689struct ConnectingOutput {
690 seqno: u32,
691 drop_result: bool,
692 connecting_result: ManuallyDrop<Result<Connection, Arc<anyhow::Error>>>,
693 target_address: Address,
694 target_peer_id: PeerId,
695 origin: Direction,
696}
697
698impl Drop for ConnectingOutput {
699 fn drop(&mut self) {
700 if self.drop_result {
701 unsafe { ManuallyDrop::drop(&mut self.connecting_result) };
703 }
704 }
705}
706
707struct ConnectionClosedOnDrop {
708 connection: ManuallyDrop<Connection>,
709 close_on_drop: bool,
710}
711
712impl ConnectionClosedOnDrop {
713 fn new(connection: Connection) -> Self {
714 Self {
715 connection: ManuallyDrop::new(connection),
716 close_on_drop: true,
717 }
718 }
719
720 fn disarm(mut self) -> Connection {
721 self.close_on_drop = false;
722 unsafe { ManuallyDrop::take(&mut self.connection) }
724 }
725}
726
727impl std::ops::Deref for ConnectionClosedOnDrop {
728 type Target = Connection;
729
730 #[inline]
731 fn deref(&self) -> &Self::Target {
732 &self.connection
733 }
734}
735
736impl Drop for ConnectionClosedOnDrop {
737 fn drop(&mut self) {
738 if self.close_on_drop {
739 let connection = unsafe { ManuallyDrop::take(&mut self.connection) };
741 connection.close();
742 }
743 }
744}
745
746#[derive(Default)]
747struct DelayedCallbacksQueue {
748 callbacks: FastHashMap<PeerId, VecDeque<DelayedCallbacks>>,
749 expirations: DelayQueue<PeerId>,
750}
751
752impl DelayedCallbacksQueue {
753 fn push(
754 &mut self,
755 peer_id: &PeerId,
756 error: Arc<anyhow::Error>,
757 callbacks: Vec<CallbackTx>,
758 delay: &Duration,
759 ) {
760 tracing::debug!(%peer_id, %error, "delayed connection error");
761
762 let expires_at = Instant::now() + *delay;
763 let delay_key = self.expirations.insert_at(*peer_id, expires_at.into());
764
765 let items = self.callbacks.entry(*peer_id).or_default();
766 items.push_back(DelayedCallbacks {
767 delay_key,
768 error,
769 callbacks,
770 expires_at,
771 });
772 }
773
774 async fn wait_for_next_expired(&mut self) -> Option<PeerId> {
775 let res = futures_util::future::poll_fn(|cx| self.expirations.poll_expired(cx)).await?;
776 Some(res.into_inner())
777 }
778
779 fn execute_resolved(&mut self, peer_id: &PeerId) {
780 let Some(items) = self.callbacks.remove(peer_id) else {
781 return;
782 };
783
784 let mut batches_executed = 0;
785 let mut callbacks_executed = 0;
786
787 for delayed in items {
788 batches_executed += 1;
789 callbacks_executed += delayed.callbacks.len();
790
791 let key = delayed.execute_with_ok(peer_id);
792
793 self.expirations.remove(&key);
795 }
796
797 tracing::debug!(
798 %peer_id,
799 batches_executed,
800 callbacks_executed,
801 "executed all delayed callbacks",
802 );
803 }
804
805 fn execute_expired(&mut self, peer_id: &PeerId) {
806 let now = Instant::now();
807
808 let mut batches_executed = 0;
809 let mut callbacks_executed = 0;
810
811 'outer: {
812 if let Some(items) = self.callbacks.get_mut(peer_id) {
813 while let Some(front) = items.front() {
814 if !front.is_expired(&now) {
815 break 'outer;
816 }
817
818 if let Some(delayed) = items.pop_front() {
819 batches_executed += 1;
820 callbacks_executed += delayed.callbacks.len();
821
822 let key = delayed.execute_with_error();
823
824 self.expirations.try_remove(&key);
827 }
828 }
829 }
830
831 self.callbacks.remove(peer_id);
833 }
834
835 tracing::debug!(
836 %peer_id,
837 batches_executed,
838 callbacks_executed,
839 "executed expired delayed callbacks"
840 );
841 }
842}
843
844struct DelayedCallbacks {
845 delay_key: delay_queue::Key,
846 error: Arc<anyhow::Error>,
847 callbacks: Vec<CallbackTx>,
848 expires_at: Instant,
849}
850
851impl DelayedCallbacks {
852 fn execute_with_ok(self, peer_id: &PeerId) -> delay_queue::Key {
853 for callback in self.callbacks {
854 _ = callback.send(Ok(*peer_id));
855 }
856 self.delay_key
857 }
858
859 fn execute_with_error(self) -> delay_queue::Key {
860 for callback in self.callbacks {
861 _ = callback.send(Err(self.error.clone()));
862 }
863 self.delay_key
864 }
865
866 fn is_expired(&self, now: &Instant) -> bool {
867 *now >= self.expires_at
868 }
869}
870
871#[derive(Debug)]
872struct DialBackoffState {
873 next_attempt_at: Instant,
874 attempts: usize,
875}
876
877impl DialBackoffState {
878 fn new(now: Instant, step: Duration, max: Duration) -> Self {
879 let mut state = Self {
880 next_attempt_at: now,
881 attempts: 0,
882 };
883 state.update(now, step, max);
884 state
885 }
886
887 fn update(&mut self, now: Instant, step: Duration, max: Duration) {
888 self.attempts += 1;
889 self.next_attempt_at = now
890 + std::cmp::min(
891 max,
892 step.saturating_mul(self.attempts.try_into().unwrap_or(u32::MAX)),
893 );
894 }
895}
896
897#[derive(Clone)]
898pub struct ActivePeers(Arc<ActivePeersInner>);
899
900impl ActivePeers {
901 pub fn new(channel_size: usize) -> Self {
902 Self(Arc::new(ActivePeersInner::new(channel_size)))
903 }
904
905 pub fn get(&self, peer_id: &PeerId) -> Option<Connection> {
906 self.0.get(peer_id)
907 }
908
909 pub fn contains(&self, peer_id: &PeerId) -> bool {
910 self.0.contains(peer_id)
911 }
912
913 pub fn add(&self, local_id: &PeerId, new_connection: Connection) -> Option<Connection> {
914 self.0.add(local_id, new_connection)
915 }
916
917 pub fn remove(&self, peer_id: &PeerId, reason: DisconnectReason) {
918 self.0.remove(peer_id, reason);
919 }
920
921 pub fn remove_with_stable_id(
922 &self,
923 peer_id: &PeerId,
924 stable_id: usize,
925 reason: DisconnectReason,
926 ) {
927 self.0.remove_with_stable_id(peer_id, stable_id, reason);
928 }
929
930 pub fn subscribe(&self) -> broadcast::Receiver<PeerEvent> {
931 self.0.subscribe()
932 }
933
934 pub fn is_empty(&self) -> bool {
935 self.0.is_empty()
936 }
937
938 pub fn len(&self) -> usize {
939 self.0.len()
940 }
941
942 pub fn downgrade(this: &Self) -> WeakActivePeers {
943 WeakActivePeers(Arc::downgrade(&this.0))
944 }
945}
946
947#[derive(Clone)]
948pub struct WeakActivePeers(Weak<ActivePeersInner>);
949
950impl WeakActivePeers {
951 pub fn upgrade(&self) -> Option<ActivePeers> {
952 self.0.upgrade().map(ActivePeers)
953 }
954}
955
956struct ActivePeersInner {
957 connections: FastDashMap<PeerId, Connection>,
958 connections_len: AtomicUsize,
959 events_tx: broadcast::Sender<PeerEvent>,
960}
961
962impl ActivePeersInner {
963 fn new(channel_size: usize) -> Self {
964 let (events_tx, _) = broadcast::channel(channel_size);
965 Self {
966 connections: Default::default(),
967 connections_len: Default::default(),
968 events_tx,
969 }
970 }
971
972 fn get(&self, peer_id: &PeerId) -> Option<Connection> {
973 self.connections
974 .get(peer_id)
975 .map(|item| item.value().clone())
976 }
977
978 fn contains(&self, peer_id: &PeerId) -> bool {
979 self.connections.contains_key(peer_id)
980 }
981
982 #[must_use]
983 fn add(&self, local_id: &PeerId, new_connection: Connection) -> Option<Connection> {
984 use dashmap::mapref::entry::Entry;
985
986 let mut added = false;
987
988 let peer_id = new_connection.peer_id();
989 match self.connections.entry(*peer_id) {
990 Entry::Occupied(mut entry) => {
991 if simultaneous_dial_tie_breaking(
992 local_id,
993 peer_id,
994 entry.get().origin(),
995 new_connection.origin(),
996 ) {
997 tracing::debug!(%peer_id, "closing old connection to mitigate simultaneous dial");
998 let old_connection = entry.insert(new_connection.clone());
999 old_connection.close();
1000 self.send_event(PeerEvent::LostPeer(*peer_id, DisconnectReason::Requested));
1001 } else {
1002 tracing::debug!(%peer_id, "closing new connection to mitigate simultaneous dial");
1003 new_connection.close();
1004 return None;
1005 }
1006 }
1007 Entry::Vacant(entry) => {
1008 self.connections_len.fetch_add(1, Ordering::Release);
1009 entry.insert(new_connection.clone());
1010 added = true;
1011 }
1012 }
1013
1014 self.send_event(PeerEvent::NewPeer(*peer_id));
1015
1016 if added {
1017 metrics::gauge!(METRIC_ACTIVE_PEERS).increment(1);
1018 }
1019 Some(new_connection)
1020 }
1021
1022 fn remove(&self, peer_id: &PeerId, reason: DisconnectReason) {
1023 if let Some((_, connection)) = self.connections.remove(peer_id) {
1024 connection.close();
1025 self.connections_len.fetch_sub(1, Ordering::Release);
1026 self.send_event(PeerEvent::LostPeer(*peer_id, reason));
1027
1028 metrics::gauge!(METRIC_ACTIVE_PEERS).decrement(1);
1029 }
1030 }
1031
1032 fn remove_with_stable_id(&self, peer_id: &PeerId, stable_id: usize, reason: DisconnectReason) {
1033 if let Some((_, connection)) = self
1034 .connections
1035 .remove_if(peer_id, |_, connection| connection.stable_id() == stable_id)
1036 {
1037 connection.close();
1038 self.connections_len.fetch_sub(1, Ordering::Release);
1039 self.send_event(PeerEvent::LostPeer(*peer_id, reason));
1040
1041 metrics::gauge!(METRIC_ACTIVE_PEERS).decrement(1);
1042 }
1043 }
1044
1045 fn subscribe(&self) -> broadcast::Receiver<PeerEvent> {
1046 self.events_tx.subscribe()
1047 }
1048
1049 fn send_event(&self, event: PeerEvent) {
1050 _ = self.events_tx.send(event);
1051 }
1052
1053 fn is_empty(&self) -> bool {
1054 self.connections.is_empty()
1055 }
1056
1057 fn len(&self) -> usize {
1058 self.connections_len.load(Ordering::Acquire)
1059 }
1060}
1061
1062fn simultaneous_dial_tie_breaking(
1063 local_id: &PeerId,
1064 peer_id: &PeerId,
1065 old_origin: Direction,
1066 new_origin: Direction,
1067) -> bool {
1068 match (old_origin, new_origin) {
1069 (Direction::Inbound, Direction::Inbound) | (Direction::Outbound, Direction::Outbound) => {
1070 true
1071 }
1072 (Direction::Inbound, Direction::Outbound) => peer_id < local_id,
1073 (Direction::Outbound, Direction::Inbound) => local_id < peer_id,
1074 }
1075}
1076
1077#[derive(Default, Clone)]
1078#[repr(transparent)]
1079pub struct KnownPeers(Arc<FastDashMap<PeerId, KnownPeerState>>);
1080
1081impl KnownPeers {
1082 pub fn new() -> Self {
1083 Self::default()
1084 }
1085
1086 pub fn contains(&self, peer_id: &PeerId) -> bool {
1087 self.0.contains_key(peer_id)
1088 }
1089
1090 pub fn is_banned(&self, peer_id: &PeerId) -> bool {
1091 self.0
1092 .get(peer_id)
1093 .and_then(|item| {
1094 Some(match item.value() {
1095 KnownPeerState::Stored(item) => item.upgrade()?.is_banned(),
1096 KnownPeerState::Banned => true,
1097 })
1098 })
1099 .unwrap_or_default()
1100 }
1101
1102 pub fn get(&self, peer_id: &PeerId) -> Option<Arc<PeerInfo>> {
1103 self.0.get(peer_id).and_then(|item| match item.value() {
1104 KnownPeerState::Stored(item) => {
1105 let inner = item.upgrade()?;
1106 Some(inner.peer_info.load_full())
1107 }
1108 KnownPeerState::Banned => None,
1109 })
1110 }
1111
1112 pub fn get_affinity(&self, peer_id: &PeerId) -> Option<PeerAffinity> {
1113 self.0
1114 .get(peer_id)
1115 .and_then(|item| item.value().compute_affinity())
1116 }
1117
1118 pub fn remove(&self, peer_id: &PeerId) {
1119 self.0.remove(peer_id);
1120 metrics::gauge!(METRIC_KNOWN_PEERS).decrement(1);
1121 }
1122
1123 pub fn ban(&self, peer_id: &PeerId) {
1124 let mut added = false;
1125 match self.0.entry(*peer_id) {
1126 dashmap::mapref::entry::Entry::Vacant(entry) => {
1127 entry.insert(KnownPeerState::Banned);
1128 added = true;
1129 }
1130 dashmap::mapref::entry::Entry::Occupied(mut entry) => match entry.get_mut() {
1131 KnownPeerState::Banned => {}
1132 KnownPeerState::Stored(item) => match item.upgrade() {
1133 Some(item) => item.affinity.store(AFFINITY_BANNED, Ordering::Release),
1134 None => *entry.get_mut() = KnownPeerState::Banned,
1135 },
1136 },
1137 }
1138
1139 if added {
1140 metrics::gauge!(METRIC_KNOWN_PEERS).increment(1);
1142 }
1143 }
1144
1145 pub fn make_handle(&self, peer_id: &PeerId, with_affinity: bool) -> Option<KnownPeerHandle> {
1146 let inner = match self.0.get(peer_id)?.value() {
1147 KnownPeerState::Stored(item) => {
1148 let inner = item.upgrade()?;
1149 if with_affinity && !inner.increase_affinity() {
1150 return None;
1151 }
1152 inner
1153 }
1154 KnownPeerState::Banned => return None,
1155 };
1156
1157 Some(KnownPeerHandle::from_inner(inner, with_affinity))
1158 }
1159
1160 pub fn insert(
1163 &self,
1164 peer_info: Arc<PeerInfo>,
1165 with_affinity: bool,
1166 ) -> Result<KnownPeerHandle, KnownPeersError> {
1167 let mut added = false;
1169 let inner = match self.0.entry(peer_info.id) {
1170 dashmap::mapref::entry::Entry::Vacant(entry) => {
1171 let inner = KnownPeerInner::new(peer_info, with_affinity, &self.0);
1172 entry.insert(KnownPeerState::Stored(Arc::downgrade(&inner)));
1173 added = true;
1174 inner
1175 }
1176 dashmap::mapref::entry::Entry::Occupied(mut entry) => match entry.get_mut() {
1177 KnownPeerState::Banned => return Err(KnownPeersError::from(PeerBannedError)),
1178 KnownPeerState::Stored(item) => match item.upgrade() {
1179 Some(inner) => match inner.try_update_peer_info(&peer_info, with_affinity)? {
1180 true => inner,
1181 false => return Err(KnownPeersError::OutdatedInfo),
1182 },
1183 None => {
1184 let inner = KnownPeerInner::new(peer_info, with_affinity, &self.0);
1185 *item = Arc::downgrade(&inner);
1186 inner
1187 }
1188 },
1189 },
1190 };
1191
1192 if added {
1193 metrics::gauge!(METRIC_KNOWN_PEERS).increment(1);
1194 }
1195
1196 Ok(KnownPeerHandle::from_inner(inner, with_affinity))
1197 }
1198
1199 pub fn insert_allow_outdated(
1201 &self,
1202 peer_info: Arc<PeerInfo>,
1203 with_affinity: bool,
1204 ) -> Result<KnownPeerHandle, PeerBannedError> {
1205 let mut added = false;
1207 let inner = match self.0.entry(peer_info.id) {
1208 dashmap::mapref::entry::Entry::Vacant(entry) => {
1209 let inner = KnownPeerInner::new(peer_info, with_affinity, &self.0);
1210 entry.insert(KnownPeerState::Stored(Arc::downgrade(&inner)));
1211 added = true;
1212 inner
1213 }
1214 dashmap::mapref::entry::Entry::Occupied(mut entry) => match entry.get_mut() {
1215 KnownPeerState::Banned => return Err(PeerBannedError),
1216 KnownPeerState::Stored(item) => match item.upgrade() {
1217 Some(inner) => {
1218 inner.try_update_peer_info(&peer_info, with_affinity)?;
1220 inner
1221 }
1222 None => {
1223 let inner = KnownPeerInner::new(peer_info, with_affinity, &self.0);
1224 *item = Arc::downgrade(&inner);
1225 inner
1226 }
1227 },
1228 },
1229 };
1230
1231 if added {
1232 metrics::gauge!(METRIC_KNOWN_PEERS).increment(1);
1233 }
1234
1235 Ok(KnownPeerHandle::from_inner(inner, with_affinity))
1236 }
1237}
1238
1239enum KnownPeerState {
1240 Stored(Weak<KnownPeerInner>),
1241 Banned,
1242}
1243
1244impl KnownPeerState {
1245 fn compute_affinity(&self) -> Option<PeerAffinity> {
1246 Some(match self {
1247 Self::Stored(weak) => weak.upgrade()?.compute_affinity(),
1248 Self::Banned => PeerAffinity::Never,
1249 })
1250 }
1251}
1252
1253#[derive(Clone)]
1254#[repr(transparent)]
1255pub struct KnownPeerHandle(KnownPeerHandleState);
1256
1257impl KnownPeerHandle {
1258 fn from_inner(inner: Arc<KnownPeerInner>, with_affinity: bool) -> Self {
1259 KnownPeerHandle(if with_affinity {
1260 KnownPeerHandleState::WithAffinity(ManuallyDrop::new(Arc::new(
1261 KnownPeerHandleWithAffinity { inner },
1262 )))
1263 } else {
1264 KnownPeerHandleState::Simple(ManuallyDrop::new(inner))
1265 })
1266 }
1267
1268 pub fn peer_info(&self) -> arc_swap::Guard<Arc<PeerInfo>, arc_swap::DefaultStrategy> {
1269 self.inner().peer_info.load()
1270 }
1271
1272 pub fn load_peer_info(&self) -> Arc<PeerInfo> {
1273 arc_swap::Guard::into_inner(self.peer_info())
1274 }
1275
1276 pub fn is_banned(&self) -> bool {
1277 self.inner().is_banned()
1278 }
1279
1280 pub fn max_affinity(&self) -> PeerAffinity {
1281 self.inner().compute_affinity()
1282 }
1283
1284 pub fn update_peer_info(&self, peer_info: &Arc<PeerInfo>) -> Result<(), KnownPeersError> {
1285 match self.inner().try_update_peer_info(peer_info, false) {
1286 Ok(true) => Ok(()),
1287 Ok(false) => Err(KnownPeersError::OutdatedInfo),
1288 Err(e) => Err(KnownPeersError::PeerBanned(e)),
1289 }
1290 }
1291
1292 pub fn ban(&self) -> bool {
1293 let inner = self.inner();
1294 inner.affinity.swap(AFFINITY_BANNED, Ordering::AcqRel) != AFFINITY_BANNED
1295 }
1296
1297 pub fn increase_affinity(&mut self) -> bool {
1298 match &mut self.0 {
1299 KnownPeerHandleState::Simple(inner) => {
1300 inner.increase_affinity();
1302
1303 let inner = unsafe { ManuallyDrop::take(inner) };
1305
1306 let prev_state = std::mem::replace(
1309 &mut self.0,
1310 KnownPeerHandleState::WithAffinity(ManuallyDrop::new(Arc::new(
1311 KnownPeerHandleWithAffinity { inner },
1312 ))),
1313 );
1314
1315 #[allow(clippy::mem_forget)]
1317 std::mem::forget(prev_state);
1318
1319 true
1320 }
1321 KnownPeerHandleState::WithAffinity(_) => false,
1322 }
1323 }
1324
1325 pub fn decrease_affinity(&mut self) -> bool {
1326 match &mut self.0 {
1327 KnownPeerHandleState::Simple(_) => false,
1328 KnownPeerHandleState::WithAffinity(inner) => {
1329 inner.inner.decrease_affinity();
1331
1332 let inner = unsafe { ManuallyDrop::take(inner) };
1334
1335 let inner = match Arc::try_unwrap(inner) {
1337 Ok(KnownPeerHandleWithAffinity { inner }) => inner,
1338 Err(inner) => inner.inner.clone(),
1339 };
1340
1341 let prev_state = std::mem::replace(
1344 &mut self.0,
1345 KnownPeerHandleState::Simple(ManuallyDrop::new(inner)),
1346 );
1347
1348 #[allow(clippy::mem_forget)]
1350 std::mem::forget(prev_state);
1351
1352 true
1353 }
1354 }
1355 }
1356
1357 pub fn downgrade(&self) -> WeakKnownPeerHandle {
1358 WeakKnownPeerHandle(match &self.0 {
1359 KnownPeerHandleState::Simple(data) => {
1360 WeakKnownPeerHandleState::Simple(Arc::downgrade(data))
1361 }
1362 KnownPeerHandleState::WithAffinity(data) => {
1363 WeakKnownPeerHandleState::WithAffinity(Arc::downgrade(data))
1364 }
1365 })
1366 }
1367
1368 fn inner(&self) -> &KnownPeerInner {
1369 match &self.0 {
1370 KnownPeerHandleState::Simple(data) => data.as_ref(),
1371 KnownPeerHandleState::WithAffinity(data) => data.inner.as_ref(),
1372 }
1373 }
1374}
1375
1376#[derive(Clone)]
1377enum KnownPeerHandleState {
1378 Simple(ManuallyDrop<Arc<KnownPeerInner>>),
1379 WithAffinity(ManuallyDrop<Arc<KnownPeerHandleWithAffinity>>),
1380}
1381
1382impl Drop for KnownPeerHandleState {
1383 fn drop(&mut self) {
1384 let inner;
1385 let is_banned;
1386 match self {
1387 KnownPeerHandleState::Simple(data) => {
1388 inner = unsafe { ManuallyDrop::take(data) };
1390 is_banned = inner.is_banned();
1391 }
1392 KnownPeerHandleState::WithAffinity(data) => {
1393 match Arc::into_inner(unsafe { ManuallyDrop::take(data) }) {
1395 Some(data) => {
1396 inner = data.inner;
1397 is_banned = !inner.decrease_affinity() || inner.is_banned();
1398 }
1399 None => return,
1400 }
1401 }
1402 };
1403
1404 if is_banned {
1405 return;
1407 }
1408
1409 if let Some(inner) = Arc::into_inner(inner) {
1410 if let Some(peers) = inner.weak_known_peers.upgrade() {
1412 peers.remove(&inner.peer_info.load().id);
1413 metrics::gauge!(METRIC_KNOWN_PEERS).decrement(1);
1414 }
1415 }
1416 }
1417}
1418
1419#[derive(Clone, PartialEq, Eq)]
1420#[repr(transparent)]
1421pub struct WeakKnownPeerHandle(WeakKnownPeerHandleState);
1422
1423impl WeakKnownPeerHandle {
1424 pub fn upgrade(&self) -> Option<KnownPeerHandle> {
1425 Some(KnownPeerHandle(match &self.0 {
1426 WeakKnownPeerHandleState::Simple(weak) => {
1427 KnownPeerHandleState::Simple(ManuallyDrop::new(weak.upgrade()?))
1428 }
1429 WeakKnownPeerHandleState::WithAffinity(weak) => {
1430 KnownPeerHandleState::WithAffinity(ManuallyDrop::new(weak.upgrade()?))
1431 }
1432 }))
1433 }
1434}
1435
1436#[derive(Clone)]
1437enum WeakKnownPeerHandleState {
1438 Simple(Weak<KnownPeerInner>),
1439 WithAffinity(Weak<KnownPeerHandleWithAffinity>),
1440}
1441
1442impl Eq for WeakKnownPeerHandleState {}
1443impl PartialEq for WeakKnownPeerHandleState {
1444 #[inline]
1445 fn eq(&self, other: &Self) -> bool {
1446 match (self, other) {
1447 (Self::Simple(left), Self::Simple(right)) => Weak::ptr_eq(left, right),
1448 (Self::WithAffinity(left), Self::WithAffinity(right)) => Weak::ptr_eq(left, right),
1449 _ => false,
1450 }
1451 }
1452}
1453
1454struct KnownPeerHandleWithAffinity {
1455 inner: Arc<KnownPeerInner>,
1456}
1457
1458struct KnownPeerInner {
1459 peer_info: ArcSwap<PeerInfo>,
1460 affinity: AtomicUsize,
1461 weak_known_peers: Weak<FastDashMap<PeerId, KnownPeerState>>,
1462}
1463
1464impl KnownPeerInner {
1465 fn new(
1466 peer_info: Arc<PeerInfo>,
1467 with_affinity: bool,
1468 known_peers: &Arc<FastDashMap<PeerId, KnownPeerState>>,
1469 ) -> Arc<Self> {
1470 Arc::new(Self {
1471 peer_info: ArcSwap::from(peer_info),
1472 affinity: AtomicUsize::new(if with_affinity { 1 } else { 0 }),
1473 weak_known_peers: Arc::downgrade(known_peers),
1474 })
1475 }
1476
1477 fn is_banned(&self) -> bool {
1478 self.affinity.load(Ordering::Acquire) == AFFINITY_BANNED
1479 }
1480
1481 fn compute_affinity(&self) -> PeerAffinity {
1482 match self.affinity.load(Ordering::Acquire) {
1483 0 => PeerAffinity::Allowed,
1484 AFFINITY_BANNED => PeerAffinity::Never,
1485 _ => PeerAffinity::High,
1486 }
1487 }
1488
1489 fn increase_affinity(&self) -> bool {
1490 let mut current = self.affinity.load(Ordering::Acquire);
1491 while current != AFFINITY_BANNED {
1492 debug_assert_ne!(current, AFFINITY_BANNED - 1);
1493 match self.affinity.compare_exchange_weak(
1494 current,
1495 current + 1,
1496 Ordering::Release,
1497 Ordering::Acquire,
1498 ) {
1499 Ok(_) => return true,
1500 Err(affinity) => current = affinity,
1501 }
1502 }
1503
1504 false
1505 }
1506
1507 fn decrease_affinity(&self) -> bool {
1508 let mut current = self.affinity.load(Ordering::Acquire);
1509 while current != AFFINITY_BANNED {
1510 debug_assert_ne!(current, 0);
1511 match self.affinity.compare_exchange_weak(
1512 current,
1513 current - 1,
1514 Ordering::Release,
1515 Ordering::Acquire,
1516 ) {
1517 Ok(_) => return true,
1518 Err(affinity) => current = affinity,
1519 }
1520 }
1521
1522 false
1523 }
1524
1525 fn try_update_peer_info(
1526 &self,
1527 peer_info: &Arc<PeerInfo>,
1528 with_affinity: bool,
1529 ) -> Result<bool, PeerBannedError> {
1530 struct AffinityGuard<'a> {
1531 inner: &'a KnownPeerInner,
1532 decrease_on_drop: bool,
1533 }
1534
1535 impl AffinityGuard<'_> {
1536 fn increase_affinity_or_check_ban(&mut self, with_affinity: bool) -> bool {
1537 let with_affinity = with_affinity && !self.decrease_on_drop;
1538 let is_banned = if with_affinity {
1539 !self.inner.increase_affinity()
1540 } else {
1541 self.inner.is_banned()
1542 };
1543
1544 if !is_banned && with_affinity {
1545 self.decrease_on_drop = true;
1546 }
1547
1548 is_banned
1549 }
1550 }
1551
1552 impl Drop for AffinityGuard<'_> {
1553 fn drop(&mut self) {
1554 if self.decrease_on_drop {
1555 self.inner.decrease_affinity();
1556 }
1557 }
1558 }
1559
1560 let mut guard = AffinityGuard {
1562 inner: self,
1563 decrease_on_drop: false,
1564 };
1565
1566 let mut cur = self.peer_info.load();
1567 let updated = loop {
1568 if guard.increase_affinity_or_check_ban(with_affinity) {
1569 return Err(PeerBannedError);
1571 }
1572
1573 match cur.created_at.cmp(&peer_info.created_at) {
1574 std::cmp::Ordering::Equal => break true,
1577 std::cmp::Ordering::Less => {
1579 let prev = self.peer_info.compare_and_swap(&*cur, peer_info.clone());
1580 if std::ptr::eq(cur.as_raw(), prev.as_raw()) {
1581 break true;
1582 } else {
1583 cur = prev;
1584 }
1585 }
1586 std::cmp::Ordering::Greater => break false,
1588 }
1589 };
1590
1591 guard.decrease_on_drop = false;
1592 Ok(updated)
1593 }
1594}
1595
1596const AFFINITY_BANNED: usize = usize::MAX;
1597
1598#[derive(Debug, thiserror::Error)]
1599pub enum KnownPeersError {
1600 #[error(transparent)]
1601 PeerBanned(#[from] PeerBannedError),
1602 #[error("provided peer info is outdated")]
1603 OutdatedInfo,
1604}
1605
1606#[derive(Debug, Copy, Clone, thiserror::Error)]
1607#[error("peer is banned")]
1608pub struct PeerBannedError;
1609
1610#[cfg(test)]
1611mod tests {
1612 use super::*;
1613 use crate::util::make_peer_info_stub;
1614
1615 #[test]
1616 fn remove_from_cache_on_drop_works() {
1617 let peers = KnownPeers::new();
1618
1619 let peer_info = make_peer_info_stub(rand::random());
1620 let handle = peers.insert(peer_info.clone(), false).unwrap();
1621 assert!(peers.contains(&peer_info.id));
1622 assert!(!peers.is_banned(&peer_info.id));
1623 assert_eq!(peers.get(&peer_info.id), Some(peer_info.clone()));
1624 assert_eq!(
1625 peers.get_affinity(&peer_info.id),
1626 Some(PeerAffinity::Allowed)
1627 );
1628
1629 assert_eq!(handle.peer_info().as_ref(), peer_info.as_ref());
1630 assert_eq!(handle.max_affinity(), PeerAffinity::Allowed);
1631
1632 let other_handle = peers.insert(peer_info.clone(), false).unwrap();
1633 assert!(peers.contains(&peer_info.id));
1634 assert!(!peers.is_banned(&peer_info.id));
1635 assert_eq!(peers.get(&peer_info.id), Some(peer_info.clone()));
1636 assert_eq!(
1637 peers.get_affinity(&peer_info.id),
1638 Some(PeerAffinity::Allowed)
1639 );
1640
1641 assert_eq!(other_handle.peer_info().as_ref(), peer_info.as_ref());
1642 assert_eq!(other_handle.max_affinity(), PeerAffinity::Allowed);
1643
1644 drop(other_handle);
1645 assert!(peers.contains(&peer_info.id));
1646 assert!(!peers.is_banned(&peer_info.id));
1647 assert_eq!(peers.get(&peer_info.id), Some(peer_info.clone()));
1648 assert_eq!(
1649 peers.get_affinity(&peer_info.id),
1650 Some(PeerAffinity::Allowed)
1651 );
1652
1653 drop(handle);
1654 assert!(!peers.contains(&peer_info.id));
1655 assert!(!peers.is_banned(&peer_info.id));
1656 assert_eq!(peers.get(&peer_info.id), None);
1657 assert_eq!(peers.get_affinity(&peer_info.id), None);
1658
1659 peers.insert(peer_info.clone(), false).unwrap();
1660 }
1661
1662 #[test]
1663 fn with_affinity_after_simple() {
1664 let peers = KnownPeers::new();
1665
1666 let peer_info = make_peer_info_stub(rand::random());
1667 let handle_simple = peers.insert(peer_info.clone(), false).unwrap();
1668 assert!(peers.contains(&peer_info.id));
1669 assert_eq!(
1670 peers.get_affinity(&peer_info.id),
1671 Some(PeerAffinity::Allowed)
1672 );
1673 assert_eq!(handle_simple.max_affinity(), PeerAffinity::Allowed);
1674
1675 let handle_with_affinity = peers.insert(peer_info.clone(), true).unwrap();
1676 assert!(peers.contains(&peer_info.id));
1677 assert_eq!(peers.get_affinity(&peer_info.id), Some(PeerAffinity::High));
1678 assert_eq!(handle_with_affinity.max_affinity(), PeerAffinity::High);
1679 assert_eq!(handle_simple.max_affinity(), PeerAffinity::High);
1680
1681 drop(handle_with_affinity);
1682 assert!(peers.contains(&peer_info.id));
1683 assert_eq!(handle_simple.max_affinity(), PeerAffinity::Allowed);
1684 assert_eq!(
1685 peers.get_affinity(&peer_info.id),
1686 Some(PeerAffinity::Allowed)
1687 );
1688
1689 drop(handle_simple);
1690 assert!(!peers.contains(&peer_info.id));
1691 assert_eq!(peers.get_affinity(&peer_info.id), None);
1692 }
1693
1694 #[test]
1695 fn with_affinity_before_simple() {
1696 let peers = KnownPeers::new();
1697
1698 let peer_info = make_peer_info_stub(rand::random());
1699 let handle_with_affinity = peers.insert(peer_info.clone(), true).unwrap();
1700 assert!(peers.contains(&peer_info.id));
1701 assert_eq!(peers.get_affinity(&peer_info.id), Some(PeerAffinity::High));
1702 assert_eq!(handle_with_affinity.max_affinity(), PeerAffinity::High);
1703
1704 let handle_simple = peers.insert(peer_info.clone(), false).unwrap();
1705 assert!(peers.contains(&peer_info.id));
1706 assert_eq!(peers.get_affinity(&peer_info.id), Some(PeerAffinity::High));
1707 assert_eq!(handle_with_affinity.max_affinity(), PeerAffinity::High);
1708 assert_eq!(handle_simple.max_affinity(), PeerAffinity::High);
1709
1710 drop(handle_simple);
1711 assert!(peers.contains(&peer_info.id));
1712 assert_eq!(handle_with_affinity.max_affinity(), PeerAffinity::High);
1713 assert_eq!(peers.get_affinity(&peer_info.id), Some(PeerAffinity::High));
1714
1715 drop(handle_with_affinity);
1716 assert!(!peers.contains(&peer_info.id));
1717 assert_eq!(peers.get_affinity(&peer_info.id), None);
1718 }
1719
1720 #[test]
1721 fn ban_while_handle_exists() {
1722 let peers = KnownPeers::new();
1723
1724 let peer_info = make_peer_info_stub(rand::random());
1725 let handle = peers.insert(peer_info.clone(), false).unwrap();
1726 assert!(peers.contains(&peer_info.id));
1727 assert_eq!(handle.max_affinity(), PeerAffinity::Allowed);
1728
1729 peers.ban(&peer_info.id);
1730 assert!(peers.contains(&peer_info.id));
1731 assert!(peers.is_banned(&peer_info.id));
1732 assert_eq!(handle.max_affinity(), PeerAffinity::Never);
1733 assert_eq!(peers.get(&peer_info.id), Some(peer_info.clone()));
1734 assert_eq!(peers.get_affinity(&peer_info.id), Some(PeerAffinity::Never));
1735 }
1736}