tycho_core/blockchain_rpc/
client.rs

1use std::io::Write;
2use std::num::{NonZeroU32, NonZeroU64, NonZeroUsize};
3use std::sync::Arc;
4use std::time::{Duration, Instant};
5
6use anyhow::Result;
7use bytes::Bytes;
8use bytesize::ByteSize;
9use futures_util::stream::{FuturesUnordered, StreamExt};
10use parking_lot::Mutex;
11use scopeguard::ScopeGuard;
12use serde::{Deserialize, Serialize};
13use tokio::sync::mpsc;
14use tycho_block_util::archive::ArchiveVerifier;
15use tycho_network::{PublicOverlay, Request};
16use tycho_types::models::BlockId;
17use tycho_util::compression::ZstdDecompressStream;
18use tycho_util::futures::JoinTask;
19use tycho_util::serde_helpers;
20
21use crate::overlay_client::{
22    Error, Neighbour, NeighbourType, PublicOverlayClient, QueryResponse, QueryResponseHandle,
23};
24use crate::proto::blockchain::*;
25use crate::proto::overlay::BroadcastPrefix;
26use crate::storage::PersistentStateKind;
27use crate::util::downloader::{DownloaderError, DownloaderResponseHandle, download_and_decompress};
28
29/// A listener for self-broadcasted messages.
30///
31/// NOTE: `async_trait` is used to add object safety to the trait.
32#[async_trait::async_trait]
33pub trait SelfBroadcastListener: Send + Sync + 'static {
34    async fn handle_message(&self, message: Bytes);
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
38#[serde(default)]
39#[non_exhaustive]
40pub struct BlockchainRpcClientConfig {
41    /// Timeout to broadcast external messages
42    ///
43    /// Default: 100 ms.
44    #[serde(with = "serde_helpers::humantime")]
45    pub min_broadcast_timeout: Duration,
46
47    /// Minimum number of neighbours with `TooNew`
48    /// response required to switch provider
49    ///
50    /// Default: 4.
51    pub too_new_archive_threshold: usize,
52
53    /// Number of retries to download blocks/archives
54    ///
55    /// Default: 10.
56    pub download_retries: usize,
57}
58
59impl Default for BlockchainRpcClientConfig {
60    fn default() -> Self {
61        Self {
62            min_broadcast_timeout: Duration::from_millis(100),
63            too_new_archive_threshold: 4,
64            download_retries: 10,
65        }
66    }
67}
68
69pub struct BlockchainRpcClientBuilder<MandatoryFields = PublicOverlayClient> {
70    config: BlockchainRpcClientConfig,
71    mandatory_fields: MandatoryFields,
72    broadcast_listener: Option<Box<dyn SelfBroadcastListener>>,
73}
74
75impl BlockchainRpcClientBuilder<PublicOverlayClient> {
76    pub fn build(self) -> BlockchainRpcClient {
77        BlockchainRpcClient {
78            inner: Arc::new(Inner {
79                config: self.config,
80                overlay_client: self.mandatory_fields,
81                broadcast_listener: self.broadcast_listener,
82                response_tracker: Mutex::new(
83                    // 5 windows, 60 seconds each, 0.75 quantile
84                    tycho_util::time::RollingP2Estimator::new_with_config(
85                        0.75, // should be enough to filter most of the outliers
86                        Duration::from_secs(60),
87                        5,
88                        tycho_util::time::RealClock,
89                    )
90                    .expect("correct quantile"),
91                ),
92            }),
93        }
94    }
95}
96
97impl BlockchainRpcClientBuilder<()> {
98    pub fn with_public_overlay_client(
99        self,
100        client: PublicOverlayClient,
101    ) -> BlockchainRpcClientBuilder<PublicOverlayClient> {
102        BlockchainRpcClientBuilder {
103            config: self.config,
104            mandatory_fields: client,
105            broadcast_listener: self.broadcast_listener,
106        }
107    }
108}
109
110impl<T> BlockchainRpcClientBuilder<T> {
111    pub fn with_self_broadcast_listener(mut self, listener: impl SelfBroadcastListener) -> Self {
112        self.broadcast_listener = Some(Box::new(listener));
113        self
114    }
115}
116
117impl<T> BlockchainRpcClientBuilder<T> {
118    pub fn with_config(self, config: BlockchainRpcClientConfig) -> Self {
119        Self { config, ..self }
120    }
121}
122
123#[derive(Clone)]
124#[repr(transparent)]
125pub struct BlockchainRpcClient {
126    inner: Arc<Inner>,
127}
128
129impl BlockchainRpcClient {
130    pub fn builder() -> BlockchainRpcClientBuilder<()> {
131        BlockchainRpcClientBuilder {
132            config: Default::default(),
133            mandatory_fields: (),
134            broadcast_listener: None,
135        }
136    }
137
138    pub fn overlay(&self) -> &PublicOverlay {
139        self.inner.overlay_client.overlay()
140    }
141
142    pub fn overlay_client(&self) -> &PublicOverlayClient {
143        &self.inner.overlay_client
144    }
145
146    // TODO: Add rate limiting
147    /// Broadcasts a message to the current targets list and
148    /// returns the number of peers the message was delivered to.
149    pub async fn broadcast_external_message(&self, message: &[u8]) -> usize {
150        struct ExternalMessage<'a> {
151            data: &'a [u8],
152        }
153
154        impl tl_proto::TlWrite for ExternalMessage<'_> {
155            type Repr = tl_proto::Boxed;
156
157            fn max_size_hint(&self) -> usize {
158                4 + MessageBroadcastRef { data: self.data }.max_size_hint()
159            }
160
161            fn write_to<P>(&self, packet: &mut P)
162            where
163                P: tl_proto::TlPacket,
164            {
165                packet.write_u32(BroadcastPrefix::TL_ID);
166                MessageBroadcastRef { data: self.data }.write_to(packet);
167            }
168        }
169
170        // Broadcast to yourself
171        if let Some(l) = &self.inner.broadcast_listener {
172            l.handle_message(Bytes::copy_from_slice(message)).await;
173        }
174
175        let client = &self.inner.overlay_client;
176
177        let mut delivered_to = 0;
178
179        let targets = client.get_broadcast_targets();
180        let request = Request::from_tl(ExternalMessage { data: message });
181        let mut futures = FuturesUnordered::new();
182
183        // we wait for all the responses to come back but cap them at `broadcast_timeout_upper_bound`
184        // all peers timeouts are calculated based on p90 of the previous responses time weighted average
185        // This will make broadcast timeout to be adaptive based on the network conditions
186        for validator in targets.as_ref() {
187            let client = client.clone();
188            let validator = validator.clone();
189            let request = request.clone();
190            let this = self.inner.clone();
191
192            futures.push(JoinTask::new(async move {
193                let start = Instant::now();
194                let res = client.send_to_validator(validator, request).await;
195                this.response_tracker
196                    .lock()
197                    .append(start.elapsed().as_millis() as i64);
198                res
199            }));
200        }
201
202        let timeout = self.compute_broadcast_timeout();
203        tokio::time::timeout(timeout, async {
204            // inner task timeout won't happen because outer task timeout is always <= inner task timeout
205            while let Some(res) = futures.next().await {
206                if let Err(e) = res {
207                    tracing::warn!("failed to broadcast external message: {e}");
208                } else {
209                    delivered_to += 1;
210                }
211            }
212        })
213        .await
214        .ok();
215
216        if delivered_to == 0 {
217            tracing::debug!("message was not delivered to any peer");
218        }
219
220        delivered_to
221    }
222
223    fn compute_broadcast_timeout(&self) -> Duration {
224        let max_broadcast_timeout = std::cmp::max(
225            self.inner.overlay_client.config().validators.send_timeout,
226            self.inner.config.min_broadcast_timeout,
227        );
228
229        if let Some(prev_time) = self
230            .inner
231            .response_tracker
232            .lock()
233            .exponentially_weighted_average()
234            .map(|x| Duration::from_millis(x as _))
235        {
236            metrics::gauge!("tycho_broadcast_timeout", "kind" => "calculated").set(prev_time);
237            let value = prev_time.clamp(
238                self.inner.config.min_broadcast_timeout,
239                max_broadcast_timeout,
240            );
241            metrics::gauge!("tycho_broadcast_timeout", "kind" => "clamped").set(value);
242            value
243        } else {
244            max_broadcast_timeout
245        }
246    }
247
248    pub async fn get_next_key_block_ids(
249        &self,
250        block: &BlockId,
251        max_size: u32,
252    ) -> Result<QueryResponse<KeyBlockIds>, Error> {
253        let client = &self.inner.overlay_client;
254        let data = client
255            .query::<_, KeyBlockIds>(&rpc::GetNextKeyBlockIds {
256                block_id: *block,
257                max_size,
258            })
259            .await?;
260        Ok(data)
261    }
262
263    #[tracing::instrument(skip_all, fields(
264        block_id = %block.as_short_id(),
265        requirement = ?requirement,
266    ))]
267    pub async fn get_block_full(
268        &self,
269        block: &BlockId,
270        requirement: DataRequirement,
271    ) -> Result<BlockDataFullWithNeighbour, Error> {
272        let overlay_client = self.inner.overlay_client.clone();
273
274        let Some(neighbour) = overlay_client.neighbours().choose() else {
275            return Err(Error::NoNeighbours);
276        };
277
278        let retries = self.inner.config.download_retries;
279
280        download_block_inner(
281            Request::from_tl(rpc::GetBlockFull { block_id: *block }),
282            overlay_client,
283            neighbour,
284            requirement,
285            retries,
286        )
287        .await
288    }
289
290    pub async fn get_next_block_full(
291        &self,
292        prev_block: &BlockId,
293        requirement: DataRequirement,
294    ) -> Result<BlockDataFullWithNeighbour, Error> {
295        let overlay_client = self.inner.overlay_client.clone();
296
297        let Some(neighbour) = overlay_client.neighbours().choose() else {
298            return Err(Error::NoNeighbours);
299        };
300
301        let retries = self.inner.config.download_retries;
302
303        download_block_inner(
304            Request::from_tl(rpc::GetNextBlockFull {
305                prev_block_id: *prev_block,
306            }),
307            overlay_client,
308            neighbour,
309            requirement,
310            retries,
311        )
312        .await
313    }
314
315    pub async fn get_key_block_proof(
316        &self,
317        block_id: &BlockId,
318    ) -> Result<QueryResponse<KeyBlockProof>, Error> {
319        let client = &self.inner.overlay_client;
320        let data = client
321            .query::<_, KeyBlockProof>(&rpc::GetKeyBlockProof {
322                block_id: *block_id,
323            })
324            .await?;
325        Ok(data)
326    }
327
328    pub async fn get_persistent_state_info(
329        &self,
330        block_id: &BlockId,
331    ) -> Result<QueryResponse<PersistentStateInfo>, Error> {
332        let client = &self.inner.overlay_client;
333        let data = client
334            .query::<_, PersistentStateInfo>(&rpc::GetPersistentShardStateInfo {
335                block_id: *block_id,
336            })
337            .await?;
338        Ok(data)
339    }
340
341    pub async fn get_persistent_state_part(
342        &self,
343        neighbour: &Neighbour,
344        block_id: &BlockId,
345        offset: u64,
346    ) -> Result<QueryResponse<Data>, Error> {
347        let client = &self.inner.overlay_client;
348        let data = client
349            .query_raw::<Data>(
350                neighbour.clone(),
351                Request::from_tl(rpc::GetPersistentShardStateChunk {
352                    block_id: *block_id,
353                    offset,
354                }),
355            )
356            .await?;
357        Ok(data)
358    }
359
360    pub async fn find_persistent_state(
361        &self,
362        block_id: &BlockId,
363        kind: PersistentStateKind,
364    ) -> Result<PendingPersistentState, Error> {
365        const NEIGHBOUR_COUNT: usize = 10;
366
367        // Get reliable neighbours with higher weight
368        let neighbours = self
369            .overlay_client()
370            .neighbours()
371            .choose_multiple(NEIGHBOUR_COUNT, NeighbourType::Reliable);
372
373        let req = match kind {
374            PersistentStateKind::Shard => Request::from_tl(rpc::GetPersistentShardStateInfo {
375                block_id: *block_id,
376            }),
377            PersistentStateKind::Queue => Request::from_tl(rpc::GetPersistentQueueStateInfo {
378                block_id: *block_id,
379            }),
380        };
381
382        let mut futures = FuturesUnordered::new();
383        for neighbour in neighbours {
384            futures.push(
385                self.overlay_client()
386                    .query_raw::<PersistentStateInfo>(neighbour.clone(), req.clone()),
387            );
388        }
389
390        let mut err = None;
391        while let Some(info) = futures.next().await {
392            let (handle, info) = match info {
393                Ok(res) => res.split(),
394                Err(e) => {
395                    err = Some(e);
396                    continue;
397                }
398            };
399
400            match info {
401                PersistentStateInfo::Found { size, chunk_size } => {
402                    let neighbour = handle.accept();
403                    tracing::debug!(
404                        peer_id = %neighbour.peer_id(),
405                        state_size = size.get(),
406                        state_chunk_size = chunk_size.get(),
407                        ?kind,
408                        "found persistent state",
409                    );
410
411                    return Ok(PendingPersistentState {
412                        block_id: *block_id,
413                        kind,
414                        size,
415                        chunk_size,
416                        neighbour,
417                    });
418                }
419                PersistentStateInfo::NotFound => {}
420            }
421        }
422
423        match err {
424            None => Err(Error::NotFound),
425            Some(err) => Err(err),
426        }
427    }
428
429    #[tracing::instrument(skip_all, fields(
430        peer_id = %state.neighbour.peer_id(),
431        block_id = %state.block_id,
432        kind = ?state.kind,
433    ))]
434    pub async fn download_persistent_state<W>(
435        &self,
436        state: PendingPersistentState,
437        output: W,
438    ) -> Result<W, Error>
439    where
440        W: Write + Send + 'static,
441    {
442        tracing::debug!("started");
443        scopeguard::defer! {
444            tracing::debug!("finished");
445        }
446
447        let block_id = state.block_id;
448        let max_retries = self.inner.config.download_retries;
449
450        download_and_decompress(
451            state.size,
452            state.chunk_size,
453            PARALLEL_REQUESTS,
454            output,
455            |offset| {
456                tracing::debug!("downloading persistent state chunk");
457
458                let req = match state.kind {
459                    PersistentStateKind::Shard => {
460                        Request::from_tl(rpc::GetPersistentShardStateChunk { block_id, offset })
461                    }
462                    PersistentStateKind::Queue => {
463                        Request::from_tl(rpc::GetPersistentQueueStateChunk { block_id, offset })
464                    }
465                };
466                download_with_retries(
467                    req,
468                    self.overlay_client().clone(),
469                    state.neighbour.clone(),
470                    max_retries,
471                    "persistent state chunk",
472                )
473            },
474            |output, chunk| {
475                output.write_all(chunk)?;
476                Ok(())
477            },
478            |mut output| {
479                output.flush()?;
480                Ok(output)
481            },
482        )
483        .await
484        .map_err(map_downloader_error)
485    }
486
487    pub async fn find_archive(&self, mc_seqno: u32) -> Result<PendingArchiveResponse, Error> {
488        const NEIGHBOUR_COUNT: usize = 10;
489
490        // Get reliable neighbours with higher weight
491        let neighbours = self
492            .overlay_client()
493            .neighbours()
494            .choose_multiple(NEIGHBOUR_COUNT, NeighbourType::Reliable);
495
496        // Find a neighbour which has the requested archive
497        let pending_archive = 'info: {
498            let req = Request::from_tl(rpc::GetArchiveInfo { mc_seqno });
499
500            // Number of ArchiveInfo::TooNew responses
501            let mut new_archive_count = 0usize;
502
503            let mut futures = FuturesUnordered::new();
504            for neighbour in neighbours {
505                futures.push(self.overlay_client().query_raw(neighbour, req.clone()));
506            }
507
508            let mut err = None;
509            while let Some(info) = futures.next().await {
510                let (handle, info) = match info {
511                    Ok(res) => res.split(),
512                    Err(e) => {
513                        err = Some(e);
514                        continue;
515                    }
516                };
517
518                match info {
519                    ArchiveInfo::Found {
520                        id,
521                        size,
522                        chunk_size,
523                    } => {
524                        break 'info PendingArchive {
525                            id,
526                            size,
527                            chunk_size,
528                            neighbour: handle.accept(),
529                        };
530                    }
531                    ArchiveInfo::TooNew => {
532                        new_archive_count += 1;
533
534                        handle.accept();
535                    }
536                    ArchiveInfo::NotFound => {
537                        handle.accept();
538                    }
539                }
540            }
541
542            // Stop using archives if enough neighbors responded TooNew
543            if new_archive_count >= self.inner.config.too_new_archive_threshold {
544                return Ok(PendingArchiveResponse::TooNew);
545            }
546
547            return match err {
548                None => Err(Error::NotFound),
549                Some(err) => Err(err),
550            };
551        };
552
553        tracing::info!(
554            peer_id = %pending_archive.neighbour.peer_id(),
555            archive_id = pending_archive.id,
556            archive_size = %ByteSize(pending_archive.size.get()),
557            archuve_chunk_size = %ByteSize(pending_archive.chunk_size.get() as _),
558            "found archive",
559        );
560        Ok(PendingArchiveResponse::Found(pending_archive))
561    }
562
563    #[tracing::instrument(skip_all, fields(
564        peer_id = %archive.neighbour.peer_id(),
565        archive_id = archive.id,
566    ))]
567    pub async fn download_archive<W>(&self, archive: PendingArchive, output: W) -> Result<W, Error>
568    where
569        W: Write + Send + 'static,
570    {
571        use futures_util::FutureExt;
572
573        tracing::debug!("started");
574        scopeguard::defer! {
575            tracing::debug!("finished");
576        }
577
578        let retries = self.inner.config.download_retries;
579
580        download_and_decompress(
581            archive.size,
582            archive.chunk_size,
583            PARALLEL_REQUESTS,
584            (output, ArchiveVerifier::default()),
585            |offset| {
586                let archive_id = archive.id;
587                let neighbour = archive.neighbour.clone();
588                let overlay_client = self.overlay_client().clone();
589
590                let started_at = Instant::now();
591
592                tracing::debug!(archive_id, offset, "downloading archive chunk");
593                download_with_retries(
594                    Request::from_tl(rpc::GetArchiveChunk { archive_id, offset }),
595                    overlay_client,
596                    neighbour,
597                    retries,
598                    "archive chunk",
599                )
600                .map(move |res| {
601                    tracing::info!(
602                        archive_id,
603                        offset,
604                        elapsed = %humantime::format_duration(started_at.elapsed()),
605                        "downloaded archive chunk",
606                    );
607                    res
608                })
609            },
610            |(output, verifier), chunk| {
611                verifier.write_verify(chunk)?;
612                output.write_all(chunk)?;
613                Ok(())
614            },
615            |(mut output, verifier)| {
616                verifier.final_check()?;
617                output.flush()?;
618                Ok(output)
619            },
620        )
621        .await
622        .map_err(map_downloader_error)
623    }
624}
625
626struct Inner {
627    config: BlockchainRpcClientConfig,
628    overlay_client: PublicOverlayClient,
629    broadcast_listener: Option<Box<dyn SelfBroadcastListener>>,
630    response_tracker: Mutex<tycho_util::time::RollingP2Estimator>,
631}
632
633pub enum PendingArchiveResponse {
634    Found(PendingArchive),
635    TooNew,
636}
637
638#[derive(Clone)]
639pub struct PendingArchive {
640    pub id: u64,
641    pub size: NonZeroU64,
642    pub chunk_size: NonZeroU32,
643    pub neighbour: Neighbour,
644}
645
646#[derive(Clone)]
647pub struct PendingPersistentState {
648    pub block_id: BlockId,
649    pub kind: PersistentStateKind,
650    pub size: NonZeroU64,
651    pub chunk_size: NonZeroU32,
652    pub neighbour: Neighbour,
653}
654
655pub struct BlockDataFull {
656    pub block_id: BlockId,
657    pub block_data: Bytes,
658    pub proof_data: Bytes,
659    pub queue_diff_data: Bytes,
660}
661
662pub struct BlockDataFullWithNeighbour {
663    pub data: Option<BlockDataFull>,
664    pub neighbour: Neighbour,
665}
666
667#[derive(Debug, Clone, Copy, PartialEq, Eq)]
668pub enum DataRequirement {
669    /// Data is not required to be present on the neighbour (mostly for polling).
670    ///
671    /// NOTE: Node will not be punished if the data is not present.
672    Optional,
673    /// We assume that the node has the data, but it's not required.
674    ///
675    /// NOTE: Node will be punished as [`PunishReason::Dumb`] if the data is not present.
676    Expected,
677    /// Data must be present on the requested neighbour.
678    ///
679    /// NOTE: Node will be punished as [`PunishReason::Malicious`] if the data is not present.
680    Required,
681}
682
683async fn download_block_inner(
684    req: Request,
685    overlay_client: PublicOverlayClient,
686    neighbour: Neighbour,
687    requirement: DataRequirement,
688    retries: usize,
689) -> Result<BlockDataFullWithNeighbour, Error> {
690    let response = overlay_client
691        .query_raw::<BlockFull>(neighbour.clone(), req)
692        .await?;
693
694    let (handle, block_full) = response.split();
695
696    let BlockFull::Found {
697        block_id,
698        block: block_data,
699        proof: proof_data,
700        queue_diff: queue_diff_data,
701    } = block_full
702    else {
703        match requirement {
704            DataRequirement::Optional => {
705                handle.accept();
706            }
707            DataRequirement::Expected => {
708                handle.reject();
709            }
710            DataRequirement::Required => {
711                neighbour.punish(crate::overlay_client::PunishReason::Malicious);
712            }
713        }
714
715        return Ok(BlockDataFullWithNeighbour {
716            data: None,
717            neighbour,
718        });
719    };
720
721    const PARALLEL_REQUESTS: usize = 10;
722
723    let target_size = block_data.size.get();
724    let chunk_size = block_data.chunk_size.get();
725    let block_data_size = block_data.data.len() as u32;
726
727    if block_data_size > target_size || block_data_size > chunk_size {
728        return Err(Error::Internal(anyhow::anyhow!("invalid first chunk")));
729    }
730
731    let (chunks_tx, mut chunks_rx) =
732        mpsc::channel::<(QueryResponseHandle, Bytes)>(PARALLEL_REQUESTS);
733
734    let span = tracing::Span::current();
735    let processing_task = tokio::task::spawn_blocking(move || {
736        let _span = span.enter();
737
738        let mut zstd_decoder = ZstdDecompressStream::new(chunk_size as usize)?;
739
740        // Buffer for decompressed data
741        let mut decompressed = Vec::new();
742
743        // Decompress chunk
744        zstd_decoder.write(block_data.data.as_ref(), &mut decompressed)?;
745
746        // Receive and process chunks
747        let mut downloaded = block_data.data.len() as u32;
748        while let Some((h, chunk)) = chunks_rx.blocking_recv() {
749            let guard = scopeguard::guard(h, |handle| {
750                handle.reject();
751            });
752
753            anyhow::ensure!(chunk.len() <= chunk_size as usize, "received invalid chunk");
754
755            downloaded += chunk.len() as u32;
756            tracing::debug!(
757                downloaded = %bytesize::ByteSize::b(downloaded as _),
758                "got block data chunk"
759            );
760
761            anyhow::ensure!(downloaded <= target_size, "received too many chunks");
762
763            // Decompress chunk
764            zstd_decoder.write(chunk.as_ref(), &mut decompressed)?;
765
766            ScopeGuard::into_inner(guard).accept(); // defuse the guard
767        }
768
769        anyhow::ensure!(
770            target_size == downloaded,
771            "block size mismatch (target size: {target_size}; downloaded: {downloaded})",
772        );
773
774        Ok(decompressed)
775    });
776
777    let stream = futures_util::stream::iter((chunk_size..target_size).step_by(chunk_size as usize))
778        .map(|offset| {
779            let neighbour = neighbour.clone();
780            let overlay_client = overlay_client.clone();
781
782            tracing::debug!(%block_id, offset, "downloading block data chunk");
783            JoinTask::new(download_with_retries(
784                Request::from_tl(rpc::GetBlockDataChunk { block_id, offset }),
785                overlay_client,
786                neighbour,
787                retries,
788                "block data chunk",
789            ))
790        })
791        .buffered(PARALLEL_REQUESTS);
792
793    let mut stream = std::pin::pin!(stream);
794    while let Some(chunk) = stream.next().await.transpose()? {
795        if chunks_tx.send(chunk).await.is_err() {
796            break;
797        }
798    }
799
800    drop(chunks_tx);
801
802    let block_data = processing_task
803        .await
804        .map_err(|e| Error::Internal(anyhow::anyhow!("Failed to join blocking task: {e}")))?
805        .map(Bytes::from)
806        .map_err(Error::Internal)?;
807
808    Ok(BlockDataFullWithNeighbour {
809        data: Some(BlockDataFull {
810            block_id,
811            block_data,
812            proof_data,
813            queue_diff_data,
814        }),
815        neighbour: neighbour.clone(),
816    })
817}
818
819async fn download_with_retries(
820    req: Request,
821    overlay_client: PublicOverlayClient,
822    neighbour: Neighbour,
823    max_retries: usize,
824    name: &'static str,
825) -> Result<(QueryResponseHandle, Bytes), Error> {
826    let mut retries = 0;
827    loop {
828        match overlay_client
829            .query_raw::<Data>(neighbour.clone(), req.clone())
830            .await
831        {
832            Ok(r) => {
833                let (h, res) = r.split();
834                return Ok((h, res.data));
835            }
836            Err(e) => {
837                tracing::error!("failed to download {name}: {e:?}");
838                retries += 1;
839                if retries >= max_retries || !neighbour.is_reliable() {
840                    return Err(e);
841                }
842
843                tokio::time::sleep(Duration::from_millis(100)).await;
844            }
845        }
846    }
847}
848
849impl DownloaderResponseHandle for QueryResponseHandle {
850    fn accept(self) {
851        QueryResponseHandle::accept(self);
852    }
853
854    fn reject(self) {
855        QueryResponseHandle::reject(self);
856    }
857}
858
859fn map_downloader_error(e: DownloaderError<Error>) -> Error {
860    match e {
861        DownloaderError::DownloadFailed(e) => e,
862        e => Error::Internal(e.into()),
863    }
864}
865
866// TODO: Move info config?
867const PARALLEL_REQUESTS: NonZeroUsize = NonZeroUsize::new(10).unwrap();
868
869#[cfg(test)]
870mod tests {
871    use rand::RngCore;
872    use tycho_network::PeerId;
873    use tycho_util::compression::zstd_compress;
874
875    use super::*;
876
877    #[derive(Debug, thiserror::Error)]
878    #[error("stub")]
879    struct StubError;
880
881    #[tokio::test]
882    async fn download_compressed_works() -> Result<()> {
883        let neighbour = Neighbour::new(PeerId([0; 32]), u32::MAX, &Duration::from_millis(100));
884
885        let mut original_data = vec![0u8; 1 << 20]; // 1 MB of garbage
886        rand::rng().fill_bytes(&mut original_data);
887
888        let mut compressed_data = Vec::new();
889        zstd_compress(&original_data, &mut compressed_data, 9);
890        let compressed_data = Bytes::from(compressed_data);
891
892        assert_ne!(compressed_data, original_data);
893
894        const CHUNK_SIZE: usize = 128;
895
896        let received = download_and_decompress(
897            NonZeroU64::new(compressed_data.len() as _).unwrap(),
898            NonZeroU32::new(CHUNK_SIZE as _).unwrap(),
899            PARALLEL_REQUESTS,
900            Vec::new(),
901            |offset| {
902                assert_eq!(offset % CHUNK_SIZE as u64, 0);
903                assert!(offset < compressed_data.len() as u64);
904                let from = offset as usize;
905                let to = std::cmp::min(from + CHUNK_SIZE, compressed_data.len());
906                let chunk = compressed_data.slice(from..to);
907                let handle = QueryResponseHandle::with_roundtrip_ms(neighbour.clone(), 100);
908                futures_util::future::ready(Ok::<_, StubError>((handle, chunk)))
909            },
910            |result, chunk| {
911                result.extend_from_slice(chunk);
912                Ok(())
913            },
914            Ok,
915        )
916        .await?;
917        assert_eq!(received, original_data);
918
919        Ok(())
920    }
921}