1use std::sync::Arc;
2use std::time::{Duration, Instant};
3
4use anyhow::Result;
5use bytes::Bytes;
6use rand::prelude::IndexedRandom;
7use tokio::task::AbortHandle;
8use tycho_network::{ConnectionError, Network, PublicOverlay, Request, UnknownPeerError};
9
10pub use self::config::{NeighborsConfig, PublicOverlayClientConfig, ValidatorsConfig};
11pub use self::neighbour::{Neighbour, NeighbourStats, PunishReason};
12pub use self::neighbours::{NeighbourType, Neighbours};
13pub use self::validators::{Validator, ValidatorSetPeers, ValidatorsResolver};
14use crate::proto::overlay;
15
16mod config;
17mod neighbour;
18mod neighbours;
19mod validators;
20
21#[derive(Clone)]
22#[repr(transparent)]
23pub struct PublicOverlayClient {
24 inner: Arc<Inner>,
25}
26
27impl PublicOverlayClient {
28 pub fn new(
29 network: Network,
30 overlay: PublicOverlay,
31 config: PublicOverlayClientConfig,
32 ) -> Self {
33 let ttl = overlay.entry_ttl_sec();
34
35 let neighbors_config = &config.neighbors;
36
37 let entries = overlay
38 .read_entries()
39 .choose_multiple(&mut rand::rng(), neighbors_config.keep)
40 .map(|entry_data| {
41 Neighbour::new(
42 entry_data.entry.peer_id,
43 entry_data.expires_at(ttl),
44 &neighbors_config.default_roundtrip,
45 )
46 })
47 .collect::<Vec<_>>();
48
49 let neighbours = Neighbours::new(entries, config.neighbors.keep);
50 let validators_resolver =
51 ValidatorsResolver::new(network.clone(), overlay.clone(), config.validators.clone());
52
53 let enable_neighbors_metrics = config.neighbors.enable_metrics;
54
55 let mut res = Inner {
56 network,
57 overlay,
58 neighbours,
59 config,
60 validators_resolver,
61 ping_task: None,
62 update_task: None,
63 score_task: None,
64 cleanup_task: None,
65 metrics_task: None,
66 };
67
68 res.ping_task = Some(tokio::spawn(res.clone().ping_neighbours_task()).abort_handle());
71 res.update_task = Some(tokio::spawn(res.clone().update_neighbours_task()).abort_handle());
72 res.score_task = Some(tokio::spawn(res.clone().apply_score_task()).abort_handle());
73 res.cleanup_task = Some(tokio::spawn(res.clone().cleanup_neighbours_task()).abort_handle());
74 res.metrics_task = if enable_neighbors_metrics {
75 Some(tokio::spawn(res.clone().update_metrics_task()).abort_handle())
76 } else {
77 None
78 };
79
80 Self {
81 inner: Arc::new(res),
82 }
83 }
84
85 pub fn config(&self) -> &PublicOverlayClientConfig {
86 &self.inner.config
87 }
88
89 pub fn neighbours(&self) -> &Neighbours {
90 &self.inner.neighbours
91 }
92
93 pub fn update_validator_set<T: ValidatorSetPeers>(&self, vset: &T) {
94 self.inner.validators_resolver.update_validator_set(vset);
95 }
96
97 pub fn get_broadcast_targets(&self) -> Arc<Vec<Validator>> {
99 self.inner.validators_resolver.get_broadcast_targets()
100 }
101
102 pub fn validators_resolver(&self) -> &ValidatorsResolver {
103 &self.inner.validators_resolver
104 }
105
106 pub fn overlay(&self) -> &PublicOverlay {
107 &self.inner.overlay
108 }
109
110 pub fn network(&self) -> &Network {
111 &self.inner.network
112 }
113
114 pub async fn send<R>(&self, data: R) -> Result<(), Error>
115 where
116 R: tl_proto::TlWrite<Repr = tl_proto::Boxed>,
117 {
118 self.inner.send(data).await
119 }
120
121 pub async fn send_to_validator(
122 &self,
123 validator: Validator,
124 data: Request,
125 ) -> Result<(), Error> {
126 self.inner.send_to_validator(validator.clone(), data).await
127 }
128
129 #[inline]
130 pub async fn send_raw(&self, neighbour: Neighbour, req: Request) -> Result<(), Error> {
131 self.inner.send_impl(neighbour, req).await
132 }
133
134 pub async fn query<R, A>(&self, data: R) -> Result<QueryResponse<A>, Error>
135 where
136 R: tl_proto::TlWrite<Repr = tl_proto::Boxed>,
137 for<'a> A: tl_proto::TlRead<'a, Repr = tl_proto::Boxed>,
138 {
139 self.inner.query(data).await
140 }
141
142 #[inline]
143 pub async fn query_raw<A>(
144 &self,
145 neighbour: Neighbour,
146 req: Request,
147 ) -> Result<QueryResponse<A>, Error>
148 where
149 for<'a> A: tl_proto::TlRead<'a, Repr = tl_proto::Boxed>,
150 {
151 self.inner.query_impl(neighbour, req).await?.parse()
152 }
153}
154
155#[derive(thiserror::Error, Debug)]
156pub enum Error {
157 #[error("no active neighbours found")]
158 NoNeighbours,
159 #[error("no neighbour has the requested data")]
160 NotFound,
161 #[error("network error")]
162 NetworkError(#[source] anyhow::Error),
163 #[error("invalid response")]
164 InvalidResponse(#[source] tl_proto::TlError),
165 #[error("request failed with code: {0}")]
166 RequestFailed(u32),
167 #[error("internal error")]
168 Internal(#[source] anyhow::Error),
169 #[error("timeout")]
170 Timeout,
171}
172
173struct Inner {
174 network: Network,
175 overlay: PublicOverlay,
176 neighbours: Neighbours,
177 config: PublicOverlayClientConfig,
178
179 validators_resolver: ValidatorsResolver,
180
181 ping_task: Option<AbortHandle>,
182 update_task: Option<AbortHandle>,
183 score_task: Option<AbortHandle>,
184 cleanup_task: Option<AbortHandle>,
185 metrics_task: Option<AbortHandle>,
186}
187
188impl Clone for Inner {
189 fn clone(&self) -> Self {
190 Self {
191 network: self.network.clone(),
192 overlay: self.overlay.clone(),
193 neighbours: self.neighbours.clone(),
194 config: self.config.clone(),
195 validators_resolver: self.validators_resolver.clone(),
196 ping_task: None,
197 update_task: None,
198 score_task: None,
199 cleanup_task: None,
200 metrics_task: None,
201 }
202 }
203}
204
205impl Inner {
206 #[tracing::instrument(name = "ping_neighbours", skip_all)]
207 async fn ping_neighbours_task(self) {
208 tracing::info!("started");
209 scopeguard::defer! { tracing::info!("finished"); };
210
211 let req = Request::from_tl(overlay::Ping);
212
213 let mut interval = tokio::time::interval(self.config.neighbors.ping_interval);
215 loop {
216 interval.tick().await;
217
218 let neighbours = self.neighbours.get_active_neighbours();
220 let Some(neighbour) = neighbours.choose(&mut rand::rng()) else {
221 continue;
222 };
223
224 let peer_id = *neighbour.peer_id();
225 match self.query_impl(neighbour.clone(), req.clone()).await {
226 Ok(res) => match tl_proto::deserialize::<overlay::Pong>(&res.data) {
227 Ok(_) => {
228 res.accept();
229 tracing::debug!(%peer_id, "pinged neighbour");
230 }
231 Err(e) => {
232 tracing::warn!(
233 %peer_id,
234 "received an invalid ping response: {e}",
235 );
236 res.reject();
237 }
238 },
239 Err(e) => {
240 tracing::warn!(
241 %peer_id,
242 "failed to ping neighbour: {e}",
243 );
244 }
245 }
246 }
247 }
248
249 #[tracing::instrument(name = "update_neighbours", skip_all)]
250 async fn update_neighbours_task(self) {
251 tracing::info!("started");
252 scopeguard::defer! { tracing::info!("finished"); };
253
254 let ttl = self.overlay.entry_ttl_sec();
255 let max_neighbours = self.config.neighbors.keep;
256 let default_roundtrip = self.config.neighbors.default_roundtrip;
257
258 let mut overlay_peers_added = self.overlay.entires_added().notified();
259 let mut overlay_peer_count = self.overlay.read_entries().len();
260
261 let mut interval = tokio::time::interval(self.config.neighbors.update_interval);
262
263 loop {
264 if overlay_peer_count == 0 {
265 tracing::info!("not enough neighbours, waiting for more");
266
267 overlay_peers_added.await;
268 overlay_peers_added = self.overlay.entires_added().notified();
269
270 overlay_peer_count = self.overlay.read_entries().len();
271 } else {
272 interval.tick().await;
273 }
274
275 let active_neighbours = self.neighbours.get_active_neighbours().len();
276 let neighbours_to_get = max_neighbours + (max_neighbours - active_neighbours);
277
278 let neighbours = {
279 self.overlay
280 .read_entries()
281 .choose_multiple(&mut rand::rng(), neighbours_to_get)
282 .map(|x| Neighbour::new(x.entry.peer_id, x.expires_at(ttl), &default_roundtrip))
283 .collect::<Vec<_>>()
284 };
285 self.neighbours.update(neighbours);
286 }
287 }
288
289 #[tracing::instrument(name = "apply_score", skip_all)]
290 async fn apply_score_task(self) {
291 tracing::info!("started");
292 scopeguard::defer! { tracing::info!("finished"); };
293
294 let mut interval = tokio::time::interval(self.config.neighbors.apply_score_interval);
295
296 loop {
297 interval.tick().await;
298
299 let now = tycho_util::time::now_sec();
300 let applied = self.neighbours.try_apply_score(now);
301 tracing::debug!(now, applied, "tried to apply neighbours score");
302 }
303 }
304
305 #[tracing::instrument(name = "cleanup_neighbours", skip_all)]
306 async fn cleanup_neighbours_task(self) {
307 tracing::info!("started");
308 scopeguard::defer! { tracing::info!("finished"); };
309
310 loop {
311 self.overlay.entries_removed().notified().await;
312
313 let now = tycho_util::time::now_sec();
314 let applied = self.neighbours.try_apply_score(now);
315 tracing::debug!(
316 now,
317 applied,
318 "tried to apply neighbours score after some overlay entry was removed"
319 );
320 }
321 }
322
323 #[tracing::instrument(name = "update_neighbour_metrics", skip_all)]
324 async fn update_metrics_task(self) {
325 tracing::info!("started");
326 scopeguard::defer! { tracing::info!("finished"); };
327
328 let mut interval = tokio::time::interval(self.config.neighbors.update_metrics_interval);
329
330 loop {
331 interval.tick().await;
332 self.neighbours.update_metrics(self.network.peer_id());
333 }
334 }
335
336 async fn send<R>(&self, data: R) -> Result<(), Error>
337 where
338 R: tl_proto::TlWrite<Repr = tl_proto::Boxed>,
339 {
340 let Some(neighbour) = self.neighbours.choose() else {
341 return Err(Error::NoNeighbours);
342 };
343
344 self.send_impl(neighbour, Request::from_tl(data)).await
345 }
346
347 async fn send_to_validator(&self, validator: Validator, data: Request) -> Result<(), Error> {
348 let res = self
349 .overlay
350 .send(&self.network, &validator.peer_id(), data)
351 .await;
352 res.map_err(Error::NetworkError)
353 }
354
355 async fn query<R, A>(&self, data: R) -> Result<QueryResponse<A>, Error>
356 where
357 R: tl_proto::TlWrite<Repr = tl_proto::Boxed>,
358 for<'a> A: tl_proto::TlRead<'a, Repr = tl_proto::Boxed>,
359 {
360 let Some(neighbour) = self.neighbours.choose() else {
361 return Err(Error::NoNeighbours);
362 };
363
364 self.query_impl(neighbour, Request::from_tl(data))
365 .await?
366 .parse()
367 }
368
369 async fn send_impl(&self, neighbour: Neighbour, req: Request) -> Result<(), Error> {
370 let started_at = Instant::now();
371
372 let res = tokio::time::timeout(
373 self.config.neighbors.send_timeout,
374 self.overlay.send(&self.network, neighbour.peer_id(), req),
375 )
376 .await;
377
378 let roundtrip = started_at.elapsed() * 2; match res {
381 Ok(response) => {
382 neighbour.track_request(&roundtrip, response.is_ok());
383
384 if let Err(e) = &response {
385 apply_network_error(e, &neighbour);
386 }
387
388 response.map_err(Error::NetworkError)
389 }
390 Err(_) => {
391 neighbour.track_request(&roundtrip, false);
392 neighbour.punish(PunishReason::Slow);
393 Err(Error::Timeout)
394 }
395 }
396 }
397
398 async fn query_impl(
399 &self,
400 neighbour: Neighbour,
401 req: Request,
402 ) -> Result<QueryResponse<Bytes>, Error> {
403 let started_at = Instant::now();
404
405 let res = tokio::time::timeout(
406 self.config.neighbors.query_timeout,
407 self.overlay.query(&self.network, neighbour.peer_id(), req),
408 )
409 .await;
410
411 let roundtrip = started_at.elapsed();
412
413 match res {
414 Ok(Ok(response)) => Ok(QueryResponse {
415 data: response.body,
416 roundtrip_ms: roundtrip.as_millis() as u64,
417 neighbour,
418 }),
419 Ok(Err(e)) => {
420 neighbour.track_request(&roundtrip, false);
421 apply_network_error(&e, &neighbour);
422 Err(Error::NetworkError(e))
423 }
424 Err(_) => {
425 neighbour.track_request(&roundtrip, false);
426 neighbour.punish(PunishReason::Slow);
427 Err(Error::Timeout)
428 }
429 }
430 }
431}
432
433impl Drop for Inner {
434 fn drop(&mut self) {
435 if let Some(handle) = self.ping_task.take() {
436 handle.abort();
437 }
438
439 if let Some(handle) = self.update_task.take() {
440 handle.abort();
441 }
442
443 if let Some(handle) = self.cleanup_task.take() {
444 handle.abort();
445 }
446 }
447}
448
449pub struct QueryResponse<A> {
450 data: A,
451 neighbour: Neighbour,
452 roundtrip_ms: u64,
453}
454
455impl<A> QueryResponse<A> {
456 pub fn data(&self) -> &A {
457 &self.data
458 }
459
460 pub fn split(self) -> (QueryResponseHandle, A) {
461 let handle = QueryResponseHandle::with_roundtrip_ms(self.neighbour, self.roundtrip_ms);
462 (handle, self.data)
463 }
464
465 pub fn accept(self) -> (Neighbour, A) {
466 self.track_request(true);
467 (self.neighbour, self.data)
468 }
469
470 pub fn reject(self) -> (Neighbour, A) {
471 self.track_request(false);
472 (self.neighbour, self.data)
473 }
474
475 fn track_request(&self, success: bool) {
476 self.neighbour
477 .track_request(&Duration::from_millis(self.roundtrip_ms), success);
478 }
479}
480
481impl QueryResponse<Bytes> {
482 pub fn parse<A>(self) -> Result<QueryResponse<A>, Error>
483 where
484 for<'a> A: tl_proto::TlRead<'a, Repr = tl_proto::Boxed>,
485 {
486 let response = match tl_proto::deserialize::<overlay::Response<A>>(&self.data) {
487 Ok(r) => r,
488 Err(e) => {
489 self.reject();
490 return Err(Error::InvalidResponse(e));
491 }
492 };
493
494 match response {
495 overlay::Response::Ok(data) => Ok(QueryResponse {
496 data,
497 roundtrip_ms: self.roundtrip_ms,
498 neighbour: self.neighbour,
499 }),
500 overlay::Response::Err(code) => {
501 self.reject();
502 Err(Error::RequestFailed(code))
503 }
504 }
505 }
506}
507
508pub struct QueryResponseHandle {
509 neighbour: Neighbour,
510 roundtrip_ms: u64,
511}
512
513impl QueryResponseHandle {
514 pub fn with_roundtrip_ms(neighbour: Neighbour, roundtrip_ms: u64) -> Self {
515 Self {
516 neighbour,
517 roundtrip_ms,
518 }
519 }
520
521 pub fn accept(self) -> Neighbour {
522 self.track_request(true);
523 self.neighbour
524 }
525
526 pub fn reject(self) -> Neighbour {
527 self.track_request(false);
528 self.neighbour
529 }
530
531 fn track_request(&self, success: bool) {
532 self.neighbour
533 .track_request(&Duration::from_millis(self.roundtrip_ms), success);
534 }
535}
536
537fn apply_network_error(error: &anyhow::Error, neighbour: &Neighbour) {
538 let Some(error) = (*error).downcast_ref() else {
540 if let Some(UnknownPeerError { .. }) = (*error).downcast_ref() {
541 neighbour.punish(PunishReason::Malicious);
542 }
543
544 return;
546 };
547
548 match error {
549 ConnectionError::InvalidAddress | ConnectionError::InvalidCertificate => {
550 neighbour.punish(PunishReason::Malicious);
551 }
552 ConnectionError::Timeout => {
553 neighbour.punish(PunishReason::Slow);
554 }
555 _ => {}
556 }
557}