1use std::{
7 collections::HashMap,
8 sync::Arc,
9 time::{Duration, SystemTime},
10};
11
12use async_trait::async_trait;
13use backoff::{exponential::ExponentialBackoffBuilder, ExponentialBackoff};
14use futures03::future::try_join_all;
15#[cfg(test)]
16use mockall::automock;
17use reqwest::{header, Client, ClientBuilder, Response, StatusCode, Url};
18use serde::Serialize;
19use thiserror::Error;
20use time::{format_description::well_known::Rfc2822, OffsetDateTime};
21use tokio::{
22 sync::{RwLock, Semaphore},
23 time::sleep,
24};
25use tracing::{debug, error, instrument, trace, warn};
26use tycho_common::{
27 dto::{
28 BlockParam, Chain, ComponentTvlRequestBody, ComponentTvlRequestResponse,
29 EntryPointWithTracingParams, PaginationLimits, PaginationParams, PaginationResponse,
30 ProtocolComponent, ProtocolComponentRequestResponse, ProtocolComponentsRequestBody,
31 ProtocolStateRequestBody, ProtocolStateRequestResponse, ProtocolSystemsRequestBody,
32 ProtocolSystemsRequestResponse, ResponseToken, StateRequestBody, StateRequestResponse,
33 TokensRequestBody, TokensRequestResponse, TracedEntryPointRequestBody,
34 TracedEntryPointRequestResponse, TracingResult, VersionParam,
35 },
36 models::ComponentId,
37 Bytes,
38};
39
40use crate::{
41 feed::synchronizer::{ComponentWithState, Snapshot},
42 TYCHO_SERVER_VERSION,
43};
44
45pub const RPC_CLIENT_CONCURRENCY: usize = 4;
47
48#[derive(Clone, Debug, PartialEq)]
53pub struct SnapshotParameters<'a> {
54 pub chain: Chain,
56 pub protocol_system: &'a str,
58 pub components: &'a HashMap<ComponentId, ProtocolComponent>,
60 pub entrypoints: Option<&'a HashMap<String, Vec<(EntryPointWithTracingParams, TracingResult)>>>,
62 pub contract_ids: &'a [Bytes],
64 pub block_number: u64,
66 pub include_balances: bool,
68 pub include_tvl: bool,
70}
71
72impl<'a> SnapshotParameters<'a> {
73 pub fn new(
74 chain: Chain,
75 protocol_system: &'a str,
76 components: &'a HashMap<ComponentId, ProtocolComponent>,
77 contract_ids: &'a [Bytes],
78 block_number: u64,
79 ) -> Self {
80 Self {
81 chain,
82 protocol_system,
83 components,
84 entrypoints: None,
85 contract_ids,
86 block_number,
87 include_balances: true,
88 include_tvl: true,
89 }
90 }
91
92 pub fn include_balances(mut self, include_balances: bool) -> Self {
94 self.include_balances = include_balances;
95 self
96 }
97
98 pub fn include_tvl(mut self, include_tvl: bool) -> Self {
100 self.include_tvl = include_tvl;
101 self
102 }
103
104 pub fn entrypoints(
105 mut self,
106 entrypoints: &'a HashMap<String, Vec<(EntryPointWithTracingParams, TracingResult)>>,
107 ) -> Self {
108 self.entrypoints = Some(entrypoints);
109 self
110 }
111}
112
113#[derive(Error, Debug)]
114pub enum RPCError {
115 #[error("Failed to parse URL: {0}. Error: {1}")]
117 UrlParsing(String, String),
118
119 #[error("Failed to format request: {0}")]
121 FormatRequest(String),
122
123 #[error("Unexpected HTTP client error: {0}")]
125 HttpClient(String, #[source] reqwest::Error),
126
127 #[error("Failed to parse response: {0}")]
129 ParseResponse(String),
130
131 #[error("Fatal error: {0}")]
133 Fatal(String),
134
135 #[error("Rate limited until {0:?}")]
136 RateLimited(Option<SystemTime>),
137
138 #[error("Server unreachable: {0}")]
139 ServerUnreachable(String),
140}
141
142#[cfg_attr(test, automock)]
143#[async_trait]
144pub trait RPCClient: Send + Sync {
145 fn compression(&self) -> bool;
147
148 async fn get_contract_state(
150 &self,
151 request: &StateRequestBody,
152 ) -> Result<StateRequestResponse, RPCError>;
153
154 async fn get_contract_state_paginated(
157 &self,
158 chain: Chain,
159 ids: &[Bytes],
160 protocol_system: &str,
161 version: &VersionParam,
162 chunk_size: Option<usize>,
163 concurrency: usize,
164 ) -> Result<StateRequestResponse, RPCError> {
165 let semaphore = Arc::new(Semaphore::new(concurrency));
166
167 let mut sorted_ids = ids.to_vec();
169 sorted_ids.sort();
170
171 let chunk_size = chunk_size
172 .unwrap_or(StateRequestBody::effective_max_page_size(self.compression()) as usize);
173
174 let chunked_bodies = sorted_ids
175 .chunks(chunk_size)
176 .map(|chunk| StateRequestBody {
177 contract_ids: Some(chunk.to_vec()),
178 protocol_system: protocol_system.to_string(),
179 chain,
180 version: version.clone(),
181 pagination: PaginationParams { page: 0, page_size: chunk_size as i64 },
182 })
183 .collect::<Vec<_>>();
184
185 let mut tasks = Vec::new();
186 for body in chunked_bodies.iter() {
187 let sem = semaphore.clone();
188 tasks.push(async move {
189 let _permit = sem
190 .acquire()
191 .await
192 .map_err(|_| RPCError::Fatal("Semaphore dropped".to_string()))?;
193 self.get_contract_state(body).await
194 });
195 }
196
197 let responses = try_join_all(tasks).await?;
199
200 let accounts = responses
202 .iter()
203 .flat_map(|r| r.accounts.clone())
204 .collect();
205 let total: i64 = responses
206 .iter()
207 .map(|r| r.pagination.total)
208 .sum();
209
210 Ok(StateRequestResponse {
211 accounts,
212 pagination: PaginationResponse { page: 0, page_size: chunk_size as i64, total },
213 })
214 }
215
216 async fn get_protocol_components(
217 &self,
218 request: &ProtocolComponentsRequestBody,
219 ) -> Result<ProtocolComponentRequestResponse, RPCError>;
220
221 async fn get_protocol_components_paginated(
224 &self,
225 request: &ProtocolComponentsRequestBody,
226 chunk_size: Option<usize>,
227 concurrency: usize,
228 ) -> Result<ProtocolComponentRequestResponse, RPCError> {
229 let semaphore = Arc::new(Semaphore::new(concurrency));
230
231 let chunk_size = chunk_size.unwrap_or(
232 ProtocolComponentsRequestBody::effective_max_page_size(self.compression()) as usize,
233 );
234
235 match request.component_ids {
238 Some(ref ids) => {
239 let chunked_bodies = ids
241 .chunks(chunk_size)
242 .enumerate()
243 .map(|(index, _)| ProtocolComponentsRequestBody {
244 protocol_system: request.protocol_system.clone(),
245 component_ids: request.component_ids.clone(),
246 tvl_gt: request.tvl_gt,
247 chain: request.chain,
248 pagination: PaginationParams {
249 page: index as i64,
250 page_size: chunk_size as i64,
251 },
252 })
253 .collect::<Vec<_>>();
254
255 let mut tasks = Vec::new();
256 for body in chunked_bodies.iter() {
257 let sem = semaphore.clone();
258 tasks.push(async move {
259 let _permit = sem
260 .acquire()
261 .await
262 .map_err(|_| RPCError::Fatal("Semaphore dropped".to_string()))?;
263 self.get_protocol_components(body).await
264 });
265 }
266
267 try_join_all(tasks)
268 .await
269 .map(|responses| ProtocolComponentRequestResponse {
270 protocol_components: responses
271 .into_iter()
272 .flat_map(|r| r.protocol_components.into_iter())
273 .collect(),
274 pagination: PaginationResponse {
275 page: 0,
276 page_size: chunk_size as i64,
277 total: ids.len() as i64,
278 },
279 })
280 }
281 _ => {
282 let initial_request = ProtocolComponentsRequestBody {
286 protocol_system: request.protocol_system.clone(),
287 component_ids: request.component_ids.clone(),
288 tvl_gt: request.tvl_gt,
289 chain: request.chain,
290 pagination: PaginationParams { page: 0, page_size: chunk_size as i64 },
291 };
292 let first_response = self
293 .get_protocol_components(&initial_request)
294 .await
295 .map_err(|err| RPCError::Fatal(err.to_string()))?;
296
297 let total_items = first_response.pagination.total;
298 let total_pages = (total_items as f64 / chunk_size as f64).ceil() as i64;
299
300 let mut accumulated_response = ProtocolComponentRequestResponse {
302 protocol_components: first_response.protocol_components,
303 pagination: PaginationResponse {
304 page: 0,
305 page_size: chunk_size as i64,
306 total: total_items,
307 },
308 };
309
310 let mut page = 1;
311 while page < total_pages {
312 let requests_in_this_iteration = (total_pages - page).min(concurrency as i64);
313
314 let chunked_bodies = (0..requests_in_this_iteration)
316 .map(|iter| ProtocolComponentsRequestBody {
317 protocol_system: request.protocol_system.clone(),
318 component_ids: request.component_ids.clone(),
319 tvl_gt: request.tvl_gt,
320 chain: request.chain,
321 pagination: PaginationParams {
322 page: page + iter,
323 page_size: chunk_size as i64,
324 },
325 })
326 .collect::<Vec<_>>();
327
328 let tasks: Vec<_> = chunked_bodies
329 .iter()
330 .map(|body| {
331 let sem = semaphore.clone();
332 async move {
333 let _permit = sem.acquire().await.map_err(|_| {
334 RPCError::Fatal("Semaphore dropped".to_string())
335 })?;
336 self.get_protocol_components(body).await
337 }
338 })
339 .collect();
340
341 let responses = try_join_all(tasks)
342 .await
343 .map(|responses| {
344 let total = responses[0].pagination.total;
345 ProtocolComponentRequestResponse {
346 protocol_components: responses
347 .into_iter()
348 .flat_map(|r| r.protocol_components.into_iter())
349 .collect(),
350 pagination: PaginationResponse {
351 page,
352 page_size: chunk_size as i64,
353 total,
354 },
355 }
356 });
357
358 match responses {
360 Ok(mut resp) => {
361 accumulated_response
362 .protocol_components
363 .append(&mut resp.protocol_components);
364 }
365 Err(e) => return Err(e),
366 }
367
368 page += concurrency as i64;
369 }
370 Ok(accumulated_response)
371 }
372 }
373 }
374
375 async fn get_protocol_states(
376 &self,
377 request: &ProtocolStateRequestBody,
378 ) -> Result<ProtocolStateRequestResponse, RPCError>;
379
380 #[allow(clippy::too_many_arguments)]
383 async fn get_protocol_states_paginated<T>(
384 &self,
385 chain: Chain,
386 ids: &[T],
387 protocol_system: &str,
388 include_balances: bool,
389 version: &VersionParam,
390 chunk_size: Option<usize>,
391 concurrency: usize,
392 ) -> Result<ProtocolStateRequestResponse, RPCError>
393 where
394 T: AsRef<str> + Sync + 'static,
395 {
396 let semaphore = Arc::new(Semaphore::new(concurrency));
397
398 let chunk_size = chunk_size.unwrap_or(ProtocolStateRequestBody::effective_max_page_size(
399 self.compression(),
400 ) as usize);
401
402 let chunked_bodies = ids
403 .chunks(chunk_size)
404 .map(|c| ProtocolStateRequestBody {
405 protocol_ids: Some(
406 c.iter()
407 .map(|id| id.as_ref().to_string())
408 .collect(),
409 ),
410 protocol_system: protocol_system.to_string(),
411 chain,
412 include_balances,
413 version: version.clone(),
414 pagination: PaginationParams { page: 0, page_size: chunk_size as i64 },
415 })
416 .collect::<Vec<_>>();
417
418 let mut tasks = Vec::new();
419 for body in chunked_bodies.iter() {
420 let sem = semaphore.clone();
421 tasks.push(async move {
422 let _permit = sem
423 .acquire()
424 .await
425 .map_err(|_| RPCError::Fatal("Semaphore dropped".to_string()))?;
426 self.get_protocol_states(body).await
427 });
428 }
429
430 try_join_all(tasks)
431 .await
432 .map(|responses| {
433 let states = responses
434 .clone()
435 .into_iter()
436 .flat_map(|r| r.states)
437 .collect();
438 let total = responses
439 .iter()
440 .map(|r| r.pagination.total)
441 .sum();
442 ProtocolStateRequestResponse {
443 states,
444 pagination: PaginationResponse { page: 0, page_size: chunk_size as i64, total },
445 }
446 })
447 }
448
449 async fn get_tokens(
452 &self,
453 request: &TokensRequestBody,
454 ) -> Result<TokensRequestResponse, RPCError>;
455
456 async fn get_all_tokens(
459 &self,
460 chain: Chain,
461 min_quality: Option<i32>,
462 traded_n_days_ago: Option<u64>,
463 chunk_size: Option<usize>,
464 concurrency: usize,
465 ) -> Result<Vec<ResponseToken>, RPCError> {
466 let chunk_size = chunk_size
467 .unwrap_or(TokensRequestBody::effective_max_page_size(self.compression()) as usize);
468
469 let semaphore = Arc::new(Semaphore::new(concurrency));
470
471 let page_size = chunk_size.try_into().map_err(|_| {
473 RPCError::FormatRequest("Failed to convert chunk_size into i64".to_string())
474 })?;
475
476 let initial_request = TokensRequestBody {
477 token_addresses: None,
478 min_quality,
479 traded_n_days_ago,
480 pagination: PaginationParams { page: 0, page_size },
481 chain,
482 };
483
484 let first_response = self
485 .get_tokens(&initial_request)
486 .await?;
487 let total_items = first_response.pagination.total;
488 let total_pages = (total_items as f64 / chunk_size as f64).ceil() as i64;
489
490 let mut all_tokens = first_response.tokens;
491
492 if total_pages <= 1 {
494 return Ok(all_tokens);
495 }
496
497 let tasks: Vec<_> = (1..total_pages)
499 .map(|page| {
500 let sem = semaphore.clone();
501 let request = TokensRequestBody {
502 token_addresses: None,
503 min_quality,
504 traded_n_days_ago,
505 pagination: PaginationParams { page, page_size },
506 chain,
507 };
508
509 async move {
510 let _permit = sem
512 .acquire()
513 .await
514 .map_err(|_| RPCError::Fatal("Semaphore dropped".to_string()))?;
515 self.get_tokens(&request).await
516 }
517 })
518 .collect();
519
520 let responses = try_join_all(tasks).await?;
522
523 for mut response in responses {
525 all_tokens.append(&mut response.tokens);
526 }
527
528 Ok(all_tokens)
529 }
530
531 async fn get_protocol_systems(
532 &self,
533 request: &ProtocolSystemsRequestBody,
534 ) -> Result<ProtocolSystemsRequestResponse, RPCError>;
535
536 async fn get_component_tvl(
537 &self,
538 request: &ComponentTvlRequestBody,
539 ) -> Result<ComponentTvlRequestResponse, RPCError>;
540
541 async fn get_component_tvl_paginated(
544 &self,
545 request: &ComponentTvlRequestBody,
546 chunk_size: Option<usize>,
547 concurrency: usize,
548 ) -> Result<ComponentTvlRequestResponse, RPCError> {
549 let semaphore = Arc::new(Semaphore::new(concurrency));
550
551 let chunk_size = chunk_size.unwrap_or(ComponentTvlRequestBody::effective_max_page_size(
552 self.compression(),
553 ) as usize);
554
555 match request.component_ids {
556 Some(ref ids) => {
557 let chunked_requests = ids
558 .chunks(chunk_size)
559 .enumerate()
560 .map(|(index, _)| ComponentTvlRequestBody {
561 chain: request.chain,
562 protocol_system: request.protocol_system.clone(),
563 component_ids: Some(ids.clone()),
564 pagination: PaginationParams {
565 page: index as i64,
566 page_size: chunk_size as i64,
567 },
568 })
569 .collect::<Vec<_>>();
570
571 let tasks: Vec<_> = chunked_requests
572 .into_iter()
573 .map(|req| {
574 let sem = semaphore.clone();
575 async move {
576 let _permit = sem
577 .acquire()
578 .await
579 .map_err(|_| RPCError::Fatal("Semaphore dropped".to_string()))?;
580 self.get_component_tvl(&req).await
581 }
582 })
583 .collect();
584
585 let responses = try_join_all(tasks).await?;
586
587 let mut merged_tvl = HashMap::new();
588 for resp in responses {
589 for (key, value) in resp.tvl {
590 *merged_tvl.entry(key).or_insert(0.0) = value;
591 }
592 }
593
594 Ok(ComponentTvlRequestResponse {
595 tvl: merged_tvl,
596 pagination: PaginationResponse {
597 page: 0,
598 page_size: chunk_size as i64,
599 total: ids.len() as i64,
600 },
601 })
602 }
603 _ => {
604 let first_request = ComponentTvlRequestBody {
605 chain: request.chain,
606 protocol_system: request.protocol_system.clone(),
607 component_ids: request.component_ids.clone(),
608 pagination: PaginationParams { page: 0, page_size: chunk_size as i64 },
609 };
610
611 let first_response = self
612 .get_component_tvl(&first_request)
613 .await?;
614 let total_items = first_response.pagination.total;
615 let total_pages = (total_items as f64 / chunk_size as f64).ceil() as i64;
616
617 let mut merged_tvl = first_response.tvl;
618
619 let mut page = 1;
620 while page < total_pages {
621 let requests_in_this_iteration = (total_pages - page).min(concurrency as i64);
622
623 let chunked_requests: Vec<_> = (0..requests_in_this_iteration)
624 .map(|i| ComponentTvlRequestBody {
625 chain: request.chain,
626 protocol_system: request.protocol_system.clone(),
627 component_ids: request.component_ids.clone(),
628 pagination: PaginationParams {
629 page: page + i,
630 page_size: chunk_size as i64,
631 },
632 })
633 .collect();
634
635 let tasks: Vec<_> = chunked_requests
636 .into_iter()
637 .map(|req| {
638 let sem = semaphore.clone();
639 async move {
640 let _permit = sem.acquire().await.map_err(|_| {
641 RPCError::Fatal("Semaphore dropped".to_string())
642 })?;
643 self.get_component_tvl(&req).await
644 }
645 })
646 .collect();
647
648 let responses = try_join_all(tasks).await?;
649
650 for resp in responses {
652 for (key, value) in resp.tvl {
653 *merged_tvl.entry(key).or_insert(0.0) += value;
654 }
655 }
656
657 page += concurrency as i64;
658 }
659
660 Ok(ComponentTvlRequestResponse {
661 tvl: merged_tvl,
662 pagination: PaginationResponse {
663 page: 0,
664 page_size: chunk_size as i64,
665 total: total_items,
666 },
667 })
668 }
669 }
670 }
671
672 async fn get_traced_entry_points(
673 &self,
674 request: &TracedEntryPointRequestBody,
675 ) -> Result<TracedEntryPointRequestResponse, RPCError>;
676
677 async fn get_traced_entry_points_paginated(
680 &self,
681 chain: Chain,
682 protocol_system: &str,
683 component_ids: &[String],
684 chunk_size: Option<usize>,
685 concurrency: usize,
686 ) -> Result<TracedEntryPointRequestResponse, RPCError> {
687 let semaphore = Arc::new(Semaphore::new(concurrency));
688
689 let chunk_size = chunk_size.unwrap_or(
690 TracedEntryPointRequestBody::effective_max_page_size(self.compression()) as usize,
691 );
692
693 let chunked_bodies = component_ids
694 .chunks(chunk_size)
695 .map(|c| TracedEntryPointRequestBody {
696 chain,
697 protocol_system: protocol_system.to_string(),
698 component_ids: Some(c.to_vec()),
699 pagination: PaginationParams { page: 0, page_size: chunk_size as i64 },
700 })
701 .collect::<Vec<_>>();
702
703 let mut tasks = Vec::new();
704 for body in chunked_bodies.iter() {
705 let sem = semaphore.clone();
706 tasks.push(async move {
707 let _permit = sem
708 .acquire()
709 .await
710 .map_err(|_| RPCError::Fatal("Semaphore dropped".to_string()))?;
711 self.get_traced_entry_points(body).await
712 });
713 }
714
715 try_join_all(tasks)
716 .await
717 .map(|responses| {
718 let traced_entry_points = responses
719 .clone()
720 .into_iter()
721 .flat_map(|r| r.traced_entry_points)
722 .collect();
723 let total = responses
724 .iter()
725 .map(|r| r.pagination.total)
726 .sum();
727 TracedEntryPointRequestResponse {
728 traced_entry_points,
729 pagination: PaginationResponse { page: 0, page_size: chunk_size as i64, total },
730 }
731 })
732 }
733
734 async fn get_snapshots<'a>(
735 &self,
736 request: &SnapshotParameters<'a>,
737 chunk_size: Option<usize>,
738 concurrency: usize,
739 ) -> Result<Snapshot, RPCError>;
740}
741
742#[derive(Debug, Clone)]
744pub struct HttpRPCClientOptions {
745 pub auth_key: Option<String>,
747 pub compression: bool,
750}
751
752impl Default for HttpRPCClientOptions {
753 fn default() -> Self {
754 Self::new()
755 }
756}
757
758impl HttpRPCClientOptions {
759 pub fn new() -> Self {
761 Self { auth_key: None, compression: true }
762 }
763
764 pub fn with_auth_key(mut self, auth_key: Option<String>) -> Self {
766 self.auth_key = auth_key;
767 self
768 }
769
770 pub fn with_compression(mut self, compression: bool) -> Self {
772 self.compression = compression;
773 self
774 }
775}
776
777#[derive(Debug, Clone)]
778pub struct HttpRPCClient {
779 http_client: Client,
780 url: Url,
781 retry_after: Arc<RwLock<Option<SystemTime>>>,
782 backoff_policy: ExponentialBackoff,
783 server_restart_duration: Duration,
784 compression: bool,
785}
786
787impl HttpRPCClient {
788 pub fn new(base_uri: &str, options: HttpRPCClientOptions) -> Result<Self, RPCError> {
789 let uri = base_uri
790 .parse::<Url>()
791 .map_err(|e| RPCError::UrlParsing(base_uri.to_string(), e.to_string()))?;
792
793 let mut headers = header::HeaderMap::new();
795 headers.insert(header::CONTENT_TYPE, header::HeaderValue::from_static("application/json"));
796 let user_agent = format!("tycho-client-{version}", version = env!("CARGO_PKG_VERSION"));
797 headers.insert(
798 header::USER_AGENT,
799 header::HeaderValue::from_str(&user_agent)
800 .map_err(|e| RPCError::FormatRequest(format!("Invalid user agent format: {e}")))?,
801 );
802
803 if let Some(key) = options.auth_key.as_deref() {
805 let mut auth_value = header::HeaderValue::from_str(key).map_err(|e| {
806 RPCError::FormatRequest(format!("Invalid authorization key format: {e}"))
807 })?;
808 auth_value.set_sensitive(true);
809 headers.insert(header::AUTHORIZATION, auth_value);
810 }
811
812 let mut client_builder = ClientBuilder::new()
813 .default_headers(headers)
814 .http2_prior_knowledge();
815
816 if !options.compression {
818 client_builder = client_builder.no_zstd();
819 }
820
821 let client = client_builder
822 .build()
823 .map_err(|e| RPCError::HttpClient(e.to_string(), e))?;
824
825 Ok(Self {
826 http_client: client,
827 url: uri,
828 retry_after: Arc::new(RwLock::new(None)),
829 backoff_policy: ExponentialBackoffBuilder::new()
830 .with_initial_interval(Duration::from_millis(250))
831 .with_multiplier(1.75)
833 .with_max_interval(Duration::from_secs(30))
835 .with_max_elapsed_time(Some(Duration::from_secs(125)))
837 .build(),
838 server_restart_duration: Duration::from_secs(120),
839 compression: options.compression,
840 })
841 }
842
843 #[cfg(test)]
844 pub fn with_test_backoff_policy(mut self) -> Self {
845 self.backoff_policy = ExponentialBackoffBuilder::new()
847 .with_initial_interval(Duration::from_millis(1))
848 .with_multiplier(1.1)
849 .with_max_interval(Duration::from_millis(5))
850 .with_max_elapsed_time(Some(Duration::from_millis(50)))
851 .build();
852 self.server_restart_duration = Duration::from_millis(50);
853 self
854 }
855
856 async fn error_for_response(
862 &self,
863 response: reqwest::Response,
864 ) -> Result<reqwest::Response, RPCError> {
865 match response.status() {
866 StatusCode::TOO_MANY_REQUESTS => {
867 let retry_after_raw = response
868 .headers()
869 .get(reqwest::header::RETRY_AFTER)
870 .and_then(|h| h.to_str().ok())
871 .and_then(parse_retry_value);
872
873 Err(RPCError::RateLimited(retry_after_raw))
874 }
875 StatusCode::BAD_GATEWAY |
876 StatusCode::SERVICE_UNAVAILABLE |
877 StatusCode::GATEWAY_TIMEOUT => Err(RPCError::ServerUnreachable(
878 response
879 .text()
880 .await
881 .unwrap_or_else(|_| "Server Unreachable".to_string()),
882 )),
883 _ => Ok(response),
884 }
885 }
886
887 async fn handle_error_for_backoff(&self, e: RPCError) -> backoff::Error<RPCError> {
893 match e {
894 RPCError::ServerUnreachable(_) => {
895 backoff::Error::retry_after(e, self.server_restart_duration)
896 }
897 RPCError::RateLimited(Some(until)) => {
898 let mut retry_after_guard = self.retry_after.write().await;
899 *retry_after_guard = Some(
900 retry_after_guard
901 .unwrap_or(until)
902 .max(until),
903 );
904
905 if let Ok(duration) = until.duration_since(SystemTime::now()) {
906 backoff::Error::retry_after(e, duration)
907 } else {
908 e.into()
909 }
910 }
911 RPCError::RateLimited(None) => e.into(),
912 _ => backoff::Error::permanent(e),
913 }
914 }
915
916 async fn wait_until_retry_after(&self) {
921 if let Some(&until) = self.retry_after.read().await.as_ref() {
922 let now = SystemTime::now();
923 if until > now {
924 if let Ok(duration) = until.duration_since(now) {
925 sleep(duration).await
926 }
927 }
928 }
929 }
930
931 async fn make_post_request<T: Serialize + ?Sized>(
936 &self,
937 request: &T,
938 uri: &String,
939 ) -> Result<Response, RPCError> {
940 self.wait_until_retry_after().await;
941 let response = backoff::future::retry(self.backoff_policy.clone(), || async {
942 let server_response = self
943 .http_client
944 .post(uri)
945 .json(request)
946 .send()
947 .await
948 .map_err(|e| RPCError::HttpClient(e.to_string(), e))?;
949
950 match self
951 .error_for_response(server_response)
952 .await
953 {
954 Ok(response) => Ok(response),
955 Err(e) => Err(self.handle_error_for_backoff(e).await),
956 }
957 })
958 .await?;
959 Ok(response)
960 }
961}
962
963fn parse_retry_value(val: &str) -> Option<SystemTime> {
964 if let Ok(secs) = val.parse::<u64>() {
965 return Some(SystemTime::now() + Duration::from_secs(secs));
966 }
967 if let Ok(date) = OffsetDateTime::parse(val, &Rfc2822) {
968 return Some(date.into());
969 }
970 None
971}
972
973#[async_trait]
974impl RPCClient for HttpRPCClient {
975 fn compression(&self) -> bool {
976 self.compression
977 }
978
979 #[instrument(skip(self, request))]
980 async fn get_contract_state(
981 &self,
982 request: &StateRequestBody,
983 ) -> Result<StateRequestResponse, RPCError> {
984 if request
986 .contract_ids
987 .as_ref()
988 .is_none_or(|ids| ids.is_empty())
989 {
990 warn!("No contract ids specified in request.");
991 }
992
993 let uri = format!(
994 "{}/{}/contract_state",
995 self.url
996 .to_string()
997 .trim_end_matches('/'),
998 TYCHO_SERVER_VERSION
999 );
1000 debug!(%uri, "Sending contract_state request to Tycho server");
1001 trace!(?request, "Sending request to Tycho server");
1002 let response = self
1003 .make_post_request(request, &uri)
1004 .await?;
1005 trace!(?response, "Received response from Tycho server");
1006
1007 let body = response
1008 .text()
1009 .await
1010 .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
1011 if body.is_empty() {
1012 return Ok(StateRequestResponse {
1014 accounts: vec![],
1015 pagination: PaginationResponse {
1016 page: request.pagination.page,
1017 page_size: request.pagination.page,
1018 total: 0,
1019 },
1020 });
1021 }
1022
1023 let accounts = serde_json::from_str::<StateRequestResponse>(&body)
1024 .map_err(|err| RPCError::ParseResponse(format!("Error: {err}, Body: {body}")))?;
1025 trace!(?accounts, "Received contract_state response from Tycho server");
1026
1027 Ok(accounts)
1028 }
1029
1030 async fn get_protocol_components(
1031 &self,
1032 request: &ProtocolComponentsRequestBody,
1033 ) -> Result<ProtocolComponentRequestResponse, RPCError> {
1034 let uri = format!(
1035 "{}/{}/protocol_components",
1036 self.url
1037 .to_string()
1038 .trim_end_matches('/'),
1039 TYCHO_SERVER_VERSION,
1040 );
1041 debug!(%uri, "Sending protocol_components request to Tycho server");
1042 trace!(?request, "Sending request to Tycho server");
1043
1044 let response = self
1045 .make_post_request(request, &uri)
1046 .await?;
1047
1048 trace!(?response, "Received response from Tycho server");
1049
1050 let body = response
1051 .text()
1052 .await
1053 .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
1054 let components = serde_json::from_str::<ProtocolComponentRequestResponse>(&body)
1055 .map_err(|err| RPCError::ParseResponse(format!("Error: {err}, Body: {body}")))?;
1056 trace!(?components, "Received protocol_components response from Tycho server");
1057
1058 Ok(components)
1059 }
1060
1061 async fn get_protocol_states(
1062 &self,
1063 request: &ProtocolStateRequestBody,
1064 ) -> Result<ProtocolStateRequestResponse, RPCError> {
1065 if request
1067 .protocol_ids
1068 .as_ref()
1069 .is_none_or(|ids| ids.is_empty())
1070 {
1071 warn!("No protocol ids specified in request.");
1072 }
1073
1074 let uri = format!(
1075 "{}/{}/protocol_state",
1076 self.url
1077 .to_string()
1078 .trim_end_matches('/'),
1079 TYCHO_SERVER_VERSION
1080 );
1081 debug!(%uri, "Sending protocol_states request to Tycho server");
1082 trace!(?request, "Sending request to Tycho server");
1083
1084 let response = self
1085 .make_post_request(request, &uri)
1086 .await?;
1087 trace!(?response, "Received response from Tycho server");
1088
1089 let body = response
1090 .text()
1091 .await
1092 .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
1093
1094 if body.is_empty() {
1095 return Ok(ProtocolStateRequestResponse {
1097 states: vec![],
1098 pagination: PaginationResponse {
1099 page: request.pagination.page,
1100 page_size: request.pagination.page_size,
1101 total: 0,
1102 },
1103 });
1104 }
1105
1106 let states = serde_json::from_str::<ProtocolStateRequestResponse>(&body)
1107 .map_err(|err| RPCError::ParseResponse(format!("Error: {err}, Body: {body}")))?;
1108 trace!(?states, "Received protocol_states response from Tycho server");
1109
1110 Ok(states)
1111 }
1112
1113 async fn get_tokens(
1114 &self,
1115 request: &TokensRequestBody,
1116 ) -> Result<TokensRequestResponse, RPCError> {
1117 let uri = format!(
1118 "{}/{}/tokens",
1119 self.url
1120 .to_string()
1121 .trim_end_matches('/'),
1122 TYCHO_SERVER_VERSION
1123 );
1124 debug!(%uri, "Sending tokens request to Tycho server");
1125
1126 let response = self
1127 .make_post_request(request, &uri)
1128 .await?;
1129
1130 let body = response
1131 .text()
1132 .await
1133 .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
1134 let tokens = serde_json::from_str::<TokensRequestResponse>(&body)
1135 .map_err(|err| RPCError::ParseResponse(format!("Error: {err}, Body: {body}")))?;
1136
1137 Ok(tokens)
1138 }
1139
1140 async fn get_protocol_systems(
1141 &self,
1142 request: &ProtocolSystemsRequestBody,
1143 ) -> Result<ProtocolSystemsRequestResponse, RPCError> {
1144 let uri = format!(
1145 "{}/{}/protocol_systems",
1146 self.url
1147 .to_string()
1148 .trim_end_matches('/'),
1149 TYCHO_SERVER_VERSION
1150 );
1151 debug!(%uri, "Sending protocol_systems request to Tycho server");
1152 trace!(?request, "Sending request to Tycho server");
1153 let response = self
1154 .make_post_request(request, &uri)
1155 .await?;
1156 trace!(?response, "Received response from Tycho server");
1157 let body = response
1158 .text()
1159 .await
1160 .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
1161 let protocol_systems = serde_json::from_str::<ProtocolSystemsRequestResponse>(&body)
1162 .map_err(|err| RPCError::ParseResponse(format!("Error: {err}, Body: {body}")))?;
1163 trace!(?protocol_systems, "Received protocol_systems response from Tycho server");
1164 Ok(protocol_systems)
1165 }
1166
1167 async fn get_component_tvl(
1168 &self,
1169 request: &ComponentTvlRequestBody,
1170 ) -> Result<ComponentTvlRequestResponse, RPCError> {
1171 let uri = format!(
1172 "{}/{}/component_tvl",
1173 self.url
1174 .to_string()
1175 .trim_end_matches('/'),
1176 TYCHO_SERVER_VERSION
1177 );
1178 debug!(%uri, "Sending get_component_tvl request to Tycho server");
1179 trace!(?request, "Sending request to Tycho server");
1180 let response = self
1181 .make_post_request(request, &uri)
1182 .await?;
1183 trace!(?response, "Received response from Tycho server");
1184 let body = response
1185 .text()
1186 .await
1187 .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
1188 let component_tvl =
1189 serde_json::from_str::<ComponentTvlRequestResponse>(&body).map_err(|err| {
1190 error!("Failed to parse component_tvl response: {:?}", &body);
1191 RPCError::ParseResponse(format!("Error: {err}, Body: {body}"))
1192 })?;
1193 trace!(?component_tvl, "Received component_tvl response from Tycho server");
1194 Ok(component_tvl)
1195 }
1196
1197 async fn get_traced_entry_points(
1198 &self,
1199 request: &TracedEntryPointRequestBody,
1200 ) -> Result<TracedEntryPointRequestResponse, RPCError> {
1201 let uri = format!(
1202 "{}/{TYCHO_SERVER_VERSION}/traced_entry_points",
1203 self.url
1204 .to_string()
1205 .trim_end_matches('/')
1206 );
1207 debug!(%uri, "Sending traced_entry_points request to Tycho server");
1208 trace!(?request, "Sending request to Tycho server");
1209
1210 let response = self
1211 .make_post_request(request, &uri)
1212 .await?;
1213
1214 trace!(?response, "Received response from Tycho server");
1215
1216 let body = response
1217 .text()
1218 .await
1219 .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
1220 let entrypoints =
1221 serde_json::from_str::<TracedEntryPointRequestResponse>(&body).map_err(|err| {
1222 error!("Failed to parse traced_entry_points response: {:?}", &body);
1223 RPCError::ParseResponse(format!("Error: {err}, Body: {body}"))
1224 })?;
1225 trace!(?entrypoints, "Received traced_entry_points response from Tycho server");
1226 Ok(entrypoints)
1227 }
1228
1229 async fn get_snapshots<'a>(
1230 &self,
1231 request: &SnapshotParameters<'a>,
1232 chunk_size: Option<usize>,
1233 concurrency: usize,
1234 ) -> Result<Snapshot, RPCError> {
1235 let component_ids: Vec<_> = request
1236 .components
1237 .keys()
1238 .cloned()
1239 .collect();
1240
1241 let version = VersionParam::new(
1242 None,
1243 Some({
1244 #[allow(deprecated)]
1245 BlockParam {
1246 hash: None,
1247 chain: Some(request.chain),
1248 number: Some(request.block_number as i64),
1249 }
1250 }),
1251 );
1252
1253 let component_tvl = if request.include_tvl && !component_ids.is_empty() {
1254 let body = ComponentTvlRequestBody::id_filtered(component_ids.clone(), request.chain);
1255 self.get_component_tvl_paginated(&body, chunk_size, concurrency)
1256 .await?
1257 .tvl
1258 } else {
1259 HashMap::new()
1260 };
1261
1262 let mut protocol_states = if !component_ids.is_empty() {
1263 self.get_protocol_states_paginated(
1264 request.chain,
1265 &component_ids,
1266 request.protocol_system,
1267 request.include_balances,
1268 &version,
1269 chunk_size,
1270 concurrency,
1271 )
1272 .await?
1273 .states
1274 .into_iter()
1275 .map(|state| (state.component_id.clone(), state))
1276 .collect()
1277 } else {
1278 HashMap::new()
1279 };
1280
1281 let states = request
1283 .components
1284 .values()
1285 .filter_map(|component| {
1286 if let Some(state) = protocol_states.remove(&component.id) {
1287 Some((
1288 component.id.clone(),
1289 ComponentWithState {
1290 state,
1291 component: component.clone(),
1292 component_tvl: component_tvl
1293 .get(&component.id)
1294 .cloned(),
1295 entrypoints: request
1296 .entrypoints
1297 .as_ref()
1298 .and_then(|map| map.get(&component.id))
1299 .cloned()
1300 .unwrap_or_default(),
1301 },
1302 ))
1303 } else if component_ids.contains(&component.id) {
1304 let component_id = &component.id;
1306 error!(?component_id, "Missing state for native component!");
1307 None
1308 } else {
1309 None
1310 }
1311 })
1312 .collect();
1313
1314 let vm_storage = if !request.contract_ids.is_empty() {
1315 let contract_states = self
1316 .get_contract_state_paginated(
1317 request.chain,
1318 request.contract_ids,
1319 request.protocol_system,
1320 &version,
1321 chunk_size,
1322 concurrency,
1323 )
1324 .await?
1325 .accounts
1326 .into_iter()
1327 .map(|acc| (acc.address.clone(), acc))
1328 .collect::<HashMap<_, _>>();
1329
1330 trace!(states=?&contract_states, "Retrieved ContractState");
1331
1332 let contract_address_to_components = request
1333 .components
1334 .iter()
1335 .filter_map(|(id, comp)| {
1336 if component_ids.contains(id) {
1337 Some(
1338 comp.contract_ids
1339 .iter()
1340 .map(|address| (address.clone(), comp.id.clone())),
1341 )
1342 } else {
1343 None
1344 }
1345 })
1346 .flatten()
1347 .fold(HashMap::<Bytes, Vec<String>>::new(), |mut acc, (addr, c_id)| {
1348 acc.entry(addr).or_default().push(c_id);
1349 acc
1350 });
1351
1352 request
1353 .contract_ids
1354 .iter()
1355 .filter_map(|address| {
1356 if let Some(state) = contract_states.get(address) {
1357 Some((address.clone(), state.clone()))
1358 } else if let Some(ids) = contract_address_to_components.get(address) {
1359 error!(
1361 ?address,
1362 ?ids,
1363 "Component with lacking contract storage encountered!"
1364 );
1365 None
1366 } else {
1367 None
1368 }
1369 })
1370 .collect()
1371 } else {
1372 HashMap::new()
1373 };
1374
1375 Ok(Snapshot { states, vm_storage })
1376 }
1377}
1378
1379#[cfg(test)]
1380mod tests {
1381 use std::{
1382 collections::{HashMap, HashSet},
1383 str::FromStr,
1384 };
1385
1386 use mockito::Server;
1387 use rstest::rstest;
1388 #[allow(deprecated)]
1390 use tycho_common::dto::ProtocolId;
1391 use tycho_common::dto::{AddressStorageLocation, TracingParams};
1392
1393 use super::*;
1394
1395 impl MockRPCClient {
1398 #[allow(clippy::too_many_arguments)]
1399 async fn test_get_protocol_states_paginated<T>(
1400 &self,
1401 chain: Chain,
1402 ids: &[T],
1403 protocol_system: &str,
1404 include_balances: bool,
1405 version: &VersionParam,
1406 chunk_size: usize,
1407 _concurrency: usize,
1408 ) -> Vec<ProtocolStateRequestBody>
1409 where
1410 T: AsRef<str> + Clone + Send + Sync + 'static,
1411 {
1412 ids.chunks(chunk_size)
1413 .map(|chunk| ProtocolStateRequestBody {
1414 protocol_ids: Some(
1415 chunk
1416 .iter()
1417 .map(|id| id.as_ref().to_string())
1418 .collect(),
1419 ),
1420 protocol_system: protocol_system.to_string(),
1421 chain,
1422 include_balances,
1423 version: version.clone(),
1424 pagination: PaginationParams { page: 0, page_size: chunk_size as i64 },
1425 })
1426 .collect()
1427 }
1428 }
1429
1430 const GET_CONTRACT_STATE_RESP: &str = r#"
1431 {
1432 "accounts": [
1433 {
1434 "chain": "ethereum",
1435 "address": "0x0000000000000000000000000000000000000000",
1436 "title": "",
1437 "slots": {},
1438 "native_balance": "0x01f4",
1439 "token_balances": {},
1440 "code": "0x00",
1441 "code_hash": "0x5c06b7c5b3d910fd33bc2229846f9ddaf91d584d9b196e16636901ac3a77077e",
1442 "balance_modify_tx": "0x0000000000000000000000000000000000000000000000000000000000000000",
1443 "code_modify_tx": "0x0000000000000000000000000000000000000000000000000000000000000000",
1444 "creation_tx": null
1445 }
1446 ],
1447 "pagination": {
1448 "page": 0,
1449 "page_size": 20,
1450 "total": 10
1451 }
1452 }
1453 "#;
1454
1455 #[allow(deprecated)]
1457 #[rstest]
1458 #[case::protocol_id_input(vec![
1459 ProtocolId { id: "id1".to_string(), chain: Chain::Ethereum },
1460 ProtocolId { id: "id2".to_string(), chain: Chain::Ethereum }
1461 ])]
1462 #[case::string_input(vec![
1463 "id1".to_string(),
1464 "id2".to_string()
1465 ])]
1466 #[tokio::test]
1467 async fn test_get_protocol_states_paginated_backwards_compatibility<T>(#[case] ids: Vec<T>)
1468 where
1469 T: AsRef<str> + Clone + Send + Sync + 'static,
1470 {
1471 let mock_client = MockRPCClient::new();
1472
1473 let request_bodies = mock_client
1474 .test_get_protocol_states_paginated(
1475 Chain::Ethereum,
1476 &ids,
1477 "test_system",
1478 true,
1479 &VersionParam::default(),
1480 2,
1481 2,
1482 )
1483 .await;
1484
1485 assert_eq!(request_bodies.len(), 1);
1487 assert_eq!(
1488 request_bodies[0]
1489 .protocol_ids
1490 .as_ref()
1491 .unwrap()
1492 .len(),
1493 2
1494 );
1495 }
1496
1497 #[tokio::test]
1498 async fn test_get_contract_state() {
1499 let mut server = Server::new_async().await;
1500 let server_resp = GET_CONTRACT_STATE_RESP;
1501 serde_json::from_str::<StateRequestResponse>(server_resp).expect("deserialize");
1503
1504 let mocked_server = server
1505 .mock("POST", "/v1/contract_state")
1506 .expect(1)
1507 .with_body(server_resp)
1508 .create_async()
1509 .await;
1510
1511 let client = HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
1512 .expect("create client");
1513
1514 let response = client
1515 .get_contract_state(&Default::default())
1516 .await
1517 .expect("get state");
1518 let accounts = response.accounts;
1519
1520 mocked_server.assert();
1521 assert_eq!(accounts.len(), 1);
1522 assert_eq!(accounts[0].slots, HashMap::new());
1523 assert_eq!(accounts[0].native_balance, Bytes::from(500u16.to_be_bytes()));
1524 assert_eq!(accounts[0].code, [0].to_vec());
1525 assert_eq!(
1526 accounts[0].code_hash,
1527 hex::decode("5c06b7c5b3d910fd33bc2229846f9ddaf91d584d9b196e16636901ac3a77077e")
1528 .unwrap()
1529 );
1530 }
1531
1532 #[tokio::test]
1533 async fn test_get_protocol_components() {
1534 let mut server = Server::new_async().await;
1535 let server_resp = r#"
1536 {
1537 "protocol_components": [
1538 {
1539 "id": "State1",
1540 "protocol_system": "ambient",
1541 "protocol_type_name": "Pool",
1542 "chain": "ethereum",
1543 "tokens": [
1544 "0x0000000000000000000000000000000000000000",
1545 "0x0000000000000000000000000000000000000001"
1546 ],
1547 "contract_ids": [
1548 "0x0000000000000000000000000000000000000000"
1549 ],
1550 "static_attributes": {
1551 "attribute_1": "0x00000000000003e8"
1552 },
1553 "change": "Creation",
1554 "creation_tx": "0x0000000000000000000000000000000000000000000000000000000000000000",
1555 "created_at": "2022-01-01T00:00:00"
1556 }
1557 ],
1558 "pagination": {
1559 "page": 0,
1560 "page_size": 20,
1561 "total": 10
1562 }
1563 }
1564 "#;
1565 serde_json::from_str::<ProtocolComponentRequestResponse>(server_resp).expect("deserialize");
1567
1568 let mocked_server = server
1569 .mock("POST", "/v1/protocol_components")
1570 .expect(1)
1571 .with_body(server_resp)
1572 .create_async()
1573 .await;
1574
1575 let client = HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
1576 .expect("create client");
1577
1578 let response = client
1579 .get_protocol_components(&Default::default())
1580 .await
1581 .expect("get state");
1582 let components = response.protocol_components;
1583
1584 mocked_server.assert();
1585 assert_eq!(components.len(), 1);
1586 assert_eq!(components[0].id, "State1");
1587 assert_eq!(components[0].protocol_system, "ambient");
1588 assert_eq!(components[0].protocol_type_name, "Pool");
1589 assert_eq!(components[0].tokens.len(), 2);
1590 let expected_attributes =
1591 [("attribute_1".to_string(), Bytes::from(1000_u64.to_be_bytes()))]
1592 .iter()
1593 .cloned()
1594 .collect::<HashMap<String, Bytes>>();
1595 assert_eq!(components[0].static_attributes, expected_attributes);
1596 }
1597
1598 #[tokio::test]
1599 async fn test_get_protocol_states() {
1600 let mut server = Server::new_async().await;
1601 let server_resp = r#"
1602 {
1603 "states": [
1604 {
1605 "component_id": "State1",
1606 "attributes": {
1607 "attribute_1": "0x00000000000003e8"
1608 },
1609 "balances": {
1610 "0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2": "0x01f4"
1611 }
1612 }
1613 ],
1614 "pagination": {
1615 "page": 0,
1616 "page_size": 20,
1617 "total": 10
1618 }
1619 }
1620 "#;
1621 serde_json::from_str::<ProtocolStateRequestResponse>(server_resp).expect("deserialize");
1623
1624 let mocked_server = server
1625 .mock("POST", "/v1/protocol_state")
1626 .expect(1)
1627 .with_body(server_resp)
1628 .create_async()
1629 .await;
1630 let client = HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
1631 .expect("create client");
1632
1633 let response = client
1634 .get_protocol_states(&Default::default())
1635 .await
1636 .expect("get state");
1637 let states = response.states;
1638
1639 mocked_server.assert();
1640 assert_eq!(states.len(), 1);
1641 assert_eq!(states[0].component_id, "State1");
1642 let expected_attributes =
1643 [("attribute_1".to_string(), Bytes::from(1000_u64.to_be_bytes()))]
1644 .iter()
1645 .cloned()
1646 .collect::<HashMap<String, Bytes>>();
1647 assert_eq!(states[0].attributes, expected_attributes);
1648 let expected_balances = [(
1649 Bytes::from_str("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2")
1650 .expect("Unsupported address format"),
1651 Bytes::from_str("0x01f4").unwrap(),
1652 )]
1653 .iter()
1654 .cloned()
1655 .collect::<HashMap<Bytes, Bytes>>();
1656 assert_eq!(states[0].balances, expected_balances);
1657 }
1658
1659 #[tokio::test]
1660 async fn test_get_tokens() {
1661 let mut server = Server::new_async().await;
1662 let server_resp = r#"
1663 {
1664 "tokens": [
1665 {
1666 "chain": "ethereum",
1667 "address": "0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2",
1668 "symbol": "WETH",
1669 "decimals": 18,
1670 "tax": 0,
1671 "gas": [
1672 29962
1673 ],
1674 "quality": 100
1675 },
1676 {
1677 "chain": "ethereum",
1678 "address": "0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48",
1679 "symbol": "USDC",
1680 "decimals": 6,
1681 "tax": 0,
1682 "gas": [
1683 40652
1684 ],
1685 "quality": 100
1686 }
1687 ],
1688 "pagination": {
1689 "page": 0,
1690 "page_size": 20,
1691 "total": 10
1692 }
1693 }
1694 "#;
1695 serde_json::from_str::<TokensRequestResponse>(server_resp).expect("deserialize");
1697
1698 let mocked_server = server
1699 .mock("POST", "/v1/tokens")
1700 .expect(1)
1701 .with_body(server_resp)
1702 .create_async()
1703 .await;
1704 let client = HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
1705 .expect("create client");
1706
1707 let response = client
1708 .get_tokens(&Default::default())
1709 .await
1710 .expect("get tokens");
1711
1712 let expected = vec![
1713 ResponseToken {
1714 chain: Chain::Ethereum,
1715 address: Bytes::from_str("0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2").unwrap(),
1716 symbol: "WETH".to_string(),
1717 decimals: 18,
1718 tax: 0,
1719 gas: vec![Some(29962)],
1720 quality: 100,
1721 },
1722 ResponseToken {
1723 chain: Chain::Ethereum,
1724 address: Bytes::from_str("0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48").unwrap(),
1725 symbol: "USDC".to_string(),
1726 decimals: 6,
1727 tax: 0,
1728 gas: vec![Some(40652)],
1729 quality: 100,
1730 },
1731 ];
1732
1733 mocked_server.assert();
1734 assert_eq!(response.tokens, expected);
1735 assert_eq!(response.pagination, PaginationResponse { page: 0, page_size: 20, total: 10 });
1736 }
1737
1738 #[tokio::test]
1739 async fn test_get_protocol_systems() {
1740 let mut server = Server::new_async().await;
1741 let server_resp = r#"
1742 {
1743 "protocol_systems": [
1744 "system1",
1745 "system2"
1746 ],
1747 "pagination": {
1748 "page": 0,
1749 "page_size": 20,
1750 "total": 10
1751 }
1752 }
1753 "#;
1754 serde_json::from_str::<ProtocolSystemsRequestResponse>(server_resp).expect("deserialize");
1756
1757 let mocked_server = server
1758 .mock("POST", "/v1/protocol_systems")
1759 .expect(1)
1760 .with_body(server_resp)
1761 .create_async()
1762 .await;
1763 let client = HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
1764 .expect("create client");
1765
1766 let response = client
1767 .get_protocol_systems(&Default::default())
1768 .await
1769 .expect("get protocol systems");
1770 let protocol_systems = response.protocol_systems;
1771
1772 mocked_server.assert();
1773 assert_eq!(protocol_systems, vec!["system1", "system2"]);
1774 }
1775
1776 #[tokio::test]
1777 async fn test_get_component_tvl() {
1778 let mut server = Server::new_async().await;
1779 let server_resp = r#"
1780 {
1781 "tvl": {
1782 "component1": 100.0
1783 },
1784 "pagination": {
1785 "page": 0,
1786 "page_size": 20,
1787 "total": 10
1788 }
1789 }
1790 "#;
1791 serde_json::from_str::<ComponentTvlRequestResponse>(server_resp).expect("deserialize");
1793
1794 let mocked_server = server
1795 .mock("POST", "/v1/component_tvl")
1796 .expect(1)
1797 .with_body(server_resp)
1798 .create_async()
1799 .await;
1800 let client = HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
1801 .expect("create client");
1802
1803 let response = client
1804 .get_component_tvl(&Default::default())
1805 .await
1806 .expect("get protocol systems");
1807 let component_tvl = response.tvl;
1808
1809 mocked_server.assert();
1810 assert_eq!(component_tvl.get("component1"), Some(&100.0));
1811 }
1812
1813 #[tokio::test]
1814 async fn test_get_traced_entry_points() {
1815 let mut server = Server::new_async().await;
1816 let server_resp = r#"
1817 {
1818 "traced_entry_points": {
1819 "component_1": [
1820 [
1821 {
1822 "entry_point": {
1823 "external_id": "entrypoint_a",
1824 "target": "0x0000000000000000000000000000000000000001",
1825 "signature": "sig()"
1826 },
1827 "params": {
1828 "method": "rpctracer",
1829 "caller": "0x000000000000000000000000000000000000000a",
1830 "calldata": "0x000000000000000000000000000000000000000b"
1831 }
1832 },
1833 {
1834 "retriggers": [
1835 [
1836 "0x00000000000000000000000000000000000000aa",
1837 {"key": "0x0000000000000000000000000000000000000aaa", "offset": 12}
1838 ]
1839 ],
1840 "accessed_slots": {
1841 "0x0000000000000000000000000000000000aaaa": [
1842 "0x0000000000000000000000000000000000aaaa"
1843 ]
1844 }
1845 }
1846 ]
1847 ]
1848 },
1849 "pagination": {
1850 "page": 0,
1851 "page_size": 20,
1852 "total": 1
1853 }
1854 }
1855 "#;
1856 serde_json::from_str::<TracedEntryPointRequestResponse>(server_resp).expect("deserialize");
1858
1859 let mocked_server = server
1860 .mock("POST", "/v1/traced_entry_points")
1861 .expect(1)
1862 .with_body(server_resp)
1863 .create_async()
1864 .await;
1865 let client = HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
1866 .expect("create client");
1867
1868 let response = client
1869 .get_traced_entry_points(&Default::default())
1870 .await
1871 .expect("get traced entry points");
1872 let entrypoints = response.traced_entry_points;
1873
1874 mocked_server.assert();
1875 assert_eq!(entrypoints.len(), 1);
1876 let comp1_entrypoints = entrypoints
1877 .get("component_1")
1878 .expect("component_1 entrypoints should exist");
1879 assert_eq!(comp1_entrypoints.len(), 1);
1880
1881 let (entrypoint, trace_result) = &comp1_entrypoints[0];
1882 assert_eq!(entrypoint.entry_point.external_id, "entrypoint_a");
1883 assert_eq!(
1884 entrypoint.entry_point.target,
1885 Bytes::from_str("0x0000000000000000000000000000000000000001").unwrap()
1886 );
1887 assert_eq!(entrypoint.entry_point.signature, "sig()");
1888 let TracingParams::RPCTracer(rpc_params) = &entrypoint.params;
1889 assert_eq!(
1890 rpc_params.caller,
1891 Some(Bytes::from("0x000000000000000000000000000000000000000a"))
1892 );
1893 assert_eq!(rpc_params.calldata, Bytes::from("0x000000000000000000000000000000000000000b"));
1894
1895 assert_eq!(
1896 trace_result.retriggers,
1897 HashSet::from([(
1898 Bytes::from("0x00000000000000000000000000000000000000aa"),
1899 AddressStorageLocation::new(
1900 Bytes::from("0x0000000000000000000000000000000000000aaa"),
1901 12
1902 )
1903 )])
1904 );
1905 assert_eq!(trace_result.accessed_slots.len(), 1);
1906 assert_eq!(
1907 trace_result.accessed_slots,
1908 HashMap::from([(
1909 Bytes::from("0x0000000000000000000000000000000000aaaa"),
1910 HashSet::from([Bytes::from("0x0000000000000000000000000000000000aaaa")])
1911 )])
1912 );
1913 }
1914
1915 #[tokio::test]
1916 async fn test_parse_retry_value_numeric() {
1917 let result = parse_retry_value("60");
1918 assert!(result.is_some());
1919
1920 let expected_time = SystemTime::now() + Duration::from_secs(60);
1921 let actual_time = result.unwrap();
1922
1923 let diff = if actual_time > expected_time {
1925 actual_time
1926 .duration_since(expected_time)
1927 .unwrap()
1928 } else {
1929 expected_time
1930 .duration_since(actual_time)
1931 .unwrap()
1932 };
1933 assert!(diff < Duration::from_secs(1), "Time difference too large: {:?}", diff);
1934 }
1935
1936 #[tokio::test]
1937 async fn test_parse_retry_value_rfc2822() {
1938 let rfc2822_date = "Sat, 01 Jan 2030 12:00:00 +0000";
1940 let result = parse_retry_value(rfc2822_date);
1941 assert!(result.is_some());
1942
1943 let parsed_time = result.unwrap();
1944 assert!(parsed_time > SystemTime::now());
1945 }
1946
1947 #[tokio::test]
1948 async fn test_parse_retry_value_invalid_formats() {
1949 assert!(parse_retry_value("invalid").is_none());
1951 assert!(parse_retry_value("").is_none());
1952 assert!(parse_retry_value("not_a_number").is_none());
1953 assert!(parse_retry_value("Mon, 32 Jan 2030 25:00:00 +0000").is_none()); }
1955
1956 #[tokio::test]
1957 async fn test_parse_retry_value_zero_seconds() {
1958 let result = parse_retry_value("0");
1959 assert!(result.is_some());
1960
1961 let expected_time = SystemTime::now();
1962 let actual_time = result.unwrap();
1963
1964 let diff = if actual_time > expected_time {
1966 actual_time
1967 .duration_since(expected_time)
1968 .unwrap()
1969 } else {
1970 expected_time
1971 .duration_since(actual_time)
1972 .unwrap()
1973 };
1974 assert!(diff < Duration::from_secs(1));
1975 }
1976
1977 #[tokio::test]
1978 async fn test_error_for_response_rate_limited() {
1979 let mut server = Server::new_async().await;
1980 let mock = server
1981 .mock("GET", "/test")
1982 .with_status(429)
1983 .with_header("Retry-After", "60")
1984 .create_async()
1985 .await;
1986
1987 let client = reqwest::Client::new();
1988 let response = client
1989 .get(format!("{}/test", server.url()))
1990 .send()
1991 .await
1992 .unwrap();
1993
1994 let http_client =
1995 HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
1996 .unwrap()
1997 .with_test_backoff_policy();
1998 let result = http_client
1999 .error_for_response(response)
2000 .await;
2001
2002 mock.assert();
2003 assert!(matches!(result, Err(RPCError::RateLimited(_))));
2004 if let Err(RPCError::RateLimited(retry_after)) = result {
2005 assert!(retry_after.is_some());
2006 }
2007 }
2008
2009 #[tokio::test]
2010 async fn test_error_for_response_rate_limited_no_header() {
2011 let mut server = Server::new_async().await;
2012 let mock = server
2013 .mock("GET", "/test")
2014 .with_status(429)
2015 .create_async()
2016 .await;
2017
2018 let client = reqwest::Client::new();
2019 let response = client
2020 .get(format!("{}/test", server.url()))
2021 .send()
2022 .await
2023 .unwrap();
2024
2025 let http_client =
2026 HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
2027 .unwrap()
2028 .with_test_backoff_policy();
2029 let result = http_client
2030 .error_for_response(response)
2031 .await;
2032
2033 mock.assert();
2034 assert!(matches!(result, Err(RPCError::RateLimited(None))));
2035 }
2036
2037 #[tokio::test]
2038 async fn test_error_for_response_server_errors() {
2039 let test_cases =
2040 vec![(502, "Bad Gateway"), (503, "Service Unavailable"), (504, "Gateway Timeout")];
2041
2042 for (status_code, expected_body) in test_cases {
2043 let mut server = Server::new_async().await;
2044 let mock = server
2045 .mock("GET", "/test")
2046 .with_status(status_code)
2047 .with_body(expected_body)
2048 .create_async()
2049 .await;
2050
2051 let client = reqwest::Client::new();
2052 let response = client
2053 .get(format!("{}/test", server.url()))
2054 .send()
2055 .await
2056 .unwrap();
2057
2058 let http_client =
2059 HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
2060 .unwrap()
2061 .with_test_backoff_policy();
2062 let result = http_client
2063 .error_for_response(response)
2064 .await;
2065
2066 mock.assert();
2067 assert!(matches!(result, Err(RPCError::ServerUnreachable(_))));
2068 if let Err(RPCError::ServerUnreachable(body)) = result {
2069 assert_eq!(body, expected_body);
2070 }
2071 }
2072 }
2073
2074 #[tokio::test]
2075 async fn test_error_for_response_success() {
2076 let mut server = Server::new_async().await;
2077 let mock = server
2078 .mock("GET", "/test")
2079 .with_status(200)
2080 .with_body("success")
2081 .create_async()
2082 .await;
2083
2084 let client = reqwest::Client::new();
2085 let response = client
2086 .get(format!("{}/test", server.url()))
2087 .send()
2088 .await
2089 .unwrap();
2090
2091 let http_client =
2092 HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
2093 .unwrap()
2094 .with_test_backoff_policy();
2095 let result = http_client
2096 .error_for_response(response)
2097 .await;
2098
2099 mock.assert();
2100 assert!(result.is_ok());
2101
2102 let response = result.unwrap();
2103 assert_eq!(response.status(), 200);
2104 }
2105
2106 #[tokio::test]
2107 async fn test_handle_error_for_backoff_server_unreachable() {
2108 let http_client =
2109 HttpRPCClient::new("http://localhost:8080", HttpRPCClientOptions::default())
2110 .unwrap()
2111 .with_test_backoff_policy();
2112 let error = RPCError::ServerUnreachable("Service down".to_string());
2113
2114 let backoff_error = http_client
2115 .handle_error_for_backoff(error)
2116 .await;
2117
2118 match backoff_error {
2119 backoff::Error::Transient { err: RPCError::ServerUnreachable(msg), retry_after } => {
2120 assert_eq!(msg, "Service down");
2121 assert_eq!(retry_after, Some(Duration::from_millis(50))); }
2123 _ => panic!("Expected transient error for ServerUnreachable"),
2124 }
2125 }
2126
2127 #[tokio::test]
2128 async fn test_handle_error_for_backoff_rate_limited_with_retry_after() {
2129 let http_client =
2130 HttpRPCClient::new("http://localhost:8080", HttpRPCClientOptions::default())
2131 .unwrap()
2132 .with_test_backoff_policy();
2133 let future_time = SystemTime::now() + Duration::from_secs(30);
2134 let error = RPCError::RateLimited(Some(future_time));
2135
2136 let backoff_error = http_client
2137 .handle_error_for_backoff(error)
2138 .await;
2139
2140 match backoff_error {
2141 backoff::Error::Transient { err: RPCError::RateLimited(retry_after), .. } => {
2142 assert_eq!(retry_after, Some(future_time));
2143 }
2144 _ => panic!("Expected transient error for RateLimited"),
2145 }
2146
2147 let stored_retry_after = http_client.retry_after.read().await;
2149 assert_eq!(*stored_retry_after, Some(future_time));
2150 }
2151
2152 #[tokio::test]
2153 async fn test_handle_error_for_backoff_rate_limited_no_retry_after() {
2154 let http_client =
2155 HttpRPCClient::new("http://localhost:8080", HttpRPCClientOptions::default())
2156 .unwrap()
2157 .with_test_backoff_policy();
2158 let error = RPCError::RateLimited(None);
2159
2160 let backoff_error = http_client
2161 .handle_error_for_backoff(error)
2162 .await;
2163
2164 match backoff_error {
2165 backoff::Error::Transient { err: RPCError::RateLimited(None), .. } => {
2166 }
2168 _ => panic!("Expected transient error for RateLimited without retry-after"),
2169 }
2170 }
2171
2172 #[tokio::test]
2173 async fn test_handle_error_for_backoff_other_errors() {
2174 let http_client =
2175 HttpRPCClient::new("http://localhost:8080", HttpRPCClientOptions::default())
2176 .unwrap()
2177 .with_test_backoff_policy();
2178 let error = RPCError::ParseResponse("Invalid JSON".to_string());
2179
2180 let backoff_error = http_client
2181 .handle_error_for_backoff(error)
2182 .await;
2183
2184 match backoff_error {
2185 backoff::Error::Permanent(RPCError::ParseResponse(msg)) => {
2186 assert_eq!(msg, "Invalid JSON");
2187 }
2188 _ => panic!("Expected permanent error for ParseResponse"),
2189 }
2190 }
2191
2192 #[tokio::test]
2193 async fn test_wait_until_retry_after_no_retry_time() {
2194 let http_client =
2195 HttpRPCClient::new("http://localhost:8080", HttpRPCClientOptions::default())
2196 .unwrap()
2197 .with_test_backoff_policy();
2198
2199 let start = std::time::Instant::now();
2200 http_client
2201 .wait_until_retry_after()
2202 .await;
2203 let elapsed = start.elapsed();
2204
2205 assert!(elapsed < Duration::from_millis(100));
2207 }
2208
2209 #[tokio::test]
2210 async fn test_wait_until_retry_after_past_time() {
2211 let http_client =
2212 HttpRPCClient::new("http://localhost:8080", HttpRPCClientOptions::default())
2213 .unwrap()
2214 .with_test_backoff_policy();
2215
2216 let past_time = SystemTime::now() - Duration::from_secs(10);
2218 *http_client.retry_after.write().await = Some(past_time);
2219
2220 let start = std::time::Instant::now();
2221 http_client
2222 .wait_until_retry_after()
2223 .await;
2224 let elapsed = start.elapsed();
2225
2226 assert!(elapsed < Duration::from_millis(100));
2228 }
2229
2230 #[tokio::test]
2231 async fn test_wait_until_retry_after_future_time() {
2232 let http_client =
2233 HttpRPCClient::new("http://localhost:8080", HttpRPCClientOptions::default())
2234 .unwrap()
2235 .with_test_backoff_policy();
2236
2237 let future_time = SystemTime::now() + Duration::from_millis(100);
2239 *http_client.retry_after.write().await = Some(future_time);
2240
2241 let start = std::time::Instant::now();
2242 http_client
2243 .wait_until_retry_after()
2244 .await;
2245 let elapsed = start.elapsed();
2246
2247 assert!(elapsed >= Duration::from_millis(80)); assert!(elapsed <= Duration::from_millis(200)); }
2251
2252 #[tokio::test]
2253 async fn test_make_post_request_success() {
2254 let mut server = Server::new_async().await;
2255 let server_resp = r#"{"success": true}"#;
2256
2257 let mock = server
2258 .mock("POST", "/test")
2259 .with_status(200)
2260 .with_body(server_resp)
2261 .create_async()
2262 .await;
2263
2264 let http_client =
2265 HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
2266 .unwrap()
2267 .with_test_backoff_policy();
2268 let request_body = serde_json::json!({"test": "data"});
2269 let uri = format!("{}/test", server.url());
2270
2271 let result = http_client
2272 .make_post_request(&request_body, &uri)
2273 .await;
2274
2275 mock.assert();
2276 assert!(result.is_ok());
2277
2278 let response = result.unwrap();
2279 assert_eq!(response.status(), 200);
2280 assert_eq!(response.text().await.unwrap(), server_resp);
2281 }
2282
2283 #[tokio::test]
2284 async fn test_make_post_request_retry_on_server_error() {
2285 let mut server = Server::new_async().await;
2286 let error_mock = server
2288 .mock("POST", "/test")
2289 .with_status(503)
2290 .with_body("Service Unavailable")
2291 .expect(1)
2292 .create_async()
2293 .await;
2294
2295 let success_mock = server
2296 .mock("POST", "/test")
2297 .with_status(200)
2298 .with_body(r#"{"success": true}"#)
2299 .expect(1)
2300 .create_async()
2301 .await;
2302
2303 let http_client =
2304 HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
2305 .unwrap()
2306 .with_test_backoff_policy();
2307 let request_body = serde_json::json!({"test": "data"});
2308 let uri = format!("{}/test", server.url());
2309
2310 let result = http_client
2311 .make_post_request(&request_body, &uri)
2312 .await;
2313
2314 error_mock.assert();
2315 success_mock.assert();
2316 assert!(result.is_ok());
2317 }
2318
2319 #[tokio::test]
2320 async fn test_make_post_request_respect_retry_after_header() {
2321 let mut server = Server::new_async().await;
2322
2323 let rate_limit_mock = server
2325 .mock("POST", "/test")
2326 .with_status(429)
2327 .with_header("Retry-After", "1") .expect(1)
2329 .create_async()
2330 .await;
2331
2332 let success_mock = server
2333 .mock("POST", "/test")
2334 .with_status(200)
2335 .with_body(r#"{"success": true}"#)
2336 .expect(1)
2337 .create_async()
2338 .await;
2339
2340 let http_client =
2341 HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
2342 .unwrap()
2343 .with_test_backoff_policy();
2344 let request_body = serde_json::json!({"test": "data"});
2345 let uri = format!("{}/test", server.url());
2346
2347 let start = std::time::Instant::now();
2348 let result = http_client
2349 .make_post_request(&request_body, &uri)
2350 .await;
2351 let elapsed = start.elapsed();
2352
2353 rate_limit_mock.assert();
2354 success_mock.assert();
2355 assert!(result.is_ok());
2356
2357 assert!(elapsed >= Duration::from_millis(900)); assert!(elapsed <= Duration::from_millis(2000)); }
2361
2362 #[tokio::test]
2363 async fn test_make_post_request_permanent_error() {
2364 let mut server = Server::new_async().await;
2365
2366 let mock = server
2367 .mock("POST", "/test")
2368 .with_status(400) .with_body("Bad Request")
2370 .expect(1)
2371 .create_async()
2372 .await;
2373
2374 let http_client =
2375 HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
2376 .unwrap()
2377 .with_test_backoff_policy();
2378 let request_body = serde_json::json!({"test": "data"});
2379 let uri = format!("{}/test", server.url());
2380
2381 let result = http_client
2382 .make_post_request(&request_body, &uri)
2383 .await;
2384
2385 mock.assert();
2386 assert!(result.is_ok()); let response = result.unwrap();
2389 assert_eq!(response.status(), 400);
2390 }
2391
2392 #[tokio::test]
2393 async fn test_concurrent_requests_with_different_retry_after() {
2394 let mut server = Server::new_async().await;
2395
2396 let rate_limit_mock_1 = server
2398 .mock("POST", "/test1")
2399 .with_status(429)
2400 .with_header("Retry-After", "1")
2401 .expect(1)
2402 .create_async()
2403 .await;
2404
2405 let rate_limit_mock_2 = server
2407 .mock("POST", "/test2")
2408 .with_status(429)
2409 .with_header("Retry-After", "2")
2410 .expect(1)
2411 .create_async()
2412 .await;
2413
2414 let success_mock_1 = server
2416 .mock("POST", "/test1")
2417 .with_status(200)
2418 .with_body(r#"{"result": "success1"}"#)
2419 .expect(1)
2420 .create_async()
2421 .await;
2422
2423 let success_mock_2 = server
2424 .mock("POST", "/test2")
2425 .with_status(200)
2426 .with_body(r#"{"result": "success2"}"#)
2427 .expect(1)
2428 .create_async()
2429 .await;
2430
2431 let http_client =
2432 HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
2433 .unwrap()
2434 .with_test_backoff_policy();
2435 let request_body = serde_json::json!({"test": "data"});
2436
2437 let uri1 = format!("{}/test1", server.url());
2438 let uri2 = format!("{}/test2", server.url());
2439
2440 let start = std::time::Instant::now();
2442 let (result1, result2) = tokio::join!(
2443 http_client.make_post_request(&request_body, &uri1),
2444 http_client.make_post_request(&request_body, &uri2)
2445 );
2446 let elapsed = start.elapsed();
2447
2448 rate_limit_mock_1.assert();
2449 rate_limit_mock_2.assert();
2450 success_mock_1.assert();
2451 success_mock_2.assert();
2452
2453 assert!(result1.is_ok());
2454 assert!(result2.is_ok());
2455
2456 assert!(elapsed >= Duration::from_millis(1800)); assert!(elapsed <= Duration::from_millis(3000)); let final_retry_after = http_client.retry_after.read().await;
2464 assert!(final_retry_after.is_some());
2465
2466 if let Some(retry_time) = *final_retry_after {
2468 let now = SystemTime::now();
2471 let diff = if retry_time > now {
2472 retry_time.duration_since(now).unwrap()
2473 } else {
2474 now.duration_since(retry_time).unwrap()
2475 };
2476
2477 assert!(diff <= Duration::from_secs(3), "Retry time difference too large: {:?}", diff);
2479 }
2480 }
2481
2482 #[tokio::test]
2483 async fn test_get_snapshots() {
2484 let mut server = Server::new_async().await;
2485
2486 let protocol_states_resp = r#"
2488 {
2489 "states": [
2490 {
2491 "component_id": "component1",
2492 "attributes": {
2493 "attribute_1": "0x00000000000003e8"
2494 },
2495 "balances": {
2496 "0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2": "0x01f4"
2497 }
2498 }
2499 ],
2500 "pagination": {
2501 "page": 0,
2502 "page_size": 100,
2503 "total": 1
2504 }
2505 }
2506 "#;
2507
2508 let contract_state_resp = r#"
2510 {
2511 "accounts": [
2512 {
2513 "chain": "ethereum",
2514 "address": "0x1111111111111111111111111111111111111111",
2515 "title": "",
2516 "slots": {},
2517 "native_balance": "0x01f4",
2518 "token_balances": {},
2519 "code": "0x00",
2520 "code_hash": "0x5c06b7c5b3d910fd33bc2229846f9ddaf91d584d9b196e16636901ac3a77077e",
2521 "balance_modify_tx": "0x0000000000000000000000000000000000000000000000000000000000000000",
2522 "code_modify_tx": "0x0000000000000000000000000000000000000000000000000000000000000000",
2523 "creation_tx": null
2524 }
2525 ],
2526 "pagination": {
2527 "page": 0,
2528 "page_size": 100,
2529 "total": 1
2530 }
2531 }
2532 "#;
2533
2534 let tvl_resp = r#"
2536 {
2537 "tvl": {
2538 "component1": 1000000.0
2539 },
2540 "pagination": {
2541 "page": 0,
2542 "page_size": 100,
2543 "total": 1
2544 }
2545 }
2546 "#;
2547
2548 let protocol_states_mock = server
2549 .mock("POST", "/v1/protocol_state")
2550 .expect(1)
2551 .with_body(protocol_states_resp)
2552 .create_async()
2553 .await;
2554
2555 let contract_state_mock = server
2556 .mock("POST", "/v1/contract_state")
2557 .expect(1)
2558 .with_body(contract_state_resp)
2559 .create_async()
2560 .await;
2561
2562 let tvl_mock = server
2563 .mock("POST", "/v1/component_tvl")
2564 .expect(1)
2565 .with_body(tvl_resp)
2566 .create_async()
2567 .await;
2568
2569 let client = HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
2570 .expect("create client");
2571
2572 #[allow(deprecated)]
2573 let component = tycho_common::dto::ProtocolComponent {
2574 id: "component1".to_string(),
2575 protocol_system: "test_protocol".to_string(),
2576 protocol_type_name: "test_type".to_string(),
2577 chain: Chain::Ethereum,
2578 tokens: vec![],
2579 contract_ids: vec![
2580 Bytes::from_str("0x1111111111111111111111111111111111111111").unwrap()
2581 ],
2582 static_attributes: HashMap::new(),
2583 change: tycho_common::dto::ChangeType::Creation,
2584 creation_tx: Bytes::from_str(
2585 "0x0000000000000000000000000000000000000000000000000000000000000000",
2586 )
2587 .unwrap(),
2588 created_at: chrono::Utc::now().naive_utc(),
2589 };
2590
2591 let mut components = HashMap::new();
2592 components.insert("component1".to_string(), component);
2593
2594 let contract_ids =
2595 vec![Bytes::from_str("0x1111111111111111111111111111111111111111").unwrap()];
2596
2597 let request = SnapshotParameters::new(
2598 Chain::Ethereum,
2599 "test_protocol",
2600 &components,
2601 &contract_ids,
2602 12345,
2603 );
2604
2605 let response = client
2606 .get_snapshots(&request, None, RPC_CLIENT_CONCURRENCY)
2607 .await
2608 .expect("get snapshots");
2609
2610 protocol_states_mock.assert();
2612 contract_state_mock.assert();
2613 tvl_mock.assert();
2614
2615 assert_eq!(response.states.len(), 1);
2617 assert!(response
2618 .states
2619 .contains_key("component1"));
2620
2621 let component_state = response
2623 .states
2624 .get("component1")
2625 .unwrap();
2626 assert_eq!(component_state.component_tvl, Some(1000000.0));
2627
2628 assert_eq!(response.vm_storage.len(), 1);
2630 let contract_addr = Bytes::from_str("0x1111111111111111111111111111111111111111").unwrap();
2631 assert!(response
2632 .vm_storage
2633 .contains_key(&contract_addr));
2634 }
2635
2636 #[tokio::test]
2637 async fn test_get_snapshots_empty_components() {
2638 let server = Server::new_async().await;
2639 let client = HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
2640 .expect("create client");
2641
2642 let components = HashMap::new();
2643 let contract_ids = vec![];
2644
2645 let request = SnapshotParameters::new(
2646 Chain::Ethereum,
2647 "test_protocol",
2648 &components,
2649 &contract_ids,
2650 12345,
2651 );
2652
2653 let response = client
2654 .get_snapshots(&request, None, RPC_CLIENT_CONCURRENCY)
2655 .await
2656 .expect("get snapshots");
2657
2658 assert!(response.states.is_empty());
2660 assert!(response.vm_storage.is_empty());
2661 }
2662
2663 #[tokio::test]
2664 async fn test_get_snapshots_without_tvl() {
2665 let mut server = Server::new_async().await;
2666
2667 let protocol_states_resp = r#"
2668 {
2669 "states": [
2670 {
2671 "component_id": "component1",
2672 "attributes": {},
2673 "balances": {}
2674 }
2675 ],
2676 "pagination": {
2677 "page": 0,
2678 "page_size": 100,
2679 "total": 1
2680 }
2681 }
2682 "#;
2683
2684 let protocol_states_mock = server
2685 .mock("POST", "/v1/protocol_state")
2686 .expect(1)
2687 .with_body(protocol_states_resp)
2688 .create_async()
2689 .await;
2690
2691 let client = HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
2692 .expect("create client");
2693
2694 #[allow(deprecated)]
2696 let component = tycho_common::dto::ProtocolComponent {
2697 id: "component1".to_string(),
2698 protocol_system: "test_protocol".to_string(),
2699 protocol_type_name: "test_type".to_string(),
2700 chain: Chain::Ethereum,
2701 tokens: vec![],
2702 contract_ids: vec![],
2703 static_attributes: HashMap::new(),
2704 change: tycho_common::dto::ChangeType::Creation,
2705 creation_tx: Bytes::from_str(
2706 "0x0000000000000000000000000000000000000000000000000000000000000000",
2707 )
2708 .unwrap(),
2709 created_at: chrono::Utc::now().naive_utc(),
2710 };
2711
2712 let mut components = HashMap::new();
2713 components.insert("component1".to_string(), component);
2714 let contract_ids = vec![];
2715
2716 let request = SnapshotParameters::new(
2717 Chain::Ethereum,
2718 "test_protocol",
2719 &components,
2720 &contract_ids,
2721 12345,
2722 )
2723 .include_balances(false)
2724 .include_tvl(false);
2725
2726 let response = client
2727 .get_snapshots(&request, None, RPC_CLIENT_CONCURRENCY)
2728 .await
2729 .expect("get snapshots");
2730
2731 protocol_states_mock.assert();
2733 assert_eq!(response.states.len(), 1);
2737 let component_state = response
2739 .states
2740 .get("component1")
2741 .unwrap();
2742 assert_eq!(component_state.component_tvl, None);
2743 }
2744
2745 #[tokio::test]
2746 async fn test_compression_enabled() {
2747 let mut server = Server::new_async().await;
2748 let server_resp = GET_CONTRACT_STATE_RESP;
2749
2750 let compressed_body =
2752 zstd::encode_all(server_resp.as_bytes(), 0).expect("compression failed");
2753
2754 let mocked_server = server
2755 .mock("POST", "/v1/contract_state")
2756 .expect(1)
2757 .with_header("Content-Encoding", "zstd")
2758 .with_body(compressed_body)
2759 .create_async()
2760 .await;
2761
2762 let client = HttpRPCClient::new(
2764 server.url().as_str(),
2765 HttpRPCClientOptions::new().with_compression(true),
2766 )
2767 .expect("create client");
2768
2769 let response = client
2770 .get_contract_state(&Default::default())
2771 .await
2772 .expect("get state");
2773 let accounts = response.accounts;
2774
2775 mocked_server.assert();
2776 assert_eq!(accounts.len(), 1);
2777 assert_eq!(accounts[0].native_balance, Bytes::from(500u16.to_be_bytes()));
2778 }
2779
2780 #[tokio::test]
2781 async fn test_compression_disabled() {
2782 let mut server = Server::new_async().await;
2783 let server_resp = GET_CONTRACT_STATE_RESP;
2784
2785 let mocked_server = server
2788 .mock("POST", "/v1/contract_state")
2789 .expect(1)
2790 .match_header("Accept-Encoding", mockito::Matcher::Missing)
2791 .with_status(200)
2792 .with_body(server_resp)
2793 .create_async()
2794 .await;
2795
2796 let client = HttpRPCClient::new(
2798 server.url().as_str(),
2799 HttpRPCClientOptions::new().with_compression(false),
2800 )
2801 .expect("create client");
2802
2803 let response = client
2804 .get_contract_state(&Default::default())
2805 .await
2806 .expect("get state");
2807 let accounts = response.accounts;
2808
2809 mocked_server.assert();
2811 assert_eq!(accounts.len(), 1);
2812 assert_eq!(accounts[0].native_balance, Bytes::from(500u16.to_be_bytes()));
2813 }
2814
2815 #[rstest]
2816 #[case::single_page(2, 1000)]
2817 #[case::multiple_pages_within_concurrency(10, 2)]
2818 #[case::exceeds_concurrency_limit(60, 2)]
2819 #[tokio::test]
2820 async fn test_get_all_tokens_pagination_and_concurrency(
2821 #[case] total_tokens: usize,
2822 #[case] page_size: usize,
2823 ) {
2824 use std::sync::atomic::{AtomicUsize, Ordering};
2825
2826 let allowed_concurrency = 10;
2827
2828 let concurrent_requests = Arc::new(AtomicUsize::new(0));
2829 let max_concurrent = Arc::new(AtomicUsize::new(0));
2830
2831 let mut server = Server::new_async().await;
2832
2833 let total_pages = (total_tokens as f64 / page_size as f64).ceil() as i64;
2834
2835 for page in 0..total_pages {
2837 let concurrent = concurrent_requests.clone();
2838 let max_conc = max_concurrent.clone();
2839
2840 let tokens_in_page = {
2841 let start_idx = (page as usize) * page_size;
2842 let end_idx = ((page as usize + 1) * page_size).min(total_tokens);
2843 (start_idx..end_idx)
2844 .map(|i| {
2845 format!(
2846 r#"{{
2847 "chain": "ethereum",
2848 "address": "0x{i:040x}",
2849 "symbol": "TOKEN_{i}",
2850 "decimals": 18,
2851 "tax": 0,
2852 "gas": [30000],
2853 "quality": 100
2854 }}"#
2855 )
2856 })
2857 .collect::<Vec<_>>()
2858 };
2859
2860 let tokens_json = tokens_in_page.join(",");
2861 let response = format!(
2862 r#"{{
2863 "tokens": [{tokens_json}],
2864 "pagination": {{
2865 "page": {page},
2866 "page_size": {page_size},
2867 "total": {total_tokens}
2868 }}
2869 }}"#,
2870 );
2871
2872 server
2873 .mock("POST", "/v1/tokens")
2874 .expect(1)
2875 .with_chunked_body(move |w| {
2876 let current = concurrent.fetch_add(1, Ordering::SeqCst);
2878 max_conc.fetch_max(current + 1, Ordering::SeqCst);
2879
2880 std::thread::sleep(Duration::from_millis(10));
2882
2883 concurrent.fetch_sub(1, Ordering::SeqCst);
2884
2885 w.write_all(response.as_bytes())
2886 })
2887 .create_async()
2888 .await;
2889 }
2890
2891 let client = HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
2892 .expect("create client");
2893
2894 let tokens = client
2895 .get_all_tokens(Chain::Ethereum, None, None, Some(page_size), allowed_concurrency)
2896 .await
2897 .expect("get all tokens");
2898
2899 let max = max_concurrent.load(Ordering::SeqCst);
2901 let expected_max_concurrency = (total_pages as usize)
2902 .saturating_sub(1)
2903 .min(allowed_concurrency);
2904 assert!(
2905 max <= allowed_concurrency,
2906 "Expected max concurrent requests <= {allowed_concurrency}, got {max}"
2907 );
2908
2909 if total_pages > 1 && expected_max_concurrency > 1 {
2911 assert!(
2912 max > 0,
2913 "Expected some concurrent requests for multi-page response, got {max}"
2914 );
2915 }
2916
2917 assert_eq!(
2919 tokens.len(),
2920 total_tokens,
2921 "Expected {total_tokens} tokens, got {}",
2922 tokens.len()
2923 );
2924
2925 for (i, token) in tokens.iter().enumerate() {
2927 assert_eq!(token.symbol, format!("TOKEN_{i}"), "Token at index {i} has wrong symbol");
2928 }
2929 }
2930}