1use std::io::Write;
2use std::num::{NonZeroU32, NonZeroU64, NonZeroUsize};
3use std::sync::Arc;
4use std::time::{Duration, Instant};
5
6use anyhow::Result;
7use bytes::Bytes;
8use bytesize::ByteSize;
9use futures_util::stream::{FuturesUnordered, StreamExt};
10use parking_lot::Mutex;
11use scopeguard::ScopeGuard;
12use serde::{Deserialize, Serialize};
13use tokio::sync::mpsc;
14use tycho_block_util::archive::ArchiveVerifier;
15use tycho_network::{PublicOverlay, Request};
16use tycho_types::models::BlockId;
17use tycho_util::compression::ZstdDecompressStream;
18use tycho_util::futures::JoinTask;
19use tycho_util::serde_helpers;
20
21use crate::overlay_client::{
22 Error, Neighbour, NeighbourType, PublicOverlayClient, QueryResponse, QueryResponseHandle,
23};
24use crate::proto::blockchain::*;
25use crate::proto::overlay::BroadcastPrefix;
26use crate::storage::PersistentStateKind;
27use crate::util::downloader::{DownloaderError, DownloaderResponseHandle, download_and_decompress};
28
29#[async_trait::async_trait]
33pub trait SelfBroadcastListener: Send + Sync + 'static {
34 async fn handle_message(&self, message: Bytes);
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
38#[serde(default)]
39#[non_exhaustive]
40pub struct BlockchainRpcClientConfig {
41 #[serde(with = "serde_helpers::humantime")]
45 pub min_broadcast_timeout: Duration,
46
47 pub too_new_archive_threshold: usize,
52
53 pub download_retries: usize,
57}
58
59impl Default for BlockchainRpcClientConfig {
60 fn default() -> Self {
61 Self {
62 min_broadcast_timeout: Duration::from_millis(100),
63 too_new_archive_threshold: 4,
64 download_retries: 10,
65 }
66 }
67}
68
69pub struct BlockchainRpcClientBuilder<MandatoryFields = PublicOverlayClient> {
70 config: BlockchainRpcClientConfig,
71 mandatory_fields: MandatoryFields,
72 broadcast_listener: Option<Box<dyn SelfBroadcastListener>>,
73}
74
75impl BlockchainRpcClientBuilder<PublicOverlayClient> {
76 pub fn build(self) -> BlockchainRpcClient {
77 BlockchainRpcClient {
78 inner: Arc::new(Inner {
79 config: self.config,
80 overlay_client: self.mandatory_fields,
81 broadcast_listener: self.broadcast_listener,
82 response_tracker: Mutex::new(
83 tycho_util::time::RollingP2Estimator::new_with_config(
85 0.75, Duration::from_secs(60),
87 5,
88 tycho_util::time::RealClock,
89 )
90 .expect("correct quantile"),
91 ),
92 }),
93 }
94 }
95}
96
97impl BlockchainRpcClientBuilder<()> {
98 pub fn with_public_overlay_client(
99 self,
100 client: PublicOverlayClient,
101 ) -> BlockchainRpcClientBuilder<PublicOverlayClient> {
102 BlockchainRpcClientBuilder {
103 config: self.config,
104 mandatory_fields: client,
105 broadcast_listener: self.broadcast_listener,
106 }
107 }
108}
109
110impl<T> BlockchainRpcClientBuilder<T> {
111 pub fn with_self_broadcast_listener(mut self, listener: impl SelfBroadcastListener) -> Self {
112 self.broadcast_listener = Some(Box::new(listener));
113 self
114 }
115}
116
117impl<T> BlockchainRpcClientBuilder<T> {
118 pub fn with_config(self, config: BlockchainRpcClientConfig) -> Self {
119 Self { config, ..self }
120 }
121}
122
123#[derive(Clone)]
124#[repr(transparent)]
125pub struct BlockchainRpcClient {
126 inner: Arc<Inner>,
127}
128
129impl BlockchainRpcClient {
130 pub fn builder() -> BlockchainRpcClientBuilder<()> {
131 BlockchainRpcClientBuilder {
132 config: Default::default(),
133 mandatory_fields: (),
134 broadcast_listener: None,
135 }
136 }
137
138 pub fn overlay(&self) -> &PublicOverlay {
139 self.inner.overlay_client.overlay()
140 }
141
142 pub fn overlay_client(&self) -> &PublicOverlayClient {
143 &self.inner.overlay_client
144 }
145
146 pub async fn broadcast_external_message(&self, message: &[u8]) -> usize {
150 struct ExternalMessage<'a> {
151 data: &'a [u8],
152 }
153
154 impl tl_proto::TlWrite for ExternalMessage<'_> {
155 type Repr = tl_proto::Boxed;
156
157 fn max_size_hint(&self) -> usize {
158 4 + MessageBroadcastRef { data: self.data }.max_size_hint()
159 }
160
161 fn write_to<P>(&self, packet: &mut P)
162 where
163 P: tl_proto::TlPacket,
164 {
165 packet.write_u32(BroadcastPrefix::TL_ID);
166 MessageBroadcastRef { data: self.data }.write_to(packet);
167 }
168 }
169
170 if let Some(l) = &self.inner.broadcast_listener {
172 l.handle_message(Bytes::copy_from_slice(message)).await;
173 }
174
175 let client = &self.inner.overlay_client;
176
177 let mut delivered_to = 0;
178
179 let targets = client.get_broadcast_targets();
180 let request = Request::from_tl(ExternalMessage { data: message });
181 let mut futures = FuturesUnordered::new();
182
183 for validator in targets.as_ref() {
187 let client = client.clone();
188 let validator = validator.clone();
189 let request = request.clone();
190 let this = self.inner.clone();
191
192 futures.push(JoinTask::new(async move {
193 let start = Instant::now();
194 let res = client.send_to_validator(validator, request).await;
195 this.response_tracker
196 .lock()
197 .append(start.elapsed().as_millis() as i64);
198 res
199 }));
200 }
201
202 let timeout = self.compute_broadcast_timeout();
203 tokio::time::timeout(timeout, async {
204 while let Some(res) = futures.next().await {
206 if let Err(e) = res {
207 tracing::warn!("failed to broadcast external message: {e}");
208 } else {
209 delivered_to += 1;
210 }
211 }
212 })
213 .await
214 .ok();
215
216 if delivered_to == 0 {
217 tracing::debug!("message was not delivered to any peer");
218 }
219
220 delivered_to
221 }
222
223 fn compute_broadcast_timeout(&self) -> Duration {
224 let max_broadcast_timeout = std::cmp::max(
225 self.inner.overlay_client.config().validators.send_timeout,
226 self.inner.config.min_broadcast_timeout,
227 );
228
229 if let Some(prev_time) = self
230 .inner
231 .response_tracker
232 .lock()
233 .exponentially_weighted_average()
234 .map(|x| Duration::from_millis(x as _))
235 {
236 metrics::gauge!("tycho_broadcast_timeout", "kind" => "calculated").set(prev_time);
237 let value = prev_time.clamp(
238 self.inner.config.min_broadcast_timeout,
239 max_broadcast_timeout,
240 );
241 metrics::gauge!("tycho_broadcast_timeout", "kind" => "clamped").set(value);
242 value
243 } else {
244 max_broadcast_timeout
245 }
246 }
247
248 pub async fn get_next_key_block_ids(
249 &self,
250 block: &BlockId,
251 max_size: u32,
252 ) -> Result<QueryResponse<KeyBlockIds>, Error> {
253 let client = &self.inner.overlay_client;
254 let data = client
255 .query::<_, KeyBlockIds>(&rpc::GetNextKeyBlockIds {
256 block_id: *block,
257 max_size,
258 })
259 .await?;
260 Ok(data)
261 }
262
263 #[tracing::instrument(skip_all, fields(
264 block_id = %block.as_short_id(),
265 requirement = ?requirement,
266 ))]
267 pub async fn get_block_full(
268 &self,
269 block: &BlockId,
270 requirement: DataRequirement,
271 ) -> Result<BlockDataFullWithNeighbour, Error> {
272 let overlay_client = self.inner.overlay_client.clone();
273
274 let Some(neighbour) = overlay_client.neighbours().choose() else {
275 return Err(Error::NoNeighbours);
276 };
277
278 let retries = self.inner.config.download_retries;
279
280 download_block_inner(
281 Request::from_tl(rpc::GetBlockFull { block_id: *block }),
282 overlay_client,
283 neighbour,
284 requirement,
285 retries,
286 )
287 .await
288 }
289
290 pub async fn get_next_block_full(
291 &self,
292 prev_block: &BlockId,
293 requirement: DataRequirement,
294 ) -> Result<BlockDataFullWithNeighbour, Error> {
295 let overlay_client = self.inner.overlay_client.clone();
296
297 let Some(neighbour) = overlay_client.neighbours().choose() else {
298 return Err(Error::NoNeighbours);
299 };
300
301 let retries = self.inner.config.download_retries;
302
303 download_block_inner(
304 Request::from_tl(rpc::GetNextBlockFull {
305 prev_block_id: *prev_block,
306 }),
307 overlay_client,
308 neighbour,
309 requirement,
310 retries,
311 )
312 .await
313 }
314
315 pub async fn get_key_block_proof(
316 &self,
317 block_id: &BlockId,
318 ) -> Result<QueryResponse<KeyBlockProof>, Error> {
319 let client = &self.inner.overlay_client;
320 let data = client
321 .query::<_, KeyBlockProof>(&rpc::GetKeyBlockProof {
322 block_id: *block_id,
323 })
324 .await?;
325 Ok(data)
326 }
327
328 pub async fn get_zerostate_proof(&self) -> Result<QueryResponse<ZerostateProof>, Error> {
329 let client = &self.inner.overlay_client;
330 let data = client
331 .query::<_, ZerostateProof>(&rpc::GetZerostateProof)
332 .await?;
333 Ok(data)
334 }
335
336 pub async fn get_persistent_state_info(
337 &self,
338 block_id: &BlockId,
339 ) -> Result<QueryResponse<PersistentStateInfo>, Error> {
340 let client = &self.inner.overlay_client;
341 let data = client
342 .query::<_, PersistentStateInfo>(&rpc::GetPersistentShardStateInfo {
343 block_id: *block_id,
344 })
345 .await?;
346 Ok(data)
347 }
348
349 pub async fn get_persistent_state_part(
350 &self,
351 neighbour: &Neighbour,
352 block_id: &BlockId,
353 offset: u64,
354 ) -> Result<QueryResponse<Data>, Error> {
355 let client = &self.inner.overlay_client;
356 let data = client
357 .query_raw::<Data>(
358 neighbour.clone(),
359 Request::from_tl(rpc::GetPersistentShardStateChunk {
360 block_id: *block_id,
361 offset,
362 }),
363 )
364 .await?;
365 Ok(data)
366 }
367
368 pub async fn find_persistent_state(
369 &self,
370 block_id: &BlockId,
371 kind: PersistentStateKind,
372 ) -> Result<PendingPersistentState, Error> {
373 const NEIGHBOUR_COUNT: usize = 10;
374
375 let neighbours = self
377 .overlay_client()
378 .neighbours()
379 .choose_multiple(NEIGHBOUR_COUNT, NeighbourType::Reliable);
380
381 let req = match kind {
382 PersistentStateKind::Shard => Request::from_tl(rpc::GetPersistentShardStateInfo {
383 block_id: *block_id,
384 }),
385 PersistentStateKind::Queue => Request::from_tl(rpc::GetPersistentQueueStateInfo {
386 block_id: *block_id,
387 }),
388 };
389
390 let mut futures = FuturesUnordered::new();
391 for neighbour in neighbours {
392 futures.push(
393 self.overlay_client()
394 .query_raw::<PersistentStateInfo>(neighbour.clone(), req.clone()),
395 );
396 }
397
398 let mut err = None;
399 while let Some(info) = futures.next().await {
400 let (handle, info) = match info {
401 Ok(res) => res.split(),
402 Err(e) => {
403 err = Some(e);
404 continue;
405 }
406 };
407
408 match info {
409 PersistentStateInfo::Found { size, chunk_size } => {
410 let neighbour = handle.accept();
411 tracing::debug!(
412 peer_id = %neighbour.peer_id(),
413 state_size = size.get(),
414 state_chunk_size = chunk_size.get(),
415 ?kind,
416 "found persistent state",
417 );
418
419 return Ok(PendingPersistentState {
420 block_id: *block_id,
421 kind,
422 size,
423 chunk_size,
424 neighbour,
425 });
426 }
427 PersistentStateInfo::NotFound => {}
428 }
429 }
430
431 match err {
432 None => Err(Error::NotFound),
433 Some(err) => Err(err),
434 }
435 }
436
437 #[tracing::instrument(skip_all, fields(
438 peer_id = %state.neighbour.peer_id(),
439 block_id = %state.block_id,
440 kind = ?state.kind,
441 ))]
442 pub async fn download_persistent_state<W>(
443 &self,
444 state: PendingPersistentState,
445 output: W,
446 ) -> Result<W, Error>
447 where
448 W: Write + Send + 'static,
449 {
450 tracing::debug!("started");
451 scopeguard::defer! {
452 tracing::debug!("finished");
453 }
454
455 let block_id = state.block_id;
456 let max_retries = self.inner.config.download_retries;
457
458 download_and_decompress(
459 state.size,
460 state.chunk_size,
461 PARALLEL_REQUESTS,
462 output,
463 |offset| {
464 tracing::debug!("downloading persistent state chunk");
465
466 let req = match state.kind {
467 PersistentStateKind::Shard => {
468 Request::from_tl(rpc::GetPersistentShardStateChunk { block_id, offset })
469 }
470 PersistentStateKind::Queue => {
471 Request::from_tl(rpc::GetPersistentQueueStateChunk { block_id, offset })
472 }
473 };
474 download_with_retries(
475 req,
476 self.overlay_client().clone(),
477 state.neighbour.clone(),
478 max_retries,
479 "persistent state chunk",
480 )
481 },
482 |output, chunk| {
483 output.write_all(chunk)?;
484 Ok(())
485 },
486 |mut output| {
487 output.flush()?;
488 Ok(output)
489 },
490 )
491 .await
492 .map_err(map_downloader_error)
493 }
494
495 pub async fn find_archive(&self, mc_seqno: u32) -> Result<PendingArchiveResponse, Error> {
496 const NEIGHBOUR_COUNT: usize = 10;
497
498 let neighbours = self
500 .overlay_client()
501 .neighbours()
502 .choose_multiple(NEIGHBOUR_COUNT, NeighbourType::Reliable);
503
504 let pending_archive = 'info: {
506 let req = Request::from_tl(rpc::GetArchiveInfo { mc_seqno });
507
508 let mut new_archive_count = 0usize;
510
511 let mut futures = FuturesUnordered::new();
512 for neighbour in neighbours {
513 futures.push(self.overlay_client().query_raw(neighbour, req.clone()));
514 }
515
516 let mut err = None;
517 while let Some(info) = futures.next().await {
518 let (handle, info) = match info {
519 Ok(res) => res.split(),
520 Err(e) => {
521 err = Some(e);
522 continue;
523 }
524 };
525
526 match info {
527 ArchiveInfo::Found {
528 id,
529 size,
530 chunk_size,
531 } => {
532 break 'info PendingArchive {
533 id,
534 size,
535 chunk_size,
536 neighbour: handle.accept(),
537 };
538 }
539 ArchiveInfo::TooNew => {
540 new_archive_count += 1;
541
542 handle.accept();
543 }
544 ArchiveInfo::NotFound => {
545 handle.accept();
546 }
547 }
548 }
549
550 if new_archive_count >= self.inner.config.too_new_archive_threshold {
552 return Ok(PendingArchiveResponse::TooNew);
553 }
554
555 return match err {
556 None => Err(Error::NotFound),
557 Some(err) => Err(err),
558 };
559 };
560
561 tracing::info!(
562 peer_id = %pending_archive.neighbour.peer_id(),
563 archive_id = pending_archive.id,
564 archive_size = %ByteSize(pending_archive.size.get()),
565 archuve_chunk_size = %ByteSize(pending_archive.chunk_size.get() as _),
566 "found archive",
567 );
568 Ok(PendingArchiveResponse::Found(pending_archive))
569 }
570
571 #[tracing::instrument(skip_all, fields(
572 peer_id = %archive.neighbour.peer_id(),
573 archive_id = archive.id,
574 ))]
575 pub async fn download_archive<W>(&self, archive: PendingArchive, output: W) -> Result<W, Error>
576 where
577 W: Write + Send + 'static,
578 {
579 use futures_util::FutureExt;
580
581 tracing::debug!("started");
582 scopeguard::defer! {
583 tracing::debug!("finished");
584 }
585
586 let retries = self.inner.config.download_retries;
587
588 download_and_decompress(
589 archive.size,
590 archive.chunk_size,
591 PARALLEL_REQUESTS,
592 (output, ArchiveVerifier::default()),
593 |offset| {
594 let archive_id = archive.id;
595 let neighbour = archive.neighbour.clone();
596 let overlay_client = self.overlay_client().clone();
597
598 let started_at = Instant::now();
599
600 tracing::debug!(archive_id, offset, "downloading archive chunk");
601 download_with_retries(
602 Request::from_tl(rpc::GetArchiveChunk { archive_id, offset }),
603 overlay_client,
604 neighbour,
605 retries,
606 "archive chunk",
607 )
608 .map(move |res| {
609 tracing::info!(
610 archive_id,
611 offset,
612 elapsed = %humantime::format_duration(started_at.elapsed()),
613 "downloaded archive chunk",
614 );
615 res
616 })
617 },
618 |(output, verifier), chunk| {
619 verifier.write_verify(chunk)?;
620 output.write_all(chunk)?;
621 Ok(())
622 },
623 |(mut output, verifier)| {
624 verifier.final_check()?;
625 output.flush()?;
626 Ok(output)
627 },
628 )
629 .await
630 .map_err(map_downloader_error)
631 }
632}
633
634struct Inner {
635 config: BlockchainRpcClientConfig,
636 overlay_client: PublicOverlayClient,
637 broadcast_listener: Option<Box<dyn SelfBroadcastListener>>,
638 response_tracker: Mutex<tycho_util::time::RollingP2Estimator>,
639}
640
641pub enum PendingArchiveResponse {
642 Found(PendingArchive),
643 TooNew,
644}
645
646#[derive(Clone)]
647pub struct PendingArchive {
648 pub id: u64,
649 pub size: NonZeroU64,
650 pub chunk_size: NonZeroU32,
651 pub neighbour: Neighbour,
652}
653
654#[derive(Clone)]
655pub struct PendingPersistentState {
656 pub block_id: BlockId,
657 pub kind: PersistentStateKind,
658 pub size: NonZeroU64,
659 pub chunk_size: NonZeroU32,
660 pub neighbour: Neighbour,
661}
662
663pub struct BlockDataFull {
664 pub block_id: BlockId,
665 pub block_data: Bytes,
666 pub proof_data: Bytes,
667 pub queue_diff_data: Bytes,
668}
669
670pub struct BlockDataFullWithNeighbour {
671 pub data: Option<BlockDataFull>,
672 pub neighbour: Neighbour,
673}
674
675#[derive(Debug, Clone, Copy, PartialEq, Eq)]
676pub enum DataRequirement {
677 Optional,
681 Expected,
685 Required,
689}
690
691async fn download_block_inner(
692 req: Request,
693 overlay_client: PublicOverlayClient,
694 neighbour: Neighbour,
695 requirement: DataRequirement,
696 retries: usize,
697) -> Result<BlockDataFullWithNeighbour, Error> {
698 let response = overlay_client
699 .query_raw::<BlockFull>(neighbour.clone(), req)
700 .await?;
701
702 let (handle, block_full) = response.split();
703
704 let BlockFull::Found {
705 block_id,
706 block: block_data,
707 proof: proof_data,
708 queue_diff: queue_diff_data,
709 } = block_full
710 else {
711 match requirement {
712 DataRequirement::Optional => {
713 handle.accept();
714 }
715 DataRequirement::Expected => {
716 handle.reject();
717 }
718 DataRequirement::Required => {
719 neighbour.punish(crate::overlay_client::PunishReason::Malicious);
720 }
721 }
722
723 return Ok(BlockDataFullWithNeighbour {
724 data: None,
725 neighbour,
726 });
727 };
728
729 const PARALLEL_REQUESTS: usize = 10;
730
731 let target_size = block_data.size.get();
732 let chunk_size = block_data.chunk_size.get();
733 let block_data_size = block_data.data.len() as u32;
734
735 if block_data_size > target_size || block_data_size > chunk_size {
736 return Err(Error::Internal(anyhow::anyhow!("invalid first chunk")));
737 }
738
739 let (chunks_tx, mut chunks_rx) =
740 mpsc::channel::<(QueryResponseHandle, Bytes)>(PARALLEL_REQUESTS);
741
742 let span = tracing::Span::current();
743 let processing_task = tokio::task::spawn_blocking(move || {
744 let _span = span.enter();
745
746 let mut zstd_decoder = ZstdDecompressStream::new(chunk_size as usize)?;
747
748 let mut decompressed = Vec::new();
750
751 zstd_decoder.write(block_data.data.as_ref(), &mut decompressed)?;
753
754 let mut downloaded = block_data.data.len() as u32;
756 while let Some((h, chunk)) = chunks_rx.blocking_recv() {
757 let guard = scopeguard::guard(h, |handle| {
758 handle.reject();
759 });
760
761 anyhow::ensure!(chunk.len() <= chunk_size as usize, "received invalid chunk");
762
763 downloaded += chunk.len() as u32;
764 tracing::debug!(
765 downloaded = %bytesize::ByteSize::b(downloaded as _),
766 "got block data chunk"
767 );
768
769 anyhow::ensure!(downloaded <= target_size, "received too many chunks");
770
771 zstd_decoder.write(chunk.as_ref(), &mut decompressed)?;
773
774 ScopeGuard::into_inner(guard).accept(); }
776
777 anyhow::ensure!(
778 target_size == downloaded,
779 "block size mismatch (target size: {target_size}; downloaded: {downloaded})",
780 );
781
782 Ok(decompressed)
783 });
784
785 let stream = futures_util::stream::iter((chunk_size..target_size).step_by(chunk_size as usize))
786 .map(|offset| {
787 let neighbour = neighbour.clone();
788 let overlay_client = overlay_client.clone();
789
790 tracing::debug!(%block_id, offset, "downloading block data chunk");
791 JoinTask::new(download_with_retries(
792 Request::from_tl(rpc::GetBlockDataChunk { block_id, offset }),
793 overlay_client,
794 neighbour,
795 retries,
796 "block data chunk",
797 ))
798 })
799 .buffered(PARALLEL_REQUESTS);
800
801 let mut stream = std::pin::pin!(stream);
802 while let Some(chunk) = stream.next().await.transpose()? {
803 if chunks_tx.send(chunk).await.is_err() {
804 break;
805 }
806 }
807
808 drop(chunks_tx);
809
810 let block_data = processing_task
811 .await
812 .map_err(|e| Error::Internal(anyhow::anyhow!("Failed to join blocking task: {e}")))?
813 .map(Bytes::from)
814 .map_err(Error::Internal)?;
815
816 Ok(BlockDataFullWithNeighbour {
817 data: Some(BlockDataFull {
818 block_id,
819 block_data,
820 proof_data,
821 queue_diff_data,
822 }),
823 neighbour: neighbour.clone(),
824 })
825}
826
827async fn download_with_retries(
828 req: Request,
829 overlay_client: PublicOverlayClient,
830 neighbour: Neighbour,
831 max_retries: usize,
832 name: &'static str,
833) -> Result<(QueryResponseHandle, Bytes), Error> {
834 let mut retries = 0;
835 loop {
836 match overlay_client
837 .query_raw::<Data>(neighbour.clone(), req.clone())
838 .await
839 {
840 Ok(r) => {
841 let (h, res) = r.split();
842 return Ok((h, res.data));
843 }
844 Err(e) => {
845 tracing::error!("failed to download {name}: {e:?}");
846 retries += 1;
847 if retries >= max_retries || !neighbour.is_reliable() {
848 return Err(e);
849 }
850
851 tokio::time::sleep(Duration::from_millis(100)).await;
852 }
853 }
854 }
855}
856
857impl DownloaderResponseHandle for QueryResponseHandle {
858 fn accept(self) {
859 QueryResponseHandle::accept(self);
860 }
861
862 fn reject(self) {
863 QueryResponseHandle::reject(self);
864 }
865}
866
867fn map_downloader_error(e: DownloaderError<Error>) -> Error {
868 match e {
869 DownloaderError::DownloadFailed(e) => e,
870 e => Error::Internal(e.into()),
871 }
872}
873
874const PARALLEL_REQUESTS: NonZeroUsize = NonZeroUsize::new(10).unwrap();
876
877#[cfg(test)]
878mod tests {
879 use rand::RngCore;
880 use tycho_network::PeerId;
881 use tycho_util::compression::zstd_compress;
882
883 use super::*;
884
885 #[derive(Debug, thiserror::Error)]
886 #[error("stub")]
887 struct StubError;
888
889 #[tokio::test]
890 async fn download_compressed_works() -> Result<()> {
891 let neighbour = Neighbour::new(PeerId([0; 32]), u32::MAX, &Duration::from_millis(100));
892
893 let mut original_data = vec![0u8; 1 << 20]; rand::rng().fill_bytes(&mut original_data);
895
896 let mut compressed_data = Vec::new();
897 zstd_compress(&original_data, &mut compressed_data, 9);
898 let compressed_data = Bytes::from(compressed_data);
899
900 assert_ne!(compressed_data, original_data);
901
902 const CHUNK_SIZE: usize = 128;
903
904 let received = download_and_decompress(
905 NonZeroU64::new(compressed_data.len() as _).unwrap(),
906 NonZeroU32::new(CHUNK_SIZE as _).unwrap(),
907 PARALLEL_REQUESTS,
908 Vec::new(),
909 |offset| {
910 assert_eq!(offset % CHUNK_SIZE as u64, 0);
911 assert!(offset < compressed_data.len() as u64);
912 let from = offset as usize;
913 let to = std::cmp::min(from + CHUNK_SIZE, compressed_data.len());
914 let chunk = compressed_data.slice(from..to);
915 let handle = QueryResponseHandle::with_roundtrip_ms(neighbour.clone(), 100);
916 futures_util::future::ready(Ok::<_, StubError>((handle, chunk)))
917 },
918 |result, chunk| {
919 result.extend_from_slice(chunk);
920 Ok(())
921 },
922 Ok,
923 )
924 .await?;
925 assert_eq!(received, original_data);
926
927 Ok(())
928 }
929}