1use std::future::Future;
2use std::io::Write;
3use std::num::{NonZeroU32, NonZeroU64};
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6
7use anyhow::Result;
8use bytes::Bytes;
9use bytesize::ByteSize;
10use futures_util::stream::{FuturesUnordered, StreamExt};
11use parking_lot::Mutex;
12use scopeguard::ScopeGuard;
13use serde::{Deserialize, Serialize};
14use tokio::sync::mpsc;
15use tycho_block_util::archive::ArchiveVerifier;
16use tycho_network::{PublicOverlay, Request};
17use tycho_types::models::BlockId;
18use tycho_util::compression::ZstdDecompressStream;
19use tycho_util::futures::JoinTask;
20use tycho_util::serde_helpers;
21
22use crate::overlay_client::{
23 Error, Neighbour, NeighbourType, PublicOverlayClient, QueryResponse, QueryResponseHandle,
24};
25use crate::proto::blockchain::*;
26use crate::proto::overlay::BroadcastPrefix;
27use crate::storage::PersistentStateKind;
28
29#[async_trait::async_trait]
33pub trait SelfBroadcastListener: Send + Sync + 'static {
34 async fn handle_message(&self, message: Bytes);
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
38#[serde(default)]
39#[non_exhaustive]
40pub struct BlockchainRpcClientConfig {
41 #[serde(with = "serde_helpers::humantime")]
45 pub min_broadcast_timeout: Duration,
46
47 pub too_new_archive_threshold: usize,
52
53 pub download_retries: usize,
57}
58
59impl Default for BlockchainRpcClientConfig {
60 fn default() -> Self {
61 Self {
62 min_broadcast_timeout: Duration::from_millis(100),
63 too_new_archive_threshold: 4,
64 download_retries: 10,
65 }
66 }
67}
68
69pub struct BlockchainRpcClientBuilder<MandatoryFields = PublicOverlayClient> {
70 config: BlockchainRpcClientConfig,
71 mandatory_fields: MandatoryFields,
72 broadcast_listener: Option<Box<dyn SelfBroadcastListener>>,
73}
74
75impl BlockchainRpcClientBuilder<PublicOverlayClient> {
76 pub fn build(self) -> BlockchainRpcClient {
77 BlockchainRpcClient {
78 inner: Arc::new(Inner {
79 config: self.config,
80 overlay_client: self.mandatory_fields,
81 broadcast_listener: self.broadcast_listener,
82 response_tracker: Mutex::new(
83 tycho_util::time::RollingP2Estimator::new_with_config(
85 0.75, Duration::from_secs(60),
87 5,
88 tycho_util::time::RealClock,
89 )
90 .expect("correct quantile"),
91 ),
92 }),
93 }
94 }
95}
96
97impl BlockchainRpcClientBuilder<()> {
98 pub fn with_public_overlay_client(
99 self,
100 client: PublicOverlayClient,
101 ) -> BlockchainRpcClientBuilder<PublicOverlayClient> {
102 BlockchainRpcClientBuilder {
103 config: self.config,
104 mandatory_fields: client,
105 broadcast_listener: self.broadcast_listener,
106 }
107 }
108}
109
110impl<T> BlockchainRpcClientBuilder<T> {
111 pub fn with_self_broadcast_listener(mut self, listener: impl SelfBroadcastListener) -> Self {
112 self.broadcast_listener = Some(Box::new(listener));
113 self
114 }
115}
116
117impl<T> BlockchainRpcClientBuilder<T> {
118 pub fn with_config(self, config: BlockchainRpcClientConfig) -> Self {
119 Self { config, ..self }
120 }
121}
122
123#[derive(Clone)]
124#[repr(transparent)]
125pub struct BlockchainRpcClient {
126 inner: Arc<Inner>,
127}
128
129impl BlockchainRpcClient {
130 pub fn builder() -> BlockchainRpcClientBuilder<()> {
131 BlockchainRpcClientBuilder {
132 config: Default::default(),
133 mandatory_fields: (),
134 broadcast_listener: None,
135 }
136 }
137
138 pub fn overlay(&self) -> &PublicOverlay {
139 self.inner.overlay_client.overlay()
140 }
141
142 pub fn overlay_client(&self) -> &PublicOverlayClient {
143 &self.inner.overlay_client
144 }
145
146 pub async fn broadcast_external_message(&self, message: &[u8]) -> usize {
150 struct ExternalMessage<'a> {
151 data: &'a [u8],
152 }
153
154 impl tl_proto::TlWrite for ExternalMessage<'_> {
155 type Repr = tl_proto::Boxed;
156
157 fn max_size_hint(&self) -> usize {
158 4 + MessageBroadcastRef { data: self.data }.max_size_hint()
159 }
160
161 fn write_to<P>(&self, packet: &mut P)
162 where
163 P: tl_proto::TlPacket,
164 {
165 packet.write_u32(BroadcastPrefix::TL_ID);
166 MessageBroadcastRef { data: self.data }.write_to(packet);
167 }
168 }
169
170 if let Some(l) = &self.inner.broadcast_listener {
172 l.handle_message(Bytes::copy_from_slice(message)).await;
173 }
174
175 let client = &self.inner.overlay_client;
176
177 let mut delivered_to = 0;
178
179 let targets = client.get_broadcast_targets();
180 let request = Request::from_tl(ExternalMessage { data: message });
181 let mut futures = FuturesUnordered::new();
182
183 for validator in targets.as_ref() {
187 let client = client.clone();
188 let validator = validator.clone();
189 let request = request.clone();
190 let this = self.inner.clone();
191 futures.push(tokio::spawn(async move {
193 let start = Instant::now();
194 let res = client.send_to_validator(validator, request).await;
195 this.response_tracker
196 .lock()
197 .append(start.elapsed().as_millis() as i64);
198 res
199 }));
200 }
201
202 let timeout = self.compute_broadcast_timeout();
203 tokio::time::timeout(timeout, async {
204 while let Some(Ok(res)) = futures.next().await {
206 if let Err(e) = res {
207 tracing::warn!("failed to broadcast external message: {e}");
208 } else {
209 delivered_to += 1;
210 }
211 }
212 })
213 .await
214 .ok();
215
216 if delivered_to == 0 {
217 tracing::debug!("message was not delivered to any peer");
218 }
219
220 delivered_to
221 }
222
223 fn compute_broadcast_timeout(&self) -> Duration {
224 let max_broadcast_timeout = std::cmp::max(
225 self.inner.overlay_client.config().validators.send_timeout,
226 self.inner.config.min_broadcast_timeout,
227 );
228
229 if let Some(prev_time) = self
230 .inner
231 .response_tracker
232 .lock()
233 .exponentially_weighted_average()
234 .map(|x| Duration::from_millis(x as _))
235 {
236 metrics::gauge!("tycho_broadcast_timeout", "kind" => "calculated").set(prev_time);
237 let value = prev_time.clamp(
238 self.inner.config.min_broadcast_timeout,
239 max_broadcast_timeout,
240 );
241 metrics::gauge!("tycho_broadcast_timeout", "kind" => "clamped").set(value);
242 value
243 } else {
244 max_broadcast_timeout
245 }
246 }
247
248 pub async fn get_next_key_block_ids(
249 &self,
250 block: &BlockId,
251 max_size: u32,
252 ) -> Result<QueryResponse<KeyBlockIds>, Error> {
253 let client = &self.inner.overlay_client;
254 let data = client
255 .query::<_, KeyBlockIds>(&rpc::GetNextKeyBlockIds {
256 block_id: *block,
257 max_size,
258 })
259 .await?;
260 Ok(data)
261 }
262
263 pub async fn get_block_full(
264 &self,
265 block: &BlockId,
266 requirement: DataRequirement,
267 ) -> Result<BlockDataFullWithNeighbour, Error> {
268 let overlay_client = self.inner.overlay_client.clone();
269
270 let Some(neighbour) = overlay_client.neighbours().choose() else {
271 return Err(Error::NoNeighbours);
272 };
273
274 let retries = self.inner.config.download_retries;
275
276 download_block_inner(
277 Request::from_tl(rpc::GetBlockFull { block_id: *block }),
278 overlay_client,
279 neighbour,
280 requirement,
281 retries,
282 )
283 .await
284 }
285
286 pub async fn get_next_block_full(
287 &self,
288 prev_block: &BlockId,
289 requirement: DataRequirement,
290 ) -> Result<BlockDataFullWithNeighbour, Error> {
291 let overlay_client = self.inner.overlay_client.clone();
292
293 let Some(neighbour) = overlay_client.neighbours().choose() else {
294 return Err(Error::NoNeighbours);
295 };
296
297 let retries = self.inner.config.download_retries;
298
299 download_block_inner(
300 Request::from_tl(rpc::GetNextBlockFull {
301 prev_block_id: *prev_block,
302 }),
303 overlay_client,
304 neighbour,
305 requirement,
306 retries,
307 )
308 .await
309 }
310
311 pub async fn get_key_block_proof(
312 &self,
313 block_id: &BlockId,
314 ) -> Result<QueryResponse<KeyBlockProof>, Error> {
315 let client = &self.inner.overlay_client;
316 let data = client
317 .query::<_, KeyBlockProof>(&rpc::GetKeyBlockProof {
318 block_id: *block_id,
319 })
320 .await?;
321 Ok(data)
322 }
323
324 pub async fn get_persistent_state_info(
325 &self,
326 block_id: &BlockId,
327 ) -> Result<QueryResponse<PersistentStateInfo>, Error> {
328 let client = &self.inner.overlay_client;
329 let data = client
330 .query::<_, PersistentStateInfo>(&rpc::GetPersistentShardStateInfo {
331 block_id: *block_id,
332 })
333 .await?;
334 Ok(data)
335 }
336
337 pub async fn get_persistent_state_part(
338 &self,
339 neighbour: &Neighbour,
340 block_id: &BlockId,
341 offset: u64,
342 ) -> Result<QueryResponse<Data>, Error> {
343 let client = &self.inner.overlay_client;
344 let data = client
345 .query_raw::<Data>(
346 neighbour.clone(),
347 Request::from_tl(rpc::GetPersistentShardStateChunk {
348 block_id: *block_id,
349 offset,
350 }),
351 )
352 .await?;
353 Ok(data)
354 }
355
356 pub async fn find_persistent_state(
357 &self,
358 block_id: &BlockId,
359 kind: PersistentStateKind,
360 ) -> Result<PendingPersistentState, Error> {
361 const NEIGHBOUR_COUNT: usize = 10;
362
363 let neighbours = self
365 .overlay_client()
366 .neighbours()
367 .choose_multiple(NEIGHBOUR_COUNT, NeighbourType::Reliable);
368
369 let req = match kind {
370 PersistentStateKind::Shard => Request::from_tl(rpc::GetPersistentShardStateInfo {
371 block_id: *block_id,
372 }),
373 PersistentStateKind::Queue => Request::from_tl(rpc::GetPersistentQueueStateInfo {
374 block_id: *block_id,
375 }),
376 };
377
378 let mut futures = FuturesUnordered::new();
379 for neighbour in neighbours {
380 futures.push(
381 self.overlay_client()
382 .query_raw::<PersistentStateInfo>(neighbour.clone(), req.clone()),
383 );
384 }
385
386 let mut err = None;
387 while let Some(info) = futures.next().await {
388 let (handle, info) = match info {
389 Ok(res) => res.split(),
390 Err(e) => {
391 err = Some(e);
392 continue;
393 }
394 };
395
396 match info {
397 PersistentStateInfo::Found { size, chunk_size } => {
398 let neighbour = handle.accept();
399 tracing::debug!(
400 peer_id = %neighbour.peer_id(),
401 state_size = size.get(),
402 state_chunk_size = chunk_size.get(),
403 ?kind,
404 "found persistent state",
405 );
406
407 return Ok(PendingPersistentState {
408 block_id: *block_id,
409 kind,
410 size,
411 chunk_size,
412 neighbour,
413 });
414 }
415 PersistentStateInfo::NotFound => {}
416 }
417 }
418
419 match err {
420 None => Err(Error::NotFound),
421 Some(err) => Err(err),
422 }
423 }
424
425 #[tracing::instrument(skip_all, fields(
426 peer_id = %state.neighbour.peer_id(),
427 block_id = %state.block_id,
428 kind = ?state.kind,
429 ))]
430 pub async fn download_persistent_state<W>(
431 &self,
432 state: PendingPersistentState,
433 output: W,
434 ) -> Result<W, Error>
435 where
436 W: Write + Send + 'static,
437 {
438 tracing::debug!("started");
439 scopeguard::defer! {
440 tracing::debug!("finished");
441 }
442
443 let block_id = state.block_id;
444 let max_retries = self.inner.config.download_retries;
445
446 download_compressed(
447 state.size,
448 state.chunk_size,
449 output,
450 |offset| {
451 tracing::debug!("downloading persistent state chunk");
452
453 let req = match state.kind {
454 PersistentStateKind::Shard => {
455 Request::from_tl(rpc::GetPersistentShardStateChunk { block_id, offset })
456 }
457 PersistentStateKind::Queue => {
458 Request::from_tl(rpc::GetPersistentQueueStateChunk { block_id, offset })
459 }
460 };
461 download_with_retries(
462 req,
463 self.overlay_client().clone(),
464 state.neighbour.clone(),
465 max_retries,
466 "persistent state chunk",
467 )
468 },
469 |output, chunk| {
470 output.write_all(chunk)?;
471 Ok(())
472 },
473 |mut output| {
474 output.flush()?;
475 Ok(output)
476 },
477 )
478 .await
479 }
480
481 pub async fn find_archive(&self, mc_seqno: u32) -> Result<PendingArchiveResponse, Error> {
482 const NEIGHBOUR_COUNT: usize = 10;
483
484 let neighbours = self
486 .overlay_client()
487 .neighbours()
488 .choose_multiple(NEIGHBOUR_COUNT, NeighbourType::Reliable);
489
490 let pending_archive = 'info: {
492 let req = Request::from_tl(rpc::GetArchiveInfo { mc_seqno });
493
494 let mut new_archive_count = 0usize;
496
497 let mut futures = FuturesUnordered::new();
498 for neighbour in neighbours {
499 futures.push(self.overlay_client().query_raw(neighbour, req.clone()));
500 }
501
502 let mut err = None;
503 while let Some(info) = futures.next().await {
504 let (handle, info) = match info {
505 Ok(res) => res.split(),
506 Err(e) => {
507 err = Some(e);
508 continue;
509 }
510 };
511
512 match info {
513 ArchiveInfo::Found {
514 id,
515 size,
516 chunk_size,
517 } => {
518 break 'info PendingArchive {
519 id,
520 size,
521 chunk_size,
522 neighbour: handle.accept(),
523 };
524 }
525 ArchiveInfo::TooNew => {
526 new_archive_count += 1;
527
528 handle.accept();
529 }
530 ArchiveInfo::NotFound => {
531 handle.accept();
532 }
533 }
534 }
535
536 if new_archive_count >= self.inner.config.too_new_archive_threshold {
538 return Ok(PendingArchiveResponse::TooNew);
539 }
540
541 return match err {
542 None => Err(Error::NotFound),
543 Some(err) => Err(err),
544 };
545 };
546
547 tracing::info!(
548 peer_id = %pending_archive.neighbour.peer_id(),
549 archive_id = pending_archive.id,
550 archive_size = %ByteSize(pending_archive.size.get()),
551 archuve_chunk_size = %ByteSize(pending_archive.chunk_size.get() as _),
552 "found archive",
553 );
554 Ok(PendingArchiveResponse::Found(pending_archive))
555 }
556
557 #[tracing::instrument(skip_all, fields(
558 peer_id = %archive.neighbour.peer_id(),
559 archive_id = archive.id,
560 ))]
561 pub async fn download_archive<W>(&self, archive: PendingArchive, output: W) -> Result<W, Error>
562 where
563 W: Write + Send + 'static,
564 {
565 use futures_util::FutureExt;
566
567 tracing::debug!("started");
568 scopeguard::defer! {
569 tracing::debug!("finished");
570 }
571
572 let retries = self.inner.config.download_retries;
573
574 download_compressed(
575 archive.size,
576 archive.chunk_size,
577 (output, ArchiveVerifier::default()),
578 |offset| {
579 let archive_id = archive.id;
580 let neighbour = archive.neighbour.clone();
581 let overlay_client = self.overlay_client().clone();
582
583 let started_at = Instant::now();
584
585 tracing::debug!(archive_id, offset, "downloading archive chunk");
586 download_with_retries(
587 Request::from_tl(rpc::GetArchiveChunk { archive_id, offset }),
588 overlay_client,
589 neighbour,
590 retries,
591 "archive chunk",
592 )
593 .map(move |res| {
594 tracing::info!(
595 archive_id,
596 offset,
597 elapsed = %humantime::format_duration(started_at.elapsed()),
598 "downloaded archive chunk",
599 );
600 res
601 })
602 },
603 |(output, verifier), chunk| {
604 verifier.write_verify(chunk)?;
605 output.write_all(chunk)?;
606 Ok(())
607 },
608 |(mut output, verifier)| {
609 verifier.final_check()?;
610 output.flush()?;
611 Ok(output)
612 },
613 )
614 .await
615 }
616}
617
618struct Inner {
619 config: BlockchainRpcClientConfig,
620 overlay_client: PublicOverlayClient,
621 broadcast_listener: Option<Box<dyn SelfBroadcastListener>>,
622 response_tracker: Mutex<tycho_util::time::RollingP2Estimator>,
623}
624
625pub enum PendingArchiveResponse {
626 Found(PendingArchive),
627 TooNew,
628}
629
630#[derive(Clone)]
631pub struct PendingArchive {
632 pub id: u64,
633 pub size: NonZeroU64,
634 pub chunk_size: NonZeroU32,
635 pub neighbour: Neighbour,
636}
637
638#[derive(Clone)]
639pub struct PendingPersistentState {
640 pub block_id: BlockId,
641 pub kind: PersistentStateKind,
642 pub size: NonZeroU64,
643 pub chunk_size: NonZeroU32,
644 pub neighbour: Neighbour,
645}
646
647pub struct BlockDataFull {
648 pub block_id: BlockId,
649 pub block_data: Bytes,
650 pub proof_data: Bytes,
651 pub queue_diff_data: Bytes,
652}
653
654pub struct BlockDataFullWithNeighbour {
655 pub data: Option<BlockDataFull>,
656 pub neighbour: Neighbour,
657}
658
659#[derive(Debug, Clone, Copy, PartialEq, Eq)]
660pub enum DataRequirement {
661 Optional,
665 Expected,
669 Required,
673}
674
675async fn download_block_inner(
676 req: Request,
677 overlay_client: PublicOverlayClient,
678 neighbour: Neighbour,
679 requirement: DataRequirement,
680 retries: usize,
681) -> Result<BlockDataFullWithNeighbour, Error> {
682 let response = overlay_client
683 .query_raw::<BlockFull>(neighbour.clone(), req)
684 .await?;
685
686 let (handle, block_full) = response.split();
687
688 let BlockFull::Found {
689 block_id,
690 block: block_data,
691 proof: proof_data,
692 queue_diff: queue_diff_data,
693 } = block_full
694 else {
695 match requirement {
696 DataRequirement::Optional => {
697 handle.accept();
698 }
699 DataRequirement::Expected => {
700 handle.reject();
701 }
702 DataRequirement::Required => {
703 neighbour.punish(crate::overlay_client::PunishReason::Malicious);
704 }
705 }
706
707 return Ok(BlockDataFullWithNeighbour {
708 data: None,
709 neighbour,
710 });
711 };
712
713 const PARALLEL_REQUESTS: usize = 10;
714
715 let target_size = block_data.size.get();
716 let chunk_size = block_data.chunk_size.get();
717 let block_data_size = block_data.data.len() as u32;
718
719 if block_data_size > target_size || block_data_size > chunk_size {
720 return Err(Error::Internal(anyhow::anyhow!("invalid first chunk")));
721 }
722
723 let (chunks_tx, mut chunks_rx) =
724 mpsc::channel::<(QueryResponseHandle, Bytes)>(PARALLEL_REQUESTS);
725
726 let span = tracing::Span::current();
727 let processing_task = tokio::task::spawn_blocking(move || {
728 let _span = span.enter();
729
730 let mut zstd_decoder = ZstdDecompressStream::new(chunk_size as usize)?;
731
732 let mut decompressed = Vec::new();
734
735 zstd_decoder.write(block_data.data.as_ref(), &mut decompressed)?;
737
738 let mut downloaded = block_data.data.len() as u32;
740 while let Some((h, chunk)) = chunks_rx.blocking_recv() {
741 let guard = scopeguard::guard(h, |handle| {
742 handle.reject();
743 });
744
745 anyhow::ensure!(chunk.len() <= chunk_size as usize, "received invalid chunk");
746
747 downloaded += chunk.len() as u32;
748 tracing::debug!(
749 downloaded = %bytesize::ByteSize::b(downloaded as _),
750 "got block data chunk"
751 );
752
753 anyhow::ensure!(downloaded <= target_size, "received too many chunks");
754
755 zstd_decoder.write(chunk.as_ref(), &mut decompressed)?;
757
758 ScopeGuard::into_inner(guard).accept(); }
760
761 anyhow::ensure!(
762 target_size == downloaded,
763 "block size mismatch (target size: {target_size}; downloaded: {downloaded})",
764 );
765
766 Ok(decompressed)
767 });
768
769 let stream = futures_util::stream::iter((chunk_size..target_size).step_by(chunk_size as usize))
770 .map(|offset| {
771 let neighbour = neighbour.clone();
772 let overlay_client = overlay_client.clone();
773
774 tracing::debug!(%block_id, offset, "downloading block data chunk");
775 JoinTask::new(download_with_retries(
776 Request::from_tl(rpc::GetBlockDataChunk { block_id, offset }),
777 overlay_client,
778 neighbour,
779 retries,
780 "block data chunk",
781 ))
782 })
783 .buffered(PARALLEL_REQUESTS);
784
785 let mut stream = std::pin::pin!(stream);
786 while let Some(chunk) = stream.next().await.transpose()? {
787 if chunks_tx.send(chunk).await.is_err() {
788 break;
789 }
790 }
791
792 drop(chunks_tx);
793
794 let block_data = processing_task
795 .await
796 .map_err(|e| Error::Internal(anyhow::anyhow!("Failed to join blocking task: {e}")))?
797 .map(Bytes::from)
798 .map_err(Error::Internal)?;
799
800 Ok(BlockDataFullWithNeighbour {
801 data: Some(BlockDataFull {
802 block_id,
803 block_data,
804 proof_data,
805 queue_diff_data,
806 }),
807 neighbour: neighbour.clone(),
808 })
809}
810
811async fn download_compressed<S, T, DF, DFut, PF, FF>(
812 target_size: NonZeroU64,
813 chunk_size: NonZeroU32,
814 mut state: S,
815 mut download_fn: DF,
816 mut process_fn: PF,
817 finalize_fn: FF,
818) -> Result<T, Error>
819where
820 S: Send + 'static,
821 T: Send + 'static,
822 DF: FnMut(u64) -> DFut,
823 DFut: Future<Output = DownloadedChunkResult> + Send + 'static,
824 PF: FnMut(&mut S, &[u8]) -> Result<()> + Send + 'static,
825 FF: FnOnce(S) -> Result<T> + Send + 'static,
826{
827 const PARALLEL_REQUESTS: usize = 10;
828
829 let target_size = target_size.get();
830 let chunk_size = chunk_size.get() as usize;
831
832 let (chunks_tx, mut chunks_rx) =
833 mpsc::channel::<(QueryResponseHandle, Bytes)>(PARALLEL_REQUESTS);
834
835 let span = tracing::Span::current();
836 let processing_task = tokio::task::spawn_blocking(move || {
837 let _span = span.enter();
838
839 let mut zstd_decoder = ZstdDecompressStream::new(chunk_size)?;
840
841 let mut decompressed_chunk = Vec::new();
843
844 let mut downloaded = 0;
846 while let Some((h, chunk)) = chunks_rx.blocking_recv() {
847 let guard = scopeguard::guard(h, |handle| {
848 handle.reject();
849 });
850
851 anyhow::ensure!(chunk.len() <= chunk_size, "received invalid chunk");
852
853 downloaded += chunk.len() as u64;
854 tracing::debug!(
855 downloaded = %bytesize::ByteSize::b(downloaded),
856 "got chunk"
857 );
858
859 anyhow::ensure!(downloaded <= target_size, "received too many chunks");
860
861 decompressed_chunk.clear();
862 zstd_decoder.write(chunk.as_ref(), &mut decompressed_chunk)?;
863
864 process_fn(&mut state, &decompressed_chunk)?;
865
866 ScopeGuard::into_inner(guard).accept(); }
868
869 anyhow::ensure!(
870 target_size == downloaded,
871 "size mismatch (target size: {target_size}; downloaded: {downloaded})",
872 );
873
874 finalize_fn(state)
875 });
876
877 let stream = futures_util::stream::iter((0..target_size).step_by(chunk_size))
878 .map(|offset| JoinTask::new(download_fn(offset)))
879 .buffered(PARALLEL_REQUESTS);
880
881 let mut stream = std::pin::pin!(stream);
882 while let Some(chunk) = stream.next().await.transpose()? {
883 if chunks_tx.send(chunk).await.is_err() {
884 break;
885 }
886 }
887
888 drop(chunks_tx);
889
890 let output = processing_task
891 .await
892 .map_err(|e| Error::Internal(anyhow::anyhow!("Failed to join blocking task: {e}")))?
893 .map_err(Error::Internal)?;
894
895 Ok(output)
896}
897
898async fn download_with_retries(
899 req: Request,
900 overlay_client: PublicOverlayClient,
901 neighbour: Neighbour,
902 max_retries: usize,
903 name: &'static str,
904) -> DownloadedChunkResult {
905 let mut retries = 0;
906 loop {
907 match overlay_client
908 .query_raw::<Data>(neighbour.clone(), req.clone())
909 .await
910 {
911 Ok(r) => {
912 let (h, res) = r.split();
913 return Ok((h, res.data));
914 }
915 Err(e) => {
916 tracing::error!("failed to download {name}: {e}");
917 retries += 1;
918 if retries >= max_retries || !neighbour.is_reliable() {
919 return Err(e);
920 }
921
922 tokio::time::sleep(Duration::from_millis(100)).await;
923 }
924 }
925 }
926}
927
928type DownloadedChunkResult = Result<(QueryResponseHandle, Bytes), Error>;
929
930#[cfg(test)]
931mod tests {
932 use rand::RngCore;
933 use tycho_network::PeerId;
934 use tycho_util::compression::zstd_compress;
935
936 use super::*;
937
938 #[tokio::test]
939 async fn download_compressed_works() -> Result<()> {
940 let neighbour = Neighbour::new(PeerId([0; 32]), u32::MAX, &Duration::from_millis(100));
941
942 let mut original_data = vec![0u8; 1 << 20]; rand::rng().fill_bytes(&mut original_data);
944
945 let mut compressed_data = Vec::new();
946 zstd_compress(&original_data, &mut compressed_data, 9);
947 let compressed_data = Bytes::from(compressed_data);
948
949 assert_ne!(compressed_data, original_data);
950
951 const CHUNK_SIZE: usize = 128;
952
953 let received = download_compressed(
954 NonZeroU64::new(compressed_data.len() as _).unwrap(),
955 NonZeroU32::new(CHUNK_SIZE as _).unwrap(),
956 Vec::new(),
957 |offset| {
958 assert_eq!(offset % CHUNK_SIZE as u64, 0);
959 assert!(offset < compressed_data.len() as u64);
960 let from = offset as usize;
961 let to = std::cmp::min(from + CHUNK_SIZE, compressed_data.len());
962 let chunk = compressed_data.slice(from..to);
963 let handle = QueryResponseHandle::with_roundtrip_ms(neighbour.clone(), 100);
964 futures_util::future::ready(Ok((handle, chunk)))
965 },
966 |result, chunk| {
967 result.extend_from_slice(chunk);
968 Ok(())
969 },
970 Ok,
971 )
972 .await?;
973 assert_eq!(received, original_data);
974
975 Ok(())
976 }
977}