1use std::future::Future;
2use std::io::Write;
3use std::num::{NonZeroU32, NonZeroU64};
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6
7use anyhow::Result;
8use bytes::Bytes;
9use bytesize::ByteSize;
10use futures_util::stream::{FuturesUnordered, StreamExt};
11use parking_lot::Mutex;
12use scopeguard::ScopeGuard;
13use serde::{Deserialize, Serialize};
14use tokio::sync::mpsc;
15use tycho_block_util::archive::ArchiveVerifier;
16use tycho_network::{PublicOverlay, Request};
17use tycho_types::models::BlockId;
18use tycho_util::compression::ZstdDecompressStream;
19use tycho_util::futures::JoinTask;
20use tycho_util::serde_helpers;
21
22use crate::overlay_client::{
23 Error, Neighbour, NeighbourType, PublicOverlayClient, QueryResponse, QueryResponseHandle,
24};
25use crate::proto::blockchain::*;
26use crate::proto::overlay::BroadcastPrefix;
27use crate::storage::PersistentStateKind;
28
29#[async_trait::async_trait]
33pub trait SelfBroadcastListener: Send + Sync + 'static {
34 async fn handle_message(&self, message: Bytes);
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
38#[serde(default)]
39#[non_exhaustive]
40pub struct BlockchainRpcClientConfig {
41 #[serde(with = "serde_helpers::humantime")]
45 pub min_broadcast_timeout: Duration,
46
47 pub too_new_archive_threshold: usize,
52
53 pub download_retries: usize,
57}
58
59impl Default for BlockchainRpcClientConfig {
60 fn default() -> Self {
61 Self {
62 min_broadcast_timeout: Duration::from_millis(100),
63 too_new_archive_threshold: 4,
64 download_retries: 10,
65 }
66 }
67}
68
69pub struct BlockchainRpcClientBuilder<MandatoryFields = PublicOverlayClient> {
70 config: BlockchainRpcClientConfig,
71 mandatory_fields: MandatoryFields,
72 broadcast_listener: Option<Box<dyn SelfBroadcastListener>>,
73}
74
75impl BlockchainRpcClientBuilder<PublicOverlayClient> {
76 pub fn build(self) -> BlockchainRpcClient {
77 BlockchainRpcClient {
78 inner: Arc::new(Inner {
79 config: self.config,
80 overlay_client: self.mandatory_fields,
81 broadcast_listener: self.broadcast_listener,
82 response_tracker: Mutex::new(
83 tycho_util::time::RollingP2Estimator::new_with_config(
85 0.75, Duration::from_secs(60),
87 5,
88 tycho_util::time::RealClock,
89 )
90 .expect("correct quantile"),
91 ),
92 }),
93 }
94 }
95}
96
97impl BlockchainRpcClientBuilder<()> {
98 pub fn with_public_overlay_client(
99 self,
100 client: PublicOverlayClient,
101 ) -> BlockchainRpcClientBuilder<PublicOverlayClient> {
102 BlockchainRpcClientBuilder {
103 config: self.config,
104 mandatory_fields: client,
105 broadcast_listener: self.broadcast_listener,
106 }
107 }
108}
109
110impl<T> BlockchainRpcClientBuilder<T> {
111 pub fn with_self_broadcast_listener(mut self, listener: impl SelfBroadcastListener) -> Self {
112 self.broadcast_listener = Some(Box::new(listener));
113 self
114 }
115}
116
117impl<T> BlockchainRpcClientBuilder<T> {
118 pub fn with_config(self, config: BlockchainRpcClientConfig) -> Self {
119 Self { config, ..self }
120 }
121}
122
123#[derive(Clone)]
124#[repr(transparent)]
125pub struct BlockchainRpcClient {
126 inner: Arc<Inner>,
127}
128
129impl BlockchainRpcClient {
130 pub fn builder() -> BlockchainRpcClientBuilder<()> {
131 BlockchainRpcClientBuilder {
132 config: Default::default(),
133 mandatory_fields: (),
134 broadcast_listener: None,
135 }
136 }
137
138 pub fn overlay(&self) -> &PublicOverlay {
139 self.inner.overlay_client.overlay()
140 }
141
142 pub fn overlay_client(&self) -> &PublicOverlayClient {
143 &self.inner.overlay_client
144 }
145
146 pub async fn broadcast_external_message(&self, message: &[u8]) -> usize {
150 struct ExternalMessage<'a> {
151 data: &'a [u8],
152 }
153
154 impl tl_proto::TlWrite for ExternalMessage<'_> {
155 type Repr = tl_proto::Boxed;
156
157 fn max_size_hint(&self) -> usize {
158 4 + MessageBroadcastRef { data: self.data }.max_size_hint()
159 }
160
161 fn write_to<P>(&self, packet: &mut P)
162 where
163 P: tl_proto::TlPacket,
164 {
165 packet.write_u32(BroadcastPrefix::TL_ID);
166 MessageBroadcastRef { data: self.data }.write_to(packet);
167 }
168 }
169
170 if let Some(l) = &self.inner.broadcast_listener {
172 l.handle_message(Bytes::copy_from_slice(message)).await;
173 }
174
175 let client = &self.inner.overlay_client;
176
177 let mut delivered_to = 0;
178
179 let targets = client.get_broadcast_targets();
180 let request = Request::from_tl(ExternalMessage { data: message });
181 let mut futures = FuturesUnordered::new();
182
183 for validator in targets.as_ref() {
187 let client = client.clone();
188 let validator = validator.clone();
189 let request = request.clone();
190 let this = self.inner.clone();
191 futures.push(tokio::spawn(async move {
193 let start = Instant::now();
194 let res = client.send_to_validator(validator, request).await;
195 this.response_tracker
196 .lock()
197 .append(start.elapsed().as_millis() as i64);
198 res
199 }));
200 }
201
202 let timeout = self.compute_broadcast_timeout();
203 tokio::time::timeout(timeout, async {
204 while let Some(Ok(res)) = futures.next().await {
206 if let Err(e) = res {
207 tracing::warn!("failed to broadcast external message: {e}");
208 } else {
209 delivered_to += 1;
210 }
211 }
212 })
213 .await
214 .ok();
215
216 if delivered_to == 0 {
217 tracing::debug!("message was not delivered to any peer");
218 }
219
220 delivered_to
221 }
222
223 fn compute_broadcast_timeout(&self) -> Duration {
224 let max_broadcast_timeout = std::cmp::max(
225 self.inner.overlay_client.config().validators.send_timeout,
226 self.inner.config.min_broadcast_timeout,
227 );
228
229 if let Some(prev_time) = self
230 .inner
231 .response_tracker
232 .lock()
233 .exponentially_weighted_average()
234 .map(|x| Duration::from_millis(x as _))
235 {
236 metrics::gauge!("tycho_broadcast_timeout", "kind" => "calculated").set(prev_time);
237 let value = prev_time.clamp(
238 self.inner.config.min_broadcast_timeout,
239 max_broadcast_timeout,
240 );
241 metrics::gauge!("tycho_broadcast_timeout", "kind" => "clamped").set(value);
242 value
243 } else {
244 max_broadcast_timeout
245 }
246 }
247
248 pub async fn get_next_key_block_ids(
249 &self,
250 block: &BlockId,
251 max_size: u32,
252 ) -> Result<QueryResponse<KeyBlockIds>, Error> {
253 let client = &self.inner.overlay_client;
254 let data = client
255 .query::<_, KeyBlockIds>(&rpc::GetNextKeyBlockIds {
256 block_id: *block,
257 max_size,
258 })
259 .await?;
260 Ok(data)
261 }
262
263 #[tracing::instrument(skip_all, fields(
264 block_id = %block.as_short_id(),
265 requirement = ?requirement,
266 ))]
267 pub async fn get_block_full(
268 &self,
269 block: &BlockId,
270 requirement: DataRequirement,
271 ) -> Result<BlockDataFullWithNeighbour, Error> {
272 let overlay_client = self.inner.overlay_client.clone();
273
274 let Some(neighbour) = overlay_client.neighbours().choose() else {
275 return Err(Error::NoNeighbours);
276 };
277
278 let retries = self.inner.config.download_retries;
279
280 download_block_inner(
281 Request::from_tl(rpc::GetBlockFull { block_id: *block }),
282 overlay_client,
283 neighbour,
284 requirement,
285 retries,
286 )
287 .await
288 }
289
290 pub async fn get_next_block_full(
291 &self,
292 prev_block: &BlockId,
293 requirement: DataRequirement,
294 ) -> Result<BlockDataFullWithNeighbour, Error> {
295 let overlay_client = self.inner.overlay_client.clone();
296
297 let Some(neighbour) = overlay_client.neighbours().choose() else {
298 return Err(Error::NoNeighbours);
299 };
300
301 let retries = self.inner.config.download_retries;
302
303 download_block_inner(
304 Request::from_tl(rpc::GetNextBlockFull {
305 prev_block_id: *prev_block,
306 }),
307 overlay_client,
308 neighbour,
309 requirement,
310 retries,
311 )
312 .await
313 }
314
315 pub async fn get_key_block_proof(
316 &self,
317 block_id: &BlockId,
318 ) -> Result<QueryResponse<KeyBlockProof>, Error> {
319 let client = &self.inner.overlay_client;
320 let data = client
321 .query::<_, KeyBlockProof>(&rpc::GetKeyBlockProof {
322 block_id: *block_id,
323 })
324 .await?;
325 Ok(data)
326 }
327
328 pub async fn get_persistent_state_info(
329 &self,
330 block_id: &BlockId,
331 ) -> Result<QueryResponse<PersistentStateInfo>, Error> {
332 let client = &self.inner.overlay_client;
333 let data = client
334 .query::<_, PersistentStateInfo>(&rpc::GetPersistentShardStateInfo {
335 block_id: *block_id,
336 })
337 .await?;
338 Ok(data)
339 }
340
341 pub async fn get_persistent_state_part(
342 &self,
343 neighbour: &Neighbour,
344 block_id: &BlockId,
345 offset: u64,
346 ) -> Result<QueryResponse<Data>, Error> {
347 let client = &self.inner.overlay_client;
348 let data = client
349 .query_raw::<Data>(
350 neighbour.clone(),
351 Request::from_tl(rpc::GetPersistentShardStateChunk {
352 block_id: *block_id,
353 offset,
354 }),
355 )
356 .await?;
357 Ok(data)
358 }
359
360 pub async fn find_persistent_state(
361 &self,
362 block_id: &BlockId,
363 kind: PersistentStateKind,
364 ) -> Result<PendingPersistentState, Error> {
365 const NEIGHBOUR_COUNT: usize = 10;
366
367 let neighbours = self
369 .overlay_client()
370 .neighbours()
371 .choose_multiple(NEIGHBOUR_COUNT, NeighbourType::Reliable);
372
373 let req = match kind {
374 PersistentStateKind::Shard => Request::from_tl(rpc::GetPersistentShardStateInfo {
375 block_id: *block_id,
376 }),
377 PersistentStateKind::Queue => Request::from_tl(rpc::GetPersistentQueueStateInfo {
378 block_id: *block_id,
379 }),
380 };
381
382 let mut futures = FuturesUnordered::new();
383 for neighbour in neighbours {
384 futures.push(
385 self.overlay_client()
386 .query_raw::<PersistentStateInfo>(neighbour.clone(), req.clone()),
387 );
388 }
389
390 let mut err = None;
391 while let Some(info) = futures.next().await {
392 let (handle, info) = match info {
393 Ok(res) => res.split(),
394 Err(e) => {
395 err = Some(e);
396 continue;
397 }
398 };
399
400 match info {
401 PersistentStateInfo::Found { size, chunk_size } => {
402 let neighbour = handle.accept();
403 tracing::debug!(
404 peer_id = %neighbour.peer_id(),
405 state_size = size.get(),
406 state_chunk_size = chunk_size.get(),
407 ?kind,
408 "found persistent state",
409 );
410
411 return Ok(PendingPersistentState {
412 block_id: *block_id,
413 kind,
414 size,
415 chunk_size,
416 neighbour,
417 });
418 }
419 PersistentStateInfo::NotFound => {}
420 }
421 }
422
423 match err {
424 None => Err(Error::NotFound),
425 Some(err) => Err(err),
426 }
427 }
428
429 #[tracing::instrument(skip_all, fields(
430 peer_id = %state.neighbour.peer_id(),
431 block_id = %state.block_id,
432 kind = ?state.kind,
433 ))]
434 pub async fn download_persistent_state<W>(
435 &self,
436 state: PendingPersistentState,
437 output: W,
438 ) -> Result<W, Error>
439 where
440 W: Write + Send + 'static,
441 {
442 tracing::debug!("started");
443 scopeguard::defer! {
444 tracing::debug!("finished");
445 }
446
447 let block_id = state.block_id;
448 let max_retries = self.inner.config.download_retries;
449
450 download_compressed(
451 state.size,
452 state.chunk_size,
453 output,
454 |offset| {
455 tracing::debug!("downloading persistent state chunk");
456
457 let req = match state.kind {
458 PersistentStateKind::Shard => {
459 Request::from_tl(rpc::GetPersistentShardStateChunk { block_id, offset })
460 }
461 PersistentStateKind::Queue => {
462 Request::from_tl(rpc::GetPersistentQueueStateChunk { block_id, offset })
463 }
464 };
465 download_with_retries(
466 req,
467 self.overlay_client().clone(),
468 state.neighbour.clone(),
469 max_retries,
470 "persistent state chunk",
471 )
472 },
473 |output, chunk| {
474 output.write_all(chunk)?;
475 Ok(())
476 },
477 |mut output| {
478 output.flush()?;
479 Ok(output)
480 },
481 )
482 .await
483 }
484
485 pub async fn find_archive(&self, mc_seqno: u32) -> Result<PendingArchiveResponse, Error> {
486 const NEIGHBOUR_COUNT: usize = 10;
487
488 let neighbours = self
490 .overlay_client()
491 .neighbours()
492 .choose_multiple(NEIGHBOUR_COUNT, NeighbourType::Reliable);
493
494 let pending_archive = 'info: {
496 let req = Request::from_tl(rpc::GetArchiveInfo { mc_seqno });
497
498 let mut new_archive_count = 0usize;
500
501 let mut futures = FuturesUnordered::new();
502 for neighbour in neighbours {
503 futures.push(self.overlay_client().query_raw(neighbour, req.clone()));
504 }
505
506 let mut err = None;
507 while let Some(info) = futures.next().await {
508 let (handle, info) = match info {
509 Ok(res) => res.split(),
510 Err(e) => {
511 err = Some(e);
512 continue;
513 }
514 };
515
516 match info {
517 ArchiveInfo::Found {
518 id,
519 size,
520 chunk_size,
521 } => {
522 break 'info PendingArchive {
523 id,
524 size,
525 chunk_size,
526 neighbour: handle.accept(),
527 };
528 }
529 ArchiveInfo::TooNew => {
530 new_archive_count += 1;
531
532 handle.accept();
533 }
534 ArchiveInfo::NotFound => {
535 handle.accept();
536 }
537 }
538 }
539
540 if new_archive_count >= self.inner.config.too_new_archive_threshold {
542 return Ok(PendingArchiveResponse::TooNew);
543 }
544
545 return match err {
546 None => Err(Error::NotFound),
547 Some(err) => Err(err),
548 };
549 };
550
551 tracing::info!(
552 peer_id = %pending_archive.neighbour.peer_id(),
553 archive_id = pending_archive.id,
554 archive_size = %ByteSize(pending_archive.size.get()),
555 archuve_chunk_size = %ByteSize(pending_archive.chunk_size.get() as _),
556 "found archive",
557 );
558 Ok(PendingArchiveResponse::Found(pending_archive))
559 }
560
561 #[tracing::instrument(skip_all, fields(
562 peer_id = %archive.neighbour.peer_id(),
563 archive_id = archive.id,
564 ))]
565 pub async fn download_archive<W>(&self, archive: PendingArchive, output: W) -> Result<W, Error>
566 where
567 W: Write + Send + 'static,
568 {
569 use futures_util::FutureExt;
570
571 tracing::debug!("started");
572 scopeguard::defer! {
573 tracing::debug!("finished");
574 }
575
576 let retries = self.inner.config.download_retries;
577
578 download_compressed(
579 archive.size,
580 archive.chunk_size,
581 (output, ArchiveVerifier::default()),
582 |offset| {
583 let archive_id = archive.id;
584 let neighbour = archive.neighbour.clone();
585 let overlay_client = self.overlay_client().clone();
586
587 let started_at = Instant::now();
588
589 tracing::debug!(archive_id, offset, "downloading archive chunk");
590 download_with_retries(
591 Request::from_tl(rpc::GetArchiveChunk { archive_id, offset }),
592 overlay_client,
593 neighbour,
594 retries,
595 "archive chunk",
596 )
597 .map(move |res| {
598 tracing::info!(
599 archive_id,
600 offset,
601 elapsed = %humantime::format_duration(started_at.elapsed()),
602 "downloaded archive chunk",
603 );
604 res
605 })
606 },
607 |(output, verifier), chunk| {
608 verifier.write_verify(chunk)?;
609 output.write_all(chunk)?;
610 Ok(())
611 },
612 |(mut output, verifier)| {
613 verifier.final_check()?;
614 output.flush()?;
615 Ok(output)
616 },
617 )
618 .await
619 }
620}
621
622struct Inner {
623 config: BlockchainRpcClientConfig,
624 overlay_client: PublicOverlayClient,
625 broadcast_listener: Option<Box<dyn SelfBroadcastListener>>,
626 response_tracker: Mutex<tycho_util::time::RollingP2Estimator>,
627}
628
629pub enum PendingArchiveResponse {
630 Found(PendingArchive),
631 TooNew,
632}
633
634#[derive(Clone)]
635pub struct PendingArchive {
636 pub id: u64,
637 pub size: NonZeroU64,
638 pub chunk_size: NonZeroU32,
639 pub neighbour: Neighbour,
640}
641
642#[derive(Clone)]
643pub struct PendingPersistentState {
644 pub block_id: BlockId,
645 pub kind: PersistentStateKind,
646 pub size: NonZeroU64,
647 pub chunk_size: NonZeroU32,
648 pub neighbour: Neighbour,
649}
650
651pub struct BlockDataFull {
652 pub block_id: BlockId,
653 pub block_data: Bytes,
654 pub proof_data: Bytes,
655 pub queue_diff_data: Bytes,
656}
657
658pub struct BlockDataFullWithNeighbour {
659 pub data: Option<BlockDataFull>,
660 pub neighbour: Neighbour,
661}
662
663#[derive(Debug, Clone, Copy, PartialEq, Eq)]
664pub enum DataRequirement {
665 Optional,
669 Expected,
673 Required,
677}
678
679async fn download_block_inner(
680 req: Request,
681 overlay_client: PublicOverlayClient,
682 neighbour: Neighbour,
683 requirement: DataRequirement,
684 retries: usize,
685) -> Result<BlockDataFullWithNeighbour, Error> {
686 let response = overlay_client
687 .query_raw::<BlockFull>(neighbour.clone(), req)
688 .await?;
689
690 let (handle, block_full) = response.split();
691
692 let BlockFull::Found {
693 block_id,
694 block: block_data,
695 proof: proof_data,
696 queue_diff: queue_diff_data,
697 } = block_full
698 else {
699 match requirement {
700 DataRequirement::Optional => {
701 handle.accept();
702 }
703 DataRequirement::Expected => {
704 handle.reject();
705 }
706 DataRequirement::Required => {
707 neighbour.punish(crate::overlay_client::PunishReason::Malicious);
708 }
709 }
710
711 return Ok(BlockDataFullWithNeighbour {
712 data: None,
713 neighbour,
714 });
715 };
716
717 const PARALLEL_REQUESTS: usize = 10;
718
719 let target_size = block_data.size.get();
720 let chunk_size = block_data.chunk_size.get();
721 let block_data_size = block_data.data.len() as u32;
722
723 if block_data_size > target_size || block_data_size > chunk_size {
724 return Err(Error::Internal(anyhow::anyhow!("invalid first chunk")));
725 }
726
727 let (chunks_tx, mut chunks_rx) =
728 mpsc::channel::<(QueryResponseHandle, Bytes)>(PARALLEL_REQUESTS);
729
730 let span = tracing::Span::current();
731 let processing_task = tokio::task::spawn_blocking(move || {
732 let _span = span.enter();
733
734 let mut zstd_decoder = ZstdDecompressStream::new(chunk_size as usize)?;
735
736 let mut decompressed = Vec::new();
738
739 zstd_decoder.write(block_data.data.as_ref(), &mut decompressed)?;
741
742 let mut downloaded = block_data.data.len() as u32;
744 while let Some((h, chunk)) = chunks_rx.blocking_recv() {
745 let guard = scopeguard::guard(h, |handle| {
746 handle.reject();
747 });
748
749 anyhow::ensure!(chunk.len() <= chunk_size as usize, "received invalid chunk");
750
751 downloaded += chunk.len() as u32;
752 tracing::debug!(
753 downloaded = %bytesize::ByteSize::b(downloaded as _),
754 "got block data chunk"
755 );
756
757 anyhow::ensure!(downloaded <= target_size, "received too many chunks");
758
759 zstd_decoder.write(chunk.as_ref(), &mut decompressed)?;
761
762 ScopeGuard::into_inner(guard).accept(); }
764
765 anyhow::ensure!(
766 target_size == downloaded,
767 "block size mismatch (target size: {target_size}; downloaded: {downloaded})",
768 );
769
770 Ok(decompressed)
771 });
772
773 let stream = futures_util::stream::iter((chunk_size..target_size).step_by(chunk_size as usize))
774 .map(|offset| {
775 let neighbour = neighbour.clone();
776 let overlay_client = overlay_client.clone();
777
778 tracing::debug!(%block_id, offset, "downloading block data chunk");
779 JoinTask::new(download_with_retries(
780 Request::from_tl(rpc::GetBlockDataChunk { block_id, offset }),
781 overlay_client,
782 neighbour,
783 retries,
784 "block data chunk",
785 ))
786 })
787 .buffered(PARALLEL_REQUESTS);
788
789 let mut stream = std::pin::pin!(stream);
790 while let Some(chunk) = stream.next().await.transpose()? {
791 if chunks_tx.send(chunk).await.is_err() {
792 break;
793 }
794 }
795
796 drop(chunks_tx);
797
798 let block_data = processing_task
799 .await
800 .map_err(|e| Error::Internal(anyhow::anyhow!("Failed to join blocking task: {e}")))?
801 .map(Bytes::from)
802 .map_err(Error::Internal)?;
803
804 Ok(BlockDataFullWithNeighbour {
805 data: Some(BlockDataFull {
806 block_id,
807 block_data,
808 proof_data,
809 queue_diff_data,
810 }),
811 neighbour: neighbour.clone(),
812 })
813}
814
815async fn download_compressed<S, T, DF, DFut, PF, FF>(
816 target_size: NonZeroU64,
817 chunk_size: NonZeroU32,
818 mut state: S,
819 mut download_fn: DF,
820 mut process_fn: PF,
821 finalize_fn: FF,
822) -> Result<T, Error>
823where
824 S: Send + 'static,
825 T: Send + 'static,
826 DF: FnMut(u64) -> DFut,
827 DFut: Future<Output = DownloadedChunkResult> + Send + 'static,
828 PF: FnMut(&mut S, &[u8]) -> Result<()> + Send + 'static,
829 FF: FnOnce(S) -> Result<T> + Send + 'static,
830{
831 const PARALLEL_REQUESTS: usize = 10;
832
833 let target_size = target_size.get();
834 let chunk_size = chunk_size.get() as usize;
835
836 let (chunks_tx, mut chunks_rx) =
837 mpsc::channel::<(QueryResponseHandle, Bytes)>(PARALLEL_REQUESTS);
838
839 let span = tracing::Span::current();
840 let processing_task = tokio::task::spawn_blocking(move || {
841 let _span = span.enter();
842
843 let mut zstd_decoder = ZstdDecompressStream::new(chunk_size)?;
844
845 let mut decompressed_chunk = Vec::new();
847
848 let mut downloaded = 0;
850 while let Some((h, chunk)) = chunks_rx.blocking_recv() {
851 let guard = scopeguard::guard(h, |handle| {
852 handle.reject();
853 });
854
855 anyhow::ensure!(chunk.len() <= chunk_size, "received invalid chunk");
856
857 downloaded += chunk.len() as u64;
858 tracing::debug!(
859 downloaded = %bytesize::ByteSize::b(downloaded),
860 "got chunk"
861 );
862
863 anyhow::ensure!(downloaded <= target_size, "received too many chunks");
864
865 decompressed_chunk.clear();
866 zstd_decoder.write(chunk.as_ref(), &mut decompressed_chunk)?;
867
868 process_fn(&mut state, &decompressed_chunk)?;
869
870 ScopeGuard::into_inner(guard).accept(); }
872
873 anyhow::ensure!(
874 target_size == downloaded,
875 "size mismatch (target size: {target_size}; downloaded: {downloaded})",
876 );
877
878 finalize_fn(state)
879 });
880
881 let stream = futures_util::stream::iter((0..target_size).step_by(chunk_size))
882 .map(|offset| JoinTask::new(download_fn(offset)))
883 .buffered(PARALLEL_REQUESTS);
884
885 let mut stream = std::pin::pin!(stream);
886 while let Some(chunk) = stream.next().await.transpose()? {
887 if chunks_tx.send(chunk).await.is_err() {
888 break;
889 }
890 }
891
892 drop(chunks_tx);
893
894 let output = processing_task
895 .await
896 .map_err(|e| Error::Internal(anyhow::anyhow!("Failed to join blocking task: {e}")))?
897 .map_err(Error::Internal)?;
898
899 Ok(output)
900}
901
902async fn download_with_retries(
903 req: Request,
904 overlay_client: PublicOverlayClient,
905 neighbour: Neighbour,
906 max_retries: usize,
907 name: &'static str,
908) -> DownloadedChunkResult {
909 let mut retries = 0;
910 loop {
911 match overlay_client
912 .query_raw::<Data>(neighbour.clone(), req.clone())
913 .await
914 {
915 Ok(r) => {
916 let (h, res) = r.split();
917 return Ok((h, res.data));
918 }
919 Err(e) => {
920 tracing::error!("failed to download {name}: {e}");
921 retries += 1;
922 if retries >= max_retries || !neighbour.is_reliable() {
923 return Err(e);
924 }
925
926 tokio::time::sleep(Duration::from_millis(100)).await;
927 }
928 }
929 }
930}
931
932type DownloadedChunkResult = Result<(QueryResponseHandle, Bytes), Error>;
933
934#[cfg(test)]
935mod tests {
936 use rand::RngCore;
937 use tycho_network::PeerId;
938 use tycho_util::compression::zstd_compress;
939
940 use super::*;
941
942 #[tokio::test]
943 async fn download_compressed_works() -> Result<()> {
944 let neighbour = Neighbour::new(PeerId([0; 32]), u32::MAX, &Duration::from_millis(100));
945
946 let mut original_data = vec![0u8; 1 << 20]; rand::rng().fill_bytes(&mut original_data);
948
949 let mut compressed_data = Vec::new();
950 zstd_compress(&original_data, &mut compressed_data, 9);
951 let compressed_data = Bytes::from(compressed_data);
952
953 assert_ne!(compressed_data, original_data);
954
955 const CHUNK_SIZE: usize = 128;
956
957 let received = download_compressed(
958 NonZeroU64::new(compressed_data.len() as _).unwrap(),
959 NonZeroU32::new(CHUNK_SIZE as _).unwrap(),
960 Vec::new(),
961 |offset| {
962 assert_eq!(offset % CHUNK_SIZE as u64, 0);
963 assert!(offset < compressed_data.len() as u64);
964 let from = offset as usize;
965 let to = std::cmp::min(from + CHUNK_SIZE, compressed_data.len());
966 let chunk = compressed_data.slice(from..to);
967 let handle = QueryResponseHandle::with_roundtrip_ms(neighbour.clone(), 100);
968 futures_util::future::ready(Ok((handle, chunk)))
969 },
970 |result, chunk| {
971 result.extend_from_slice(chunk);
972 Ok(())
973 },
974 Ok,
975 )
976 .await?;
977 assert_eq!(received, original_data);
978
979 Ok(())
980 }
981}