1use std::collections::{VecDeque, hash_map};
2use std::mem::ManuallyDrop;
3use std::sync::atomic::{AtomicUsize, Ordering};
4use std::sync::{Arc, Weak};
5use std::time::{Duration, Instant};
6
7use anyhow::Result;
8use arc_swap::{ArcSwap, AsRaw};
9use tokio::sync::{broadcast, mpsc, oneshot};
10use tokio::task::{AbortHandle, JoinSet};
11use tokio_util::time::{DelayQueue, delay_queue};
12use tycho_util::{FastDashMap, FastHashMap};
13
14use crate::network::ConnectionError;
15use crate::network::config::NetworkConfig;
16use crate::network::connection::Connection;
17use crate::network::endpoint::{Connecting, ConnectionInitError, Endpoint, Into0RttResult};
18use crate::network::request_handler::InboundRequestHandler;
19use crate::network::wire::{HandshakeError, handshake};
20use crate::types::{
21 Address, BoxCloneService, Direction, DisconnectReason, PeerAffinity, PeerEvent, PeerId,
22 PeerInfo, Response, ServiceRequest,
23};
24
25const METRIC_CONNECTION_OUT_TIME: &str = "tycho_net_conn_out_time";
27const METRIC_CONNECTION_IN_TIME: &str = "tycho_net_conn_in_time";
28
29const METRIC_CONNECTIONS_OUT_TOTAL: &str = "tycho_net_conn_out_total";
31const METRIC_CONNECTIONS_IN_TOTAL: &str = "tycho_net_conn_in_total";
32const METRIC_CONNECTIONS_OUT_FAIL_TOTAL: &str = "tycho_net_conn_out_fail_total";
33const METRIC_CONNECTIONS_IN_FAIL_TOTAL: &str = "tycho_net_conn_in_fail_total";
34
35const METRIC_CONNECTIONS_ACTIVE: &str = "tycho_net_conn_active";
37const METRIC_CONNECTIONS_PENDING: &str = "tycho_net_conn_pending";
38const METRIC_CONNECTIONS_PARTIAL: &str = "tycho_net_conn_partial";
39const METRIC_CONNECTIONS_PENDING_DIALS: &str = "tycho_net_conn_pending_dials";
40
41const METRIC_ACTIVE_PEERS: &str = "tycho_net_active_peers";
42const METRIC_KNOWN_PEERS: &str = "tycho_net_known_peers";
43
44#[derive(Debug)]
45pub(crate) enum ConnectionManagerRequest {
46 Connect(Address, PeerId, CallbackTx),
47 Shutdown(oneshot::Sender<()>),
48}
49
50pub(crate) struct ConnectionManager {
51 config: Arc<NetworkConfig>,
52 endpoint: Arc<Endpoint>,
53
54 mailbox: mpsc::Receiver<ConnectionManagerRequest>,
55
56 pending_connection_callbacks: FastHashMap<Address, PendingConnectionCallbacks>,
57 pending_partial_connections: JoinSet<Option<PartialConnection>>,
58 pending_connections: JoinSet<ConnectingOutput>,
59 connection_handlers: JoinSet<()>,
60 delayed_callbacks: DelayedCallbacksQueue,
61
62 pending_dials: FastHashMap<PeerId, CallbackRx>,
63 dial_backoff_states: FastHashMap<PeerId, DialBackoffState>,
64
65 active_peers: ActivePeers,
66 known_peers: KnownPeers,
67
68 service: BoxCloneService<ServiceRequest, Response>,
69}
70
71type CallbackTx = oneshot::Sender<Result<Connection, ConnectionError>>;
72type CallbackRx = oneshot::Receiver<Result<Connection, ConnectionError>>;
73
74impl Drop for ConnectionManager {
75 fn drop(&mut self) {
76 tracing::trace!("dropping connection manager");
77 self.endpoint.close();
78 }
79}
80
81impl ConnectionManager {
82 pub fn new(
83 config: Arc<NetworkConfig>,
84 endpoint: Arc<Endpoint>,
85 active_peers: ActivePeers,
86 known_peers: KnownPeers,
87 service: BoxCloneService<ServiceRequest, Response>,
88 ) -> (Self, mpsc::Sender<ConnectionManagerRequest>) {
89 let (mailbox_tx, mailbox) = mpsc::channel(config.connection_manager_channel_capacity);
90 let connection_manager = Self {
91 config,
92 endpoint,
93 mailbox,
94 pending_connection_callbacks: Default::default(),
95 pending_partial_connections: Default::default(),
96 pending_connections: Default::default(),
97 connection_handlers: Default::default(),
98 delayed_callbacks: Default::default(),
99 pending_dials: Default::default(),
100 dial_backoff_states: Default::default(),
101 active_peers,
102 known_peers,
103 service,
104 };
105 (connection_manager, mailbox_tx)
106 }
107
108 pub async fn start(mut self) {
109 tracing::info!("connection manager started");
110
111 let jitter = Duration::from_millis(1000).mul_f64(rand::random());
112 let mut interval = tokio::time::interval(self.config.connectivity_check_interval + jitter);
113
114 let mut shutdown_notifier = None;
115
116 loop {
117 tokio::select! {
118 now = interval.tick() => {
119 self.handle_connectivity_check(now.into_std());
120 }
121 maybe_request = self.mailbox.recv() => {
122 let Some(request) = maybe_request else {
123 break;
124 };
125
126 match request {
127 ConnectionManagerRequest::Connect(address, peer_id, callback) => {
128 self.handle_connect_request(address, &peer_id, callback);
129 }
130 ConnectionManagerRequest::Shutdown(oneshot) => {
131 shutdown_notifier = Some(oneshot);
132 break;
133 }
134 }
135 }
136 connecting = self.endpoint.accept() => {
137 if let Some(connecting) = connecting {
138 self.handle_incoming(connecting);
139 }
140 }
141 Some(connecting_output) = self.pending_connections.join_next() => {
142 metrics::gauge!(METRIC_CONNECTIONS_PENDING).decrement(1);
143 match connecting_output {
144 Ok(connecting) => self.handle_connecting_result(connecting),
145 Err(e) => {
146 if e.is_panic() {
147 std::panic::resume_unwind(e.into_panic());
148 }
149 }
150 }
151 }
152 Some(partial_connection) = self.pending_partial_connections.join_next() => {
153 metrics::gauge!(METRIC_CONNECTIONS_PARTIAL).decrement(1);
154
155 match partial_connection {
156 Ok(None) => {}
157 Ok(Some(PartialConnection {
158 connection,
159 timeout_at,
160 })) => self.handle_incoming_impl(connection, None, timeout_at),
161 Err(e) => {
163 if e.is_panic() {
164 std::panic::resume_unwind(e.into_panic());
165 }
166 }
167 }
168 }
169 Some(connection_handler_output) = self.connection_handlers.join_next() => {
170 metrics::gauge!(METRIC_CONNECTIONS_ACTIVE).decrement(1);
171
172 if let Err(e) = connection_handler_output {
174 if e.is_panic() {
175 std::panic::resume_unwind(e.into_panic());
176 }
177 }
178 }
179 Some(peer_id) = self.delayed_callbacks.wait_for_next_expired() => {
180 self.delayed_callbacks.execute_expired(&peer_id);
181 }
182 }
183 }
184
185 self.shutdown().await;
186
187 if let Some(tx) = shutdown_notifier {
188 _ = tx.send(());
189 }
190
191 tracing::info!("connection manager stopped");
192 }
193
194 async fn shutdown(mut self) {
195 tracing::trace!("shutting down connection manager");
196
197 self.endpoint.close();
198
199 self.pending_partial_connections.shutdown().await;
200 metrics::gauge!(METRIC_CONNECTIONS_PARTIAL).set(0);
201
202 self.pending_connections.shutdown().await;
203 metrics::gauge!(METRIC_CONNECTIONS_PENDING).set(0);
204
205 while self.connection_handlers.join_next().await.is_some() {
206 metrics::gauge!(METRIC_CONNECTIONS_ACTIVE).decrement(1);
207 }
208 assert!(self.active_peers.is_empty());
209
210 self.endpoint
211 .wait_idle(self.config.shutdown_idle_timeout)
212 .await;
213 }
214
215 fn handle_connectivity_check(&mut self, now: Instant) {
216 use std::collections::hash_map::Entry;
217
218 self.pending_dials
219 .retain(|peer_id, oneshot| match oneshot.try_recv() {
220 Ok(Ok(returned_peer_id)) => {
221 debug_assert_eq!(peer_id, returned_peer_id.peer_id());
222 self.dial_backoff_states.remove(peer_id);
223 false
224 }
225 Ok(Err(_)) => {
226 match self.dial_backoff_states.entry(*peer_id) {
227 Entry::Occupied(mut entry) => entry.get_mut().update(
228 now,
229 self.config.connection_backoff,
230 self.config.max_connection_backoff,
231 ),
232 Entry::Vacant(entry) => {
233 entry.insert(DialBackoffState::new(
234 now,
235 self.config.connection_backoff,
236 self.config.max_connection_backoff,
237 ));
238 }
239 }
240 false
241 }
242 Err(oneshot::error::TryRecvError::Closed) => {
243 panic!("BUG: connection manager never finished dialing a peer");
244 }
245 Err(oneshot::error::TryRecvError::Empty) => true,
246 });
247
248 let outstanding_connections_limit = self
249 .config
250 .max_concurrent_outstanding_connections
251 .saturating_sub(self.pending_connections.len());
252
253 let outstanding_connections = self
254 .known_peers
255 .0
256 .iter()
257 .filter_map(|item| {
258 let value = match item.value() {
259 KnownPeerState::Stored(item) => item.upgrade()?,
260 KnownPeerState::Banned => return None,
261 };
262 let peer_info = value.peer_info.load();
263 let affinity = value.compute_affinity();
264
265 (affinity == PeerAffinity::High
266 && peer_info.id != self.endpoint.peer_id()
267 && !self.active_peers.contains(&peer_info.id)
268 && !self.pending_dials.contains_key(&peer_info.id)
269 && self
270 .dial_backoff_states
271 .get(&peer_info.id)
272 .is_none_or(|state| now > state.next_attempt_at))
273 .then(|| arc_swap::Guard::into_inner(peer_info))
274 })
275 .take(outstanding_connections_limit)
276 .collect::<Vec<_>>();
277
278 for peer_info in outstanding_connections {
279 let address = peer_info
281 .iter_addresses()
282 .next()
283 .cloned()
284 .expect("address list must have at least one item");
285
286 let (tx, rx) = oneshot::channel();
287 self.dial_peer(address, &peer_info.id, tx);
288 self.pending_dials.insert(peer_info.id, rx);
289 }
290
291 metrics::gauge!(METRIC_CONNECTIONS_PENDING_DIALS).set(self.pending_dials.len() as f64);
292 }
293
294 fn handle_connect_request(&mut self, address: Address, peer_id: &PeerId, callback: CallbackTx) {
295 self.dial_peer(address, peer_id, callback);
296 }
297
298 fn handle_incoming(&mut self, connecting: Connecting) {
299 let remote_addr = connecting.remote_address();
300 tracing::trace!(
301 local_id = %self.endpoint.peer_id(),
302 %remote_addr,
303 "received an incoming connection",
304 );
305
306 match connecting.into_0rtt() {
308 Into0RttResult::Established(connection, accepted) => {
309 let timeout_at = Instant::now() + self.config.connect_timeout;
310 self.handle_incoming_impl(connection, Some(accepted), timeout_at);
311 }
312 Into0RttResult::WithoutIdentity(partial_connection) => {
313 tracing::trace!("connection identity is not available yet");
314
315 let timeout_at = Instant::now() + self.config.connect_timeout;
316 self.pending_partial_connections.spawn(async move {
317 match tokio::time::timeout_at(timeout_at.into(), partial_connection).await {
318 Ok(Ok(connection)) => Some(PartialConnection {
319 connection,
320 timeout_at,
321 }),
322 Ok(Err(e)) => {
323 tracing::trace!(
324 %remote_addr,
325 "failed to establish an incoming connection: {e}",
326 );
327 None
328 }
329 Err(_) => {
330 tracing::trace!(
331 %remote_addr,
332 "incoming connection timed out",
333 );
334 None
335 }
336 }
337 });
338 metrics::gauge!(METRIC_CONNECTIONS_PARTIAL).increment(1);
339 }
340 Into0RttResult::InvalidCertificate => {
341 tracing::trace!(%remote_addr, "invalid incoming connection");
342 }
343 Into0RttResult::Unavailable(_) => unreachable!(
344 "BUG: For incoming connections, a 0.5-RTT connection must \
345 always be successfully constructed."
346 ),
347 };
348 }
349
350 fn handle_incoming_impl(
351 &mut self,
352 connection: Connection,
353 accepted: Option<quinn::ZeroRttAccepted>,
354 timeout_at: Instant,
355 ) {
356 async fn handle_incoming_task(
357 seqno: u32,
358 connection: ConnectionClosedOnDrop,
359 accepted: Option<quinn::ZeroRttAccepted>,
360 timeout_at: Instant,
361 ) -> ConnectingOutput {
362 let target_peer_id = *connection.peer_id();
363 let target_address = connection.remote_address().into();
364 let fut = async {
365 if let Some(accepted) = accepted {
366 accepted.await;
368 }
369 handshake(&connection).await
370 };
371
372 let started_at = Instant::now();
373
374 let connecting_result = match tokio::time::timeout_at(timeout_at.into(), fut).await {
375 Ok(Ok(())) => Ok(connection.disarm()),
376 Ok(Err(e)) => Err(FullConnectionError::HandshakeFailed(e)),
377 Err(_) => Err(FullConnectionError::Timeout),
378 };
379
380 metrics::histogram!(METRIC_CONNECTION_IN_TIME).record(started_at.elapsed());
381
382 ConnectingOutput {
383 seqno,
384 drop_result: true,
385 connecting_result: ManuallyDrop::new(connecting_result),
386 target_address,
387 target_peer_id,
388 origin: Direction::Inbound,
389 }
390 }
391
392 let remote_addr = connection.remote_address();
393
394 match self.known_peers.get_affinity(connection.peer_id()) {
396 Some(PeerAffinity::High | PeerAffinity::Allowed) => {}
397 Some(PeerAffinity::Never) => {
398 tracing::warn!(
400 %remote_addr,
401 peer_id = %connection.peer_id(),
402 "rejecting connection due to PeerAffinity::Never",
403 );
404 connection.close();
405 return;
406 }
407 _ => {
408 if matches!(
409 self.config.max_concurrent_connections,
410 Some(limit) if self.active_peers.len() >= limit
411 ) {
412 tracing::warn!(
414 %remote_addr,
415 peer_id = %connection.peer_id(),
416 "rejecting connection due too many concurrent connections",
417 );
418 connection.close();
419 return;
420 }
421 }
422 }
423
424 let entry = match self.pending_connection_callbacks.entry(remote_addr.into()) {
425 hash_map::Entry::Vacant(entry) => Some(entry.insert(PendingConnectionCallbacks {
426 last_seqno: 0,
427 origin: Direction::Inbound,
428 callbacks: Default::default(),
429 abort_handle: None,
430 })),
431 hash_map::Entry::Occupied(entry) => {
432 let entry = entry.into_mut();
433
434 if simultaneous_dial_tie_breaking(
436 self.endpoint.peer_id(),
437 connection.peer_id(),
438 entry.origin,
439 Direction::Inbound,
440 ) {
441 tracing::debug!(
443 %remote_addr,
444 peer_id = %connection.peer_id(),
445 "cancelling old connection to mitigate simultaneous dial",
446 );
447
448 entry.origin = Direction::Inbound;
449 entry.last_seqno += 1;
450 if let Some(handle) = entry.abort_handle.take() {
451 handle.abort();
452 }
453 Some(entry)
454 } else {
455 tracing::debug!(
457 %remote_addr,
458 peer_id = %connection.peer_id(),
459 "cancelling new connection to mitigate simultaneous dial",
460 );
461
462 connection.close();
463 None
464 }
465 }
466 };
467
468 if let Some(entry) = entry {
469 entry.abort_handle = Some(self.pending_connections.spawn(handle_incoming_task(
470 entry.last_seqno,
471 ConnectionClosedOnDrop::new(connection),
472 accepted,
473 timeout_at,
474 )));
475 metrics::gauge!(METRIC_CONNECTIONS_PENDING).increment(1);
476 }
477 }
478
479 fn handle_connecting_result(&mut self, mut res: ConnectingOutput) {
480 {
482 let Some(entry) = self.pending_connection_callbacks.get(&res.target_address) else {
483 tracing::trace!("connection task reordering detected");
484 return;
485 };
486
487 if entry.last_seqno != res.seqno {
488 tracing::debug!(
489 local_id = %self.endpoint.peer_id(),
490 peer_id = %res.target_peer_id,
491 remote_addr = %res.target_address,
492 "connection result is outdated"
493 );
494 return;
495 }
496 }
497
498 let callbacks = self
499 .pending_connection_callbacks
500 .remove(&res.target_address)
501 .expect("Connection tasks must be tracked")
502 .callbacks;
503
504 res.drop_result = false;
505 match unsafe { ManuallyDrop::take(&mut res.connecting_result) } {
507 Ok(connection) => {
508 tracing::debug!(
509 local_id = %self.endpoint.peer_id(),
510 peer_id = %connection.peer_id(),
511 remote_addr = %res.target_address,
512 "new connection",
513 );
514
515 let connection = self.add_peer(connection);
516 self.delayed_callbacks.execute_resolved(&connection);
517
518 for callback in callbacks {
519 _ = callback.send(Ok(connection.clone()));
520 }
521 }
522 Err(e) => {
523 tracing::debug!(
524 local_id = %self.endpoint.peer_id(),
525 peer_id = %res.target_peer_id,
526 remote_addr = %res.target_address,
527 "connection failed: {e:?}"
528 );
529
530 metrics::counter!(match res.origin {
531 Direction::Outbound => METRIC_CONNECTIONS_OUT_FAIL_TOTAL,
532 Direction::Inbound => METRIC_CONNECTIONS_IN_FAIL_TOTAL,
533 })
534 .increment(1);
535
536 let brief_error = e.as_brief();
537
538 if e.was_closed() && !callbacks.is_empty() {
541 self.delayed_callbacks.push(
542 &res.target_peer_id,
543 brief_error,
544 callbacks,
545 &self.config.connection_error_delay,
546 );
547 return;
548 }
549
550 for callback in callbacks {
551 _ = callback.send(Err(brief_error));
552 }
553 }
554 }
555 }
556
557 fn add_peer(&mut self, connection: Connection) -> Connection {
558 match self.active_peers.add(self.endpoint.peer_id(), connection) {
559 AddedPeer::New(connection) => {
560 let origin = connection.origin();
561
562 let handler = InboundRequestHandler::new(
563 self.config.clone(),
564 connection.clone(),
565 self.service.clone(),
566 self.active_peers.clone(),
567 );
568
569 metrics::counter!(match origin {
570 Direction::Outbound => METRIC_CONNECTIONS_OUT_TOTAL,
571 Direction::Inbound => METRIC_CONNECTIONS_IN_TOTAL,
572 })
573 .increment(1);
574
575 metrics::gauge!(METRIC_CONNECTIONS_ACTIVE).increment(1);
576 self.connection_handlers.spawn(handler.start());
577
578 connection
579 }
580 AddedPeer::Existing(connection) => connection,
581 }
582 }
583
584 #[tracing::instrument(
585 level = "trace",
586 skip_all,
587 fields(
588 local_id = %self.endpoint.peer_id(),
589 peer_id = %peer_id,
590 remote_addr = %address,
591 ),
592 )]
593 fn dial_peer(&mut self, address: Address, peer_id: &PeerId, callback: CallbackTx) {
594 async fn dial_peer_task(
595 seqno: u32,
596 endpoint: Arc<Endpoint>,
597 address: Address,
598 peer_id: PeerId,
599 config: Arc<NetworkConfig>,
600 ) -> ConnectingOutput {
601 let fut = async {
602 let address = address
603 .resolve()
604 .await
605 .map_err(FullConnectionError::InvalidAddress)?;
606
607 let connecting = endpoint
608 .connect_with_expected_id(&address, &peer_id)
609 .map_err(|e| FullConnectionError::InvalidAddress(std::io::Error::other(e)))?;
610
611 let connection = ConnectionClosedOnDrop::new(connecting.await?);
612 match handshake(&connection).await {
613 Ok(()) => Ok(connection),
614 Err(e) => Err(FullConnectionError::HandshakeFailed(e)),
615 }
616 };
617
618 let started_at = Instant::now();
619
620 let connecting_result = match tokio::time::timeout(config.connect_timeout, fut).await {
621 Ok(res) => res.map(ConnectionClosedOnDrop::disarm),
622 Err(_) => Err(FullConnectionError::Timeout),
623 };
624
625 metrics::histogram!(METRIC_CONNECTION_OUT_TIME).record(started_at.elapsed());
626
627 ConnectingOutput {
628 seqno,
629 drop_result: true,
630 connecting_result: ManuallyDrop::new(connecting_result),
631 target_address: address,
632 target_peer_id: peer_id,
633 origin: Direction::Outbound,
634 }
635 }
636
637 if let Some(connection) = self.active_peers.get(peer_id) {
638 tracing::debug!("peer is already connected");
639 _ = callback.send(Ok(connection));
640 return;
641 }
642
643 tracing::trace!("connecting to peer");
644
645 let entry = match self.pending_connection_callbacks.entry(address.clone()) {
646 hash_map::Entry::Vacant(entry) => Some(entry.insert(PendingConnectionCallbacks {
647 last_seqno: 0,
648 origin: Direction::Outbound,
649 callbacks: vec![callback],
650 abort_handle: None,
651 })),
652 hash_map::Entry::Occupied(entry) => {
653 let entry = entry.into_mut();
654
655 entry.callbacks.push(callback);
657
658 let break_tie = simultaneous_dial_tie_breaking(
660 self.endpoint.peer_id(),
661 peer_id,
662 entry.origin,
663 Direction::Outbound,
664 );
665
666 if break_tie && entry.origin != Direction::Outbound {
667 tracing::debug!("cancelling old connection to mitigate simultaneous dial");
669
670 entry.origin = Direction::Outbound;
671 entry.last_seqno += 1;
672 if let Some(handle) = entry.abort_handle.take() {
673 handle.abort();
674 }
675 Some(entry)
676 } else {
677 tracing::trace!("reusing old connection to mitigate simultaneous dial");
679 None
680 }
681 }
682 };
683
684 if let Some(entry) = entry {
685 entry.abort_handle = Some(self.pending_connections.spawn(dial_peer_task(
686 entry.last_seqno,
687 self.endpoint.clone(),
688 address.clone(),
689 *peer_id,
690 self.config.clone(),
691 )));
692 metrics::gauge!(METRIC_CONNECTIONS_PENDING).increment(1);
693 }
694 }
695}
696
697struct PendingConnectionCallbacks {
698 last_seqno: u32,
699 origin: Direction,
700 callbacks: Vec<CallbackTx>,
701 abort_handle: Option<AbortHandle>,
702}
703
704struct PartialConnection {
705 connection: Connection,
706 timeout_at: Instant,
707}
708
709struct ConnectingOutput {
710 seqno: u32,
711 drop_result: bool,
712 connecting_result: ManuallyDrop<Result<Connection, FullConnectionError>>,
713 target_address: Address,
714 target_peer_id: PeerId,
715 origin: Direction,
716}
717
718impl Drop for ConnectingOutput {
719 fn drop(&mut self) {
720 if self.drop_result {
721 unsafe { ManuallyDrop::drop(&mut self.connecting_result) };
723 }
724 }
725}
726
727struct ConnectionClosedOnDrop {
728 connection: ManuallyDrop<Connection>,
729 close_on_drop: bool,
730}
731
732impl ConnectionClosedOnDrop {
733 fn new(connection: Connection) -> Self {
734 Self {
735 connection: ManuallyDrop::new(connection),
736 close_on_drop: true,
737 }
738 }
739
740 fn disarm(mut self) -> Connection {
741 self.close_on_drop = false;
742 unsafe { ManuallyDrop::take(&mut self.connection) }
744 }
745}
746
747impl std::ops::Deref for ConnectionClosedOnDrop {
748 type Target = Connection;
749
750 #[inline]
751 fn deref(&self) -> &Self::Target {
752 &self.connection
753 }
754}
755
756impl Drop for ConnectionClosedOnDrop {
757 fn drop(&mut self) {
758 if self.close_on_drop {
759 let connection = unsafe { ManuallyDrop::take(&mut self.connection) };
761 connection.close();
762 }
763 }
764}
765
766#[derive(Default)]
767struct DelayedCallbacksQueue {
768 callbacks: FastHashMap<PeerId, VecDeque<DelayedCallbacks>>,
769 expirations: DelayQueue<PeerId>,
770}
771
772impl DelayedCallbacksQueue {
773 fn push(
774 &mut self,
775 peer_id: &PeerId,
776 error: ConnectionError,
777 callbacks: Vec<CallbackTx>,
778 delay: &Duration,
779 ) {
780 tracing::debug!(%peer_id, %error, "delayed connection error");
781
782 let expires_at = Instant::now() + *delay;
783 let delay_key = self.expirations.insert_at(*peer_id, expires_at.into());
784
785 let items = self.callbacks.entry(*peer_id).or_default();
786 items.push_back(DelayedCallbacks {
787 delay_key,
788 error,
789 callbacks,
790 expires_at,
791 });
792 }
793
794 async fn wait_for_next_expired(&mut self) -> Option<PeerId> {
795 let res = futures_util::future::poll_fn(|cx| self.expirations.poll_expired(cx)).await?;
796 Some(res.into_inner())
797 }
798
799 fn execute_resolved(&mut self, connection: &Connection) {
800 let Some(items) = self.callbacks.remove(connection.peer_id()) else {
801 return;
802 };
803
804 let mut batches_executed = 0;
805 let mut callbacks_executed = 0;
806
807 for delayed in items {
808 batches_executed += 1;
809 callbacks_executed += delayed.callbacks.len();
810
811 let key = delayed.execute_with_ok(connection);
812
813 self.expirations.remove(&key);
815 }
816
817 tracing::debug!(
818 peer_id = %connection.peer_id(),
819 batches_executed,
820 callbacks_executed,
821 "executed all delayed callbacks",
822 );
823 }
824
825 fn execute_expired(&mut self, peer_id: &PeerId) {
826 let now = Instant::now();
827
828 let mut batches_executed = 0;
829 let mut callbacks_executed = 0;
830
831 'outer: {
832 if let Some(items) = self.callbacks.get_mut(peer_id) {
833 while let Some(front) = items.front() {
834 if !front.is_expired(&now) {
835 break 'outer;
836 }
837
838 if let Some(delayed) = items.pop_front() {
839 batches_executed += 1;
840 callbacks_executed += delayed.callbacks.len();
841
842 let key = delayed.execute_with_error();
843
844 self.expirations.try_remove(&key);
847 }
848 }
849 }
850
851 self.callbacks.remove(peer_id);
853 }
854
855 tracing::debug!(
856 %peer_id,
857 batches_executed,
858 callbacks_executed,
859 "executed expired delayed callbacks"
860 );
861 }
862}
863
864struct DelayedCallbacks {
865 delay_key: delay_queue::Key,
866 error: ConnectionError,
867 callbacks: Vec<CallbackTx>,
868 expires_at: Instant,
869}
870
871impl DelayedCallbacks {
872 fn execute_with_ok(self, connection: &Connection) -> delay_queue::Key {
873 for callback in self.callbacks {
874 _ = callback.send(Ok(connection.clone()));
875 }
876 self.delay_key
877 }
878
879 fn execute_with_error(self) -> delay_queue::Key {
880 for callback in self.callbacks {
881 _ = callback.send(Err(self.error));
882 }
883 self.delay_key
884 }
885
886 fn is_expired(&self, now: &Instant) -> bool {
887 *now >= self.expires_at
888 }
889}
890
891#[derive(Debug)]
892struct DialBackoffState {
893 next_attempt_at: Instant,
894 attempts: usize,
895}
896
897impl DialBackoffState {
898 fn new(now: Instant, step: Duration, max: Duration) -> Self {
899 let mut state = Self {
900 next_attempt_at: now,
901 attempts: 0,
902 };
903 state.update(now, step, max);
904 state
905 }
906
907 fn update(&mut self, now: Instant, step: Duration, max: Duration) {
908 self.attempts += 1;
909 self.next_attempt_at = now
910 + std::cmp::min(
911 max,
912 step.saturating_mul(self.attempts.try_into().unwrap_or(u32::MAX)),
913 );
914 }
915}
916
917#[derive(Debug, thiserror::Error)]
918enum FullConnectionError {
919 #[error("invalid address")]
920 InvalidAddress(#[source] std::io::Error),
921 #[error(transparent)]
922 ConnectionFailed(quinn::ConnectionError),
923 #[error("invalid certificate")]
924 InvalidCertificate,
925 #[error("handshake failed")]
926 HandshakeFailed(#[source] HandshakeError),
927 #[error("connection timeout")]
928 Timeout,
929}
930
931impl FullConnectionError {
932 fn as_brief(&self) -> ConnectionError {
933 fn is_crypto_error(error: &quinn::ConnectionError) -> bool {
934 const QUIC_CRYPTO_FLAG: u64 = 0x100;
935
936 matches!(
937 error,
938 quinn::ConnectionError::TransportError(e)
939 if u64::from(e.code) & QUIC_CRYPTO_FLAG != 0
940 )
941 }
942
943 match self {
944 Self::InvalidAddress(_) => ConnectionError::InvalidAddress,
945 Self::ConnectionFailed(e) if is_crypto_error(e) => ConnectionError::InvalidCertificate,
946 Self::ConnectionFailed(_) => ConnectionError::ConnectionInitFailed,
947 Self::InvalidCertificate => ConnectionError::InvalidCertificate,
948 Self::HandshakeFailed(HandshakeError::ConnectionFailed(e)) if is_crypto_error(e) => {
949 ConnectionError::InvalidCertificate
950 }
951 Self::HandshakeFailed(_) => ConnectionError::HandshakeFailed,
952 Self::Timeout => ConnectionError::Timeout,
953 }
954 }
955
956 fn was_closed(&self) -> bool {
957 let connection_error = match self {
958 Self::ConnectionFailed(e)
959 | Self::HandshakeFailed(HandshakeError::ConnectionFailed(e)) => e,
960 _ => return false,
961 };
962
963 matches!(
964 connection_error,
965 quinn::ConnectionError::ApplicationClosed(closed)
966 if closed.error_code.into_inner() == 0
967 )
968 }
969}
970
971impl From<ConnectionInitError> for FullConnectionError {
972 fn from(value: ConnectionInitError) -> Self {
973 match value {
974 ConnectionInitError::ConnectionFailed(e) => Self::ConnectionFailed(e),
975 ConnectionInitError::InvalidCertificate => Self::InvalidCertificate,
976 }
977 }
978}
979
980#[derive(Clone)]
981pub(crate) struct ActivePeers(Arc<ActivePeersInner>);
982
983impl ActivePeers {
984 pub fn new(channel_size: usize) -> Self {
985 Self(Arc::new(ActivePeersInner::new(channel_size)))
986 }
987
988 pub fn get(&self, peer_id: &PeerId) -> Option<Connection> {
989 self.0.get(peer_id)
990 }
991
992 pub fn contains(&self, peer_id: &PeerId) -> bool {
993 self.0.contains(peer_id)
994 }
995
996 pub fn add(&self, local_id: &PeerId, new_connection: Connection) -> AddedPeer {
997 self.0.add(local_id, new_connection)
998 }
999
1000 pub fn remove(&self, peer_id: &PeerId, reason: DisconnectReason) {
1001 self.0.remove(peer_id, reason);
1002 }
1003
1004 pub fn remove_with_stable_id(
1005 &self,
1006 peer_id: &PeerId,
1007 stable_id: usize,
1008 reason: DisconnectReason,
1009 ) {
1010 self.0.remove_with_stable_id(peer_id, stable_id, reason);
1011 }
1012
1013 pub fn subscribe(&self) -> broadcast::Receiver<PeerEvent> {
1014 self.0.subscribe()
1015 }
1016
1017 pub fn is_empty(&self) -> bool {
1018 self.0.is_empty()
1019 }
1020
1021 pub fn len(&self) -> usize {
1022 self.0.len()
1023 }
1024}
1025
1026struct ActivePeersInner {
1027 connections: FastDashMap<PeerId, Connection>,
1028 connections_len: AtomicUsize,
1029 events_tx: broadcast::Sender<PeerEvent>,
1030}
1031
1032impl ActivePeersInner {
1033 fn new(channel_size: usize) -> Self {
1034 let (events_tx, _) = broadcast::channel(channel_size);
1035 Self {
1036 connections: Default::default(),
1037 connections_len: Default::default(),
1038 events_tx,
1039 }
1040 }
1041
1042 fn get(&self, peer_id: &PeerId) -> Option<Connection> {
1043 self.connections
1044 .get(peer_id)
1045 .map(|item| item.value().clone())
1046 }
1047
1048 fn contains(&self, peer_id: &PeerId) -> bool {
1049 self.connections.contains_key(peer_id)
1050 }
1051
1052 #[must_use]
1053 fn add(&self, local_id: &PeerId, new_connection: Connection) -> AddedPeer {
1054 use dashmap::mapref::entry::Entry;
1055
1056 let mut added = false;
1057
1058 let peer_id = new_connection.peer_id();
1059 match self.connections.entry(*peer_id) {
1060 Entry::Occupied(mut entry) => {
1061 if simultaneous_dial_tie_breaking(
1062 local_id,
1063 peer_id,
1064 entry.get().origin(),
1065 new_connection.origin(),
1066 ) {
1067 tracing::debug!(%peer_id, "closing old connection to mitigate simultaneous dial");
1068 let old_connection = entry.insert(new_connection.clone());
1069 old_connection.close();
1070 self.send_event(PeerEvent::lost_peer(*peer_id, DisconnectReason::Requested));
1071 } else {
1072 tracing::debug!(%peer_id, "closing new connection to mitigate simultaneous dial");
1073 new_connection.close();
1074 return AddedPeer::Existing(entry.get().clone());
1075 }
1076 }
1077 Entry::Vacant(entry) => {
1078 self.connections_len.fetch_add(1, Ordering::Release);
1079 entry.insert(new_connection.clone());
1080 added = true;
1081 }
1082 }
1083
1084 self.send_event(PeerEvent::new_peer(*peer_id));
1085
1086 if added {
1087 metrics::gauge!(METRIC_ACTIVE_PEERS).increment(1);
1088 }
1089 AddedPeer::New(new_connection)
1090 }
1091
1092 fn remove(&self, peer_id: &PeerId, reason: DisconnectReason) {
1093 if let Some((_, connection)) = self.connections.remove(peer_id) {
1094 connection.close();
1095 self.connections_len.fetch_sub(1, Ordering::Release);
1096 self.send_event(PeerEvent::lost_peer(*peer_id, reason));
1097
1098 metrics::gauge!(METRIC_ACTIVE_PEERS).decrement(1);
1099 }
1100 }
1101
1102 fn remove_with_stable_id(&self, peer_id: &PeerId, stable_id: usize, reason: DisconnectReason) {
1103 if let Some((_, connection)) = self
1104 .connections
1105 .remove_if(peer_id, |_, connection| connection.stable_id() == stable_id)
1106 {
1107 connection.close();
1108 self.connections_len.fetch_sub(1, Ordering::Release);
1109 self.send_event(PeerEvent::lost_peer(*peer_id, reason));
1110
1111 metrics::gauge!(METRIC_ACTIVE_PEERS).decrement(1);
1112 }
1113 }
1114
1115 fn subscribe(&self) -> broadcast::Receiver<PeerEvent> {
1116 self.events_tx.subscribe()
1117 }
1118
1119 fn send_event(&self, event: PeerEvent) {
1120 _ = self.events_tx.send(event);
1121 }
1122
1123 fn is_empty(&self) -> bool {
1124 self.connections.is_empty()
1125 }
1126
1127 fn len(&self) -> usize {
1128 self.connections_len.load(Ordering::Acquire)
1129 }
1130}
1131
1132pub(crate) enum AddedPeer {
1133 New(Connection),
1134 Existing(Connection),
1135}
1136
1137fn simultaneous_dial_tie_breaking(
1138 local_id: &PeerId,
1139 peer_id: &PeerId,
1140 old_origin: Direction,
1141 new_origin: Direction,
1142) -> bool {
1143 match (old_origin, new_origin) {
1144 (Direction::Inbound, Direction::Inbound) | (Direction::Outbound, Direction::Outbound) => {
1145 true
1146 }
1147 (Direction::Inbound, Direction::Outbound) => peer_id < local_id,
1148 (Direction::Outbound, Direction::Inbound) => local_id < peer_id,
1149 }
1150}
1151
1152#[derive(Default, Clone)]
1153#[repr(transparent)]
1154pub struct KnownPeers(Arc<FastDashMap<PeerId, KnownPeerState>>);
1155
1156impl KnownPeers {
1157 pub fn new() -> Self {
1158 Self::default()
1159 }
1160
1161 pub fn contains(&self, peer_id: &PeerId) -> bool {
1162 self.0.contains_key(peer_id)
1163 }
1164
1165 pub fn is_banned(&self, peer_id: &PeerId) -> bool {
1166 self.0
1167 .get(peer_id)
1168 .and_then(|item| {
1169 Some(match item.value() {
1170 KnownPeerState::Stored(item) => item.upgrade()?.is_banned(),
1171 KnownPeerState::Banned => true,
1172 })
1173 })
1174 .unwrap_or_default()
1175 }
1176
1177 pub fn get(&self, peer_id: &PeerId) -> Option<Arc<PeerInfo>> {
1178 self.0.get(peer_id).and_then(|item| match item.value() {
1179 KnownPeerState::Stored(item) => {
1180 let inner = item.upgrade()?;
1181 Some(inner.peer_info.load_full())
1182 }
1183 KnownPeerState::Banned => None,
1184 })
1185 }
1186
1187 pub fn get_affinity(&self, peer_id: &PeerId) -> Option<PeerAffinity> {
1188 self.0
1189 .get(peer_id)
1190 .and_then(|item| item.value().compute_affinity())
1191 }
1192
1193 pub fn remove(&self, peer_id: &PeerId) {
1194 self.0.remove(peer_id);
1195 metrics::gauge!(METRIC_KNOWN_PEERS).decrement(1);
1196 }
1197
1198 pub fn ban(&self, peer_id: &PeerId) {
1199 let mut added = false;
1200 match self.0.entry(*peer_id) {
1201 dashmap::mapref::entry::Entry::Vacant(entry) => {
1202 entry.insert(KnownPeerState::Banned);
1203 added = true;
1204 }
1205 dashmap::mapref::entry::Entry::Occupied(mut entry) => match entry.get_mut() {
1206 KnownPeerState::Banned => {}
1207 KnownPeerState::Stored(item) => match item.upgrade() {
1208 Some(item) => item.affinity.store(AFFINITY_BANNED, Ordering::Release),
1209 None => *entry.get_mut() = KnownPeerState::Banned,
1210 },
1211 },
1212 }
1213
1214 if added {
1215 metrics::gauge!(METRIC_KNOWN_PEERS).increment(1);
1217 }
1218 }
1219
1220 pub fn make_handle(&self, peer_id: &PeerId, with_affinity: bool) -> Option<KnownPeerHandle> {
1221 let inner = match self.0.get(peer_id)?.value() {
1222 KnownPeerState::Stored(item) => {
1223 let inner = item.upgrade()?;
1224 if with_affinity && !inner.increase_affinity() {
1225 return None;
1226 }
1227 inner
1228 }
1229 KnownPeerState::Banned => return None,
1230 };
1231
1232 Some(KnownPeerHandle::from_inner(inner, with_affinity))
1233 }
1234
1235 pub fn insert(
1238 &self,
1239 peer_info: Arc<PeerInfo>,
1240 with_affinity: bool,
1241 ) -> Result<KnownPeerHandle, KnownPeersError> {
1242 let mut added = false;
1244 let inner = match self.0.entry(peer_info.id) {
1245 dashmap::mapref::entry::Entry::Vacant(entry) => {
1246 let inner = KnownPeerInner::new(peer_info, with_affinity, &self.0);
1247 entry.insert(KnownPeerState::Stored(Arc::downgrade(&inner)));
1248 added = true;
1249 inner
1250 }
1251 dashmap::mapref::entry::Entry::Occupied(mut entry) => match entry.get_mut() {
1252 KnownPeerState::Banned => return Err(KnownPeersError::from(PeerBannedError)),
1253 KnownPeerState::Stored(item) => match item.upgrade() {
1254 Some(inner) => match inner.try_update_peer_info(&peer_info, with_affinity)? {
1255 true => inner,
1256 false => return Err(KnownPeersError::OutdatedInfo),
1257 },
1258 None => {
1259 let inner = KnownPeerInner::new(peer_info, with_affinity, &self.0);
1260 *item = Arc::downgrade(&inner);
1261 inner
1262 }
1263 },
1264 },
1265 };
1266
1267 if added {
1268 metrics::gauge!(METRIC_KNOWN_PEERS).increment(1);
1269 }
1270
1271 Ok(KnownPeerHandle::from_inner(inner, with_affinity))
1272 }
1273
1274 pub fn insert_allow_outdated(
1276 &self,
1277 peer_info: Arc<PeerInfo>,
1278 with_affinity: bool,
1279 ) -> Result<KnownPeerHandle, PeerBannedError> {
1280 let mut added = false;
1282 let inner = match self.0.entry(peer_info.id) {
1283 dashmap::mapref::entry::Entry::Vacant(entry) => {
1284 let inner = KnownPeerInner::new(peer_info, with_affinity, &self.0);
1285 entry.insert(KnownPeerState::Stored(Arc::downgrade(&inner)));
1286 added = true;
1287 inner
1288 }
1289 dashmap::mapref::entry::Entry::Occupied(mut entry) => match entry.get_mut() {
1290 KnownPeerState::Banned => return Err(PeerBannedError),
1291 KnownPeerState::Stored(item) => match item.upgrade() {
1292 Some(inner) => {
1293 inner.try_update_peer_info(&peer_info, with_affinity)?;
1295 inner
1296 }
1297 None => {
1298 let inner = KnownPeerInner::new(peer_info, with_affinity, &self.0);
1299 *item = Arc::downgrade(&inner);
1300 inner
1301 }
1302 },
1303 },
1304 };
1305
1306 if added {
1307 metrics::gauge!(METRIC_KNOWN_PEERS).increment(1);
1308 }
1309
1310 Ok(KnownPeerHandle::from_inner(inner, with_affinity))
1311 }
1312}
1313
1314enum KnownPeerState {
1315 Stored(Weak<KnownPeerInner>),
1316 Banned,
1317}
1318
1319impl KnownPeerState {
1320 fn compute_affinity(&self) -> Option<PeerAffinity> {
1321 Some(match self {
1322 Self::Stored(weak) => weak.upgrade()?.compute_affinity(),
1323 Self::Banned => PeerAffinity::Never,
1324 })
1325 }
1326}
1327
1328#[derive(Clone)]
1329#[repr(transparent)]
1330pub struct KnownPeerHandle(KnownPeerHandleState);
1331
1332impl KnownPeerHandle {
1333 fn from_inner(inner: Arc<KnownPeerInner>, with_affinity: bool) -> Self {
1334 KnownPeerHandle(if with_affinity {
1335 KnownPeerHandleState::WithAffinity(ManuallyDrop::new(Arc::new(
1336 KnownPeerHandleWithAffinity { inner },
1337 )))
1338 } else {
1339 KnownPeerHandleState::Simple(ManuallyDrop::new(inner))
1340 })
1341 }
1342
1343 pub fn peer_info(&self) -> arc_swap::Guard<Arc<PeerInfo>, arc_swap::DefaultStrategy> {
1344 self.inner().peer_info.load()
1345 }
1346
1347 pub fn load_peer_info(&self) -> Arc<PeerInfo> {
1348 arc_swap::Guard::into_inner(self.peer_info())
1349 }
1350
1351 pub fn is_banned(&self) -> bool {
1352 self.inner().is_banned()
1353 }
1354
1355 pub fn max_affinity(&self) -> PeerAffinity {
1356 self.inner().compute_affinity()
1357 }
1358
1359 pub fn update_peer_info(&self, peer_info: &Arc<PeerInfo>) -> Result<(), KnownPeersError> {
1360 match self.inner().try_update_peer_info(peer_info, false) {
1361 Ok(true) => Ok(()),
1362 Ok(false) => Err(KnownPeersError::OutdatedInfo),
1363 Err(e) => Err(KnownPeersError::PeerBanned(e)),
1364 }
1365 }
1366
1367 pub fn ban(&self) -> bool {
1368 let inner = self.inner();
1369 inner.affinity.swap(AFFINITY_BANNED, Ordering::AcqRel) != AFFINITY_BANNED
1370 }
1371
1372 pub fn increase_affinity(&mut self) -> bool {
1373 match &mut self.0 {
1374 KnownPeerHandleState::Simple(inner) => {
1375 inner.increase_affinity();
1377
1378 let inner = unsafe { ManuallyDrop::take(inner) };
1380
1381 let prev_state = std::mem::replace(
1384 &mut self.0,
1385 KnownPeerHandleState::WithAffinity(ManuallyDrop::new(Arc::new(
1386 KnownPeerHandleWithAffinity { inner },
1387 ))),
1388 );
1389
1390 #[allow(clippy::mem_forget)]
1392 std::mem::forget(prev_state);
1393
1394 true
1395 }
1396 KnownPeerHandleState::WithAffinity(_) => false,
1397 }
1398 }
1399
1400 pub fn decrease_affinity(&mut self) -> bool {
1401 match &mut self.0 {
1402 KnownPeerHandleState::Simple(_) => false,
1403 KnownPeerHandleState::WithAffinity(inner) => {
1404 inner.inner.decrease_affinity();
1406
1407 let inner = unsafe { ManuallyDrop::take(inner) };
1409
1410 let inner = match Arc::try_unwrap(inner) {
1412 Ok(KnownPeerHandleWithAffinity { inner }) => inner,
1413 Err(inner) => inner.inner.clone(),
1414 };
1415
1416 let prev_state = std::mem::replace(
1419 &mut self.0,
1420 KnownPeerHandleState::Simple(ManuallyDrop::new(inner)),
1421 );
1422
1423 #[allow(clippy::mem_forget)]
1425 std::mem::forget(prev_state);
1426
1427 true
1428 }
1429 }
1430 }
1431
1432 pub fn downgrade(&self) -> WeakKnownPeerHandle {
1433 WeakKnownPeerHandle(match &self.0 {
1434 KnownPeerHandleState::Simple(data) => {
1435 WeakKnownPeerHandleState::Simple(Arc::downgrade(data))
1436 }
1437 KnownPeerHandleState::WithAffinity(data) => {
1438 WeakKnownPeerHandleState::WithAffinity(Arc::downgrade(data))
1439 }
1440 })
1441 }
1442
1443 fn inner(&self) -> &KnownPeerInner {
1444 match &self.0 {
1445 KnownPeerHandleState::Simple(data) => data.as_ref(),
1446 KnownPeerHandleState::WithAffinity(data) => data.inner.as_ref(),
1447 }
1448 }
1449}
1450
1451#[derive(Clone)]
1452enum KnownPeerHandleState {
1453 Simple(ManuallyDrop<Arc<KnownPeerInner>>),
1454 WithAffinity(ManuallyDrop<Arc<KnownPeerHandleWithAffinity>>),
1455}
1456
1457impl Drop for KnownPeerHandleState {
1458 fn drop(&mut self) {
1459 let inner;
1460 let is_banned;
1461 match self {
1462 KnownPeerHandleState::Simple(data) => {
1463 inner = unsafe { ManuallyDrop::take(data) };
1465 is_banned = inner.is_banned();
1466 }
1467 KnownPeerHandleState::WithAffinity(data) => {
1468 match Arc::into_inner(unsafe { ManuallyDrop::take(data) }) {
1470 Some(data) => {
1471 inner = data.inner;
1472 is_banned = !inner.decrease_affinity() || inner.is_banned();
1473 }
1474 None => return,
1475 }
1476 }
1477 };
1478
1479 if is_banned {
1480 return;
1482 }
1483
1484 if let Some(inner) = Arc::into_inner(inner) {
1485 if let Some(peers) = inner.weak_known_peers.upgrade() {
1487 peers.remove(&inner.peer_info.load().id);
1488 metrics::gauge!(METRIC_KNOWN_PEERS).decrement(1);
1489 }
1490 }
1491 }
1492}
1493
1494#[derive(Clone, PartialEq, Eq)]
1495#[repr(transparent)]
1496pub struct WeakKnownPeerHandle(WeakKnownPeerHandleState);
1497
1498impl WeakKnownPeerHandle {
1499 pub fn upgrade(&self) -> Option<KnownPeerHandle> {
1500 Some(KnownPeerHandle(match &self.0 {
1501 WeakKnownPeerHandleState::Simple(weak) => {
1502 KnownPeerHandleState::Simple(ManuallyDrop::new(weak.upgrade()?))
1503 }
1504 WeakKnownPeerHandleState::WithAffinity(weak) => {
1505 KnownPeerHandleState::WithAffinity(ManuallyDrop::new(weak.upgrade()?))
1506 }
1507 }))
1508 }
1509}
1510
1511#[derive(Clone)]
1512enum WeakKnownPeerHandleState {
1513 Simple(Weak<KnownPeerInner>),
1514 WithAffinity(Weak<KnownPeerHandleWithAffinity>),
1515}
1516
1517impl Eq for WeakKnownPeerHandleState {}
1518impl PartialEq for WeakKnownPeerHandleState {
1519 #[inline]
1520 fn eq(&self, other: &Self) -> bool {
1521 match (self, other) {
1522 (Self::Simple(left), Self::Simple(right)) => Weak::ptr_eq(left, right),
1523 (Self::WithAffinity(left), Self::WithAffinity(right)) => Weak::ptr_eq(left, right),
1524 _ => false,
1525 }
1526 }
1527}
1528
1529struct KnownPeerHandleWithAffinity {
1530 inner: Arc<KnownPeerInner>,
1531}
1532
1533struct KnownPeerInner {
1534 peer_info: ArcSwap<PeerInfo>,
1535 affinity: AtomicUsize,
1536 weak_known_peers: Weak<FastDashMap<PeerId, KnownPeerState>>,
1537}
1538
1539impl KnownPeerInner {
1540 fn new(
1541 peer_info: Arc<PeerInfo>,
1542 with_affinity: bool,
1543 known_peers: &Arc<FastDashMap<PeerId, KnownPeerState>>,
1544 ) -> Arc<Self> {
1545 Arc::new(Self {
1546 peer_info: ArcSwap::from(peer_info),
1547 affinity: AtomicUsize::new(if with_affinity { 1 } else { 0 }),
1548 weak_known_peers: Arc::downgrade(known_peers),
1549 })
1550 }
1551
1552 fn is_banned(&self) -> bool {
1553 self.affinity.load(Ordering::Acquire) == AFFINITY_BANNED
1554 }
1555
1556 fn compute_affinity(&self) -> PeerAffinity {
1557 match self.affinity.load(Ordering::Acquire) {
1558 0 => PeerAffinity::Allowed,
1559 AFFINITY_BANNED => PeerAffinity::Never,
1560 _ => PeerAffinity::High,
1561 }
1562 }
1563
1564 fn increase_affinity(&self) -> bool {
1565 let mut current = self.affinity.load(Ordering::Acquire);
1566 while current != AFFINITY_BANNED {
1567 debug_assert_ne!(current, AFFINITY_BANNED - 1);
1568 match self.affinity.compare_exchange_weak(
1569 current,
1570 current + 1,
1571 Ordering::Release,
1572 Ordering::Acquire,
1573 ) {
1574 Ok(_) => return true,
1575 Err(affinity) => current = affinity,
1576 }
1577 }
1578
1579 false
1580 }
1581
1582 fn decrease_affinity(&self) -> bool {
1583 let mut current = self.affinity.load(Ordering::Acquire);
1584 while current != AFFINITY_BANNED {
1585 debug_assert_ne!(current, 0);
1586 match self.affinity.compare_exchange_weak(
1587 current,
1588 current - 1,
1589 Ordering::Release,
1590 Ordering::Acquire,
1591 ) {
1592 Ok(_) => return true,
1593 Err(affinity) => current = affinity,
1594 }
1595 }
1596
1597 false
1598 }
1599
1600 fn try_update_peer_info(
1601 &self,
1602 peer_info: &Arc<PeerInfo>,
1603 with_affinity: bool,
1604 ) -> Result<bool, PeerBannedError> {
1605 struct AffinityGuard<'a> {
1606 inner: &'a KnownPeerInner,
1607 decrease_on_drop: bool,
1608 }
1609
1610 impl AffinityGuard<'_> {
1611 fn increase_affinity_or_check_ban(&mut self, with_affinity: bool) -> bool {
1612 let with_affinity = with_affinity && !self.decrease_on_drop;
1613 let is_banned = if with_affinity {
1614 !self.inner.increase_affinity()
1615 } else {
1616 self.inner.is_banned()
1617 };
1618
1619 if !is_banned && with_affinity {
1620 self.decrease_on_drop = true;
1621 }
1622
1623 is_banned
1624 }
1625 }
1626
1627 impl Drop for AffinityGuard<'_> {
1628 fn drop(&mut self) {
1629 if self.decrease_on_drop {
1630 self.inner.decrease_affinity();
1631 }
1632 }
1633 }
1634
1635 let mut guard = AffinityGuard {
1637 inner: self,
1638 decrease_on_drop: false,
1639 };
1640
1641 let mut cur = self.peer_info.load();
1642 let updated = loop {
1643 if guard.increase_affinity_or_check_ban(with_affinity) {
1644 return Err(PeerBannedError);
1646 }
1647
1648 match cur.created_at.cmp(&peer_info.created_at) {
1649 std::cmp::Ordering::Equal => break true,
1652 std::cmp::Ordering::Less => {
1654 let prev = self.peer_info.compare_and_swap(&*cur, peer_info.clone());
1655 if std::ptr::eq(cur.as_raw(), prev.as_raw()) {
1656 break true;
1657 } else {
1658 cur = prev;
1659 }
1660 }
1661 std::cmp::Ordering::Greater => break false,
1663 }
1664 };
1665
1666 guard.decrease_on_drop = false;
1667 Ok(updated)
1668 }
1669}
1670
1671const AFFINITY_BANNED: usize = usize::MAX;
1672
1673#[derive(Debug, thiserror::Error)]
1674pub enum KnownPeersError {
1675 #[error(transparent)]
1676 PeerBanned(#[from] PeerBannedError),
1677 #[error("provided peer info is outdated")]
1678 OutdatedInfo,
1679}
1680
1681#[derive(Debug, Copy, Clone, thiserror::Error)]
1682#[error("peer is banned")]
1683pub struct PeerBannedError;
1684
1685#[cfg(test)]
1686mod tests {
1687 use super::*;
1688 use crate::util::make_peer_info_stub;
1689
1690 #[test]
1691 fn remove_from_cache_on_drop_works() {
1692 let peers = KnownPeers::new();
1693
1694 let peer_info = make_peer_info_stub(rand::random());
1695 let handle = peers.insert(peer_info.clone(), false).unwrap();
1696 assert!(peers.contains(&peer_info.id));
1697 assert!(!peers.is_banned(&peer_info.id));
1698 assert_eq!(peers.get(&peer_info.id), Some(peer_info.clone()));
1699 assert_eq!(
1700 peers.get_affinity(&peer_info.id),
1701 Some(PeerAffinity::Allowed)
1702 );
1703
1704 assert_eq!(handle.peer_info().as_ref(), peer_info.as_ref());
1705 assert_eq!(handle.max_affinity(), PeerAffinity::Allowed);
1706
1707 let other_handle = peers.insert(peer_info.clone(), false).unwrap();
1708 assert!(peers.contains(&peer_info.id));
1709 assert!(!peers.is_banned(&peer_info.id));
1710 assert_eq!(peers.get(&peer_info.id), Some(peer_info.clone()));
1711 assert_eq!(
1712 peers.get_affinity(&peer_info.id),
1713 Some(PeerAffinity::Allowed)
1714 );
1715
1716 assert_eq!(other_handle.peer_info().as_ref(), peer_info.as_ref());
1717 assert_eq!(other_handle.max_affinity(), PeerAffinity::Allowed);
1718
1719 drop(other_handle);
1720 assert!(peers.contains(&peer_info.id));
1721 assert!(!peers.is_banned(&peer_info.id));
1722 assert_eq!(peers.get(&peer_info.id), Some(peer_info.clone()));
1723 assert_eq!(
1724 peers.get_affinity(&peer_info.id),
1725 Some(PeerAffinity::Allowed)
1726 );
1727
1728 drop(handle);
1729 assert!(!peers.contains(&peer_info.id));
1730 assert!(!peers.is_banned(&peer_info.id));
1731 assert_eq!(peers.get(&peer_info.id), None);
1732 assert_eq!(peers.get_affinity(&peer_info.id), None);
1733
1734 peers.insert(peer_info.clone(), false).unwrap();
1735 }
1736
1737 #[test]
1738 fn with_affinity_after_simple() {
1739 let peers = KnownPeers::new();
1740
1741 let peer_info = make_peer_info_stub(rand::random());
1742 let handle_simple = peers.insert(peer_info.clone(), false).unwrap();
1743 assert!(peers.contains(&peer_info.id));
1744 assert_eq!(
1745 peers.get_affinity(&peer_info.id),
1746 Some(PeerAffinity::Allowed)
1747 );
1748 assert_eq!(handle_simple.max_affinity(), PeerAffinity::Allowed);
1749
1750 let handle_with_affinity = peers.insert(peer_info.clone(), true).unwrap();
1751 assert!(peers.contains(&peer_info.id));
1752 assert_eq!(peers.get_affinity(&peer_info.id), Some(PeerAffinity::High));
1753 assert_eq!(handle_with_affinity.max_affinity(), PeerAffinity::High);
1754 assert_eq!(handle_simple.max_affinity(), PeerAffinity::High);
1755
1756 drop(handle_with_affinity);
1757 assert!(peers.contains(&peer_info.id));
1758 assert_eq!(handle_simple.max_affinity(), PeerAffinity::Allowed);
1759 assert_eq!(
1760 peers.get_affinity(&peer_info.id),
1761 Some(PeerAffinity::Allowed)
1762 );
1763
1764 drop(handle_simple);
1765 assert!(!peers.contains(&peer_info.id));
1766 assert_eq!(peers.get_affinity(&peer_info.id), None);
1767 }
1768
1769 #[test]
1770 fn with_affinity_before_simple() {
1771 let peers = KnownPeers::new();
1772
1773 let peer_info = make_peer_info_stub(rand::random());
1774 let handle_with_affinity = peers.insert(peer_info.clone(), true).unwrap();
1775 assert!(peers.contains(&peer_info.id));
1776 assert_eq!(peers.get_affinity(&peer_info.id), Some(PeerAffinity::High));
1777 assert_eq!(handle_with_affinity.max_affinity(), PeerAffinity::High);
1778
1779 let handle_simple = peers.insert(peer_info.clone(), false).unwrap();
1780 assert!(peers.contains(&peer_info.id));
1781 assert_eq!(peers.get_affinity(&peer_info.id), Some(PeerAffinity::High));
1782 assert_eq!(handle_with_affinity.max_affinity(), PeerAffinity::High);
1783 assert_eq!(handle_simple.max_affinity(), PeerAffinity::High);
1784
1785 drop(handle_simple);
1786 assert!(peers.contains(&peer_info.id));
1787 assert_eq!(handle_with_affinity.max_affinity(), PeerAffinity::High);
1788 assert_eq!(peers.get_affinity(&peer_info.id), Some(PeerAffinity::High));
1789
1790 drop(handle_with_affinity);
1791 assert!(!peers.contains(&peer_info.id));
1792 assert_eq!(peers.get_affinity(&peer_info.id), None);
1793 }
1794
1795 #[test]
1796 fn ban_while_handle_exists() {
1797 let peers = KnownPeers::new();
1798
1799 let peer_info = make_peer_info_stub(rand::random());
1800 let handle = peers.insert(peer_info.clone(), false).unwrap();
1801 assert!(peers.contains(&peer_info.id));
1802 assert_eq!(handle.max_affinity(), PeerAffinity::Allowed);
1803
1804 peers.ban(&peer_info.id);
1805 assert!(peers.contains(&peer_info.id));
1806 assert!(peers.is_banned(&peer_info.id));
1807 assert_eq!(handle.max_affinity(), PeerAffinity::Never);
1808 assert_eq!(peers.get(&peer_info.id), Some(peer_info.clone()));
1809 assert_eq!(peers.get_affinity(&peer_info.id), Some(PeerAffinity::Never));
1810 }
1811}