tycho_network/network/
connection_manager.rs

1use std::collections::{hash_map, VecDeque};
2use std::mem::ManuallyDrop;
3use std::sync::atomic::{AtomicUsize, Ordering};
4use std::sync::{Arc, Weak};
5use std::time::{Duration, Instant};
6
7use anyhow::Result;
8use arc_swap::{ArcSwap, AsRaw};
9use tokio::sync::{broadcast, mpsc, oneshot};
10use tokio::task::{AbortHandle, JoinSet};
11use tokio_util::time::{delay_queue, DelayQueue};
12use tycho_util::{FastDashMap, FastHashMap};
13
14use crate::network::config::NetworkConfig;
15use crate::network::connection::Connection;
16use crate::network::endpoint::{Connecting, Endpoint, Into0RttResult};
17use crate::network::request_handler::InboundRequestHandler;
18use crate::network::wire::handshake;
19use crate::types::{
20    Address, BoxCloneService, Direction, DisconnectReason, PeerAffinity, PeerEvent, PeerId,
21    PeerInfo, Response, ServiceRequest,
22};
23
24// Histograms
25const METRIC_CONNECTION_OUT_TIME: &str = "tycho_net_conn_out_time";
26const METRIC_CONNECTION_IN_TIME: &str = "tycho_net_conn_in_time";
27
28// Counters
29const METRIC_CONNECTIONS_OUT_TOTAL: &str = "tycho_net_conn_out_total";
30const METRIC_CONNECTIONS_IN_TOTAL: &str = "tycho_net_conn_in_total";
31const METRIC_CONNECTIONS_OUT_FAIL_TOTAL: &str = "tycho_net_conn_out_fail_total";
32const METRIC_CONNECTIONS_IN_FAIL_TOTAL: &str = "tycho_net_conn_in_fail_total";
33
34// Gauges
35const METRIC_CONNECTIONS_ACTIVE: &str = "tycho_net_conn_active";
36const METRIC_CONNECTIONS_PENDING: &str = "tycho_net_conn_pending";
37const METRIC_CONNECTIONS_PARTIAL: &str = "tycho_net_conn_partial";
38const METRIC_CONNECTIONS_PENDING_DIALS: &str = "tycho_net_conn_pending_dials";
39
40const METRIC_ACTIVE_PEERS: &str = "tycho_net_active_peers";
41const METRIC_KNOWN_PEERS: &str = "tycho_net_known_peers";
42
43#[derive(Debug)]
44pub(crate) enum ConnectionManagerRequest {
45    Connect(Address, PeerId, CallbackTx),
46    Shutdown(oneshot::Sender<()>),
47}
48
49pub(crate) struct ConnectionManager {
50    config: Arc<NetworkConfig>,
51    endpoint: Arc<Endpoint>,
52
53    mailbox: mpsc::Receiver<ConnectionManagerRequest>,
54
55    pending_connection_callbacks: FastHashMap<Address, PendingConnectionCallbacks>,
56    pending_partial_connections: JoinSet<Option<PartialConnection>>,
57    pending_connections: JoinSet<ConnectingOutput>,
58    connection_handlers: JoinSet<()>,
59    delayed_callbacks: DelayedCallbacksQueue,
60
61    pending_dials: FastHashMap<PeerId, CallbackRx>,
62    dial_backoff_states: FastHashMap<PeerId, DialBackoffState>,
63
64    active_peers: ActivePeers,
65    known_peers: KnownPeers,
66
67    service: BoxCloneService<ServiceRequest, Response>,
68}
69
70type CallbackTx = oneshot::Sender<Result<PeerId, Arc<anyhow::Error>>>;
71type CallbackRx = oneshot::Receiver<Result<PeerId, Arc<anyhow::Error>>>;
72
73impl Drop for ConnectionManager {
74    fn drop(&mut self) {
75        tracing::trace!("dropping connection manager");
76        self.endpoint.close();
77    }
78}
79
80impl ConnectionManager {
81    pub fn new(
82        config: Arc<NetworkConfig>,
83        endpoint: Arc<Endpoint>,
84        active_peers: ActivePeers,
85        known_peers: KnownPeers,
86        service: BoxCloneService<ServiceRequest, Response>,
87    ) -> (Self, mpsc::Sender<ConnectionManagerRequest>) {
88        let (mailbox_tx, mailbox) = mpsc::channel(config.connection_manager_channel_capacity);
89        let connection_manager = Self {
90            config,
91            endpoint,
92            mailbox,
93            pending_connection_callbacks: Default::default(),
94            pending_partial_connections: Default::default(),
95            pending_connections: Default::default(),
96            connection_handlers: Default::default(),
97            delayed_callbacks: Default::default(),
98            pending_dials: Default::default(),
99            dial_backoff_states: Default::default(),
100            active_peers,
101            known_peers,
102            service,
103        };
104        (connection_manager, mailbox_tx)
105    }
106
107    pub async fn start(mut self) {
108        tracing::info!("connection manager started");
109
110        let jitter = Duration::from_millis(1000).mul_f64(rand::random());
111        let mut interval = tokio::time::interval(self.config.connectivity_check_interval + jitter);
112
113        let mut shutdown_notifier = None;
114
115        loop {
116            tokio::select! {
117                now = interval.tick() => {
118                    self.handle_connectivity_check(now.into_std());
119                }
120                maybe_request = self.mailbox.recv() => {
121                    let Some(request) = maybe_request else {
122                        break;
123                    };
124
125                    match request {
126                        ConnectionManagerRequest::Connect(address, peer_id, callback) => {
127                            self.handle_connect_request(address, &peer_id, callback);
128                        }
129                        ConnectionManagerRequest::Shutdown(oneshot) => {
130                            shutdown_notifier = Some(oneshot);
131                            break;
132                        }
133                    }
134                }
135                connecting = self.endpoint.accept() => {
136                    if let Some(connecting) = connecting {
137                        self.handle_incoming(connecting);
138                    }
139                }
140                Some(connecting_output) = self.pending_connections.join_next() => {
141                    metrics::gauge!(METRIC_CONNECTIONS_PENDING).decrement(1);
142                    match connecting_output {
143                        Ok(connecting) => self.handle_connecting_result(connecting),
144                        Err(e) => {
145                            if e.is_panic() {
146                                std::panic::resume_unwind(e.into_panic());
147                            }
148                        }
149                    }
150                }
151                Some(partial_connection) = self.pending_partial_connections.join_next() => {
152                    metrics::gauge!(METRIC_CONNECTIONS_PARTIAL).decrement(1);
153                    // NOTE: unwrap here is to propagate panic from the spawned future
154                    if let Some(PartialConnection { connection, timeout_at }) = partial_connection.unwrap() {
155                        self.handle_incoming_impl(connection, None, timeout_at);
156                    }
157                }
158                Some(connection_handler_output) = self.connection_handlers.join_next() => {
159                    metrics::gauge!(METRIC_CONNECTIONS_ACTIVE).decrement(1);
160
161                    // NOTE: unwrap here is to propagate panic from the spawned future
162                    if let Err(e) = connection_handler_output {
163                        if e.is_panic() {
164                            std::panic::resume_unwind(e.into_panic());
165                        }
166                    }
167                }
168                Some(peer_id) = self.delayed_callbacks.wait_for_next_expired() => {
169                    self.delayed_callbacks.execute_expired(&peer_id);
170                }
171            }
172        }
173
174        self.shutdown().await;
175
176        if let Some(tx) = shutdown_notifier {
177            _ = tx.send(());
178        }
179
180        tracing::info!("connection manager stopped");
181    }
182
183    async fn shutdown(mut self) {
184        tracing::trace!("shutting down connection manager");
185
186        self.endpoint.close();
187
188        self.pending_partial_connections.shutdown().await;
189        metrics::gauge!(METRIC_CONNECTIONS_PARTIAL).set(0);
190
191        self.pending_connections.shutdown().await;
192        metrics::gauge!(METRIC_CONNECTIONS_PENDING).set(0);
193
194        while self.connection_handlers.join_next().await.is_some() {
195            metrics::gauge!(METRIC_CONNECTIONS_ACTIVE).decrement(1);
196        }
197        assert!(self.active_peers.is_empty());
198
199        self.endpoint
200            .wait_idle(self.config.shutdown_idle_timeout)
201            .await;
202    }
203
204    fn handle_connectivity_check(&mut self, now: Instant) {
205        use std::collections::hash_map::Entry;
206
207        self.pending_dials
208            .retain(|peer_id, oneshot| match oneshot.try_recv() {
209                Ok(Ok(returned_peer_id)) => {
210                    debug_assert_eq!(peer_id, &returned_peer_id);
211                    self.dial_backoff_states.remove(peer_id);
212                    false
213                }
214                Ok(Err(_)) => {
215                    match self.dial_backoff_states.entry(*peer_id) {
216                        Entry::Occupied(mut entry) => entry.get_mut().update(
217                            now,
218                            self.config.connection_backoff,
219                            self.config.max_connection_backoff,
220                        ),
221                        Entry::Vacant(entry) => {
222                            entry.insert(DialBackoffState::new(
223                                now,
224                                self.config.connection_backoff,
225                                self.config.max_connection_backoff,
226                            ));
227                        }
228                    }
229                    false
230                }
231                Err(oneshot::error::TryRecvError::Closed) => {
232                    panic!("BUG: connection manager never finished dialing a peer");
233                }
234                Err(oneshot::error::TryRecvError::Empty) => true,
235            });
236
237        let outstanding_connections_limit = self
238            .config
239            .max_concurrent_outstanding_connections
240            .saturating_sub(self.pending_connections.len());
241
242        let outstanding_connections = self
243            .known_peers
244            .0
245            .iter()
246            .filter_map(|item| {
247                let value = match item.value() {
248                    KnownPeerState::Stored(item) => item.upgrade()?,
249                    KnownPeerState::Banned => return None,
250                };
251                let peer_info = value.peer_info.load();
252                let affinity = value.compute_affinity();
253
254                (affinity == PeerAffinity::High
255                    && peer_info.id != self.endpoint.peer_id()
256                    && !self.active_peers.contains(&peer_info.id)
257                    && !self.pending_dials.contains_key(&peer_info.id)
258                    && self
259                        .dial_backoff_states
260                        .get(&peer_info.id)
261                        .map_or(true, |state| now > state.next_attempt_at))
262                .then(|| arc_swap::Guard::into_inner(peer_info))
263            })
264            .take(outstanding_connections_limit)
265            .collect::<Vec<_>>();
266
267        for peer_info in outstanding_connections {
268            // TODO: handle multiple addresses
269            let address = peer_info
270                .iter_addresses()
271                .next()
272                .cloned()
273                .expect("address list must have at least one item");
274
275            let (tx, rx) = oneshot::channel();
276            self.dial_peer(address, &peer_info.id, tx);
277            self.pending_dials.insert(peer_info.id, rx);
278        }
279
280        metrics::gauge!(METRIC_CONNECTIONS_PENDING_DIALS).set(self.pending_dials.len() as f64);
281    }
282
283    fn handle_connect_request(&mut self, address: Address, peer_id: &PeerId, callback: CallbackTx) {
284        self.dial_peer(address, peer_id, callback);
285    }
286
287    fn handle_incoming(&mut self, connecting: Connecting) {
288        let remote_addr = connecting.remote_address();
289        tracing::trace!(
290            local_id = %self.endpoint.peer_id(),
291            %remote_addr,
292            "received an incoming connection",
293        );
294
295        // Split incoming connection into 0.5-RTT and 1-RTT parts.
296        match connecting.into_0rtt() {
297            Into0RttResult::Established(connection, accepted) => {
298                let timeout_at = Instant::now() + self.config.connect_timeout;
299                self.handle_incoming_impl(connection, Some(accepted), timeout_at);
300            }
301            Into0RttResult::WithoutIdentity(partial_connection) => {
302                tracing::trace!("connection identity is not available yet");
303
304                let timeout_at = Instant::now() + self.config.connect_timeout;
305                self.pending_partial_connections.spawn(async move {
306                    match tokio::time::timeout_at(timeout_at.into(), partial_connection).await {
307                        Ok(Ok(connection)) => Some(PartialConnection {
308                            connection,
309                            timeout_at,
310                        }),
311                        Ok(Err(e)) => {
312                            tracing::trace!(
313                                %remote_addr,
314                                "failed to establish an incoming connection: {e}",
315                            );
316                            None
317                        }
318                        Err(_) => {
319                            tracing::trace!(
320                                %remote_addr,
321                                "incoming connection timed out",
322                            );
323                            None
324                        }
325                    }
326                });
327                metrics::gauge!(METRIC_CONNECTIONS_PARTIAL).increment(1);
328            }
329            Into0RttResult::InvalidConnection(e) => {
330                // TODO: Lower log level to trace/debug?
331                tracing::warn!(%remote_addr, "invalid incoming connection: {e}");
332            }
333            Into0RttResult::Unavailable(_) => unreachable!(
334                "BUG: For incoming connections, a 0.5-RTT connection must \
335                always be successfully constructed."
336            ),
337        };
338    }
339
340    fn handle_incoming_impl(
341        &mut self,
342        connection: Connection,
343        accepted: Option<quinn::ZeroRttAccepted>,
344        timeout_at: Instant,
345    ) {
346        async fn handle_incoming_task(
347            seqno: u32,
348            connection: ConnectionClosedOnDrop,
349            accepted: Option<quinn::ZeroRttAccepted>,
350            timeout_at: Instant,
351        ) -> ConnectingOutput {
352            let target_peer_id = *connection.peer_id();
353            let target_address = connection.remote_address().into();
354            let fut = async {
355                if let Some(accepted) = accepted {
356                    // NOTE: `bool` output of this future is meaningless for servers.
357                    accepted.await;
358                }
359                handshake(&connection).await
360            };
361
362            let started_at = Instant::now();
363
364            let connecting_result = tokio::time::timeout_at(timeout_at.into(), fut)
365                .await
366                .map_err(Into::into)
367                .and_then(std::convert::identity)
368                .map_err(Arc::new)
369                .map(|_| connection.disarm());
370
371            metrics::histogram!(METRIC_CONNECTION_IN_TIME).record(started_at.elapsed());
372
373            ConnectingOutput {
374                seqno,
375                drop_result: true,
376                connecting_result: ManuallyDrop::new(connecting_result),
377                target_address,
378                target_peer_id,
379                origin: Direction::Inbound,
380            }
381        }
382
383        let remote_addr = connection.remote_address();
384
385        // Check if the peer is allowed before doing anything else.
386        match self.known_peers.get_affinity(connection.peer_id()) {
387            Some(PeerAffinity::High | PeerAffinity::Allowed) => {}
388            Some(PeerAffinity::Never) => {
389                // TODO: Lower log level to trace/debug?
390                tracing::warn!(
391                    %remote_addr,
392                    peer_id = %connection.peer_id(),
393                    "rejecting connection due to PeerAffinity::Never",
394                );
395                connection.close();
396                return;
397            }
398            _ => {
399                if matches!(
400                    self.config.max_concurrent_connections,
401                    Some(limit) if self.active_peers.len() >= limit
402                ) {
403                    // TODO: Lower log level to trace/debug?
404                    tracing::warn!(
405                        %remote_addr,
406                        peer_id = %connection.peer_id(),
407                        "rejecting connection due too many concurrent connections",
408                    );
409                    connection.close();
410                    return;
411                }
412            }
413        }
414
415        let entry = match self.pending_connection_callbacks.entry(remote_addr.into()) {
416            hash_map::Entry::Vacant(entry) => Some(entry.insert(PendingConnectionCallbacks {
417                last_seqno: 0,
418                origin: Direction::Inbound,
419                callbacks: Default::default(),
420                abort_handle: None,
421            })),
422            hash_map::Entry::Occupied(entry) => {
423                let entry = entry.into_mut();
424
425                // Check if the incoming connection is a simultaneous dial.
426                if simultaneous_dial_tie_breaking(
427                    self.endpoint.peer_id(),
428                    connection.peer_id(),
429                    entry.origin,
430                    Direction::Inbound,
431                ) {
432                    // New connection wins the tie, abort the old one and spawn a new task.
433                    tracing::debug!(
434                        %remote_addr,
435                        peer_id = %connection.peer_id(),
436                        "cancelling old connection to mitigate simultaneous dial",
437                    );
438
439                    entry.origin = Direction::Inbound;
440                    entry.last_seqno += 1;
441                    if let Some(handle) = entry.abort_handle.take() {
442                        handle.abort();
443                    }
444                    Some(entry)
445                } else {
446                    // Old connection wins the tie, gracefully close the new one.
447                    tracing::debug!(
448                        %remote_addr,
449                        peer_id = %connection.peer_id(),
450                        "cancelling new connection to mitigate simultaneous dial",
451                    );
452
453                    connection.close();
454                    None
455                }
456            }
457        };
458
459        if let Some(entry) = entry {
460            entry.abort_handle = Some(self.pending_connections.spawn(handle_incoming_task(
461                entry.last_seqno,
462                ConnectionClosedOnDrop::new(connection),
463                accepted,
464                timeout_at,
465            )));
466            metrics::gauge!(METRIC_CONNECTIONS_PENDING).increment(1);
467        }
468    }
469
470    fn handle_connecting_result(&mut self, mut res: ConnectingOutput) {
471        // Check seqno first to drop outdated results.
472        {
473            let Some(entry) = self.pending_connection_callbacks.get(&res.target_address) else {
474                tracing::trace!("connection task reordering detected");
475                return;
476            };
477
478            if entry.last_seqno != res.seqno {
479                tracing::debug!(
480                    local_id = %self.endpoint.peer_id(),
481                    peer_id = %res.target_peer_id,
482                    remote_addr = %res.target_address,
483                    "connection result is outdated"
484                );
485                return;
486            }
487        }
488
489        let callbacks = self
490            .pending_connection_callbacks
491            .remove(&res.target_address)
492            .expect("Connection tasks must be tracked")
493            .callbacks;
494
495        res.drop_result = false;
496        // SAFETY: `drop_result` is set to `false`.
497        match unsafe { ManuallyDrop::take(&mut res.connecting_result) } {
498            Ok(connection) => {
499                let peer_id = *connection.peer_id();
500                tracing::debug!(
501                    local_id = %self.endpoint.peer_id(),
502                    %peer_id,
503                    remote_addr = %res.target_address,
504                    "new connection",
505                );
506                self.add_peer(connection);
507
508                self.delayed_callbacks.execute_resolved(&res.target_peer_id);
509
510                for callback in callbacks {
511                    _ = callback.send(Ok(peer_id));
512                }
513            }
514            Err(e) => {
515                tracing::debug!(
516                    local_id = %self.endpoint.peer_id(),
517                    peer_id = %res.target_peer_id,
518                    remote_addr = %res.target_address,
519                    "connection failed: {e}"
520                );
521
522                metrics::counter!(match res.origin {
523                    Direction::Outbound => METRIC_CONNECTIONS_OUT_FAIL_TOTAL,
524                    Direction::Inbound => METRIC_CONNECTIONS_IN_FAIL_TOTAL,
525                })
526                .increment(1);
527
528                // Delay sending the error to callbacks as the target peer might be
529                // in the process of connecting to us.
530                if let Some(quinn::ConnectionError::ApplicationClosed(closed)) = e.downcast_ref() {
531                    if closed.error_code.into_inner() == 0 && !callbacks.is_empty() {
532                        self.delayed_callbacks.push(
533                            &res.target_peer_id,
534                            e,
535                            callbacks,
536                            &self.config.connection_error_delay,
537                        );
538                        return;
539                    }
540                }
541
542                for callback in callbacks {
543                    _ = callback.send(Err(e.clone()));
544                }
545            }
546        }
547    }
548
549    fn add_peer(&mut self, connection: Connection) {
550        if let Some(connection) = self.active_peers.add(self.endpoint.peer_id(), connection) {
551            let origin = connection.origin();
552
553            let handler = InboundRequestHandler::new(
554                self.config.clone(),
555                connection,
556                self.service.clone(),
557                self.active_peers.clone(),
558            );
559
560            metrics::counter!(match origin {
561                Direction::Outbound => METRIC_CONNECTIONS_OUT_TOTAL,
562                Direction::Inbound => METRIC_CONNECTIONS_IN_TOTAL,
563            })
564            .increment(1);
565
566            metrics::gauge!(METRIC_CONNECTIONS_ACTIVE).increment(1);
567            self.connection_handlers.spawn(handler.start());
568        }
569    }
570
571    #[tracing::instrument(
572        level = "trace",
573        skip_all,
574        fields(
575            local_id = %self.endpoint.peer_id(),
576            peer_id = %peer_id,
577            remote_addr = %address,
578        ),
579    )]
580    fn dial_peer(&mut self, address: Address, peer_id: &PeerId, callback: CallbackTx) {
581        async fn dial_peer_task(
582            seqno: u32,
583            endpoint: Arc<Endpoint>,
584            address: Address,
585            peer_id: PeerId,
586            config: Arc<NetworkConfig>,
587        ) -> ConnectingOutput {
588            let fut = async {
589                let address = address.resolve().await?;
590                let connecting = endpoint.connect_with_expected_id(&address, &peer_id)?;
591                let connection = ConnectionClosedOnDrop::new(connecting.await?);
592                handshake(&connection).await?;
593                Ok(connection)
594            };
595
596            let started_at = Instant::now();
597
598            let connecting_result = tokio::time::timeout(config.connect_timeout, fut)
599                .await
600                .map_err(Into::into)
601                .and_then(std::convert::identity)
602                .map_err(Arc::new)
603                .map(ConnectionClosedOnDrop::disarm);
604
605            metrics::histogram!(METRIC_CONNECTION_OUT_TIME).record(started_at.elapsed());
606
607            ConnectingOutput {
608                seqno,
609                drop_result: true,
610                connecting_result: ManuallyDrop::new(connecting_result),
611                target_address: address,
612                target_peer_id: peer_id,
613                origin: Direction::Outbound,
614            }
615        }
616
617        if self.active_peers.contains(peer_id) {
618            tracing::debug!("peer is already connected");
619            _ = callback.send(Ok(*peer_id));
620            return;
621        }
622
623        tracing::trace!("connecting to peer");
624
625        let entry = match self.pending_connection_callbacks.entry(address.clone()) {
626            hash_map::Entry::Vacant(entry) => Some(entry.insert(PendingConnectionCallbacks {
627                last_seqno: 0,
628                origin: Direction::Outbound,
629                callbacks: vec![callback],
630                abort_handle: None,
631            })),
632            hash_map::Entry::Occupied(entry) => {
633                let entry = entry.into_mut();
634
635                // Add the callback to the existing entry.
636                entry.callbacks.push(callback);
637
638                // Check if the outgoing connection is a simultaneous dial.
639                let break_tie = simultaneous_dial_tie_breaking(
640                    self.endpoint.peer_id(),
641                    peer_id,
642                    entry.origin,
643                    Direction::Outbound,
644                );
645
646                if break_tie && entry.origin != Direction::Outbound {
647                    // New connection wins the tie, abort the old one and spawn a new task.
648                    tracing::debug!("cancelling old connection to mitigate simultaneous dial");
649
650                    entry.origin = Direction::Outbound;
651                    entry.last_seqno += 1;
652                    if let Some(handle) = entry.abort_handle.take() {
653                        handle.abort();
654                    }
655                    Some(entry)
656                } else {
657                    // Old connection wins the tie, don't create a new one
658                    tracing::trace!("reusing old connection to mitigate simultaneous dial");
659                    None
660                }
661            }
662        };
663
664        if let Some(entry) = entry {
665            entry.abort_handle = Some(self.pending_connections.spawn(dial_peer_task(
666                entry.last_seqno,
667                self.endpoint.clone(),
668                address.clone(),
669                *peer_id,
670                self.config.clone(),
671            )));
672            metrics::gauge!(METRIC_CONNECTIONS_PENDING).increment(1);
673        }
674    }
675}
676
677struct PendingConnectionCallbacks {
678    last_seqno: u32,
679    origin: Direction,
680    callbacks: Vec<CallbackTx>,
681    abort_handle: Option<AbortHandle>,
682}
683
684struct PartialConnection {
685    connection: Connection,
686    timeout_at: Instant,
687}
688
689struct ConnectingOutput {
690    seqno: u32,
691    drop_result: bool,
692    connecting_result: ManuallyDrop<Result<Connection, Arc<anyhow::Error>>>,
693    target_address: Address,
694    target_peer_id: PeerId,
695    origin: Direction,
696}
697
698impl Drop for ConnectingOutput {
699    fn drop(&mut self) {
700        if self.drop_result {
701            // SAFETY: `drop_result` is set to `true` only when the result is not used.
702            unsafe { ManuallyDrop::drop(&mut self.connecting_result) };
703        }
704    }
705}
706
707struct ConnectionClosedOnDrop {
708    connection: ManuallyDrop<Connection>,
709    close_on_drop: bool,
710}
711
712impl ConnectionClosedOnDrop {
713    fn new(connection: Connection) -> Self {
714        Self {
715            connection: ManuallyDrop::new(connection),
716            close_on_drop: true,
717        }
718    }
719
720    fn disarm(mut self) -> Connection {
721        self.close_on_drop = false;
722        // SAFETY: `drop` will not be called.
723        unsafe { ManuallyDrop::take(&mut self.connection) }
724    }
725}
726
727impl std::ops::Deref for ConnectionClosedOnDrop {
728    type Target = Connection;
729
730    #[inline]
731    fn deref(&self) -> &Self::Target {
732        &self.connection
733    }
734}
735
736impl Drop for ConnectionClosedOnDrop {
737    fn drop(&mut self) {
738        if self.close_on_drop {
739            // SAFETY: `disarm` was not called.
740            let connection = unsafe { ManuallyDrop::take(&mut self.connection) };
741            connection.close();
742        }
743    }
744}
745
746#[derive(Default)]
747struct DelayedCallbacksQueue {
748    callbacks: FastHashMap<PeerId, VecDeque<DelayedCallbacks>>,
749    expirations: DelayQueue<PeerId>,
750}
751
752impl DelayedCallbacksQueue {
753    fn push(
754        &mut self,
755        peer_id: &PeerId,
756        error: Arc<anyhow::Error>,
757        callbacks: Vec<CallbackTx>,
758        delay: &Duration,
759    ) {
760        tracing::debug!(%peer_id, %error, "delayed connection error");
761
762        let expires_at = Instant::now() + *delay;
763        let delay_key = self.expirations.insert_at(*peer_id, expires_at.into());
764
765        let items = self.callbacks.entry(*peer_id).or_default();
766        items.push_back(DelayedCallbacks {
767            delay_key,
768            error,
769            callbacks,
770            expires_at,
771        });
772    }
773
774    async fn wait_for_next_expired(&mut self) -> Option<PeerId> {
775        let res = futures_util::future::poll_fn(|cx| self.expirations.poll_expired(cx)).await?;
776        Some(res.into_inner())
777    }
778
779    fn execute_resolved(&mut self, peer_id: &PeerId) {
780        let Some(items) = self.callbacks.remove(peer_id) else {
781            return;
782        };
783
784        let mut batches_executed = 0;
785        let mut callbacks_executed = 0;
786
787        for delayed in items {
788            batches_executed += 1;
789            callbacks_executed += delayed.callbacks.len();
790
791            let key = delayed.execute_with_ok(peer_id);
792
793            // NOTE: Delay key must exist in the queue.
794            self.expirations.remove(&key);
795        }
796
797        tracing::debug!(
798            %peer_id,
799            batches_executed,
800            callbacks_executed,
801            "executed all delayed callbacks",
802        );
803    }
804
805    fn execute_expired(&mut self, peer_id: &PeerId) {
806        let now = Instant::now();
807
808        let mut batches_executed = 0;
809        let mut callbacks_executed = 0;
810
811        'outer: {
812            if let Some(items) = self.callbacks.get_mut(peer_id) {
813                while let Some(front) = items.front() {
814                    if !front.is_expired(&now) {
815                        break 'outer;
816                    }
817
818                    if let Some(delayed) = items.pop_front() {
819                        batches_executed += 1;
820                        callbacks_executed += delayed.callbacks.len();
821
822                        let key = delayed.execute_with_error();
823
824                        // NOTE: Might not be necessary since items are stored in order,
825                        // but it's better to be safe.
826                        self.expirations.try_remove(&key);
827                    }
828                }
829            }
830
831            // There is no need to hold an empty queue for this peer
832            self.callbacks.remove(peer_id);
833        }
834
835        tracing::debug!(
836            %peer_id,
837            batches_executed,
838            callbacks_executed,
839            "executed expired delayed callbacks"
840        );
841    }
842}
843
844struct DelayedCallbacks {
845    delay_key: delay_queue::Key,
846    error: Arc<anyhow::Error>,
847    callbacks: Vec<CallbackTx>,
848    expires_at: Instant,
849}
850
851impl DelayedCallbacks {
852    fn execute_with_ok(self, peer_id: &PeerId) -> delay_queue::Key {
853        for callback in self.callbacks {
854            _ = callback.send(Ok(*peer_id));
855        }
856        self.delay_key
857    }
858
859    fn execute_with_error(self) -> delay_queue::Key {
860        for callback in self.callbacks {
861            _ = callback.send(Err(self.error.clone()));
862        }
863        self.delay_key
864    }
865
866    fn is_expired(&self, now: &Instant) -> bool {
867        *now >= self.expires_at
868    }
869}
870
871#[derive(Debug)]
872struct DialBackoffState {
873    next_attempt_at: Instant,
874    attempts: usize,
875}
876
877impl DialBackoffState {
878    fn new(now: Instant, step: Duration, max: Duration) -> Self {
879        let mut state = Self {
880            next_attempt_at: now,
881            attempts: 0,
882        };
883        state.update(now, step, max);
884        state
885    }
886
887    fn update(&mut self, now: Instant, step: Duration, max: Duration) {
888        self.attempts += 1;
889        self.next_attempt_at = now
890            + std::cmp::min(
891                max,
892                step.saturating_mul(self.attempts.try_into().unwrap_or(u32::MAX)),
893            );
894    }
895}
896
897#[derive(Clone)]
898pub struct ActivePeers(Arc<ActivePeersInner>);
899
900impl ActivePeers {
901    pub fn new(channel_size: usize) -> Self {
902        Self(Arc::new(ActivePeersInner::new(channel_size)))
903    }
904
905    pub fn get(&self, peer_id: &PeerId) -> Option<Connection> {
906        self.0.get(peer_id)
907    }
908
909    pub fn contains(&self, peer_id: &PeerId) -> bool {
910        self.0.contains(peer_id)
911    }
912
913    pub fn add(&self, local_id: &PeerId, new_connection: Connection) -> Option<Connection> {
914        self.0.add(local_id, new_connection)
915    }
916
917    pub fn remove(&self, peer_id: &PeerId, reason: DisconnectReason) {
918        self.0.remove(peer_id, reason);
919    }
920
921    pub fn remove_with_stable_id(
922        &self,
923        peer_id: &PeerId,
924        stable_id: usize,
925        reason: DisconnectReason,
926    ) {
927        self.0.remove_with_stable_id(peer_id, stable_id, reason);
928    }
929
930    pub fn subscribe(&self) -> broadcast::Receiver<PeerEvent> {
931        self.0.subscribe()
932    }
933
934    pub fn is_empty(&self) -> bool {
935        self.0.is_empty()
936    }
937
938    pub fn len(&self) -> usize {
939        self.0.len()
940    }
941
942    pub fn downgrade(this: &Self) -> WeakActivePeers {
943        WeakActivePeers(Arc::downgrade(&this.0))
944    }
945}
946
947#[derive(Clone)]
948pub struct WeakActivePeers(Weak<ActivePeersInner>);
949
950impl WeakActivePeers {
951    pub fn upgrade(&self) -> Option<ActivePeers> {
952        self.0.upgrade().map(ActivePeers)
953    }
954}
955
956struct ActivePeersInner {
957    connections: FastDashMap<PeerId, Connection>,
958    connections_len: AtomicUsize,
959    events_tx: broadcast::Sender<PeerEvent>,
960}
961
962impl ActivePeersInner {
963    fn new(channel_size: usize) -> Self {
964        let (events_tx, _) = broadcast::channel(channel_size);
965        Self {
966            connections: Default::default(),
967            connections_len: Default::default(),
968            events_tx,
969        }
970    }
971
972    fn get(&self, peer_id: &PeerId) -> Option<Connection> {
973        self.connections
974            .get(peer_id)
975            .map(|item| item.value().clone())
976    }
977
978    fn contains(&self, peer_id: &PeerId) -> bool {
979        self.connections.contains_key(peer_id)
980    }
981
982    #[must_use]
983    fn add(&self, local_id: &PeerId, new_connection: Connection) -> Option<Connection> {
984        use dashmap::mapref::entry::Entry;
985
986        let mut added = false;
987
988        let peer_id = new_connection.peer_id();
989        match self.connections.entry(*peer_id) {
990            Entry::Occupied(mut entry) => {
991                if simultaneous_dial_tie_breaking(
992                    local_id,
993                    peer_id,
994                    entry.get().origin(),
995                    new_connection.origin(),
996                ) {
997                    tracing::debug!(%peer_id, "closing old connection to mitigate simultaneous dial");
998                    let old_connection = entry.insert(new_connection.clone());
999                    old_connection.close();
1000                    self.send_event(PeerEvent::LostPeer(*peer_id, DisconnectReason::Requested));
1001                } else {
1002                    tracing::debug!(%peer_id, "closing new connection to mitigate simultaneous dial");
1003                    new_connection.close();
1004                    return None;
1005                }
1006            }
1007            Entry::Vacant(entry) => {
1008                self.connections_len.fetch_add(1, Ordering::Release);
1009                entry.insert(new_connection.clone());
1010                added = true;
1011            }
1012        }
1013
1014        self.send_event(PeerEvent::NewPeer(*peer_id));
1015
1016        if added {
1017            metrics::gauge!(METRIC_ACTIVE_PEERS).increment(1);
1018        }
1019        Some(new_connection)
1020    }
1021
1022    fn remove(&self, peer_id: &PeerId, reason: DisconnectReason) {
1023        if let Some((_, connection)) = self.connections.remove(peer_id) {
1024            connection.close();
1025            self.connections_len.fetch_sub(1, Ordering::Release);
1026            self.send_event(PeerEvent::LostPeer(*peer_id, reason));
1027
1028            metrics::gauge!(METRIC_ACTIVE_PEERS).decrement(1);
1029        }
1030    }
1031
1032    fn remove_with_stable_id(&self, peer_id: &PeerId, stable_id: usize, reason: DisconnectReason) {
1033        if let Some((_, connection)) = self
1034            .connections
1035            .remove_if(peer_id, |_, connection| connection.stable_id() == stable_id)
1036        {
1037            connection.close();
1038            self.connections_len.fetch_sub(1, Ordering::Release);
1039            self.send_event(PeerEvent::LostPeer(*peer_id, reason));
1040
1041            metrics::gauge!(METRIC_ACTIVE_PEERS).decrement(1);
1042        }
1043    }
1044
1045    fn subscribe(&self) -> broadcast::Receiver<PeerEvent> {
1046        self.events_tx.subscribe()
1047    }
1048
1049    fn send_event(&self, event: PeerEvent) {
1050        _ = self.events_tx.send(event);
1051    }
1052
1053    fn is_empty(&self) -> bool {
1054        self.connections.is_empty()
1055    }
1056
1057    fn len(&self) -> usize {
1058        self.connections_len.load(Ordering::Acquire)
1059    }
1060}
1061
1062fn simultaneous_dial_tie_breaking(
1063    local_id: &PeerId,
1064    peer_id: &PeerId,
1065    old_origin: Direction,
1066    new_origin: Direction,
1067) -> bool {
1068    match (old_origin, new_origin) {
1069        (Direction::Inbound, Direction::Inbound) | (Direction::Outbound, Direction::Outbound) => {
1070            true
1071        }
1072        (Direction::Inbound, Direction::Outbound) => peer_id < local_id,
1073        (Direction::Outbound, Direction::Inbound) => local_id < peer_id,
1074    }
1075}
1076
1077#[derive(Default, Clone)]
1078#[repr(transparent)]
1079pub struct KnownPeers(Arc<FastDashMap<PeerId, KnownPeerState>>);
1080
1081impl KnownPeers {
1082    pub fn new() -> Self {
1083        Self::default()
1084    }
1085
1086    pub fn contains(&self, peer_id: &PeerId) -> bool {
1087        self.0.contains_key(peer_id)
1088    }
1089
1090    pub fn is_banned(&self, peer_id: &PeerId) -> bool {
1091        self.0
1092            .get(peer_id)
1093            .and_then(|item| {
1094                Some(match item.value() {
1095                    KnownPeerState::Stored(item) => item.upgrade()?.is_banned(),
1096                    KnownPeerState::Banned => true,
1097                })
1098            })
1099            .unwrap_or_default()
1100    }
1101
1102    pub fn get(&self, peer_id: &PeerId) -> Option<Arc<PeerInfo>> {
1103        self.0.get(peer_id).and_then(|item| match item.value() {
1104            KnownPeerState::Stored(item) => {
1105                let inner = item.upgrade()?;
1106                Some(inner.peer_info.load_full())
1107            }
1108            KnownPeerState::Banned => None,
1109        })
1110    }
1111
1112    pub fn get_affinity(&self, peer_id: &PeerId) -> Option<PeerAffinity> {
1113        self.0
1114            .get(peer_id)
1115            .and_then(|item| item.value().compute_affinity())
1116    }
1117
1118    pub fn remove(&self, peer_id: &PeerId) {
1119        self.0.remove(peer_id);
1120        metrics::gauge!(METRIC_KNOWN_PEERS).decrement(1);
1121    }
1122
1123    pub fn ban(&self, peer_id: &PeerId) {
1124        let mut added = false;
1125        match self.0.entry(*peer_id) {
1126            dashmap::mapref::entry::Entry::Vacant(entry) => {
1127                entry.insert(KnownPeerState::Banned);
1128                added = true;
1129            }
1130            dashmap::mapref::entry::Entry::Occupied(mut entry) => match entry.get_mut() {
1131                KnownPeerState::Banned => {}
1132                KnownPeerState::Stored(item) => match item.upgrade() {
1133                    Some(item) => item.affinity.store(AFFINITY_BANNED, Ordering::Release),
1134                    None => *entry.get_mut() = KnownPeerState::Banned,
1135                },
1136            },
1137        }
1138
1139        if added {
1140            // NOTE: "New" banned peer is a "new" known peer.
1141            metrics::gauge!(METRIC_KNOWN_PEERS).increment(1);
1142        }
1143    }
1144
1145    pub fn make_handle(&self, peer_id: &PeerId, with_affinity: bool) -> Option<KnownPeerHandle> {
1146        let inner = match self.0.get(peer_id)?.value() {
1147            KnownPeerState::Stored(item) => {
1148                let inner = item.upgrade()?;
1149                if with_affinity && !inner.increase_affinity() {
1150                    return None;
1151                }
1152                inner
1153            }
1154            KnownPeerState::Banned => return None,
1155        };
1156
1157        Some(KnownPeerHandle::from_inner(inner, with_affinity))
1158    }
1159
1160    /// Inserts a new handle only if the provided info is not outdated
1161    /// and the peer is not banned.
1162    pub fn insert(
1163        &self,
1164        peer_info: Arc<PeerInfo>,
1165        with_affinity: bool,
1166    ) -> Result<KnownPeerHandle, KnownPeersError> {
1167        // TODO: add capacity limit for entries without affinity
1168        let mut added = false;
1169        let inner = match self.0.entry(peer_info.id) {
1170            dashmap::mapref::entry::Entry::Vacant(entry) => {
1171                let inner = KnownPeerInner::new(peer_info, with_affinity, &self.0);
1172                entry.insert(KnownPeerState::Stored(Arc::downgrade(&inner)));
1173                added = true;
1174                inner
1175            }
1176            dashmap::mapref::entry::Entry::Occupied(mut entry) => match entry.get_mut() {
1177                KnownPeerState::Banned => return Err(KnownPeersError::from(PeerBannedError)),
1178                KnownPeerState::Stored(item) => match item.upgrade() {
1179                    Some(inner) => match inner.try_update_peer_info(&peer_info, with_affinity)? {
1180                        true => inner,
1181                        false => return Err(KnownPeersError::OutdatedInfo),
1182                    },
1183                    None => {
1184                        let inner = KnownPeerInner::new(peer_info, with_affinity, &self.0);
1185                        *item = Arc::downgrade(&inner);
1186                        inner
1187                    }
1188                },
1189            },
1190        };
1191
1192        if added {
1193            metrics::gauge!(METRIC_KNOWN_PEERS).increment(1);
1194        }
1195
1196        Ok(KnownPeerHandle::from_inner(inner, with_affinity))
1197    }
1198
1199    /// Same as [`KnownPeers::insert`], but ignores outdated info.
1200    pub fn insert_allow_outdated(
1201        &self,
1202        peer_info: Arc<PeerInfo>,
1203        with_affinity: bool,
1204    ) -> Result<KnownPeerHandle, PeerBannedError> {
1205        // TODO: add capacity limit for entries without affinity
1206        let mut added = false;
1207        let inner = match self.0.entry(peer_info.id) {
1208            dashmap::mapref::entry::Entry::Vacant(entry) => {
1209                let inner = KnownPeerInner::new(peer_info, with_affinity, &self.0);
1210                entry.insert(KnownPeerState::Stored(Arc::downgrade(&inner)));
1211                added = true;
1212                inner
1213            }
1214            dashmap::mapref::entry::Entry::Occupied(mut entry) => match entry.get_mut() {
1215                KnownPeerState::Banned => return Err(PeerBannedError),
1216                KnownPeerState::Stored(item) => match item.upgrade() {
1217                    Some(inner) => {
1218                        // NOTE: Outdated info is ignored here.
1219                        inner.try_update_peer_info(&peer_info, with_affinity)?;
1220                        inner
1221                    }
1222                    None => {
1223                        let inner = KnownPeerInner::new(peer_info, with_affinity, &self.0);
1224                        *item = Arc::downgrade(&inner);
1225                        inner
1226                    }
1227                },
1228            },
1229        };
1230
1231        if added {
1232            metrics::gauge!(METRIC_KNOWN_PEERS).increment(1);
1233        }
1234
1235        Ok(KnownPeerHandle::from_inner(inner, with_affinity))
1236    }
1237}
1238
1239enum KnownPeerState {
1240    Stored(Weak<KnownPeerInner>),
1241    Banned,
1242}
1243
1244impl KnownPeerState {
1245    fn compute_affinity(&self) -> Option<PeerAffinity> {
1246        Some(match self {
1247            Self::Stored(weak) => weak.upgrade()?.compute_affinity(),
1248            Self::Banned => PeerAffinity::Never,
1249        })
1250    }
1251}
1252
1253#[derive(Clone)]
1254#[repr(transparent)]
1255pub struct KnownPeerHandle(KnownPeerHandleState);
1256
1257impl KnownPeerHandle {
1258    fn from_inner(inner: Arc<KnownPeerInner>, with_affinity: bool) -> Self {
1259        KnownPeerHandle(if with_affinity {
1260            KnownPeerHandleState::WithAffinity(ManuallyDrop::new(Arc::new(
1261                KnownPeerHandleWithAffinity { inner },
1262            )))
1263        } else {
1264            KnownPeerHandleState::Simple(ManuallyDrop::new(inner))
1265        })
1266    }
1267
1268    pub fn peer_info(&self) -> arc_swap::Guard<Arc<PeerInfo>, arc_swap::DefaultStrategy> {
1269        self.inner().peer_info.load()
1270    }
1271
1272    pub fn load_peer_info(&self) -> Arc<PeerInfo> {
1273        arc_swap::Guard::into_inner(self.peer_info())
1274    }
1275
1276    pub fn is_banned(&self) -> bool {
1277        self.inner().is_banned()
1278    }
1279
1280    pub fn max_affinity(&self) -> PeerAffinity {
1281        self.inner().compute_affinity()
1282    }
1283
1284    pub fn update_peer_info(&self, peer_info: &Arc<PeerInfo>) -> Result<(), KnownPeersError> {
1285        match self.inner().try_update_peer_info(peer_info, false) {
1286            Ok(true) => Ok(()),
1287            Ok(false) => Err(KnownPeersError::OutdatedInfo),
1288            Err(e) => Err(KnownPeersError::PeerBanned(e)),
1289        }
1290    }
1291
1292    pub fn ban(&self) -> bool {
1293        let inner = self.inner();
1294        inner.affinity.swap(AFFINITY_BANNED, Ordering::AcqRel) != AFFINITY_BANNED
1295    }
1296
1297    pub fn increase_affinity(&mut self) -> bool {
1298        match &mut self.0 {
1299            KnownPeerHandleState::Simple(inner) => {
1300                // NOTE: Handle will be updated even if the peer is banned.
1301                inner.increase_affinity();
1302
1303                // SAFETY: Inner value was not dropped.
1304                let inner = unsafe { ManuallyDrop::take(inner) };
1305
1306                // Replace the old state with the new one, ensuring that the old state
1307                // is not dropped (because we took the value out of it).
1308                let prev_state = std::mem::replace(
1309                    &mut self.0,
1310                    KnownPeerHandleState::WithAffinity(ManuallyDrop::new(Arc::new(
1311                        KnownPeerHandleWithAffinity { inner },
1312                    ))),
1313                );
1314
1315                // Forget the old state to avoid dropping it.
1316                #[allow(clippy::mem_forget)]
1317                std::mem::forget(prev_state);
1318
1319                true
1320            }
1321            KnownPeerHandleState::WithAffinity(_) => false,
1322        }
1323    }
1324
1325    pub fn decrease_affinity(&mut self) -> bool {
1326        match &mut self.0 {
1327            KnownPeerHandleState::Simple(_) => false,
1328            KnownPeerHandleState::WithAffinity(inner) => {
1329                // NOTE: Handle will be updated even if the peer is banned.
1330                inner.inner.decrease_affinity();
1331
1332                // SAFETY: Inner value was not dropped.
1333                let inner = unsafe { ManuallyDrop::take(inner) };
1334
1335                // Get `KnownPeerInner` out of the wrapper.
1336                let inner = match Arc::try_unwrap(inner) {
1337                    Ok(KnownPeerHandleWithAffinity { inner }) => inner,
1338                    Err(inner) => inner.inner.clone(),
1339                };
1340
1341                // Replace the old state with the new one, ensuring that the old state
1342                // is not dropped (because we took the value out of it).
1343                let prev_state = std::mem::replace(
1344                    &mut self.0,
1345                    KnownPeerHandleState::Simple(ManuallyDrop::new(inner)),
1346                );
1347
1348                // Forget the old state to avoid dropping it.
1349                #[allow(clippy::mem_forget)]
1350                std::mem::forget(prev_state);
1351
1352                true
1353            }
1354        }
1355    }
1356
1357    pub fn downgrade(&self) -> WeakKnownPeerHandle {
1358        WeakKnownPeerHandle(match &self.0 {
1359            KnownPeerHandleState::Simple(data) => {
1360                WeakKnownPeerHandleState::Simple(Arc::downgrade(data))
1361            }
1362            KnownPeerHandleState::WithAffinity(data) => {
1363                WeakKnownPeerHandleState::WithAffinity(Arc::downgrade(data))
1364            }
1365        })
1366    }
1367
1368    fn inner(&self) -> &KnownPeerInner {
1369        match &self.0 {
1370            KnownPeerHandleState::Simple(data) => data.as_ref(),
1371            KnownPeerHandleState::WithAffinity(data) => data.inner.as_ref(),
1372        }
1373    }
1374}
1375
1376#[derive(Clone)]
1377enum KnownPeerHandleState {
1378    Simple(ManuallyDrop<Arc<KnownPeerInner>>),
1379    WithAffinity(ManuallyDrop<Arc<KnownPeerHandleWithAffinity>>),
1380}
1381
1382impl Drop for KnownPeerHandleState {
1383    fn drop(&mut self) {
1384        let inner;
1385        let is_banned;
1386        match self {
1387            KnownPeerHandleState::Simple(data) => {
1388                // SAFETY: inner value is dropped only once
1389                inner = unsafe { ManuallyDrop::take(data) };
1390                is_banned = inner.is_banned();
1391            }
1392            KnownPeerHandleState::WithAffinity(data) => {
1393                // SAFETY: inner value is dropped only once
1394                match Arc::into_inner(unsafe { ManuallyDrop::take(data) }) {
1395                    Some(data) => {
1396                        inner = data.inner;
1397                        is_banned = !inner.decrease_affinity() || inner.is_banned();
1398                    }
1399                    None => return,
1400                }
1401            }
1402        };
1403
1404        if is_banned {
1405            // Don't remove banned peers from the known peers cache
1406            return;
1407        }
1408
1409        if let Some(inner) = Arc::into_inner(inner) {
1410            // If the last reference is dropped, remove the peer from the known peers cache
1411            if let Some(peers) = inner.weak_known_peers.upgrade() {
1412                peers.remove(&inner.peer_info.load().id);
1413                metrics::gauge!(METRIC_KNOWN_PEERS).decrement(1);
1414            }
1415        }
1416    }
1417}
1418
1419#[derive(Clone, PartialEq, Eq)]
1420#[repr(transparent)]
1421pub struct WeakKnownPeerHandle(WeakKnownPeerHandleState);
1422
1423impl WeakKnownPeerHandle {
1424    pub fn upgrade(&self) -> Option<KnownPeerHandle> {
1425        Some(KnownPeerHandle(match &self.0 {
1426            WeakKnownPeerHandleState::Simple(weak) => {
1427                KnownPeerHandleState::Simple(ManuallyDrop::new(weak.upgrade()?))
1428            }
1429            WeakKnownPeerHandleState::WithAffinity(weak) => {
1430                KnownPeerHandleState::WithAffinity(ManuallyDrop::new(weak.upgrade()?))
1431            }
1432        }))
1433    }
1434}
1435
1436#[derive(Clone)]
1437enum WeakKnownPeerHandleState {
1438    Simple(Weak<KnownPeerInner>),
1439    WithAffinity(Weak<KnownPeerHandleWithAffinity>),
1440}
1441
1442impl Eq for WeakKnownPeerHandleState {}
1443impl PartialEq for WeakKnownPeerHandleState {
1444    #[inline]
1445    fn eq(&self, other: &Self) -> bool {
1446        match (self, other) {
1447            (Self::Simple(left), Self::Simple(right)) => Weak::ptr_eq(left, right),
1448            (Self::WithAffinity(left), Self::WithAffinity(right)) => Weak::ptr_eq(left, right),
1449            _ => false,
1450        }
1451    }
1452}
1453
1454struct KnownPeerHandleWithAffinity {
1455    inner: Arc<KnownPeerInner>,
1456}
1457
1458struct KnownPeerInner {
1459    peer_info: ArcSwap<PeerInfo>,
1460    affinity: AtomicUsize,
1461    weak_known_peers: Weak<FastDashMap<PeerId, KnownPeerState>>,
1462}
1463
1464impl KnownPeerInner {
1465    fn new(
1466        peer_info: Arc<PeerInfo>,
1467        with_affinity: bool,
1468        known_peers: &Arc<FastDashMap<PeerId, KnownPeerState>>,
1469    ) -> Arc<Self> {
1470        Arc::new(Self {
1471            peer_info: ArcSwap::from(peer_info),
1472            affinity: AtomicUsize::new(if with_affinity { 1 } else { 0 }),
1473            weak_known_peers: Arc::downgrade(known_peers),
1474        })
1475    }
1476
1477    fn is_banned(&self) -> bool {
1478        self.affinity.load(Ordering::Acquire) == AFFINITY_BANNED
1479    }
1480
1481    fn compute_affinity(&self) -> PeerAffinity {
1482        match self.affinity.load(Ordering::Acquire) {
1483            0 => PeerAffinity::Allowed,
1484            AFFINITY_BANNED => PeerAffinity::Never,
1485            _ => PeerAffinity::High,
1486        }
1487    }
1488
1489    fn increase_affinity(&self) -> bool {
1490        let mut current = self.affinity.load(Ordering::Acquire);
1491        while current != AFFINITY_BANNED {
1492            debug_assert_ne!(current, AFFINITY_BANNED - 1);
1493            match self.affinity.compare_exchange_weak(
1494                current,
1495                current + 1,
1496                Ordering::Release,
1497                Ordering::Acquire,
1498            ) {
1499                Ok(_) => return true,
1500                Err(affinity) => current = affinity,
1501            }
1502        }
1503
1504        false
1505    }
1506
1507    fn decrease_affinity(&self) -> bool {
1508        let mut current = self.affinity.load(Ordering::Acquire);
1509        while current != AFFINITY_BANNED {
1510            debug_assert_ne!(current, 0);
1511            match self.affinity.compare_exchange_weak(
1512                current,
1513                current - 1,
1514                Ordering::Release,
1515                Ordering::Acquire,
1516            ) {
1517                Ok(_) => return true,
1518                Err(affinity) => current = affinity,
1519            }
1520        }
1521
1522        false
1523    }
1524
1525    fn try_update_peer_info(
1526        &self,
1527        peer_info: &Arc<PeerInfo>,
1528        with_affinity: bool,
1529    ) -> Result<bool, PeerBannedError> {
1530        struct AffinityGuard<'a> {
1531            inner: &'a KnownPeerInner,
1532            decrease_on_drop: bool,
1533        }
1534
1535        impl AffinityGuard<'_> {
1536            fn increase_affinity_or_check_ban(&mut self, with_affinity: bool) -> bool {
1537                let with_affinity = with_affinity && !self.decrease_on_drop;
1538                let is_banned = if with_affinity {
1539                    !self.inner.increase_affinity()
1540                } else {
1541                    self.inner.is_banned()
1542                };
1543
1544                if !is_banned && with_affinity {
1545                    self.decrease_on_drop = true;
1546                }
1547
1548                is_banned
1549            }
1550        }
1551
1552        impl Drop for AffinityGuard<'_> {
1553            fn drop(&mut self) {
1554                if self.decrease_on_drop {
1555                    self.inner.decrease_affinity();
1556                }
1557            }
1558        }
1559
1560        // Create a guard to restore the peer affinity in case of an error
1561        let mut guard = AffinityGuard {
1562            inner: self,
1563            decrease_on_drop: false,
1564        };
1565
1566        let mut cur = self.peer_info.load();
1567        let updated = loop {
1568            if guard.increase_affinity_or_check_ban(with_affinity) {
1569                // Do nothing for banned peers
1570                return Err(PeerBannedError);
1571            }
1572
1573            match cur.created_at.cmp(&peer_info.created_at) {
1574                // Do nothing for the same creation time
1575                // TODO: is `created_at` equality enough?
1576                std::cmp::Ordering::Equal => break true,
1577                // Try to update peer info
1578                std::cmp::Ordering::Less => {
1579                    let prev = self.peer_info.compare_and_swap(&*cur, peer_info.clone());
1580                    if std::ptr::eq(cur.as_raw(), prev.as_raw()) {
1581                        break true;
1582                    } else {
1583                        cur = prev;
1584                    }
1585                }
1586                // Allow an outdated data
1587                std::cmp::Ordering::Greater => break false,
1588            }
1589        };
1590
1591        guard.decrease_on_drop = false;
1592        Ok(updated)
1593    }
1594}
1595
1596const AFFINITY_BANNED: usize = usize::MAX;
1597
1598#[derive(Debug, thiserror::Error)]
1599pub enum KnownPeersError {
1600    #[error(transparent)]
1601    PeerBanned(#[from] PeerBannedError),
1602    #[error("provided peer info is outdated")]
1603    OutdatedInfo,
1604}
1605
1606#[derive(Debug, Copy, Clone, thiserror::Error)]
1607#[error("peer is banned")]
1608pub struct PeerBannedError;
1609
1610#[cfg(test)]
1611mod tests {
1612    use super::*;
1613    use crate::util::make_peer_info_stub;
1614
1615    #[test]
1616    fn remove_from_cache_on_drop_works() {
1617        let peers = KnownPeers::new();
1618
1619        let peer_info = make_peer_info_stub(rand::random());
1620        let handle = peers.insert(peer_info.clone(), false).unwrap();
1621        assert!(peers.contains(&peer_info.id));
1622        assert!(!peers.is_banned(&peer_info.id));
1623        assert_eq!(peers.get(&peer_info.id), Some(peer_info.clone()));
1624        assert_eq!(
1625            peers.get_affinity(&peer_info.id),
1626            Some(PeerAffinity::Allowed)
1627        );
1628
1629        assert_eq!(handle.peer_info().as_ref(), peer_info.as_ref());
1630        assert_eq!(handle.max_affinity(), PeerAffinity::Allowed);
1631
1632        let other_handle = peers.insert(peer_info.clone(), false).unwrap();
1633        assert!(peers.contains(&peer_info.id));
1634        assert!(!peers.is_banned(&peer_info.id));
1635        assert_eq!(peers.get(&peer_info.id), Some(peer_info.clone()));
1636        assert_eq!(
1637            peers.get_affinity(&peer_info.id),
1638            Some(PeerAffinity::Allowed)
1639        );
1640
1641        assert_eq!(other_handle.peer_info().as_ref(), peer_info.as_ref());
1642        assert_eq!(other_handle.max_affinity(), PeerAffinity::Allowed);
1643
1644        drop(other_handle);
1645        assert!(peers.contains(&peer_info.id));
1646        assert!(!peers.is_banned(&peer_info.id));
1647        assert_eq!(peers.get(&peer_info.id), Some(peer_info.clone()));
1648        assert_eq!(
1649            peers.get_affinity(&peer_info.id),
1650            Some(PeerAffinity::Allowed)
1651        );
1652
1653        drop(handle);
1654        assert!(!peers.contains(&peer_info.id));
1655        assert!(!peers.is_banned(&peer_info.id));
1656        assert_eq!(peers.get(&peer_info.id), None);
1657        assert_eq!(peers.get_affinity(&peer_info.id), None);
1658
1659        peers.insert(peer_info.clone(), false).unwrap();
1660    }
1661
1662    #[test]
1663    fn with_affinity_after_simple() {
1664        let peers = KnownPeers::new();
1665
1666        let peer_info = make_peer_info_stub(rand::random());
1667        let handle_simple = peers.insert(peer_info.clone(), false).unwrap();
1668        assert!(peers.contains(&peer_info.id));
1669        assert_eq!(
1670            peers.get_affinity(&peer_info.id),
1671            Some(PeerAffinity::Allowed)
1672        );
1673        assert_eq!(handle_simple.max_affinity(), PeerAffinity::Allowed);
1674
1675        let handle_with_affinity = peers.insert(peer_info.clone(), true).unwrap();
1676        assert!(peers.contains(&peer_info.id));
1677        assert_eq!(peers.get_affinity(&peer_info.id), Some(PeerAffinity::High));
1678        assert_eq!(handle_with_affinity.max_affinity(), PeerAffinity::High);
1679        assert_eq!(handle_simple.max_affinity(), PeerAffinity::High);
1680
1681        drop(handle_with_affinity);
1682        assert!(peers.contains(&peer_info.id));
1683        assert_eq!(handle_simple.max_affinity(), PeerAffinity::Allowed);
1684        assert_eq!(
1685            peers.get_affinity(&peer_info.id),
1686            Some(PeerAffinity::Allowed)
1687        );
1688
1689        drop(handle_simple);
1690        assert!(!peers.contains(&peer_info.id));
1691        assert_eq!(peers.get_affinity(&peer_info.id), None);
1692    }
1693
1694    #[test]
1695    fn with_affinity_before_simple() {
1696        let peers = KnownPeers::new();
1697
1698        let peer_info = make_peer_info_stub(rand::random());
1699        let handle_with_affinity = peers.insert(peer_info.clone(), true).unwrap();
1700        assert!(peers.contains(&peer_info.id));
1701        assert_eq!(peers.get_affinity(&peer_info.id), Some(PeerAffinity::High));
1702        assert_eq!(handle_with_affinity.max_affinity(), PeerAffinity::High);
1703
1704        let handle_simple = peers.insert(peer_info.clone(), false).unwrap();
1705        assert!(peers.contains(&peer_info.id));
1706        assert_eq!(peers.get_affinity(&peer_info.id), Some(PeerAffinity::High));
1707        assert_eq!(handle_with_affinity.max_affinity(), PeerAffinity::High);
1708        assert_eq!(handle_simple.max_affinity(), PeerAffinity::High);
1709
1710        drop(handle_simple);
1711        assert!(peers.contains(&peer_info.id));
1712        assert_eq!(handle_with_affinity.max_affinity(), PeerAffinity::High);
1713        assert_eq!(peers.get_affinity(&peer_info.id), Some(PeerAffinity::High));
1714
1715        drop(handle_with_affinity);
1716        assert!(!peers.contains(&peer_info.id));
1717        assert_eq!(peers.get_affinity(&peer_info.id), None);
1718    }
1719
1720    #[test]
1721    fn ban_while_handle_exists() {
1722        let peers = KnownPeers::new();
1723
1724        let peer_info = make_peer_info_stub(rand::random());
1725        let handle = peers.insert(peer_info.clone(), false).unwrap();
1726        assert!(peers.contains(&peer_info.id));
1727        assert_eq!(handle.max_affinity(), PeerAffinity::Allowed);
1728
1729        peers.ban(&peer_info.id);
1730        assert!(peers.contains(&peer_info.id));
1731        assert!(peers.is_banned(&peer_info.id));
1732        assert_eq!(handle.max_affinity(), PeerAffinity::Never);
1733        assert_eq!(peers.get(&peer_info.id), Some(peer_info.clone()));
1734        assert_eq!(peers.get_affinity(&peer_info.id), Some(PeerAffinity::Never));
1735    }
1736}