1use std::collections::{VecDeque, hash_map};
2use std::mem::ManuallyDrop;
3use std::sync::atomic::{AtomicUsize, Ordering};
4use std::sync::{Arc, Weak};
5use std::time::{Duration, Instant};
6
7use anyhow::Result;
8use arc_swap::{ArcSwap, AsRaw};
9use tokio::sync::{broadcast, mpsc, oneshot};
10use tokio::task::{AbortHandle, JoinSet};
11use tokio_util::time::{DelayQueue, delay_queue};
12use tycho_util::{FastDashMap, FastHashMap};
13
14use crate::network::ConnectionError;
15use crate::network::config::NetworkConfig;
16use crate::network::connection::Connection;
17use crate::network::endpoint::{Connecting, ConnectionInitError, Endpoint, Into0RttResult};
18use crate::network::request_handler::InboundRequestHandler;
19use crate::network::wire::{HandshakeError, handshake};
20use crate::types::{
21 Address, BoxCloneService, Direction, DisconnectReason, PeerAffinity, PeerEvent, PeerId,
22 PeerInfo, Response, ServiceRequest,
23};
24
25const METRIC_CONNECTION_OUT_TIME: &str = "tycho_net_conn_out_time";
27const METRIC_CONNECTION_IN_TIME: &str = "tycho_net_conn_in_time";
28
29const METRIC_CONNECTIONS_OUT_TOTAL: &str = "tycho_net_conn_out_total";
31const METRIC_CONNECTIONS_IN_TOTAL: &str = "tycho_net_conn_in_total";
32const METRIC_CONNECTIONS_OUT_FAIL_TOTAL: &str = "tycho_net_conn_out_fail_total";
33const METRIC_CONNECTIONS_IN_FAIL_TOTAL: &str = "tycho_net_conn_in_fail_total";
34
35const METRIC_CONNECTIONS_ACTIVE: &str = "tycho_net_conn_active";
37const METRIC_CONNECTIONS_PENDING: &str = "tycho_net_conn_pending";
38const METRIC_CONNECTIONS_PARTIAL: &str = "tycho_net_conn_partial";
39const METRIC_CONNECTIONS_PENDING_DIALS: &str = "tycho_net_conn_pending_dials";
40
41const METRIC_ACTIVE_PEERS: &str = "tycho_net_active_peers";
42const METRIC_KNOWN_PEERS: &str = "tycho_net_known_peers";
43
44#[derive(Debug)]
45pub(crate) enum ConnectionManagerRequest {
46 Connect(Address, PeerId, CallbackTx),
47 Shutdown(oneshot::Sender<()>),
48}
49
50pub(crate) struct ConnectionManager {
51 config: Arc<NetworkConfig>,
52 endpoint: Arc<Endpoint>,
53
54 mailbox: mpsc::Receiver<ConnectionManagerRequest>,
55
56 pending_connection_callbacks: FastHashMap<Address, PendingConnectionCallbacks>,
57 pending_partial_connections: JoinSet<Option<PartialConnection>>,
58 pending_connections: JoinSet<ConnectingOutput>,
59 connection_handlers: JoinSet<()>,
60 delayed_callbacks: DelayedCallbacksQueue,
61
62 pending_dials: FastHashMap<PeerId, CallbackRx>,
63 dial_backoff_states: FastHashMap<PeerId, DialBackoffState>,
64
65 active_peers: ActivePeers,
66 known_peers: KnownPeers,
67
68 service: BoxCloneService<ServiceRequest, Response>,
69}
70
71type CallbackTx = oneshot::Sender<Result<Connection, ConnectionError>>;
72type CallbackRx = oneshot::Receiver<Result<Connection, ConnectionError>>;
73
74impl Drop for ConnectionManager {
75 fn drop(&mut self) {
76 tracing::trace!("dropping connection manager");
77 self.endpoint.close();
78 }
79}
80
81impl ConnectionManager {
82 pub fn new(
83 config: Arc<NetworkConfig>,
84 endpoint: Arc<Endpoint>,
85 active_peers: ActivePeers,
86 known_peers: KnownPeers,
87 service: BoxCloneService<ServiceRequest, Response>,
88 ) -> (Self, mpsc::Sender<ConnectionManagerRequest>) {
89 let (mailbox_tx, mailbox) = mpsc::channel(config.connection_manager_channel_capacity);
90 let connection_manager = Self {
91 config,
92 endpoint,
93 mailbox,
94 pending_connection_callbacks: Default::default(),
95 pending_partial_connections: Default::default(),
96 pending_connections: Default::default(),
97 connection_handlers: Default::default(),
98 delayed_callbacks: Default::default(),
99 pending_dials: Default::default(),
100 dial_backoff_states: Default::default(),
101 active_peers,
102 known_peers,
103 service,
104 };
105 (connection_manager, mailbox_tx)
106 }
107
108 pub async fn start(mut self) {
109 tracing::info!("connection manager started");
110
111 let jitter = Duration::from_millis(1000).mul_f64(rand::random());
112 let mut interval = tokio::time::interval(self.config.connectivity_check_interval + jitter);
113
114 let mut shutdown_notifier = None;
115
116 loop {
117 tokio::select! {
118 now = interval.tick() => {
119 self.handle_connectivity_check(now.into_std());
120 }
121 maybe_request = self.mailbox.recv() => {
122 let Some(request) = maybe_request else {
123 break;
124 };
125
126 match request {
127 ConnectionManagerRequest::Connect(address, peer_id, callback) => {
128 self.handle_connect_request(address, &peer_id, callback);
129 }
130 ConnectionManagerRequest::Shutdown(oneshot) => {
131 shutdown_notifier = Some(oneshot);
132 break;
133 }
134 }
135 }
136 connecting = self.endpoint.accept() => {
137 if let Some(connecting) = connecting {
138 self.handle_incoming(connecting);
139 }
140 }
141 Some(connecting_output) = self.pending_connections.join_next() => {
142 metrics::gauge!(METRIC_CONNECTIONS_PENDING).decrement(1);
143 match connecting_output {
144 Ok(connecting) => self.handle_connecting_result(connecting),
145 Err(e) => {
146 if e.is_panic() {
147 std::panic::resume_unwind(e.into_panic());
148 }
149 }
150 }
151 }
152 Some(partial_connection) = self.pending_partial_connections.join_next() => {
153 metrics::gauge!(METRIC_CONNECTIONS_PARTIAL).decrement(1);
154
155 match partial_connection {
156 Ok(None) => {}
157 Ok(Some(PartialConnection {
158 connection,
159 timeout_at,
160 })) => self.handle_incoming_impl(connection, None, timeout_at),
161 Err(e) => {
163 if e.is_panic() {
164 std::panic::resume_unwind(e.into_panic());
165 }
166 }
167 }
168 }
169 Some(connection_handler_output) = self.connection_handlers.join_next() => {
170 metrics::gauge!(METRIC_CONNECTIONS_ACTIVE).decrement(1);
171
172 if let Err(e) = connection_handler_output && e.is_panic() {
174 std::panic::resume_unwind(e.into_panic());
175 }
176 }
177 Some(peer_id) = self.delayed_callbacks.wait_for_next_expired() => {
178 self.delayed_callbacks.execute_expired(&peer_id);
179 }
180 }
181 }
182
183 self.shutdown().await;
184
185 if let Some(tx) = shutdown_notifier {
186 _ = tx.send(());
187 }
188
189 tracing::info!("connection manager stopped");
190 }
191
192 async fn shutdown(mut self) {
193 tracing::trace!("shutting down connection manager");
194
195 self.endpoint.close();
196
197 self.pending_partial_connections.shutdown().await;
198 metrics::gauge!(METRIC_CONNECTIONS_PARTIAL).set(0);
199
200 self.pending_connections.shutdown().await;
201 metrics::gauge!(METRIC_CONNECTIONS_PENDING).set(0);
202
203 while self.connection_handlers.join_next().await.is_some() {
204 metrics::gauge!(METRIC_CONNECTIONS_ACTIVE).decrement(1);
205 }
206 assert!(self.active_peers.is_empty());
207
208 self.endpoint
209 .wait_idle(self.config.shutdown_idle_timeout)
210 .await;
211 }
212
213 fn handle_connectivity_check(&mut self, now: Instant) {
214 use std::collections::hash_map::Entry;
215
216 self.pending_dials
217 .retain(|peer_id, oneshot| match oneshot.try_recv() {
218 Ok(Ok(returned_peer_id)) => {
219 debug_assert_eq!(peer_id, returned_peer_id.peer_id());
220 self.dial_backoff_states.remove(peer_id);
221 false
222 }
223 Ok(Err(_)) => {
224 match self.dial_backoff_states.entry(*peer_id) {
225 Entry::Occupied(mut entry) => entry.get_mut().update(
226 now,
227 self.config.connection_backoff,
228 self.config.max_connection_backoff,
229 ),
230 Entry::Vacant(entry) => {
231 entry.insert(DialBackoffState::new(
232 now,
233 self.config.connection_backoff,
234 self.config.max_connection_backoff,
235 ));
236 }
237 }
238 false
239 }
240 Err(oneshot::error::TryRecvError::Closed) => {
241 panic!("BUG: connection manager never finished dialing a peer");
242 }
243 Err(oneshot::error::TryRecvError::Empty) => true,
244 });
245
246 let outstanding_connections_limit = self
247 .config
248 .max_concurrent_outstanding_connections
249 .saturating_sub(self.pending_connections.len());
250
251 let outstanding_connections = self
252 .known_peers
253 .0
254 .iter()
255 .filter_map(|item| {
256 let value = match item.value() {
257 KnownPeerState::Stored(item) => item.upgrade()?,
258 KnownPeerState::Banned => return None,
259 };
260 let peer_info = value.peer_info.load();
261 let affinity = value.compute_affinity();
262
263 (affinity == PeerAffinity::High
264 && peer_info.id != self.endpoint.peer_id()
265 && !self.active_peers.contains(&peer_info.id)
266 && !self.pending_dials.contains_key(&peer_info.id)
267 && self
268 .dial_backoff_states
269 .get(&peer_info.id)
270 .is_none_or(|state| now > state.next_attempt_at))
271 .then(|| arc_swap::Guard::into_inner(peer_info))
272 })
273 .take(outstanding_connections_limit)
274 .collect::<Vec<_>>();
275
276 for peer_info in outstanding_connections {
277 let address = peer_info
279 .iter_addresses()
280 .next()
281 .cloned()
282 .expect("address list must have at least one item");
283
284 let (tx, rx) = oneshot::channel();
285 self.dial_peer(address, &peer_info.id, tx);
286 self.pending_dials.insert(peer_info.id, rx);
287 }
288
289 metrics::gauge!(METRIC_CONNECTIONS_PENDING_DIALS).set(self.pending_dials.len() as f64);
290 }
291
292 fn handle_connect_request(&mut self, address: Address, peer_id: &PeerId, callback: CallbackTx) {
293 self.dial_peer(address, peer_id, callback);
294 }
295
296 fn handle_incoming(&mut self, connecting: Connecting) {
297 let remote_addr = connecting.remote_address();
298 tracing::trace!(
299 local_id = %self.endpoint.peer_id(),
300 %remote_addr,
301 "received an incoming connection",
302 );
303
304 match connecting.into_0rtt() {
306 Into0RttResult::Established(connection, accepted) => {
307 let timeout_at = Instant::now() + self.config.connect_timeout;
308 self.handle_incoming_impl(connection, Some(accepted), timeout_at);
309 }
310 Into0RttResult::WithoutIdentity(partial_connection) => {
311 tracing::trace!("connection identity is not available yet");
312
313 let timeout_at = Instant::now() + self.config.connect_timeout;
314 self.pending_partial_connections.spawn(async move {
315 match tokio::time::timeout_at(timeout_at.into(), partial_connection).await {
316 Ok(Ok(connection)) => Some(PartialConnection {
317 connection,
318 timeout_at,
319 }),
320 Ok(Err(e)) => {
321 tracing::trace!(
322 %remote_addr,
323 "failed to establish an incoming connection: {e}",
324 );
325 None
326 }
327 Err(_) => {
328 tracing::trace!(
329 %remote_addr,
330 "incoming connection timed out",
331 );
332 None
333 }
334 }
335 });
336 metrics::gauge!(METRIC_CONNECTIONS_PARTIAL).increment(1);
337 }
338 Into0RttResult::InvalidCertificate => {
339 tracing::trace!(%remote_addr, "invalid incoming connection");
340 }
341 Into0RttResult::Unavailable(_) => unreachable!(
342 "BUG: For incoming connections, a 0.5-RTT connection must \
343 always be successfully constructed."
344 ),
345 };
346 }
347
348 fn handle_incoming_impl(
349 &mut self,
350 connection: Connection,
351 accepted: Option<quinn::ZeroRttAccepted>,
352 timeout_at: Instant,
353 ) {
354 async fn handle_incoming_task(
355 seqno: u32,
356 connection: ConnectionClosedOnDrop,
357 accepted: Option<quinn::ZeroRttAccepted>,
358 timeout_at: Instant,
359 ) -> ConnectingOutput {
360 let target_peer_id = *connection.peer_id();
361 let target_address = connection.remote_address().into();
362 let fut = async {
363 if let Some(accepted) = accepted {
364 accepted.await;
366 }
367 handshake(&connection).await
368 };
369
370 let started_at = Instant::now();
371
372 let connecting_result = match tokio::time::timeout_at(timeout_at.into(), fut).await {
373 Ok(Ok(())) => Ok(connection.disarm()),
374 Ok(Err(e)) => Err(FullConnectionError::HandshakeFailed(e)),
375 Err(_) => Err(FullConnectionError::Timeout),
376 };
377
378 metrics::histogram!(METRIC_CONNECTION_IN_TIME).record(started_at.elapsed());
379
380 ConnectingOutput {
381 seqno,
382 drop_result: true,
383 connecting_result: ManuallyDrop::new(connecting_result),
384 target_address,
385 target_peer_id,
386 origin: Direction::Inbound,
387 }
388 }
389
390 let remote_addr = connection.remote_address();
391
392 match self.known_peers.get_affinity(connection.peer_id()) {
394 Some(PeerAffinity::High | PeerAffinity::Allowed) => {}
395 Some(PeerAffinity::Never) => {
396 tracing::warn!(
398 %remote_addr,
399 peer_id = %connection.peer_id(),
400 "rejecting connection due to PeerAffinity::Never",
401 );
402 connection.close();
403 return;
404 }
405 _ => {
406 if matches!(
407 self.config.max_concurrent_connections,
408 Some(limit) if self.active_peers.len() >= limit
409 ) {
410 tracing::warn!(
412 %remote_addr,
413 peer_id = %connection.peer_id(),
414 "rejecting connection due too many concurrent connections",
415 );
416 connection.close();
417 return;
418 }
419 }
420 }
421
422 let entry = match self.pending_connection_callbacks.entry(remote_addr.into()) {
423 hash_map::Entry::Vacant(entry) => Some(entry.insert(PendingConnectionCallbacks {
424 last_seqno: 0,
425 origin: Direction::Inbound,
426 callbacks: Default::default(),
427 abort_handle: None,
428 })),
429 hash_map::Entry::Occupied(entry) => {
430 let entry = entry.into_mut();
431
432 if simultaneous_dial_tie_breaking(
434 self.endpoint.peer_id(),
435 connection.peer_id(),
436 entry.origin,
437 Direction::Inbound,
438 ) {
439 tracing::debug!(
441 %remote_addr,
442 peer_id = %connection.peer_id(),
443 "cancelling old connection to mitigate simultaneous dial",
444 );
445
446 entry.origin = Direction::Inbound;
447 entry.last_seqno += 1;
448 if let Some(handle) = entry.abort_handle.take() {
449 handle.abort();
450 }
451 Some(entry)
452 } else {
453 tracing::debug!(
455 %remote_addr,
456 peer_id = %connection.peer_id(),
457 "cancelling new connection to mitigate simultaneous dial",
458 );
459
460 connection.close();
461 None
462 }
463 }
464 };
465
466 if let Some(entry) = entry {
467 entry.abort_handle = Some(self.pending_connections.spawn(handle_incoming_task(
468 entry.last_seqno,
469 ConnectionClosedOnDrop::new(connection),
470 accepted,
471 timeout_at,
472 )));
473 metrics::gauge!(METRIC_CONNECTIONS_PENDING).increment(1);
474 }
475 }
476
477 fn handle_connecting_result(&mut self, mut res: ConnectingOutput) {
478 {
480 let Some(entry) = self.pending_connection_callbacks.get(&res.target_address) else {
481 tracing::trace!("connection task reordering detected");
482 return;
483 };
484
485 if entry.last_seqno != res.seqno {
486 tracing::debug!(
487 local_id = %self.endpoint.peer_id(),
488 peer_id = %res.target_peer_id,
489 remote_addr = %res.target_address,
490 "connection result is outdated"
491 );
492 return;
493 }
494 }
495
496 let callbacks = self
497 .pending_connection_callbacks
498 .remove(&res.target_address)
499 .expect("Connection tasks must be tracked")
500 .callbacks;
501
502 res.drop_result = false;
503 match unsafe { ManuallyDrop::take(&mut res.connecting_result) } {
505 Ok(connection) => {
506 tracing::debug!(
507 local_id = %self.endpoint.peer_id(),
508 peer_id = %connection.peer_id(),
509 remote_addr = %res.target_address,
510 "new connection",
511 );
512
513 let connection = self.add_peer(connection);
514 self.delayed_callbacks.execute_resolved(&connection);
515
516 for callback in callbacks {
517 _ = callback.send(Ok(connection.clone()));
518 }
519 }
520 Err(e) => {
521 tracing::debug!(
522 local_id = %self.endpoint.peer_id(),
523 peer_id = %res.target_peer_id,
524 remote_addr = %res.target_address,
525 "connection failed: {e:?}"
526 );
527
528 metrics::counter!(match res.origin {
529 Direction::Outbound => METRIC_CONNECTIONS_OUT_FAIL_TOTAL,
530 Direction::Inbound => METRIC_CONNECTIONS_IN_FAIL_TOTAL,
531 })
532 .increment(1);
533
534 let brief_error = e.as_brief();
535
536 if e.was_closed() && !callbacks.is_empty() {
539 self.delayed_callbacks.push(
540 &res.target_peer_id,
541 brief_error,
542 callbacks,
543 &self.config.connection_error_delay,
544 );
545 return;
546 }
547
548 for callback in callbacks {
549 _ = callback.send(Err(brief_error));
550 }
551 }
552 }
553 }
554
555 fn add_peer(&mut self, connection: Connection) -> Connection {
556 match self.active_peers.add(self.endpoint.peer_id(), connection) {
557 AddedPeer::New(connection) => {
558 let origin = connection.origin();
559
560 let handler = InboundRequestHandler::new(
561 self.config.clone(),
562 connection.clone(),
563 self.service.clone(),
564 self.active_peers.clone(),
565 );
566
567 metrics::counter!(match origin {
568 Direction::Outbound => METRIC_CONNECTIONS_OUT_TOTAL,
569 Direction::Inbound => METRIC_CONNECTIONS_IN_TOTAL,
570 })
571 .increment(1);
572
573 metrics::gauge!(METRIC_CONNECTIONS_ACTIVE).increment(1);
574 self.connection_handlers.spawn(handler.start());
575
576 connection
577 }
578 AddedPeer::Existing(connection) => connection,
579 }
580 }
581
582 #[tracing::instrument(
583 level = "trace",
584 skip_all,
585 fields(
586 local_id = %self.endpoint.peer_id(),
587 peer_id = %peer_id,
588 remote_addr = %address,
589 ),
590 )]
591 fn dial_peer(&mut self, address: Address, peer_id: &PeerId, callback: CallbackTx) {
592 async fn dial_peer_task(
593 seqno: u32,
594 endpoint: Arc<Endpoint>,
595 address: Address,
596 peer_id: PeerId,
597 config: Arc<NetworkConfig>,
598 ) -> ConnectingOutput {
599 let fut = async {
600 let address = address
601 .resolve()
602 .await
603 .map_err(FullConnectionError::InvalidAddress)?;
604
605 let connecting = endpoint
606 .connect_with_expected_id(&address, &peer_id)
607 .map_err(|e| FullConnectionError::InvalidAddress(std::io::Error::other(e)))?;
608
609 let connection = ConnectionClosedOnDrop::new(connecting.await?);
610 match handshake(&connection).await {
611 Ok(()) => Ok(connection),
612 Err(e) => Err(FullConnectionError::HandshakeFailed(e)),
613 }
614 };
615
616 let started_at = Instant::now();
617
618 let connecting_result = match tokio::time::timeout(config.connect_timeout, fut).await {
619 Ok(res) => res.map(ConnectionClosedOnDrop::disarm),
620 Err(_) => Err(FullConnectionError::Timeout),
621 };
622
623 metrics::histogram!(METRIC_CONNECTION_OUT_TIME).record(started_at.elapsed());
624
625 ConnectingOutput {
626 seqno,
627 drop_result: true,
628 connecting_result: ManuallyDrop::new(connecting_result),
629 target_address: address,
630 target_peer_id: peer_id,
631 origin: Direction::Outbound,
632 }
633 }
634
635 if let Some(connection) = self.active_peers.get(peer_id) {
636 tracing::debug!("peer is already connected");
637 _ = callback.send(Ok(connection));
638 return;
639 }
640
641 tracing::trace!("connecting to peer");
642
643 let entry = match self.pending_connection_callbacks.entry(address.clone()) {
644 hash_map::Entry::Vacant(entry) => Some(entry.insert(PendingConnectionCallbacks {
645 last_seqno: 0,
646 origin: Direction::Outbound,
647 callbacks: vec![callback],
648 abort_handle: None,
649 })),
650 hash_map::Entry::Occupied(entry) => {
651 let entry = entry.into_mut();
652
653 entry.callbacks.push(callback);
655
656 let break_tie = simultaneous_dial_tie_breaking(
658 self.endpoint.peer_id(),
659 peer_id,
660 entry.origin,
661 Direction::Outbound,
662 );
663
664 if break_tie && entry.origin != Direction::Outbound {
665 tracing::debug!("cancelling old connection to mitigate simultaneous dial");
667
668 entry.origin = Direction::Outbound;
669 entry.last_seqno += 1;
670 if let Some(handle) = entry.abort_handle.take() {
671 handle.abort();
672 }
673 Some(entry)
674 } else {
675 tracing::trace!("reusing old connection to mitigate simultaneous dial");
677 None
678 }
679 }
680 };
681
682 if let Some(entry) = entry {
683 entry.abort_handle = Some(self.pending_connections.spawn(dial_peer_task(
684 entry.last_seqno,
685 self.endpoint.clone(),
686 address.clone(),
687 *peer_id,
688 self.config.clone(),
689 )));
690 metrics::gauge!(METRIC_CONNECTIONS_PENDING).increment(1);
691 }
692 }
693}
694
695struct PendingConnectionCallbacks {
696 last_seqno: u32,
697 origin: Direction,
698 callbacks: Vec<CallbackTx>,
699 abort_handle: Option<AbortHandle>,
700}
701
702struct PartialConnection {
703 connection: Connection,
704 timeout_at: Instant,
705}
706
707struct ConnectingOutput {
708 seqno: u32,
709 drop_result: bool,
710 connecting_result: ManuallyDrop<Result<Connection, FullConnectionError>>,
711 target_address: Address,
712 target_peer_id: PeerId,
713 origin: Direction,
714}
715
716impl Drop for ConnectingOutput {
717 fn drop(&mut self) {
718 if self.drop_result {
719 unsafe { ManuallyDrop::drop(&mut self.connecting_result) };
721 }
722 }
723}
724
725struct ConnectionClosedOnDrop {
726 connection: ManuallyDrop<Connection>,
727 close_on_drop: bool,
728}
729
730impl ConnectionClosedOnDrop {
731 fn new(connection: Connection) -> Self {
732 Self {
733 connection: ManuallyDrop::new(connection),
734 close_on_drop: true,
735 }
736 }
737
738 fn disarm(mut self) -> Connection {
739 self.close_on_drop = false;
740 unsafe { ManuallyDrop::take(&mut self.connection) }
742 }
743}
744
745impl std::ops::Deref for ConnectionClosedOnDrop {
746 type Target = Connection;
747
748 #[inline]
749 fn deref(&self) -> &Self::Target {
750 &self.connection
751 }
752}
753
754impl Drop for ConnectionClosedOnDrop {
755 fn drop(&mut self) {
756 if self.close_on_drop {
757 let connection = unsafe { ManuallyDrop::take(&mut self.connection) };
759 connection.close();
760 }
761 }
762}
763
764#[derive(Default)]
765struct DelayedCallbacksQueue {
766 callbacks: FastHashMap<PeerId, VecDeque<DelayedCallbacks>>,
767 expirations: DelayQueue<PeerId>,
768}
769
770impl DelayedCallbacksQueue {
771 fn push(
772 &mut self,
773 peer_id: &PeerId,
774 error: ConnectionError,
775 callbacks: Vec<CallbackTx>,
776 delay: &Duration,
777 ) {
778 tracing::debug!(%peer_id, %error, "delayed connection error");
779
780 let expires_at = Instant::now() + *delay;
781 let delay_key = self.expirations.insert_at(*peer_id, expires_at.into());
782
783 let items = self.callbacks.entry(*peer_id).or_default();
784 items.push_back(DelayedCallbacks {
785 delay_key,
786 error,
787 callbacks,
788 expires_at,
789 });
790 }
791
792 async fn wait_for_next_expired(&mut self) -> Option<PeerId> {
793 let res = futures_util::future::poll_fn(|cx| self.expirations.poll_expired(cx)).await?;
794 Some(res.into_inner())
795 }
796
797 fn execute_resolved(&mut self, connection: &Connection) {
798 let Some(items) = self.callbacks.remove(connection.peer_id()) else {
799 return;
800 };
801
802 let mut batches_executed = 0;
803 let mut callbacks_executed = 0;
804
805 for delayed in items {
806 batches_executed += 1;
807 callbacks_executed += delayed.callbacks.len();
808
809 let key = delayed.execute_with_ok(connection);
810
811 self.expirations.remove(&key);
813 }
814
815 tracing::debug!(
816 peer_id = %connection.peer_id(),
817 batches_executed,
818 callbacks_executed,
819 "executed all delayed callbacks",
820 );
821 }
822
823 fn execute_expired(&mut self, peer_id: &PeerId) {
824 let now = Instant::now();
825
826 let mut batches_executed = 0;
827 let mut callbacks_executed = 0;
828
829 'outer: {
830 if let Some(items) = self.callbacks.get_mut(peer_id) {
831 while let Some(front) = items.front() {
832 if !front.is_expired(&now) {
833 break 'outer;
834 }
835
836 if let Some(delayed) = items.pop_front() {
837 batches_executed += 1;
838 callbacks_executed += delayed.callbacks.len();
839
840 let key = delayed.execute_with_error();
841
842 self.expirations.try_remove(&key);
845 }
846 }
847 }
848
849 self.callbacks.remove(peer_id);
851 }
852
853 tracing::debug!(
854 %peer_id,
855 batches_executed,
856 callbacks_executed,
857 "executed expired delayed callbacks"
858 );
859 }
860}
861
862struct DelayedCallbacks {
863 delay_key: delay_queue::Key,
864 error: ConnectionError,
865 callbacks: Vec<CallbackTx>,
866 expires_at: Instant,
867}
868
869impl DelayedCallbacks {
870 fn execute_with_ok(self, connection: &Connection) -> delay_queue::Key {
871 for callback in self.callbacks {
872 _ = callback.send(Ok(connection.clone()));
873 }
874 self.delay_key
875 }
876
877 fn execute_with_error(self) -> delay_queue::Key {
878 for callback in self.callbacks {
879 _ = callback.send(Err(self.error));
880 }
881 self.delay_key
882 }
883
884 fn is_expired(&self, now: &Instant) -> bool {
885 *now >= self.expires_at
886 }
887}
888
889#[derive(Debug)]
890struct DialBackoffState {
891 next_attempt_at: Instant,
892 attempts: usize,
893}
894
895impl DialBackoffState {
896 fn new(now: Instant, step: Duration, max: Duration) -> Self {
897 let mut state = Self {
898 next_attempt_at: now,
899 attempts: 0,
900 };
901 state.update(now, step, max);
902 state
903 }
904
905 fn update(&mut self, now: Instant, step: Duration, max: Duration) {
906 self.attempts += 1;
907 self.next_attempt_at = now
908 + std::cmp::min(
909 max,
910 step.saturating_mul(self.attempts.try_into().unwrap_or(u32::MAX)),
911 );
912 }
913}
914
915#[derive(Debug, thiserror::Error)]
916enum FullConnectionError {
917 #[error("invalid address")]
918 InvalidAddress(#[source] std::io::Error),
919 #[error(transparent)]
920 ConnectionFailed(quinn::ConnectionError),
921 #[error("invalid certificate")]
922 InvalidCertificate,
923 #[error("handshake failed")]
924 HandshakeFailed(#[source] HandshakeError),
925 #[error("connection timeout")]
926 Timeout,
927}
928
929impl FullConnectionError {
930 fn as_brief(&self) -> ConnectionError {
931 fn is_crypto_error(error: &quinn::ConnectionError) -> bool {
932 const QUIC_CRYPTO_FLAG: u64 = 0x100;
933
934 matches!(
935 error,
936 quinn::ConnectionError::TransportError(e)
937 if u64::from(e.code) & QUIC_CRYPTO_FLAG != 0
938 )
939 }
940
941 match self {
942 Self::InvalidAddress(_) => ConnectionError::InvalidAddress,
943 Self::ConnectionFailed(e) if is_crypto_error(e) => ConnectionError::InvalidCertificate,
944 Self::ConnectionFailed(_) => ConnectionError::ConnectionInitFailed,
945 Self::InvalidCertificate => ConnectionError::InvalidCertificate,
946 Self::HandshakeFailed(HandshakeError::ConnectionFailed(e)) if is_crypto_error(e) => {
947 ConnectionError::InvalidCertificate
948 }
949 Self::HandshakeFailed(_) => ConnectionError::HandshakeFailed,
950 Self::Timeout => ConnectionError::Timeout,
951 }
952 }
953
954 fn was_closed(&self) -> bool {
955 let connection_error = match self {
956 Self::ConnectionFailed(e)
957 | Self::HandshakeFailed(HandshakeError::ConnectionFailed(e)) => e,
958 _ => return false,
959 };
960
961 matches!(
962 connection_error,
963 quinn::ConnectionError::ApplicationClosed(closed)
964 if closed.error_code.into_inner() == 0
965 )
966 }
967}
968
969impl From<ConnectionInitError> for FullConnectionError {
970 fn from(value: ConnectionInitError) -> Self {
971 match value {
972 ConnectionInitError::ConnectionFailed(e) => Self::ConnectionFailed(e),
973 ConnectionInitError::InvalidCertificate => Self::InvalidCertificate,
974 }
975 }
976}
977
978#[derive(Clone)]
979pub(crate) struct ActivePeers(Arc<ActivePeersInner>);
980
981impl ActivePeers {
982 pub fn new(channel_size: usize) -> Self {
983 Self(Arc::new(ActivePeersInner::new(channel_size)))
984 }
985
986 pub fn get(&self, peer_id: &PeerId) -> Option<Connection> {
987 self.0.get(peer_id)
988 }
989
990 pub fn contains(&self, peer_id: &PeerId) -> bool {
991 self.0.contains(peer_id)
992 }
993
994 pub fn add(&self, local_id: &PeerId, new_connection: Connection) -> AddedPeer {
995 self.0.add(local_id, new_connection)
996 }
997
998 pub fn remove(&self, peer_id: &PeerId, reason: DisconnectReason) {
999 self.0.remove(peer_id, reason);
1000 }
1001
1002 pub fn remove_with_stable_id(
1003 &self,
1004 peer_id: &PeerId,
1005 stable_id: usize,
1006 reason: DisconnectReason,
1007 ) {
1008 self.0.remove_with_stable_id(peer_id, stable_id, reason);
1009 }
1010
1011 pub fn subscribe(&self) -> broadcast::Receiver<PeerEvent> {
1012 self.0.subscribe()
1013 }
1014
1015 pub fn is_empty(&self) -> bool {
1016 self.0.is_empty()
1017 }
1018
1019 pub fn len(&self) -> usize {
1020 self.0.len()
1021 }
1022}
1023
1024struct ActivePeersInner {
1025 connections: FastDashMap<PeerId, Connection>,
1026 connections_len: AtomicUsize,
1027 events_tx: broadcast::Sender<PeerEvent>,
1028}
1029
1030impl ActivePeersInner {
1031 fn new(channel_size: usize) -> Self {
1032 let (events_tx, _) = broadcast::channel(channel_size);
1033 Self {
1034 connections: Default::default(),
1035 connections_len: Default::default(),
1036 events_tx,
1037 }
1038 }
1039
1040 fn get(&self, peer_id: &PeerId) -> Option<Connection> {
1041 self.connections
1042 .get(peer_id)
1043 .map(|item| item.value().clone())
1044 }
1045
1046 fn contains(&self, peer_id: &PeerId) -> bool {
1047 self.connections.contains_key(peer_id)
1048 }
1049
1050 #[must_use]
1051 fn add(&self, local_id: &PeerId, new_connection: Connection) -> AddedPeer {
1052 use dashmap::mapref::entry::Entry;
1053
1054 let mut added = false;
1055
1056 let peer_id = new_connection.peer_id();
1057 match self.connections.entry(*peer_id) {
1058 Entry::Occupied(mut entry) => {
1059 if simultaneous_dial_tie_breaking(
1060 local_id,
1061 peer_id,
1062 entry.get().origin(),
1063 new_connection.origin(),
1064 ) {
1065 tracing::debug!(%peer_id, "closing old connection to mitigate simultaneous dial");
1066 let old_connection = entry.insert(new_connection.clone());
1067 old_connection.close();
1068 self.send_event(PeerEvent::lost_peer(*peer_id, DisconnectReason::Requested));
1069 } else {
1070 tracing::debug!(%peer_id, "closing new connection to mitigate simultaneous dial");
1071 new_connection.close();
1072 return AddedPeer::Existing(entry.get().clone());
1073 }
1074 }
1075 Entry::Vacant(entry) => {
1076 self.connections_len.fetch_add(1, Ordering::Release);
1077 entry.insert(new_connection.clone());
1078 added = true;
1079 }
1080 }
1081
1082 self.send_event(PeerEvent::new_peer(*peer_id));
1083
1084 if added {
1085 metrics::gauge!(METRIC_ACTIVE_PEERS).increment(1);
1086 }
1087 AddedPeer::New(new_connection)
1088 }
1089
1090 fn remove(&self, peer_id: &PeerId, reason: DisconnectReason) {
1091 if let Some((_, connection)) = self.connections.remove(peer_id) {
1092 connection.close();
1093 self.connections_len.fetch_sub(1, Ordering::Release);
1094 self.send_event(PeerEvent::lost_peer(*peer_id, reason));
1095
1096 metrics::gauge!(METRIC_ACTIVE_PEERS).decrement(1);
1097 }
1098 }
1099
1100 fn remove_with_stable_id(&self, peer_id: &PeerId, stable_id: usize, reason: DisconnectReason) {
1101 if let Some((_, connection)) = self
1102 .connections
1103 .remove_if(peer_id, |_, connection| connection.stable_id() == stable_id)
1104 {
1105 connection.close();
1106 self.connections_len.fetch_sub(1, Ordering::Release);
1107 self.send_event(PeerEvent::lost_peer(*peer_id, reason));
1108
1109 metrics::gauge!(METRIC_ACTIVE_PEERS).decrement(1);
1110 }
1111 }
1112
1113 fn subscribe(&self) -> broadcast::Receiver<PeerEvent> {
1114 self.events_tx.subscribe()
1115 }
1116
1117 fn send_event(&self, event: PeerEvent) {
1118 _ = self.events_tx.send(event);
1119 }
1120
1121 fn is_empty(&self) -> bool {
1122 self.connections.is_empty()
1123 }
1124
1125 fn len(&self) -> usize {
1126 self.connections_len.load(Ordering::Acquire)
1127 }
1128}
1129
1130pub(crate) enum AddedPeer {
1131 New(Connection),
1132 Existing(Connection),
1133}
1134
1135fn simultaneous_dial_tie_breaking(
1136 local_id: &PeerId,
1137 peer_id: &PeerId,
1138 old_origin: Direction,
1139 new_origin: Direction,
1140) -> bool {
1141 match (old_origin, new_origin) {
1142 (Direction::Inbound, Direction::Inbound) | (Direction::Outbound, Direction::Outbound) => {
1143 true
1144 }
1145 (Direction::Inbound, Direction::Outbound) => peer_id < local_id,
1146 (Direction::Outbound, Direction::Inbound) => local_id < peer_id,
1147 }
1148}
1149
1150#[derive(Default, Clone)]
1151#[repr(transparent)]
1152pub struct KnownPeers(Arc<FastDashMap<PeerId, KnownPeerState>>);
1153
1154impl KnownPeers {
1155 pub fn new() -> Self {
1156 Self::default()
1157 }
1158
1159 pub fn contains(&self, peer_id: &PeerId) -> bool {
1160 self.0.contains_key(peer_id)
1161 }
1162
1163 pub fn is_banned(&self, peer_id: &PeerId) -> bool {
1164 self.0
1165 .get(peer_id)
1166 .and_then(|item| {
1167 Some(match item.value() {
1168 KnownPeerState::Stored(item) => item.upgrade()?.is_banned(),
1169 KnownPeerState::Banned => true,
1170 })
1171 })
1172 .unwrap_or_default()
1173 }
1174
1175 pub fn get(&self, peer_id: &PeerId) -> Option<Arc<PeerInfo>> {
1176 self.0.get(peer_id).and_then(|item| match item.value() {
1177 KnownPeerState::Stored(item) => {
1178 let inner = item.upgrade()?;
1179 Some(inner.peer_info.load_full())
1180 }
1181 KnownPeerState::Banned => None,
1182 })
1183 }
1184
1185 pub fn get_affinity(&self, peer_id: &PeerId) -> Option<PeerAffinity> {
1186 self.0
1187 .get(peer_id)
1188 .and_then(|item| item.value().compute_affinity())
1189 }
1190
1191 pub fn remove(&self, peer_id: &PeerId) {
1192 self.0.remove(peer_id);
1193 metrics::gauge!(METRIC_KNOWN_PEERS).decrement(1);
1194 }
1195
1196 pub fn ban(&self, peer_id: &PeerId) {
1197 let mut added = false;
1198 match self.0.entry(*peer_id) {
1199 dashmap::mapref::entry::Entry::Vacant(entry) => {
1200 entry.insert(KnownPeerState::Banned);
1201 added = true;
1202 }
1203 dashmap::mapref::entry::Entry::Occupied(mut entry) => match entry.get_mut() {
1204 KnownPeerState::Banned => {}
1205 KnownPeerState::Stored(item) => match item.upgrade() {
1206 Some(item) => item.affinity.store(AFFINITY_BANNED, Ordering::Release),
1207 None => *entry.get_mut() = KnownPeerState::Banned,
1208 },
1209 },
1210 }
1211
1212 if added {
1213 metrics::gauge!(METRIC_KNOWN_PEERS).increment(1);
1215 }
1216 }
1217
1218 pub fn make_handle(&self, peer_id: &PeerId, with_affinity: bool) -> Option<KnownPeerHandle> {
1219 let inner = match self.0.get(peer_id)?.value() {
1220 KnownPeerState::Stored(item) => {
1221 let inner = item.upgrade()?;
1222 if with_affinity && !inner.increase_affinity() {
1223 return None;
1224 }
1225 inner
1226 }
1227 KnownPeerState::Banned => return None,
1228 };
1229
1230 Some(KnownPeerHandle::from_inner(inner, with_affinity))
1231 }
1232
1233 pub fn insert(
1236 &self,
1237 peer_info: Arc<PeerInfo>,
1238 with_affinity: bool,
1239 ) -> Result<KnownPeerHandle, KnownPeersError> {
1240 let mut added = false;
1242 let inner = match self.0.entry(peer_info.id) {
1243 dashmap::mapref::entry::Entry::Vacant(entry) => {
1244 let inner = KnownPeerInner::new(peer_info, with_affinity, &self.0);
1245 entry.insert(KnownPeerState::Stored(Arc::downgrade(&inner)));
1246 added = true;
1247 inner
1248 }
1249 dashmap::mapref::entry::Entry::Occupied(mut entry) => match entry.get_mut() {
1250 KnownPeerState::Banned => return Err(KnownPeersError::from(PeerBannedError)),
1251 KnownPeerState::Stored(item) => match item.upgrade() {
1252 Some(inner) => match inner.try_update_peer_info(&peer_info, with_affinity)? {
1253 true => inner,
1254 false => return Err(KnownPeersError::OutdatedInfo),
1255 },
1256 None => {
1257 let inner = KnownPeerInner::new(peer_info, with_affinity, &self.0);
1258 *item = Arc::downgrade(&inner);
1259 inner
1260 }
1261 },
1262 },
1263 };
1264
1265 if added {
1266 metrics::gauge!(METRIC_KNOWN_PEERS).increment(1);
1267 }
1268
1269 Ok(KnownPeerHandle::from_inner(inner, with_affinity))
1270 }
1271
1272 pub fn insert_allow_outdated(
1274 &self,
1275 peer_info: Arc<PeerInfo>,
1276 with_affinity: bool,
1277 ) -> Result<KnownPeerHandle, PeerBannedError> {
1278 let mut added = false;
1280 let inner = match self.0.entry(peer_info.id) {
1281 dashmap::mapref::entry::Entry::Vacant(entry) => {
1282 let inner = KnownPeerInner::new(peer_info, with_affinity, &self.0);
1283 entry.insert(KnownPeerState::Stored(Arc::downgrade(&inner)));
1284 added = true;
1285 inner
1286 }
1287 dashmap::mapref::entry::Entry::Occupied(mut entry) => match entry.get_mut() {
1288 KnownPeerState::Banned => return Err(PeerBannedError),
1289 KnownPeerState::Stored(item) => match item.upgrade() {
1290 Some(inner) => {
1291 inner.try_update_peer_info(&peer_info, with_affinity)?;
1293 inner
1294 }
1295 None => {
1296 let inner = KnownPeerInner::new(peer_info, with_affinity, &self.0);
1297 *item = Arc::downgrade(&inner);
1298 inner
1299 }
1300 },
1301 },
1302 };
1303
1304 if added {
1305 metrics::gauge!(METRIC_KNOWN_PEERS).increment(1);
1306 }
1307
1308 Ok(KnownPeerHandle::from_inner(inner, with_affinity))
1309 }
1310}
1311
1312enum KnownPeerState {
1313 Stored(Weak<KnownPeerInner>),
1314 Banned,
1315}
1316
1317impl KnownPeerState {
1318 fn compute_affinity(&self) -> Option<PeerAffinity> {
1319 Some(match self {
1320 Self::Stored(weak) => weak.upgrade()?.compute_affinity(),
1321 Self::Banned => PeerAffinity::Never,
1322 })
1323 }
1324}
1325
1326#[derive(Clone)]
1327#[repr(transparent)]
1328pub struct KnownPeerHandle(KnownPeerHandleState);
1329
1330impl KnownPeerHandle {
1331 fn from_inner(inner: Arc<KnownPeerInner>, with_affinity: bool) -> Self {
1332 KnownPeerHandle(if with_affinity {
1333 KnownPeerHandleState::WithAffinity(ManuallyDrop::new(Arc::new(
1334 KnownPeerHandleWithAffinity { inner },
1335 )))
1336 } else {
1337 KnownPeerHandleState::Simple(ManuallyDrop::new(inner))
1338 })
1339 }
1340
1341 pub fn peer_info(&self) -> arc_swap::Guard<Arc<PeerInfo>, arc_swap::DefaultStrategy> {
1342 self.inner().peer_info.load()
1343 }
1344
1345 pub fn load_peer_info(&self) -> Arc<PeerInfo> {
1346 arc_swap::Guard::into_inner(self.peer_info())
1347 }
1348
1349 pub fn is_banned(&self) -> bool {
1350 self.inner().is_banned()
1351 }
1352
1353 pub fn max_affinity(&self) -> PeerAffinity {
1354 self.inner().compute_affinity()
1355 }
1356
1357 pub fn update_peer_info(&self, peer_info: &Arc<PeerInfo>) -> Result<(), KnownPeersError> {
1358 match self.inner().try_update_peer_info(peer_info, false) {
1359 Ok(true) => Ok(()),
1360 Ok(false) => Err(KnownPeersError::OutdatedInfo),
1361 Err(e) => Err(KnownPeersError::PeerBanned(e)),
1362 }
1363 }
1364
1365 pub fn ban(&self) -> bool {
1366 let inner = self.inner();
1367 inner.affinity.swap(AFFINITY_BANNED, Ordering::AcqRel) != AFFINITY_BANNED
1368 }
1369
1370 pub fn increase_affinity(&mut self) -> bool {
1371 match &mut self.0 {
1372 KnownPeerHandleState::Simple(inner) => {
1373 inner.increase_affinity();
1375
1376 let inner = unsafe { ManuallyDrop::take(inner) };
1378
1379 let prev_state = std::mem::replace(
1382 &mut self.0,
1383 KnownPeerHandleState::WithAffinity(ManuallyDrop::new(Arc::new(
1384 KnownPeerHandleWithAffinity { inner },
1385 ))),
1386 );
1387
1388 #[allow(clippy::mem_forget)]
1390 std::mem::forget(prev_state);
1391
1392 true
1393 }
1394 KnownPeerHandleState::WithAffinity(_) => false,
1395 }
1396 }
1397
1398 pub fn decrease_affinity(&mut self) -> bool {
1399 match &mut self.0 {
1400 KnownPeerHandleState::Simple(_) => false,
1401 KnownPeerHandleState::WithAffinity(inner) => {
1402 inner.inner.decrease_affinity();
1404
1405 let inner = unsafe { ManuallyDrop::take(inner) };
1407
1408 let inner = match Arc::try_unwrap(inner) {
1410 Ok(KnownPeerHandleWithAffinity { inner }) => inner,
1411 Err(inner) => inner.inner.clone(),
1412 };
1413
1414 let prev_state = std::mem::replace(
1417 &mut self.0,
1418 KnownPeerHandleState::Simple(ManuallyDrop::new(inner)),
1419 );
1420
1421 #[allow(clippy::mem_forget)]
1423 std::mem::forget(prev_state);
1424
1425 true
1426 }
1427 }
1428 }
1429
1430 pub fn downgrade(&self) -> WeakKnownPeerHandle {
1431 WeakKnownPeerHandle(match &self.0 {
1432 KnownPeerHandleState::Simple(data) => {
1433 WeakKnownPeerHandleState::Simple(Arc::downgrade(data))
1434 }
1435 KnownPeerHandleState::WithAffinity(data) => {
1436 WeakKnownPeerHandleState::WithAffinity(Arc::downgrade(data))
1437 }
1438 })
1439 }
1440
1441 fn inner(&self) -> &KnownPeerInner {
1442 match &self.0 {
1443 KnownPeerHandleState::Simple(data) => data.as_ref(),
1444 KnownPeerHandleState::WithAffinity(data) => data.inner.as_ref(),
1445 }
1446 }
1447}
1448
1449#[derive(Clone)]
1450enum KnownPeerHandleState {
1451 Simple(ManuallyDrop<Arc<KnownPeerInner>>),
1452 WithAffinity(ManuallyDrop<Arc<KnownPeerHandleWithAffinity>>),
1453}
1454
1455impl Drop for KnownPeerHandleState {
1456 fn drop(&mut self) {
1457 let inner;
1458 let is_banned;
1459 match self {
1460 KnownPeerHandleState::Simple(data) => {
1461 inner = unsafe { ManuallyDrop::take(data) };
1463 is_banned = inner.is_banned();
1464 }
1465 KnownPeerHandleState::WithAffinity(data) => {
1466 match Arc::into_inner(unsafe { ManuallyDrop::take(data) }) {
1468 Some(data) => {
1469 inner = data.inner;
1470 is_banned = !inner.decrease_affinity() || inner.is_banned();
1471 }
1472 None => return,
1473 }
1474 }
1475 };
1476
1477 if is_banned {
1478 return;
1480 }
1481
1482 if let Some(inner) = Arc::into_inner(inner) {
1483 if let Some(peers) = inner.weak_known_peers.upgrade() {
1485 peers.remove(&inner.peer_info.load().id);
1486 metrics::gauge!(METRIC_KNOWN_PEERS).decrement(1);
1487 }
1488 }
1489 }
1490}
1491
1492#[derive(Clone, PartialEq, Eq)]
1493#[repr(transparent)]
1494pub struct WeakKnownPeerHandle(WeakKnownPeerHandleState);
1495
1496impl WeakKnownPeerHandle {
1497 pub fn upgrade(&self) -> Option<KnownPeerHandle> {
1498 Some(KnownPeerHandle(match &self.0 {
1499 WeakKnownPeerHandleState::Simple(weak) => {
1500 KnownPeerHandleState::Simple(ManuallyDrop::new(weak.upgrade()?))
1501 }
1502 WeakKnownPeerHandleState::WithAffinity(weak) => {
1503 KnownPeerHandleState::WithAffinity(ManuallyDrop::new(weak.upgrade()?))
1504 }
1505 }))
1506 }
1507}
1508
1509#[derive(Clone)]
1510enum WeakKnownPeerHandleState {
1511 Simple(Weak<KnownPeerInner>),
1512 WithAffinity(Weak<KnownPeerHandleWithAffinity>),
1513}
1514
1515impl Eq for WeakKnownPeerHandleState {}
1516impl PartialEq for WeakKnownPeerHandleState {
1517 #[inline]
1518 fn eq(&self, other: &Self) -> bool {
1519 match (self, other) {
1520 (Self::Simple(left), Self::Simple(right)) => Weak::ptr_eq(left, right),
1521 (Self::WithAffinity(left), Self::WithAffinity(right)) => Weak::ptr_eq(left, right),
1522 _ => false,
1523 }
1524 }
1525}
1526
1527struct KnownPeerHandleWithAffinity {
1528 inner: Arc<KnownPeerInner>,
1529}
1530
1531struct KnownPeerInner {
1532 peer_info: ArcSwap<PeerInfo>,
1533 affinity: AtomicUsize,
1534 weak_known_peers: Weak<FastDashMap<PeerId, KnownPeerState>>,
1535}
1536
1537impl KnownPeerInner {
1538 fn new(
1539 peer_info: Arc<PeerInfo>,
1540 with_affinity: bool,
1541 known_peers: &Arc<FastDashMap<PeerId, KnownPeerState>>,
1542 ) -> Arc<Self> {
1543 Arc::new(Self {
1544 peer_info: ArcSwap::from(peer_info),
1545 affinity: AtomicUsize::new(if with_affinity { 1 } else { 0 }),
1546 weak_known_peers: Arc::downgrade(known_peers),
1547 })
1548 }
1549
1550 fn is_banned(&self) -> bool {
1551 self.affinity.load(Ordering::Acquire) == AFFINITY_BANNED
1552 }
1553
1554 fn compute_affinity(&self) -> PeerAffinity {
1555 match self.affinity.load(Ordering::Acquire) {
1556 0 => PeerAffinity::Allowed,
1557 AFFINITY_BANNED => PeerAffinity::Never,
1558 _ => PeerAffinity::High,
1559 }
1560 }
1561
1562 fn increase_affinity(&self) -> bool {
1563 let mut current = self.affinity.load(Ordering::Acquire);
1564 while current != AFFINITY_BANNED {
1565 debug_assert_ne!(current, AFFINITY_BANNED - 1);
1566 match self.affinity.compare_exchange_weak(
1567 current,
1568 current + 1,
1569 Ordering::Release,
1570 Ordering::Acquire,
1571 ) {
1572 Ok(_) => return true,
1573 Err(affinity) => current = affinity,
1574 }
1575 }
1576
1577 false
1578 }
1579
1580 fn decrease_affinity(&self) -> bool {
1581 let mut current = self.affinity.load(Ordering::Acquire);
1582 while current != AFFINITY_BANNED {
1583 debug_assert_ne!(current, 0);
1584 match self.affinity.compare_exchange_weak(
1585 current,
1586 current - 1,
1587 Ordering::Release,
1588 Ordering::Acquire,
1589 ) {
1590 Ok(_) => return true,
1591 Err(affinity) => current = affinity,
1592 }
1593 }
1594
1595 false
1596 }
1597
1598 fn try_update_peer_info(
1599 &self,
1600 peer_info: &Arc<PeerInfo>,
1601 with_affinity: bool,
1602 ) -> Result<bool, PeerBannedError> {
1603 struct AffinityGuard<'a> {
1604 inner: &'a KnownPeerInner,
1605 decrease_on_drop: bool,
1606 }
1607
1608 impl AffinityGuard<'_> {
1609 fn increase_affinity_or_check_ban(&mut self, with_affinity: bool) -> bool {
1610 let with_affinity = with_affinity && !self.decrease_on_drop;
1611 let is_banned = if with_affinity {
1612 !self.inner.increase_affinity()
1613 } else {
1614 self.inner.is_banned()
1615 };
1616
1617 if !is_banned && with_affinity {
1618 self.decrease_on_drop = true;
1619 }
1620
1621 is_banned
1622 }
1623 }
1624
1625 impl Drop for AffinityGuard<'_> {
1626 fn drop(&mut self) {
1627 if self.decrease_on_drop {
1628 self.inner.decrease_affinity();
1629 }
1630 }
1631 }
1632
1633 let mut guard = AffinityGuard {
1635 inner: self,
1636 decrease_on_drop: false,
1637 };
1638
1639 let mut cur = self.peer_info.load();
1640 let updated = loop {
1641 if guard.increase_affinity_or_check_ban(with_affinity) {
1642 return Err(PeerBannedError);
1644 }
1645
1646 match cur.created_at.cmp(&peer_info.created_at) {
1647 std::cmp::Ordering::Equal => break true,
1650 std::cmp::Ordering::Less => {
1652 let prev = self.peer_info.compare_and_swap(&*cur, peer_info.clone());
1653 if std::ptr::eq(cur.as_raw(), prev.as_raw()) {
1654 break true;
1655 } else {
1656 cur = prev;
1657 }
1658 }
1659 std::cmp::Ordering::Greater => break false,
1661 }
1662 };
1663
1664 guard.decrease_on_drop = false;
1665 Ok(updated)
1666 }
1667}
1668
1669const AFFINITY_BANNED: usize = usize::MAX;
1670
1671#[derive(Debug, thiserror::Error)]
1672pub enum KnownPeersError {
1673 #[error(transparent)]
1674 PeerBanned(#[from] PeerBannedError),
1675 #[error("provided peer info is outdated")]
1676 OutdatedInfo,
1677}
1678
1679#[derive(Debug, Copy, Clone, thiserror::Error)]
1680#[error("peer is banned")]
1681pub struct PeerBannedError;
1682
1683#[cfg(test)]
1684mod tests {
1685 use super::*;
1686 use crate::util::make_peer_info_stub;
1687
1688 #[test]
1689 fn remove_from_cache_on_drop_works() {
1690 let peers = KnownPeers::new();
1691
1692 let peer_info = make_peer_info_stub(rand::random());
1693 let handle = peers.insert(peer_info.clone(), false).unwrap();
1694 assert!(peers.contains(&peer_info.id));
1695 assert!(!peers.is_banned(&peer_info.id));
1696 assert_eq!(peers.get(&peer_info.id), Some(peer_info.clone()));
1697 assert_eq!(
1698 peers.get_affinity(&peer_info.id),
1699 Some(PeerAffinity::Allowed)
1700 );
1701
1702 assert_eq!(handle.peer_info().as_ref(), peer_info.as_ref());
1703 assert_eq!(handle.max_affinity(), PeerAffinity::Allowed);
1704
1705 let other_handle = peers.insert(peer_info.clone(), false).unwrap();
1706 assert!(peers.contains(&peer_info.id));
1707 assert!(!peers.is_banned(&peer_info.id));
1708 assert_eq!(peers.get(&peer_info.id), Some(peer_info.clone()));
1709 assert_eq!(
1710 peers.get_affinity(&peer_info.id),
1711 Some(PeerAffinity::Allowed)
1712 );
1713
1714 assert_eq!(other_handle.peer_info().as_ref(), peer_info.as_ref());
1715 assert_eq!(other_handle.max_affinity(), PeerAffinity::Allowed);
1716
1717 drop(other_handle);
1718 assert!(peers.contains(&peer_info.id));
1719 assert!(!peers.is_banned(&peer_info.id));
1720 assert_eq!(peers.get(&peer_info.id), Some(peer_info.clone()));
1721 assert_eq!(
1722 peers.get_affinity(&peer_info.id),
1723 Some(PeerAffinity::Allowed)
1724 );
1725
1726 drop(handle);
1727 assert!(!peers.contains(&peer_info.id));
1728 assert!(!peers.is_banned(&peer_info.id));
1729 assert_eq!(peers.get(&peer_info.id), None);
1730 assert_eq!(peers.get_affinity(&peer_info.id), None);
1731
1732 peers.insert(peer_info.clone(), false).unwrap();
1733 }
1734
1735 #[test]
1736 fn with_affinity_after_simple() {
1737 let peers = KnownPeers::new();
1738
1739 let peer_info = make_peer_info_stub(rand::random());
1740 let handle_simple = peers.insert(peer_info.clone(), false).unwrap();
1741 assert!(peers.contains(&peer_info.id));
1742 assert_eq!(
1743 peers.get_affinity(&peer_info.id),
1744 Some(PeerAffinity::Allowed)
1745 );
1746 assert_eq!(handle_simple.max_affinity(), PeerAffinity::Allowed);
1747
1748 let handle_with_affinity = peers.insert(peer_info.clone(), true).unwrap();
1749 assert!(peers.contains(&peer_info.id));
1750 assert_eq!(peers.get_affinity(&peer_info.id), Some(PeerAffinity::High));
1751 assert_eq!(handle_with_affinity.max_affinity(), PeerAffinity::High);
1752 assert_eq!(handle_simple.max_affinity(), PeerAffinity::High);
1753
1754 drop(handle_with_affinity);
1755 assert!(peers.contains(&peer_info.id));
1756 assert_eq!(handle_simple.max_affinity(), PeerAffinity::Allowed);
1757 assert_eq!(
1758 peers.get_affinity(&peer_info.id),
1759 Some(PeerAffinity::Allowed)
1760 );
1761
1762 drop(handle_simple);
1763 assert!(!peers.contains(&peer_info.id));
1764 assert_eq!(peers.get_affinity(&peer_info.id), None);
1765 }
1766
1767 #[test]
1768 fn with_affinity_before_simple() {
1769 let peers = KnownPeers::new();
1770
1771 let peer_info = make_peer_info_stub(rand::random());
1772 let handle_with_affinity = peers.insert(peer_info.clone(), true).unwrap();
1773 assert!(peers.contains(&peer_info.id));
1774 assert_eq!(peers.get_affinity(&peer_info.id), Some(PeerAffinity::High));
1775 assert_eq!(handle_with_affinity.max_affinity(), PeerAffinity::High);
1776
1777 let handle_simple = peers.insert(peer_info.clone(), false).unwrap();
1778 assert!(peers.contains(&peer_info.id));
1779 assert_eq!(peers.get_affinity(&peer_info.id), Some(PeerAffinity::High));
1780 assert_eq!(handle_with_affinity.max_affinity(), PeerAffinity::High);
1781 assert_eq!(handle_simple.max_affinity(), PeerAffinity::High);
1782
1783 drop(handle_simple);
1784 assert!(peers.contains(&peer_info.id));
1785 assert_eq!(handle_with_affinity.max_affinity(), PeerAffinity::High);
1786 assert_eq!(peers.get_affinity(&peer_info.id), Some(PeerAffinity::High));
1787
1788 drop(handle_with_affinity);
1789 assert!(!peers.contains(&peer_info.id));
1790 assert_eq!(peers.get_affinity(&peer_info.id), None);
1791 }
1792
1793 #[test]
1794 fn ban_while_handle_exists() {
1795 let peers = KnownPeers::new();
1796
1797 let peer_info = make_peer_info_stub(rand::random());
1798 let handle = peers.insert(peer_info.clone(), false).unwrap();
1799 assert!(peers.contains(&peer_info.id));
1800 assert_eq!(handle.max_affinity(), PeerAffinity::Allowed);
1801
1802 peers.ban(&peer_info.id);
1803 assert!(peers.contains(&peer_info.id));
1804 assert!(peers.is_banned(&peer_info.id));
1805 assert_eq!(handle.max_affinity(), PeerAffinity::Never);
1806 assert_eq!(peers.get(&peer_info.id), Some(peer_info.clone()));
1807 assert_eq!(peers.get_affinity(&peer_info.id), Some(PeerAffinity::Never));
1808 }
1809}