tycho_core/blockchain_rpc/
client.rs

1use std::future::Future;
2use std::io::Write;
3use std::num::{NonZeroU32, NonZeroU64};
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6
7use anyhow::Result;
8use bytes::Bytes;
9use bytesize::ByteSize;
10use futures_util::stream::{FuturesUnordered, StreamExt};
11use parking_lot::Mutex;
12use scopeguard::ScopeGuard;
13use serde::{Deserialize, Serialize};
14use tokio::sync::mpsc;
15use tycho_block_util::archive::ArchiveVerifier;
16use tycho_network::{PublicOverlay, Request};
17use tycho_types::models::BlockId;
18use tycho_util::compression::ZstdDecompressStream;
19use tycho_util::futures::JoinTask;
20use tycho_util::serde_helpers;
21
22use crate::overlay_client::{
23    Error, Neighbour, NeighbourType, PublicOverlayClient, QueryResponse, QueryResponseHandle,
24};
25use crate::proto::blockchain::*;
26use crate::proto::overlay::BroadcastPrefix;
27use crate::storage::PersistentStateKind;
28
29/// A listener for self-broadcasted messages.
30///
31/// NOTE: `async_trait` is used to add object safety to the trait.
32#[async_trait::async_trait]
33pub trait SelfBroadcastListener: Send + Sync + 'static {
34    async fn handle_message(&self, message: Bytes);
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
38#[serde(default)]
39#[non_exhaustive]
40pub struct BlockchainRpcClientConfig {
41    /// Timeout to broadcast external messages
42    ///
43    /// Default: 100 ms.
44    #[serde(with = "serde_helpers::humantime")]
45    pub min_broadcast_timeout: Duration,
46
47    /// Minimum number of neighbours with `TooNew`
48    /// response required to switch provider
49    ///
50    /// Default: 4.
51    pub too_new_archive_threshold: usize,
52
53    /// Number of retries to download blocks/archives
54    ///
55    /// Default: 10.
56    pub download_retries: usize,
57}
58
59impl Default for BlockchainRpcClientConfig {
60    fn default() -> Self {
61        Self {
62            min_broadcast_timeout: Duration::from_millis(100),
63            too_new_archive_threshold: 4,
64            download_retries: 10,
65        }
66    }
67}
68
69pub struct BlockchainRpcClientBuilder<MandatoryFields = PublicOverlayClient> {
70    config: BlockchainRpcClientConfig,
71    mandatory_fields: MandatoryFields,
72    broadcast_listener: Option<Box<dyn SelfBroadcastListener>>,
73}
74
75impl BlockchainRpcClientBuilder<PublicOverlayClient> {
76    pub fn build(self) -> BlockchainRpcClient {
77        BlockchainRpcClient {
78            inner: Arc::new(Inner {
79                config: self.config,
80                overlay_client: self.mandatory_fields,
81                broadcast_listener: self.broadcast_listener,
82                response_tracker: Mutex::new(
83                    // 5 windows, 60 seconds each, 0.75 quantile
84                    tycho_util::time::RollingP2Estimator::new_with_config(
85                        0.75, // should be enough to filter most of the outliers
86                        Duration::from_secs(60),
87                        5,
88                        tycho_util::time::RealClock,
89                    )
90                    .expect("correct quantile"),
91                ),
92            }),
93        }
94    }
95}
96
97impl BlockchainRpcClientBuilder<()> {
98    pub fn with_public_overlay_client(
99        self,
100        client: PublicOverlayClient,
101    ) -> BlockchainRpcClientBuilder<PublicOverlayClient> {
102        BlockchainRpcClientBuilder {
103            config: self.config,
104            mandatory_fields: client,
105            broadcast_listener: self.broadcast_listener,
106        }
107    }
108}
109
110impl<T> BlockchainRpcClientBuilder<T> {
111    pub fn with_self_broadcast_listener(mut self, listener: impl SelfBroadcastListener) -> Self {
112        self.broadcast_listener = Some(Box::new(listener));
113        self
114    }
115}
116
117impl<T> BlockchainRpcClientBuilder<T> {
118    pub fn with_config(self, config: BlockchainRpcClientConfig) -> Self {
119        Self { config, ..self }
120    }
121}
122
123#[derive(Clone)]
124#[repr(transparent)]
125pub struct BlockchainRpcClient {
126    inner: Arc<Inner>,
127}
128
129impl BlockchainRpcClient {
130    pub fn builder() -> BlockchainRpcClientBuilder<()> {
131        BlockchainRpcClientBuilder {
132            config: Default::default(),
133            mandatory_fields: (),
134            broadcast_listener: None,
135        }
136    }
137
138    pub fn overlay(&self) -> &PublicOverlay {
139        self.inner.overlay_client.overlay()
140    }
141
142    pub fn overlay_client(&self) -> &PublicOverlayClient {
143        &self.inner.overlay_client
144    }
145
146    // TODO: Add rate limiting
147    /// Broadcasts a message to the current targets list and
148    /// returns the number of peers the message was delivered to.
149    pub async fn broadcast_external_message(&self, message: &[u8]) -> usize {
150        struct ExternalMessage<'a> {
151            data: &'a [u8],
152        }
153
154        impl tl_proto::TlWrite for ExternalMessage<'_> {
155            type Repr = tl_proto::Boxed;
156
157            fn max_size_hint(&self) -> usize {
158                4 + MessageBroadcastRef { data: self.data }.max_size_hint()
159            }
160
161            fn write_to<P>(&self, packet: &mut P)
162            where
163                P: tl_proto::TlPacket,
164            {
165                packet.write_u32(BroadcastPrefix::TL_ID);
166                MessageBroadcastRef { data: self.data }.write_to(packet);
167            }
168        }
169
170        // Broadcast to yourself
171        if let Some(l) = &self.inner.broadcast_listener {
172            l.handle_message(Bytes::copy_from_slice(message)).await;
173        }
174
175        let client = &self.inner.overlay_client;
176
177        let mut delivered_to = 0;
178
179        let targets = client.get_broadcast_targets();
180        let request = Request::from_tl(ExternalMessage { data: message });
181        let mut futures = FuturesUnordered::new();
182
183        // we wait for all the responses to come back but cap them at `broadcast_timeout_upper_bound`
184        // all peers timeouts are calculated based on p90 of the previous responses time weighted average
185        // This will make broadcast timeout to be adaptive based on the network conditions
186        for validator in targets.as_ref() {
187            let client = client.clone();
188            let validator = validator.clone();
189            let request = request.clone();
190            let this = self.inner.clone();
191            // we are not using `JoinTask` here because we want to measure the time taken by the broadcast
192            futures.push(tokio::spawn(async move {
193                let start = Instant::now();
194                let res = client.send_to_validator(validator, request).await;
195                this.response_tracker
196                    .lock()
197                    .append(start.elapsed().as_millis() as i64);
198                res
199            }));
200        }
201
202        let timeout = self.compute_broadcast_timeout();
203        tokio::time::timeout(timeout, async {
204            // inner task timeout won't happen because outer task timeout is always <= inner task timeout
205            while let Some(Ok(res)) = futures.next().await {
206                if let Err(e) = res {
207                    tracing::warn!("failed to broadcast external message: {e}");
208                } else {
209                    delivered_to += 1;
210                }
211            }
212        })
213        .await
214        .ok();
215
216        if delivered_to == 0 {
217            tracing::debug!("message was not delivered to any peer");
218        }
219
220        delivered_to
221    }
222
223    fn compute_broadcast_timeout(&self) -> Duration {
224        let max_broadcast_timeout = std::cmp::max(
225            self.inner.overlay_client.config().validators.send_timeout,
226            self.inner.config.min_broadcast_timeout,
227        );
228
229        if let Some(prev_time) = self
230            .inner
231            .response_tracker
232            .lock()
233            .exponentially_weighted_average()
234            .map(|x| Duration::from_millis(x as _))
235        {
236            metrics::gauge!("tycho_broadcast_timeout", "kind" => "calculated").set(prev_time);
237            let value = prev_time.clamp(
238                self.inner.config.min_broadcast_timeout,
239                max_broadcast_timeout,
240            );
241            metrics::gauge!("tycho_broadcast_timeout", "kind" => "clamped").set(value);
242            value
243        } else {
244            max_broadcast_timeout
245        }
246    }
247
248    pub async fn get_next_key_block_ids(
249        &self,
250        block: &BlockId,
251        max_size: u32,
252    ) -> Result<QueryResponse<KeyBlockIds>, Error> {
253        let client = &self.inner.overlay_client;
254        let data = client
255            .query::<_, KeyBlockIds>(&rpc::GetNextKeyBlockIds {
256                block_id: *block,
257                max_size,
258            })
259            .await?;
260        Ok(data)
261    }
262
263    pub async fn get_block_full(
264        &self,
265        block: &BlockId,
266        requirement: DataRequirement,
267    ) -> Result<BlockDataFullWithNeighbour, Error> {
268        let overlay_client = self.inner.overlay_client.clone();
269
270        let Some(neighbour) = overlay_client.neighbours().choose() else {
271            return Err(Error::NoNeighbours);
272        };
273
274        let retries = self.inner.config.download_retries;
275
276        download_block_inner(
277            Request::from_tl(rpc::GetBlockFull { block_id: *block }),
278            overlay_client,
279            neighbour,
280            requirement,
281            retries,
282        )
283        .await
284    }
285
286    pub async fn get_next_block_full(
287        &self,
288        prev_block: &BlockId,
289        requirement: DataRequirement,
290    ) -> Result<BlockDataFullWithNeighbour, Error> {
291        let overlay_client = self.inner.overlay_client.clone();
292
293        let Some(neighbour) = overlay_client.neighbours().choose() else {
294            return Err(Error::NoNeighbours);
295        };
296
297        let retries = self.inner.config.download_retries;
298
299        download_block_inner(
300            Request::from_tl(rpc::GetNextBlockFull {
301                prev_block_id: *prev_block,
302            }),
303            overlay_client,
304            neighbour,
305            requirement,
306            retries,
307        )
308        .await
309    }
310
311    pub async fn get_key_block_proof(
312        &self,
313        block_id: &BlockId,
314    ) -> Result<QueryResponse<KeyBlockProof>, Error> {
315        let client = &self.inner.overlay_client;
316        let data = client
317            .query::<_, KeyBlockProof>(&rpc::GetKeyBlockProof {
318                block_id: *block_id,
319            })
320            .await?;
321        Ok(data)
322    }
323
324    pub async fn get_persistent_state_info(
325        &self,
326        block_id: &BlockId,
327    ) -> Result<QueryResponse<PersistentStateInfo>, Error> {
328        let client = &self.inner.overlay_client;
329        let data = client
330            .query::<_, PersistentStateInfo>(&rpc::GetPersistentShardStateInfo {
331                block_id: *block_id,
332            })
333            .await?;
334        Ok(data)
335    }
336
337    pub async fn get_persistent_state_part(
338        &self,
339        neighbour: &Neighbour,
340        block_id: &BlockId,
341        offset: u64,
342    ) -> Result<QueryResponse<Data>, Error> {
343        let client = &self.inner.overlay_client;
344        let data = client
345            .query_raw::<Data>(
346                neighbour.clone(),
347                Request::from_tl(rpc::GetPersistentShardStateChunk {
348                    block_id: *block_id,
349                    offset,
350                }),
351            )
352            .await?;
353        Ok(data)
354    }
355
356    pub async fn find_persistent_state(
357        &self,
358        block_id: &BlockId,
359        kind: PersistentStateKind,
360    ) -> Result<PendingPersistentState, Error> {
361        const NEIGHBOUR_COUNT: usize = 10;
362
363        // Get reliable neighbours with higher weight
364        let neighbours = self
365            .overlay_client()
366            .neighbours()
367            .choose_multiple(NEIGHBOUR_COUNT, NeighbourType::Reliable);
368
369        let req = match kind {
370            PersistentStateKind::Shard => Request::from_tl(rpc::GetPersistentShardStateInfo {
371                block_id: *block_id,
372            }),
373            PersistentStateKind::Queue => Request::from_tl(rpc::GetPersistentQueueStateInfo {
374                block_id: *block_id,
375            }),
376        };
377
378        let mut futures = FuturesUnordered::new();
379        for neighbour in neighbours {
380            futures.push(
381                self.overlay_client()
382                    .query_raw::<PersistentStateInfo>(neighbour.clone(), req.clone()),
383            );
384        }
385
386        let mut err = None;
387        while let Some(info) = futures.next().await {
388            let (handle, info) = match info {
389                Ok(res) => res.split(),
390                Err(e) => {
391                    err = Some(e);
392                    continue;
393                }
394            };
395
396            match info {
397                PersistentStateInfo::Found { size, chunk_size } => {
398                    let neighbour = handle.accept();
399                    tracing::debug!(
400                        peer_id = %neighbour.peer_id(),
401                        state_size = size.get(),
402                        state_chunk_size = chunk_size.get(),
403                        ?kind,
404                        "found persistent state",
405                    );
406
407                    return Ok(PendingPersistentState {
408                        block_id: *block_id,
409                        kind,
410                        size,
411                        chunk_size,
412                        neighbour,
413                    });
414                }
415                PersistentStateInfo::NotFound => {}
416            }
417        }
418
419        match err {
420            None => Err(Error::NotFound),
421            Some(err) => Err(err),
422        }
423    }
424
425    #[tracing::instrument(skip_all, fields(
426        peer_id = %state.neighbour.peer_id(),
427        block_id = %state.block_id,
428        kind = ?state.kind,
429    ))]
430    pub async fn download_persistent_state<W>(
431        &self,
432        state: PendingPersistentState,
433        output: W,
434    ) -> Result<W, Error>
435    where
436        W: Write + Send + 'static,
437    {
438        tracing::debug!("started");
439        scopeguard::defer! {
440            tracing::debug!("finished");
441        }
442
443        let block_id = state.block_id;
444        let max_retries = self.inner.config.download_retries;
445
446        download_compressed(
447            state.size,
448            state.chunk_size,
449            output,
450            |offset| {
451                tracing::debug!("downloading persistent state chunk");
452
453                let req = match state.kind {
454                    PersistentStateKind::Shard => {
455                        Request::from_tl(rpc::GetPersistentShardStateChunk { block_id, offset })
456                    }
457                    PersistentStateKind::Queue => {
458                        Request::from_tl(rpc::GetPersistentQueueStateChunk { block_id, offset })
459                    }
460                };
461                download_with_retries(
462                    req,
463                    self.overlay_client().clone(),
464                    state.neighbour.clone(),
465                    max_retries,
466                    "persistent state chunk",
467                )
468            },
469            |output, chunk| {
470                output.write_all(chunk)?;
471                Ok(())
472            },
473            |mut output| {
474                output.flush()?;
475                Ok(output)
476            },
477        )
478        .await
479    }
480
481    pub async fn find_archive(&self, mc_seqno: u32) -> Result<PendingArchiveResponse, Error> {
482        const NEIGHBOUR_COUNT: usize = 10;
483
484        // Get reliable neighbours with higher weight
485        let neighbours = self
486            .overlay_client()
487            .neighbours()
488            .choose_multiple(NEIGHBOUR_COUNT, NeighbourType::Reliable);
489
490        // Find a neighbour which has the requested archive
491        let pending_archive = 'info: {
492            let req = Request::from_tl(rpc::GetArchiveInfo { mc_seqno });
493
494            // Number of ArchiveInfo::TooNew responses
495            let mut new_archive_count = 0usize;
496
497            let mut futures = FuturesUnordered::new();
498            for neighbour in neighbours {
499                futures.push(self.overlay_client().query_raw(neighbour, req.clone()));
500            }
501
502            let mut err = None;
503            while let Some(info) = futures.next().await {
504                let (handle, info) = match info {
505                    Ok(res) => res.split(),
506                    Err(e) => {
507                        err = Some(e);
508                        continue;
509                    }
510                };
511
512                match info {
513                    ArchiveInfo::Found {
514                        id,
515                        size,
516                        chunk_size,
517                    } => {
518                        break 'info PendingArchive {
519                            id,
520                            size,
521                            chunk_size,
522                            neighbour: handle.accept(),
523                        };
524                    }
525                    ArchiveInfo::TooNew => {
526                        new_archive_count += 1;
527
528                        handle.accept();
529                    }
530                    ArchiveInfo::NotFound => {
531                        handle.accept();
532                    }
533                }
534            }
535
536            // Stop using archives if enough neighbors responded TooNew
537            if new_archive_count >= self.inner.config.too_new_archive_threshold {
538                return Ok(PendingArchiveResponse::TooNew);
539            }
540
541            return match err {
542                None => Err(Error::NotFound),
543                Some(err) => Err(err),
544            };
545        };
546
547        tracing::info!(
548            peer_id = %pending_archive.neighbour.peer_id(),
549            archive_id = pending_archive.id,
550            archive_size = %ByteSize(pending_archive.size.get()),
551            archuve_chunk_size = %ByteSize(pending_archive.chunk_size.get() as _),
552            "found archive",
553        );
554        Ok(PendingArchiveResponse::Found(pending_archive))
555    }
556
557    #[tracing::instrument(skip_all, fields(
558        peer_id = %archive.neighbour.peer_id(),
559        archive_id = archive.id,
560    ))]
561    pub async fn download_archive<W>(&self, archive: PendingArchive, output: W) -> Result<W, Error>
562    where
563        W: Write + Send + 'static,
564    {
565        use futures_util::FutureExt;
566
567        tracing::debug!("started");
568        scopeguard::defer! {
569            tracing::debug!("finished");
570        }
571
572        let retries = self.inner.config.download_retries;
573
574        download_compressed(
575            archive.size,
576            archive.chunk_size,
577            (output, ArchiveVerifier::default()),
578            |offset| {
579                let archive_id = archive.id;
580                let neighbour = archive.neighbour.clone();
581                let overlay_client = self.overlay_client().clone();
582
583                let started_at = Instant::now();
584
585                tracing::debug!(archive_id, offset, "downloading archive chunk");
586                download_with_retries(
587                    Request::from_tl(rpc::GetArchiveChunk { archive_id, offset }),
588                    overlay_client,
589                    neighbour,
590                    retries,
591                    "archive chunk",
592                )
593                .map(move |res| {
594                    tracing::info!(
595                        archive_id,
596                        offset,
597                        elapsed = %humantime::format_duration(started_at.elapsed()),
598                        "downloaded archive chunk",
599                    );
600                    res
601                })
602            },
603            |(output, verifier), chunk| {
604                verifier.write_verify(chunk)?;
605                output.write_all(chunk)?;
606                Ok(())
607            },
608            |(mut output, verifier)| {
609                verifier.final_check()?;
610                output.flush()?;
611                Ok(output)
612            },
613        )
614        .await
615    }
616}
617
618struct Inner {
619    config: BlockchainRpcClientConfig,
620    overlay_client: PublicOverlayClient,
621    broadcast_listener: Option<Box<dyn SelfBroadcastListener>>,
622    response_tracker: Mutex<tycho_util::time::RollingP2Estimator>,
623}
624
625pub enum PendingArchiveResponse {
626    Found(PendingArchive),
627    TooNew,
628}
629
630#[derive(Clone)]
631pub struct PendingArchive {
632    pub id: u64,
633    pub size: NonZeroU64,
634    pub chunk_size: NonZeroU32,
635    pub neighbour: Neighbour,
636}
637
638#[derive(Clone)]
639pub struct PendingPersistentState {
640    pub block_id: BlockId,
641    pub kind: PersistentStateKind,
642    pub size: NonZeroU64,
643    pub chunk_size: NonZeroU32,
644    pub neighbour: Neighbour,
645}
646
647pub struct BlockDataFull {
648    pub block_id: BlockId,
649    pub block_data: Bytes,
650    pub proof_data: Bytes,
651    pub queue_diff_data: Bytes,
652}
653
654pub struct BlockDataFullWithNeighbour {
655    pub data: Option<BlockDataFull>,
656    pub neighbour: Neighbour,
657}
658
659#[derive(Debug, Clone, Copy, PartialEq, Eq)]
660pub enum DataRequirement {
661    /// Data is not required to be present on the neighbour (mostly for polling).
662    ///
663    /// NOTE: Node will not be punished if the data is not present.
664    Optional,
665    /// We assume that the node has the data, but it's not required.
666    ///
667    /// NOTE: Node will be punished as [`PunishReason::Dumb`] if the data is not present.
668    Expected,
669    /// Data must be present on the requested neighbour.
670    ///
671    /// NOTE: Node will be punished as [`PunishReason::Malicious`] if the data is not present.
672    Required,
673}
674
675async fn download_block_inner(
676    req: Request,
677    overlay_client: PublicOverlayClient,
678    neighbour: Neighbour,
679    requirement: DataRequirement,
680    retries: usize,
681) -> Result<BlockDataFullWithNeighbour, Error> {
682    let response = overlay_client
683        .query_raw::<BlockFull>(neighbour.clone(), req)
684        .await?;
685
686    let (handle, block_full) = response.split();
687
688    let BlockFull::Found {
689        block_id,
690        block: block_data,
691        proof: proof_data,
692        queue_diff: queue_diff_data,
693    } = block_full
694    else {
695        match requirement {
696            DataRequirement::Optional => {
697                handle.accept();
698            }
699            DataRequirement::Expected => {
700                handle.reject();
701            }
702            DataRequirement::Required => {
703                neighbour.punish(crate::overlay_client::PunishReason::Malicious);
704            }
705        }
706
707        return Ok(BlockDataFullWithNeighbour {
708            data: None,
709            neighbour,
710        });
711    };
712
713    const PARALLEL_REQUESTS: usize = 10;
714
715    let target_size = block_data.size.get();
716    let chunk_size = block_data.chunk_size.get();
717    let block_data_size = block_data.data.len() as u32;
718
719    if block_data_size > target_size || block_data_size > chunk_size {
720        return Err(Error::Internal(anyhow::anyhow!("invalid first chunk")));
721    }
722
723    let (chunks_tx, mut chunks_rx) =
724        mpsc::channel::<(QueryResponseHandle, Bytes)>(PARALLEL_REQUESTS);
725
726    let span = tracing::Span::current();
727    let processing_task = tokio::task::spawn_blocking(move || {
728        let _span = span.enter();
729
730        let mut zstd_decoder = ZstdDecompressStream::new(chunk_size as usize)?;
731
732        // Buffer for decompressed data
733        let mut decompressed = Vec::new();
734
735        // Decompress chunk
736        zstd_decoder.write(block_data.data.as_ref(), &mut decompressed)?;
737
738        // Receive and process chunks
739        let mut downloaded = block_data.data.len() as u32;
740        while let Some((h, chunk)) = chunks_rx.blocking_recv() {
741            let guard = scopeguard::guard(h, |handle| {
742                handle.reject();
743            });
744
745            anyhow::ensure!(chunk.len() <= chunk_size as usize, "received invalid chunk");
746
747            downloaded += chunk.len() as u32;
748            tracing::debug!(
749                downloaded = %bytesize::ByteSize::b(downloaded as _),
750                "got block data chunk"
751            );
752
753            anyhow::ensure!(downloaded <= target_size, "received too many chunks");
754
755            // Decompress chunk
756            zstd_decoder.write(chunk.as_ref(), &mut decompressed)?;
757
758            ScopeGuard::into_inner(guard).accept(); // defuse the guard
759        }
760
761        anyhow::ensure!(
762            target_size == downloaded,
763            "block size mismatch (target size: {target_size}; downloaded: {downloaded})",
764        );
765
766        Ok(decompressed)
767    });
768
769    let stream = futures_util::stream::iter((chunk_size..target_size).step_by(chunk_size as usize))
770        .map(|offset| {
771            let neighbour = neighbour.clone();
772            let overlay_client = overlay_client.clone();
773
774            tracing::debug!(%block_id, offset, "downloading block data chunk");
775            JoinTask::new(download_with_retries(
776                Request::from_tl(rpc::GetBlockDataChunk { block_id, offset }),
777                overlay_client,
778                neighbour,
779                retries,
780                "block data chunk",
781            ))
782        })
783        .buffered(PARALLEL_REQUESTS);
784
785    let mut stream = std::pin::pin!(stream);
786    while let Some(chunk) = stream.next().await.transpose()? {
787        if chunks_tx.send(chunk).await.is_err() {
788            break;
789        }
790    }
791
792    drop(chunks_tx);
793
794    let block_data = processing_task
795        .await
796        .map_err(|e| Error::Internal(anyhow::anyhow!("Failed to join blocking task: {e}")))?
797        .map(Bytes::from)
798        .map_err(Error::Internal)?;
799
800    Ok(BlockDataFullWithNeighbour {
801        data: Some(BlockDataFull {
802            block_id,
803            block_data,
804            proof_data,
805            queue_diff_data,
806        }),
807        neighbour: neighbour.clone(),
808    })
809}
810
811async fn download_compressed<S, T, DF, DFut, PF, FF>(
812    target_size: NonZeroU64,
813    chunk_size: NonZeroU32,
814    mut state: S,
815    mut download_fn: DF,
816    mut process_fn: PF,
817    finalize_fn: FF,
818) -> Result<T, Error>
819where
820    S: Send + 'static,
821    T: Send + 'static,
822    DF: FnMut(u64) -> DFut,
823    DFut: Future<Output = DownloadedChunkResult> + Send + 'static,
824    PF: FnMut(&mut S, &[u8]) -> Result<()> + Send + 'static,
825    FF: FnOnce(S) -> Result<T> + Send + 'static,
826{
827    const PARALLEL_REQUESTS: usize = 10;
828
829    let target_size = target_size.get();
830    let chunk_size = chunk_size.get() as usize;
831
832    let (chunks_tx, mut chunks_rx) =
833        mpsc::channel::<(QueryResponseHandle, Bytes)>(PARALLEL_REQUESTS);
834
835    let span = tracing::Span::current();
836    let processing_task = tokio::task::spawn_blocking(move || {
837        let _span = span.enter();
838
839        let mut zstd_decoder = ZstdDecompressStream::new(chunk_size)?;
840
841        // Reuse buffer for decompressed data
842        let mut decompressed_chunk = Vec::new();
843
844        // Receive and process chunks
845        let mut downloaded = 0;
846        while let Some((h, chunk)) = chunks_rx.blocking_recv() {
847            let guard = scopeguard::guard(h, |handle| {
848                handle.reject();
849            });
850
851            anyhow::ensure!(chunk.len() <= chunk_size, "received invalid chunk");
852
853            downloaded += chunk.len() as u64;
854            tracing::debug!(
855                downloaded = %bytesize::ByteSize::b(downloaded),
856                "got chunk"
857            );
858
859            anyhow::ensure!(downloaded <= target_size, "received too many chunks");
860
861            decompressed_chunk.clear();
862            zstd_decoder.write(chunk.as_ref(), &mut decompressed_chunk)?;
863
864            process_fn(&mut state, &decompressed_chunk)?;
865
866            ScopeGuard::into_inner(guard).accept(); // defuse the guard
867        }
868
869        anyhow::ensure!(
870            target_size == downloaded,
871            "size mismatch (target size: {target_size}; downloaded: {downloaded})",
872        );
873
874        finalize_fn(state)
875    });
876
877    let stream = futures_util::stream::iter((0..target_size).step_by(chunk_size))
878        .map(|offset| JoinTask::new(download_fn(offset)))
879        .buffered(PARALLEL_REQUESTS);
880
881    let mut stream = std::pin::pin!(stream);
882    while let Some(chunk) = stream.next().await.transpose()? {
883        if chunks_tx.send(chunk).await.is_err() {
884            break;
885        }
886    }
887
888    drop(chunks_tx);
889
890    let output = processing_task
891        .await
892        .map_err(|e| Error::Internal(anyhow::anyhow!("Failed to join blocking task: {e}")))?
893        .map_err(Error::Internal)?;
894
895    Ok(output)
896}
897
898async fn download_with_retries(
899    req: Request,
900    overlay_client: PublicOverlayClient,
901    neighbour: Neighbour,
902    max_retries: usize,
903    name: &'static str,
904) -> DownloadedChunkResult {
905    let mut retries = 0;
906    loop {
907        match overlay_client
908            .query_raw::<Data>(neighbour.clone(), req.clone())
909            .await
910        {
911            Ok(r) => {
912                let (h, res) = r.split();
913                return Ok((h, res.data));
914            }
915            Err(e) => {
916                tracing::error!("failed to download {name}: {e}");
917                retries += 1;
918                if retries >= max_retries || !neighbour.is_reliable() {
919                    return Err(e);
920                }
921
922                tokio::time::sleep(Duration::from_millis(100)).await;
923            }
924        }
925    }
926}
927
928type DownloadedChunkResult = Result<(QueryResponseHandle, Bytes), Error>;
929
930#[cfg(test)]
931mod tests {
932    use rand::RngCore;
933    use tycho_network::PeerId;
934    use tycho_util::compression::zstd_compress;
935
936    use super::*;
937
938    #[tokio::test]
939    async fn download_compressed_works() -> Result<()> {
940        let neighbour = Neighbour::new(PeerId([0; 32]), u32::MAX, &Duration::from_millis(100));
941
942        let mut original_data = vec![0u8; 1 << 20]; // 1 MB of garbage
943        rand::rng().fill_bytes(&mut original_data);
944
945        let mut compressed_data = Vec::new();
946        zstd_compress(&original_data, &mut compressed_data, 9);
947        let compressed_data = Bytes::from(compressed_data);
948
949        assert_ne!(compressed_data, original_data);
950
951        const CHUNK_SIZE: usize = 128;
952
953        let received = download_compressed(
954            NonZeroU64::new(compressed_data.len() as _).unwrap(),
955            NonZeroU32::new(CHUNK_SIZE as _).unwrap(),
956            Vec::new(),
957            |offset| {
958                assert_eq!(offset % CHUNK_SIZE as u64, 0);
959                assert!(offset < compressed_data.len() as u64);
960                let from = offset as usize;
961                let to = std::cmp::min(from + CHUNK_SIZE, compressed_data.len());
962                let chunk = compressed_data.slice(from..to);
963                let handle = QueryResponseHandle::with_roundtrip_ms(neighbour.clone(), 100);
964                futures_util::future::ready(Ok((handle, chunk)))
965            },
966            |result, chunk| {
967                result.extend_from_slice(chunk);
968                Ok(())
969            },
970            Ok,
971        )
972        .await?;
973        assert_eq!(received, original_data);
974
975        Ok(())
976    }
977}