tycho_core/overlay_client/
mod.rs

1use std::sync::Arc;
2use std::time::{Duration, Instant};
3
4use anyhow::Result;
5use bytes::Bytes;
6use tokio::task::AbortHandle;
7use tycho_network::{ConnectionError, Network, PublicOverlay, Request, UnknownPeerError};
8
9pub use self::config::{NeighborsConfig, PublicOverlayClientConfig, ValidatorsConfig};
10pub use self::neighbour::{Neighbour, NeighbourStats, PunishReason};
11pub use self::neighbours::{NeighbourType, Neighbours};
12pub use self::validators::{Validator, ValidatorSetPeers, ValidatorsResolver};
13use crate::proto::overlay;
14
15mod config;
16mod neighbour;
17mod neighbours;
18mod validators;
19
20#[derive(Clone)]
21#[repr(transparent)]
22pub struct PublicOverlayClient {
23    inner: Arc<Inner>,
24}
25
26impl PublicOverlayClient {
27    pub fn new(
28        network: Network,
29        overlay: PublicOverlay,
30        config: PublicOverlayClientConfig,
31    ) -> Self {
32        let ttl = overlay.entry_ttl_sec();
33
34        let neighbors_config = &config.neighbors;
35
36        let entries = overlay
37            .read_entries()
38            .choose_multiple(&mut rand::rng(), neighbors_config.keep)
39            .map(|entry_data| {
40                Neighbour::new(
41                    entry_data.entry.peer_id,
42                    entry_data.expires_at(ttl),
43                    &neighbors_config.default_roundtrip,
44                )
45            })
46            .collect::<Vec<_>>();
47
48        let neighbours = Neighbours::new(entries, config.neighbors.keep);
49        let validators_resolver =
50            ValidatorsResolver::new(network.clone(), overlay.clone(), config.validators.clone());
51
52        let mut res = Inner {
53            network,
54            overlay,
55            neighbours,
56            config,
57            validators_resolver,
58            ping_task: None,
59            update_task: None,
60            score_task: None,
61            cleanup_task: None,
62        };
63
64        // NOTE: Reuse same `Inner` type to avoid introducing a new type for shard state
65        // NOTE: Clone does not clone the tasks
66        res.ping_task = Some(tokio::spawn(res.clone().ping_neighbours_task()).abort_handle());
67        res.update_task = Some(tokio::spawn(res.clone().update_neighbours_task()).abort_handle());
68        res.score_task = Some(tokio::spawn(res.clone().apply_score_task()).abort_handle());
69        res.cleanup_task = Some(tokio::spawn(res.clone().cleanup_neighbours_task()).abort_handle());
70
71        Self {
72            inner: Arc::new(res),
73        }
74    }
75
76    pub fn config(&self) -> &PublicOverlayClientConfig {
77        &self.inner.config
78    }
79
80    pub fn neighbours(&self) -> &Neighbours {
81        &self.inner.neighbours
82    }
83
84    pub fn update_validator_set<T: ValidatorSetPeers>(&self, vset: &T) {
85        self.inner.validators_resolver.update_validator_set(vset);
86    }
87
88    // Returns a small random subset of possibly alive validators.
89    pub fn get_broadcast_targets(&self) -> Arc<Vec<Validator>> {
90        self.inner.validators_resolver.get_broadcast_targets()
91    }
92
93    pub fn validators_resolver(&self) -> &ValidatorsResolver {
94        &self.inner.validators_resolver
95    }
96
97    pub fn overlay(&self) -> &PublicOverlay {
98        &self.inner.overlay
99    }
100
101    pub fn network(&self) -> &Network {
102        &self.inner.network
103    }
104
105    pub async fn send<R>(&self, data: R) -> Result<(), Error>
106    where
107        R: tl_proto::TlWrite<Repr = tl_proto::Boxed>,
108    {
109        self.inner.send(data).await
110    }
111
112    pub async fn send_to_validator(
113        &self,
114        validator: Validator,
115        data: Request,
116    ) -> Result<(), Error> {
117        self.inner.send_to_validator(validator.clone(), data).await
118    }
119
120    #[inline]
121    pub async fn send_raw(&self, neighbour: Neighbour, req: Request) -> Result<(), Error> {
122        self.inner.send_impl(neighbour, req).await
123    }
124
125    pub async fn query<R, A>(&self, data: R) -> Result<QueryResponse<A>, Error>
126    where
127        R: tl_proto::TlWrite<Repr = tl_proto::Boxed>,
128        for<'a> A: tl_proto::TlRead<'a, Repr = tl_proto::Boxed>,
129    {
130        self.inner.query(data).await
131    }
132
133    #[inline]
134    pub async fn query_raw<A>(
135        &self,
136        neighbour: Neighbour,
137        req: Request,
138    ) -> Result<QueryResponse<A>, Error>
139    where
140        for<'a> A: tl_proto::TlRead<'a, Repr = tl_proto::Boxed>,
141    {
142        self.inner.query_impl(neighbour, req).await?.parse()
143    }
144}
145
146#[derive(thiserror::Error, Debug)]
147pub enum Error {
148    #[error("no active neighbours found")]
149    NoNeighbours,
150    #[error("no neighbour has the requested data")]
151    NotFound,
152    #[error("network error: {0}")]
153    NetworkError(#[source] anyhow::Error),
154    #[error("invalid response: {0}")]
155    InvalidResponse(#[source] tl_proto::TlError),
156    #[error("request failed with code: {0}")]
157    RequestFailed(u32),
158    #[error("internal error: {0}")]
159    Internal(#[source] anyhow::Error),
160    #[error("timeout")]
161    Timeout,
162}
163
164struct Inner {
165    network: Network,
166    overlay: PublicOverlay,
167    neighbours: Neighbours,
168    config: PublicOverlayClientConfig,
169
170    validators_resolver: ValidatorsResolver,
171
172    ping_task: Option<AbortHandle>,
173    update_task: Option<AbortHandle>,
174    score_task: Option<AbortHandle>,
175    cleanup_task: Option<AbortHandle>,
176}
177
178impl Clone for Inner {
179    fn clone(&self) -> Self {
180        Self {
181            network: self.network.clone(),
182            overlay: self.overlay.clone(),
183            neighbours: self.neighbours.clone(),
184            config: self.config.clone(),
185            validators_resolver: self.validators_resolver.clone(),
186            ping_task: None,
187            update_task: None,
188            score_task: None,
189            cleanup_task: None,
190        }
191    }
192}
193
194impl Inner {
195    #[tracing::instrument(name = "ping_neighbours", skip_all)]
196    async fn ping_neighbours_task(self) {
197        tracing::info!("started");
198        scopeguard::defer! { tracing::info!("finished"); };
199
200        let req = Request::from_tl(overlay::Ping);
201
202        // Start pinging neighbours
203        let mut interval = tokio::time::interval(self.config.neighbors.ping_interval);
204        loop {
205            interval.tick().await;
206
207            let Some(neighbour) = self.neighbours.choose() else {
208                continue;
209            };
210
211            let peer_id = *neighbour.peer_id();
212            match self.query_impl(neighbour.clone(), req.clone()).await {
213                Ok(res) => match tl_proto::deserialize::<overlay::Pong>(&res.data) {
214                    Ok(_) => {
215                        res.accept();
216                        tracing::debug!(%peer_id, "pinged neighbour");
217                    }
218                    Err(e) => {
219                        tracing::warn!(
220                            %peer_id,
221                            "received an invalid ping response: {e}",
222                        );
223                        res.reject();
224                    }
225                },
226                Err(e) => {
227                    tracing::warn!(
228                        %peer_id,
229                        "failed to ping neighbour: {e}",
230                    );
231                }
232            }
233        }
234    }
235
236    #[tracing::instrument(name = "update_neighbours", skip_all)]
237    async fn update_neighbours_task(self) {
238        tracing::info!("started");
239        scopeguard::defer! { tracing::info!("finished"); };
240
241        let ttl = self.overlay.entry_ttl_sec();
242        let max_neighbours = self.config.neighbors.keep;
243        let default_roundtrip = self.config.neighbors.default_roundtrip;
244
245        let mut overlay_peers_added = self.overlay.entires_added().notified();
246        let mut overlay_peer_count = self.overlay.read_entries().len();
247
248        let mut interval = tokio::time::interval(self.config.neighbors.update_interval);
249
250        loop {
251            if overlay_peer_count < max_neighbours {
252                tracing::info!("not enough neighbours, waiting for more");
253
254                overlay_peers_added.await;
255                overlay_peers_added = self.overlay.entires_added().notified();
256
257                overlay_peer_count = self.overlay.read_entries().len();
258            } else {
259                interval.tick().await;
260            }
261
262            let active_neighbours = self.neighbours.get_active_neighbours().len();
263            let neighbours_to_get = max_neighbours + (max_neighbours - active_neighbours);
264
265            let neighbours = {
266                self.overlay
267                    .read_entries()
268                    .choose_multiple(&mut rand::rng(), neighbours_to_get)
269                    .map(|x| Neighbour::new(x.entry.peer_id, x.expires_at(ttl), &default_roundtrip))
270                    .collect::<Vec<_>>()
271            };
272            self.neighbours.update(neighbours);
273        }
274    }
275
276    #[tracing::instrument(name = "apply_score", skip_all)]
277    async fn apply_score_task(self) {
278        tracing::info!("started");
279        scopeguard::defer! { tracing::info!("finished"); };
280
281        let mut interval = tokio::time::interval(self.config.neighbors.apply_score_interval);
282
283        loop {
284            interval.tick().await;
285
286            let now = tycho_util::time::now_sec();
287            let applied = self.neighbours.try_apply_score(now);
288            tracing::debug!(now, applied, "tried to apply neighbours score");
289        }
290    }
291
292    #[tracing::instrument(name = "cleanup_neighbours", skip_all)]
293    async fn cleanup_neighbours_task(self) {
294        tracing::info!("started");
295        scopeguard::defer! { tracing::info!("finished"); };
296
297        loop {
298            self.overlay.entries_removed().notified().await;
299
300            let now = tycho_util::time::now_sec();
301            let applied = self.neighbours.try_apply_score(now);
302            tracing::debug!(
303                now,
304                applied,
305                "tried to apply neighbours score after some overlay entry was removed"
306            );
307        }
308    }
309
310    async fn send<R>(&self, data: R) -> Result<(), Error>
311    where
312        R: tl_proto::TlWrite<Repr = tl_proto::Boxed>,
313    {
314        let Some(neighbour) = self.neighbours.choose() else {
315            return Err(Error::NoNeighbours);
316        };
317
318        self.send_impl(neighbour, Request::from_tl(data)).await
319    }
320
321    async fn send_to_validator(&self, validator: Validator, data: Request) -> Result<(), Error> {
322        let res = self
323            .overlay
324            .send(&self.network, &validator.peer_id(), data)
325            .await;
326        res.map_err(Error::NetworkError)
327    }
328
329    async fn query<R, A>(&self, data: R) -> Result<QueryResponse<A>, Error>
330    where
331        R: tl_proto::TlWrite<Repr = tl_proto::Boxed>,
332        for<'a> A: tl_proto::TlRead<'a, Repr = tl_proto::Boxed>,
333    {
334        let Some(neighbour) = self.neighbours.choose() else {
335            return Err(Error::NoNeighbours);
336        };
337
338        self.query_impl(neighbour, Request::from_tl(data))
339            .await?
340            .parse()
341    }
342
343    async fn send_impl(&self, neighbour: Neighbour, req: Request) -> Result<(), Error> {
344        let started_at = Instant::now();
345
346        let res = tokio::time::timeout(
347            self.config.neighbors.send_timeout,
348            self.overlay.send(&self.network, neighbour.peer_id(), req),
349        )
350        .await;
351
352        let roundtrip = started_at.elapsed() * 2; // Multiply by 2 to estimate the roundtrip time
353
354        match res {
355            Ok(response) => {
356                neighbour.track_request(&roundtrip, response.is_ok());
357
358                if let Err(e) = &response {
359                    apply_network_error(e, &neighbour);
360                }
361
362                response.map_err(Error::NetworkError)
363            }
364            Err(_) => {
365                neighbour.track_request(&roundtrip, false);
366                neighbour.punish(PunishReason::Slow);
367                Err(Error::Timeout)
368            }
369        }
370    }
371
372    async fn query_impl(
373        &self,
374        neighbour: Neighbour,
375        req: Request,
376    ) -> Result<QueryResponse<Bytes>, Error> {
377        let started_at = Instant::now();
378
379        let res = tokio::time::timeout(
380            self.config.neighbors.query_timeout,
381            self.overlay.query(&self.network, neighbour.peer_id(), req),
382        )
383        .await;
384
385        let roundtrip = started_at.elapsed();
386
387        match res {
388            Ok(Ok(response)) => Ok(QueryResponse {
389                data: response.body,
390                roundtrip_ms: roundtrip.as_millis() as u64,
391                neighbour,
392            }),
393            Ok(Err(e)) => {
394                neighbour.track_request(&roundtrip, false);
395                apply_network_error(&e, &neighbour);
396                Err(Error::NetworkError(e))
397            }
398            Err(_) => {
399                neighbour.track_request(&roundtrip, false);
400                neighbour.punish(PunishReason::Slow);
401                Err(Error::Timeout)
402            }
403        }
404    }
405}
406
407impl Drop for Inner {
408    fn drop(&mut self) {
409        if let Some(handle) = self.ping_task.take() {
410            handle.abort();
411        }
412
413        if let Some(handle) = self.update_task.take() {
414            handle.abort();
415        }
416
417        if let Some(handle) = self.cleanup_task.take() {
418            handle.abort();
419        }
420    }
421}
422
423pub struct QueryResponse<A> {
424    data: A,
425    neighbour: Neighbour,
426    roundtrip_ms: u64,
427}
428
429impl<A> QueryResponse<A> {
430    pub fn data(&self) -> &A {
431        &self.data
432    }
433
434    pub fn split(self) -> (QueryResponseHandle, A) {
435        let handle = QueryResponseHandle::with_roundtrip_ms(self.neighbour, self.roundtrip_ms);
436        (handle, self.data)
437    }
438
439    pub fn accept(self) -> (Neighbour, A) {
440        self.track_request(true);
441        (self.neighbour, self.data)
442    }
443
444    pub fn reject(self) -> (Neighbour, A) {
445        self.track_request(false);
446        (self.neighbour, self.data)
447    }
448
449    fn track_request(&self, success: bool) {
450        self.neighbour
451            .track_request(&Duration::from_millis(self.roundtrip_ms), success);
452    }
453}
454
455impl QueryResponse<Bytes> {
456    pub fn parse<A>(self) -> Result<QueryResponse<A>, Error>
457    where
458        for<'a> A: tl_proto::TlRead<'a, Repr = tl_proto::Boxed>,
459    {
460        let response = match tl_proto::deserialize::<overlay::Response<A>>(&self.data) {
461            Ok(r) => r,
462            Err(e) => {
463                self.reject();
464                return Err(Error::InvalidResponse(e));
465            }
466        };
467
468        match response {
469            overlay::Response::Ok(data) => Ok(QueryResponse {
470                data,
471                roundtrip_ms: self.roundtrip_ms,
472                neighbour: self.neighbour,
473            }),
474            overlay::Response::Err(code) => {
475                self.reject();
476                Err(Error::RequestFailed(code))
477            }
478        }
479    }
480}
481
482pub struct QueryResponseHandle {
483    neighbour: Neighbour,
484    roundtrip_ms: u64,
485}
486
487impl QueryResponseHandle {
488    pub fn with_roundtrip_ms(neighbour: Neighbour, roundtrip_ms: u64) -> Self {
489        Self {
490            neighbour,
491            roundtrip_ms,
492        }
493    }
494
495    pub fn accept(self) -> Neighbour {
496        self.track_request(true);
497        self.neighbour
498    }
499
500    pub fn reject(self) -> Neighbour {
501        self.track_request(false);
502        self.neighbour
503    }
504
505    fn track_request(&self, success: bool) {
506        self.neighbour
507            .track_request(&Duration::from_millis(self.roundtrip_ms), success);
508    }
509}
510
511fn apply_network_error(error: &anyhow::Error, neighbour: &Neighbour) {
512    // NOTE: `(*error)` is a non-recurisve downcast
513    let Some(error) = (*error).downcast_ref() else {
514        if let Some(UnknownPeerError { .. }) = (*error).downcast_ref() {
515            neighbour.punish(PunishReason::Malicious);
516        }
517
518        // TODO: Handle other errors as well
519        return;
520    };
521
522    match error {
523        ConnectionError::InvalidAddress | ConnectionError::InvalidCertificate => {
524            neighbour.punish(PunishReason::Malicious);
525        }
526        ConnectionError::Timeout => {
527            neighbour.punish(PunishReason::Slow);
528        }
529        _ => {}
530    }
531}