tycho_network/network/
connection_manager.rs

1use std::collections::{VecDeque, hash_map};
2use std::mem::ManuallyDrop;
3use std::sync::atomic::{AtomicUsize, Ordering};
4use std::sync::{Arc, Weak};
5use std::time::{Duration, Instant};
6
7use anyhow::Result;
8use arc_swap::{ArcSwap, AsRaw};
9use tokio::sync::{broadcast, mpsc, oneshot};
10use tokio::task::{AbortHandle, JoinSet};
11use tokio_util::time::{DelayQueue, delay_queue};
12use tycho_util::{FastDashMap, FastHashMap};
13
14use crate::network::ConnectionError;
15use crate::network::config::NetworkConfig;
16use crate::network::connection::Connection;
17use crate::network::endpoint::{Connecting, ConnectionInitError, Endpoint, Into0RttResult};
18use crate::network::request_handler::InboundRequestHandler;
19use crate::network::wire::{HandshakeError, handshake};
20use crate::types::{
21    Address, BoxCloneService, Direction, DisconnectReason, PeerAffinity, PeerEvent, PeerId,
22    PeerInfo, Response, ServiceRequest,
23};
24
25// Histograms
26const METRIC_CONNECTION_OUT_TIME: &str = "tycho_net_conn_out_time";
27const METRIC_CONNECTION_IN_TIME: &str = "tycho_net_conn_in_time";
28
29// Counters
30const METRIC_CONNECTIONS_OUT_TOTAL: &str = "tycho_net_conn_out_total";
31const METRIC_CONNECTIONS_IN_TOTAL: &str = "tycho_net_conn_in_total";
32const METRIC_CONNECTIONS_OUT_FAIL_TOTAL: &str = "tycho_net_conn_out_fail_total";
33const METRIC_CONNECTIONS_IN_FAIL_TOTAL: &str = "tycho_net_conn_in_fail_total";
34
35// Gauges
36const METRIC_CONNECTIONS_ACTIVE: &str = "tycho_net_conn_active";
37const METRIC_CONNECTIONS_PENDING: &str = "tycho_net_conn_pending";
38const METRIC_CONNECTIONS_PARTIAL: &str = "tycho_net_conn_partial";
39const METRIC_CONNECTIONS_PENDING_DIALS: &str = "tycho_net_conn_pending_dials";
40
41const METRIC_ACTIVE_PEERS: &str = "tycho_net_active_peers";
42const METRIC_KNOWN_PEERS: &str = "tycho_net_known_peers";
43
44#[derive(Debug)]
45pub(crate) enum ConnectionManagerRequest {
46    Connect(Address, PeerId, CallbackTx),
47    Shutdown(oneshot::Sender<()>),
48}
49
50pub(crate) struct ConnectionManager {
51    config: Arc<NetworkConfig>,
52    endpoint: Arc<Endpoint>,
53
54    mailbox: mpsc::Receiver<ConnectionManagerRequest>,
55
56    pending_connection_callbacks: FastHashMap<Address, PendingConnectionCallbacks>,
57    pending_partial_connections: JoinSet<Option<PartialConnection>>,
58    pending_connections: JoinSet<ConnectingOutput>,
59    connection_handlers: JoinSet<()>,
60    delayed_callbacks: DelayedCallbacksQueue,
61
62    pending_dials: FastHashMap<PeerId, CallbackRx>,
63    dial_backoff_states: FastHashMap<PeerId, DialBackoffState>,
64
65    active_peers: ActivePeers,
66    known_peers: KnownPeers,
67
68    service: BoxCloneService<ServiceRequest, Response>,
69}
70
71type CallbackTx = oneshot::Sender<Result<Connection, ConnectionError>>;
72type CallbackRx = oneshot::Receiver<Result<Connection, ConnectionError>>;
73
74impl Drop for ConnectionManager {
75    fn drop(&mut self) {
76        tracing::trace!("dropping connection manager");
77        self.endpoint.close();
78    }
79}
80
81impl ConnectionManager {
82    pub fn new(
83        config: Arc<NetworkConfig>,
84        endpoint: Arc<Endpoint>,
85        active_peers: ActivePeers,
86        known_peers: KnownPeers,
87        service: BoxCloneService<ServiceRequest, Response>,
88    ) -> (Self, mpsc::Sender<ConnectionManagerRequest>) {
89        let (mailbox_tx, mailbox) = mpsc::channel(config.connection_manager_channel_capacity);
90        let connection_manager = Self {
91            config,
92            endpoint,
93            mailbox,
94            pending_connection_callbacks: Default::default(),
95            pending_partial_connections: Default::default(),
96            pending_connections: Default::default(),
97            connection_handlers: Default::default(),
98            delayed_callbacks: Default::default(),
99            pending_dials: Default::default(),
100            dial_backoff_states: Default::default(),
101            active_peers,
102            known_peers,
103            service,
104        };
105        (connection_manager, mailbox_tx)
106    }
107
108    pub async fn start(mut self) {
109        tracing::info!("connection manager started");
110
111        let jitter = Duration::from_millis(1000).mul_f64(rand::random());
112        let mut interval = tokio::time::interval(self.config.connectivity_check_interval + jitter);
113
114        let mut shutdown_notifier = None;
115
116        loop {
117            tokio::select! {
118                now = interval.tick() => {
119                    self.handle_connectivity_check(now.into_std());
120                }
121                maybe_request = self.mailbox.recv() => {
122                    let Some(request) = maybe_request else {
123                        break;
124                    };
125
126                    match request {
127                        ConnectionManagerRequest::Connect(address, peer_id, callback) => {
128                            self.handle_connect_request(address, &peer_id, callback);
129                        }
130                        ConnectionManagerRequest::Shutdown(oneshot) => {
131                            shutdown_notifier = Some(oneshot);
132                            break;
133                        }
134                    }
135                }
136                connecting = self.endpoint.accept() => {
137                    if let Some(connecting) = connecting {
138                        self.handle_incoming(connecting);
139                    }
140                }
141                Some(connecting_output) = self.pending_connections.join_next() => {
142                    metrics::gauge!(METRIC_CONNECTIONS_PENDING).decrement(1);
143                    match connecting_output {
144                        Ok(connecting) => self.handle_connecting_result(connecting),
145                        Err(e) => {
146                            if e.is_panic() {
147                                std::panic::resume_unwind(e.into_panic());
148                            }
149                        }
150                    }
151                }
152                Some(partial_connection) = self.pending_partial_connections.join_next() => {
153                    metrics::gauge!(METRIC_CONNECTIONS_PARTIAL).decrement(1);
154
155                    match partial_connection {
156                        Ok(None) => {}
157                        Ok(Some(PartialConnection {
158                            connection,
159                            timeout_at,
160                        })) => self.handle_incoming_impl(connection, None, timeout_at),
161                        // NOTE: unwrap here is to propagate panic from the spawned future
162                        Err(e) => {
163                            if e.is_panic() {
164                                std::panic::resume_unwind(e.into_panic());
165                            }
166                        }
167                    }
168                }
169                Some(connection_handler_output) = self.connection_handlers.join_next() => {
170                    metrics::gauge!(METRIC_CONNECTIONS_ACTIVE).decrement(1);
171
172                    // NOTE: unwrap here is to propagate panic from the spawned future
173                    if let Err(e) = connection_handler_output && e.is_panic() {
174                        std::panic::resume_unwind(e.into_panic());
175                    }
176                }
177                Some(peer_id) = self.delayed_callbacks.wait_for_next_expired() => {
178                    self.delayed_callbacks.execute_expired(&peer_id);
179                }
180            }
181        }
182
183        self.shutdown().await;
184
185        if let Some(tx) = shutdown_notifier {
186            _ = tx.send(());
187        }
188
189        tracing::info!("connection manager stopped");
190    }
191
192    async fn shutdown(mut self) {
193        tracing::trace!("shutting down connection manager");
194
195        self.endpoint.close();
196
197        self.pending_partial_connections.shutdown().await;
198        metrics::gauge!(METRIC_CONNECTIONS_PARTIAL).set(0);
199
200        self.pending_connections.shutdown().await;
201        metrics::gauge!(METRIC_CONNECTIONS_PENDING).set(0);
202
203        while self.connection_handlers.join_next().await.is_some() {
204            metrics::gauge!(METRIC_CONNECTIONS_ACTIVE).decrement(1);
205        }
206        assert!(self.active_peers.is_empty());
207
208        self.endpoint
209            .wait_idle(self.config.shutdown_idle_timeout)
210            .await;
211    }
212
213    fn handle_connectivity_check(&mut self, now: Instant) {
214        use std::collections::hash_map::Entry;
215
216        self.pending_dials
217            .retain(|peer_id, oneshot| match oneshot.try_recv() {
218                Ok(Ok(returned_peer_id)) => {
219                    debug_assert_eq!(peer_id, returned_peer_id.peer_id());
220                    self.dial_backoff_states.remove(peer_id);
221                    false
222                }
223                Ok(Err(_)) => {
224                    match self.dial_backoff_states.entry(*peer_id) {
225                        Entry::Occupied(mut entry) => entry.get_mut().update(
226                            now,
227                            self.config.connection_backoff,
228                            self.config.max_connection_backoff,
229                        ),
230                        Entry::Vacant(entry) => {
231                            entry.insert(DialBackoffState::new(
232                                now,
233                                self.config.connection_backoff,
234                                self.config.max_connection_backoff,
235                            ));
236                        }
237                    }
238                    false
239                }
240                Err(oneshot::error::TryRecvError::Closed) => {
241                    panic!("BUG: connection manager never finished dialing a peer");
242                }
243                Err(oneshot::error::TryRecvError::Empty) => true,
244            });
245
246        let outstanding_connections_limit = self
247            .config
248            .max_concurrent_outstanding_connections
249            .saturating_sub(self.pending_connections.len());
250
251        let outstanding_connections = self
252            .known_peers
253            .0
254            .iter()
255            .filter_map(|item| {
256                let value = match item.value() {
257                    KnownPeerState::Stored(item) => item.upgrade()?,
258                    KnownPeerState::Banned => return None,
259                };
260                let peer_info = value.peer_info.load();
261                let affinity = value.compute_affinity();
262
263                (affinity == PeerAffinity::High
264                    && peer_info.id != self.endpoint.peer_id()
265                    && !self.active_peers.contains(&peer_info.id)
266                    && !self.pending_dials.contains_key(&peer_info.id)
267                    && self
268                        .dial_backoff_states
269                        .get(&peer_info.id)
270                        .is_none_or(|state| now > state.next_attempt_at))
271                .then(|| arc_swap::Guard::into_inner(peer_info))
272            })
273            .take(outstanding_connections_limit)
274            .collect::<Vec<_>>();
275
276        for peer_info in outstanding_connections {
277            // TODO: handle multiple addresses
278            let address = peer_info
279                .iter_addresses()
280                .next()
281                .cloned()
282                .expect("address list must have at least one item");
283
284            let (tx, rx) = oneshot::channel();
285            self.dial_peer(address, &peer_info.id, tx);
286            self.pending_dials.insert(peer_info.id, rx);
287        }
288
289        metrics::gauge!(METRIC_CONNECTIONS_PENDING_DIALS).set(self.pending_dials.len() as f64);
290    }
291
292    fn handle_connect_request(&mut self, address: Address, peer_id: &PeerId, callback: CallbackTx) {
293        self.dial_peer(address, peer_id, callback);
294    }
295
296    fn handle_incoming(&mut self, connecting: Connecting) {
297        let remote_addr = connecting.remote_address();
298        tracing::trace!(
299            local_id = %self.endpoint.peer_id(),
300            %remote_addr,
301            "received an incoming connection",
302        );
303
304        // Split incoming connection into 0.5-RTT and 1-RTT parts.
305        match connecting.into_0rtt() {
306            Into0RttResult::Established(connection, accepted) => {
307                let timeout_at = Instant::now() + self.config.connect_timeout;
308                self.handle_incoming_impl(connection, Some(accepted), timeout_at);
309            }
310            Into0RttResult::WithoutIdentity(partial_connection) => {
311                tracing::trace!("connection identity is not available yet");
312
313                let timeout_at = Instant::now() + self.config.connect_timeout;
314                self.pending_partial_connections.spawn(async move {
315                    match tokio::time::timeout_at(timeout_at.into(), partial_connection).await {
316                        Ok(Ok(connection)) => Some(PartialConnection {
317                            connection,
318                            timeout_at,
319                        }),
320                        Ok(Err(e)) => {
321                            tracing::trace!(
322                                %remote_addr,
323                                "failed to establish an incoming connection: {e}",
324                            );
325                            None
326                        }
327                        Err(_) => {
328                            tracing::trace!(
329                                %remote_addr,
330                                "incoming connection timed out",
331                            );
332                            None
333                        }
334                    }
335                });
336                metrics::gauge!(METRIC_CONNECTIONS_PARTIAL).increment(1);
337            }
338            Into0RttResult::InvalidCertificate => {
339                tracing::trace!(%remote_addr, "invalid incoming connection");
340            }
341            Into0RttResult::Unavailable(_) => unreachable!(
342                "BUG: For incoming connections, a 0.5-RTT connection must \
343                always be successfully constructed."
344            ),
345        };
346    }
347
348    fn handle_incoming_impl(
349        &mut self,
350        connection: Connection,
351        accepted: Option<quinn::ZeroRttAccepted>,
352        timeout_at: Instant,
353    ) {
354        async fn handle_incoming_task(
355            seqno: u32,
356            connection: ConnectionClosedOnDrop,
357            accepted: Option<quinn::ZeroRttAccepted>,
358            timeout_at: Instant,
359        ) -> ConnectingOutput {
360            let target_peer_id = *connection.peer_id();
361            let target_address = connection.remote_address().into();
362            let fut = async {
363                if let Some(accepted) = accepted {
364                    // NOTE: `bool` output of this future is meaningless for servers.
365                    accepted.await;
366                }
367                handshake(&connection).await
368            };
369
370            let started_at = Instant::now();
371
372            let connecting_result = match tokio::time::timeout_at(timeout_at.into(), fut).await {
373                Ok(Ok(())) => Ok(connection.disarm()),
374                Ok(Err(e)) => Err(FullConnectionError::HandshakeFailed(e)),
375                Err(_) => Err(FullConnectionError::Timeout),
376            };
377
378            metrics::histogram!(METRIC_CONNECTION_IN_TIME).record(started_at.elapsed());
379
380            ConnectingOutput {
381                seqno,
382                drop_result: true,
383                connecting_result: ManuallyDrop::new(connecting_result),
384                target_address,
385                target_peer_id,
386                origin: Direction::Inbound,
387            }
388        }
389
390        let remote_addr = connection.remote_address();
391
392        // Check if the peer is allowed before doing anything else.
393        match self.known_peers.get_affinity(connection.peer_id()) {
394            Some(PeerAffinity::High | PeerAffinity::Allowed) => {}
395            Some(PeerAffinity::Never) => {
396                // TODO: Lower log level to trace/debug?
397                tracing::warn!(
398                    %remote_addr,
399                    peer_id = %connection.peer_id(),
400                    "rejecting connection due to PeerAffinity::Never",
401                );
402                connection.close();
403                return;
404            }
405            _ => {
406                if matches!(
407                    self.config.max_concurrent_connections,
408                    Some(limit) if self.active_peers.len() >= limit
409                ) {
410                    // TODO: Lower log level to trace/debug?
411                    tracing::warn!(
412                        %remote_addr,
413                        peer_id = %connection.peer_id(),
414                        "rejecting connection due too many concurrent connections",
415                    );
416                    connection.close();
417                    return;
418                }
419            }
420        }
421
422        let entry = match self.pending_connection_callbacks.entry(remote_addr.into()) {
423            hash_map::Entry::Vacant(entry) => Some(entry.insert(PendingConnectionCallbacks {
424                last_seqno: 0,
425                origin: Direction::Inbound,
426                callbacks: Default::default(),
427                abort_handle: None,
428            })),
429            hash_map::Entry::Occupied(entry) => {
430                let entry = entry.into_mut();
431
432                // Check if the incoming connection is a simultaneous dial.
433                if simultaneous_dial_tie_breaking(
434                    self.endpoint.peer_id(),
435                    connection.peer_id(),
436                    entry.origin,
437                    Direction::Inbound,
438                ) {
439                    // New connection wins the tie, abort the old one and spawn a new task.
440                    tracing::debug!(
441                        %remote_addr,
442                        peer_id = %connection.peer_id(),
443                        "cancelling old connection to mitigate simultaneous dial",
444                    );
445
446                    entry.origin = Direction::Inbound;
447                    entry.last_seqno += 1;
448                    if let Some(handle) = entry.abort_handle.take() {
449                        handle.abort();
450                    }
451                    Some(entry)
452                } else {
453                    // Old connection wins the tie, gracefully close the new one.
454                    tracing::debug!(
455                        %remote_addr,
456                        peer_id = %connection.peer_id(),
457                        "cancelling new connection to mitigate simultaneous dial",
458                    );
459
460                    connection.close();
461                    None
462                }
463            }
464        };
465
466        if let Some(entry) = entry {
467            entry.abort_handle = Some(self.pending_connections.spawn(handle_incoming_task(
468                entry.last_seqno,
469                ConnectionClosedOnDrop::new(connection),
470                accepted,
471                timeout_at,
472            )));
473            metrics::gauge!(METRIC_CONNECTIONS_PENDING).increment(1);
474        }
475    }
476
477    fn handle_connecting_result(&mut self, mut res: ConnectingOutput) {
478        // Check seqno first to drop outdated results.
479        {
480            let Some(entry) = self.pending_connection_callbacks.get(&res.target_address) else {
481                tracing::trace!("connection task reordering detected");
482                return;
483            };
484
485            if entry.last_seqno != res.seqno {
486                tracing::debug!(
487                    local_id = %self.endpoint.peer_id(),
488                    peer_id = %res.target_peer_id,
489                    remote_addr = %res.target_address,
490                    "connection result is outdated"
491                );
492                return;
493            }
494        }
495
496        let callbacks = self
497            .pending_connection_callbacks
498            .remove(&res.target_address)
499            .expect("Connection tasks must be tracked")
500            .callbacks;
501
502        res.drop_result = false;
503        // SAFETY: `drop_result` is set to `false`.
504        match unsafe { ManuallyDrop::take(&mut res.connecting_result) } {
505            Ok(connection) => {
506                tracing::debug!(
507                    local_id = %self.endpoint.peer_id(),
508                    peer_id = %connection.peer_id(),
509                    remote_addr = %res.target_address,
510                    "new connection",
511                );
512
513                let connection = self.add_peer(connection);
514                self.delayed_callbacks.execute_resolved(&connection);
515
516                for callback in callbacks {
517                    _ = callback.send(Ok(connection.clone()));
518                }
519            }
520            Err(e) => {
521                tracing::debug!(
522                    local_id = %self.endpoint.peer_id(),
523                    peer_id = %res.target_peer_id,
524                    remote_addr = %res.target_address,
525                    "connection failed: {e:?}"
526                );
527
528                metrics::counter!(match res.origin {
529                    Direction::Outbound => METRIC_CONNECTIONS_OUT_FAIL_TOTAL,
530                    Direction::Inbound => METRIC_CONNECTIONS_IN_FAIL_TOTAL,
531                })
532                .increment(1);
533
534                let brief_error = e.as_brief();
535
536                // Delay sending the error to callbacks as the target peer might be
537                // in the process of connecting to us.
538                if e.was_closed() && !callbacks.is_empty() {
539                    self.delayed_callbacks.push(
540                        &res.target_peer_id,
541                        brief_error,
542                        callbacks,
543                        &self.config.connection_error_delay,
544                    );
545                    return;
546                }
547
548                for callback in callbacks {
549                    _ = callback.send(Err(brief_error));
550                }
551            }
552        }
553    }
554
555    fn add_peer(&mut self, connection: Connection) -> Connection {
556        match self.active_peers.add(self.endpoint.peer_id(), connection) {
557            AddedPeer::New(connection) => {
558                let origin = connection.origin();
559
560                let handler = InboundRequestHandler::new(
561                    self.config.clone(),
562                    connection.clone(),
563                    self.service.clone(),
564                    self.active_peers.clone(),
565                );
566
567                metrics::counter!(match origin {
568                    Direction::Outbound => METRIC_CONNECTIONS_OUT_TOTAL,
569                    Direction::Inbound => METRIC_CONNECTIONS_IN_TOTAL,
570                })
571                .increment(1);
572
573                metrics::gauge!(METRIC_CONNECTIONS_ACTIVE).increment(1);
574                self.connection_handlers.spawn(handler.start());
575
576                connection
577            }
578            AddedPeer::Existing(connection) => connection,
579        }
580    }
581
582    #[tracing::instrument(
583        level = "trace",
584        skip_all,
585        fields(
586            local_id = %self.endpoint.peer_id(),
587            peer_id = %peer_id,
588            remote_addr = %address,
589        ),
590    )]
591    fn dial_peer(&mut self, address: Address, peer_id: &PeerId, callback: CallbackTx) {
592        async fn dial_peer_task(
593            seqno: u32,
594            endpoint: Arc<Endpoint>,
595            address: Address,
596            peer_id: PeerId,
597            config: Arc<NetworkConfig>,
598        ) -> ConnectingOutput {
599            let fut = async {
600                let address = address
601                    .resolve()
602                    .await
603                    .map_err(FullConnectionError::InvalidAddress)?;
604
605                let connecting = endpoint
606                    .connect_with_expected_id(&address, &peer_id)
607                    .map_err(|e| FullConnectionError::InvalidAddress(std::io::Error::other(e)))?;
608
609                let connection = ConnectionClosedOnDrop::new(connecting.await?);
610                match handshake(&connection).await {
611                    Ok(()) => Ok(connection),
612                    Err(e) => Err(FullConnectionError::HandshakeFailed(e)),
613                }
614            };
615
616            let started_at = Instant::now();
617
618            let connecting_result = match tokio::time::timeout(config.connect_timeout, fut).await {
619                Ok(res) => res.map(ConnectionClosedOnDrop::disarm),
620                Err(_) => Err(FullConnectionError::Timeout),
621            };
622
623            metrics::histogram!(METRIC_CONNECTION_OUT_TIME).record(started_at.elapsed());
624
625            ConnectingOutput {
626                seqno,
627                drop_result: true,
628                connecting_result: ManuallyDrop::new(connecting_result),
629                target_address: address,
630                target_peer_id: peer_id,
631                origin: Direction::Outbound,
632            }
633        }
634
635        if let Some(connection) = self.active_peers.get(peer_id) {
636            tracing::debug!("peer is already connected");
637            _ = callback.send(Ok(connection));
638            return;
639        }
640
641        tracing::trace!("connecting to peer");
642
643        let entry = match self.pending_connection_callbacks.entry(address.clone()) {
644            hash_map::Entry::Vacant(entry) => Some(entry.insert(PendingConnectionCallbacks {
645                last_seqno: 0,
646                origin: Direction::Outbound,
647                callbacks: vec![callback],
648                abort_handle: None,
649            })),
650            hash_map::Entry::Occupied(entry) => {
651                let entry = entry.into_mut();
652
653                // Add the callback to the existing entry.
654                entry.callbacks.push(callback);
655
656                // Check if the outgoing connection is a simultaneous dial.
657                let break_tie = simultaneous_dial_tie_breaking(
658                    self.endpoint.peer_id(),
659                    peer_id,
660                    entry.origin,
661                    Direction::Outbound,
662                );
663
664                if break_tie && entry.origin != Direction::Outbound {
665                    // New connection wins the tie, abort the old one and spawn a new task.
666                    tracing::debug!("cancelling old connection to mitigate simultaneous dial");
667
668                    entry.origin = Direction::Outbound;
669                    entry.last_seqno += 1;
670                    if let Some(handle) = entry.abort_handle.take() {
671                        handle.abort();
672                    }
673                    Some(entry)
674                } else {
675                    // Old connection wins the tie, don't create a new one
676                    tracing::trace!("reusing old connection to mitigate simultaneous dial");
677                    None
678                }
679            }
680        };
681
682        if let Some(entry) = entry {
683            entry.abort_handle = Some(self.pending_connections.spawn(dial_peer_task(
684                entry.last_seqno,
685                self.endpoint.clone(),
686                address.clone(),
687                *peer_id,
688                self.config.clone(),
689            )));
690            metrics::gauge!(METRIC_CONNECTIONS_PENDING).increment(1);
691        }
692    }
693}
694
695struct PendingConnectionCallbacks {
696    last_seqno: u32,
697    origin: Direction,
698    callbacks: Vec<CallbackTx>,
699    abort_handle: Option<AbortHandle>,
700}
701
702struct PartialConnection {
703    connection: Connection,
704    timeout_at: Instant,
705}
706
707struct ConnectingOutput {
708    seqno: u32,
709    drop_result: bool,
710    connecting_result: ManuallyDrop<Result<Connection, FullConnectionError>>,
711    target_address: Address,
712    target_peer_id: PeerId,
713    origin: Direction,
714}
715
716impl Drop for ConnectingOutput {
717    fn drop(&mut self) {
718        if self.drop_result {
719            // SAFETY: `drop_result` is set to `true` only when the result is not used.
720            unsafe { ManuallyDrop::drop(&mut self.connecting_result) };
721        }
722    }
723}
724
725struct ConnectionClosedOnDrop {
726    connection: ManuallyDrop<Connection>,
727    close_on_drop: bool,
728}
729
730impl ConnectionClosedOnDrop {
731    fn new(connection: Connection) -> Self {
732        Self {
733            connection: ManuallyDrop::new(connection),
734            close_on_drop: true,
735        }
736    }
737
738    fn disarm(mut self) -> Connection {
739        self.close_on_drop = false;
740        // SAFETY: `drop` will not be called.
741        unsafe { ManuallyDrop::take(&mut self.connection) }
742    }
743}
744
745impl std::ops::Deref for ConnectionClosedOnDrop {
746    type Target = Connection;
747
748    #[inline]
749    fn deref(&self) -> &Self::Target {
750        &self.connection
751    }
752}
753
754impl Drop for ConnectionClosedOnDrop {
755    fn drop(&mut self) {
756        if self.close_on_drop {
757            // SAFETY: `disarm` was not called.
758            let connection = unsafe { ManuallyDrop::take(&mut self.connection) };
759            connection.close();
760        }
761    }
762}
763
764#[derive(Default)]
765struct DelayedCallbacksQueue {
766    callbacks: FastHashMap<PeerId, VecDeque<DelayedCallbacks>>,
767    expirations: DelayQueue<PeerId>,
768}
769
770impl DelayedCallbacksQueue {
771    fn push(
772        &mut self,
773        peer_id: &PeerId,
774        error: ConnectionError,
775        callbacks: Vec<CallbackTx>,
776        delay: &Duration,
777    ) {
778        tracing::debug!(%peer_id, %error, "delayed connection error");
779
780        let expires_at = Instant::now() + *delay;
781        let delay_key = self.expirations.insert_at(*peer_id, expires_at.into());
782
783        let items = self.callbacks.entry(*peer_id).or_default();
784        items.push_back(DelayedCallbacks {
785            delay_key,
786            error,
787            callbacks,
788            expires_at,
789        });
790    }
791
792    async fn wait_for_next_expired(&mut self) -> Option<PeerId> {
793        let res = futures_util::future::poll_fn(|cx| self.expirations.poll_expired(cx)).await?;
794        Some(res.into_inner())
795    }
796
797    fn execute_resolved(&mut self, connection: &Connection) {
798        let Some(items) = self.callbacks.remove(connection.peer_id()) else {
799            return;
800        };
801
802        let mut batches_executed = 0;
803        let mut callbacks_executed = 0;
804
805        for delayed in items {
806            batches_executed += 1;
807            callbacks_executed += delayed.callbacks.len();
808
809            let key = delayed.execute_with_ok(connection);
810
811            // NOTE: Delay key must exist in the queue.
812            self.expirations.remove(&key);
813        }
814
815        tracing::debug!(
816            peer_id = %connection.peer_id(),
817            batches_executed,
818            callbacks_executed,
819            "executed all delayed callbacks",
820        );
821    }
822
823    fn execute_expired(&mut self, peer_id: &PeerId) {
824        let now = Instant::now();
825
826        let mut batches_executed = 0;
827        let mut callbacks_executed = 0;
828
829        'outer: {
830            if let Some(items) = self.callbacks.get_mut(peer_id) {
831                while let Some(front) = items.front() {
832                    if !front.is_expired(&now) {
833                        break 'outer;
834                    }
835
836                    if let Some(delayed) = items.pop_front() {
837                        batches_executed += 1;
838                        callbacks_executed += delayed.callbacks.len();
839
840                        let key = delayed.execute_with_error();
841
842                        // NOTE: Might not be necessary since items are stored in order,
843                        // but it's better to be safe.
844                        self.expirations.try_remove(&key);
845                    }
846                }
847            }
848
849            // There is no need to hold an empty queue for this peer
850            self.callbacks.remove(peer_id);
851        }
852
853        tracing::debug!(
854            %peer_id,
855            batches_executed,
856            callbacks_executed,
857            "executed expired delayed callbacks"
858        );
859    }
860}
861
862struct DelayedCallbacks {
863    delay_key: delay_queue::Key,
864    error: ConnectionError,
865    callbacks: Vec<CallbackTx>,
866    expires_at: Instant,
867}
868
869impl DelayedCallbacks {
870    fn execute_with_ok(self, connection: &Connection) -> delay_queue::Key {
871        for callback in self.callbacks {
872            _ = callback.send(Ok(connection.clone()));
873        }
874        self.delay_key
875    }
876
877    fn execute_with_error(self) -> delay_queue::Key {
878        for callback in self.callbacks {
879            _ = callback.send(Err(self.error));
880        }
881        self.delay_key
882    }
883
884    fn is_expired(&self, now: &Instant) -> bool {
885        *now >= self.expires_at
886    }
887}
888
889#[derive(Debug)]
890struct DialBackoffState {
891    next_attempt_at: Instant,
892    attempts: usize,
893}
894
895impl DialBackoffState {
896    fn new(now: Instant, step: Duration, max: Duration) -> Self {
897        let mut state = Self {
898            next_attempt_at: now,
899            attempts: 0,
900        };
901        state.update(now, step, max);
902        state
903    }
904
905    fn update(&mut self, now: Instant, step: Duration, max: Duration) {
906        self.attempts += 1;
907        self.next_attempt_at = now
908            + std::cmp::min(
909                max,
910                step.saturating_mul(self.attempts.try_into().unwrap_or(u32::MAX)),
911            );
912    }
913}
914
915#[derive(Debug, thiserror::Error)]
916enum FullConnectionError {
917    #[error("invalid address")]
918    InvalidAddress(#[source] std::io::Error),
919    #[error(transparent)]
920    ConnectionFailed(quinn::ConnectionError),
921    #[error("invalid certificate")]
922    InvalidCertificate,
923    #[error("handshake failed")]
924    HandshakeFailed(#[source] HandshakeError),
925    #[error("connection timeout")]
926    Timeout,
927}
928
929impl FullConnectionError {
930    fn as_brief(&self) -> ConnectionError {
931        fn is_crypto_error(error: &quinn::ConnectionError) -> bool {
932            const QUIC_CRYPTO_FLAG: u64 = 0x100;
933
934            matches!(
935                error,
936                quinn::ConnectionError::TransportError(e)
937                if u64::from(e.code) & QUIC_CRYPTO_FLAG != 0
938            )
939        }
940
941        match self {
942            Self::InvalidAddress(_) => ConnectionError::InvalidAddress,
943            Self::ConnectionFailed(e) if is_crypto_error(e) => ConnectionError::InvalidCertificate,
944            Self::ConnectionFailed(_) => ConnectionError::ConnectionInitFailed,
945            Self::InvalidCertificate => ConnectionError::InvalidCertificate,
946            Self::HandshakeFailed(HandshakeError::ConnectionFailed(e)) if is_crypto_error(e) => {
947                ConnectionError::InvalidCertificate
948            }
949            Self::HandshakeFailed(_) => ConnectionError::HandshakeFailed,
950            Self::Timeout => ConnectionError::Timeout,
951        }
952    }
953
954    fn was_closed(&self) -> bool {
955        let connection_error = match self {
956            Self::ConnectionFailed(e)
957            | Self::HandshakeFailed(HandshakeError::ConnectionFailed(e)) => e,
958            _ => return false,
959        };
960
961        matches!(
962            connection_error,
963            quinn::ConnectionError::ApplicationClosed(closed)
964            if closed.error_code.into_inner() == 0
965        )
966    }
967}
968
969impl From<ConnectionInitError> for FullConnectionError {
970    fn from(value: ConnectionInitError) -> Self {
971        match value {
972            ConnectionInitError::ConnectionFailed(e) => Self::ConnectionFailed(e),
973            ConnectionInitError::InvalidCertificate => Self::InvalidCertificate,
974        }
975    }
976}
977
978#[derive(Clone)]
979pub(crate) struct ActivePeers(Arc<ActivePeersInner>);
980
981impl ActivePeers {
982    pub fn new(channel_size: usize) -> Self {
983        Self(Arc::new(ActivePeersInner::new(channel_size)))
984    }
985
986    pub fn get(&self, peer_id: &PeerId) -> Option<Connection> {
987        self.0.get(peer_id)
988    }
989
990    pub fn contains(&self, peer_id: &PeerId) -> bool {
991        self.0.contains(peer_id)
992    }
993
994    pub fn add(&self, local_id: &PeerId, new_connection: Connection) -> AddedPeer {
995        self.0.add(local_id, new_connection)
996    }
997
998    pub fn remove(&self, peer_id: &PeerId, reason: DisconnectReason) {
999        self.0.remove(peer_id, reason);
1000    }
1001
1002    pub fn remove_with_stable_id(
1003        &self,
1004        peer_id: &PeerId,
1005        stable_id: usize,
1006        reason: DisconnectReason,
1007    ) {
1008        self.0.remove_with_stable_id(peer_id, stable_id, reason);
1009    }
1010
1011    pub fn subscribe(&self) -> broadcast::Receiver<PeerEvent> {
1012        self.0.subscribe()
1013    }
1014
1015    pub fn is_empty(&self) -> bool {
1016        self.0.is_empty()
1017    }
1018
1019    pub fn len(&self) -> usize {
1020        self.0.len()
1021    }
1022}
1023
1024struct ActivePeersInner {
1025    connections: FastDashMap<PeerId, Connection>,
1026    connections_len: AtomicUsize,
1027    events_tx: broadcast::Sender<PeerEvent>,
1028}
1029
1030impl ActivePeersInner {
1031    fn new(channel_size: usize) -> Self {
1032        let (events_tx, _) = broadcast::channel(channel_size);
1033        Self {
1034            connections: Default::default(),
1035            connections_len: Default::default(),
1036            events_tx,
1037        }
1038    }
1039
1040    fn get(&self, peer_id: &PeerId) -> Option<Connection> {
1041        self.connections
1042            .get(peer_id)
1043            .map(|item| item.value().clone())
1044    }
1045
1046    fn contains(&self, peer_id: &PeerId) -> bool {
1047        self.connections.contains_key(peer_id)
1048    }
1049
1050    #[must_use]
1051    fn add(&self, local_id: &PeerId, new_connection: Connection) -> AddedPeer {
1052        use dashmap::mapref::entry::Entry;
1053
1054        let mut added = false;
1055
1056        let peer_id = new_connection.peer_id();
1057        match self.connections.entry(*peer_id) {
1058            Entry::Occupied(mut entry) => {
1059                if simultaneous_dial_tie_breaking(
1060                    local_id,
1061                    peer_id,
1062                    entry.get().origin(),
1063                    new_connection.origin(),
1064                ) {
1065                    tracing::debug!(%peer_id, "closing old connection to mitigate simultaneous dial");
1066                    let old_connection = entry.insert(new_connection.clone());
1067                    old_connection.close();
1068                    self.send_event(PeerEvent::lost_peer(*peer_id, DisconnectReason::Requested));
1069                } else {
1070                    tracing::debug!(%peer_id, "closing new connection to mitigate simultaneous dial");
1071                    new_connection.close();
1072                    return AddedPeer::Existing(entry.get().clone());
1073                }
1074            }
1075            Entry::Vacant(entry) => {
1076                self.connections_len.fetch_add(1, Ordering::Release);
1077                entry.insert(new_connection.clone());
1078                added = true;
1079            }
1080        }
1081
1082        self.send_event(PeerEvent::new_peer(*peer_id));
1083
1084        if added {
1085            metrics::gauge!(METRIC_ACTIVE_PEERS).increment(1);
1086        }
1087        AddedPeer::New(new_connection)
1088    }
1089
1090    fn remove(&self, peer_id: &PeerId, reason: DisconnectReason) {
1091        if let Some((_, connection)) = self.connections.remove(peer_id) {
1092            connection.close();
1093            self.connections_len.fetch_sub(1, Ordering::Release);
1094            self.send_event(PeerEvent::lost_peer(*peer_id, reason));
1095
1096            metrics::gauge!(METRIC_ACTIVE_PEERS).decrement(1);
1097        }
1098    }
1099
1100    fn remove_with_stable_id(&self, peer_id: &PeerId, stable_id: usize, reason: DisconnectReason) {
1101        if let Some((_, connection)) = self
1102            .connections
1103            .remove_if(peer_id, |_, connection| connection.stable_id() == stable_id)
1104        {
1105            connection.close();
1106            self.connections_len.fetch_sub(1, Ordering::Release);
1107            self.send_event(PeerEvent::lost_peer(*peer_id, reason));
1108
1109            metrics::gauge!(METRIC_ACTIVE_PEERS).decrement(1);
1110        }
1111    }
1112
1113    fn subscribe(&self) -> broadcast::Receiver<PeerEvent> {
1114        self.events_tx.subscribe()
1115    }
1116
1117    fn send_event(&self, event: PeerEvent) {
1118        _ = self.events_tx.send(event);
1119    }
1120
1121    fn is_empty(&self) -> bool {
1122        self.connections.is_empty()
1123    }
1124
1125    fn len(&self) -> usize {
1126        self.connections_len.load(Ordering::Acquire)
1127    }
1128}
1129
1130pub(crate) enum AddedPeer {
1131    New(Connection),
1132    Existing(Connection),
1133}
1134
1135fn simultaneous_dial_tie_breaking(
1136    local_id: &PeerId,
1137    peer_id: &PeerId,
1138    old_origin: Direction,
1139    new_origin: Direction,
1140) -> bool {
1141    match (old_origin, new_origin) {
1142        (Direction::Inbound, Direction::Inbound) | (Direction::Outbound, Direction::Outbound) => {
1143            true
1144        }
1145        (Direction::Inbound, Direction::Outbound) => peer_id < local_id,
1146        (Direction::Outbound, Direction::Inbound) => local_id < peer_id,
1147    }
1148}
1149
1150#[derive(Default, Clone)]
1151#[repr(transparent)]
1152pub struct KnownPeers(Arc<FastDashMap<PeerId, KnownPeerState>>);
1153
1154impl KnownPeers {
1155    pub fn new() -> Self {
1156        Self::default()
1157    }
1158
1159    pub fn contains(&self, peer_id: &PeerId) -> bool {
1160        self.0.contains_key(peer_id)
1161    }
1162
1163    pub fn is_banned(&self, peer_id: &PeerId) -> bool {
1164        self.0
1165            .get(peer_id)
1166            .and_then(|item| {
1167                Some(match item.value() {
1168                    KnownPeerState::Stored(item) => item.upgrade()?.is_banned(),
1169                    KnownPeerState::Banned => true,
1170                })
1171            })
1172            .unwrap_or_default()
1173    }
1174
1175    pub fn get(&self, peer_id: &PeerId) -> Option<Arc<PeerInfo>> {
1176        self.0.get(peer_id).and_then(|item| match item.value() {
1177            KnownPeerState::Stored(item) => {
1178                let inner = item.upgrade()?;
1179                Some(inner.peer_info.load_full())
1180            }
1181            KnownPeerState::Banned => None,
1182        })
1183    }
1184
1185    pub fn get_affinity(&self, peer_id: &PeerId) -> Option<PeerAffinity> {
1186        self.0
1187            .get(peer_id)
1188            .and_then(|item| item.value().compute_affinity())
1189    }
1190
1191    pub fn remove(&self, peer_id: &PeerId) {
1192        self.0.remove(peer_id);
1193        metrics::gauge!(METRIC_KNOWN_PEERS).decrement(1);
1194    }
1195
1196    pub fn ban(&self, peer_id: &PeerId) {
1197        let mut added = false;
1198        match self.0.entry(*peer_id) {
1199            dashmap::mapref::entry::Entry::Vacant(entry) => {
1200                entry.insert(KnownPeerState::Banned);
1201                added = true;
1202            }
1203            dashmap::mapref::entry::Entry::Occupied(mut entry) => match entry.get_mut() {
1204                KnownPeerState::Banned => {}
1205                KnownPeerState::Stored(item) => match item.upgrade() {
1206                    Some(item) => item.affinity.store(AFFINITY_BANNED, Ordering::Release),
1207                    None => *entry.get_mut() = KnownPeerState::Banned,
1208                },
1209            },
1210        }
1211
1212        if added {
1213            // NOTE: "New" banned peer is a "new" known peer.
1214            metrics::gauge!(METRIC_KNOWN_PEERS).increment(1);
1215        }
1216    }
1217
1218    pub fn make_handle(&self, peer_id: &PeerId, with_affinity: bool) -> Option<KnownPeerHandle> {
1219        let inner = match self.0.get(peer_id)?.value() {
1220            KnownPeerState::Stored(item) => {
1221                let inner = item.upgrade()?;
1222                if with_affinity && !inner.increase_affinity() {
1223                    return None;
1224                }
1225                inner
1226            }
1227            KnownPeerState::Banned => return None,
1228        };
1229
1230        Some(KnownPeerHandle::from_inner(inner, with_affinity))
1231    }
1232
1233    /// Inserts a new handle only if the provided info is not outdated
1234    /// and the peer is not banned.
1235    pub fn insert(
1236        &self,
1237        peer_info: Arc<PeerInfo>,
1238        with_affinity: bool,
1239    ) -> Result<KnownPeerHandle, KnownPeersError> {
1240        // TODO: add capacity limit for entries without affinity
1241        let mut added = false;
1242        let inner = match self.0.entry(peer_info.id) {
1243            dashmap::mapref::entry::Entry::Vacant(entry) => {
1244                let inner = KnownPeerInner::new(peer_info, with_affinity, &self.0);
1245                entry.insert(KnownPeerState::Stored(Arc::downgrade(&inner)));
1246                added = true;
1247                inner
1248            }
1249            dashmap::mapref::entry::Entry::Occupied(mut entry) => match entry.get_mut() {
1250                KnownPeerState::Banned => return Err(KnownPeersError::from(PeerBannedError)),
1251                KnownPeerState::Stored(item) => match item.upgrade() {
1252                    Some(inner) => match inner.try_update_peer_info(&peer_info, with_affinity)? {
1253                        true => inner,
1254                        false => return Err(KnownPeersError::OutdatedInfo),
1255                    },
1256                    None => {
1257                        let inner = KnownPeerInner::new(peer_info, with_affinity, &self.0);
1258                        *item = Arc::downgrade(&inner);
1259                        inner
1260                    }
1261                },
1262            },
1263        };
1264
1265        if added {
1266            metrics::gauge!(METRIC_KNOWN_PEERS).increment(1);
1267        }
1268
1269        Ok(KnownPeerHandle::from_inner(inner, with_affinity))
1270    }
1271
1272    /// Same as [`KnownPeers::insert`], but ignores outdated info.
1273    pub fn insert_allow_outdated(
1274        &self,
1275        peer_info: Arc<PeerInfo>,
1276        with_affinity: bool,
1277    ) -> Result<KnownPeerHandle, PeerBannedError> {
1278        // TODO: add capacity limit for entries without affinity
1279        let mut added = false;
1280        let inner = match self.0.entry(peer_info.id) {
1281            dashmap::mapref::entry::Entry::Vacant(entry) => {
1282                let inner = KnownPeerInner::new(peer_info, with_affinity, &self.0);
1283                entry.insert(KnownPeerState::Stored(Arc::downgrade(&inner)));
1284                added = true;
1285                inner
1286            }
1287            dashmap::mapref::entry::Entry::Occupied(mut entry) => match entry.get_mut() {
1288                KnownPeerState::Banned => return Err(PeerBannedError),
1289                KnownPeerState::Stored(item) => match item.upgrade() {
1290                    Some(inner) => {
1291                        // NOTE: Outdated info is ignored here.
1292                        inner.try_update_peer_info(&peer_info, with_affinity)?;
1293                        inner
1294                    }
1295                    None => {
1296                        let inner = KnownPeerInner::new(peer_info, with_affinity, &self.0);
1297                        *item = Arc::downgrade(&inner);
1298                        inner
1299                    }
1300                },
1301            },
1302        };
1303
1304        if added {
1305            metrics::gauge!(METRIC_KNOWN_PEERS).increment(1);
1306        }
1307
1308        Ok(KnownPeerHandle::from_inner(inner, with_affinity))
1309    }
1310}
1311
1312enum KnownPeerState {
1313    Stored(Weak<KnownPeerInner>),
1314    Banned,
1315}
1316
1317impl KnownPeerState {
1318    fn compute_affinity(&self) -> Option<PeerAffinity> {
1319        Some(match self {
1320            Self::Stored(weak) => weak.upgrade()?.compute_affinity(),
1321            Self::Banned => PeerAffinity::Never,
1322        })
1323    }
1324}
1325
1326#[derive(Clone)]
1327#[repr(transparent)]
1328pub struct KnownPeerHandle(KnownPeerHandleState);
1329
1330impl KnownPeerHandle {
1331    fn from_inner(inner: Arc<KnownPeerInner>, with_affinity: bool) -> Self {
1332        KnownPeerHandle(if with_affinity {
1333            KnownPeerHandleState::WithAffinity(ManuallyDrop::new(Arc::new(
1334                KnownPeerHandleWithAffinity { inner },
1335            )))
1336        } else {
1337            KnownPeerHandleState::Simple(ManuallyDrop::new(inner))
1338        })
1339    }
1340
1341    pub fn peer_info(&self) -> arc_swap::Guard<Arc<PeerInfo>, arc_swap::DefaultStrategy> {
1342        self.inner().peer_info.load()
1343    }
1344
1345    pub fn load_peer_info(&self) -> Arc<PeerInfo> {
1346        arc_swap::Guard::into_inner(self.peer_info())
1347    }
1348
1349    pub fn is_banned(&self) -> bool {
1350        self.inner().is_banned()
1351    }
1352
1353    pub fn max_affinity(&self) -> PeerAffinity {
1354        self.inner().compute_affinity()
1355    }
1356
1357    pub fn update_peer_info(&self, peer_info: &Arc<PeerInfo>) -> Result<(), KnownPeersError> {
1358        match self.inner().try_update_peer_info(peer_info, false) {
1359            Ok(true) => Ok(()),
1360            Ok(false) => Err(KnownPeersError::OutdatedInfo),
1361            Err(e) => Err(KnownPeersError::PeerBanned(e)),
1362        }
1363    }
1364
1365    pub fn ban(&self) -> bool {
1366        let inner = self.inner();
1367        inner.affinity.swap(AFFINITY_BANNED, Ordering::AcqRel) != AFFINITY_BANNED
1368    }
1369
1370    pub fn increase_affinity(&mut self) -> bool {
1371        match &mut self.0 {
1372            KnownPeerHandleState::Simple(inner) => {
1373                // NOTE: Handle will be updated even if the peer is banned.
1374                inner.increase_affinity();
1375
1376                // SAFETY: Inner value was not dropped.
1377                let inner = unsafe { ManuallyDrop::take(inner) };
1378
1379                // Replace the old state with the new one, ensuring that the old state
1380                // is not dropped (because we took the value out of it).
1381                let prev_state = std::mem::replace(
1382                    &mut self.0,
1383                    KnownPeerHandleState::WithAffinity(ManuallyDrop::new(Arc::new(
1384                        KnownPeerHandleWithAffinity { inner },
1385                    ))),
1386                );
1387
1388                // Forget the old state to avoid dropping it.
1389                #[allow(clippy::mem_forget)]
1390                std::mem::forget(prev_state);
1391
1392                true
1393            }
1394            KnownPeerHandleState::WithAffinity(_) => false,
1395        }
1396    }
1397
1398    pub fn decrease_affinity(&mut self) -> bool {
1399        match &mut self.0 {
1400            KnownPeerHandleState::Simple(_) => false,
1401            KnownPeerHandleState::WithAffinity(inner) => {
1402                // NOTE: Handle will be updated even if the peer is banned.
1403                inner.inner.decrease_affinity();
1404
1405                // SAFETY: Inner value was not dropped.
1406                let inner = unsafe { ManuallyDrop::take(inner) };
1407
1408                // Get `KnownPeerInner` out of the wrapper.
1409                let inner = match Arc::try_unwrap(inner) {
1410                    Ok(KnownPeerHandleWithAffinity { inner }) => inner,
1411                    Err(inner) => inner.inner.clone(),
1412                };
1413
1414                // Replace the old state with the new one, ensuring that the old state
1415                // is not dropped (because we took the value out of it).
1416                let prev_state = std::mem::replace(
1417                    &mut self.0,
1418                    KnownPeerHandleState::Simple(ManuallyDrop::new(inner)),
1419                );
1420
1421                // Forget the old state to avoid dropping it.
1422                #[allow(clippy::mem_forget)]
1423                std::mem::forget(prev_state);
1424
1425                true
1426            }
1427        }
1428    }
1429
1430    pub fn downgrade(&self) -> WeakKnownPeerHandle {
1431        WeakKnownPeerHandle(match &self.0 {
1432            KnownPeerHandleState::Simple(data) => {
1433                WeakKnownPeerHandleState::Simple(Arc::downgrade(data))
1434            }
1435            KnownPeerHandleState::WithAffinity(data) => {
1436                WeakKnownPeerHandleState::WithAffinity(Arc::downgrade(data))
1437            }
1438        })
1439    }
1440
1441    fn inner(&self) -> &KnownPeerInner {
1442        match &self.0 {
1443            KnownPeerHandleState::Simple(data) => data.as_ref(),
1444            KnownPeerHandleState::WithAffinity(data) => data.inner.as_ref(),
1445        }
1446    }
1447}
1448
1449#[derive(Clone)]
1450enum KnownPeerHandleState {
1451    Simple(ManuallyDrop<Arc<KnownPeerInner>>),
1452    WithAffinity(ManuallyDrop<Arc<KnownPeerHandleWithAffinity>>),
1453}
1454
1455impl Drop for KnownPeerHandleState {
1456    fn drop(&mut self) {
1457        let inner;
1458        let is_banned;
1459        match self {
1460            KnownPeerHandleState::Simple(data) => {
1461                // SAFETY: inner value is dropped only once
1462                inner = unsafe { ManuallyDrop::take(data) };
1463                is_banned = inner.is_banned();
1464            }
1465            KnownPeerHandleState::WithAffinity(data) => {
1466                // SAFETY: inner value is dropped only once
1467                match Arc::into_inner(unsafe { ManuallyDrop::take(data) }) {
1468                    Some(data) => {
1469                        inner = data.inner;
1470                        is_banned = !inner.decrease_affinity() || inner.is_banned();
1471                    }
1472                    None => return,
1473                }
1474            }
1475        };
1476
1477        if is_banned {
1478            // Don't remove banned peers from the known peers cache
1479            return;
1480        }
1481
1482        if let Some(inner) = Arc::into_inner(inner) {
1483            // If the last reference is dropped, remove the peer from the known peers cache
1484            if let Some(peers) = inner.weak_known_peers.upgrade() {
1485                peers.remove(&inner.peer_info.load().id);
1486                metrics::gauge!(METRIC_KNOWN_PEERS).decrement(1);
1487            }
1488        }
1489    }
1490}
1491
1492#[derive(Clone, PartialEq, Eq)]
1493#[repr(transparent)]
1494pub struct WeakKnownPeerHandle(WeakKnownPeerHandleState);
1495
1496impl WeakKnownPeerHandle {
1497    pub fn upgrade(&self) -> Option<KnownPeerHandle> {
1498        Some(KnownPeerHandle(match &self.0 {
1499            WeakKnownPeerHandleState::Simple(weak) => {
1500                KnownPeerHandleState::Simple(ManuallyDrop::new(weak.upgrade()?))
1501            }
1502            WeakKnownPeerHandleState::WithAffinity(weak) => {
1503                KnownPeerHandleState::WithAffinity(ManuallyDrop::new(weak.upgrade()?))
1504            }
1505        }))
1506    }
1507}
1508
1509#[derive(Clone)]
1510enum WeakKnownPeerHandleState {
1511    Simple(Weak<KnownPeerInner>),
1512    WithAffinity(Weak<KnownPeerHandleWithAffinity>),
1513}
1514
1515impl Eq for WeakKnownPeerHandleState {}
1516impl PartialEq for WeakKnownPeerHandleState {
1517    #[inline]
1518    fn eq(&self, other: &Self) -> bool {
1519        match (self, other) {
1520            (Self::Simple(left), Self::Simple(right)) => Weak::ptr_eq(left, right),
1521            (Self::WithAffinity(left), Self::WithAffinity(right)) => Weak::ptr_eq(left, right),
1522            _ => false,
1523        }
1524    }
1525}
1526
1527struct KnownPeerHandleWithAffinity {
1528    inner: Arc<KnownPeerInner>,
1529}
1530
1531struct KnownPeerInner {
1532    peer_info: ArcSwap<PeerInfo>,
1533    affinity: AtomicUsize,
1534    weak_known_peers: Weak<FastDashMap<PeerId, KnownPeerState>>,
1535}
1536
1537impl KnownPeerInner {
1538    fn new(
1539        peer_info: Arc<PeerInfo>,
1540        with_affinity: bool,
1541        known_peers: &Arc<FastDashMap<PeerId, KnownPeerState>>,
1542    ) -> Arc<Self> {
1543        Arc::new(Self {
1544            peer_info: ArcSwap::from(peer_info),
1545            affinity: AtomicUsize::new(if with_affinity { 1 } else { 0 }),
1546            weak_known_peers: Arc::downgrade(known_peers),
1547        })
1548    }
1549
1550    fn is_banned(&self) -> bool {
1551        self.affinity.load(Ordering::Acquire) == AFFINITY_BANNED
1552    }
1553
1554    fn compute_affinity(&self) -> PeerAffinity {
1555        match self.affinity.load(Ordering::Acquire) {
1556            0 => PeerAffinity::Allowed,
1557            AFFINITY_BANNED => PeerAffinity::Never,
1558            _ => PeerAffinity::High,
1559        }
1560    }
1561
1562    fn increase_affinity(&self) -> bool {
1563        let mut current = self.affinity.load(Ordering::Acquire);
1564        while current != AFFINITY_BANNED {
1565            debug_assert_ne!(current, AFFINITY_BANNED - 1);
1566            match self.affinity.compare_exchange_weak(
1567                current,
1568                current + 1,
1569                Ordering::Release,
1570                Ordering::Acquire,
1571            ) {
1572                Ok(_) => return true,
1573                Err(affinity) => current = affinity,
1574            }
1575        }
1576
1577        false
1578    }
1579
1580    fn decrease_affinity(&self) -> bool {
1581        let mut current = self.affinity.load(Ordering::Acquire);
1582        while current != AFFINITY_BANNED {
1583            debug_assert_ne!(current, 0);
1584            match self.affinity.compare_exchange_weak(
1585                current,
1586                current - 1,
1587                Ordering::Release,
1588                Ordering::Acquire,
1589            ) {
1590                Ok(_) => return true,
1591                Err(affinity) => current = affinity,
1592            }
1593        }
1594
1595        false
1596    }
1597
1598    fn try_update_peer_info(
1599        &self,
1600        peer_info: &Arc<PeerInfo>,
1601        with_affinity: bool,
1602    ) -> Result<bool, PeerBannedError> {
1603        struct AffinityGuard<'a> {
1604            inner: &'a KnownPeerInner,
1605            decrease_on_drop: bool,
1606        }
1607
1608        impl AffinityGuard<'_> {
1609            fn increase_affinity_or_check_ban(&mut self, with_affinity: bool) -> bool {
1610                let with_affinity = with_affinity && !self.decrease_on_drop;
1611                let is_banned = if with_affinity {
1612                    !self.inner.increase_affinity()
1613                } else {
1614                    self.inner.is_banned()
1615                };
1616
1617                if !is_banned && with_affinity {
1618                    self.decrease_on_drop = true;
1619                }
1620
1621                is_banned
1622            }
1623        }
1624
1625        impl Drop for AffinityGuard<'_> {
1626            fn drop(&mut self) {
1627                if self.decrease_on_drop {
1628                    self.inner.decrease_affinity();
1629                }
1630            }
1631        }
1632
1633        // Create a guard to restore the peer affinity in case of an error
1634        let mut guard = AffinityGuard {
1635            inner: self,
1636            decrease_on_drop: false,
1637        };
1638
1639        let mut cur = self.peer_info.load();
1640        let updated = loop {
1641            if guard.increase_affinity_or_check_ban(with_affinity) {
1642                // Do nothing for banned peers
1643                return Err(PeerBannedError);
1644            }
1645
1646            match cur.created_at.cmp(&peer_info.created_at) {
1647                // Do nothing for the same creation time
1648                // TODO: is `created_at` equality enough?
1649                std::cmp::Ordering::Equal => break true,
1650                // Try to update peer info
1651                std::cmp::Ordering::Less => {
1652                    let prev = self.peer_info.compare_and_swap(&*cur, peer_info.clone());
1653                    if std::ptr::eq(cur.as_raw(), prev.as_raw()) {
1654                        break true;
1655                    } else {
1656                        cur = prev;
1657                    }
1658                }
1659                // Allow an outdated data
1660                std::cmp::Ordering::Greater => break false,
1661            }
1662        };
1663
1664        guard.decrease_on_drop = false;
1665        Ok(updated)
1666    }
1667}
1668
1669const AFFINITY_BANNED: usize = usize::MAX;
1670
1671#[derive(Debug, thiserror::Error)]
1672pub enum KnownPeersError {
1673    #[error(transparent)]
1674    PeerBanned(#[from] PeerBannedError),
1675    #[error("provided peer info is outdated")]
1676    OutdatedInfo,
1677}
1678
1679#[derive(Debug, Copy, Clone, thiserror::Error)]
1680#[error("peer is banned")]
1681pub struct PeerBannedError;
1682
1683#[cfg(test)]
1684mod tests {
1685    use super::*;
1686    use crate::util::make_peer_info_stub;
1687
1688    #[test]
1689    fn remove_from_cache_on_drop_works() {
1690        let peers = KnownPeers::new();
1691
1692        let peer_info = make_peer_info_stub(rand::random());
1693        let handle = peers.insert(peer_info.clone(), false).unwrap();
1694        assert!(peers.contains(&peer_info.id));
1695        assert!(!peers.is_banned(&peer_info.id));
1696        assert_eq!(peers.get(&peer_info.id), Some(peer_info.clone()));
1697        assert_eq!(
1698            peers.get_affinity(&peer_info.id),
1699            Some(PeerAffinity::Allowed)
1700        );
1701
1702        assert_eq!(handle.peer_info().as_ref(), peer_info.as_ref());
1703        assert_eq!(handle.max_affinity(), PeerAffinity::Allowed);
1704
1705        let other_handle = peers.insert(peer_info.clone(), false).unwrap();
1706        assert!(peers.contains(&peer_info.id));
1707        assert!(!peers.is_banned(&peer_info.id));
1708        assert_eq!(peers.get(&peer_info.id), Some(peer_info.clone()));
1709        assert_eq!(
1710            peers.get_affinity(&peer_info.id),
1711            Some(PeerAffinity::Allowed)
1712        );
1713
1714        assert_eq!(other_handle.peer_info().as_ref(), peer_info.as_ref());
1715        assert_eq!(other_handle.max_affinity(), PeerAffinity::Allowed);
1716
1717        drop(other_handle);
1718        assert!(peers.contains(&peer_info.id));
1719        assert!(!peers.is_banned(&peer_info.id));
1720        assert_eq!(peers.get(&peer_info.id), Some(peer_info.clone()));
1721        assert_eq!(
1722            peers.get_affinity(&peer_info.id),
1723            Some(PeerAffinity::Allowed)
1724        );
1725
1726        drop(handle);
1727        assert!(!peers.contains(&peer_info.id));
1728        assert!(!peers.is_banned(&peer_info.id));
1729        assert_eq!(peers.get(&peer_info.id), None);
1730        assert_eq!(peers.get_affinity(&peer_info.id), None);
1731
1732        peers.insert(peer_info.clone(), false).unwrap();
1733    }
1734
1735    #[test]
1736    fn with_affinity_after_simple() {
1737        let peers = KnownPeers::new();
1738
1739        let peer_info = make_peer_info_stub(rand::random());
1740        let handle_simple = peers.insert(peer_info.clone(), false).unwrap();
1741        assert!(peers.contains(&peer_info.id));
1742        assert_eq!(
1743            peers.get_affinity(&peer_info.id),
1744            Some(PeerAffinity::Allowed)
1745        );
1746        assert_eq!(handle_simple.max_affinity(), PeerAffinity::Allowed);
1747
1748        let handle_with_affinity = peers.insert(peer_info.clone(), true).unwrap();
1749        assert!(peers.contains(&peer_info.id));
1750        assert_eq!(peers.get_affinity(&peer_info.id), Some(PeerAffinity::High));
1751        assert_eq!(handle_with_affinity.max_affinity(), PeerAffinity::High);
1752        assert_eq!(handle_simple.max_affinity(), PeerAffinity::High);
1753
1754        drop(handle_with_affinity);
1755        assert!(peers.contains(&peer_info.id));
1756        assert_eq!(handle_simple.max_affinity(), PeerAffinity::Allowed);
1757        assert_eq!(
1758            peers.get_affinity(&peer_info.id),
1759            Some(PeerAffinity::Allowed)
1760        );
1761
1762        drop(handle_simple);
1763        assert!(!peers.contains(&peer_info.id));
1764        assert_eq!(peers.get_affinity(&peer_info.id), None);
1765    }
1766
1767    #[test]
1768    fn with_affinity_before_simple() {
1769        let peers = KnownPeers::new();
1770
1771        let peer_info = make_peer_info_stub(rand::random());
1772        let handle_with_affinity = peers.insert(peer_info.clone(), true).unwrap();
1773        assert!(peers.contains(&peer_info.id));
1774        assert_eq!(peers.get_affinity(&peer_info.id), Some(PeerAffinity::High));
1775        assert_eq!(handle_with_affinity.max_affinity(), PeerAffinity::High);
1776
1777        let handle_simple = peers.insert(peer_info.clone(), false).unwrap();
1778        assert!(peers.contains(&peer_info.id));
1779        assert_eq!(peers.get_affinity(&peer_info.id), Some(PeerAffinity::High));
1780        assert_eq!(handle_with_affinity.max_affinity(), PeerAffinity::High);
1781        assert_eq!(handle_simple.max_affinity(), PeerAffinity::High);
1782
1783        drop(handle_simple);
1784        assert!(peers.contains(&peer_info.id));
1785        assert_eq!(handle_with_affinity.max_affinity(), PeerAffinity::High);
1786        assert_eq!(peers.get_affinity(&peer_info.id), Some(PeerAffinity::High));
1787
1788        drop(handle_with_affinity);
1789        assert!(!peers.contains(&peer_info.id));
1790        assert_eq!(peers.get_affinity(&peer_info.id), None);
1791    }
1792
1793    #[test]
1794    fn ban_while_handle_exists() {
1795        let peers = KnownPeers::new();
1796
1797        let peer_info = make_peer_info_stub(rand::random());
1798        let handle = peers.insert(peer_info.clone(), false).unwrap();
1799        assert!(peers.contains(&peer_info.id));
1800        assert_eq!(handle.max_affinity(), PeerAffinity::Allowed);
1801
1802        peers.ban(&peer_info.id);
1803        assert!(peers.contains(&peer_info.id));
1804        assert!(peers.is_banned(&peer_info.id));
1805        assert_eq!(handle.max_affinity(), PeerAffinity::Never);
1806        assert_eq!(peers.get(&peer_info.id), Some(peer_info.clone()));
1807        assert_eq!(peers.get_affinity(&peer_info.id), Some(PeerAffinity::Never));
1808    }
1809}