tycho_network/network/
connection_manager.rs

1use std::collections::{VecDeque, hash_map};
2use std::mem::ManuallyDrop;
3use std::sync::atomic::{AtomicUsize, Ordering};
4use std::sync::{Arc, Weak};
5use std::time::{Duration, Instant};
6
7use anyhow::Result;
8use arc_swap::{ArcSwap, AsRaw};
9use tokio::sync::{broadcast, mpsc, oneshot};
10use tokio::task::{AbortHandle, JoinSet};
11use tokio_util::time::{DelayQueue, delay_queue};
12use tycho_util::{FastDashMap, FastHashMap};
13
14use crate::network::ConnectionError;
15use crate::network::config::NetworkConfig;
16use crate::network::connection::Connection;
17use crate::network::endpoint::{Connecting, ConnectionInitError, Endpoint, Into0RttResult};
18use crate::network::request_handler::InboundRequestHandler;
19use crate::network::wire::{HandshakeError, handshake};
20use crate::types::{
21    Address, BoxCloneService, Direction, DisconnectReason, PeerAffinity, PeerEvent, PeerId,
22    PeerInfo, Response, ServiceRequest,
23};
24
25// Histograms
26const METRIC_CONNECTION_OUT_TIME: &str = "tycho_net_conn_out_time";
27const METRIC_CONNECTION_IN_TIME: &str = "tycho_net_conn_in_time";
28
29// Counters
30const METRIC_CONNECTIONS_OUT_TOTAL: &str = "tycho_net_conn_out_total";
31const METRIC_CONNECTIONS_IN_TOTAL: &str = "tycho_net_conn_in_total";
32const METRIC_CONNECTIONS_OUT_FAIL_TOTAL: &str = "tycho_net_conn_out_fail_total";
33const METRIC_CONNECTIONS_IN_FAIL_TOTAL: &str = "tycho_net_conn_in_fail_total";
34
35// Gauges
36const METRIC_CONNECTIONS_ACTIVE: &str = "tycho_net_conn_active";
37const METRIC_CONNECTIONS_PENDING: &str = "tycho_net_conn_pending";
38const METRIC_CONNECTIONS_PARTIAL: &str = "tycho_net_conn_partial";
39const METRIC_CONNECTIONS_PENDING_DIALS: &str = "tycho_net_conn_pending_dials";
40
41const METRIC_ACTIVE_PEERS: &str = "tycho_net_active_peers";
42const METRIC_KNOWN_PEERS: &str = "tycho_net_known_peers";
43
44#[derive(Debug)]
45pub(crate) enum ConnectionManagerRequest {
46    Connect(Address, PeerId, CallbackTx),
47    Shutdown(oneshot::Sender<()>),
48}
49
50pub(crate) struct ConnectionManager {
51    config: Arc<NetworkConfig>,
52    endpoint: Arc<Endpoint>,
53
54    mailbox: mpsc::Receiver<ConnectionManagerRequest>,
55
56    pending_connection_callbacks: FastHashMap<Address, PendingConnectionCallbacks>,
57    pending_partial_connections: JoinSet<Option<PartialConnection>>,
58    pending_connections: JoinSet<ConnectingOutput>,
59    connection_handlers: JoinSet<()>,
60    delayed_callbacks: DelayedCallbacksQueue,
61
62    pending_dials: FastHashMap<PeerId, CallbackRx>,
63    dial_backoff_states: FastHashMap<PeerId, DialBackoffState>,
64
65    active_peers: ActivePeers,
66    known_peers: KnownPeers,
67
68    service: BoxCloneService<ServiceRequest, Response>,
69}
70
71type CallbackTx = oneshot::Sender<Result<Connection, ConnectionError>>;
72type CallbackRx = oneshot::Receiver<Result<Connection, ConnectionError>>;
73
74impl Drop for ConnectionManager {
75    fn drop(&mut self) {
76        tracing::trace!("dropping connection manager");
77        self.endpoint.close();
78    }
79}
80
81impl ConnectionManager {
82    pub fn new(
83        config: Arc<NetworkConfig>,
84        endpoint: Arc<Endpoint>,
85        active_peers: ActivePeers,
86        known_peers: KnownPeers,
87        service: BoxCloneService<ServiceRequest, Response>,
88    ) -> (Self, mpsc::Sender<ConnectionManagerRequest>) {
89        let (mailbox_tx, mailbox) = mpsc::channel(config.connection_manager_channel_capacity);
90        let connection_manager = Self {
91            config,
92            endpoint,
93            mailbox,
94            pending_connection_callbacks: Default::default(),
95            pending_partial_connections: Default::default(),
96            pending_connections: Default::default(),
97            connection_handlers: Default::default(),
98            delayed_callbacks: Default::default(),
99            pending_dials: Default::default(),
100            dial_backoff_states: Default::default(),
101            active_peers,
102            known_peers,
103            service,
104        };
105        (connection_manager, mailbox_tx)
106    }
107
108    pub async fn start(mut self) {
109        tracing::info!("connection manager started");
110
111        let jitter = Duration::from_millis(1000).mul_f64(rand::random());
112        let mut interval = tokio::time::interval(self.config.connectivity_check_interval + jitter);
113
114        let mut shutdown_notifier = None;
115
116        loop {
117            tokio::select! {
118                now = interval.tick() => {
119                    self.handle_connectivity_check(now.into_std());
120                }
121                maybe_request = self.mailbox.recv() => {
122                    let Some(request) = maybe_request else {
123                        break;
124                    };
125
126                    match request {
127                        ConnectionManagerRequest::Connect(address, peer_id, callback) => {
128                            self.handle_connect_request(address, &peer_id, callback);
129                        }
130                        ConnectionManagerRequest::Shutdown(oneshot) => {
131                            shutdown_notifier = Some(oneshot);
132                            break;
133                        }
134                    }
135                }
136                connecting = self.endpoint.accept() => {
137                    if let Some(connecting) = connecting {
138                        self.handle_incoming(connecting);
139                    }
140                }
141                Some(connecting_output) = self.pending_connections.join_next() => {
142                    metrics::gauge!(METRIC_CONNECTIONS_PENDING).decrement(1);
143                    match connecting_output {
144                        Ok(connecting) => self.handle_connecting_result(connecting),
145                        Err(e) => {
146                            if e.is_panic() {
147                                std::panic::resume_unwind(e.into_panic());
148                            }
149                        }
150                    }
151                }
152                Some(partial_connection) = self.pending_partial_connections.join_next() => {
153                    metrics::gauge!(METRIC_CONNECTIONS_PARTIAL).decrement(1);
154
155                    match partial_connection {
156                        Ok(None) => {}
157                        Ok(Some(PartialConnection {
158                            connection,
159                            timeout_at,
160                        })) => self.handle_incoming_impl(connection, None, timeout_at),
161                        // NOTE: unwrap here is to propagate panic from the spawned future
162                        Err(e) => {
163                            if e.is_panic() {
164                                std::panic::resume_unwind(e.into_panic());
165                            }
166                        }
167                    }
168                }
169                Some(connection_handler_output) = self.connection_handlers.join_next() => {
170                    metrics::gauge!(METRIC_CONNECTIONS_ACTIVE).decrement(1);
171
172                    // NOTE: unwrap here is to propagate panic from the spawned future
173                    if let Err(e) = connection_handler_output {
174                        if e.is_panic() {
175                            std::panic::resume_unwind(e.into_panic());
176                        }
177                    }
178                }
179                Some(peer_id) = self.delayed_callbacks.wait_for_next_expired() => {
180                    self.delayed_callbacks.execute_expired(&peer_id);
181                }
182            }
183        }
184
185        self.shutdown().await;
186
187        if let Some(tx) = shutdown_notifier {
188            _ = tx.send(());
189        }
190
191        tracing::info!("connection manager stopped");
192    }
193
194    async fn shutdown(mut self) {
195        tracing::trace!("shutting down connection manager");
196
197        self.endpoint.close();
198
199        self.pending_partial_connections.shutdown().await;
200        metrics::gauge!(METRIC_CONNECTIONS_PARTIAL).set(0);
201
202        self.pending_connections.shutdown().await;
203        metrics::gauge!(METRIC_CONNECTIONS_PENDING).set(0);
204
205        while self.connection_handlers.join_next().await.is_some() {
206            metrics::gauge!(METRIC_CONNECTIONS_ACTIVE).decrement(1);
207        }
208        assert!(self.active_peers.is_empty());
209
210        self.endpoint
211            .wait_idle(self.config.shutdown_idle_timeout)
212            .await;
213    }
214
215    fn handle_connectivity_check(&mut self, now: Instant) {
216        use std::collections::hash_map::Entry;
217
218        self.pending_dials
219            .retain(|peer_id, oneshot| match oneshot.try_recv() {
220                Ok(Ok(returned_peer_id)) => {
221                    debug_assert_eq!(peer_id, returned_peer_id.peer_id());
222                    self.dial_backoff_states.remove(peer_id);
223                    false
224                }
225                Ok(Err(_)) => {
226                    match self.dial_backoff_states.entry(*peer_id) {
227                        Entry::Occupied(mut entry) => entry.get_mut().update(
228                            now,
229                            self.config.connection_backoff,
230                            self.config.max_connection_backoff,
231                        ),
232                        Entry::Vacant(entry) => {
233                            entry.insert(DialBackoffState::new(
234                                now,
235                                self.config.connection_backoff,
236                                self.config.max_connection_backoff,
237                            ));
238                        }
239                    }
240                    false
241                }
242                Err(oneshot::error::TryRecvError::Closed) => {
243                    panic!("BUG: connection manager never finished dialing a peer");
244                }
245                Err(oneshot::error::TryRecvError::Empty) => true,
246            });
247
248        let outstanding_connections_limit = self
249            .config
250            .max_concurrent_outstanding_connections
251            .saturating_sub(self.pending_connections.len());
252
253        let outstanding_connections = self
254            .known_peers
255            .0
256            .iter()
257            .filter_map(|item| {
258                let value = match item.value() {
259                    KnownPeerState::Stored(item) => item.upgrade()?,
260                    KnownPeerState::Banned => return None,
261                };
262                let peer_info = value.peer_info.load();
263                let affinity = value.compute_affinity();
264
265                (affinity == PeerAffinity::High
266                    && peer_info.id != self.endpoint.peer_id()
267                    && !self.active_peers.contains(&peer_info.id)
268                    && !self.pending_dials.contains_key(&peer_info.id)
269                    && self
270                        .dial_backoff_states
271                        .get(&peer_info.id)
272                        .is_none_or(|state| now > state.next_attempt_at))
273                .then(|| arc_swap::Guard::into_inner(peer_info))
274            })
275            .take(outstanding_connections_limit)
276            .collect::<Vec<_>>();
277
278        for peer_info in outstanding_connections {
279            // TODO: handle multiple addresses
280            let address = peer_info
281                .iter_addresses()
282                .next()
283                .cloned()
284                .expect("address list must have at least one item");
285
286            let (tx, rx) = oneshot::channel();
287            self.dial_peer(address, &peer_info.id, tx);
288            self.pending_dials.insert(peer_info.id, rx);
289        }
290
291        metrics::gauge!(METRIC_CONNECTIONS_PENDING_DIALS).set(self.pending_dials.len() as f64);
292    }
293
294    fn handle_connect_request(&mut self, address: Address, peer_id: &PeerId, callback: CallbackTx) {
295        self.dial_peer(address, peer_id, callback);
296    }
297
298    fn handle_incoming(&mut self, connecting: Connecting) {
299        let remote_addr = connecting.remote_address();
300        tracing::trace!(
301            local_id = %self.endpoint.peer_id(),
302            %remote_addr,
303            "received an incoming connection",
304        );
305
306        // Split incoming connection into 0.5-RTT and 1-RTT parts.
307        match connecting.into_0rtt() {
308            Into0RttResult::Established(connection, accepted) => {
309                let timeout_at = Instant::now() + self.config.connect_timeout;
310                self.handle_incoming_impl(connection, Some(accepted), timeout_at);
311            }
312            Into0RttResult::WithoutIdentity(partial_connection) => {
313                tracing::trace!("connection identity is not available yet");
314
315                let timeout_at = Instant::now() + self.config.connect_timeout;
316                self.pending_partial_connections.spawn(async move {
317                    match tokio::time::timeout_at(timeout_at.into(), partial_connection).await {
318                        Ok(Ok(connection)) => Some(PartialConnection {
319                            connection,
320                            timeout_at,
321                        }),
322                        Ok(Err(e)) => {
323                            tracing::trace!(
324                                %remote_addr,
325                                "failed to establish an incoming connection: {e}",
326                            );
327                            None
328                        }
329                        Err(_) => {
330                            tracing::trace!(
331                                %remote_addr,
332                                "incoming connection timed out",
333                            );
334                            None
335                        }
336                    }
337                });
338                metrics::gauge!(METRIC_CONNECTIONS_PARTIAL).increment(1);
339            }
340            Into0RttResult::InvalidCertificate => {
341                tracing::trace!(%remote_addr, "invalid incoming connection");
342            }
343            Into0RttResult::Unavailable(_) => unreachable!(
344                "BUG: For incoming connections, a 0.5-RTT connection must \
345                always be successfully constructed."
346            ),
347        };
348    }
349
350    fn handle_incoming_impl(
351        &mut self,
352        connection: Connection,
353        accepted: Option<quinn::ZeroRttAccepted>,
354        timeout_at: Instant,
355    ) {
356        async fn handle_incoming_task(
357            seqno: u32,
358            connection: ConnectionClosedOnDrop,
359            accepted: Option<quinn::ZeroRttAccepted>,
360            timeout_at: Instant,
361        ) -> ConnectingOutput {
362            let target_peer_id = *connection.peer_id();
363            let target_address = connection.remote_address().into();
364            let fut = async {
365                if let Some(accepted) = accepted {
366                    // NOTE: `bool` output of this future is meaningless for servers.
367                    accepted.await;
368                }
369                handshake(&connection).await
370            };
371
372            let started_at = Instant::now();
373
374            let connecting_result = match tokio::time::timeout_at(timeout_at.into(), fut).await {
375                Ok(Ok(())) => Ok(connection.disarm()),
376                Ok(Err(e)) => Err(FullConnectionError::HandshakeFailed(e)),
377                Err(_) => Err(FullConnectionError::Timeout),
378            };
379
380            metrics::histogram!(METRIC_CONNECTION_IN_TIME).record(started_at.elapsed());
381
382            ConnectingOutput {
383                seqno,
384                drop_result: true,
385                connecting_result: ManuallyDrop::new(connecting_result),
386                target_address,
387                target_peer_id,
388                origin: Direction::Inbound,
389            }
390        }
391
392        let remote_addr = connection.remote_address();
393
394        // Check if the peer is allowed before doing anything else.
395        match self.known_peers.get_affinity(connection.peer_id()) {
396            Some(PeerAffinity::High | PeerAffinity::Allowed) => {}
397            Some(PeerAffinity::Never) => {
398                // TODO: Lower log level to trace/debug?
399                tracing::warn!(
400                    %remote_addr,
401                    peer_id = %connection.peer_id(),
402                    "rejecting connection due to PeerAffinity::Never",
403                );
404                connection.close();
405                return;
406            }
407            _ => {
408                if matches!(
409                    self.config.max_concurrent_connections,
410                    Some(limit) if self.active_peers.len() >= limit
411                ) {
412                    // TODO: Lower log level to trace/debug?
413                    tracing::warn!(
414                        %remote_addr,
415                        peer_id = %connection.peer_id(),
416                        "rejecting connection due too many concurrent connections",
417                    );
418                    connection.close();
419                    return;
420                }
421            }
422        }
423
424        let entry = match self.pending_connection_callbacks.entry(remote_addr.into()) {
425            hash_map::Entry::Vacant(entry) => Some(entry.insert(PendingConnectionCallbacks {
426                last_seqno: 0,
427                origin: Direction::Inbound,
428                callbacks: Default::default(),
429                abort_handle: None,
430            })),
431            hash_map::Entry::Occupied(entry) => {
432                let entry = entry.into_mut();
433
434                // Check if the incoming connection is a simultaneous dial.
435                if simultaneous_dial_tie_breaking(
436                    self.endpoint.peer_id(),
437                    connection.peer_id(),
438                    entry.origin,
439                    Direction::Inbound,
440                ) {
441                    // New connection wins the tie, abort the old one and spawn a new task.
442                    tracing::debug!(
443                        %remote_addr,
444                        peer_id = %connection.peer_id(),
445                        "cancelling old connection to mitigate simultaneous dial",
446                    );
447
448                    entry.origin = Direction::Inbound;
449                    entry.last_seqno += 1;
450                    if let Some(handle) = entry.abort_handle.take() {
451                        handle.abort();
452                    }
453                    Some(entry)
454                } else {
455                    // Old connection wins the tie, gracefully close the new one.
456                    tracing::debug!(
457                        %remote_addr,
458                        peer_id = %connection.peer_id(),
459                        "cancelling new connection to mitigate simultaneous dial",
460                    );
461
462                    connection.close();
463                    None
464                }
465            }
466        };
467
468        if let Some(entry) = entry {
469            entry.abort_handle = Some(self.pending_connections.spawn(handle_incoming_task(
470                entry.last_seqno,
471                ConnectionClosedOnDrop::new(connection),
472                accepted,
473                timeout_at,
474            )));
475            metrics::gauge!(METRIC_CONNECTIONS_PENDING).increment(1);
476        }
477    }
478
479    fn handle_connecting_result(&mut self, mut res: ConnectingOutput) {
480        // Check seqno first to drop outdated results.
481        {
482            let Some(entry) = self.pending_connection_callbacks.get(&res.target_address) else {
483                tracing::trace!("connection task reordering detected");
484                return;
485            };
486
487            if entry.last_seqno != res.seqno {
488                tracing::debug!(
489                    local_id = %self.endpoint.peer_id(),
490                    peer_id = %res.target_peer_id,
491                    remote_addr = %res.target_address,
492                    "connection result is outdated"
493                );
494                return;
495            }
496        }
497
498        let callbacks = self
499            .pending_connection_callbacks
500            .remove(&res.target_address)
501            .expect("Connection tasks must be tracked")
502            .callbacks;
503
504        res.drop_result = false;
505        // SAFETY: `drop_result` is set to `false`.
506        match unsafe { ManuallyDrop::take(&mut res.connecting_result) } {
507            Ok(connection) => {
508                tracing::debug!(
509                    local_id = %self.endpoint.peer_id(),
510                    peer_id = %connection.peer_id(),
511                    remote_addr = %res.target_address,
512                    "new connection",
513                );
514
515                let connection = self.add_peer(connection);
516                self.delayed_callbacks.execute_resolved(&connection);
517
518                for callback in callbacks {
519                    _ = callback.send(Ok(connection.clone()));
520                }
521            }
522            Err(e) => {
523                tracing::debug!(
524                    local_id = %self.endpoint.peer_id(),
525                    peer_id = %res.target_peer_id,
526                    remote_addr = %res.target_address,
527                    "connection failed: {e:?}"
528                );
529
530                metrics::counter!(match res.origin {
531                    Direction::Outbound => METRIC_CONNECTIONS_OUT_FAIL_TOTAL,
532                    Direction::Inbound => METRIC_CONNECTIONS_IN_FAIL_TOTAL,
533                })
534                .increment(1);
535
536                let brief_error = e.as_brief();
537
538                // Delay sending the error to callbacks as the target peer might be
539                // in the process of connecting to us.
540                if e.was_closed() && !callbacks.is_empty() {
541                    self.delayed_callbacks.push(
542                        &res.target_peer_id,
543                        brief_error,
544                        callbacks,
545                        &self.config.connection_error_delay,
546                    );
547                    return;
548                }
549
550                for callback in callbacks {
551                    _ = callback.send(Err(brief_error));
552                }
553            }
554        }
555    }
556
557    fn add_peer(&mut self, connection: Connection) -> Connection {
558        match self.active_peers.add(self.endpoint.peer_id(), connection) {
559            AddedPeer::New(connection) => {
560                let origin = connection.origin();
561
562                let handler = InboundRequestHandler::new(
563                    self.config.clone(),
564                    connection.clone(),
565                    self.service.clone(),
566                    self.active_peers.clone(),
567                );
568
569                metrics::counter!(match origin {
570                    Direction::Outbound => METRIC_CONNECTIONS_OUT_TOTAL,
571                    Direction::Inbound => METRIC_CONNECTIONS_IN_TOTAL,
572                })
573                .increment(1);
574
575                metrics::gauge!(METRIC_CONNECTIONS_ACTIVE).increment(1);
576                self.connection_handlers.spawn(handler.start());
577
578                connection
579            }
580            AddedPeer::Existing(connection) => connection,
581        }
582    }
583
584    #[tracing::instrument(
585        level = "trace",
586        skip_all,
587        fields(
588            local_id = %self.endpoint.peer_id(),
589            peer_id = %peer_id,
590            remote_addr = %address,
591        ),
592    )]
593    fn dial_peer(&mut self, address: Address, peer_id: &PeerId, callback: CallbackTx) {
594        async fn dial_peer_task(
595            seqno: u32,
596            endpoint: Arc<Endpoint>,
597            address: Address,
598            peer_id: PeerId,
599            config: Arc<NetworkConfig>,
600        ) -> ConnectingOutput {
601            let fut = async {
602                let address = address
603                    .resolve()
604                    .await
605                    .map_err(FullConnectionError::InvalidAddress)?;
606
607                let connecting = endpoint
608                    .connect_with_expected_id(&address, &peer_id)
609                    .map_err(|e| FullConnectionError::InvalidAddress(std::io::Error::other(e)))?;
610
611                let connection = ConnectionClosedOnDrop::new(connecting.await?);
612                match handshake(&connection).await {
613                    Ok(()) => Ok(connection),
614                    Err(e) => Err(FullConnectionError::HandshakeFailed(e)),
615                }
616            };
617
618            let started_at = Instant::now();
619
620            let connecting_result = match tokio::time::timeout(config.connect_timeout, fut).await {
621                Ok(res) => res.map(ConnectionClosedOnDrop::disarm),
622                Err(_) => Err(FullConnectionError::Timeout),
623            };
624
625            metrics::histogram!(METRIC_CONNECTION_OUT_TIME).record(started_at.elapsed());
626
627            ConnectingOutput {
628                seqno,
629                drop_result: true,
630                connecting_result: ManuallyDrop::new(connecting_result),
631                target_address: address,
632                target_peer_id: peer_id,
633                origin: Direction::Outbound,
634            }
635        }
636
637        if let Some(connection) = self.active_peers.get(peer_id) {
638            tracing::debug!("peer is already connected");
639            _ = callback.send(Ok(connection));
640            return;
641        }
642
643        tracing::trace!("connecting to peer");
644
645        let entry = match self.pending_connection_callbacks.entry(address.clone()) {
646            hash_map::Entry::Vacant(entry) => Some(entry.insert(PendingConnectionCallbacks {
647                last_seqno: 0,
648                origin: Direction::Outbound,
649                callbacks: vec![callback],
650                abort_handle: None,
651            })),
652            hash_map::Entry::Occupied(entry) => {
653                let entry = entry.into_mut();
654
655                // Add the callback to the existing entry.
656                entry.callbacks.push(callback);
657
658                // Check if the outgoing connection is a simultaneous dial.
659                let break_tie = simultaneous_dial_tie_breaking(
660                    self.endpoint.peer_id(),
661                    peer_id,
662                    entry.origin,
663                    Direction::Outbound,
664                );
665
666                if break_tie && entry.origin != Direction::Outbound {
667                    // New connection wins the tie, abort the old one and spawn a new task.
668                    tracing::debug!("cancelling old connection to mitigate simultaneous dial");
669
670                    entry.origin = Direction::Outbound;
671                    entry.last_seqno += 1;
672                    if let Some(handle) = entry.abort_handle.take() {
673                        handle.abort();
674                    }
675                    Some(entry)
676                } else {
677                    // Old connection wins the tie, don't create a new one
678                    tracing::trace!("reusing old connection to mitigate simultaneous dial");
679                    None
680                }
681            }
682        };
683
684        if let Some(entry) = entry {
685            entry.abort_handle = Some(self.pending_connections.spawn(dial_peer_task(
686                entry.last_seqno,
687                self.endpoint.clone(),
688                address.clone(),
689                *peer_id,
690                self.config.clone(),
691            )));
692            metrics::gauge!(METRIC_CONNECTIONS_PENDING).increment(1);
693        }
694    }
695}
696
697struct PendingConnectionCallbacks {
698    last_seqno: u32,
699    origin: Direction,
700    callbacks: Vec<CallbackTx>,
701    abort_handle: Option<AbortHandle>,
702}
703
704struct PartialConnection {
705    connection: Connection,
706    timeout_at: Instant,
707}
708
709struct ConnectingOutput {
710    seqno: u32,
711    drop_result: bool,
712    connecting_result: ManuallyDrop<Result<Connection, FullConnectionError>>,
713    target_address: Address,
714    target_peer_id: PeerId,
715    origin: Direction,
716}
717
718impl Drop for ConnectingOutput {
719    fn drop(&mut self) {
720        if self.drop_result {
721            // SAFETY: `drop_result` is set to `true` only when the result is not used.
722            unsafe { ManuallyDrop::drop(&mut self.connecting_result) };
723        }
724    }
725}
726
727struct ConnectionClosedOnDrop {
728    connection: ManuallyDrop<Connection>,
729    close_on_drop: bool,
730}
731
732impl ConnectionClosedOnDrop {
733    fn new(connection: Connection) -> Self {
734        Self {
735            connection: ManuallyDrop::new(connection),
736            close_on_drop: true,
737        }
738    }
739
740    fn disarm(mut self) -> Connection {
741        self.close_on_drop = false;
742        // SAFETY: `drop` will not be called.
743        unsafe { ManuallyDrop::take(&mut self.connection) }
744    }
745}
746
747impl std::ops::Deref for ConnectionClosedOnDrop {
748    type Target = Connection;
749
750    #[inline]
751    fn deref(&self) -> &Self::Target {
752        &self.connection
753    }
754}
755
756impl Drop for ConnectionClosedOnDrop {
757    fn drop(&mut self) {
758        if self.close_on_drop {
759            // SAFETY: `disarm` was not called.
760            let connection = unsafe { ManuallyDrop::take(&mut self.connection) };
761            connection.close();
762        }
763    }
764}
765
766#[derive(Default)]
767struct DelayedCallbacksQueue {
768    callbacks: FastHashMap<PeerId, VecDeque<DelayedCallbacks>>,
769    expirations: DelayQueue<PeerId>,
770}
771
772impl DelayedCallbacksQueue {
773    fn push(
774        &mut self,
775        peer_id: &PeerId,
776        error: ConnectionError,
777        callbacks: Vec<CallbackTx>,
778        delay: &Duration,
779    ) {
780        tracing::debug!(%peer_id, %error, "delayed connection error");
781
782        let expires_at = Instant::now() + *delay;
783        let delay_key = self.expirations.insert_at(*peer_id, expires_at.into());
784
785        let items = self.callbacks.entry(*peer_id).or_default();
786        items.push_back(DelayedCallbacks {
787            delay_key,
788            error,
789            callbacks,
790            expires_at,
791        });
792    }
793
794    async fn wait_for_next_expired(&mut self) -> Option<PeerId> {
795        let res = futures_util::future::poll_fn(|cx| self.expirations.poll_expired(cx)).await?;
796        Some(res.into_inner())
797    }
798
799    fn execute_resolved(&mut self, connection: &Connection) {
800        let Some(items) = self.callbacks.remove(connection.peer_id()) else {
801            return;
802        };
803
804        let mut batches_executed = 0;
805        let mut callbacks_executed = 0;
806
807        for delayed in items {
808            batches_executed += 1;
809            callbacks_executed += delayed.callbacks.len();
810
811            let key = delayed.execute_with_ok(connection);
812
813            // NOTE: Delay key must exist in the queue.
814            self.expirations.remove(&key);
815        }
816
817        tracing::debug!(
818            peer_id = %connection.peer_id(),
819            batches_executed,
820            callbacks_executed,
821            "executed all delayed callbacks",
822        );
823    }
824
825    fn execute_expired(&mut self, peer_id: &PeerId) {
826        let now = Instant::now();
827
828        let mut batches_executed = 0;
829        let mut callbacks_executed = 0;
830
831        'outer: {
832            if let Some(items) = self.callbacks.get_mut(peer_id) {
833                while let Some(front) = items.front() {
834                    if !front.is_expired(&now) {
835                        break 'outer;
836                    }
837
838                    if let Some(delayed) = items.pop_front() {
839                        batches_executed += 1;
840                        callbacks_executed += delayed.callbacks.len();
841
842                        let key = delayed.execute_with_error();
843
844                        // NOTE: Might not be necessary since items are stored in order,
845                        // but it's better to be safe.
846                        self.expirations.try_remove(&key);
847                    }
848                }
849            }
850
851            // There is no need to hold an empty queue for this peer
852            self.callbacks.remove(peer_id);
853        }
854
855        tracing::debug!(
856            %peer_id,
857            batches_executed,
858            callbacks_executed,
859            "executed expired delayed callbacks"
860        );
861    }
862}
863
864struct DelayedCallbacks {
865    delay_key: delay_queue::Key,
866    error: ConnectionError,
867    callbacks: Vec<CallbackTx>,
868    expires_at: Instant,
869}
870
871impl DelayedCallbacks {
872    fn execute_with_ok(self, connection: &Connection) -> delay_queue::Key {
873        for callback in self.callbacks {
874            _ = callback.send(Ok(connection.clone()));
875        }
876        self.delay_key
877    }
878
879    fn execute_with_error(self) -> delay_queue::Key {
880        for callback in self.callbacks {
881            _ = callback.send(Err(self.error));
882        }
883        self.delay_key
884    }
885
886    fn is_expired(&self, now: &Instant) -> bool {
887        *now >= self.expires_at
888    }
889}
890
891#[derive(Debug)]
892struct DialBackoffState {
893    next_attempt_at: Instant,
894    attempts: usize,
895}
896
897impl DialBackoffState {
898    fn new(now: Instant, step: Duration, max: Duration) -> Self {
899        let mut state = Self {
900            next_attempt_at: now,
901            attempts: 0,
902        };
903        state.update(now, step, max);
904        state
905    }
906
907    fn update(&mut self, now: Instant, step: Duration, max: Duration) {
908        self.attempts += 1;
909        self.next_attempt_at = now
910            + std::cmp::min(
911                max,
912                step.saturating_mul(self.attempts.try_into().unwrap_or(u32::MAX)),
913            );
914    }
915}
916
917#[derive(Debug, thiserror::Error)]
918enum FullConnectionError {
919    #[error("invalid address")]
920    InvalidAddress(#[source] std::io::Error),
921    #[error(transparent)]
922    ConnectionFailed(quinn::ConnectionError),
923    #[error("invalid certificate")]
924    InvalidCertificate,
925    #[error("handshake failed")]
926    HandshakeFailed(#[source] HandshakeError),
927    #[error("connection timeout")]
928    Timeout,
929}
930
931impl FullConnectionError {
932    fn as_brief(&self) -> ConnectionError {
933        fn is_crypto_error(error: &quinn::ConnectionError) -> bool {
934            const QUIC_CRYPTO_FLAG: u64 = 0x100;
935
936            matches!(
937                error,
938                quinn::ConnectionError::TransportError(e)
939                if u64::from(e.code) & QUIC_CRYPTO_FLAG != 0
940            )
941        }
942
943        match self {
944            Self::InvalidAddress(_) => ConnectionError::InvalidAddress,
945            Self::ConnectionFailed(e) if is_crypto_error(e) => ConnectionError::InvalidCertificate,
946            Self::ConnectionFailed(_) => ConnectionError::ConnectionInitFailed,
947            Self::InvalidCertificate => ConnectionError::InvalidCertificate,
948            Self::HandshakeFailed(HandshakeError::ConnectionFailed(e)) if is_crypto_error(e) => {
949                ConnectionError::InvalidCertificate
950            }
951            Self::HandshakeFailed(_) => ConnectionError::HandshakeFailed,
952            Self::Timeout => ConnectionError::Timeout,
953        }
954    }
955
956    fn was_closed(&self) -> bool {
957        let connection_error = match self {
958            Self::ConnectionFailed(e)
959            | Self::HandshakeFailed(HandshakeError::ConnectionFailed(e)) => e,
960            _ => return false,
961        };
962
963        matches!(
964            connection_error,
965            quinn::ConnectionError::ApplicationClosed(closed)
966            if closed.error_code.into_inner() == 0
967        )
968    }
969}
970
971impl From<ConnectionInitError> for FullConnectionError {
972    fn from(value: ConnectionInitError) -> Self {
973        match value {
974            ConnectionInitError::ConnectionFailed(e) => Self::ConnectionFailed(e),
975            ConnectionInitError::InvalidCertificate => Self::InvalidCertificate,
976        }
977    }
978}
979
980#[derive(Clone)]
981pub(crate) struct ActivePeers(Arc<ActivePeersInner>);
982
983impl ActivePeers {
984    pub fn new(channel_size: usize) -> Self {
985        Self(Arc::new(ActivePeersInner::new(channel_size)))
986    }
987
988    pub fn get(&self, peer_id: &PeerId) -> Option<Connection> {
989        self.0.get(peer_id)
990    }
991
992    pub fn contains(&self, peer_id: &PeerId) -> bool {
993        self.0.contains(peer_id)
994    }
995
996    pub fn add(&self, local_id: &PeerId, new_connection: Connection) -> AddedPeer {
997        self.0.add(local_id, new_connection)
998    }
999
1000    pub fn remove(&self, peer_id: &PeerId, reason: DisconnectReason) {
1001        self.0.remove(peer_id, reason);
1002    }
1003
1004    pub fn remove_with_stable_id(
1005        &self,
1006        peer_id: &PeerId,
1007        stable_id: usize,
1008        reason: DisconnectReason,
1009    ) {
1010        self.0.remove_with_stable_id(peer_id, stable_id, reason);
1011    }
1012
1013    pub fn subscribe(&self) -> broadcast::Receiver<PeerEvent> {
1014        self.0.subscribe()
1015    }
1016
1017    pub fn is_empty(&self) -> bool {
1018        self.0.is_empty()
1019    }
1020
1021    pub fn len(&self) -> usize {
1022        self.0.len()
1023    }
1024}
1025
1026struct ActivePeersInner {
1027    connections: FastDashMap<PeerId, Connection>,
1028    connections_len: AtomicUsize,
1029    events_tx: broadcast::Sender<PeerEvent>,
1030}
1031
1032impl ActivePeersInner {
1033    fn new(channel_size: usize) -> Self {
1034        let (events_tx, _) = broadcast::channel(channel_size);
1035        Self {
1036            connections: Default::default(),
1037            connections_len: Default::default(),
1038            events_tx,
1039        }
1040    }
1041
1042    fn get(&self, peer_id: &PeerId) -> Option<Connection> {
1043        self.connections
1044            .get(peer_id)
1045            .map(|item| item.value().clone())
1046    }
1047
1048    fn contains(&self, peer_id: &PeerId) -> bool {
1049        self.connections.contains_key(peer_id)
1050    }
1051
1052    #[must_use]
1053    fn add(&self, local_id: &PeerId, new_connection: Connection) -> AddedPeer {
1054        use dashmap::mapref::entry::Entry;
1055
1056        let mut added = false;
1057
1058        let peer_id = new_connection.peer_id();
1059        match self.connections.entry(*peer_id) {
1060            Entry::Occupied(mut entry) => {
1061                if simultaneous_dial_tie_breaking(
1062                    local_id,
1063                    peer_id,
1064                    entry.get().origin(),
1065                    new_connection.origin(),
1066                ) {
1067                    tracing::debug!(%peer_id, "closing old connection to mitigate simultaneous dial");
1068                    let old_connection = entry.insert(new_connection.clone());
1069                    old_connection.close();
1070                    self.send_event(PeerEvent::lost_peer(*peer_id, DisconnectReason::Requested));
1071                } else {
1072                    tracing::debug!(%peer_id, "closing new connection to mitigate simultaneous dial");
1073                    new_connection.close();
1074                    return AddedPeer::Existing(entry.get().clone());
1075                }
1076            }
1077            Entry::Vacant(entry) => {
1078                self.connections_len.fetch_add(1, Ordering::Release);
1079                entry.insert(new_connection.clone());
1080                added = true;
1081            }
1082        }
1083
1084        self.send_event(PeerEvent::new_peer(*peer_id));
1085
1086        if added {
1087            metrics::gauge!(METRIC_ACTIVE_PEERS).increment(1);
1088        }
1089        AddedPeer::New(new_connection)
1090    }
1091
1092    fn remove(&self, peer_id: &PeerId, reason: DisconnectReason) {
1093        if let Some((_, connection)) = self.connections.remove(peer_id) {
1094            connection.close();
1095            self.connections_len.fetch_sub(1, Ordering::Release);
1096            self.send_event(PeerEvent::lost_peer(*peer_id, reason));
1097
1098            metrics::gauge!(METRIC_ACTIVE_PEERS).decrement(1);
1099        }
1100    }
1101
1102    fn remove_with_stable_id(&self, peer_id: &PeerId, stable_id: usize, reason: DisconnectReason) {
1103        if let Some((_, connection)) = self
1104            .connections
1105            .remove_if(peer_id, |_, connection| connection.stable_id() == stable_id)
1106        {
1107            connection.close();
1108            self.connections_len.fetch_sub(1, Ordering::Release);
1109            self.send_event(PeerEvent::lost_peer(*peer_id, reason));
1110
1111            metrics::gauge!(METRIC_ACTIVE_PEERS).decrement(1);
1112        }
1113    }
1114
1115    fn subscribe(&self) -> broadcast::Receiver<PeerEvent> {
1116        self.events_tx.subscribe()
1117    }
1118
1119    fn send_event(&self, event: PeerEvent) {
1120        _ = self.events_tx.send(event);
1121    }
1122
1123    fn is_empty(&self) -> bool {
1124        self.connections.is_empty()
1125    }
1126
1127    fn len(&self) -> usize {
1128        self.connections_len.load(Ordering::Acquire)
1129    }
1130}
1131
1132pub(crate) enum AddedPeer {
1133    New(Connection),
1134    Existing(Connection),
1135}
1136
1137fn simultaneous_dial_tie_breaking(
1138    local_id: &PeerId,
1139    peer_id: &PeerId,
1140    old_origin: Direction,
1141    new_origin: Direction,
1142) -> bool {
1143    match (old_origin, new_origin) {
1144        (Direction::Inbound, Direction::Inbound) | (Direction::Outbound, Direction::Outbound) => {
1145            true
1146        }
1147        (Direction::Inbound, Direction::Outbound) => peer_id < local_id,
1148        (Direction::Outbound, Direction::Inbound) => local_id < peer_id,
1149    }
1150}
1151
1152#[derive(Default, Clone)]
1153#[repr(transparent)]
1154pub struct KnownPeers(Arc<FastDashMap<PeerId, KnownPeerState>>);
1155
1156impl KnownPeers {
1157    pub fn new() -> Self {
1158        Self::default()
1159    }
1160
1161    pub fn contains(&self, peer_id: &PeerId) -> bool {
1162        self.0.contains_key(peer_id)
1163    }
1164
1165    pub fn is_banned(&self, peer_id: &PeerId) -> bool {
1166        self.0
1167            .get(peer_id)
1168            .and_then(|item| {
1169                Some(match item.value() {
1170                    KnownPeerState::Stored(item) => item.upgrade()?.is_banned(),
1171                    KnownPeerState::Banned => true,
1172                })
1173            })
1174            .unwrap_or_default()
1175    }
1176
1177    pub fn get(&self, peer_id: &PeerId) -> Option<Arc<PeerInfo>> {
1178        self.0.get(peer_id).and_then(|item| match item.value() {
1179            KnownPeerState::Stored(item) => {
1180                let inner = item.upgrade()?;
1181                Some(inner.peer_info.load_full())
1182            }
1183            KnownPeerState::Banned => None,
1184        })
1185    }
1186
1187    pub fn get_affinity(&self, peer_id: &PeerId) -> Option<PeerAffinity> {
1188        self.0
1189            .get(peer_id)
1190            .and_then(|item| item.value().compute_affinity())
1191    }
1192
1193    pub fn remove(&self, peer_id: &PeerId) {
1194        self.0.remove(peer_id);
1195        metrics::gauge!(METRIC_KNOWN_PEERS).decrement(1);
1196    }
1197
1198    pub fn ban(&self, peer_id: &PeerId) {
1199        let mut added = false;
1200        match self.0.entry(*peer_id) {
1201            dashmap::mapref::entry::Entry::Vacant(entry) => {
1202                entry.insert(KnownPeerState::Banned);
1203                added = true;
1204            }
1205            dashmap::mapref::entry::Entry::Occupied(mut entry) => match entry.get_mut() {
1206                KnownPeerState::Banned => {}
1207                KnownPeerState::Stored(item) => match item.upgrade() {
1208                    Some(item) => item.affinity.store(AFFINITY_BANNED, Ordering::Release),
1209                    None => *entry.get_mut() = KnownPeerState::Banned,
1210                },
1211            },
1212        }
1213
1214        if added {
1215            // NOTE: "New" banned peer is a "new" known peer.
1216            metrics::gauge!(METRIC_KNOWN_PEERS).increment(1);
1217        }
1218    }
1219
1220    pub fn make_handle(&self, peer_id: &PeerId, with_affinity: bool) -> Option<KnownPeerHandle> {
1221        let inner = match self.0.get(peer_id)?.value() {
1222            KnownPeerState::Stored(item) => {
1223                let inner = item.upgrade()?;
1224                if with_affinity && !inner.increase_affinity() {
1225                    return None;
1226                }
1227                inner
1228            }
1229            KnownPeerState::Banned => return None,
1230        };
1231
1232        Some(KnownPeerHandle::from_inner(inner, with_affinity))
1233    }
1234
1235    /// Inserts a new handle only if the provided info is not outdated
1236    /// and the peer is not banned.
1237    pub fn insert(
1238        &self,
1239        peer_info: Arc<PeerInfo>,
1240        with_affinity: bool,
1241    ) -> Result<KnownPeerHandle, KnownPeersError> {
1242        // TODO: add capacity limit for entries without affinity
1243        let mut added = false;
1244        let inner = match self.0.entry(peer_info.id) {
1245            dashmap::mapref::entry::Entry::Vacant(entry) => {
1246                let inner = KnownPeerInner::new(peer_info, with_affinity, &self.0);
1247                entry.insert(KnownPeerState::Stored(Arc::downgrade(&inner)));
1248                added = true;
1249                inner
1250            }
1251            dashmap::mapref::entry::Entry::Occupied(mut entry) => match entry.get_mut() {
1252                KnownPeerState::Banned => return Err(KnownPeersError::from(PeerBannedError)),
1253                KnownPeerState::Stored(item) => match item.upgrade() {
1254                    Some(inner) => match inner.try_update_peer_info(&peer_info, with_affinity)? {
1255                        true => inner,
1256                        false => return Err(KnownPeersError::OutdatedInfo),
1257                    },
1258                    None => {
1259                        let inner = KnownPeerInner::new(peer_info, with_affinity, &self.0);
1260                        *item = Arc::downgrade(&inner);
1261                        inner
1262                    }
1263                },
1264            },
1265        };
1266
1267        if added {
1268            metrics::gauge!(METRIC_KNOWN_PEERS).increment(1);
1269        }
1270
1271        Ok(KnownPeerHandle::from_inner(inner, with_affinity))
1272    }
1273
1274    /// Same as [`KnownPeers::insert`], but ignores outdated info.
1275    pub fn insert_allow_outdated(
1276        &self,
1277        peer_info: Arc<PeerInfo>,
1278        with_affinity: bool,
1279    ) -> Result<KnownPeerHandle, PeerBannedError> {
1280        // TODO: add capacity limit for entries without affinity
1281        let mut added = false;
1282        let inner = match self.0.entry(peer_info.id) {
1283            dashmap::mapref::entry::Entry::Vacant(entry) => {
1284                let inner = KnownPeerInner::new(peer_info, with_affinity, &self.0);
1285                entry.insert(KnownPeerState::Stored(Arc::downgrade(&inner)));
1286                added = true;
1287                inner
1288            }
1289            dashmap::mapref::entry::Entry::Occupied(mut entry) => match entry.get_mut() {
1290                KnownPeerState::Banned => return Err(PeerBannedError),
1291                KnownPeerState::Stored(item) => match item.upgrade() {
1292                    Some(inner) => {
1293                        // NOTE: Outdated info is ignored here.
1294                        inner.try_update_peer_info(&peer_info, with_affinity)?;
1295                        inner
1296                    }
1297                    None => {
1298                        let inner = KnownPeerInner::new(peer_info, with_affinity, &self.0);
1299                        *item = Arc::downgrade(&inner);
1300                        inner
1301                    }
1302                },
1303            },
1304        };
1305
1306        if added {
1307            metrics::gauge!(METRIC_KNOWN_PEERS).increment(1);
1308        }
1309
1310        Ok(KnownPeerHandle::from_inner(inner, with_affinity))
1311    }
1312}
1313
1314enum KnownPeerState {
1315    Stored(Weak<KnownPeerInner>),
1316    Banned,
1317}
1318
1319impl KnownPeerState {
1320    fn compute_affinity(&self) -> Option<PeerAffinity> {
1321        Some(match self {
1322            Self::Stored(weak) => weak.upgrade()?.compute_affinity(),
1323            Self::Banned => PeerAffinity::Never,
1324        })
1325    }
1326}
1327
1328#[derive(Clone)]
1329#[repr(transparent)]
1330pub struct KnownPeerHandle(KnownPeerHandleState);
1331
1332impl KnownPeerHandle {
1333    fn from_inner(inner: Arc<KnownPeerInner>, with_affinity: bool) -> Self {
1334        KnownPeerHandle(if with_affinity {
1335            KnownPeerHandleState::WithAffinity(ManuallyDrop::new(Arc::new(
1336                KnownPeerHandleWithAffinity { inner },
1337            )))
1338        } else {
1339            KnownPeerHandleState::Simple(ManuallyDrop::new(inner))
1340        })
1341    }
1342
1343    pub fn peer_info(&self) -> arc_swap::Guard<Arc<PeerInfo>, arc_swap::DefaultStrategy> {
1344        self.inner().peer_info.load()
1345    }
1346
1347    pub fn load_peer_info(&self) -> Arc<PeerInfo> {
1348        arc_swap::Guard::into_inner(self.peer_info())
1349    }
1350
1351    pub fn is_banned(&self) -> bool {
1352        self.inner().is_banned()
1353    }
1354
1355    pub fn max_affinity(&self) -> PeerAffinity {
1356        self.inner().compute_affinity()
1357    }
1358
1359    pub fn update_peer_info(&self, peer_info: &Arc<PeerInfo>) -> Result<(), KnownPeersError> {
1360        match self.inner().try_update_peer_info(peer_info, false) {
1361            Ok(true) => Ok(()),
1362            Ok(false) => Err(KnownPeersError::OutdatedInfo),
1363            Err(e) => Err(KnownPeersError::PeerBanned(e)),
1364        }
1365    }
1366
1367    pub fn ban(&self) -> bool {
1368        let inner = self.inner();
1369        inner.affinity.swap(AFFINITY_BANNED, Ordering::AcqRel) != AFFINITY_BANNED
1370    }
1371
1372    pub fn increase_affinity(&mut self) -> bool {
1373        match &mut self.0 {
1374            KnownPeerHandleState::Simple(inner) => {
1375                // NOTE: Handle will be updated even if the peer is banned.
1376                inner.increase_affinity();
1377
1378                // SAFETY: Inner value was not dropped.
1379                let inner = unsafe { ManuallyDrop::take(inner) };
1380
1381                // Replace the old state with the new one, ensuring that the old state
1382                // is not dropped (because we took the value out of it).
1383                let prev_state = std::mem::replace(
1384                    &mut self.0,
1385                    KnownPeerHandleState::WithAffinity(ManuallyDrop::new(Arc::new(
1386                        KnownPeerHandleWithAffinity { inner },
1387                    ))),
1388                );
1389
1390                // Forget the old state to avoid dropping it.
1391                #[allow(clippy::mem_forget)]
1392                std::mem::forget(prev_state);
1393
1394                true
1395            }
1396            KnownPeerHandleState::WithAffinity(_) => false,
1397        }
1398    }
1399
1400    pub fn decrease_affinity(&mut self) -> bool {
1401        match &mut self.0 {
1402            KnownPeerHandleState::Simple(_) => false,
1403            KnownPeerHandleState::WithAffinity(inner) => {
1404                // NOTE: Handle will be updated even if the peer is banned.
1405                inner.inner.decrease_affinity();
1406
1407                // SAFETY: Inner value was not dropped.
1408                let inner = unsafe { ManuallyDrop::take(inner) };
1409
1410                // Get `KnownPeerInner` out of the wrapper.
1411                let inner = match Arc::try_unwrap(inner) {
1412                    Ok(KnownPeerHandleWithAffinity { inner }) => inner,
1413                    Err(inner) => inner.inner.clone(),
1414                };
1415
1416                // Replace the old state with the new one, ensuring that the old state
1417                // is not dropped (because we took the value out of it).
1418                let prev_state = std::mem::replace(
1419                    &mut self.0,
1420                    KnownPeerHandleState::Simple(ManuallyDrop::new(inner)),
1421                );
1422
1423                // Forget the old state to avoid dropping it.
1424                #[allow(clippy::mem_forget)]
1425                std::mem::forget(prev_state);
1426
1427                true
1428            }
1429        }
1430    }
1431
1432    pub fn downgrade(&self) -> WeakKnownPeerHandle {
1433        WeakKnownPeerHandle(match &self.0 {
1434            KnownPeerHandleState::Simple(data) => {
1435                WeakKnownPeerHandleState::Simple(Arc::downgrade(data))
1436            }
1437            KnownPeerHandleState::WithAffinity(data) => {
1438                WeakKnownPeerHandleState::WithAffinity(Arc::downgrade(data))
1439            }
1440        })
1441    }
1442
1443    fn inner(&self) -> &KnownPeerInner {
1444        match &self.0 {
1445            KnownPeerHandleState::Simple(data) => data.as_ref(),
1446            KnownPeerHandleState::WithAffinity(data) => data.inner.as_ref(),
1447        }
1448    }
1449}
1450
1451#[derive(Clone)]
1452enum KnownPeerHandleState {
1453    Simple(ManuallyDrop<Arc<KnownPeerInner>>),
1454    WithAffinity(ManuallyDrop<Arc<KnownPeerHandleWithAffinity>>),
1455}
1456
1457impl Drop for KnownPeerHandleState {
1458    fn drop(&mut self) {
1459        let inner;
1460        let is_banned;
1461        match self {
1462            KnownPeerHandleState::Simple(data) => {
1463                // SAFETY: inner value is dropped only once
1464                inner = unsafe { ManuallyDrop::take(data) };
1465                is_banned = inner.is_banned();
1466            }
1467            KnownPeerHandleState::WithAffinity(data) => {
1468                // SAFETY: inner value is dropped only once
1469                match Arc::into_inner(unsafe { ManuallyDrop::take(data) }) {
1470                    Some(data) => {
1471                        inner = data.inner;
1472                        is_banned = !inner.decrease_affinity() || inner.is_banned();
1473                    }
1474                    None => return,
1475                }
1476            }
1477        };
1478
1479        if is_banned {
1480            // Don't remove banned peers from the known peers cache
1481            return;
1482        }
1483
1484        if let Some(inner) = Arc::into_inner(inner) {
1485            // If the last reference is dropped, remove the peer from the known peers cache
1486            if let Some(peers) = inner.weak_known_peers.upgrade() {
1487                peers.remove(&inner.peer_info.load().id);
1488                metrics::gauge!(METRIC_KNOWN_PEERS).decrement(1);
1489            }
1490        }
1491    }
1492}
1493
1494#[derive(Clone, PartialEq, Eq)]
1495#[repr(transparent)]
1496pub struct WeakKnownPeerHandle(WeakKnownPeerHandleState);
1497
1498impl WeakKnownPeerHandle {
1499    pub fn upgrade(&self) -> Option<KnownPeerHandle> {
1500        Some(KnownPeerHandle(match &self.0 {
1501            WeakKnownPeerHandleState::Simple(weak) => {
1502                KnownPeerHandleState::Simple(ManuallyDrop::new(weak.upgrade()?))
1503            }
1504            WeakKnownPeerHandleState::WithAffinity(weak) => {
1505                KnownPeerHandleState::WithAffinity(ManuallyDrop::new(weak.upgrade()?))
1506            }
1507        }))
1508    }
1509}
1510
1511#[derive(Clone)]
1512enum WeakKnownPeerHandleState {
1513    Simple(Weak<KnownPeerInner>),
1514    WithAffinity(Weak<KnownPeerHandleWithAffinity>),
1515}
1516
1517impl Eq for WeakKnownPeerHandleState {}
1518impl PartialEq for WeakKnownPeerHandleState {
1519    #[inline]
1520    fn eq(&self, other: &Self) -> bool {
1521        match (self, other) {
1522            (Self::Simple(left), Self::Simple(right)) => Weak::ptr_eq(left, right),
1523            (Self::WithAffinity(left), Self::WithAffinity(right)) => Weak::ptr_eq(left, right),
1524            _ => false,
1525        }
1526    }
1527}
1528
1529struct KnownPeerHandleWithAffinity {
1530    inner: Arc<KnownPeerInner>,
1531}
1532
1533struct KnownPeerInner {
1534    peer_info: ArcSwap<PeerInfo>,
1535    affinity: AtomicUsize,
1536    weak_known_peers: Weak<FastDashMap<PeerId, KnownPeerState>>,
1537}
1538
1539impl KnownPeerInner {
1540    fn new(
1541        peer_info: Arc<PeerInfo>,
1542        with_affinity: bool,
1543        known_peers: &Arc<FastDashMap<PeerId, KnownPeerState>>,
1544    ) -> Arc<Self> {
1545        Arc::new(Self {
1546            peer_info: ArcSwap::from(peer_info),
1547            affinity: AtomicUsize::new(if with_affinity { 1 } else { 0 }),
1548            weak_known_peers: Arc::downgrade(known_peers),
1549        })
1550    }
1551
1552    fn is_banned(&self) -> bool {
1553        self.affinity.load(Ordering::Acquire) == AFFINITY_BANNED
1554    }
1555
1556    fn compute_affinity(&self) -> PeerAffinity {
1557        match self.affinity.load(Ordering::Acquire) {
1558            0 => PeerAffinity::Allowed,
1559            AFFINITY_BANNED => PeerAffinity::Never,
1560            _ => PeerAffinity::High,
1561        }
1562    }
1563
1564    fn increase_affinity(&self) -> bool {
1565        let mut current = self.affinity.load(Ordering::Acquire);
1566        while current != AFFINITY_BANNED {
1567            debug_assert_ne!(current, AFFINITY_BANNED - 1);
1568            match self.affinity.compare_exchange_weak(
1569                current,
1570                current + 1,
1571                Ordering::Release,
1572                Ordering::Acquire,
1573            ) {
1574                Ok(_) => return true,
1575                Err(affinity) => current = affinity,
1576            }
1577        }
1578
1579        false
1580    }
1581
1582    fn decrease_affinity(&self) -> bool {
1583        let mut current = self.affinity.load(Ordering::Acquire);
1584        while current != AFFINITY_BANNED {
1585            debug_assert_ne!(current, 0);
1586            match self.affinity.compare_exchange_weak(
1587                current,
1588                current - 1,
1589                Ordering::Release,
1590                Ordering::Acquire,
1591            ) {
1592                Ok(_) => return true,
1593                Err(affinity) => current = affinity,
1594            }
1595        }
1596
1597        false
1598    }
1599
1600    fn try_update_peer_info(
1601        &self,
1602        peer_info: &Arc<PeerInfo>,
1603        with_affinity: bool,
1604    ) -> Result<bool, PeerBannedError> {
1605        struct AffinityGuard<'a> {
1606            inner: &'a KnownPeerInner,
1607            decrease_on_drop: bool,
1608        }
1609
1610        impl AffinityGuard<'_> {
1611            fn increase_affinity_or_check_ban(&mut self, with_affinity: bool) -> bool {
1612                let with_affinity = with_affinity && !self.decrease_on_drop;
1613                let is_banned = if with_affinity {
1614                    !self.inner.increase_affinity()
1615                } else {
1616                    self.inner.is_banned()
1617                };
1618
1619                if !is_banned && with_affinity {
1620                    self.decrease_on_drop = true;
1621                }
1622
1623                is_banned
1624            }
1625        }
1626
1627        impl Drop for AffinityGuard<'_> {
1628            fn drop(&mut self) {
1629                if self.decrease_on_drop {
1630                    self.inner.decrease_affinity();
1631                }
1632            }
1633        }
1634
1635        // Create a guard to restore the peer affinity in case of an error
1636        let mut guard = AffinityGuard {
1637            inner: self,
1638            decrease_on_drop: false,
1639        };
1640
1641        let mut cur = self.peer_info.load();
1642        let updated = loop {
1643            if guard.increase_affinity_or_check_ban(with_affinity) {
1644                // Do nothing for banned peers
1645                return Err(PeerBannedError);
1646            }
1647
1648            match cur.created_at.cmp(&peer_info.created_at) {
1649                // Do nothing for the same creation time
1650                // TODO: is `created_at` equality enough?
1651                std::cmp::Ordering::Equal => break true,
1652                // Try to update peer info
1653                std::cmp::Ordering::Less => {
1654                    let prev = self.peer_info.compare_and_swap(&*cur, peer_info.clone());
1655                    if std::ptr::eq(cur.as_raw(), prev.as_raw()) {
1656                        break true;
1657                    } else {
1658                        cur = prev;
1659                    }
1660                }
1661                // Allow an outdated data
1662                std::cmp::Ordering::Greater => break false,
1663            }
1664        };
1665
1666        guard.decrease_on_drop = false;
1667        Ok(updated)
1668    }
1669}
1670
1671const AFFINITY_BANNED: usize = usize::MAX;
1672
1673#[derive(Debug, thiserror::Error)]
1674pub enum KnownPeersError {
1675    #[error(transparent)]
1676    PeerBanned(#[from] PeerBannedError),
1677    #[error("provided peer info is outdated")]
1678    OutdatedInfo,
1679}
1680
1681#[derive(Debug, Copy, Clone, thiserror::Error)]
1682#[error("peer is banned")]
1683pub struct PeerBannedError;
1684
1685#[cfg(test)]
1686mod tests {
1687    use super::*;
1688    use crate::util::make_peer_info_stub;
1689
1690    #[test]
1691    fn remove_from_cache_on_drop_works() {
1692        let peers = KnownPeers::new();
1693
1694        let peer_info = make_peer_info_stub(rand::random());
1695        let handle = peers.insert(peer_info.clone(), false).unwrap();
1696        assert!(peers.contains(&peer_info.id));
1697        assert!(!peers.is_banned(&peer_info.id));
1698        assert_eq!(peers.get(&peer_info.id), Some(peer_info.clone()));
1699        assert_eq!(
1700            peers.get_affinity(&peer_info.id),
1701            Some(PeerAffinity::Allowed)
1702        );
1703
1704        assert_eq!(handle.peer_info().as_ref(), peer_info.as_ref());
1705        assert_eq!(handle.max_affinity(), PeerAffinity::Allowed);
1706
1707        let other_handle = peers.insert(peer_info.clone(), false).unwrap();
1708        assert!(peers.contains(&peer_info.id));
1709        assert!(!peers.is_banned(&peer_info.id));
1710        assert_eq!(peers.get(&peer_info.id), Some(peer_info.clone()));
1711        assert_eq!(
1712            peers.get_affinity(&peer_info.id),
1713            Some(PeerAffinity::Allowed)
1714        );
1715
1716        assert_eq!(other_handle.peer_info().as_ref(), peer_info.as_ref());
1717        assert_eq!(other_handle.max_affinity(), PeerAffinity::Allowed);
1718
1719        drop(other_handle);
1720        assert!(peers.contains(&peer_info.id));
1721        assert!(!peers.is_banned(&peer_info.id));
1722        assert_eq!(peers.get(&peer_info.id), Some(peer_info.clone()));
1723        assert_eq!(
1724            peers.get_affinity(&peer_info.id),
1725            Some(PeerAffinity::Allowed)
1726        );
1727
1728        drop(handle);
1729        assert!(!peers.contains(&peer_info.id));
1730        assert!(!peers.is_banned(&peer_info.id));
1731        assert_eq!(peers.get(&peer_info.id), None);
1732        assert_eq!(peers.get_affinity(&peer_info.id), None);
1733
1734        peers.insert(peer_info.clone(), false).unwrap();
1735    }
1736
1737    #[test]
1738    fn with_affinity_after_simple() {
1739        let peers = KnownPeers::new();
1740
1741        let peer_info = make_peer_info_stub(rand::random());
1742        let handle_simple = peers.insert(peer_info.clone(), false).unwrap();
1743        assert!(peers.contains(&peer_info.id));
1744        assert_eq!(
1745            peers.get_affinity(&peer_info.id),
1746            Some(PeerAffinity::Allowed)
1747        );
1748        assert_eq!(handle_simple.max_affinity(), PeerAffinity::Allowed);
1749
1750        let handle_with_affinity = peers.insert(peer_info.clone(), true).unwrap();
1751        assert!(peers.contains(&peer_info.id));
1752        assert_eq!(peers.get_affinity(&peer_info.id), Some(PeerAffinity::High));
1753        assert_eq!(handle_with_affinity.max_affinity(), PeerAffinity::High);
1754        assert_eq!(handle_simple.max_affinity(), PeerAffinity::High);
1755
1756        drop(handle_with_affinity);
1757        assert!(peers.contains(&peer_info.id));
1758        assert_eq!(handle_simple.max_affinity(), PeerAffinity::Allowed);
1759        assert_eq!(
1760            peers.get_affinity(&peer_info.id),
1761            Some(PeerAffinity::Allowed)
1762        );
1763
1764        drop(handle_simple);
1765        assert!(!peers.contains(&peer_info.id));
1766        assert_eq!(peers.get_affinity(&peer_info.id), None);
1767    }
1768
1769    #[test]
1770    fn with_affinity_before_simple() {
1771        let peers = KnownPeers::new();
1772
1773        let peer_info = make_peer_info_stub(rand::random());
1774        let handle_with_affinity = peers.insert(peer_info.clone(), true).unwrap();
1775        assert!(peers.contains(&peer_info.id));
1776        assert_eq!(peers.get_affinity(&peer_info.id), Some(PeerAffinity::High));
1777        assert_eq!(handle_with_affinity.max_affinity(), PeerAffinity::High);
1778
1779        let handle_simple = peers.insert(peer_info.clone(), false).unwrap();
1780        assert!(peers.contains(&peer_info.id));
1781        assert_eq!(peers.get_affinity(&peer_info.id), Some(PeerAffinity::High));
1782        assert_eq!(handle_with_affinity.max_affinity(), PeerAffinity::High);
1783        assert_eq!(handle_simple.max_affinity(), PeerAffinity::High);
1784
1785        drop(handle_simple);
1786        assert!(peers.contains(&peer_info.id));
1787        assert_eq!(handle_with_affinity.max_affinity(), PeerAffinity::High);
1788        assert_eq!(peers.get_affinity(&peer_info.id), Some(PeerAffinity::High));
1789
1790        drop(handle_with_affinity);
1791        assert!(!peers.contains(&peer_info.id));
1792        assert_eq!(peers.get_affinity(&peer_info.id), None);
1793    }
1794
1795    #[test]
1796    fn ban_while_handle_exists() {
1797        let peers = KnownPeers::new();
1798
1799        let peer_info = make_peer_info_stub(rand::random());
1800        let handle = peers.insert(peer_info.clone(), false).unwrap();
1801        assert!(peers.contains(&peer_info.id));
1802        assert_eq!(handle.max_affinity(), PeerAffinity::Allowed);
1803
1804        peers.ban(&peer_info.id);
1805        assert!(peers.contains(&peer_info.id));
1806        assert!(peers.is_banned(&peer_info.id));
1807        assert_eq!(handle.max_affinity(), PeerAffinity::Never);
1808        assert_eq!(peers.get(&peer_info.id), Some(peer_info.clone()));
1809        assert_eq!(peers.get_affinity(&peer_info.id), Some(PeerAffinity::Never));
1810    }
1811}