tycho_core/blockchain_rpc/
client.rs

1use std::future::Future;
2use std::io::Write;
3use std::num::{NonZeroU32, NonZeroU64};
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6
7use anyhow::Result;
8use bytes::Bytes;
9use bytesize::ByteSize;
10use futures_util::stream::{FuturesUnordered, StreamExt};
11use parking_lot::Mutex;
12use scopeguard::ScopeGuard;
13use serde::{Deserialize, Serialize};
14use tokio::sync::mpsc;
15use tycho_block_util::archive::ArchiveVerifier;
16use tycho_network::{PublicOverlay, Request};
17use tycho_types::models::BlockId;
18use tycho_util::compression::ZstdDecompressStream;
19use tycho_util::futures::JoinTask;
20use tycho_util::serde_helpers;
21
22use crate::overlay_client::{
23    Error, Neighbour, NeighbourType, PublicOverlayClient, QueryResponse, QueryResponseHandle,
24};
25use crate::proto::blockchain::*;
26use crate::proto::overlay::BroadcastPrefix;
27use crate::storage::PersistentStateKind;
28
29/// A listener for self-broadcasted messages.
30///
31/// NOTE: `async_trait` is used to add object safety to the trait.
32#[async_trait::async_trait]
33pub trait SelfBroadcastListener: Send + Sync + 'static {
34    async fn handle_message(&self, message: Bytes);
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
38#[serde(default)]
39#[non_exhaustive]
40pub struct BlockchainRpcClientConfig {
41    /// Timeout to broadcast external messages
42    ///
43    /// Default: 100 ms.
44    #[serde(with = "serde_helpers::humantime")]
45    pub min_broadcast_timeout: Duration,
46
47    /// Minimum number of neighbours with `TooNew`
48    /// response required to switch provider
49    ///
50    /// Default: 4.
51    pub too_new_archive_threshold: usize,
52
53    /// Number of retries to download blocks/archives
54    ///
55    /// Default: 10.
56    pub download_retries: usize,
57}
58
59impl Default for BlockchainRpcClientConfig {
60    fn default() -> Self {
61        Self {
62            min_broadcast_timeout: Duration::from_millis(100),
63            too_new_archive_threshold: 4,
64            download_retries: 10,
65        }
66    }
67}
68
69pub struct BlockchainRpcClientBuilder<MandatoryFields = PublicOverlayClient> {
70    config: BlockchainRpcClientConfig,
71    mandatory_fields: MandatoryFields,
72    broadcast_listener: Option<Box<dyn SelfBroadcastListener>>,
73}
74
75impl BlockchainRpcClientBuilder<PublicOverlayClient> {
76    pub fn build(self) -> BlockchainRpcClient {
77        BlockchainRpcClient {
78            inner: Arc::new(Inner {
79                config: self.config,
80                overlay_client: self.mandatory_fields,
81                broadcast_listener: self.broadcast_listener,
82                response_tracker: Mutex::new(
83                    // 5 windows, 60 seconds each, 0.75 quantile
84                    tycho_util::time::RollingP2Estimator::new_with_config(
85                        0.75, // should be enough to filter most of the outliers
86                        Duration::from_secs(60),
87                        5,
88                        tycho_util::time::RealClock,
89                    )
90                    .expect("correct quantile"),
91                ),
92            }),
93        }
94    }
95}
96
97impl BlockchainRpcClientBuilder<()> {
98    pub fn with_public_overlay_client(
99        self,
100        client: PublicOverlayClient,
101    ) -> BlockchainRpcClientBuilder<PublicOverlayClient> {
102        BlockchainRpcClientBuilder {
103            config: self.config,
104            mandatory_fields: client,
105            broadcast_listener: self.broadcast_listener,
106        }
107    }
108}
109
110impl<T> BlockchainRpcClientBuilder<T> {
111    pub fn with_self_broadcast_listener(mut self, listener: impl SelfBroadcastListener) -> Self {
112        self.broadcast_listener = Some(Box::new(listener));
113        self
114    }
115}
116
117impl<T> BlockchainRpcClientBuilder<T> {
118    pub fn with_config(self, config: BlockchainRpcClientConfig) -> Self {
119        Self { config, ..self }
120    }
121}
122
123#[derive(Clone)]
124#[repr(transparent)]
125pub struct BlockchainRpcClient {
126    inner: Arc<Inner>,
127}
128
129impl BlockchainRpcClient {
130    pub fn builder() -> BlockchainRpcClientBuilder<()> {
131        BlockchainRpcClientBuilder {
132            config: Default::default(),
133            mandatory_fields: (),
134            broadcast_listener: None,
135        }
136    }
137
138    pub fn overlay(&self) -> &PublicOverlay {
139        self.inner.overlay_client.overlay()
140    }
141
142    pub fn overlay_client(&self) -> &PublicOverlayClient {
143        &self.inner.overlay_client
144    }
145
146    // TODO: Add rate limiting
147    /// Broadcasts a message to the current targets list and
148    /// returns the number of peers the message was delivered to.
149    pub async fn broadcast_external_message(&self, message: &[u8]) -> usize {
150        struct ExternalMessage<'a> {
151            data: &'a [u8],
152        }
153
154        impl tl_proto::TlWrite for ExternalMessage<'_> {
155            type Repr = tl_proto::Boxed;
156
157            fn max_size_hint(&self) -> usize {
158                4 + MessageBroadcastRef { data: self.data }.max_size_hint()
159            }
160
161            fn write_to<P>(&self, packet: &mut P)
162            where
163                P: tl_proto::TlPacket,
164            {
165                packet.write_u32(BroadcastPrefix::TL_ID);
166                MessageBroadcastRef { data: self.data }.write_to(packet);
167            }
168        }
169
170        // Broadcast to yourself
171        if let Some(l) = &self.inner.broadcast_listener {
172            l.handle_message(Bytes::copy_from_slice(message)).await;
173        }
174
175        let client = &self.inner.overlay_client;
176
177        let mut delivered_to = 0;
178
179        let targets = client.get_broadcast_targets();
180        let request = Request::from_tl(ExternalMessage { data: message });
181        let mut futures = FuturesUnordered::new();
182
183        // we wait for all the responses to come back but cap them at `broadcast_timeout_upper_bound`
184        // all peers timeouts are calculated based on p90 of the previous responses time weighted average
185        // This will make broadcast timeout to be adaptive based on the network conditions
186        for validator in targets.as_ref() {
187            let client = client.clone();
188            let validator = validator.clone();
189            let request = request.clone();
190            let this = self.inner.clone();
191            // we are not using `JoinTask` here because we want to measure the time taken by the broadcast
192            futures.push(tokio::spawn(async move {
193                let start = Instant::now();
194                let res = client.send_to_validator(validator, request).await;
195                this.response_tracker
196                    .lock()
197                    .append(start.elapsed().as_millis() as i64);
198                res
199            }));
200        }
201
202        let timeout = self.compute_broadcast_timeout();
203        tokio::time::timeout(timeout, async {
204            // inner task timeout won't happen because outer task timeout is always <= inner task timeout
205            while let Some(Ok(res)) = futures.next().await {
206                if let Err(e) = res {
207                    tracing::warn!("failed to broadcast external message: {e}");
208                } else {
209                    delivered_to += 1;
210                }
211            }
212        })
213        .await
214        .ok();
215
216        if delivered_to == 0 {
217            tracing::debug!("message was not delivered to any peer");
218        }
219
220        delivered_to
221    }
222
223    fn compute_broadcast_timeout(&self) -> Duration {
224        let max_broadcast_timeout = std::cmp::max(
225            self.inner.overlay_client.config().validators.send_timeout,
226            self.inner.config.min_broadcast_timeout,
227        );
228
229        if let Some(prev_time) = self
230            .inner
231            .response_tracker
232            .lock()
233            .exponentially_weighted_average()
234            .map(|x| Duration::from_millis(x as _))
235        {
236            metrics::gauge!("tycho_broadcast_timeout", "kind" => "calculated").set(prev_time);
237            let value = prev_time.clamp(
238                self.inner.config.min_broadcast_timeout,
239                max_broadcast_timeout,
240            );
241            metrics::gauge!("tycho_broadcast_timeout", "kind" => "clamped").set(value);
242            value
243        } else {
244            max_broadcast_timeout
245        }
246    }
247
248    pub async fn get_next_key_block_ids(
249        &self,
250        block: &BlockId,
251        max_size: u32,
252    ) -> Result<QueryResponse<KeyBlockIds>, Error> {
253        let client = &self.inner.overlay_client;
254        let data = client
255            .query::<_, KeyBlockIds>(&rpc::GetNextKeyBlockIds {
256                block_id: *block,
257                max_size,
258            })
259            .await?;
260        Ok(data)
261    }
262
263    #[tracing::instrument(skip_all, fields(
264        block_id = %block.as_short_id(),
265        requirement = ?requirement,
266    ))]
267    pub async fn get_block_full(
268        &self,
269        block: &BlockId,
270        requirement: DataRequirement,
271    ) -> Result<BlockDataFullWithNeighbour, Error> {
272        let overlay_client = self.inner.overlay_client.clone();
273
274        let Some(neighbour) = overlay_client.neighbours().choose() else {
275            return Err(Error::NoNeighbours);
276        };
277
278        let retries = self.inner.config.download_retries;
279
280        download_block_inner(
281            Request::from_tl(rpc::GetBlockFull { block_id: *block }),
282            overlay_client,
283            neighbour,
284            requirement,
285            retries,
286        )
287        .await
288    }
289
290    pub async fn get_next_block_full(
291        &self,
292        prev_block: &BlockId,
293        requirement: DataRequirement,
294    ) -> Result<BlockDataFullWithNeighbour, Error> {
295        let overlay_client = self.inner.overlay_client.clone();
296
297        let Some(neighbour) = overlay_client.neighbours().choose() else {
298            return Err(Error::NoNeighbours);
299        };
300
301        let retries = self.inner.config.download_retries;
302
303        download_block_inner(
304            Request::from_tl(rpc::GetNextBlockFull {
305                prev_block_id: *prev_block,
306            }),
307            overlay_client,
308            neighbour,
309            requirement,
310            retries,
311        )
312        .await
313    }
314
315    pub async fn get_key_block_proof(
316        &self,
317        block_id: &BlockId,
318    ) -> Result<QueryResponse<KeyBlockProof>, Error> {
319        let client = &self.inner.overlay_client;
320        let data = client
321            .query::<_, KeyBlockProof>(&rpc::GetKeyBlockProof {
322                block_id: *block_id,
323            })
324            .await?;
325        Ok(data)
326    }
327
328    pub async fn get_persistent_state_info(
329        &self,
330        block_id: &BlockId,
331    ) -> Result<QueryResponse<PersistentStateInfo>, Error> {
332        let client = &self.inner.overlay_client;
333        let data = client
334            .query::<_, PersistentStateInfo>(&rpc::GetPersistentShardStateInfo {
335                block_id: *block_id,
336            })
337            .await?;
338        Ok(data)
339    }
340
341    pub async fn get_persistent_state_part(
342        &self,
343        neighbour: &Neighbour,
344        block_id: &BlockId,
345        offset: u64,
346    ) -> Result<QueryResponse<Data>, Error> {
347        let client = &self.inner.overlay_client;
348        let data = client
349            .query_raw::<Data>(
350                neighbour.clone(),
351                Request::from_tl(rpc::GetPersistentShardStateChunk {
352                    block_id: *block_id,
353                    offset,
354                }),
355            )
356            .await?;
357        Ok(data)
358    }
359
360    pub async fn find_persistent_state(
361        &self,
362        block_id: &BlockId,
363        kind: PersistentStateKind,
364    ) -> Result<PendingPersistentState, Error> {
365        const NEIGHBOUR_COUNT: usize = 10;
366
367        // Get reliable neighbours with higher weight
368        let neighbours = self
369            .overlay_client()
370            .neighbours()
371            .choose_multiple(NEIGHBOUR_COUNT, NeighbourType::Reliable);
372
373        let req = match kind {
374            PersistentStateKind::Shard => Request::from_tl(rpc::GetPersistentShardStateInfo {
375                block_id: *block_id,
376            }),
377            PersistentStateKind::Queue => Request::from_tl(rpc::GetPersistentQueueStateInfo {
378                block_id: *block_id,
379            }),
380        };
381
382        let mut futures = FuturesUnordered::new();
383        for neighbour in neighbours {
384            futures.push(
385                self.overlay_client()
386                    .query_raw::<PersistentStateInfo>(neighbour.clone(), req.clone()),
387            );
388        }
389
390        let mut err = None;
391        while let Some(info) = futures.next().await {
392            let (handle, info) = match info {
393                Ok(res) => res.split(),
394                Err(e) => {
395                    err = Some(e);
396                    continue;
397                }
398            };
399
400            match info {
401                PersistentStateInfo::Found { size, chunk_size } => {
402                    let neighbour = handle.accept();
403                    tracing::debug!(
404                        peer_id = %neighbour.peer_id(),
405                        state_size = size.get(),
406                        state_chunk_size = chunk_size.get(),
407                        ?kind,
408                        "found persistent state",
409                    );
410
411                    return Ok(PendingPersistentState {
412                        block_id: *block_id,
413                        kind,
414                        size,
415                        chunk_size,
416                        neighbour,
417                    });
418                }
419                PersistentStateInfo::NotFound => {}
420            }
421        }
422
423        match err {
424            None => Err(Error::NotFound),
425            Some(err) => Err(err),
426        }
427    }
428
429    #[tracing::instrument(skip_all, fields(
430        peer_id = %state.neighbour.peer_id(),
431        block_id = %state.block_id,
432        kind = ?state.kind,
433    ))]
434    pub async fn download_persistent_state<W>(
435        &self,
436        state: PendingPersistentState,
437        output: W,
438    ) -> Result<W, Error>
439    where
440        W: Write + Send + 'static,
441    {
442        tracing::debug!("started");
443        scopeguard::defer! {
444            tracing::debug!("finished");
445        }
446
447        let block_id = state.block_id;
448        let max_retries = self.inner.config.download_retries;
449
450        download_compressed(
451            state.size,
452            state.chunk_size,
453            output,
454            |offset| {
455                tracing::debug!("downloading persistent state chunk");
456
457                let req = match state.kind {
458                    PersistentStateKind::Shard => {
459                        Request::from_tl(rpc::GetPersistentShardStateChunk { block_id, offset })
460                    }
461                    PersistentStateKind::Queue => {
462                        Request::from_tl(rpc::GetPersistentQueueStateChunk { block_id, offset })
463                    }
464                };
465                download_with_retries(
466                    req,
467                    self.overlay_client().clone(),
468                    state.neighbour.clone(),
469                    max_retries,
470                    "persistent state chunk",
471                )
472            },
473            |output, chunk| {
474                output.write_all(chunk)?;
475                Ok(())
476            },
477            |mut output| {
478                output.flush()?;
479                Ok(output)
480            },
481        )
482        .await
483    }
484
485    pub async fn find_archive(&self, mc_seqno: u32) -> Result<PendingArchiveResponse, Error> {
486        const NEIGHBOUR_COUNT: usize = 10;
487
488        // Get reliable neighbours with higher weight
489        let neighbours = self
490            .overlay_client()
491            .neighbours()
492            .choose_multiple(NEIGHBOUR_COUNT, NeighbourType::Reliable);
493
494        // Find a neighbour which has the requested archive
495        let pending_archive = 'info: {
496            let req = Request::from_tl(rpc::GetArchiveInfo { mc_seqno });
497
498            // Number of ArchiveInfo::TooNew responses
499            let mut new_archive_count = 0usize;
500
501            let mut futures = FuturesUnordered::new();
502            for neighbour in neighbours {
503                futures.push(self.overlay_client().query_raw(neighbour, req.clone()));
504            }
505
506            let mut err = None;
507            while let Some(info) = futures.next().await {
508                let (handle, info) = match info {
509                    Ok(res) => res.split(),
510                    Err(e) => {
511                        err = Some(e);
512                        continue;
513                    }
514                };
515
516                match info {
517                    ArchiveInfo::Found {
518                        id,
519                        size,
520                        chunk_size,
521                    } => {
522                        break 'info PendingArchive {
523                            id,
524                            size,
525                            chunk_size,
526                            neighbour: handle.accept(),
527                        };
528                    }
529                    ArchiveInfo::TooNew => {
530                        new_archive_count += 1;
531
532                        handle.accept();
533                    }
534                    ArchiveInfo::NotFound => {
535                        handle.accept();
536                    }
537                }
538            }
539
540            // Stop using archives if enough neighbors responded TooNew
541            if new_archive_count >= self.inner.config.too_new_archive_threshold {
542                return Ok(PendingArchiveResponse::TooNew);
543            }
544
545            return match err {
546                None => Err(Error::NotFound),
547                Some(err) => Err(err),
548            };
549        };
550
551        tracing::info!(
552            peer_id = %pending_archive.neighbour.peer_id(),
553            archive_id = pending_archive.id,
554            archive_size = %ByteSize(pending_archive.size.get()),
555            archuve_chunk_size = %ByteSize(pending_archive.chunk_size.get() as _),
556            "found archive",
557        );
558        Ok(PendingArchiveResponse::Found(pending_archive))
559    }
560
561    #[tracing::instrument(skip_all, fields(
562        peer_id = %archive.neighbour.peer_id(),
563        archive_id = archive.id,
564    ))]
565    pub async fn download_archive<W>(&self, archive: PendingArchive, output: W) -> Result<W, Error>
566    where
567        W: Write + Send + 'static,
568    {
569        use futures_util::FutureExt;
570
571        tracing::debug!("started");
572        scopeguard::defer! {
573            tracing::debug!("finished");
574        }
575
576        let retries = self.inner.config.download_retries;
577
578        download_compressed(
579            archive.size,
580            archive.chunk_size,
581            (output, ArchiveVerifier::default()),
582            |offset| {
583                let archive_id = archive.id;
584                let neighbour = archive.neighbour.clone();
585                let overlay_client = self.overlay_client().clone();
586
587                let started_at = Instant::now();
588
589                tracing::debug!(archive_id, offset, "downloading archive chunk");
590                download_with_retries(
591                    Request::from_tl(rpc::GetArchiveChunk { archive_id, offset }),
592                    overlay_client,
593                    neighbour,
594                    retries,
595                    "archive chunk",
596                )
597                .map(move |res| {
598                    tracing::info!(
599                        archive_id,
600                        offset,
601                        elapsed = %humantime::format_duration(started_at.elapsed()),
602                        "downloaded archive chunk",
603                    );
604                    res
605                })
606            },
607            |(output, verifier), chunk| {
608                verifier.write_verify(chunk)?;
609                output.write_all(chunk)?;
610                Ok(())
611            },
612            |(mut output, verifier)| {
613                verifier.final_check()?;
614                output.flush()?;
615                Ok(output)
616            },
617        )
618        .await
619    }
620}
621
622struct Inner {
623    config: BlockchainRpcClientConfig,
624    overlay_client: PublicOverlayClient,
625    broadcast_listener: Option<Box<dyn SelfBroadcastListener>>,
626    response_tracker: Mutex<tycho_util::time::RollingP2Estimator>,
627}
628
629pub enum PendingArchiveResponse {
630    Found(PendingArchive),
631    TooNew,
632}
633
634#[derive(Clone)]
635pub struct PendingArchive {
636    pub id: u64,
637    pub size: NonZeroU64,
638    pub chunk_size: NonZeroU32,
639    pub neighbour: Neighbour,
640}
641
642#[derive(Clone)]
643pub struct PendingPersistentState {
644    pub block_id: BlockId,
645    pub kind: PersistentStateKind,
646    pub size: NonZeroU64,
647    pub chunk_size: NonZeroU32,
648    pub neighbour: Neighbour,
649}
650
651pub struct BlockDataFull {
652    pub block_id: BlockId,
653    pub block_data: Bytes,
654    pub proof_data: Bytes,
655    pub queue_diff_data: Bytes,
656}
657
658pub struct BlockDataFullWithNeighbour {
659    pub data: Option<BlockDataFull>,
660    pub neighbour: Neighbour,
661}
662
663#[derive(Debug, Clone, Copy, PartialEq, Eq)]
664pub enum DataRequirement {
665    /// Data is not required to be present on the neighbour (mostly for polling).
666    ///
667    /// NOTE: Node will not be punished if the data is not present.
668    Optional,
669    /// We assume that the node has the data, but it's not required.
670    ///
671    /// NOTE: Node will be punished as [`PunishReason::Dumb`] if the data is not present.
672    Expected,
673    /// Data must be present on the requested neighbour.
674    ///
675    /// NOTE: Node will be punished as [`PunishReason::Malicious`] if the data is not present.
676    Required,
677}
678
679async fn download_block_inner(
680    req: Request,
681    overlay_client: PublicOverlayClient,
682    neighbour: Neighbour,
683    requirement: DataRequirement,
684    retries: usize,
685) -> Result<BlockDataFullWithNeighbour, Error> {
686    let response = overlay_client
687        .query_raw::<BlockFull>(neighbour.clone(), req)
688        .await?;
689
690    let (handle, block_full) = response.split();
691
692    let BlockFull::Found {
693        block_id,
694        block: block_data,
695        proof: proof_data,
696        queue_diff: queue_diff_data,
697    } = block_full
698    else {
699        match requirement {
700            DataRequirement::Optional => {
701                handle.accept();
702            }
703            DataRequirement::Expected => {
704                handle.reject();
705            }
706            DataRequirement::Required => {
707                neighbour.punish(crate::overlay_client::PunishReason::Malicious);
708            }
709        }
710
711        return Ok(BlockDataFullWithNeighbour {
712            data: None,
713            neighbour,
714        });
715    };
716
717    const PARALLEL_REQUESTS: usize = 10;
718
719    let target_size = block_data.size.get();
720    let chunk_size = block_data.chunk_size.get();
721    let block_data_size = block_data.data.len() as u32;
722
723    if block_data_size > target_size || block_data_size > chunk_size {
724        return Err(Error::Internal(anyhow::anyhow!("invalid first chunk")));
725    }
726
727    let (chunks_tx, mut chunks_rx) =
728        mpsc::channel::<(QueryResponseHandle, Bytes)>(PARALLEL_REQUESTS);
729
730    let span = tracing::Span::current();
731    let processing_task = tokio::task::spawn_blocking(move || {
732        let _span = span.enter();
733
734        let mut zstd_decoder = ZstdDecompressStream::new(chunk_size as usize)?;
735
736        // Buffer for decompressed data
737        let mut decompressed = Vec::new();
738
739        // Decompress chunk
740        zstd_decoder.write(block_data.data.as_ref(), &mut decompressed)?;
741
742        // Receive and process chunks
743        let mut downloaded = block_data.data.len() as u32;
744        while let Some((h, chunk)) = chunks_rx.blocking_recv() {
745            let guard = scopeguard::guard(h, |handle| {
746                handle.reject();
747            });
748
749            anyhow::ensure!(chunk.len() <= chunk_size as usize, "received invalid chunk");
750
751            downloaded += chunk.len() as u32;
752            tracing::debug!(
753                downloaded = %bytesize::ByteSize::b(downloaded as _),
754                "got block data chunk"
755            );
756
757            anyhow::ensure!(downloaded <= target_size, "received too many chunks");
758
759            // Decompress chunk
760            zstd_decoder.write(chunk.as_ref(), &mut decompressed)?;
761
762            ScopeGuard::into_inner(guard).accept(); // defuse the guard
763        }
764
765        anyhow::ensure!(
766            target_size == downloaded,
767            "block size mismatch (target size: {target_size}; downloaded: {downloaded})",
768        );
769
770        Ok(decompressed)
771    });
772
773    let stream = futures_util::stream::iter((chunk_size..target_size).step_by(chunk_size as usize))
774        .map(|offset| {
775            let neighbour = neighbour.clone();
776            let overlay_client = overlay_client.clone();
777
778            tracing::debug!(%block_id, offset, "downloading block data chunk");
779            JoinTask::new(download_with_retries(
780                Request::from_tl(rpc::GetBlockDataChunk { block_id, offset }),
781                overlay_client,
782                neighbour,
783                retries,
784                "block data chunk",
785            ))
786        })
787        .buffered(PARALLEL_REQUESTS);
788
789    let mut stream = std::pin::pin!(stream);
790    while let Some(chunk) = stream.next().await.transpose()? {
791        if chunks_tx.send(chunk).await.is_err() {
792            break;
793        }
794    }
795
796    drop(chunks_tx);
797
798    let block_data = processing_task
799        .await
800        .map_err(|e| Error::Internal(anyhow::anyhow!("Failed to join blocking task: {e}")))?
801        .map(Bytes::from)
802        .map_err(Error::Internal)?;
803
804    Ok(BlockDataFullWithNeighbour {
805        data: Some(BlockDataFull {
806            block_id,
807            block_data,
808            proof_data,
809            queue_diff_data,
810        }),
811        neighbour: neighbour.clone(),
812    })
813}
814
815async fn download_compressed<S, T, DF, DFut, PF, FF>(
816    target_size: NonZeroU64,
817    chunk_size: NonZeroU32,
818    mut state: S,
819    mut download_fn: DF,
820    mut process_fn: PF,
821    finalize_fn: FF,
822) -> Result<T, Error>
823where
824    S: Send + 'static,
825    T: Send + 'static,
826    DF: FnMut(u64) -> DFut,
827    DFut: Future<Output = DownloadedChunkResult> + Send + 'static,
828    PF: FnMut(&mut S, &[u8]) -> Result<()> + Send + 'static,
829    FF: FnOnce(S) -> Result<T> + Send + 'static,
830{
831    const PARALLEL_REQUESTS: usize = 10;
832
833    let target_size = target_size.get();
834    let chunk_size = chunk_size.get() as usize;
835
836    let (chunks_tx, mut chunks_rx) =
837        mpsc::channel::<(QueryResponseHandle, Bytes)>(PARALLEL_REQUESTS);
838
839    let span = tracing::Span::current();
840    let processing_task = tokio::task::spawn_blocking(move || {
841        let _span = span.enter();
842
843        let mut zstd_decoder = ZstdDecompressStream::new(chunk_size)?;
844
845        // Reuse buffer for decompressed data
846        let mut decompressed_chunk = Vec::new();
847
848        // Receive and process chunks
849        let mut downloaded = 0;
850        while let Some((h, chunk)) = chunks_rx.blocking_recv() {
851            let guard = scopeguard::guard(h, |handle| {
852                handle.reject();
853            });
854
855            anyhow::ensure!(chunk.len() <= chunk_size, "received invalid chunk");
856
857            downloaded += chunk.len() as u64;
858            tracing::debug!(
859                downloaded = %bytesize::ByteSize::b(downloaded),
860                "got chunk"
861            );
862
863            anyhow::ensure!(downloaded <= target_size, "received too many chunks");
864
865            decompressed_chunk.clear();
866            zstd_decoder.write(chunk.as_ref(), &mut decompressed_chunk)?;
867
868            process_fn(&mut state, &decompressed_chunk)?;
869
870            ScopeGuard::into_inner(guard).accept(); // defuse the guard
871        }
872
873        anyhow::ensure!(
874            target_size == downloaded,
875            "size mismatch (target size: {target_size}; downloaded: {downloaded})",
876        );
877
878        finalize_fn(state)
879    });
880
881    let stream = futures_util::stream::iter((0..target_size).step_by(chunk_size))
882        .map(|offset| JoinTask::new(download_fn(offset)))
883        .buffered(PARALLEL_REQUESTS);
884
885    let mut stream = std::pin::pin!(stream);
886    while let Some(chunk) = stream.next().await.transpose()? {
887        if chunks_tx.send(chunk).await.is_err() {
888            break;
889        }
890    }
891
892    drop(chunks_tx);
893
894    let output = processing_task
895        .await
896        .map_err(|e| Error::Internal(anyhow::anyhow!("Failed to join blocking task: {e}")))?
897        .map_err(Error::Internal)?;
898
899    Ok(output)
900}
901
902async fn download_with_retries(
903    req: Request,
904    overlay_client: PublicOverlayClient,
905    neighbour: Neighbour,
906    max_retries: usize,
907    name: &'static str,
908) -> DownloadedChunkResult {
909    let mut retries = 0;
910    loop {
911        match overlay_client
912            .query_raw::<Data>(neighbour.clone(), req.clone())
913            .await
914        {
915            Ok(r) => {
916                let (h, res) = r.split();
917                return Ok((h, res.data));
918            }
919            Err(e) => {
920                tracing::error!("failed to download {name}: {e}");
921                retries += 1;
922                if retries >= max_retries || !neighbour.is_reliable() {
923                    return Err(e);
924                }
925
926                tokio::time::sleep(Duration::from_millis(100)).await;
927            }
928        }
929    }
930}
931
932type DownloadedChunkResult = Result<(QueryResponseHandle, Bytes), Error>;
933
934#[cfg(test)]
935mod tests {
936    use rand::RngCore;
937    use tycho_network::PeerId;
938    use tycho_util::compression::zstd_compress;
939
940    use super::*;
941
942    #[tokio::test]
943    async fn download_compressed_works() -> Result<()> {
944        let neighbour = Neighbour::new(PeerId([0; 32]), u32::MAX, &Duration::from_millis(100));
945
946        let mut original_data = vec![0u8; 1 << 20]; // 1 MB of garbage
947        rand::rng().fill_bytes(&mut original_data);
948
949        let mut compressed_data = Vec::new();
950        zstd_compress(&original_data, &mut compressed_data, 9);
951        let compressed_data = Bytes::from(compressed_data);
952
953        assert_ne!(compressed_data, original_data);
954
955        const CHUNK_SIZE: usize = 128;
956
957        let received = download_compressed(
958            NonZeroU64::new(compressed_data.len() as _).unwrap(),
959            NonZeroU32::new(CHUNK_SIZE as _).unwrap(),
960            Vec::new(),
961            |offset| {
962                assert_eq!(offset % CHUNK_SIZE as u64, 0);
963                assert!(offset < compressed_data.len() as u64);
964                let from = offset as usize;
965                let to = std::cmp::min(from + CHUNK_SIZE, compressed_data.len());
966                let chunk = compressed_data.slice(from..to);
967                let handle = QueryResponseHandle::with_roundtrip_ms(neighbour.clone(), 100);
968                futures_util::future::ready(Ok((handle, chunk)))
969            },
970            |result, chunk| {
971                result.extend_from_slice(chunk);
972                Ok(())
973            },
974            Ok,
975        )
976        .await?;
977        assert_eq!(received, original_data);
978
979        Ok(())
980    }
981}