Skip to main content

tycho_core/blockchain_rpc/
client.rs

1use std::io::Write;
2use std::num::{NonZeroU32, NonZeroU64, NonZeroUsize};
3use std::sync::Arc;
4use std::time::{Duration, Instant};
5
6use anyhow::Result;
7use bytes::Bytes;
8use bytesize::ByteSize;
9use futures_util::stream::{FuturesUnordered, StreamExt};
10use parking_lot::Mutex;
11use scopeguard::ScopeGuard;
12use serde::{Deserialize, Serialize};
13use tokio::sync::mpsc;
14use tycho_block_util::archive::ArchiveVerifier;
15use tycho_network::{PublicOverlay, Request};
16use tycho_types::models::BlockId;
17use tycho_util::compression::ZstdDecompressStream;
18use tycho_util::futures::JoinTask;
19use tycho_util::serde_helpers;
20
21use crate::overlay_client::{
22    Error, Neighbour, NeighbourType, PublicOverlayClient, QueryResponse, QueryResponseHandle,
23};
24use crate::proto::blockchain::*;
25use crate::proto::overlay::BroadcastPrefix;
26use crate::storage::PersistentStateKind;
27use crate::util::downloader::{DownloaderError, DownloaderResponseHandle, download_and_decompress};
28
29/// A listener for self-broadcasted messages.
30///
31/// NOTE: `async_trait` is used to add object safety to the trait.
32#[async_trait::async_trait]
33pub trait SelfBroadcastListener: Send + Sync + 'static {
34    async fn handle_message(&self, message: Bytes);
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
38#[serde(default)]
39#[non_exhaustive]
40pub struct BlockchainRpcClientConfig {
41    /// Timeout to broadcast external messages
42    ///
43    /// Default: 100 ms.
44    #[serde(with = "serde_helpers::humantime")]
45    pub min_broadcast_timeout: Duration,
46
47    /// Minimum number of neighbours with `TooNew`
48    /// response required to switch provider
49    ///
50    /// Default: 4.
51    pub too_new_archive_threshold: usize,
52
53    /// Number of retries to download blocks/archives
54    ///
55    /// Default: 10.
56    pub download_retries: usize,
57}
58
59impl Default for BlockchainRpcClientConfig {
60    fn default() -> Self {
61        Self {
62            min_broadcast_timeout: Duration::from_millis(100),
63            too_new_archive_threshold: 4,
64            download_retries: 10,
65        }
66    }
67}
68
69pub struct BlockchainRpcClientBuilder<MandatoryFields = PublicOverlayClient> {
70    config: BlockchainRpcClientConfig,
71    mandatory_fields: MandatoryFields,
72    broadcast_listener: Option<Box<dyn SelfBroadcastListener>>,
73}
74
75impl BlockchainRpcClientBuilder<PublicOverlayClient> {
76    pub fn build(self) -> BlockchainRpcClient {
77        BlockchainRpcClient {
78            inner: Arc::new(Inner {
79                config: self.config,
80                overlay_client: self.mandatory_fields,
81                broadcast_listener: self.broadcast_listener,
82                response_tracker: Mutex::new(
83                    // 5 windows, 60 seconds each, 0.75 quantile
84                    tycho_util::time::RollingP2Estimator::new_with_config(
85                        0.75, // should be enough to filter most of the outliers
86                        Duration::from_secs(60),
87                        5,
88                        tycho_util::time::RealClock,
89                    )
90                    .expect("correct quantile"),
91                ),
92            }),
93        }
94    }
95}
96
97impl BlockchainRpcClientBuilder<()> {
98    pub fn with_public_overlay_client(
99        self,
100        client: PublicOverlayClient,
101    ) -> BlockchainRpcClientBuilder<PublicOverlayClient> {
102        BlockchainRpcClientBuilder {
103            config: self.config,
104            mandatory_fields: client,
105            broadcast_listener: self.broadcast_listener,
106        }
107    }
108}
109
110impl<T> BlockchainRpcClientBuilder<T> {
111    pub fn with_self_broadcast_listener(mut self, listener: impl SelfBroadcastListener) -> Self {
112        self.broadcast_listener = Some(Box::new(listener));
113        self
114    }
115}
116
117impl<T> BlockchainRpcClientBuilder<T> {
118    pub fn with_config(self, config: BlockchainRpcClientConfig) -> Self {
119        Self { config, ..self }
120    }
121}
122
123#[derive(Clone)]
124#[repr(transparent)]
125pub struct BlockchainRpcClient {
126    inner: Arc<Inner>,
127}
128
129impl BlockchainRpcClient {
130    pub fn builder() -> BlockchainRpcClientBuilder<()> {
131        BlockchainRpcClientBuilder {
132            config: Default::default(),
133            mandatory_fields: (),
134            broadcast_listener: None,
135        }
136    }
137
138    pub fn overlay(&self) -> &PublicOverlay {
139        self.inner.overlay_client.overlay()
140    }
141
142    pub fn overlay_client(&self) -> &PublicOverlayClient {
143        &self.inner.overlay_client
144    }
145
146    // TODO: Add rate limiting
147    /// Broadcasts a message to the current targets list and
148    /// returns the number of peers the message was delivered to.
149    pub async fn broadcast_external_message(&self, message: &[u8]) -> usize {
150        struct ExternalMessage<'a> {
151            data: &'a [u8],
152        }
153
154        impl tl_proto::TlWrite for ExternalMessage<'_> {
155            type Repr = tl_proto::Boxed;
156
157            fn max_size_hint(&self) -> usize {
158                4 + MessageBroadcastRef { data: self.data }.max_size_hint()
159            }
160
161            fn write_to<P>(&self, packet: &mut P)
162            where
163                P: tl_proto::TlPacket,
164            {
165                packet.write_u32(BroadcastPrefix::TL_ID);
166                MessageBroadcastRef { data: self.data }.write_to(packet);
167            }
168        }
169
170        // Broadcast to yourself
171        if let Some(l) = &self.inner.broadcast_listener {
172            l.handle_message(Bytes::copy_from_slice(message)).await;
173        }
174
175        let client = &self.inner.overlay_client;
176
177        let mut delivered_to = 0;
178
179        let targets = client.get_broadcast_targets();
180        let request = Request::from_tl(ExternalMessage { data: message });
181        let mut futures = FuturesUnordered::new();
182
183        // we wait for all the responses to come back but cap them at `broadcast_timeout_upper_bound`
184        // all peers timeouts are calculated based on p90 of the previous responses time weighted average
185        // This will make broadcast timeout to be adaptive based on the network conditions
186        for validator in targets.as_ref() {
187            let client = client.clone();
188            let validator = validator.clone();
189            let request = request.clone();
190            let this = self.inner.clone();
191
192            futures.push(JoinTask::new(async move {
193                let start = Instant::now();
194                let res = client.send_to_validator(validator, request).await;
195                this.response_tracker
196                    .lock()
197                    .append(start.elapsed().as_millis() as i64);
198                res
199            }));
200        }
201
202        let timeout = self.compute_broadcast_timeout();
203        tokio::time::timeout(timeout, async {
204            // inner task timeout won't happen because outer task timeout is always <= inner task timeout
205            while let Some(res) = futures.next().await {
206                if let Err(e) = res {
207                    tracing::warn!("failed to broadcast external message: {e}");
208                } else {
209                    delivered_to += 1;
210                }
211            }
212        })
213        .await
214        .ok();
215
216        if delivered_to == 0 {
217            tracing::debug!("message was not delivered to any peer");
218        }
219
220        delivered_to
221    }
222
223    fn compute_broadcast_timeout(&self) -> Duration {
224        let max_broadcast_timeout = std::cmp::max(
225            self.inner.overlay_client.config().validators.send_timeout,
226            self.inner.config.min_broadcast_timeout,
227        );
228
229        if let Some(prev_time) = self
230            .inner
231            .response_tracker
232            .lock()
233            .exponentially_weighted_average()
234            .map(|x| Duration::from_millis(x as _))
235        {
236            metrics::gauge!("tycho_broadcast_timeout", "kind" => "calculated").set(prev_time);
237            let value = prev_time.clamp(
238                self.inner.config.min_broadcast_timeout,
239                max_broadcast_timeout,
240            );
241            metrics::gauge!("tycho_broadcast_timeout", "kind" => "clamped").set(value);
242            value
243        } else {
244            max_broadcast_timeout
245        }
246    }
247
248    pub async fn get_next_key_block_ids(
249        &self,
250        block: &BlockId,
251        max_size: u32,
252    ) -> Result<QueryResponse<KeyBlockIds>, Error> {
253        let client = &self.inner.overlay_client;
254        let data = client
255            .query::<_, KeyBlockIds>(&rpc::GetNextKeyBlockIds {
256                block_id: *block,
257                max_size,
258            })
259            .await?;
260        Ok(data)
261    }
262
263    #[tracing::instrument(skip_all, fields(
264        block_id = %block.as_short_id(),
265        requirement = ?requirement,
266    ))]
267    pub async fn get_block_full(
268        &self,
269        block: &BlockId,
270        requirement: DataRequirement,
271    ) -> Result<BlockDataFullWithNeighbour, Error> {
272        let overlay_client = self.inner.overlay_client.clone();
273
274        let Some(neighbour) = overlay_client.neighbours().choose() else {
275            return Err(Error::NoNeighbours);
276        };
277
278        let retries = self.inner.config.download_retries;
279
280        download_block_inner(
281            Request::from_tl(rpc::GetBlockFull { block_id: *block }),
282            overlay_client,
283            neighbour,
284            requirement,
285            retries,
286        )
287        .await
288    }
289
290    pub async fn get_next_block_full(
291        &self,
292        prev_block: &BlockId,
293        requirement: DataRequirement,
294    ) -> Result<BlockDataFullWithNeighbour, Error> {
295        let overlay_client = self.inner.overlay_client.clone();
296
297        let Some(neighbour) = overlay_client.neighbours().choose() else {
298            return Err(Error::NoNeighbours);
299        };
300
301        let retries = self.inner.config.download_retries;
302
303        download_block_inner(
304            Request::from_tl(rpc::GetNextBlockFull {
305                prev_block_id: *prev_block,
306            }),
307            overlay_client,
308            neighbour,
309            requirement,
310            retries,
311        )
312        .await
313    }
314
315    pub async fn get_key_block_proof(
316        &self,
317        block_id: &BlockId,
318    ) -> Result<QueryResponse<KeyBlockProof>, Error> {
319        let client = &self.inner.overlay_client;
320        let data = client
321            .query::<_, KeyBlockProof>(&rpc::GetKeyBlockProof {
322                block_id: *block_id,
323            })
324            .await?;
325        Ok(data)
326    }
327
328    pub async fn get_zerostate_proof(&self) -> Result<QueryResponse<ZerostateProof>, Error> {
329        let client = &self.inner.overlay_client;
330        let data = client
331            .query::<_, ZerostateProof>(&rpc::GetZerostateProof)
332            .await?;
333        Ok(data)
334    }
335
336    pub async fn get_persistent_state_info(
337        &self,
338        block_id: &BlockId,
339    ) -> Result<QueryResponse<PersistentStateInfo>, Error> {
340        let client = &self.inner.overlay_client;
341        let data = client
342            .query::<_, PersistentStateInfo>(&rpc::GetPersistentShardStateInfo {
343                block_id: *block_id,
344            })
345            .await?;
346        Ok(data)
347    }
348
349    pub async fn get_persistent_state_part(
350        &self,
351        neighbour: &Neighbour,
352        block_id: &BlockId,
353        offset: u64,
354    ) -> Result<QueryResponse<Data>, Error> {
355        let client = &self.inner.overlay_client;
356        let data = client
357            .query_raw::<Data>(
358                neighbour.clone(),
359                Request::from_tl(rpc::GetPersistentShardStateChunk {
360                    block_id: *block_id,
361                    offset,
362                }),
363            )
364            .await?;
365        Ok(data)
366    }
367
368    pub async fn find_persistent_state(
369        &self,
370        block_id: &BlockId,
371        kind: PersistentStateKind,
372    ) -> Result<PendingPersistentState, Error> {
373        const NEIGHBOUR_COUNT: usize = 10;
374
375        // Get reliable neighbours with higher weight
376        let neighbours = self
377            .overlay_client()
378            .neighbours()
379            .choose_multiple(NEIGHBOUR_COUNT, NeighbourType::Reliable);
380
381        let req = match kind {
382            PersistentStateKind::Shard => Request::from_tl(rpc::GetPersistentShardStateInfo {
383                block_id: *block_id,
384            }),
385            PersistentStateKind::Queue => Request::from_tl(rpc::GetPersistentQueueStateInfo {
386                block_id: *block_id,
387            }),
388        };
389
390        let mut futures = FuturesUnordered::new();
391        for neighbour in neighbours {
392            futures.push(
393                self.overlay_client()
394                    .query_raw::<PersistentStateInfo>(neighbour.clone(), req.clone()),
395            );
396        }
397
398        let mut err = None;
399        while let Some(info) = futures.next().await {
400            let (handle, info) = match info {
401                Ok(res) => res.split(),
402                Err(e) => {
403                    err = Some(e);
404                    continue;
405                }
406            };
407
408            match info {
409                PersistentStateInfo::Found { size, chunk_size } => {
410                    let neighbour = handle.accept();
411                    tracing::debug!(
412                        peer_id = %neighbour.peer_id(),
413                        state_size = size.get(),
414                        state_chunk_size = chunk_size.get(),
415                        ?kind,
416                        "found persistent state",
417                    );
418
419                    return Ok(PendingPersistentState {
420                        block_id: *block_id,
421                        kind,
422                        size,
423                        chunk_size,
424                        neighbour,
425                    });
426                }
427                PersistentStateInfo::NotFound => {}
428            }
429        }
430
431        match err {
432            None => Err(Error::NotFound),
433            Some(err) => Err(err),
434        }
435    }
436
437    #[tracing::instrument(skip_all, fields(
438        peer_id = %state.neighbour.peer_id(),
439        block_id = %state.block_id,
440        kind = ?state.kind,
441    ))]
442    pub async fn download_persistent_state<W>(
443        &self,
444        state: PendingPersistentState,
445        output: W,
446    ) -> Result<W, Error>
447    where
448        W: Write + Send + 'static,
449    {
450        tracing::debug!("started");
451        scopeguard::defer! {
452            tracing::debug!("finished");
453        }
454
455        let block_id = state.block_id;
456        let max_retries = self.inner.config.download_retries;
457
458        download_and_decompress(
459            state.size,
460            state.chunk_size,
461            PARALLEL_REQUESTS,
462            output,
463            |offset| {
464                tracing::debug!("downloading persistent state chunk");
465
466                let req = match state.kind {
467                    PersistentStateKind::Shard => {
468                        Request::from_tl(rpc::GetPersistentShardStateChunk { block_id, offset })
469                    }
470                    PersistentStateKind::Queue => {
471                        Request::from_tl(rpc::GetPersistentQueueStateChunk { block_id, offset })
472                    }
473                };
474                download_with_retries(
475                    req,
476                    self.overlay_client().clone(),
477                    state.neighbour.clone(),
478                    max_retries,
479                    "persistent state chunk",
480                )
481            },
482            |output, chunk| {
483                output.write_all(chunk)?;
484                Ok(())
485            },
486            |mut output| {
487                output.flush()?;
488                Ok(output)
489            },
490        )
491        .await
492        .map_err(map_downloader_error)
493    }
494
495    pub async fn find_archive(&self, mc_seqno: u32) -> Result<PendingArchiveResponse, Error> {
496        const NEIGHBOUR_COUNT: usize = 10;
497
498        // Get reliable neighbours with higher weight
499        let neighbours = self
500            .overlay_client()
501            .neighbours()
502            .choose_multiple(NEIGHBOUR_COUNT, NeighbourType::Reliable);
503
504        // Find a neighbour which has the requested archive
505        let pending_archive = 'info: {
506            let req = Request::from_tl(rpc::GetArchiveInfo { mc_seqno });
507
508            // Number of ArchiveInfo::TooNew responses
509            let mut new_archive_count = 0usize;
510
511            let mut futures = FuturesUnordered::new();
512            for neighbour in neighbours {
513                futures.push(self.overlay_client().query_raw(neighbour, req.clone()));
514            }
515
516            let mut err = None;
517            while let Some(info) = futures.next().await {
518                let (handle, info) = match info {
519                    Ok(res) => res.split(),
520                    Err(e) => {
521                        err = Some(e);
522                        continue;
523                    }
524                };
525
526                match info {
527                    ArchiveInfo::Found {
528                        id,
529                        size,
530                        chunk_size,
531                    } => {
532                        break 'info PendingArchive {
533                            id,
534                            size,
535                            chunk_size,
536                            neighbour: handle.accept(),
537                        };
538                    }
539                    ArchiveInfo::TooNew => {
540                        new_archive_count += 1;
541
542                        handle.accept();
543                    }
544                    ArchiveInfo::NotFound => {
545                        handle.accept();
546                    }
547                }
548            }
549
550            // Stop using archives if enough neighbors responded TooNew
551            if new_archive_count >= self.inner.config.too_new_archive_threshold {
552                return Ok(PendingArchiveResponse::TooNew);
553            }
554
555            return match err {
556                None => Err(Error::NotFound),
557                Some(err) => Err(err),
558            };
559        };
560
561        tracing::info!(
562            peer_id = %pending_archive.neighbour.peer_id(),
563            archive_id = pending_archive.id,
564            archive_size = %ByteSize(pending_archive.size.get()),
565            archuve_chunk_size = %ByteSize(pending_archive.chunk_size.get() as _),
566            "found archive",
567        );
568        Ok(PendingArchiveResponse::Found(pending_archive))
569    }
570
571    #[tracing::instrument(skip_all, fields(
572        peer_id = %archive.neighbour.peer_id(),
573        archive_id = archive.id,
574    ))]
575    pub async fn download_archive<W>(&self, archive: PendingArchive, output: W) -> Result<W, Error>
576    where
577        W: Write + Send + 'static,
578    {
579        use futures_util::FutureExt;
580
581        tracing::debug!("started");
582        scopeguard::defer! {
583            tracing::debug!("finished");
584        }
585
586        let retries = self.inner.config.download_retries;
587
588        download_and_decompress(
589            archive.size,
590            archive.chunk_size,
591            PARALLEL_REQUESTS,
592            (output, ArchiveVerifier::default()),
593            |offset| {
594                let archive_id = archive.id;
595                let neighbour = archive.neighbour.clone();
596                let overlay_client = self.overlay_client().clone();
597
598                let started_at = Instant::now();
599
600                tracing::debug!(archive_id, offset, "downloading archive chunk");
601                download_with_retries(
602                    Request::from_tl(rpc::GetArchiveChunk { archive_id, offset }),
603                    overlay_client,
604                    neighbour,
605                    retries,
606                    "archive chunk",
607                )
608                .map(move |res| {
609                    tracing::info!(
610                        archive_id,
611                        offset,
612                        elapsed = %humantime::format_duration(started_at.elapsed()),
613                        "downloaded archive chunk",
614                    );
615                    res
616                })
617            },
618            |(output, verifier), chunk| {
619                verifier.write_verify(chunk)?;
620                output.write_all(chunk)?;
621                Ok(())
622            },
623            |(mut output, verifier)| {
624                verifier.final_check()?;
625                output.flush()?;
626                Ok(output)
627            },
628        )
629        .await
630        .map_err(map_downloader_error)
631    }
632}
633
634struct Inner {
635    config: BlockchainRpcClientConfig,
636    overlay_client: PublicOverlayClient,
637    broadcast_listener: Option<Box<dyn SelfBroadcastListener>>,
638    response_tracker: Mutex<tycho_util::time::RollingP2Estimator>,
639}
640
641pub enum PendingArchiveResponse {
642    Found(PendingArchive),
643    TooNew,
644}
645
646#[derive(Clone)]
647pub struct PendingArchive {
648    pub id: u64,
649    pub size: NonZeroU64,
650    pub chunk_size: NonZeroU32,
651    pub neighbour: Neighbour,
652}
653
654#[derive(Clone)]
655pub struct PendingPersistentState {
656    pub block_id: BlockId,
657    pub kind: PersistentStateKind,
658    pub size: NonZeroU64,
659    pub chunk_size: NonZeroU32,
660    pub neighbour: Neighbour,
661}
662
663pub struct BlockDataFull {
664    pub block_id: BlockId,
665    pub block_data: Bytes,
666    pub proof_data: Bytes,
667    pub queue_diff_data: Bytes,
668}
669
670pub struct BlockDataFullWithNeighbour {
671    pub data: Option<BlockDataFull>,
672    pub neighbour: Neighbour,
673}
674
675#[derive(Debug, Clone, Copy, PartialEq, Eq)]
676pub enum DataRequirement {
677    /// Data is not required to be present on the neighbour (mostly for polling).
678    ///
679    /// NOTE: Node will not be punished if the data is not present.
680    Optional,
681    /// We assume that the node has the data, but it's not required.
682    ///
683    /// NOTE: Node will be punished as [`PunishReason::Dumb`] if the data is not present.
684    Expected,
685    /// Data must be present on the requested neighbour.
686    ///
687    /// NOTE: Node will be punished as [`PunishReason::Malicious`] if the data is not present.
688    Required,
689}
690
691async fn download_block_inner(
692    req: Request,
693    overlay_client: PublicOverlayClient,
694    neighbour: Neighbour,
695    requirement: DataRequirement,
696    retries: usize,
697) -> Result<BlockDataFullWithNeighbour, Error> {
698    let response = overlay_client
699        .query_raw::<BlockFull>(neighbour.clone(), req)
700        .await?;
701
702    let (handle, block_full) = response.split();
703
704    let BlockFull::Found {
705        block_id,
706        block: block_data,
707        proof: proof_data,
708        queue_diff: queue_diff_data,
709    } = block_full
710    else {
711        match requirement {
712            DataRequirement::Optional => {
713                handle.accept();
714            }
715            DataRequirement::Expected => {
716                handle.reject();
717            }
718            DataRequirement::Required => {
719                neighbour.punish(crate::overlay_client::PunishReason::Malicious);
720            }
721        }
722
723        return Ok(BlockDataFullWithNeighbour {
724            data: None,
725            neighbour,
726        });
727    };
728
729    const PARALLEL_REQUESTS: usize = 10;
730
731    let target_size = block_data.size.get();
732    let chunk_size = block_data.chunk_size.get();
733    let block_data_size = block_data.data.len() as u32;
734
735    if block_data_size > target_size || block_data_size > chunk_size {
736        return Err(Error::Internal(anyhow::anyhow!("invalid first chunk")));
737    }
738
739    let (chunks_tx, mut chunks_rx) =
740        mpsc::channel::<(QueryResponseHandle, Bytes)>(PARALLEL_REQUESTS);
741
742    let span = tracing::Span::current();
743    let processing_task = tokio::task::spawn_blocking(move || {
744        let _span = span.enter();
745
746        let mut zstd_decoder = ZstdDecompressStream::new(chunk_size as usize)?;
747
748        // Buffer for decompressed data
749        let mut decompressed = Vec::new();
750
751        // Decompress chunk
752        zstd_decoder.write(block_data.data.as_ref(), &mut decompressed)?;
753
754        // Receive and process chunks
755        let mut downloaded = block_data.data.len() as u32;
756        while let Some((h, chunk)) = chunks_rx.blocking_recv() {
757            let guard = scopeguard::guard(h, |handle| {
758                handle.reject();
759            });
760
761            anyhow::ensure!(chunk.len() <= chunk_size as usize, "received invalid chunk");
762
763            downloaded += chunk.len() as u32;
764            tracing::debug!(
765                downloaded = %bytesize::ByteSize::b(downloaded as _),
766                "got block data chunk"
767            );
768
769            anyhow::ensure!(downloaded <= target_size, "received too many chunks");
770
771            // Decompress chunk
772            zstd_decoder.write(chunk.as_ref(), &mut decompressed)?;
773
774            ScopeGuard::into_inner(guard).accept(); // defuse the guard
775        }
776
777        anyhow::ensure!(
778            target_size == downloaded,
779            "block size mismatch (target size: {target_size}; downloaded: {downloaded})",
780        );
781
782        Ok(decompressed)
783    });
784
785    let stream = futures_util::stream::iter((chunk_size..target_size).step_by(chunk_size as usize))
786        .map(|offset| {
787            let neighbour = neighbour.clone();
788            let overlay_client = overlay_client.clone();
789
790            tracing::debug!(%block_id, offset, "downloading block data chunk");
791            JoinTask::new(download_with_retries(
792                Request::from_tl(rpc::GetBlockDataChunk { block_id, offset }),
793                overlay_client,
794                neighbour,
795                retries,
796                "block data chunk",
797            ))
798        })
799        .buffered(PARALLEL_REQUESTS);
800
801    let mut stream = std::pin::pin!(stream);
802    while let Some(chunk) = stream.next().await.transpose()? {
803        if chunks_tx.send(chunk).await.is_err() {
804            break;
805        }
806    }
807
808    drop(chunks_tx);
809
810    let block_data = processing_task
811        .await
812        .map_err(|e| Error::Internal(anyhow::anyhow!("Failed to join blocking task: {e}")))?
813        .map(Bytes::from)
814        .map_err(Error::Internal)?;
815
816    Ok(BlockDataFullWithNeighbour {
817        data: Some(BlockDataFull {
818            block_id,
819            block_data,
820            proof_data,
821            queue_diff_data,
822        }),
823        neighbour: neighbour.clone(),
824    })
825}
826
827async fn download_with_retries(
828    req: Request,
829    overlay_client: PublicOverlayClient,
830    neighbour: Neighbour,
831    max_retries: usize,
832    name: &'static str,
833) -> Result<(QueryResponseHandle, Bytes), Error> {
834    let mut retries = 0;
835    loop {
836        match overlay_client
837            .query_raw::<Data>(neighbour.clone(), req.clone())
838            .await
839        {
840            Ok(r) => {
841                let (h, res) = r.split();
842                return Ok((h, res.data));
843            }
844            Err(e) => {
845                tracing::error!("failed to download {name}: {e:?}");
846                retries += 1;
847                if retries >= max_retries || !neighbour.is_reliable() {
848                    return Err(e);
849                }
850
851                tokio::time::sleep(Duration::from_millis(100)).await;
852            }
853        }
854    }
855}
856
857impl DownloaderResponseHandle for QueryResponseHandle {
858    fn accept(self) {
859        QueryResponseHandle::accept(self);
860    }
861
862    fn reject(self) {
863        QueryResponseHandle::reject(self);
864    }
865}
866
867fn map_downloader_error(e: DownloaderError<Error>) -> Error {
868    match e {
869        DownloaderError::DownloadFailed(e) => e,
870        e => Error::Internal(e.into()),
871    }
872}
873
874// TODO: Move info config?
875const PARALLEL_REQUESTS: NonZeroUsize = NonZeroUsize::new(10).unwrap();
876
877#[cfg(test)]
878mod tests {
879    use rand::RngCore;
880    use tycho_network::PeerId;
881    use tycho_util::compression::zstd_compress;
882
883    use super::*;
884
885    #[derive(Debug, thiserror::Error)]
886    #[error("stub")]
887    struct StubError;
888
889    #[tokio::test]
890    async fn download_compressed_works() -> Result<()> {
891        let neighbour = Neighbour::new(PeerId([0; 32]), u32::MAX, &Duration::from_millis(100));
892
893        let mut original_data = vec![0u8; 1 << 20]; // 1 MB of garbage
894        rand::rng().fill_bytes(&mut original_data);
895
896        let mut compressed_data = Vec::new();
897        zstd_compress(&original_data, &mut compressed_data, 9);
898        let compressed_data = Bytes::from(compressed_data);
899
900        assert_ne!(compressed_data, original_data);
901
902        const CHUNK_SIZE: usize = 128;
903
904        let received = download_and_decompress(
905            NonZeroU64::new(compressed_data.len() as _).unwrap(),
906            NonZeroU32::new(CHUNK_SIZE as _).unwrap(),
907            PARALLEL_REQUESTS,
908            Vec::new(),
909            |offset| {
910                assert_eq!(offset % CHUNK_SIZE as u64, 0);
911                assert!(offset < compressed_data.len() as u64);
912                let from = offset as usize;
913                let to = std::cmp::min(from + CHUNK_SIZE, compressed_data.len());
914                let chunk = compressed_data.slice(from..to);
915                let handle = QueryResponseHandle::with_roundtrip_ms(neighbour.clone(), 100);
916                futures_util::future::ready(Ok::<_, StubError>((handle, chunk)))
917            },
918            |result, chunk| {
919                result.extend_from_slice(chunk);
920                Ok(())
921            },
922            Ok,
923        )
924        .await?;
925        assert_eq!(received, original_data);
926
927        Ok(())
928    }
929}