tycho_network/dht/
peer_resolver.rs

1use std::mem::ManuallyDrop;
2use std::sync::atomic::{AtomicU32, Ordering};
3use std::sync::{Arc, Mutex, Weak};
4use std::time::Duration;
5
6use exponential_backoff::Backoff;
7use serde::{Deserialize, Serialize};
8use tokio::sync::{Notify, Semaphore};
9use tycho_util::futures::JoinTask;
10use tycho_util::time::now_sec;
11use tycho_util::{FastDashMap, serde_helpers};
12
13use crate::dht::DhtService;
14use crate::network::{KnownPeerHandle, KnownPeersError, Network, PeerBannedError, WeakNetwork};
15use crate::proto::dht;
16use crate::types::{PeerId, PeerInfo};
17
18pub struct PeerResolverBuilder {
19    inner: PeerResolverConfig,
20    dht_service: DhtService,
21}
22
23impl PeerResolverBuilder {
24    pub fn with_config(mut self, config: PeerResolverConfig) -> Self {
25        self.inner = config;
26        self
27    }
28
29    pub fn build(self, network: &Network) -> PeerResolver {
30        let semaphore = Semaphore::new(self.inner.max_parallel_resolve_requests);
31
32        PeerResolver {
33            inner: Arc::new(PeerResolverInner {
34                weak_network: Network::downgrade(network),
35                dht_service: self.dht_service,
36                config: Default::default(),
37                tasks: Default::default(),
38                semaphore,
39            }),
40        }
41    }
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
45#[serde(default)]
46pub struct PeerResolverConfig {
47    /// Maximum number of parallel resolve requests.
48    ///
49    /// Default: 100.
50    pub max_parallel_resolve_requests: usize,
51
52    /// Minimal time-to-live for the resolved peer info.
53    ///
54    /// Default: 600 seconds.
55    pub min_ttl_sec: u32,
56
57    /// Time before the expiration when the peer info should be updated.
58    ///
59    /// Default: 1200 seconds.
60    pub update_before_sec: u32,
61
62    /// Number of fast retries before switching to the stale retry interval.
63    ///
64    /// Default: 10.
65    pub fast_retry_count: u32,
66
67    /// Minimal interval between successful resolves.
68    ///
69    /// Default: 1 minute.
70    #[serde(with = "serde_helpers::humantime")]
71    pub min_successfull_resolve_interval: Duration,
72
73    /// Minimal interval between the fast retries.
74    ///
75    /// Default: 1 second.
76    #[serde(with = "serde_helpers::humantime")]
77    pub min_retry_interval: Duration,
78
79    /// Maximal interval between the fast retries.
80    ///
81    /// Default: 120 seconds.
82    #[serde(with = "serde_helpers::humantime")]
83    pub max_retry_interval: Duration,
84
85    /// Interval between the stale retries.
86    ///
87    /// Default: 600 seconds.
88    #[serde(with = "serde_helpers::humantime")]
89    pub stale_retry_interval: Duration,
90}
91
92impl Default for PeerResolverConfig {
93    fn default() -> Self {
94        Self {
95            max_parallel_resolve_requests: 100,
96            min_ttl_sec: 600,
97            update_before_sec: 1200,
98            fast_retry_count: 10,
99            min_successfull_resolve_interval: Duration::from_secs(60),
100            min_retry_interval: Duration::from_secs(1),
101            max_retry_interval: Duration::from_secs(120),
102            stale_retry_interval: Duration::from_secs(600),
103        }
104    }
105}
106
107#[derive(Clone)]
108pub struct PeerResolver {
109    inner: Arc<PeerResolverInner>,
110}
111
112impl PeerResolver {
113    pub(crate) fn builder(dht_service: DhtService) -> PeerResolverBuilder {
114        PeerResolverBuilder {
115            inner: Default::default(),
116            dht_service,
117        }
118    }
119
120    pub fn dht_service(&self) -> &DhtService {
121        &self.inner.dht_service
122    }
123
124    // TODO: Use affinity flag to increase the handle affinity.
125    pub fn insert(&self, peer_id: &PeerId, _with_affinity: bool) -> PeerResolverHandle {
126        use dashmap::mapref::entry::Entry;
127
128        match self.inner.tasks.entry(*peer_id) {
129            Entry::Vacant(entry) => {
130                let handle = self.inner.make_resolver_handle(peer_id);
131                entry.insert(Arc::downgrade(&handle.inner));
132                handle
133            }
134            Entry::Occupied(mut entry) => match entry.get().upgrade() {
135                Some(inner) => PeerResolverHandle {
136                    inner: ManuallyDrop::new(inner),
137                },
138                None => {
139                    let handle = self.inner.make_resolver_handle(peer_id);
140                    entry.insert(Arc::downgrade(&handle.inner));
141                    handle
142                }
143            },
144        }
145    }
146}
147
148struct PeerResolverInner {
149    weak_network: WeakNetwork,
150    dht_service: DhtService,
151    config: PeerResolverConfig,
152    tasks: FastDashMap<PeerId, Weak<PeerResolverHandleInner>>,
153    semaphore: Semaphore,
154}
155
156impl PeerResolverInner {
157    fn make_resolver_handle(self: &Arc<Self>, peer_id: &PeerId) -> PeerResolverHandle {
158        let handle = match self.weak_network.upgrade() {
159            Some(handle) => handle.known_peers().make_handle(peer_id, false),
160            None => {
161                return PeerResolverHandle::new_noop(peer_id);
162            }
163        };
164        let updater_state = handle
165            .as_ref()
166            .map(|handle| self.compute_timings(&handle.peer_info()));
167
168        let data = Arc::new(PeerResolverHandleData::new(peer_id, handle));
169
170        PeerResolverHandle::new(
171            JoinTask::new(self.clone().run_task(data.clone(), updater_state)),
172            data,
173            self,
174        )
175    }
176
177    async fn run_task(
178        self: Arc<Self>,
179        data: Arc<PeerResolverHandleData>,
180        mut timings: Option<PeerResolverTimings>,
181    ) {
182        tracing::trace!(peer_id = %data.peer_id, "peer resolver task started");
183
184        // TODO: Select between the loop body and `KnownPeers` update event.
185        loop {
186            // Wait if needed.
187            if let Some(t) = timings {
188                let update_at = std::time::UNIX_EPOCH + Duration::from_secs(t.update_at as u64);
189                let now = std::time::SystemTime::now();
190
191                let remaining = std::cmp::max(
192                    update_at.duration_since(now).unwrap_or_default(),
193                    self.config.min_successfull_resolve_interval,
194                );
195                tokio::time::sleep(remaining).await;
196            }
197
198            // Start resolving peer.
199            match self.resolve_peer(&data, &timings).await {
200                Some((network, peer_info)) => {
201                    let mut handle = data.handle.lock().unwrap();
202
203                    let peer_info_guard;
204                    let peer_info = match &*handle {
205                        // TODO: Force write into known peers to keep the handle in it?
206                        Some(handle) => match handle.update_peer_info(&peer_info) {
207                            Ok(()) => peer_info.as_ref(),
208                            Err(KnownPeersError::OutdatedInfo) => {
209                                peer_info_guard = handle.peer_info();
210                                peer_info_guard.as_ref()
211                            }
212                            // TODO: Allow resuming task after ban?
213                            Err(KnownPeersError::PeerBanned(PeerBannedError)) => break,
214                        },
215                        None => match network
216                            .known_peers()
217                            .insert_allow_outdated(peer_info, false)
218                        {
219                            Ok(new_handle) => {
220                                peer_info_guard = handle.insert(new_handle).peer_info();
221                                data.mark_resolved();
222                                peer_info_guard.as_ref()
223                            }
224                            // TODO: Allow resuming task after ban?
225                            Err(PeerBannedError) => break,
226                        },
227                    };
228
229                    timings = Some(self.compute_timings(peer_info));
230                }
231                None => break,
232            }
233        }
234
235        tracing::trace!(peer_id = %data.peer_id, "peer resolver task finished");
236    }
237
238    /// Returns a verified peer info with the strong reference to the network.
239    /// Or `None` if network no longer exists.
240    async fn resolve_peer(
241        &self,
242        data: &PeerResolverHandleData,
243        prev_timings: &Option<PeerResolverTimings>,
244    ) -> Option<(Network, Arc<PeerInfo>)> {
245        struct Iter<'a> {
246            backoff: Option<exponential_backoff::Iter<'a>>,
247            data: &'a PeerResolverHandleData,
248            stale_retry_interval: &'a Duration,
249        }
250
251        impl Iterator for Iter<'_> {
252            type Item = Duration;
253
254            fn next(&mut self) -> Option<Self::Item> {
255                Some(loop {
256                    match self.backoff.as_mut() {
257                        // Get next duration from the backoff iterator.
258                        Some(backoff) => match backoff.next() {
259                            // Use it for the first attempts.
260                            Some(duration) => break duration,
261                            // Set `is_stale` flag on last attempt and continue wih only
262                            // the `stale_retry_interval` for all subsequent iterations.
263                            None => {
264                                self.data.set_stale(true);
265                                self.backoff = None;
266                            }
267                        },
268                        // Use `stale_retry_interval` after the max retry count is reached.
269                        None => break *self.stale_retry_interval,
270                    }
271                })
272            }
273        }
274
275        let backoff = Backoff::new(
276            self.config.fast_retry_count,
277            self.config.min_retry_interval,
278            Some(self.config.max_retry_interval),
279        );
280        let mut iter = Iter {
281            backoff: Some(backoff.iter()),
282            data,
283            stale_retry_interval: &self.config.stale_retry_interval,
284        };
285
286        // "Fast" path
287        let mut attempts = 0usize;
288        loop {
289            attempts += 1;
290            let is_stale = attempts > self.config.fast_retry_count as usize;
291
292            // NOTE: Acquire network ref only during the operation.
293            {
294                let network = self.weak_network.upgrade()?;
295                if let Some(peer_info) = network.known_peers().get(&data.peer_id)
296                    && PeerResolverTimings::is_new_info(prev_timings, &peer_info)
297                {
298                    tracing::trace!(
299                        peer_id = %data.peer_id,
300                        attempts,
301                        is_stale,
302                        "peer info exists",
303                    );
304                    return Some((network, peer_info));
305                }
306
307                let dht_client = self.dht_service.make_client(&network);
308
309                let res = {
310                    let _permit = self.semaphore.acquire().await.unwrap();
311                    dht_client
312                        .entry(dht::PeerValueKeyName::NodeInfo)
313                        .find_value::<PeerInfo>(&data.peer_id)
314                        .await
315                };
316
317                let now = now_sec();
318                match res {
319                    // NOTE: Single blocking signature check here is ok since
320                    //       we are going to wait for some interval anyway.
321                    Ok(peer_info) if peer_info.id == data.peer_id && peer_info.verify(now) => {
322                        // NOTE: We only need a NEW peer info, otherwise the `resolve_peer`
323                        // method will be called again and again and again... without any progress.
324                        if PeerResolverTimings::is_new_info(prev_timings, &peer_info) {
325                            return Some((network, Arc::new(peer_info)));
326                        }
327                    }
328                    Ok(_) => {
329                        tracing::trace!(
330                            peer_id = %data.peer_id,
331                            attempts,
332                            is_stale,
333                            "received an invalid peer info",
334                        );
335                    }
336                    Err(e) => {
337                        tracing::trace!(
338                            peer_id = %data.peer_id,
339                            attempts,
340                            is_stale,
341                            "failed to resolve a peer info: {e:?}",
342                        );
343                    }
344                }
345            }
346
347            let interval = iter.next().expect("retries iterator must be infinite");
348            tokio::time::sleep(interval).await;
349        }
350    }
351
352    fn compute_timings(&self, peer_info: &PeerInfo) -> PeerResolverTimings {
353        let real_ttl = peer_info
354            .expires_at
355            .saturating_sub(self.config.update_before_sec)
356            .saturating_sub(peer_info.created_at);
357
358        let adjusted_ttl = std::cmp::max(real_ttl, self.config.min_ttl_sec);
359        PeerResolverTimings {
360            created_at: peer_info.created_at,
361            expires_at: peer_info.expires_at,
362            update_at: peer_info.created_at.saturating_add(adjusted_ttl),
363        }
364    }
365}
366
367#[derive(Debug, Clone, Copy)]
368struct PeerResolverTimings {
369    created_at: u32,
370    expires_at: u32,
371    update_at: u32,
372}
373
374impl PeerResolverTimings {
375    fn is_new_info(this: &Option<Self>, peer_info: &PeerInfo) -> bool {
376        match this {
377            Some(this) => {
378                peer_info.created_at > this.created_at && peer_info.expires_at > this.expires_at
379            }
380            None => true,
381        }
382    }
383}
384
385#[derive(Clone)]
386#[repr(transparent)]
387pub struct PeerResolverHandle {
388    inner: ManuallyDrop<Arc<PeerResolverHandleInner>>,
389}
390
391impl PeerResolverHandle {
392    fn new(
393        task: JoinTask<()>,
394        data: Arc<PeerResolverHandleData>,
395        resolver: &Arc<PeerResolverInner>,
396    ) -> Self {
397        Self {
398            inner: ManuallyDrop::new(Arc::new(PeerResolverHandleInner {
399                _task: Some(task),
400                data,
401                resolver: Arc::downgrade(resolver),
402            })),
403        }
404    }
405
406    pub fn new_noop(peer_id: &PeerId) -> Self {
407        Self {
408            inner: ManuallyDrop::new(Arc::new(PeerResolverHandleInner {
409                _task: None,
410                data: Arc::new(PeerResolverHandleData::new(peer_id, None)),
411                resolver: Weak::new(),
412            })),
413        }
414    }
415
416    pub fn peer_id(&self) -> &PeerId {
417        &self.inner.data.peer_id
418    }
419
420    pub fn load_handle(&self) -> Option<KnownPeerHandle> {
421        self.inner.data.handle.lock().unwrap().clone()
422    }
423
424    pub fn is_stale(&self) -> bool {
425        self.inner.data.is_stale()
426    }
427
428    pub fn is_resolved(&self) -> bool {
429        self.inner.data.is_resolved()
430    }
431
432    pub async fn wait_resolved(&self) -> KnownPeerHandle {
433        loop {
434            let resolved = self.inner.data.notify_resolved.notified();
435            if let Some(load_handle) = self.load_handle() {
436                break load_handle;
437            }
438            resolved.await;
439        }
440    }
441}
442
443impl Drop for PeerResolverHandle {
444    fn drop(&mut self) {
445        // SAFETY: inner value is dropped only once
446        let inner = unsafe { ManuallyDrop::take(&mut self.inner) };
447
448        // Remove this entry from the resolver if it was the last strong reference.
449        if let Some(inner) = Arc::into_inner(inner) {
450            // NOTE: At this point an `Arc` was dropped, so the `Weak` in the resolver
451            // addresses only the remaining references.
452
453            if let Some(resolver) = inner.resolver.upgrade() {
454                resolver
455                    .tasks
456                    .remove_if(&inner.data.peer_id, |_, value| value.strong_count() == 0);
457            }
458        }
459    }
460}
461
462struct PeerResolverHandleInner {
463    _task: Option<JoinTask<()>>,
464    data: Arc<PeerResolverHandleData>,
465    resolver: Weak<PeerResolverInner>,
466}
467
468struct PeerResolverHandleData {
469    peer_id: PeerId,
470    handle: Mutex<Option<KnownPeerHandle>>,
471    flags: AtomicU32,
472    notify_resolved: Notify,
473}
474
475impl PeerResolverHandleData {
476    fn new(peer_id: &PeerId, handle: Option<KnownPeerHandle>) -> Self {
477        let flags = AtomicU32::new(if handle.is_some() { RESOLVED_FLAG } else { 0 });
478
479        Self {
480            peer_id: *peer_id,
481            handle: Mutex::new(handle),
482            flags,
483            notify_resolved: Notify::new(),
484        }
485    }
486
487    fn mark_resolved(&self) {
488        self.flags.fetch_or(RESOLVED_FLAG, Ordering::Release);
489        self.notify_resolved.notify_waiters();
490    }
491
492    fn is_resolved(&self) -> bool {
493        self.flags.load(Ordering::Acquire) & RESOLVED_FLAG != 0
494    }
495
496    fn set_stale(&self, stale: bool) {
497        if stale {
498            self.flags.fetch_or(STALE_FLAG, Ordering::Release);
499        } else {
500            self.flags.fetch_and(!STALE_FLAG, Ordering::Release);
501        }
502    }
503
504    fn is_stale(&self) -> bool {
505        self.flags.load(Ordering::Acquire) & STALE_FLAG != 0
506    }
507}
508
509const STALE_FLAG: u32 = 0b1;
510const RESOLVED_FLAG: u32 = 0b10;