Skip to main content

tycho_core/overlay_client/
mod.rs

1use std::sync::Arc;
2use std::time::{Duration, Instant};
3
4use anyhow::Result;
5use bytes::Bytes;
6use rand::prelude::IndexedRandom;
7use tokio::task::AbortHandle;
8use tycho_network::{ConnectionError, Network, PublicOverlay, Request, UnknownPeerError};
9
10pub use self::config::{NeighborsConfig, PublicOverlayClientConfig, ValidatorsConfig};
11pub use self::neighbour::{Neighbour, NeighbourStats, PunishReason};
12pub use self::neighbours::{NeighbourType, Neighbours};
13pub use self::validators::{Validator, ValidatorSetPeers, ValidatorsResolver};
14use crate::proto::overlay;
15
16mod config;
17mod neighbour;
18mod neighbours;
19mod validators;
20
21#[derive(Clone)]
22#[repr(transparent)]
23pub struct PublicOverlayClient {
24    inner: Arc<Inner>,
25}
26
27impl PublicOverlayClient {
28    pub fn new(
29        network: Network,
30        overlay: PublicOverlay,
31        config: PublicOverlayClientConfig,
32    ) -> Self {
33        let ttl = overlay.entry_ttl_sec();
34
35        let neighbors_config = &config.neighbors;
36
37        let entries = overlay
38            .read_entries()
39            .choose_multiple(&mut rand::rng(), neighbors_config.keep)
40            .map(|entry_data| {
41                Neighbour::new(
42                    entry_data.entry.peer_id,
43                    entry_data.expires_at(ttl),
44                    &neighbors_config.default_roundtrip,
45                )
46            })
47            .collect::<Vec<_>>();
48
49        let neighbours = Neighbours::new(entries, config.neighbors.keep);
50        let validators_resolver =
51            ValidatorsResolver::new(network.clone(), overlay.clone(), config.validators.clone());
52
53        let enable_neighbors_metrics = config.neighbors.enable_metrics;
54
55        let mut res = Inner {
56            network,
57            overlay,
58            neighbours,
59            config,
60            validators_resolver,
61            ping_task: None,
62            update_task: None,
63            score_task: None,
64            cleanup_task: None,
65            metrics_task: None,
66        };
67
68        // NOTE: Reuse same `Inner` type to avoid introducing a new type for shard state
69        // NOTE: Clone does not clone the tasks
70        res.ping_task = Some(tokio::spawn(res.clone().ping_neighbours_task()).abort_handle());
71        res.update_task = Some(tokio::spawn(res.clone().update_neighbours_task()).abort_handle());
72        res.score_task = Some(tokio::spawn(res.clone().apply_score_task()).abort_handle());
73        res.cleanup_task = Some(tokio::spawn(res.clone().cleanup_neighbours_task()).abort_handle());
74        res.metrics_task = if enable_neighbors_metrics {
75            Some(tokio::spawn(res.clone().update_metrics_task()).abort_handle())
76        } else {
77            None
78        };
79
80        Self {
81            inner: Arc::new(res),
82        }
83    }
84
85    pub fn config(&self) -> &PublicOverlayClientConfig {
86        &self.inner.config
87    }
88
89    pub fn neighbours(&self) -> &Neighbours {
90        &self.inner.neighbours
91    }
92
93    pub fn update_validator_set<T: ValidatorSetPeers>(&self, vset: &T) {
94        self.inner.validators_resolver.update_validator_set(vset);
95    }
96
97    // Returns a small random subset of possibly alive validators.
98    pub fn get_broadcast_targets(&self) -> Arc<Vec<Validator>> {
99        self.inner.validators_resolver.get_broadcast_targets()
100    }
101
102    pub fn validators_resolver(&self) -> &ValidatorsResolver {
103        &self.inner.validators_resolver
104    }
105
106    pub fn overlay(&self) -> &PublicOverlay {
107        &self.inner.overlay
108    }
109
110    pub fn network(&self) -> &Network {
111        &self.inner.network
112    }
113
114    pub async fn send<R>(&self, data: R) -> Result<(), Error>
115    where
116        R: tl_proto::TlWrite<Repr = tl_proto::Boxed>,
117    {
118        self.inner.send(data).await
119    }
120
121    pub async fn send_to_validator(
122        &self,
123        validator: Validator,
124        data: Request,
125    ) -> Result<(), Error> {
126        self.inner.send_to_validator(validator.clone(), data).await
127    }
128
129    #[inline]
130    pub async fn send_raw(&self, neighbour: Neighbour, req: Request) -> Result<(), Error> {
131        self.inner.send_impl(neighbour, req).await
132    }
133
134    pub async fn query<R, A>(&self, data: R) -> Result<QueryResponse<A>, Error>
135    where
136        R: tl_proto::TlWrite<Repr = tl_proto::Boxed>,
137        for<'a> A: tl_proto::TlRead<'a, Repr = tl_proto::Boxed>,
138    {
139        self.inner.query(data).await
140    }
141
142    #[inline]
143    pub async fn query_raw<A>(
144        &self,
145        neighbour: Neighbour,
146        req: Request,
147    ) -> Result<QueryResponse<A>, Error>
148    where
149        for<'a> A: tl_proto::TlRead<'a, Repr = tl_proto::Boxed>,
150    {
151        self.inner.query_impl(neighbour, req).await?.parse()
152    }
153}
154
155#[derive(thiserror::Error, Debug)]
156pub enum Error {
157    #[error("no active neighbours found")]
158    NoNeighbours,
159    #[error("no neighbour has the requested data")]
160    NotFound,
161    #[error("network error")]
162    NetworkError(#[source] anyhow::Error),
163    #[error("invalid response")]
164    InvalidResponse(#[source] tl_proto::TlError),
165    #[error("request failed with code: {0}")]
166    RequestFailed(u32),
167    #[error("internal error")]
168    Internal(#[source] anyhow::Error),
169    #[error("timeout")]
170    Timeout,
171}
172
173struct Inner {
174    network: Network,
175    overlay: PublicOverlay,
176    neighbours: Neighbours,
177    config: PublicOverlayClientConfig,
178
179    validators_resolver: ValidatorsResolver,
180
181    ping_task: Option<AbortHandle>,
182    update_task: Option<AbortHandle>,
183    score_task: Option<AbortHandle>,
184    cleanup_task: Option<AbortHandle>,
185    metrics_task: Option<AbortHandle>,
186}
187
188impl Clone for Inner {
189    fn clone(&self) -> Self {
190        Self {
191            network: self.network.clone(),
192            overlay: self.overlay.clone(),
193            neighbours: self.neighbours.clone(),
194            config: self.config.clone(),
195            validators_resolver: self.validators_resolver.clone(),
196            ping_task: None,
197            update_task: None,
198            score_task: None,
199            cleanup_task: None,
200            metrics_task: None,
201        }
202    }
203}
204
205impl Inner {
206    #[tracing::instrument(name = "ping_neighbours", skip_all)]
207    async fn ping_neighbours_task(self) {
208        tracing::info!("started");
209        scopeguard::defer! { tracing::info!("finished"); };
210
211        let req = Request::from_tl(overlay::Ping);
212
213        // Start pinging neighbours
214        let mut interval = tokio::time::interval(self.config.neighbors.ping_interval);
215        loop {
216            interval.tick().await;
217
218            // Ping all known neighbours (not only reliable from selection_index)
219            let neighbours = self.neighbours.get_active_neighbours();
220            let Some(neighbour) = neighbours.choose(&mut rand::rng()) else {
221                continue;
222            };
223
224            let peer_id = *neighbour.peer_id();
225            match self.query_impl(neighbour.clone(), req.clone()).await {
226                Ok(res) => match tl_proto::deserialize::<overlay::Pong>(&res.data) {
227                    Ok(_) => {
228                        res.accept();
229                        tracing::debug!(%peer_id, "pinged neighbour");
230                    }
231                    Err(e) => {
232                        tracing::warn!(
233                            %peer_id,
234                            "received an invalid ping response: {e}",
235                        );
236                        res.reject();
237                    }
238                },
239                Err(e) => {
240                    tracing::warn!(
241                        %peer_id,
242                        "failed to ping neighbour: {e}",
243                    );
244                }
245            }
246        }
247    }
248
249    #[tracing::instrument(name = "update_neighbours", skip_all)]
250    async fn update_neighbours_task(self) {
251        tracing::info!("started");
252        scopeguard::defer! { tracing::info!("finished"); };
253
254        let ttl = self.overlay.entry_ttl_sec();
255        let max_neighbours = self.config.neighbors.keep;
256        let default_roundtrip = self.config.neighbors.default_roundtrip;
257
258        let mut overlay_peers_added = self.overlay.entires_added().notified();
259        let mut overlay_peer_count = self.overlay.read_entries().len();
260
261        let mut interval = tokio::time::interval(self.config.neighbors.update_interval);
262
263        loop {
264            if overlay_peer_count == 0 {
265                tracing::info!("not enough neighbours, waiting for more");
266
267                overlay_peers_added.await;
268                overlay_peers_added = self.overlay.entires_added().notified();
269
270                overlay_peer_count = self.overlay.read_entries().len();
271            } else {
272                interval.tick().await;
273            }
274
275            let active_neighbours = self.neighbours.get_active_neighbours().len();
276            let neighbours_to_get = max_neighbours + (max_neighbours - active_neighbours);
277
278            let neighbours = {
279                self.overlay
280                    .read_entries()
281                    .choose_multiple(&mut rand::rng(), neighbours_to_get)
282                    .map(|x| Neighbour::new(x.entry.peer_id, x.expires_at(ttl), &default_roundtrip))
283                    .collect::<Vec<_>>()
284            };
285            self.neighbours.update(neighbours);
286        }
287    }
288
289    #[tracing::instrument(name = "apply_score", skip_all)]
290    async fn apply_score_task(self) {
291        tracing::info!("started");
292        scopeguard::defer! { tracing::info!("finished"); };
293
294        let mut interval = tokio::time::interval(self.config.neighbors.apply_score_interval);
295
296        loop {
297            interval.tick().await;
298
299            let now = tycho_util::time::now_sec();
300            let applied = self.neighbours.try_apply_score(now);
301            tracing::debug!(now, applied, "tried to apply neighbours score");
302        }
303    }
304
305    #[tracing::instrument(name = "cleanup_neighbours", skip_all)]
306    async fn cleanup_neighbours_task(self) {
307        tracing::info!("started");
308        scopeguard::defer! { tracing::info!("finished"); };
309
310        loop {
311            self.overlay.entries_removed().notified().await;
312
313            let now = tycho_util::time::now_sec();
314            let applied = self.neighbours.try_apply_score(now);
315            tracing::debug!(
316                now,
317                applied,
318                "tried to apply neighbours score after some overlay entry was removed"
319            );
320        }
321    }
322
323    #[tracing::instrument(name = "update_neighbour_metrics", skip_all)]
324    async fn update_metrics_task(self) {
325        tracing::info!("started");
326        scopeguard::defer! { tracing::info!("finished"); };
327
328        let mut interval = tokio::time::interval(self.config.neighbors.update_metrics_interval);
329
330        loop {
331            interval.tick().await;
332            self.neighbours.update_metrics(self.network.peer_id());
333        }
334    }
335
336    async fn send<R>(&self, data: R) -> Result<(), Error>
337    where
338        R: tl_proto::TlWrite<Repr = tl_proto::Boxed>,
339    {
340        let Some(neighbour) = self.neighbours.choose() else {
341            return Err(Error::NoNeighbours);
342        };
343
344        self.send_impl(neighbour, Request::from_tl(data)).await
345    }
346
347    async fn send_to_validator(&self, validator: Validator, data: Request) -> Result<(), Error> {
348        let res = self
349            .overlay
350            .send(&self.network, &validator.peer_id(), data)
351            .await;
352        res.map_err(Error::NetworkError)
353    }
354
355    async fn query<R, A>(&self, data: R) -> Result<QueryResponse<A>, Error>
356    where
357        R: tl_proto::TlWrite<Repr = tl_proto::Boxed>,
358        for<'a> A: tl_proto::TlRead<'a, Repr = tl_proto::Boxed>,
359    {
360        let Some(neighbour) = self.neighbours.choose() else {
361            return Err(Error::NoNeighbours);
362        };
363
364        self.query_impl(neighbour, Request::from_tl(data))
365            .await?
366            .parse()
367    }
368
369    async fn send_impl(&self, neighbour: Neighbour, req: Request) -> Result<(), Error> {
370        let started_at = Instant::now();
371
372        let res = tokio::time::timeout(
373            self.config.neighbors.send_timeout,
374            self.overlay.send(&self.network, neighbour.peer_id(), req),
375        )
376        .await;
377
378        let roundtrip = started_at.elapsed() * 2; // Multiply by 2 to estimate the roundtrip time
379
380        match res {
381            Ok(response) => {
382                neighbour.track_request(&roundtrip, response.is_ok());
383
384                if let Err(e) = &response {
385                    apply_network_error(e, &neighbour);
386                }
387
388                response.map_err(Error::NetworkError)
389            }
390            Err(_) => {
391                neighbour.track_request(&roundtrip, false);
392                neighbour.punish(PunishReason::Slow);
393                Err(Error::Timeout)
394            }
395        }
396    }
397
398    async fn query_impl(
399        &self,
400        neighbour: Neighbour,
401        req: Request,
402    ) -> Result<QueryResponse<Bytes>, Error> {
403        let started_at = Instant::now();
404
405        let res = tokio::time::timeout(
406            self.config.neighbors.query_timeout,
407            self.overlay.query(&self.network, neighbour.peer_id(), req),
408        )
409        .await;
410
411        let roundtrip = started_at.elapsed();
412
413        match res {
414            Ok(Ok(response)) => Ok(QueryResponse {
415                data: response.body,
416                roundtrip_ms: roundtrip.as_millis() as u64,
417                neighbour,
418            }),
419            Ok(Err(e)) => {
420                neighbour.track_request(&roundtrip, false);
421                apply_network_error(&e, &neighbour);
422                Err(Error::NetworkError(e))
423            }
424            Err(_) => {
425                neighbour.track_request(&roundtrip, false);
426                neighbour.punish(PunishReason::Slow);
427                Err(Error::Timeout)
428            }
429        }
430    }
431}
432
433impl Drop for Inner {
434    fn drop(&mut self) {
435        if let Some(handle) = self.ping_task.take() {
436            handle.abort();
437        }
438
439        if let Some(handle) = self.update_task.take() {
440            handle.abort();
441        }
442
443        if let Some(handle) = self.cleanup_task.take() {
444            handle.abort();
445        }
446    }
447}
448
449pub struct QueryResponse<A> {
450    data: A,
451    neighbour: Neighbour,
452    roundtrip_ms: u64,
453}
454
455impl<A> QueryResponse<A> {
456    pub fn data(&self) -> &A {
457        &self.data
458    }
459
460    pub fn split(self) -> (QueryResponseHandle, A) {
461        let handle = QueryResponseHandle::with_roundtrip_ms(self.neighbour, self.roundtrip_ms);
462        (handle, self.data)
463    }
464
465    pub fn accept(self) -> (Neighbour, A) {
466        self.track_request(true);
467        (self.neighbour, self.data)
468    }
469
470    pub fn reject(self) -> (Neighbour, A) {
471        self.track_request(false);
472        (self.neighbour, self.data)
473    }
474
475    fn track_request(&self, success: bool) {
476        self.neighbour
477            .track_request(&Duration::from_millis(self.roundtrip_ms), success);
478    }
479}
480
481impl QueryResponse<Bytes> {
482    pub fn parse<A>(self) -> Result<QueryResponse<A>, Error>
483    where
484        for<'a> A: tl_proto::TlRead<'a, Repr = tl_proto::Boxed>,
485    {
486        let response = match tl_proto::deserialize::<overlay::Response<A>>(&self.data) {
487            Ok(r) => r,
488            Err(e) => {
489                self.reject();
490                return Err(Error::InvalidResponse(e));
491            }
492        };
493
494        match response {
495            overlay::Response::Ok(data) => Ok(QueryResponse {
496                data,
497                roundtrip_ms: self.roundtrip_ms,
498                neighbour: self.neighbour,
499            }),
500            overlay::Response::Err(code) => {
501                self.reject();
502                Err(Error::RequestFailed(code))
503            }
504        }
505    }
506}
507
508pub struct QueryResponseHandle {
509    neighbour: Neighbour,
510    roundtrip_ms: u64,
511}
512
513impl QueryResponseHandle {
514    pub fn with_roundtrip_ms(neighbour: Neighbour, roundtrip_ms: u64) -> Self {
515        Self {
516            neighbour,
517            roundtrip_ms,
518        }
519    }
520
521    pub fn accept(self) -> Neighbour {
522        self.track_request(true);
523        self.neighbour
524    }
525
526    pub fn reject(self) -> Neighbour {
527        self.track_request(false);
528        self.neighbour
529    }
530
531    fn track_request(&self, success: bool) {
532        self.neighbour
533            .track_request(&Duration::from_millis(self.roundtrip_ms), success);
534    }
535}
536
537fn apply_network_error(error: &anyhow::Error, neighbour: &Neighbour) {
538    // NOTE: `(*error)` is a non-recurisve downcast
539    let Some(error) = (*error).downcast_ref() else {
540        if let Some(UnknownPeerError { .. }) = (*error).downcast_ref() {
541            neighbour.punish(PunishReason::Malicious);
542        }
543
544        // TODO: Handle other errors as well
545        return;
546    };
547
548    match error {
549        ConnectionError::InvalidAddress | ConnectionError::InvalidCertificate => {
550            neighbour.punish(PunishReason::Malicious);
551        }
552        ConnectionError::Timeout => {
553            neighbour.punish(PunishReason::Slow);
554        }
555        _ => {}
556    }
557}