1use std::{
7 collections::HashMap,
8 sync::Arc,
9 time::{Duration, SystemTime},
10};
11
12use async_trait::async_trait;
13use backoff::{exponential::ExponentialBackoffBuilder, ExponentialBackoff};
14use futures03::future::try_join_all;
15#[cfg(test)]
16use mockall::automock;
17use reqwest::{header, Client, ClientBuilder, Response, StatusCode, Url};
18use serde::Serialize;
19use thiserror::Error;
20use time::{format_description::well_known::Rfc2822, OffsetDateTime};
21use tokio::{
22 sync::{RwLock, Semaphore},
23 time::sleep,
24};
25use tracing::{debug, error, instrument, trace, warn};
26use tycho_common::{
27 dto::{
28 BlockParam, Chain, ComponentTvlRequestBody, ComponentTvlRequestResponse,
29 EntryPointWithTracingParams, PaginationLimits, PaginationParams, PaginationResponse,
30 ProtocolComponent, ProtocolComponentRequestResponse, ProtocolComponentsRequestBody,
31 ProtocolStateRequestBody, ProtocolStateRequestResponse, ProtocolSystemsRequestBody,
32 ProtocolSystemsRequestResponse, ResponseToken, StateRequestBody, StateRequestResponse,
33 TokensRequestBody, TokensRequestResponse, TracedEntryPointRequestBody,
34 TracedEntryPointRequestResponse, TracingResult, VersionParam,
35 },
36 models::ComponentId,
37 Bytes,
38};
39
40use crate::{
41 feed::synchronizer::{ComponentWithState, Snapshot},
42 TYCHO_SERVER_VERSION,
43};
44
45pub const RPC_CLIENT_CONCURRENCY: usize = 4;
47
48#[derive(Clone, Debug, PartialEq)]
53pub struct SnapshotParameters<'a> {
54 pub chain: Chain,
56 pub protocol_system: &'a str,
58 pub components: &'a HashMap<ComponentId, ProtocolComponent>,
60 pub entrypoints: Option<&'a HashMap<String, Vec<(EntryPointWithTracingParams, TracingResult)>>>,
62 pub contract_ids: &'a [Bytes],
64 pub block_number: u64,
66 pub include_balances: bool,
68 pub include_tvl: bool,
70}
71
72impl<'a> SnapshotParameters<'a> {
73 pub fn new(
74 chain: Chain,
75 protocol_system: &'a str,
76 components: &'a HashMap<ComponentId, ProtocolComponent>,
77 contract_ids: &'a [Bytes],
78 block_number: u64,
79 ) -> Self {
80 Self {
81 chain,
82 protocol_system,
83 components,
84 entrypoints: None,
85 contract_ids,
86 block_number,
87 include_balances: true,
88 include_tvl: true,
89 }
90 }
91
92 pub fn include_balances(mut self, include_balances: bool) -> Self {
94 self.include_balances = include_balances;
95 self
96 }
97
98 pub fn include_tvl(mut self, include_tvl: bool) -> Self {
100 self.include_tvl = include_tvl;
101 self
102 }
103
104 pub fn entrypoints(
105 mut self,
106 entrypoints: &'a HashMap<String, Vec<(EntryPointWithTracingParams, TracingResult)>>,
107 ) -> Self {
108 self.entrypoints = Some(entrypoints);
109 self
110 }
111}
112
113#[derive(Error, Debug)]
114pub enum RPCError {
115 #[error("Failed to parse URL: {0}. Error: {1}")]
117 UrlParsing(String, String),
118
119 #[error("Failed to format request: {0}")]
121 FormatRequest(String),
122
123 #[error("Unexpected HTTP client error: {0}")]
125 HttpClient(String, #[source] reqwest::Error),
126
127 #[error("Failed to parse response: {0}")]
129 ParseResponse(String),
130
131 #[error("Fatal error: {0}")]
133 Fatal(String),
134
135 #[error("Rate limited until {0:?}")]
136 RateLimited(Option<SystemTime>),
137
138 #[error("Server unreachable: {0}")]
139 ServerUnreachable(String),
140}
141
142#[cfg_attr(test, automock)]
143#[async_trait]
144pub trait RPCClient: Send + Sync {
145 fn compression(&self) -> bool;
147
148 async fn get_contract_state(
150 &self,
151 request: &StateRequestBody,
152 ) -> Result<StateRequestResponse, RPCError>;
153
154 async fn get_contract_state_paginated(
157 &self,
158 chain: Chain,
159 ids: &[Bytes],
160 protocol_system: &str,
161 version: &VersionParam,
162 chunk_size: Option<usize>,
163 concurrency: usize,
164 ) -> Result<StateRequestResponse, RPCError> {
165 let semaphore = Arc::new(Semaphore::new(concurrency));
166
167 let mut sorted_ids = ids.to_vec();
169 sorted_ids.sort();
170
171 let chunk_size = chunk_size
172 .unwrap_or(StateRequestBody::effective_max_page_size(self.compression()) as usize);
173
174 let chunked_bodies = sorted_ids
175 .chunks(chunk_size)
176 .map(|chunk| StateRequestBody {
177 contract_ids: Some(chunk.to_vec()),
178 protocol_system: protocol_system.to_string(),
179 chain,
180 version: version.clone(),
181 pagination: PaginationParams { page: 0, page_size: chunk_size as i64 },
182 })
183 .collect::<Vec<_>>();
184
185 let mut tasks = Vec::new();
186 for body in chunked_bodies.iter() {
187 let sem = semaphore.clone();
188 tasks.push(async move {
189 let _permit = sem
190 .acquire()
191 .await
192 .map_err(|_| RPCError::Fatal("Semaphore dropped".to_string()))?;
193 self.get_contract_state(body).await
194 });
195 }
196
197 let responses = try_join_all(tasks).await?;
199
200 let accounts = responses
202 .iter()
203 .flat_map(|r| r.accounts.clone())
204 .collect();
205 let total: i64 = responses
206 .iter()
207 .map(|r| r.pagination.total)
208 .sum();
209
210 Ok(StateRequestResponse {
211 accounts,
212 pagination: PaginationResponse { page: 0, page_size: chunk_size as i64, total },
213 })
214 }
215
216 async fn get_protocol_components(
217 &self,
218 request: &ProtocolComponentsRequestBody,
219 ) -> Result<ProtocolComponentRequestResponse, RPCError>;
220
221 async fn get_protocol_components_paginated(
224 &self,
225 request: &ProtocolComponentsRequestBody,
226 chunk_size: Option<usize>,
227 concurrency: usize,
228 ) -> Result<ProtocolComponentRequestResponse, RPCError> {
229 let semaphore = Arc::new(Semaphore::new(concurrency));
230
231 let chunk_size = chunk_size.unwrap_or(
232 ProtocolComponentsRequestBody::effective_max_page_size(self.compression()) as usize,
233 );
234
235 match request.component_ids {
238 Some(ref ids) => {
239 let chunked_bodies = ids
241 .chunks(chunk_size)
242 .enumerate()
243 .map(|(index, _)| ProtocolComponentsRequestBody {
244 protocol_system: request.protocol_system.clone(),
245 component_ids: request.component_ids.clone(),
246 tvl_gt: request.tvl_gt,
247 chain: request.chain,
248 pagination: PaginationParams {
249 page: index as i64,
250 page_size: chunk_size as i64,
251 },
252 })
253 .collect::<Vec<_>>();
254
255 let mut tasks = Vec::new();
256 for body in chunked_bodies.iter() {
257 let sem = semaphore.clone();
258 tasks.push(async move {
259 let _permit = sem
260 .acquire()
261 .await
262 .map_err(|_| RPCError::Fatal("Semaphore dropped".to_string()))?;
263 self.get_protocol_components(body).await
264 });
265 }
266
267 try_join_all(tasks)
268 .await
269 .map(|responses| ProtocolComponentRequestResponse {
270 protocol_components: responses
271 .into_iter()
272 .flat_map(|r| r.protocol_components.into_iter())
273 .collect(),
274 pagination: PaginationResponse {
275 page: 0,
276 page_size: chunk_size as i64,
277 total: ids.len() as i64,
278 },
279 })
280 }
281 _ => {
282 let initial_request = ProtocolComponentsRequestBody {
286 protocol_system: request.protocol_system.clone(),
287 component_ids: request.component_ids.clone(),
288 tvl_gt: request.tvl_gt,
289 chain: request.chain,
290 pagination: PaginationParams { page: 0, page_size: chunk_size as i64 },
291 };
292 let first_response = self
293 .get_protocol_components(&initial_request)
294 .await
295 .map_err(|err| RPCError::Fatal(err.to_string()))?;
296
297 let total_items = first_response.pagination.total;
298 let total_pages = (total_items as f64 / chunk_size as f64).ceil() as i64;
299
300 let mut accumulated_response = ProtocolComponentRequestResponse {
302 protocol_components: first_response.protocol_components,
303 pagination: PaginationResponse {
304 page: 0,
305 page_size: chunk_size as i64,
306 total: total_items,
307 },
308 };
309
310 let mut page = 1;
311 while page < total_pages {
312 let requests_in_this_iteration = (total_pages - page).min(concurrency as i64);
313
314 let chunked_bodies = (0..requests_in_this_iteration)
316 .map(|iter| ProtocolComponentsRequestBody {
317 protocol_system: request.protocol_system.clone(),
318 component_ids: request.component_ids.clone(),
319 tvl_gt: request.tvl_gt,
320 chain: request.chain,
321 pagination: PaginationParams {
322 page: page + iter,
323 page_size: chunk_size as i64,
324 },
325 })
326 .collect::<Vec<_>>();
327
328 let tasks: Vec<_> = chunked_bodies
329 .iter()
330 .map(|body| {
331 let sem = semaphore.clone();
332 async move {
333 let _permit = sem.acquire().await.map_err(|_| {
334 RPCError::Fatal("Semaphore dropped".to_string())
335 })?;
336 self.get_protocol_components(body).await
337 }
338 })
339 .collect();
340
341 let responses = try_join_all(tasks)
342 .await
343 .map(|responses| {
344 let total = responses[0].pagination.total;
345 ProtocolComponentRequestResponse {
346 protocol_components: responses
347 .into_iter()
348 .flat_map(|r| r.protocol_components.into_iter())
349 .collect(),
350 pagination: PaginationResponse {
351 page,
352 page_size: chunk_size as i64,
353 total,
354 },
355 }
356 });
357
358 match responses {
360 Ok(mut resp) => {
361 accumulated_response
362 .protocol_components
363 .append(&mut resp.protocol_components);
364 }
365 Err(e) => return Err(e),
366 }
367
368 page += concurrency as i64;
369 }
370 Ok(accumulated_response)
371 }
372 }
373 }
374
375 async fn get_protocol_states(
376 &self,
377 request: &ProtocolStateRequestBody,
378 ) -> Result<ProtocolStateRequestResponse, RPCError>;
379
380 #[allow(clippy::too_many_arguments)]
383 async fn get_protocol_states_paginated<T>(
384 &self,
385 chain: Chain,
386 ids: &[T],
387 protocol_system: &str,
388 include_balances: bool,
389 version: &VersionParam,
390 chunk_size: Option<usize>,
391 concurrency: usize,
392 ) -> Result<ProtocolStateRequestResponse, RPCError>
393 where
394 T: AsRef<str> + Sync + 'static,
395 {
396 let semaphore = Arc::new(Semaphore::new(concurrency));
397
398 let chunk_size = chunk_size.unwrap_or(ProtocolStateRequestBody::effective_max_page_size(
399 self.compression(),
400 ) as usize);
401
402 let chunked_bodies = ids
403 .chunks(chunk_size)
404 .map(|c| ProtocolStateRequestBody {
405 protocol_ids: Some(
406 c.iter()
407 .map(|id| id.as_ref().to_string())
408 .collect(),
409 ),
410 protocol_system: protocol_system.to_string(),
411 chain,
412 include_balances,
413 version: version.clone(),
414 pagination: PaginationParams { page: 0, page_size: chunk_size as i64 },
415 })
416 .collect::<Vec<_>>();
417
418 let mut tasks = Vec::new();
419 for body in chunked_bodies.iter() {
420 let sem = semaphore.clone();
421 tasks.push(async move {
422 let _permit = sem
423 .acquire()
424 .await
425 .map_err(|_| RPCError::Fatal("Semaphore dropped".to_string()))?;
426 self.get_protocol_states(body).await
427 });
428 }
429
430 try_join_all(tasks)
431 .await
432 .map(|responses| {
433 let states = responses
434 .clone()
435 .into_iter()
436 .flat_map(|r| r.states)
437 .collect();
438 let total = responses
439 .iter()
440 .map(|r| r.pagination.total)
441 .sum();
442 ProtocolStateRequestResponse {
443 states,
444 pagination: PaginationResponse { page: 0, page_size: chunk_size as i64, total },
445 }
446 })
447 }
448
449 async fn get_tokens(
452 &self,
453 request: &TokensRequestBody,
454 ) -> Result<TokensRequestResponse, RPCError>;
455
456 async fn get_all_tokens(
459 &self,
460 chain: Chain,
461 min_quality: Option<i32>,
462 traded_n_days_ago: Option<u64>,
463 chunk_size: Option<usize>,
464 concurrency: usize,
465 ) -> Result<Vec<ResponseToken>, RPCError> {
466 let chunk_size = chunk_size
467 .unwrap_or(TokensRequestBody::effective_max_page_size(self.compression()) as usize);
468
469 let semaphore = Arc::new(Semaphore::new(concurrency));
470
471 let page_size = chunk_size.try_into().map_err(|_| {
473 RPCError::FormatRequest("Failed to convert chunk_size into i64".to_string())
474 })?;
475
476 let initial_request = TokensRequestBody {
477 token_addresses: None,
478 min_quality,
479 traded_n_days_ago,
480 pagination: PaginationParams { page: 0, page_size },
481 chain,
482 };
483
484 let first_response = self
485 .get_tokens(&initial_request)
486 .await?;
487 let total_items = first_response.pagination.total;
488 let total_pages = (total_items as f64 / chunk_size as f64).ceil() as i64;
489
490 let mut all_tokens = first_response.tokens;
491
492 if total_pages <= 1 {
494 return Ok(all_tokens);
495 }
496
497 let tasks: Vec<_> = (1..total_pages)
499 .map(|page| {
500 let sem = semaphore.clone();
501 let request = TokensRequestBody {
502 token_addresses: None,
503 min_quality,
504 traded_n_days_ago,
505 pagination: PaginationParams { page, page_size },
506 chain,
507 };
508
509 async move {
510 let _permit = sem
512 .acquire()
513 .await
514 .map_err(|_| RPCError::Fatal("Semaphore dropped".to_string()))?;
515 self.get_tokens(&request).await
516 }
517 })
518 .collect();
519
520 let responses = try_join_all(tasks).await?;
522
523 for mut response in responses {
525 all_tokens.append(&mut response.tokens);
526 }
527
528 Ok(all_tokens)
529 }
530
531 async fn get_protocol_systems(
532 &self,
533 request: &ProtocolSystemsRequestBody,
534 ) -> Result<ProtocolSystemsRequestResponse, RPCError>;
535
536 async fn get_component_tvl(
537 &self,
538 request: &ComponentTvlRequestBody,
539 ) -> Result<ComponentTvlRequestResponse, RPCError>;
540
541 async fn get_component_tvl_paginated(
544 &self,
545 request: &ComponentTvlRequestBody,
546 chunk_size: Option<usize>,
547 concurrency: usize,
548 ) -> Result<ComponentTvlRequestResponse, RPCError> {
549 let semaphore = Arc::new(Semaphore::new(concurrency));
550
551 let chunk_size = chunk_size.unwrap_or(ComponentTvlRequestBody::effective_max_page_size(
552 self.compression(),
553 ) as usize);
554
555 match request.component_ids {
556 Some(ref ids) => {
557 let chunked_requests = ids
558 .chunks(chunk_size)
559 .enumerate()
560 .map(|(index, _)| ComponentTvlRequestBody {
561 chain: request.chain,
562 protocol_system: request.protocol_system.clone(),
563 component_ids: Some(ids.clone()),
564 pagination: PaginationParams {
565 page: index as i64,
566 page_size: chunk_size as i64,
567 },
568 })
569 .collect::<Vec<_>>();
570
571 let tasks: Vec<_> = chunked_requests
572 .into_iter()
573 .map(|req| {
574 let sem = semaphore.clone();
575 async move {
576 let _permit = sem
577 .acquire()
578 .await
579 .map_err(|_| RPCError::Fatal("Semaphore dropped".to_string()))?;
580 self.get_component_tvl(&req).await
581 }
582 })
583 .collect();
584
585 let responses = try_join_all(tasks).await?;
586
587 let mut merged_tvl = HashMap::new();
588 for resp in responses {
589 for (key, value) in resp.tvl {
590 *merged_tvl.entry(key).or_insert(0.0) = value;
591 }
592 }
593
594 Ok(ComponentTvlRequestResponse {
595 tvl: merged_tvl,
596 pagination: PaginationResponse {
597 page: 0,
598 page_size: chunk_size as i64,
599 total: ids.len() as i64,
600 },
601 })
602 }
603 _ => {
604 let first_request = ComponentTvlRequestBody {
605 chain: request.chain,
606 protocol_system: request.protocol_system.clone(),
607 component_ids: request.component_ids.clone(),
608 pagination: PaginationParams { page: 0, page_size: chunk_size as i64 },
609 };
610
611 let first_response = self
612 .get_component_tvl(&first_request)
613 .await?;
614 let total_items = first_response.pagination.total;
615 let total_pages = (total_items as f64 / chunk_size as f64).ceil() as i64;
616
617 let mut merged_tvl = first_response.tvl;
618
619 let mut page = 1;
620 while page < total_pages {
621 let requests_in_this_iteration = (total_pages - page).min(concurrency as i64);
622
623 let chunked_requests: Vec<_> = (0..requests_in_this_iteration)
624 .map(|i| ComponentTvlRequestBody {
625 chain: request.chain,
626 protocol_system: request.protocol_system.clone(),
627 component_ids: request.component_ids.clone(),
628 pagination: PaginationParams {
629 page: page + i,
630 page_size: chunk_size as i64,
631 },
632 })
633 .collect();
634
635 let tasks: Vec<_> = chunked_requests
636 .into_iter()
637 .map(|req| {
638 let sem = semaphore.clone();
639 async move {
640 let _permit = sem.acquire().await.map_err(|_| {
641 RPCError::Fatal("Semaphore dropped".to_string())
642 })?;
643 self.get_component_tvl(&req).await
644 }
645 })
646 .collect();
647
648 let responses = try_join_all(tasks).await?;
649
650 for resp in responses {
652 for (key, value) in resp.tvl {
653 *merged_tvl.entry(key).or_insert(0.0) += value;
654 }
655 }
656
657 page += concurrency as i64;
658 }
659
660 Ok(ComponentTvlRequestResponse {
661 tvl: merged_tvl,
662 pagination: PaginationResponse {
663 page: 0,
664 page_size: chunk_size as i64,
665 total: total_items,
666 },
667 })
668 }
669 }
670 }
671
672 async fn get_traced_entry_points(
673 &self,
674 request: &TracedEntryPointRequestBody,
675 ) -> Result<TracedEntryPointRequestResponse, RPCError>;
676
677 async fn get_traced_entry_points_paginated(
680 &self,
681 chain: Chain,
682 protocol_system: &str,
683 component_ids: &[String],
684 chunk_size: Option<usize>,
685 concurrency: usize,
686 ) -> Result<TracedEntryPointRequestResponse, RPCError> {
687 let semaphore = Arc::new(Semaphore::new(concurrency));
688
689 let chunk_size = chunk_size.unwrap_or(
690 TracedEntryPointRequestBody::effective_max_page_size(self.compression()) as usize,
691 );
692
693 let chunked_bodies = component_ids
694 .chunks(chunk_size)
695 .map(|c| TracedEntryPointRequestBody {
696 chain,
697 protocol_system: protocol_system.to_string(),
698 component_ids: Some(c.to_vec()),
699 pagination: PaginationParams { page: 0, page_size: chunk_size as i64 },
700 })
701 .collect::<Vec<_>>();
702
703 let mut tasks = Vec::new();
704 for body in chunked_bodies.iter() {
705 let sem = semaphore.clone();
706 tasks.push(async move {
707 let _permit = sem
708 .acquire()
709 .await
710 .map_err(|_| RPCError::Fatal("Semaphore dropped".to_string()))?;
711 self.get_traced_entry_points(body).await
712 });
713 }
714
715 try_join_all(tasks)
716 .await
717 .map(|responses| {
718 let traced_entry_points = responses
719 .clone()
720 .into_iter()
721 .flat_map(|r| r.traced_entry_points)
722 .collect();
723 let total = responses
724 .iter()
725 .map(|r| r.pagination.total)
726 .sum();
727 TracedEntryPointRequestResponse {
728 traced_entry_points,
729 pagination: PaginationResponse { page: 0, page_size: chunk_size as i64, total },
730 }
731 })
732 }
733
734 async fn get_snapshots<'a>(
735 &self,
736 request: &SnapshotParameters<'a>,
737 chunk_size: Option<usize>,
738 concurrency: usize,
739 ) -> Result<Snapshot, RPCError>;
740}
741
742#[derive(Debug, Clone)]
744pub struct HttpRPCClientOptions {
745 pub auth_key: Option<String>,
747 pub compression: bool,
750}
751
752impl Default for HttpRPCClientOptions {
753 fn default() -> Self {
754 Self::new()
755 }
756}
757
758impl HttpRPCClientOptions {
759 pub fn new() -> Self {
761 Self { auth_key: None, compression: true }
762 }
763
764 pub fn with_auth_key(mut self, auth_key: Option<String>) -> Self {
766 self.auth_key = auth_key;
767 self
768 }
769
770 pub fn with_compression(mut self, compression: bool) -> Self {
772 self.compression = compression;
773 self
774 }
775}
776
777#[derive(Debug, Clone)]
778pub struct HttpRPCClient {
779 http_client: Client,
780 url: Url,
781 retry_after: Arc<RwLock<Option<SystemTime>>>,
782 backoff_policy: ExponentialBackoff,
783 server_restart_duration: Duration,
784 compression: bool,
785}
786
787impl HttpRPCClient {
788 pub fn new(base_uri: &str, options: HttpRPCClientOptions) -> Result<Self, RPCError> {
789 let uri = base_uri
790 .parse::<Url>()
791 .map_err(|e| RPCError::UrlParsing(base_uri.to_string(), e.to_string()))?;
792
793 let mut headers = header::HeaderMap::new();
795 headers.insert(header::CONTENT_TYPE, header::HeaderValue::from_static("application/json"));
796 let user_agent = format!("tycho-client-{version}", version = env!("CARGO_PKG_VERSION"));
797 headers.insert(
798 header::USER_AGENT,
799 header::HeaderValue::from_str(&user_agent)
800 .map_err(|e| RPCError::FormatRequest(format!("Invalid user agent format: {e}")))?,
801 );
802
803 if let Some(key) = options.auth_key.as_deref() {
805 let mut auth_value = header::HeaderValue::from_str(key).map_err(|e| {
806 RPCError::FormatRequest(format!("Invalid authorization key format: {e}"))
807 })?;
808 auth_value.set_sensitive(true);
809 headers.insert(header::AUTHORIZATION, auth_value);
810 }
811
812 let mut client_builder = ClientBuilder::new()
813 .default_headers(headers)
814 .http2_prior_knowledge();
815
816 if !options.compression {
818 client_builder = client_builder.no_zstd();
819 }
820
821 let client = client_builder
822 .build()
823 .map_err(|e| RPCError::HttpClient(e.to_string(), e))?;
824
825 Ok(Self {
826 http_client: client,
827 url: uri,
828 retry_after: Arc::new(RwLock::new(None)),
829 backoff_policy: ExponentialBackoffBuilder::new()
830 .with_initial_interval(Duration::from_millis(250))
831 .with_multiplier(1.75)
833 .with_max_interval(Duration::from_secs(30))
835 .with_max_elapsed_time(Some(Duration::from_secs(125)))
837 .build(),
838 server_restart_duration: Duration::from_secs(120),
839 compression: options.compression,
840 })
841 }
842
843 #[cfg(test)]
844 pub fn with_test_backoff_policy(mut self) -> Self {
845 self.backoff_policy = ExponentialBackoffBuilder::new()
847 .with_initial_interval(Duration::from_millis(1))
848 .with_multiplier(1.1)
849 .with_max_interval(Duration::from_millis(5))
850 .with_max_elapsed_time(Some(Duration::from_millis(50)))
851 .build();
852 self.server_restart_duration = Duration::from_millis(50);
853 self
854 }
855
856 async fn error_for_response(
862 &self,
863 response: reqwest::Response,
864 ) -> Result<reqwest::Response, RPCError> {
865 match response.status() {
866 StatusCode::TOO_MANY_REQUESTS => {
867 let retry_after_raw = response
868 .headers()
869 .get(reqwest::header::RETRY_AFTER)
870 .and_then(|h| h.to_str().ok())
871 .and_then(parse_retry_value);
872
873 Err(RPCError::RateLimited(retry_after_raw))
874 }
875 StatusCode::BAD_GATEWAY |
876 StatusCode::SERVICE_UNAVAILABLE |
877 StatusCode::GATEWAY_TIMEOUT => Err(RPCError::ServerUnreachable(
878 response
879 .text()
880 .await
881 .unwrap_or_else(|_| "Server Unreachable".to_string()),
882 )),
883 _ => Ok(response),
884 }
885 }
886
887 async fn handle_error_for_backoff(&self, e: RPCError) -> backoff::Error<RPCError> {
893 match e {
894 RPCError::ServerUnreachable(_) => {
895 backoff::Error::retry_after(e, self.server_restart_duration)
896 }
897 RPCError::RateLimited(Some(until)) => {
898 let mut retry_after_guard = self.retry_after.write().await;
899 *retry_after_guard = Some(
900 retry_after_guard
901 .unwrap_or(until)
902 .max(until),
903 );
904
905 if let Ok(duration) = until.duration_since(SystemTime::now()) {
906 backoff::Error::retry_after(e, duration)
907 } else {
908 e.into()
909 }
910 }
911 RPCError::RateLimited(None) => e.into(),
912 _ => backoff::Error::permanent(e),
913 }
914 }
915
916 async fn wait_until_retry_after(&self) {
921 if let Some(&until) = self.retry_after.read().await.as_ref() {
922 let now = SystemTime::now();
923 if until > now {
924 if let Ok(duration) = until.duration_since(now) {
925 sleep(duration).await
926 }
927 }
928 }
929 }
930
931 async fn make_post_request<T: Serialize + ?Sized>(
936 &self,
937 request: &T,
938 uri: &String,
939 ) -> Result<Response, RPCError> {
940 self.wait_until_retry_after().await;
941 let response = backoff::future::retry(self.backoff_policy.clone(), || async {
942 let server_response = self
943 .http_client
944 .post(uri)
945 .json(request)
946 .send()
947 .await
948 .map_err(|e| RPCError::HttpClient(e.to_string(), e))?;
949
950 match self
951 .error_for_response(server_response)
952 .await
953 {
954 Ok(response) => Ok(response),
955 Err(e) => Err(self.handle_error_for_backoff(e).await),
956 }
957 })
958 .await?;
959 Ok(response)
960 }
961}
962
963fn parse_retry_value(val: &str) -> Option<SystemTime> {
964 if let Ok(secs) = val.parse::<u64>() {
965 return Some(SystemTime::now() + Duration::from_secs(secs));
966 }
967 if let Ok(date) = OffsetDateTime::parse(val, &Rfc2822) {
968 return Some(date.into());
969 }
970 None
971}
972
973#[async_trait]
974impl RPCClient for HttpRPCClient {
975 fn compression(&self) -> bool {
976 self.compression
977 }
978
979 #[instrument(skip(self, request))]
980 async fn get_contract_state(
981 &self,
982 request: &StateRequestBody,
983 ) -> Result<StateRequestResponse, RPCError> {
984 if request
986 .contract_ids
987 .as_ref()
988 .is_none_or(|ids| ids.is_empty())
989 {
990 warn!("No contract ids specified in request.");
991 }
992
993 let uri = format!(
994 "{}/{}/contract_state",
995 self.url
996 .to_string()
997 .trim_end_matches('/'),
998 TYCHO_SERVER_VERSION
999 );
1000 debug!(%uri, "Sending contract_state request to Tycho server");
1001 trace!(?request, "Sending request to Tycho server");
1002 let response = self
1003 .make_post_request(request, &uri)
1004 .await?;
1005 trace!(?response, "Received response from Tycho server");
1006
1007 let body = response
1008 .text()
1009 .await
1010 .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
1011 if body.is_empty() {
1012 return Ok(StateRequestResponse {
1014 accounts: vec![],
1015 pagination: PaginationResponse {
1016 page: request.pagination.page,
1017 page_size: request.pagination.page,
1018 total: 0,
1019 },
1020 });
1021 }
1022
1023 let accounts = serde_json::from_str::<StateRequestResponse>(&body)
1024 .map_err(|err| RPCError::ParseResponse(format!("Error: {err}, Body: {body}")))?;
1025 trace!(?accounts, "Received contract_state response from Tycho server");
1026
1027 Ok(accounts)
1028 }
1029
1030 async fn get_protocol_components(
1031 &self,
1032 request: &ProtocolComponentsRequestBody,
1033 ) -> Result<ProtocolComponentRequestResponse, RPCError> {
1034 let uri = format!(
1035 "{}/{}/protocol_components",
1036 self.url
1037 .to_string()
1038 .trim_end_matches('/'),
1039 TYCHO_SERVER_VERSION,
1040 );
1041 debug!(%uri, "Sending protocol_components request to Tycho server");
1042 trace!(?request, "Sending request to Tycho server");
1043
1044 let response = self
1045 .make_post_request(request, &uri)
1046 .await?;
1047
1048 trace!(?response, "Received response from Tycho server");
1049
1050 let body = response
1051 .text()
1052 .await
1053 .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
1054 let components = serde_json::from_str::<ProtocolComponentRequestResponse>(&body)
1055 .map_err(|err| RPCError::ParseResponse(format!("Error: {err}, Body: {body}")))?;
1056 trace!(?components, "Received protocol_components response from Tycho server");
1057
1058 Ok(components)
1059 }
1060
1061 async fn get_protocol_states(
1062 &self,
1063 request: &ProtocolStateRequestBody,
1064 ) -> Result<ProtocolStateRequestResponse, RPCError> {
1065 if request
1067 .protocol_ids
1068 .as_ref()
1069 .is_none_or(|ids| ids.is_empty())
1070 {
1071 warn!("No protocol ids specified in request.");
1072 }
1073
1074 let uri = format!(
1075 "{}/{}/protocol_state",
1076 self.url
1077 .to_string()
1078 .trim_end_matches('/'),
1079 TYCHO_SERVER_VERSION
1080 );
1081 debug!(%uri, "Sending protocol_states request to Tycho server");
1082 trace!(?request, "Sending request to Tycho server");
1083
1084 let response = self
1085 .make_post_request(request, &uri)
1086 .await?;
1087 trace!(?response, "Received response from Tycho server");
1088
1089 let body = response
1090 .text()
1091 .await
1092 .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
1093
1094 if body.is_empty() {
1095 return Ok(ProtocolStateRequestResponse {
1097 states: vec![],
1098 pagination: PaginationResponse {
1099 page: request.pagination.page,
1100 page_size: request.pagination.page_size,
1101 total: 0,
1102 },
1103 });
1104 }
1105
1106 let states = serde_json::from_str::<ProtocolStateRequestResponse>(&body)
1107 .map_err(|err| RPCError::ParseResponse(format!("Error: {err}, Body: {body}")))?;
1108 trace!(?states, "Received protocol_states response from Tycho server");
1109
1110 Ok(states)
1111 }
1112
1113 async fn get_tokens(
1114 &self,
1115 request: &TokensRequestBody,
1116 ) -> Result<TokensRequestResponse, RPCError> {
1117 let uri = format!(
1118 "{}/{}/tokens",
1119 self.url
1120 .to_string()
1121 .trim_end_matches('/'),
1122 TYCHO_SERVER_VERSION
1123 );
1124 debug!(%uri, "Sending tokens request to Tycho server");
1125
1126 let response = self
1127 .make_post_request(request, &uri)
1128 .await?;
1129
1130 let body = response
1131 .text()
1132 .await
1133 .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
1134 let tokens = serde_json::from_str::<TokensRequestResponse>(&body)
1135 .map_err(|err| RPCError::ParseResponse(format!("Error: {err}, Body: {body}")))?;
1136
1137 Ok(tokens)
1138 }
1139
1140 async fn get_protocol_systems(
1141 &self,
1142 request: &ProtocolSystemsRequestBody,
1143 ) -> Result<ProtocolSystemsRequestResponse, RPCError> {
1144 let uri = format!(
1145 "{}/{}/protocol_systems",
1146 self.url
1147 .to_string()
1148 .trim_end_matches('/'),
1149 TYCHO_SERVER_VERSION
1150 );
1151 debug!(%uri, "Sending protocol_systems request to Tycho server");
1152 trace!(?request, "Sending request to Tycho server");
1153 let response = self
1154 .make_post_request(request, &uri)
1155 .await?;
1156 trace!(?response, "Received response from Tycho server");
1157 let body = response
1158 .text()
1159 .await
1160 .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
1161 let protocol_systems = serde_json::from_str::<ProtocolSystemsRequestResponse>(&body)
1162 .map_err(|err| RPCError::ParseResponse(format!("Error: {err}, Body: {body}")))?;
1163 trace!(?protocol_systems, "Received protocol_systems response from Tycho server");
1164 Ok(protocol_systems)
1165 }
1166
1167 async fn get_component_tvl(
1168 &self,
1169 request: &ComponentTvlRequestBody,
1170 ) -> Result<ComponentTvlRequestResponse, RPCError> {
1171 let uri = format!(
1172 "{}/{}/component_tvl",
1173 self.url
1174 .to_string()
1175 .trim_end_matches('/'),
1176 TYCHO_SERVER_VERSION
1177 );
1178 debug!(%uri, "Sending get_component_tvl request to Tycho server");
1179 trace!(?request, "Sending request to Tycho server");
1180 let response = self
1181 .make_post_request(request, &uri)
1182 .await?;
1183 trace!(?response, "Received response from Tycho server");
1184 let body = response
1185 .text()
1186 .await
1187 .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
1188 let component_tvl =
1189 serde_json::from_str::<ComponentTvlRequestResponse>(&body).map_err(|err| {
1190 error!("Failed to parse component_tvl response: {:?}", &body);
1191 RPCError::ParseResponse(format!("Error: {err}, Body: {body}"))
1192 })?;
1193 trace!(?component_tvl, "Received component_tvl response from Tycho server");
1194 Ok(component_tvl)
1195 }
1196
1197 async fn get_traced_entry_points(
1198 &self,
1199 request: &TracedEntryPointRequestBody,
1200 ) -> Result<TracedEntryPointRequestResponse, RPCError> {
1201 let uri = format!(
1202 "{}/{TYCHO_SERVER_VERSION}/traced_entry_points",
1203 self.url
1204 .to_string()
1205 .trim_end_matches('/')
1206 );
1207 debug!(%uri, "Sending traced_entry_points request to Tycho server");
1208 trace!(?request, "Sending request to Tycho server");
1209
1210 let response = self
1211 .make_post_request(request, &uri)
1212 .await?;
1213
1214 trace!(?response, "Received response from Tycho server");
1215
1216 let body = response
1217 .text()
1218 .await
1219 .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
1220 let entrypoints =
1221 serde_json::from_str::<TracedEntryPointRequestResponse>(&body).map_err(|err| {
1222 error!("Failed to parse traced_entry_points response: {:?}", &body);
1223 RPCError::ParseResponse(format!("Error: {err}, Body: {body}"))
1224 })?;
1225 trace!(?entrypoints, "Received traced_entry_points response from Tycho server");
1226 Ok(entrypoints)
1227 }
1228
1229 async fn get_snapshots<'a>(
1230 &self,
1231 request: &SnapshotParameters<'a>,
1232 chunk_size: Option<usize>,
1233 concurrency: usize,
1234 ) -> Result<Snapshot, RPCError> {
1235 let component_ids: Vec<_> = request
1236 .components
1237 .keys()
1238 .cloned()
1239 .collect();
1240
1241 let version = VersionParam::new(
1242 None,
1243 Some({
1244 #[allow(deprecated)]
1245 BlockParam {
1246 hash: None,
1247 chain: Some(request.chain),
1248 number: Some(request.block_number as i64),
1249 }
1250 }),
1251 );
1252
1253 let component_tvl = if request.include_tvl && !component_ids.is_empty() {
1254 let body = ComponentTvlRequestBody::id_filtered(component_ids.clone(), request.chain);
1255 self.get_component_tvl_paginated(&body, chunk_size, concurrency)
1256 .await?
1257 .tvl
1258 } else {
1259 HashMap::new()
1260 };
1261
1262 let mut protocol_states = if !component_ids.is_empty() {
1263 self.get_protocol_states_paginated(
1264 request.chain,
1265 &component_ids,
1266 request.protocol_system,
1267 request.include_balances,
1268 &version,
1269 chunk_size,
1270 concurrency,
1271 )
1272 .await?
1273 .states
1274 .into_iter()
1275 .map(|state| (state.component_id.clone(), state))
1276 .collect()
1277 } else {
1278 HashMap::new()
1279 };
1280
1281 let states = request
1283 .components
1284 .values()
1285 .filter_map(|component| {
1286 if let Some(state) = protocol_states.remove(&component.id) {
1287 Some((
1288 component.id.clone(),
1289 ComponentWithState {
1290 state,
1291 component: component.clone(),
1292 component_tvl: component_tvl
1293 .get(&component.id)
1294 .cloned(),
1295 entrypoints: request
1296 .entrypoints
1297 .as_ref()
1298 .and_then(|map| map.get(&component.id))
1299 .cloned()
1300 .unwrap_or_default(),
1301 },
1302 ))
1303 } else if component_ids.contains(&component.id) {
1304 let component_id = &component.id;
1306 error!(?component_id, "Missing state for native component!");
1307 None
1308 } else {
1309 None
1310 }
1311 })
1312 .collect();
1313
1314 let vm_storage = if !request.contract_ids.is_empty() {
1315 let contract_states = self
1316 .get_contract_state_paginated(
1317 request.chain,
1318 request.contract_ids,
1319 request.protocol_system,
1320 &version,
1321 chunk_size,
1322 concurrency,
1323 )
1324 .await?
1325 .accounts
1326 .into_iter()
1327 .map(|acc| (acc.address.clone(), acc))
1328 .collect::<HashMap<_, _>>();
1329
1330 trace!(states=?&contract_states, "Retrieved ContractState");
1331
1332 let contract_address_to_components = request
1333 .components
1334 .iter()
1335 .filter_map(|(id, comp)| {
1336 if component_ids.contains(id) {
1337 Some(
1338 comp.contract_ids
1339 .iter()
1340 .map(|address| (address.clone(), comp.id.clone())),
1341 )
1342 } else {
1343 None
1344 }
1345 })
1346 .flatten()
1347 .fold(HashMap::<Bytes, Vec<String>>::new(), |mut acc, (addr, c_id)| {
1348 acc.entry(addr).or_default().push(c_id);
1349 acc
1350 });
1351
1352 request
1353 .contract_ids
1354 .iter()
1355 .filter_map(|address| {
1356 if let Some(state) = contract_states.get(address) {
1357 Some((address.clone(), state.clone()))
1358 } else if let Some(ids) = contract_address_to_components.get(address) {
1359 error!(
1361 ?address,
1362 ?ids,
1363 "Component with lacking contract storage encountered!"
1364 );
1365 None
1366 } else {
1367 None
1368 }
1369 })
1370 .collect()
1371 } else {
1372 HashMap::new()
1373 };
1374
1375 Ok(Snapshot { states, vm_storage })
1376 }
1377}
1378
1379#[cfg(test)]
1380mod tests {
1381 use std::{
1382 collections::{HashMap, HashSet},
1383 str::FromStr,
1384 };
1385
1386 use mockito::Server;
1387 use rstest::rstest;
1388 #[allow(deprecated)]
1390 use tycho_common::dto::ProtocolId;
1391 use tycho_common::dto::{AddressStorageLocation, TracingParams};
1392
1393 use super::*;
1394
1395 impl MockRPCClient {
1398 #[allow(clippy::too_many_arguments)]
1399 async fn test_get_protocol_states_paginated<T>(
1400 &self,
1401 chain: Chain,
1402 ids: &[T],
1403 protocol_system: &str,
1404 include_balances: bool,
1405 version: &VersionParam,
1406 chunk_size: usize,
1407 _concurrency: usize,
1408 ) -> Vec<ProtocolStateRequestBody>
1409 where
1410 T: AsRef<str> + Clone + Send + Sync + 'static,
1411 {
1412 ids.chunks(chunk_size)
1413 .map(|chunk| ProtocolStateRequestBody {
1414 protocol_ids: Some(
1415 chunk
1416 .iter()
1417 .map(|id| id.as_ref().to_string())
1418 .collect(),
1419 ),
1420 protocol_system: protocol_system.to_string(),
1421 chain,
1422 include_balances,
1423 version: version.clone(),
1424 pagination: PaginationParams { page: 0, page_size: chunk_size as i64 },
1425 })
1426 .collect()
1427 }
1428 }
1429
1430 const GET_CONTRACT_STATE_RESP: &str = r#"
1431 {
1432 "accounts": [
1433 {
1434 "chain": "ethereum",
1435 "address": "0x0000000000000000000000000000000000000000",
1436 "title": "",
1437 "slots": {},
1438 "native_balance": "0x01f4",
1439 "token_balances": {},
1440 "code": "0x00",
1441 "code_hash": "0x5c06b7c5b3d910fd33bc2229846f9ddaf91d584d9b196e16636901ac3a77077e",
1442 "balance_modify_tx": "0x0000000000000000000000000000000000000000000000000000000000000000",
1443 "code_modify_tx": "0x0000000000000000000000000000000000000000000000000000000000000000",
1444 "creation_tx": null
1445 }
1446 ],
1447 "pagination": {
1448 "page": 0,
1449 "page_size": 20,
1450 "total": 10
1451 }
1452 }
1453 "#;
1454
1455 #[allow(deprecated)]
1457 #[rstest]
1458 #[case::protocol_id_input(vec![
1459 ProtocolId { id: "id1".to_string(), chain: Chain::Ethereum },
1460 ProtocolId { id: "id2".to_string(), chain: Chain::Ethereum }
1461 ])]
1462 #[case::string_input(vec![
1463 "id1".to_string(),
1464 "id2".to_string()
1465 ])]
1466 #[tokio::test]
1467 async fn test_get_protocol_states_paginated_backwards_compatibility<T>(#[case] ids: Vec<T>)
1468 where
1469 T: AsRef<str> + Clone + Send + Sync + 'static,
1470 {
1471 let mock_client = MockRPCClient::new();
1472
1473 let request_bodies = mock_client
1474 .test_get_protocol_states_paginated(
1475 Chain::Ethereum,
1476 &ids,
1477 "test_system",
1478 true,
1479 &VersionParam::default(),
1480 2,
1481 2,
1482 )
1483 .await;
1484
1485 assert_eq!(request_bodies.len(), 1);
1487 assert_eq!(
1488 request_bodies[0]
1489 .protocol_ids
1490 .as_ref()
1491 .unwrap()
1492 .len(),
1493 2
1494 );
1495 }
1496
1497 #[tokio::test]
1498 async fn test_get_contract_state() {
1499 let mut server = Server::new_async().await;
1500 let server_resp = GET_CONTRACT_STATE_RESP;
1501 serde_json::from_str::<StateRequestResponse>(server_resp).expect("deserialize");
1503
1504 let mocked_server = server
1505 .mock("POST", "/v1/contract_state")
1506 .expect(1)
1507 .with_body(server_resp)
1508 .create_async()
1509 .await;
1510
1511 let client = HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
1512 .expect("create client");
1513
1514 let response = client
1515 .get_contract_state(&Default::default())
1516 .await
1517 .expect("get state");
1518 let accounts = response.accounts;
1519
1520 mocked_server.assert();
1521 assert_eq!(accounts.len(), 1);
1522 assert_eq!(accounts[0].slots, HashMap::new());
1523 assert_eq!(accounts[0].native_balance, Bytes::from(500u16.to_be_bytes()));
1524 assert_eq!(accounts[0].code, [0].to_vec());
1525 assert_eq!(
1526 accounts[0].code_hash,
1527 hex::decode("5c06b7c5b3d910fd33bc2229846f9ddaf91d584d9b196e16636901ac3a77077e")
1528 .unwrap()
1529 );
1530 }
1531
1532 #[tokio::test]
1533 async fn test_get_protocol_components() {
1534 let mut server = Server::new_async().await;
1535 let server_resp = r#"
1536 {
1537 "protocol_components": [
1538 {
1539 "id": "State1",
1540 "protocol_system": "ambient",
1541 "protocol_type_name": "Pool",
1542 "chain": "ethereum",
1543 "tokens": [
1544 "0x0000000000000000000000000000000000000000",
1545 "0x0000000000000000000000000000000000000001"
1546 ],
1547 "contract_ids": [
1548 "0x0000000000000000000000000000000000000000"
1549 ],
1550 "static_attributes": {
1551 "attribute_1": "0x00000000000003e8"
1552 },
1553 "change": "Creation",
1554 "creation_tx": "0x0000000000000000000000000000000000000000000000000000000000000000",
1555 "created_at": "2022-01-01T00:00:00"
1556 }
1557 ],
1558 "pagination": {
1559 "page": 0,
1560 "page_size": 20,
1561 "total": 10
1562 }
1563 }
1564 "#;
1565 serde_json::from_str::<ProtocolComponentRequestResponse>(server_resp).expect("deserialize");
1567
1568 let mocked_server = server
1569 .mock("POST", "/v1/protocol_components")
1570 .expect(1)
1571 .with_body(server_resp)
1572 .create_async()
1573 .await;
1574
1575 let client = HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
1576 .expect("create client");
1577
1578 let response = client
1579 .get_protocol_components(&Default::default())
1580 .await
1581 .expect("get state");
1582 let components = response.protocol_components;
1583
1584 mocked_server.assert();
1585 assert_eq!(components.len(), 1);
1586 assert_eq!(components[0].id, "State1");
1587 assert_eq!(components[0].protocol_system, "ambient");
1588 assert_eq!(components[0].protocol_type_name, "Pool");
1589 assert_eq!(components[0].tokens.len(), 2);
1590 let expected_attributes =
1591 [("attribute_1".to_string(), Bytes::from(1000_u64.to_be_bytes()))]
1592 .iter()
1593 .cloned()
1594 .collect::<HashMap<String, Bytes>>();
1595 assert_eq!(components[0].static_attributes, expected_attributes);
1596 }
1597
1598 #[tokio::test]
1599 async fn test_get_protocol_states() {
1600 let mut server = Server::new_async().await;
1601 let server_resp = r#"
1602 {
1603 "states": [
1604 {
1605 "component_id": "State1",
1606 "attributes": {
1607 "attribute_1": "0x00000000000003e8"
1608 },
1609 "balances": {
1610 "0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2": "0x01f4"
1611 }
1612 }
1613 ],
1614 "pagination": {
1615 "page": 0,
1616 "page_size": 20,
1617 "total": 10
1618 }
1619 }
1620 "#;
1621 serde_json::from_str::<ProtocolStateRequestResponse>(server_resp).expect("deserialize");
1623
1624 let mocked_server = server
1625 .mock("POST", "/v1/protocol_state")
1626 .expect(1)
1627 .with_body(server_resp)
1628 .create_async()
1629 .await;
1630 let client = HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
1631 .expect("create client");
1632
1633 let response = client
1634 .get_protocol_states(&Default::default())
1635 .await
1636 .expect("get state");
1637 let states = response.states;
1638
1639 mocked_server.assert();
1640 assert_eq!(states.len(), 1);
1641 assert_eq!(states[0].component_id, "State1");
1642 let expected_attributes =
1643 [("attribute_1".to_string(), Bytes::from(1000_u64.to_be_bytes()))]
1644 .iter()
1645 .cloned()
1646 .collect::<HashMap<String, Bytes>>();
1647 assert_eq!(states[0].attributes, expected_attributes);
1648 let expected_balances = [(
1649 Bytes::from_str("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2")
1650 .expect("Unsupported address format"),
1651 Bytes::from_str("0x01f4").unwrap(),
1652 )]
1653 .iter()
1654 .cloned()
1655 .collect::<HashMap<Bytes, Bytes>>();
1656 assert_eq!(states[0].balances, expected_balances);
1657 }
1658
1659 #[tokio::test]
1660 async fn test_get_tokens() {
1661 let mut server = Server::new_async().await;
1662 let server_resp = r#"
1663 {
1664 "tokens": [
1665 {
1666 "chain": "ethereum",
1667 "address": "0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2",
1668 "symbol": "WETH",
1669 "decimals": 18,
1670 "tax": 0,
1671 "gas": [
1672 29962
1673 ],
1674 "quality": 100
1675 },
1676 {
1677 "chain": "ethereum",
1678 "address": "0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48",
1679 "symbol": "USDC",
1680 "decimals": 6,
1681 "tax": 0,
1682 "gas": [
1683 40652
1684 ],
1685 "quality": 100
1686 }
1687 ],
1688 "pagination": {
1689 "page": 0,
1690 "page_size": 20,
1691 "total": 10
1692 }
1693 }
1694 "#;
1695 serde_json::from_str::<TokensRequestResponse>(server_resp).expect("deserialize");
1697
1698 let mocked_server = server
1699 .mock("POST", "/v1/tokens")
1700 .expect(1)
1701 .with_body(server_resp)
1702 .create_async()
1703 .await;
1704 let client = HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
1705 .expect("create client");
1706
1707 let response = client
1708 .get_tokens(&Default::default())
1709 .await
1710 .expect("get tokens");
1711
1712 let expected = vec![
1713 ResponseToken {
1714 chain: Chain::Ethereum,
1715 address: Bytes::from_str("0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2").unwrap(),
1716 symbol: "WETH".to_string(),
1717 decimals: 18,
1718 tax: 0,
1719 gas: vec![Some(29962)],
1720 quality: 100,
1721 },
1722 ResponseToken {
1723 chain: Chain::Ethereum,
1724 address: Bytes::from_str("0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48").unwrap(),
1725 symbol: "USDC".to_string(),
1726 decimals: 6,
1727 tax: 0,
1728 gas: vec![Some(40652)],
1729 quality: 100,
1730 },
1731 ];
1732
1733 mocked_server.assert();
1734 assert_eq!(response.tokens, expected);
1735 assert_eq!(response.pagination, PaginationResponse { page: 0, page_size: 20, total: 10 });
1736 }
1737
1738 #[rstest]
1739 #[case::with_dci(Some(vec!["system2"]), vec!["system2"])]
1740 #[case::backward_compat(None, vec![])]
1741 #[tokio::test]
1742 async fn test_get_protocol_systems(
1743 #[case] dci_protocols: Option<Vec<&str>>,
1744 #[case] expected_dci: Vec<&str>,
1745 ) {
1746 use serde_json::json;
1747
1748 let mut json_value = json!({
1749 "protocol_systems": ["system1", "system2"],
1750 "pagination": { "page": 0, "page_size": 20, "total": 2 }
1751 });
1752 if let Some(dci) = dci_protocols {
1753 json_value["dci_protocols"] = json!(dci);
1754 }
1755 let server_resp = serde_json::to_string(&json_value).unwrap();
1756
1757 let mut server = Server::new_async().await;
1758 let mocked_server = server
1759 .mock("POST", "/v1/protocol_systems")
1760 .expect(1)
1761 .with_body(&server_resp)
1762 .create_async()
1763 .await;
1764 let client = HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
1765 .expect("create client");
1766
1767 let response = client
1768 .get_protocol_systems(&Default::default())
1769 .await
1770 .expect("get protocol systems");
1771
1772 mocked_server.assert();
1773 assert_eq!(response.protocol_systems, vec!["system1", "system2"]);
1774 assert_eq!(response.dci_protocols, expected_dci);
1775 }
1776
1777 #[tokio::test]
1778 async fn test_get_component_tvl() {
1779 let mut server = Server::new_async().await;
1780 let server_resp = r#"
1781 {
1782 "tvl": {
1783 "component1": 100.0
1784 },
1785 "pagination": {
1786 "page": 0,
1787 "page_size": 20,
1788 "total": 10
1789 }
1790 }
1791 "#;
1792 serde_json::from_str::<ComponentTvlRequestResponse>(server_resp).expect("deserialize");
1794
1795 let mocked_server = server
1796 .mock("POST", "/v1/component_tvl")
1797 .expect(1)
1798 .with_body(server_resp)
1799 .create_async()
1800 .await;
1801 let client = HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
1802 .expect("create client");
1803
1804 let response = client
1805 .get_component_tvl(&Default::default())
1806 .await
1807 .expect("get protocol systems");
1808 let component_tvl = response.tvl;
1809
1810 mocked_server.assert();
1811 assert_eq!(component_tvl.get("component1"), Some(&100.0));
1812 }
1813
1814 #[tokio::test]
1815 async fn test_get_traced_entry_points() {
1816 let mut server = Server::new_async().await;
1817 let server_resp = r#"
1818 {
1819 "traced_entry_points": {
1820 "component_1": [
1821 [
1822 {
1823 "entry_point": {
1824 "external_id": "entrypoint_a",
1825 "target": "0x0000000000000000000000000000000000000001",
1826 "signature": "sig()"
1827 },
1828 "params": {
1829 "method": "rpctracer",
1830 "caller": "0x000000000000000000000000000000000000000a",
1831 "calldata": "0x000000000000000000000000000000000000000b"
1832 }
1833 },
1834 {
1835 "retriggers": [
1836 [
1837 "0x00000000000000000000000000000000000000aa",
1838 {"key": "0x0000000000000000000000000000000000000aaa", "offset": 12}
1839 ]
1840 ],
1841 "accessed_slots": {
1842 "0x0000000000000000000000000000000000aaaa": [
1843 "0x0000000000000000000000000000000000aaaa"
1844 ]
1845 }
1846 }
1847 ]
1848 ]
1849 },
1850 "pagination": {
1851 "page": 0,
1852 "page_size": 20,
1853 "total": 1
1854 }
1855 }
1856 "#;
1857 serde_json::from_str::<TracedEntryPointRequestResponse>(server_resp).expect("deserialize");
1859
1860 let mocked_server = server
1861 .mock("POST", "/v1/traced_entry_points")
1862 .expect(1)
1863 .with_body(server_resp)
1864 .create_async()
1865 .await;
1866 let client = HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
1867 .expect("create client");
1868
1869 let response = client
1870 .get_traced_entry_points(&Default::default())
1871 .await
1872 .expect("get traced entry points");
1873 let entrypoints = response.traced_entry_points;
1874
1875 mocked_server.assert();
1876 assert_eq!(entrypoints.len(), 1);
1877 let comp1_entrypoints = entrypoints
1878 .get("component_1")
1879 .expect("component_1 entrypoints should exist");
1880 assert_eq!(comp1_entrypoints.len(), 1);
1881
1882 let (entrypoint, trace_result) = &comp1_entrypoints[0];
1883 assert_eq!(entrypoint.entry_point.external_id, "entrypoint_a");
1884 assert_eq!(
1885 entrypoint.entry_point.target,
1886 Bytes::from_str("0x0000000000000000000000000000000000000001").unwrap()
1887 );
1888 assert_eq!(entrypoint.entry_point.signature, "sig()");
1889 let TracingParams::RPCTracer(rpc_params) = &entrypoint.params;
1890 assert_eq!(
1891 rpc_params.caller,
1892 Some(Bytes::from("0x000000000000000000000000000000000000000a"))
1893 );
1894 assert_eq!(rpc_params.calldata, Bytes::from("0x000000000000000000000000000000000000000b"));
1895
1896 assert_eq!(
1897 trace_result.retriggers,
1898 HashSet::from([(
1899 Bytes::from("0x00000000000000000000000000000000000000aa"),
1900 AddressStorageLocation::new(
1901 Bytes::from("0x0000000000000000000000000000000000000aaa"),
1902 12
1903 )
1904 )])
1905 );
1906 assert_eq!(trace_result.accessed_slots.len(), 1);
1907 assert_eq!(
1908 trace_result.accessed_slots,
1909 HashMap::from([(
1910 Bytes::from("0x0000000000000000000000000000000000aaaa"),
1911 HashSet::from([Bytes::from("0x0000000000000000000000000000000000aaaa")])
1912 )])
1913 );
1914 }
1915
1916 #[tokio::test]
1917 async fn test_parse_retry_value_numeric() {
1918 let result = parse_retry_value("60");
1919 assert!(result.is_some());
1920
1921 let expected_time = SystemTime::now() + Duration::from_secs(60);
1922 let actual_time = result.unwrap();
1923
1924 let diff = if actual_time > expected_time {
1926 actual_time
1927 .duration_since(expected_time)
1928 .unwrap()
1929 } else {
1930 expected_time
1931 .duration_since(actual_time)
1932 .unwrap()
1933 };
1934 assert!(diff < Duration::from_secs(1), "Time difference too large: {:?}", diff);
1935 }
1936
1937 #[tokio::test]
1938 async fn test_parse_retry_value_rfc2822() {
1939 let rfc2822_date = "Sat, 01 Jan 2030 12:00:00 +0000";
1941 let result = parse_retry_value(rfc2822_date);
1942 assert!(result.is_some());
1943
1944 let parsed_time = result.unwrap();
1945 assert!(parsed_time > SystemTime::now());
1946 }
1947
1948 #[tokio::test]
1949 async fn test_parse_retry_value_invalid_formats() {
1950 assert!(parse_retry_value("invalid").is_none());
1952 assert!(parse_retry_value("").is_none());
1953 assert!(parse_retry_value("not_a_number").is_none());
1954 assert!(parse_retry_value("Mon, 32 Jan 2030 25:00:00 +0000").is_none()); }
1956
1957 #[tokio::test]
1958 async fn test_parse_retry_value_zero_seconds() {
1959 let result = parse_retry_value("0");
1960 assert!(result.is_some());
1961
1962 let expected_time = SystemTime::now();
1963 let actual_time = result.unwrap();
1964
1965 let diff = if actual_time > expected_time {
1967 actual_time
1968 .duration_since(expected_time)
1969 .unwrap()
1970 } else {
1971 expected_time
1972 .duration_since(actual_time)
1973 .unwrap()
1974 };
1975 assert!(diff < Duration::from_secs(1));
1976 }
1977
1978 #[tokio::test]
1979 async fn test_error_for_response_rate_limited() {
1980 let mut server = Server::new_async().await;
1981 let mock = server
1982 .mock("GET", "/test")
1983 .with_status(429)
1984 .with_header("Retry-After", "60")
1985 .create_async()
1986 .await;
1987
1988 let client = reqwest::Client::new();
1989 let response = client
1990 .get(format!("{}/test", server.url()))
1991 .send()
1992 .await
1993 .unwrap();
1994
1995 let http_client =
1996 HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
1997 .unwrap()
1998 .with_test_backoff_policy();
1999 let result = http_client
2000 .error_for_response(response)
2001 .await;
2002
2003 mock.assert();
2004 assert!(matches!(result, Err(RPCError::RateLimited(_))));
2005 if let Err(RPCError::RateLimited(retry_after)) = result {
2006 assert!(retry_after.is_some());
2007 }
2008 }
2009
2010 #[tokio::test]
2011 async fn test_error_for_response_rate_limited_no_header() {
2012 let mut server = Server::new_async().await;
2013 let mock = server
2014 .mock("GET", "/test")
2015 .with_status(429)
2016 .create_async()
2017 .await;
2018
2019 let client = reqwest::Client::new();
2020 let response = client
2021 .get(format!("{}/test", server.url()))
2022 .send()
2023 .await
2024 .unwrap();
2025
2026 let http_client =
2027 HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
2028 .unwrap()
2029 .with_test_backoff_policy();
2030 let result = http_client
2031 .error_for_response(response)
2032 .await;
2033
2034 mock.assert();
2035 assert!(matches!(result, Err(RPCError::RateLimited(None))));
2036 }
2037
2038 #[tokio::test]
2039 async fn test_error_for_response_server_errors() {
2040 let test_cases =
2041 vec![(502, "Bad Gateway"), (503, "Service Unavailable"), (504, "Gateway Timeout")];
2042
2043 for (status_code, expected_body) in test_cases {
2044 let mut server = Server::new_async().await;
2045 let mock = server
2046 .mock("GET", "/test")
2047 .with_status(status_code)
2048 .with_body(expected_body)
2049 .create_async()
2050 .await;
2051
2052 let client = reqwest::Client::new();
2053 let response = client
2054 .get(format!("{}/test", server.url()))
2055 .send()
2056 .await
2057 .unwrap();
2058
2059 let http_client =
2060 HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
2061 .unwrap()
2062 .with_test_backoff_policy();
2063 let result = http_client
2064 .error_for_response(response)
2065 .await;
2066
2067 mock.assert();
2068 assert!(matches!(result, Err(RPCError::ServerUnreachable(_))));
2069 if let Err(RPCError::ServerUnreachable(body)) = result {
2070 assert_eq!(body, expected_body);
2071 }
2072 }
2073 }
2074
2075 #[tokio::test]
2076 async fn test_error_for_response_success() {
2077 let mut server = Server::new_async().await;
2078 let mock = server
2079 .mock("GET", "/test")
2080 .with_status(200)
2081 .with_body("success")
2082 .create_async()
2083 .await;
2084
2085 let client = reqwest::Client::new();
2086 let response = client
2087 .get(format!("{}/test", server.url()))
2088 .send()
2089 .await
2090 .unwrap();
2091
2092 let http_client =
2093 HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
2094 .unwrap()
2095 .with_test_backoff_policy();
2096 let result = http_client
2097 .error_for_response(response)
2098 .await;
2099
2100 mock.assert();
2101 assert!(result.is_ok());
2102
2103 let response = result.unwrap();
2104 assert_eq!(response.status(), 200);
2105 }
2106
2107 #[tokio::test]
2108 async fn test_handle_error_for_backoff_server_unreachable() {
2109 let http_client =
2110 HttpRPCClient::new("http://localhost:8080", HttpRPCClientOptions::default())
2111 .unwrap()
2112 .with_test_backoff_policy();
2113 let error = RPCError::ServerUnreachable("Service down".to_string());
2114
2115 let backoff_error = http_client
2116 .handle_error_for_backoff(error)
2117 .await;
2118
2119 match backoff_error {
2120 backoff::Error::Transient { err: RPCError::ServerUnreachable(msg), retry_after } => {
2121 assert_eq!(msg, "Service down");
2122 assert_eq!(retry_after, Some(Duration::from_millis(50))); }
2124 _ => panic!("Expected transient error for ServerUnreachable"),
2125 }
2126 }
2127
2128 #[tokio::test]
2129 async fn test_handle_error_for_backoff_rate_limited_with_retry_after() {
2130 let http_client =
2131 HttpRPCClient::new("http://localhost:8080", HttpRPCClientOptions::default())
2132 .unwrap()
2133 .with_test_backoff_policy();
2134 let future_time = SystemTime::now() + Duration::from_secs(30);
2135 let error = RPCError::RateLimited(Some(future_time));
2136
2137 let backoff_error = http_client
2138 .handle_error_for_backoff(error)
2139 .await;
2140
2141 match backoff_error {
2142 backoff::Error::Transient { err: RPCError::RateLimited(retry_after), .. } => {
2143 assert_eq!(retry_after, Some(future_time));
2144 }
2145 _ => panic!("Expected transient error for RateLimited"),
2146 }
2147
2148 let stored_retry_after = http_client.retry_after.read().await;
2150 assert_eq!(*stored_retry_after, Some(future_time));
2151 }
2152
2153 #[tokio::test]
2154 async fn test_handle_error_for_backoff_rate_limited_no_retry_after() {
2155 let http_client =
2156 HttpRPCClient::new("http://localhost:8080", HttpRPCClientOptions::default())
2157 .unwrap()
2158 .with_test_backoff_policy();
2159 let error = RPCError::RateLimited(None);
2160
2161 let backoff_error = http_client
2162 .handle_error_for_backoff(error)
2163 .await;
2164
2165 match backoff_error {
2166 backoff::Error::Transient { err: RPCError::RateLimited(None), .. } => {
2167 }
2169 _ => panic!("Expected transient error for RateLimited without retry-after"),
2170 }
2171 }
2172
2173 #[tokio::test]
2174 async fn test_handle_error_for_backoff_other_errors() {
2175 let http_client =
2176 HttpRPCClient::new("http://localhost:8080", HttpRPCClientOptions::default())
2177 .unwrap()
2178 .with_test_backoff_policy();
2179 let error = RPCError::ParseResponse("Invalid JSON".to_string());
2180
2181 let backoff_error = http_client
2182 .handle_error_for_backoff(error)
2183 .await;
2184
2185 match backoff_error {
2186 backoff::Error::Permanent(RPCError::ParseResponse(msg)) => {
2187 assert_eq!(msg, "Invalid JSON");
2188 }
2189 _ => panic!("Expected permanent error for ParseResponse"),
2190 }
2191 }
2192
2193 #[tokio::test]
2194 async fn test_wait_until_retry_after_no_retry_time() {
2195 let http_client =
2196 HttpRPCClient::new("http://localhost:8080", HttpRPCClientOptions::default())
2197 .unwrap()
2198 .with_test_backoff_policy();
2199
2200 let start = std::time::Instant::now();
2201 http_client
2202 .wait_until_retry_after()
2203 .await;
2204 let elapsed = start.elapsed();
2205
2206 assert!(elapsed < Duration::from_millis(100));
2208 }
2209
2210 #[tokio::test]
2211 async fn test_wait_until_retry_after_past_time() {
2212 let http_client =
2213 HttpRPCClient::new("http://localhost:8080", HttpRPCClientOptions::default())
2214 .unwrap()
2215 .with_test_backoff_policy();
2216
2217 let past_time = SystemTime::now() - Duration::from_secs(10);
2219 *http_client.retry_after.write().await = Some(past_time);
2220
2221 let start = std::time::Instant::now();
2222 http_client
2223 .wait_until_retry_after()
2224 .await;
2225 let elapsed = start.elapsed();
2226
2227 assert!(elapsed < Duration::from_millis(100));
2229 }
2230
2231 #[tokio::test]
2232 async fn test_wait_until_retry_after_future_time() {
2233 let http_client =
2234 HttpRPCClient::new("http://localhost:8080", HttpRPCClientOptions::default())
2235 .unwrap()
2236 .with_test_backoff_policy();
2237
2238 let future_time = SystemTime::now() + Duration::from_millis(100);
2240 *http_client.retry_after.write().await = Some(future_time);
2241
2242 let start = std::time::Instant::now();
2243 http_client
2244 .wait_until_retry_after()
2245 .await;
2246 let elapsed = start.elapsed();
2247
2248 assert!(elapsed >= Duration::from_millis(80)); assert!(elapsed <= Duration::from_millis(200)); }
2252
2253 #[tokio::test]
2254 async fn test_make_post_request_success() {
2255 let mut server = Server::new_async().await;
2256 let server_resp = r#"{"success": true}"#;
2257
2258 let mock = server
2259 .mock("POST", "/test")
2260 .with_status(200)
2261 .with_body(server_resp)
2262 .create_async()
2263 .await;
2264
2265 let http_client =
2266 HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
2267 .unwrap()
2268 .with_test_backoff_policy();
2269 let request_body = serde_json::json!({"test": "data"});
2270 let uri = format!("{}/test", server.url());
2271
2272 let result = http_client
2273 .make_post_request(&request_body, &uri)
2274 .await;
2275
2276 mock.assert();
2277 assert!(result.is_ok());
2278
2279 let response = result.unwrap();
2280 assert_eq!(response.status(), 200);
2281 assert_eq!(response.text().await.unwrap(), server_resp);
2282 }
2283
2284 #[tokio::test]
2285 async fn test_make_post_request_retry_on_server_error() {
2286 let mut server = Server::new_async().await;
2287 let error_mock = server
2289 .mock("POST", "/test")
2290 .with_status(503)
2291 .with_body("Service Unavailable")
2292 .expect(1)
2293 .create_async()
2294 .await;
2295
2296 let success_mock = server
2297 .mock("POST", "/test")
2298 .with_status(200)
2299 .with_body(r#"{"success": true}"#)
2300 .expect(1)
2301 .create_async()
2302 .await;
2303
2304 let http_client =
2305 HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
2306 .unwrap()
2307 .with_test_backoff_policy();
2308 let request_body = serde_json::json!({"test": "data"});
2309 let uri = format!("{}/test", server.url());
2310
2311 let result = http_client
2312 .make_post_request(&request_body, &uri)
2313 .await;
2314
2315 error_mock.assert();
2316 success_mock.assert();
2317 assert!(result.is_ok());
2318 }
2319
2320 #[tokio::test]
2321 async fn test_make_post_request_respect_retry_after_header() {
2322 let mut server = Server::new_async().await;
2323
2324 let rate_limit_mock = server
2326 .mock("POST", "/test")
2327 .with_status(429)
2328 .with_header("Retry-After", "1") .expect(1)
2330 .create_async()
2331 .await;
2332
2333 let success_mock = server
2334 .mock("POST", "/test")
2335 .with_status(200)
2336 .with_body(r#"{"success": true}"#)
2337 .expect(1)
2338 .create_async()
2339 .await;
2340
2341 let http_client =
2342 HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
2343 .unwrap()
2344 .with_test_backoff_policy();
2345 let request_body = serde_json::json!({"test": "data"});
2346 let uri = format!("{}/test", server.url());
2347
2348 let start = std::time::Instant::now();
2349 let result = http_client
2350 .make_post_request(&request_body, &uri)
2351 .await;
2352 let elapsed = start.elapsed();
2353
2354 rate_limit_mock.assert();
2355 success_mock.assert();
2356 assert!(result.is_ok());
2357
2358 assert!(elapsed >= Duration::from_millis(900)); assert!(elapsed <= Duration::from_millis(2000)); }
2362
2363 #[tokio::test]
2364 async fn test_make_post_request_permanent_error() {
2365 let mut server = Server::new_async().await;
2366
2367 let mock = server
2368 .mock("POST", "/test")
2369 .with_status(400) .with_body("Bad Request")
2371 .expect(1)
2372 .create_async()
2373 .await;
2374
2375 let http_client =
2376 HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
2377 .unwrap()
2378 .with_test_backoff_policy();
2379 let request_body = serde_json::json!({"test": "data"});
2380 let uri = format!("{}/test", server.url());
2381
2382 let result = http_client
2383 .make_post_request(&request_body, &uri)
2384 .await;
2385
2386 mock.assert();
2387 assert!(result.is_ok()); let response = result.unwrap();
2390 assert_eq!(response.status(), 400);
2391 }
2392
2393 #[tokio::test]
2394 async fn test_concurrent_requests_with_different_retry_after() {
2395 let mut server = Server::new_async().await;
2396
2397 let rate_limit_mock_1 = server
2399 .mock("POST", "/test1")
2400 .with_status(429)
2401 .with_header("Retry-After", "1")
2402 .expect(1)
2403 .create_async()
2404 .await;
2405
2406 let rate_limit_mock_2 = server
2408 .mock("POST", "/test2")
2409 .with_status(429)
2410 .with_header("Retry-After", "2")
2411 .expect(1)
2412 .create_async()
2413 .await;
2414
2415 let success_mock_1 = server
2417 .mock("POST", "/test1")
2418 .with_status(200)
2419 .with_body(r#"{"result": "success1"}"#)
2420 .expect(1)
2421 .create_async()
2422 .await;
2423
2424 let success_mock_2 = server
2425 .mock("POST", "/test2")
2426 .with_status(200)
2427 .with_body(r#"{"result": "success2"}"#)
2428 .expect(1)
2429 .create_async()
2430 .await;
2431
2432 let http_client =
2433 HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
2434 .unwrap()
2435 .with_test_backoff_policy();
2436 let request_body = serde_json::json!({"test": "data"});
2437
2438 let uri1 = format!("{}/test1", server.url());
2439 let uri2 = format!("{}/test2", server.url());
2440
2441 let start = std::time::Instant::now();
2443 let (result1, result2) = tokio::join!(
2444 http_client.make_post_request(&request_body, &uri1),
2445 http_client.make_post_request(&request_body, &uri2)
2446 );
2447 let elapsed = start.elapsed();
2448
2449 rate_limit_mock_1.assert();
2450 rate_limit_mock_2.assert();
2451 success_mock_1.assert();
2452 success_mock_2.assert();
2453
2454 assert!(result1.is_ok());
2455 assert!(result2.is_ok());
2456
2457 assert!(elapsed >= Duration::from_millis(1800)); assert!(elapsed <= Duration::from_millis(3000)); let final_retry_after = http_client.retry_after.read().await;
2465 assert!(final_retry_after.is_some());
2466
2467 if let Some(retry_time) = *final_retry_after {
2469 let now = SystemTime::now();
2472 let diff = if retry_time > now {
2473 retry_time.duration_since(now).unwrap()
2474 } else {
2475 now.duration_since(retry_time).unwrap()
2476 };
2477
2478 assert!(diff <= Duration::from_secs(3), "Retry time difference too large: {:?}", diff);
2480 }
2481 }
2482
2483 #[tokio::test]
2484 async fn test_get_snapshots() {
2485 let mut server = Server::new_async().await;
2486
2487 let protocol_states_resp = r#"
2489 {
2490 "states": [
2491 {
2492 "component_id": "component1",
2493 "attributes": {
2494 "attribute_1": "0x00000000000003e8"
2495 },
2496 "balances": {
2497 "0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2": "0x01f4"
2498 }
2499 }
2500 ],
2501 "pagination": {
2502 "page": 0,
2503 "page_size": 100,
2504 "total": 1
2505 }
2506 }
2507 "#;
2508
2509 let contract_state_resp = r#"
2511 {
2512 "accounts": [
2513 {
2514 "chain": "ethereum",
2515 "address": "0x1111111111111111111111111111111111111111",
2516 "title": "",
2517 "slots": {},
2518 "native_balance": "0x01f4",
2519 "token_balances": {},
2520 "code": "0x00",
2521 "code_hash": "0x5c06b7c5b3d910fd33bc2229846f9ddaf91d584d9b196e16636901ac3a77077e",
2522 "balance_modify_tx": "0x0000000000000000000000000000000000000000000000000000000000000000",
2523 "code_modify_tx": "0x0000000000000000000000000000000000000000000000000000000000000000",
2524 "creation_tx": null
2525 }
2526 ],
2527 "pagination": {
2528 "page": 0,
2529 "page_size": 100,
2530 "total": 1
2531 }
2532 }
2533 "#;
2534
2535 let tvl_resp = r#"
2537 {
2538 "tvl": {
2539 "component1": 1000000.0
2540 },
2541 "pagination": {
2542 "page": 0,
2543 "page_size": 100,
2544 "total": 1
2545 }
2546 }
2547 "#;
2548
2549 let protocol_states_mock = server
2550 .mock("POST", "/v1/protocol_state")
2551 .expect(1)
2552 .with_body(protocol_states_resp)
2553 .create_async()
2554 .await;
2555
2556 let contract_state_mock = server
2557 .mock("POST", "/v1/contract_state")
2558 .expect(1)
2559 .with_body(contract_state_resp)
2560 .create_async()
2561 .await;
2562
2563 let tvl_mock = server
2564 .mock("POST", "/v1/component_tvl")
2565 .expect(1)
2566 .with_body(tvl_resp)
2567 .create_async()
2568 .await;
2569
2570 let client = HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
2571 .expect("create client");
2572
2573 #[allow(deprecated)]
2574 let component = tycho_common::dto::ProtocolComponent {
2575 id: "component1".to_string(),
2576 protocol_system: "test_protocol".to_string(),
2577 protocol_type_name: "test_type".to_string(),
2578 chain: Chain::Ethereum,
2579 tokens: vec![],
2580 contract_ids: vec![
2581 Bytes::from_str("0x1111111111111111111111111111111111111111").unwrap()
2582 ],
2583 static_attributes: HashMap::new(),
2584 change: tycho_common::dto::ChangeType::Creation,
2585 creation_tx: Bytes::from_str(
2586 "0x0000000000000000000000000000000000000000000000000000000000000000",
2587 )
2588 .unwrap(),
2589 created_at: chrono::Utc::now().naive_utc(),
2590 };
2591
2592 let mut components = HashMap::new();
2593 components.insert("component1".to_string(), component);
2594
2595 let contract_ids =
2596 vec![Bytes::from_str("0x1111111111111111111111111111111111111111").unwrap()];
2597
2598 let request = SnapshotParameters::new(
2599 Chain::Ethereum,
2600 "test_protocol",
2601 &components,
2602 &contract_ids,
2603 12345,
2604 );
2605
2606 let response = client
2607 .get_snapshots(&request, None, RPC_CLIENT_CONCURRENCY)
2608 .await
2609 .expect("get snapshots");
2610
2611 protocol_states_mock.assert();
2613 contract_state_mock.assert();
2614 tvl_mock.assert();
2615
2616 assert_eq!(response.states.len(), 1);
2618 assert!(response
2619 .states
2620 .contains_key("component1"));
2621
2622 let component_state = response
2624 .states
2625 .get("component1")
2626 .unwrap();
2627 assert_eq!(component_state.component_tvl, Some(1000000.0));
2628
2629 assert_eq!(response.vm_storage.len(), 1);
2631 let contract_addr = Bytes::from_str("0x1111111111111111111111111111111111111111").unwrap();
2632 assert!(response
2633 .vm_storage
2634 .contains_key(&contract_addr));
2635 }
2636
2637 #[tokio::test]
2638 async fn test_get_snapshots_empty_components() {
2639 let server = Server::new_async().await;
2640 let client = HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
2641 .expect("create client");
2642
2643 let components = HashMap::new();
2644 let contract_ids = vec![];
2645
2646 let request = SnapshotParameters::new(
2647 Chain::Ethereum,
2648 "test_protocol",
2649 &components,
2650 &contract_ids,
2651 12345,
2652 );
2653
2654 let response = client
2655 .get_snapshots(&request, None, RPC_CLIENT_CONCURRENCY)
2656 .await
2657 .expect("get snapshots");
2658
2659 assert!(response.states.is_empty());
2661 assert!(response.vm_storage.is_empty());
2662 }
2663
2664 #[tokio::test]
2665 async fn test_get_snapshots_without_tvl() {
2666 let mut server = Server::new_async().await;
2667
2668 let protocol_states_resp = r#"
2669 {
2670 "states": [
2671 {
2672 "component_id": "component1",
2673 "attributes": {},
2674 "balances": {}
2675 }
2676 ],
2677 "pagination": {
2678 "page": 0,
2679 "page_size": 100,
2680 "total": 1
2681 }
2682 }
2683 "#;
2684
2685 let protocol_states_mock = server
2686 .mock("POST", "/v1/protocol_state")
2687 .expect(1)
2688 .with_body(protocol_states_resp)
2689 .create_async()
2690 .await;
2691
2692 let client = HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
2693 .expect("create client");
2694
2695 #[allow(deprecated)]
2697 let component = tycho_common::dto::ProtocolComponent {
2698 id: "component1".to_string(),
2699 protocol_system: "test_protocol".to_string(),
2700 protocol_type_name: "test_type".to_string(),
2701 chain: Chain::Ethereum,
2702 tokens: vec![],
2703 contract_ids: vec![],
2704 static_attributes: HashMap::new(),
2705 change: tycho_common::dto::ChangeType::Creation,
2706 creation_tx: Bytes::from_str(
2707 "0x0000000000000000000000000000000000000000000000000000000000000000",
2708 )
2709 .unwrap(),
2710 created_at: chrono::Utc::now().naive_utc(),
2711 };
2712
2713 let mut components = HashMap::new();
2714 components.insert("component1".to_string(), component);
2715 let contract_ids = vec![];
2716
2717 let request = SnapshotParameters::new(
2718 Chain::Ethereum,
2719 "test_protocol",
2720 &components,
2721 &contract_ids,
2722 12345,
2723 )
2724 .include_balances(false)
2725 .include_tvl(false);
2726
2727 let response = client
2728 .get_snapshots(&request, None, RPC_CLIENT_CONCURRENCY)
2729 .await
2730 .expect("get snapshots");
2731
2732 protocol_states_mock.assert();
2734 assert_eq!(response.states.len(), 1);
2738 let component_state = response
2740 .states
2741 .get("component1")
2742 .unwrap();
2743 assert_eq!(component_state.component_tvl, None);
2744 }
2745
2746 #[tokio::test]
2747 async fn test_compression_enabled() {
2748 let mut server = Server::new_async().await;
2749 let server_resp = GET_CONTRACT_STATE_RESP;
2750
2751 let compressed_body =
2753 zstd::encode_all(server_resp.as_bytes(), 0).expect("compression failed");
2754
2755 let mocked_server = server
2756 .mock("POST", "/v1/contract_state")
2757 .expect(1)
2758 .with_header("Content-Encoding", "zstd")
2759 .with_body(compressed_body)
2760 .create_async()
2761 .await;
2762
2763 let client = HttpRPCClient::new(
2765 server.url().as_str(),
2766 HttpRPCClientOptions::new().with_compression(true),
2767 )
2768 .expect("create client");
2769
2770 let response = client
2771 .get_contract_state(&Default::default())
2772 .await
2773 .expect("get state");
2774 let accounts = response.accounts;
2775
2776 mocked_server.assert();
2777 assert_eq!(accounts.len(), 1);
2778 assert_eq!(accounts[0].native_balance, Bytes::from(500u16.to_be_bytes()));
2779 }
2780
2781 #[tokio::test]
2782 async fn test_compression_disabled() {
2783 let mut server = Server::new_async().await;
2784 let server_resp = GET_CONTRACT_STATE_RESP;
2785
2786 let mocked_server = server
2789 .mock("POST", "/v1/contract_state")
2790 .expect(1)
2791 .match_header("Accept-Encoding", mockito::Matcher::Missing)
2792 .with_status(200)
2793 .with_body(server_resp)
2794 .create_async()
2795 .await;
2796
2797 let client = HttpRPCClient::new(
2799 server.url().as_str(),
2800 HttpRPCClientOptions::new().with_compression(false),
2801 )
2802 .expect("create client");
2803
2804 let response = client
2805 .get_contract_state(&Default::default())
2806 .await
2807 .expect("get state");
2808 let accounts = response.accounts;
2809
2810 mocked_server.assert();
2812 assert_eq!(accounts.len(), 1);
2813 assert_eq!(accounts[0].native_balance, Bytes::from(500u16.to_be_bytes()));
2814 }
2815
2816 #[rstest]
2817 #[case::single_page(2, 1000)]
2818 #[case::multiple_pages_within_concurrency(10, 2)]
2819 #[case::exceeds_concurrency_limit(60, 2)]
2820 #[tokio::test]
2821 async fn test_get_all_tokens_pagination_and_concurrency(
2822 #[case] total_tokens: usize,
2823 #[case] page_size: usize,
2824 ) {
2825 use std::sync::atomic::{AtomicUsize, Ordering};
2826
2827 let allowed_concurrency = 10;
2828
2829 let concurrent_requests = Arc::new(AtomicUsize::new(0));
2830 let max_concurrent = Arc::new(AtomicUsize::new(0));
2831
2832 let mut server = Server::new_async().await;
2833
2834 let total_pages = (total_tokens as f64 / page_size as f64).ceil() as i64;
2835
2836 for page in 0..total_pages {
2838 let concurrent = concurrent_requests.clone();
2839 let max_conc = max_concurrent.clone();
2840
2841 let tokens_in_page = {
2842 let start_idx = (page as usize) * page_size;
2843 let end_idx = ((page as usize + 1) * page_size).min(total_tokens);
2844 (start_idx..end_idx)
2845 .map(|i| {
2846 format!(
2847 r#"{{
2848 "chain": "ethereum",
2849 "address": "0x{i:040x}",
2850 "symbol": "TOKEN_{i}",
2851 "decimals": 18,
2852 "tax": 0,
2853 "gas": [30000],
2854 "quality": 100
2855 }}"#
2856 )
2857 })
2858 .collect::<Vec<_>>()
2859 };
2860
2861 let tokens_json = tokens_in_page.join(",");
2862 let response = format!(
2863 r#"{{
2864 "tokens": [{tokens_json}],
2865 "pagination": {{
2866 "page": {page},
2867 "page_size": {page_size},
2868 "total": {total_tokens}
2869 }}
2870 }}"#,
2871 );
2872
2873 server
2874 .mock("POST", "/v1/tokens")
2875 .expect(1)
2876 .with_chunked_body(move |w| {
2877 let current = concurrent.fetch_add(1, Ordering::SeqCst);
2879 max_conc.fetch_max(current + 1, Ordering::SeqCst);
2880
2881 std::thread::sleep(Duration::from_millis(10));
2883
2884 concurrent.fetch_sub(1, Ordering::SeqCst);
2885
2886 w.write_all(response.as_bytes())
2887 })
2888 .create_async()
2889 .await;
2890 }
2891
2892 let client = HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
2893 .expect("create client");
2894
2895 let tokens = client
2896 .get_all_tokens(Chain::Ethereum, None, None, Some(page_size), allowed_concurrency)
2897 .await
2898 .expect("get all tokens");
2899
2900 let max = max_concurrent.load(Ordering::SeqCst);
2902 let expected_max_concurrency = (total_pages as usize)
2903 .saturating_sub(1)
2904 .min(allowed_concurrency);
2905 assert!(
2906 max <= allowed_concurrency,
2907 "Expected max concurrent requests <= {allowed_concurrency}, got {max}"
2908 );
2909
2910 if total_pages > 1 && expected_max_concurrency > 1 {
2912 assert!(
2913 max > 0,
2914 "Expected some concurrent requests for multi-page response, got {max}"
2915 );
2916 }
2917
2918 assert_eq!(
2920 tokens.len(),
2921 total_tokens,
2922 "Expected {total_tokens} tokens, got {}",
2923 tokens.len()
2924 );
2925
2926 for (i, token) in tokens.iter().enumerate() {
2928 assert_eq!(token.symbol, format!("TOKEN_{i}"), "Token at index {i} has wrong symbol");
2929 }
2930 }
2931}