1use std::{
7 collections::HashMap,
8 sync::Arc,
9 time::{Duration, SystemTime},
10};
11
12use async_trait::async_trait;
13use backoff::{exponential::ExponentialBackoffBuilder, ExponentialBackoff};
14use futures03::future::try_join_all;
15#[cfg(test)]
16use mockall::automock;
17use reqwest::{header, Client, ClientBuilder, Response, StatusCode, Url};
18use serde::Serialize;
19use thiserror::Error;
20use time::{format_description::well_known::Rfc2822, OffsetDateTime};
21use tokio::{
22 sync::{RwLock, Semaphore},
23 time::sleep,
24};
25use tracing::{debug, error, instrument, trace, warn};
26use tycho_common::{
27 dto::{
28 BlockParam, Chain, ComponentTvlRequestBody, ComponentTvlRequestResponse,
29 EntryPointWithTracingParams, PaginationParams, PaginationResponse, ProtocolComponent,
30 ProtocolComponentRequestResponse, ProtocolComponentsRequestBody, ProtocolStateRequestBody,
31 ProtocolStateRequestResponse, ProtocolSystemsRequestBody, ProtocolSystemsRequestResponse,
32 ResponseToken, StateRequestBody, StateRequestResponse, TokensRequestBody,
33 TokensRequestResponse, TracedEntryPointRequestBody, TracedEntryPointRequestResponse,
34 TracingResult, VersionParam,
35 },
36 models::ComponentId,
37 Bytes,
38};
39
40use crate::{
41 feed::synchronizer::{ComponentWithState, Snapshot},
42 TYCHO_SERVER_VERSION,
43};
44
45#[derive(Clone, Debug, PartialEq)]
50pub struct SnapshotParameters<'a> {
51 pub chain: Chain,
53 pub protocol_system: &'a str,
55 pub components: &'a HashMap<ComponentId, ProtocolComponent>,
57 pub entrypoints: Option<&'a HashMap<String, Vec<(EntryPointWithTracingParams, TracingResult)>>>,
59 pub contract_ids: &'a [Bytes],
61 pub block_number: u64,
63 pub include_balances: bool,
65 pub include_tvl: bool,
67}
68
69impl<'a> SnapshotParameters<'a> {
70 pub fn new(
71 chain: Chain,
72 protocol_system: &'a str,
73 components: &'a HashMap<ComponentId, ProtocolComponent>,
74 contract_ids: &'a [Bytes],
75 block_number: u64,
76 ) -> Self {
77 Self {
78 chain,
79 protocol_system,
80 components,
81 entrypoints: None,
82 contract_ids,
83 block_number,
84 include_balances: true,
85 include_tvl: true,
86 }
87 }
88
89 pub fn include_balances(mut self, include_balances: bool) -> Self {
91 self.include_balances = include_balances;
92 self
93 }
94
95 pub fn include_tvl(mut self, include_tvl: bool) -> Self {
97 self.include_tvl = include_tvl;
98 self
99 }
100
101 pub fn entrypoints(
102 mut self,
103 entrypoints: &'a HashMap<String, Vec<(EntryPointWithTracingParams, TracingResult)>>,
104 ) -> Self {
105 self.entrypoints = Some(entrypoints);
106 self
107 }
108}
109
110#[derive(Error, Debug)]
111pub enum RPCError {
112 #[error("Failed to parse URL: {0}. Error: {1}")]
114 UrlParsing(String, String),
115
116 #[error("Failed to format request: {0}")]
118 FormatRequest(String),
119
120 #[error("Unexpected HTTP client error: {0}")]
122 HttpClient(String, #[source] reqwest::Error),
123
124 #[error("Failed to parse response: {0}")]
126 ParseResponse(String),
127
128 #[error("Fatal error: {0}")]
130 Fatal(String),
131
132 #[error("Rate limited until {0:?}")]
133 RateLimited(Option<SystemTime>),
134
135 #[error("Server unreachable: {0}")]
136 ServerUnreachable(String),
137}
138
139#[cfg_attr(test, automock)]
140#[async_trait]
141pub trait RPCClient: Send + Sync {
142 async fn get_contract_state(
144 &self,
145 request: &StateRequestBody,
146 ) -> Result<StateRequestResponse, RPCError>;
147
148 async fn get_contract_state_paginated(
149 &self,
150 chain: Chain,
151 ids: &[Bytes],
152 protocol_system: &str,
153 version: &VersionParam,
154 chunk_size: usize,
155 concurrency: usize,
156 ) -> Result<StateRequestResponse, RPCError> {
157 let semaphore = Arc::new(Semaphore::new(concurrency));
158
159 let mut sorted_ids = ids.to_vec();
161 sorted_ids.sort();
162
163 let chunked_bodies = sorted_ids
164 .chunks(chunk_size)
165 .map(|chunk| StateRequestBody {
166 contract_ids: Some(chunk.to_vec()),
167 protocol_system: protocol_system.to_string(),
168 chain,
169 version: version.clone(),
170 pagination: PaginationParams { page: 0, page_size: chunk_size as i64 },
171 })
172 .collect::<Vec<_>>();
173
174 let mut tasks = Vec::new();
175 for body in chunked_bodies.iter() {
176 let sem = semaphore.clone();
177 tasks.push(async move {
178 let _permit = sem
179 .acquire()
180 .await
181 .map_err(|_| RPCError::Fatal("Semaphore dropped".to_string()))?;
182 self.get_contract_state(body).await
183 });
184 }
185
186 let responses = try_join_all(tasks).await?;
188
189 let accounts = responses
191 .iter()
192 .flat_map(|r| r.accounts.clone())
193 .collect();
194 let total: i64 = responses
195 .iter()
196 .map(|r| r.pagination.total)
197 .sum();
198
199 Ok(StateRequestResponse {
200 accounts,
201 pagination: PaginationResponse { page: 0, page_size: chunk_size as i64, total },
202 })
203 }
204
205 async fn get_protocol_components(
206 &self,
207 request: &ProtocolComponentsRequestBody,
208 ) -> Result<ProtocolComponentRequestResponse, RPCError>;
209
210 async fn get_protocol_components_paginated(
211 &self,
212 request: &ProtocolComponentsRequestBody,
213 chunk_size: usize,
214 concurrency: usize,
215 ) -> Result<ProtocolComponentRequestResponse, RPCError> {
216 let semaphore = Arc::new(Semaphore::new(concurrency));
217
218 match request.component_ids {
221 Some(ref ids) => {
222 let chunked_bodies = ids
224 .chunks(chunk_size)
225 .enumerate()
226 .map(|(index, _)| ProtocolComponentsRequestBody {
227 protocol_system: request.protocol_system.clone(),
228 component_ids: request.component_ids.clone(),
229 tvl_gt: request.tvl_gt,
230 chain: request.chain,
231 pagination: PaginationParams {
232 page: index as i64,
233 page_size: chunk_size as i64,
234 },
235 })
236 .collect::<Vec<_>>();
237
238 let mut tasks = Vec::new();
239 for body in chunked_bodies.iter() {
240 let sem = semaphore.clone();
241 tasks.push(async move {
242 let _permit = sem
243 .acquire()
244 .await
245 .map_err(|_| RPCError::Fatal("Semaphore dropped".to_string()))?;
246 self.get_protocol_components(body).await
247 });
248 }
249
250 try_join_all(tasks)
251 .await
252 .map(|responses| ProtocolComponentRequestResponse {
253 protocol_components: responses
254 .into_iter()
255 .flat_map(|r| r.protocol_components.into_iter())
256 .collect(),
257 pagination: PaginationResponse {
258 page: 0,
259 page_size: chunk_size as i64,
260 total: ids.len() as i64,
261 },
262 })
263 }
264 _ => {
265 let initial_request = ProtocolComponentsRequestBody {
269 protocol_system: request.protocol_system.clone(),
270 component_ids: request.component_ids.clone(),
271 tvl_gt: request.tvl_gt,
272 chain: request.chain,
273 pagination: PaginationParams { page: 0, page_size: chunk_size as i64 },
274 };
275 let first_response = self
276 .get_protocol_components(&initial_request)
277 .await
278 .map_err(|err| RPCError::Fatal(err.to_string()))?;
279
280 let total_items = first_response.pagination.total;
281 let total_pages = (total_items as f64 / chunk_size as f64).ceil() as i64;
282
283 let mut accumulated_response = ProtocolComponentRequestResponse {
285 protocol_components: first_response.protocol_components,
286 pagination: PaginationResponse {
287 page: 0,
288 page_size: chunk_size as i64,
289 total: total_items,
290 },
291 };
292
293 let mut page = 1;
294 while page < total_pages {
295 let requests_in_this_iteration = (total_pages - page).min(concurrency as i64);
296
297 let chunked_bodies = (0..requests_in_this_iteration)
299 .map(|iter| ProtocolComponentsRequestBody {
300 protocol_system: request.protocol_system.clone(),
301 component_ids: request.component_ids.clone(),
302 tvl_gt: request.tvl_gt,
303 chain: request.chain,
304 pagination: PaginationParams {
305 page: page + iter,
306 page_size: chunk_size as i64,
307 },
308 })
309 .collect::<Vec<_>>();
310
311 let tasks: Vec<_> = chunked_bodies
312 .iter()
313 .map(|body| {
314 let sem = semaphore.clone();
315 async move {
316 let _permit = sem.acquire().await.map_err(|_| {
317 RPCError::Fatal("Semaphore dropped".to_string())
318 })?;
319 self.get_protocol_components(body).await
320 }
321 })
322 .collect();
323
324 let responses = try_join_all(tasks)
325 .await
326 .map(|responses| {
327 let total = responses[0].pagination.total;
328 ProtocolComponentRequestResponse {
329 protocol_components: responses
330 .into_iter()
331 .flat_map(|r| r.protocol_components.into_iter())
332 .collect(),
333 pagination: PaginationResponse {
334 page,
335 page_size: chunk_size as i64,
336 total,
337 },
338 }
339 });
340
341 match responses {
343 Ok(mut resp) => {
344 accumulated_response
345 .protocol_components
346 .append(&mut resp.protocol_components);
347 }
348 Err(e) => return Err(e),
349 }
350
351 page += concurrency as i64;
352 }
353 Ok(accumulated_response)
354 }
355 }
356 }
357
358 async fn get_protocol_states(
359 &self,
360 request: &ProtocolStateRequestBody,
361 ) -> Result<ProtocolStateRequestResponse, RPCError>;
362
363 #[allow(clippy::too_many_arguments)]
364 async fn get_protocol_states_paginated<T>(
365 &self,
366 chain: Chain,
367 ids: &[T],
368 protocol_system: &str,
369 include_balances: bool,
370 version: &VersionParam,
371 chunk_size: usize,
372 concurrency: usize,
373 ) -> Result<ProtocolStateRequestResponse, RPCError>
374 where
375 T: AsRef<str> + Sync + 'static,
376 {
377 let semaphore = Arc::new(Semaphore::new(concurrency));
378 let chunked_bodies = ids
379 .chunks(chunk_size)
380 .map(|c| ProtocolStateRequestBody {
381 protocol_ids: Some(
382 c.iter()
383 .map(|id| id.as_ref().to_string())
384 .collect(),
385 ),
386 protocol_system: protocol_system.to_string(),
387 chain,
388 include_balances,
389 version: version.clone(),
390 pagination: PaginationParams { page: 0, page_size: chunk_size as i64 },
391 })
392 .collect::<Vec<_>>();
393
394 let mut tasks = Vec::new();
395 for body in chunked_bodies.iter() {
396 let sem = semaphore.clone();
397 tasks.push(async move {
398 let _permit = sem
399 .acquire()
400 .await
401 .map_err(|_| RPCError::Fatal("Semaphore dropped".to_string()))?;
402 self.get_protocol_states(body).await
403 });
404 }
405
406 try_join_all(tasks)
407 .await
408 .map(|responses| {
409 let states = responses
410 .clone()
411 .into_iter()
412 .flat_map(|r| r.states)
413 .collect();
414 let total = responses
415 .iter()
416 .map(|r| r.pagination.total)
417 .sum();
418 ProtocolStateRequestResponse {
419 states,
420 pagination: PaginationResponse { page: 0, page_size: chunk_size as i64, total },
421 }
422 })
423 }
424
425 async fn get_tokens(
428 &self,
429 request: &TokensRequestBody,
430 ) -> Result<TokensRequestResponse, RPCError>;
431
432 async fn get_all_tokens(
433 &self,
434 chain: Chain,
435 min_quality: Option<i32>,
436 traded_n_days_ago: Option<u64>,
437 chunk_size: usize,
438 ) -> Result<Vec<ResponseToken>, RPCError> {
439 let mut request_page = 0;
440 let mut all_tokens = Vec::new();
441 loop {
442 let mut response = self
443 .get_tokens(&TokensRequestBody {
444 token_addresses: None,
445 min_quality,
446 traded_n_days_ago,
447 pagination: PaginationParams {
448 page: request_page,
449 page_size: chunk_size.try_into().map_err(|_| {
450 RPCError::FormatRequest(
451 "Failed to convert chunk_size into i64".to_string(),
452 )
453 })?,
454 },
455 chain,
456 })
457 .await?;
458
459 let num_tokens = response.tokens.len();
460 all_tokens.append(&mut response.tokens);
461 request_page += 1;
462
463 if num_tokens < chunk_size {
464 break;
465 }
466 }
467 Ok(all_tokens)
468 }
469
470 async fn get_protocol_systems(
471 &self,
472 request: &ProtocolSystemsRequestBody,
473 ) -> Result<ProtocolSystemsRequestResponse, RPCError>;
474
475 async fn get_component_tvl(
476 &self,
477 request: &ComponentTvlRequestBody,
478 ) -> Result<ComponentTvlRequestResponse, RPCError>;
479
480 async fn get_component_tvl_paginated(
481 &self,
482 request: &ComponentTvlRequestBody,
483 chunk_size: usize,
484 concurrency: usize,
485 ) -> Result<ComponentTvlRequestResponse, RPCError> {
486 let semaphore = Arc::new(Semaphore::new(concurrency));
487
488 match request.component_ids {
489 Some(ref ids) => {
490 let chunked_requests = ids
491 .chunks(chunk_size)
492 .enumerate()
493 .map(|(index, _)| ComponentTvlRequestBody {
494 chain: request.chain,
495 protocol_system: request.protocol_system.clone(),
496 component_ids: Some(ids.clone()),
497 pagination: PaginationParams {
498 page: index as i64,
499 page_size: chunk_size as i64,
500 },
501 })
502 .collect::<Vec<_>>();
503
504 let tasks: Vec<_> = chunked_requests
505 .into_iter()
506 .map(|req| {
507 let sem = semaphore.clone();
508 async move {
509 let _permit = sem
510 .acquire()
511 .await
512 .map_err(|_| RPCError::Fatal("Semaphore dropped".to_string()))?;
513 self.get_component_tvl(&req).await
514 }
515 })
516 .collect();
517
518 let responses = try_join_all(tasks).await?;
519
520 let mut merged_tvl = HashMap::new();
521 for resp in responses {
522 for (key, value) in resp.tvl {
523 *merged_tvl.entry(key).or_insert(0.0) = value;
524 }
525 }
526
527 Ok(ComponentTvlRequestResponse {
528 tvl: merged_tvl,
529 pagination: PaginationResponse {
530 page: 0,
531 page_size: chunk_size as i64,
532 total: ids.len() as i64,
533 },
534 })
535 }
536 _ => {
537 let first_request = ComponentTvlRequestBody {
538 chain: request.chain,
539 protocol_system: request.protocol_system.clone(),
540 component_ids: request.component_ids.clone(),
541 pagination: PaginationParams { page: 0, page_size: chunk_size as i64 },
542 };
543
544 let first_response = self
545 .get_component_tvl(&first_request)
546 .await?;
547 let total_items = first_response.pagination.total;
548 let total_pages = (total_items as f64 / chunk_size as f64).ceil() as i64;
549
550 let mut merged_tvl = first_response.tvl;
551
552 let mut page = 1;
553 while page < total_pages {
554 let requests_in_this_iteration = (total_pages - page).min(concurrency as i64);
555
556 let chunked_requests: Vec<_> = (0..requests_in_this_iteration)
557 .map(|i| ComponentTvlRequestBody {
558 chain: request.chain,
559 protocol_system: request.protocol_system.clone(),
560 component_ids: request.component_ids.clone(),
561 pagination: PaginationParams {
562 page: page + i,
563 page_size: chunk_size as i64,
564 },
565 })
566 .collect();
567
568 let tasks: Vec<_> = chunked_requests
569 .into_iter()
570 .map(|req| {
571 let sem = semaphore.clone();
572 async move {
573 let _permit = sem.acquire().await.map_err(|_| {
574 RPCError::Fatal("Semaphore dropped".to_string())
575 })?;
576 self.get_component_tvl(&req).await
577 }
578 })
579 .collect();
580
581 let responses = try_join_all(tasks).await?;
582
583 for resp in responses {
585 for (key, value) in resp.tvl {
586 *merged_tvl.entry(key).or_insert(0.0) += value;
587 }
588 }
589
590 page += concurrency as i64;
591 }
592
593 Ok(ComponentTvlRequestResponse {
594 tvl: merged_tvl,
595 pagination: PaginationResponse {
596 page: 0,
597 page_size: chunk_size as i64,
598 total: total_items,
599 },
600 })
601 }
602 }
603 }
604
605 async fn get_traced_entry_points(
606 &self,
607 request: &TracedEntryPointRequestBody,
608 ) -> Result<TracedEntryPointRequestResponse, RPCError>;
609
610 async fn get_traced_entry_points_paginated(
611 &self,
612 chain: Chain,
613 protocol_system: &str,
614 component_ids: &[String],
615 chunk_size: usize,
616 concurrency: usize,
617 ) -> Result<TracedEntryPointRequestResponse, RPCError> {
618 let semaphore = Arc::new(Semaphore::new(concurrency));
619 let chunked_bodies = component_ids
620 .chunks(chunk_size)
621 .map(|c| TracedEntryPointRequestBody {
622 chain,
623 protocol_system: protocol_system.to_string(),
624 component_ids: Some(c.to_vec()),
625 pagination: PaginationParams { page: 0, page_size: chunk_size as i64 },
626 })
627 .collect::<Vec<_>>();
628
629 let mut tasks = Vec::new();
630 for body in chunked_bodies.iter() {
631 let sem = semaphore.clone();
632 tasks.push(async move {
633 let _permit = sem
634 .acquire()
635 .await
636 .map_err(|_| RPCError::Fatal("Semaphore dropped".to_string()))?;
637 self.get_traced_entry_points(body).await
638 });
639 }
640
641 try_join_all(tasks)
642 .await
643 .map(|responses| {
644 let traced_entry_points = responses
645 .clone()
646 .into_iter()
647 .flat_map(|r| r.traced_entry_points)
648 .collect();
649 let total = responses
650 .iter()
651 .map(|r| r.pagination.total)
652 .sum();
653 TracedEntryPointRequestResponse {
654 traced_entry_points,
655 pagination: PaginationResponse { page: 0, page_size: chunk_size as i64, total },
656 }
657 })
658 }
659
660 async fn get_snapshots<'a>(
661 &self,
662 request: &SnapshotParameters<'a>,
663 chunk_size: usize,
664 concurrency: usize,
665 ) -> Result<Snapshot, RPCError>;
666}
667
668#[derive(Debug, Clone)]
669pub struct HttpRPCClient {
670 http_client: Client,
671 url: Url,
672 retry_after: Arc<RwLock<Option<SystemTime>>>,
673 backoff_policy: ExponentialBackoff,
674 server_restart_duration: Duration,
675}
676
677impl HttpRPCClient {
678 pub fn new(base_uri: &str, auth_key: Option<&str>) -> Result<Self, RPCError> {
679 let uri = base_uri
680 .parse::<Url>()
681 .map_err(|e| RPCError::UrlParsing(base_uri.to_string(), e.to_string()))?;
682
683 let mut headers = header::HeaderMap::new();
685 headers.insert(header::CONTENT_TYPE, header::HeaderValue::from_static("application/json"));
686 let user_agent = format!("tycho-client-{version}", version = env!("CARGO_PKG_VERSION"));
687 headers.insert(
688 header::USER_AGENT,
689 header::HeaderValue::from_str(&user_agent)
690 .map_err(|e| RPCError::FormatRequest(format!("Invalid user agent format: {e}")))?,
691 );
692
693 if let Some(key) = auth_key {
695 let mut auth_value = header::HeaderValue::from_str(key).map_err(|e| {
696 RPCError::FormatRequest(format!("Invalid authorization key format: {e}"))
697 })?;
698 auth_value.set_sensitive(true);
699 headers.insert(header::AUTHORIZATION, auth_value);
700 }
701
702 let client = ClientBuilder::new()
703 .default_headers(headers)
704 .http2_prior_knowledge()
705 .build()
706 .map_err(|e| RPCError::HttpClient(e.to_string(), e))?;
707 Ok(Self {
708 http_client: client,
709 url: uri,
710 retry_after: Arc::new(RwLock::new(None)),
711 backoff_policy: ExponentialBackoffBuilder::new()
712 .with_initial_interval(Duration::from_millis(250))
713 .with_multiplier(1.75)
715 .with_max_interval(Duration::from_secs(30))
717 .with_max_elapsed_time(Some(Duration::from_secs(125)))
719 .build(),
720 server_restart_duration: Duration::from_secs(120),
721 })
722 }
723
724 #[cfg(test)]
725 pub fn with_test_backoff_policy(mut self) -> Self {
726 self.backoff_policy = ExponentialBackoffBuilder::new()
728 .with_initial_interval(Duration::from_millis(1))
729 .with_multiplier(1.1)
730 .with_max_interval(Duration::from_millis(5))
731 .with_max_elapsed_time(Some(Duration::from_millis(50)))
732 .build();
733 self.server_restart_duration = Duration::from_millis(50);
734 self
735 }
736
737 async fn error_for_response(
743 &self,
744 response: reqwest::Response,
745 ) -> Result<reqwest::Response, RPCError> {
746 match response.status() {
747 StatusCode::TOO_MANY_REQUESTS => {
748 let retry_after_raw = response
749 .headers()
750 .get(reqwest::header::RETRY_AFTER)
751 .and_then(|h| h.to_str().ok())
752 .and_then(parse_retry_value);
753
754 Err(RPCError::RateLimited(retry_after_raw))
755 }
756 StatusCode::BAD_GATEWAY |
757 StatusCode::SERVICE_UNAVAILABLE |
758 StatusCode::GATEWAY_TIMEOUT => Err(RPCError::ServerUnreachable(
759 response
760 .text()
761 .await
762 .unwrap_or_else(|_| "Server Unreachable".to_string()),
763 )),
764 _ => Ok(response),
765 }
766 }
767
768 async fn handle_error_for_backoff(&self, e: RPCError) -> backoff::Error<RPCError> {
774 match e {
775 RPCError::ServerUnreachable(_) => {
776 backoff::Error::retry_after(e, self.server_restart_duration)
777 }
778 RPCError::RateLimited(Some(until)) => {
779 let mut retry_after_guard = self.retry_after.write().await;
780 *retry_after_guard = Some(
781 retry_after_guard
782 .unwrap_or(until)
783 .max(until),
784 );
785
786 if let Ok(duration) = until.duration_since(SystemTime::now()) {
787 backoff::Error::retry_after(e, duration)
788 } else {
789 e.into()
790 }
791 }
792 RPCError::RateLimited(None) => e.into(),
793 _ => backoff::Error::permanent(e),
794 }
795 }
796
797 async fn wait_until_retry_after(&self) {
802 if let Some(&until) = self.retry_after.read().await.as_ref() {
803 let now = SystemTime::now();
804 if until > now {
805 if let Ok(duration) = until.duration_since(now) {
806 sleep(duration).await
807 }
808 }
809 }
810 }
811
812 async fn make_post_request<T: Serialize + ?Sized>(
817 &self,
818 request: &T,
819 uri: &String,
820 ) -> Result<Response, RPCError> {
821 self.wait_until_retry_after().await;
822 let response = backoff::future::retry(self.backoff_policy.clone(), || async {
823 let server_response = self
824 .http_client
825 .post(uri)
826 .json(request)
827 .send()
828 .await
829 .map_err(|e| RPCError::HttpClient(e.to_string(), e))?;
830
831 match self
832 .error_for_response(server_response)
833 .await
834 {
835 Ok(response) => Ok(response),
836 Err(e) => Err(self.handle_error_for_backoff(e).await),
837 }
838 })
839 .await?;
840 Ok(response)
841 }
842}
843
844fn parse_retry_value(val: &str) -> Option<SystemTime> {
845 if let Ok(secs) = val.parse::<u64>() {
846 return Some(SystemTime::now() + Duration::from_secs(secs));
847 }
848 if let Ok(date) = OffsetDateTime::parse(val, &Rfc2822) {
849 return Some(date.into());
850 }
851 None
852}
853
854#[async_trait]
855impl RPCClient for HttpRPCClient {
856 #[instrument(skip(self, request))]
857 async fn get_contract_state(
858 &self,
859 request: &StateRequestBody,
860 ) -> Result<StateRequestResponse, RPCError> {
861 if request
863 .contract_ids
864 .as_ref()
865 .is_none_or(|ids| ids.is_empty())
866 {
867 warn!("No contract ids specified in request.");
868 }
869
870 let uri = format!(
871 "{}/{}/contract_state",
872 self.url
873 .to_string()
874 .trim_end_matches('/'),
875 TYCHO_SERVER_VERSION
876 );
877 debug!(%uri, "Sending contract_state request to Tycho server");
878 trace!(?request, "Sending request to Tycho server");
879 let response = self
880 .make_post_request(request, &uri)
881 .await?;
882 trace!(?response, "Received response from Tycho server");
883
884 let body = response
885 .text()
886 .await
887 .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
888 if body.is_empty() {
889 return Ok(StateRequestResponse {
891 accounts: vec![],
892 pagination: PaginationResponse {
893 page: request.pagination.page,
894 page_size: request.pagination.page,
895 total: 0,
896 },
897 });
898 }
899
900 let accounts = serde_json::from_str::<StateRequestResponse>(&body)
901 .map_err(|err| RPCError::ParseResponse(format!("Error: {err}, Body: {body}")))?;
902 trace!(?accounts, "Received contract_state response from Tycho server");
903
904 Ok(accounts)
905 }
906
907 async fn get_protocol_components(
908 &self,
909 request: &ProtocolComponentsRequestBody,
910 ) -> Result<ProtocolComponentRequestResponse, RPCError> {
911 let uri = format!(
912 "{}/{}/protocol_components",
913 self.url
914 .to_string()
915 .trim_end_matches('/'),
916 TYCHO_SERVER_VERSION,
917 );
918 debug!(%uri, "Sending protocol_components request to Tycho server");
919 trace!(?request, "Sending request to Tycho server");
920
921 let response = self
922 .make_post_request(request, &uri)
923 .await?;
924
925 trace!(?response, "Received response from Tycho server");
926
927 let body = response
928 .text()
929 .await
930 .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
931 let components = serde_json::from_str::<ProtocolComponentRequestResponse>(&body)
932 .map_err(|err| RPCError::ParseResponse(format!("Error: {err}, Body: {body}")))?;
933 trace!(?components, "Received protocol_components response from Tycho server");
934
935 Ok(components)
936 }
937
938 async fn get_protocol_states(
939 &self,
940 request: &ProtocolStateRequestBody,
941 ) -> Result<ProtocolStateRequestResponse, RPCError> {
942 if request
944 .protocol_ids
945 .as_ref()
946 .is_none_or(|ids| ids.is_empty())
947 {
948 warn!("No protocol ids specified in request.");
949 }
950
951 let uri = format!(
952 "{}/{}/protocol_state",
953 self.url
954 .to_string()
955 .trim_end_matches('/'),
956 TYCHO_SERVER_VERSION
957 );
958 debug!(%uri, "Sending protocol_states request to Tycho server");
959 trace!(?request, "Sending request to Tycho server");
960
961 let response = self
962 .make_post_request(request, &uri)
963 .await?;
964 trace!(?response, "Received response from Tycho server");
965
966 let body = response
967 .text()
968 .await
969 .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
970
971 if body.is_empty() {
972 return Ok(ProtocolStateRequestResponse {
974 states: vec![],
975 pagination: PaginationResponse {
976 page: request.pagination.page,
977 page_size: request.pagination.page_size,
978 total: 0,
979 },
980 });
981 }
982
983 let states = serde_json::from_str::<ProtocolStateRequestResponse>(&body)
984 .map_err(|err| RPCError::ParseResponse(format!("Error: {err}, Body: {body}")))?;
985 trace!(?states, "Received protocol_states response from Tycho server");
986
987 Ok(states)
988 }
989
990 async fn get_tokens(
991 &self,
992 request: &TokensRequestBody,
993 ) -> Result<TokensRequestResponse, RPCError> {
994 let uri = format!(
995 "{}/{}/tokens",
996 self.url
997 .to_string()
998 .trim_end_matches('/'),
999 TYCHO_SERVER_VERSION
1000 );
1001 debug!(%uri, "Sending tokens request to Tycho server");
1002
1003 let response = self
1004 .make_post_request(request, &uri)
1005 .await?;
1006
1007 let body = response
1008 .text()
1009 .await
1010 .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
1011 let tokens = serde_json::from_str::<TokensRequestResponse>(&body)
1012 .map_err(|err| RPCError::ParseResponse(format!("Error: {err}, Body: {body}")))?;
1013
1014 Ok(tokens)
1015 }
1016
1017 async fn get_protocol_systems(
1018 &self,
1019 request: &ProtocolSystemsRequestBody,
1020 ) -> Result<ProtocolSystemsRequestResponse, RPCError> {
1021 let uri = format!(
1022 "{}/{}/protocol_systems",
1023 self.url
1024 .to_string()
1025 .trim_end_matches('/'),
1026 TYCHO_SERVER_VERSION
1027 );
1028 debug!(%uri, "Sending protocol_systems request to Tycho server");
1029 trace!(?request, "Sending request to Tycho server");
1030 let response = self
1031 .make_post_request(request, &uri)
1032 .await?;
1033 trace!(?response, "Received response from Tycho server");
1034 let body = response
1035 .text()
1036 .await
1037 .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
1038 let protocol_systems = serde_json::from_str::<ProtocolSystemsRequestResponse>(&body)
1039 .map_err(|err| RPCError::ParseResponse(format!("Error: {err}, Body: {body}")))?;
1040 trace!(?protocol_systems, "Received protocol_systems response from Tycho server");
1041 Ok(protocol_systems)
1042 }
1043
1044 async fn get_component_tvl(
1045 &self,
1046 request: &ComponentTvlRequestBody,
1047 ) -> Result<ComponentTvlRequestResponse, RPCError> {
1048 let uri = format!(
1049 "{}/{}/component_tvl",
1050 self.url
1051 .to_string()
1052 .trim_end_matches('/'),
1053 TYCHO_SERVER_VERSION
1054 );
1055 debug!(%uri, "Sending get_component_tvl request to Tycho server");
1056 trace!(?request, "Sending request to Tycho server");
1057 let response = self
1058 .make_post_request(request, &uri)
1059 .await?;
1060 trace!(?response, "Received response from Tycho server");
1061 let body = response
1062 .text()
1063 .await
1064 .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
1065 let component_tvl =
1066 serde_json::from_str::<ComponentTvlRequestResponse>(&body).map_err(|err| {
1067 error!("Failed to parse component_tvl response: {:?}", &body);
1068 RPCError::ParseResponse(format!("Error: {err}, Body: {body}"))
1069 })?;
1070 trace!(?component_tvl, "Received component_tvl response from Tycho server");
1071 Ok(component_tvl)
1072 }
1073
1074 async fn get_traced_entry_points(
1075 &self,
1076 request: &TracedEntryPointRequestBody,
1077 ) -> Result<TracedEntryPointRequestResponse, RPCError> {
1078 let uri = format!(
1079 "{}/{TYCHO_SERVER_VERSION}/traced_entry_points",
1080 self.url
1081 .to_string()
1082 .trim_end_matches('/')
1083 );
1084 debug!(%uri, "Sending traced_entry_points request to Tycho server");
1085 trace!(?request, "Sending request to Tycho server");
1086
1087 let response = self
1088 .make_post_request(request, &uri)
1089 .await?;
1090
1091 trace!(?response, "Received response from Tycho server");
1092
1093 let body = response
1094 .text()
1095 .await
1096 .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
1097 let entrypoints =
1098 serde_json::from_str::<TracedEntryPointRequestResponse>(&body).map_err(|err| {
1099 error!("Failed to parse traced_entry_points response: {:?}", &body);
1100 RPCError::ParseResponse(format!("Error: {err}, Body: {body}"))
1101 })?;
1102 trace!(?entrypoints, "Received traced_entry_points response from Tycho server");
1103 Ok(entrypoints)
1104 }
1105
1106 async fn get_snapshots<'a>(
1107 &self,
1108 request: &SnapshotParameters<'a>,
1109 chunk_size: usize,
1110 concurrency: usize,
1111 ) -> Result<Snapshot, RPCError> {
1112 let component_ids: Vec<_> = request
1113 .components
1114 .keys()
1115 .cloned()
1116 .collect();
1117
1118 let version = VersionParam::new(
1119 None,
1120 Some({
1121 #[allow(deprecated)]
1122 BlockParam { hash: None, chain: None, number: Some(request.block_number as i64) }
1123 }),
1124 );
1125
1126 let component_tvl = if request.include_tvl && !component_ids.is_empty() {
1127 let body = ComponentTvlRequestBody::id_filtered(component_ids.clone(), request.chain);
1128 self.get_component_tvl_paginated(&body, chunk_size, concurrency)
1129 .await?
1130 .tvl
1131 } else {
1132 HashMap::new()
1133 };
1134
1135 let mut protocol_states = if !component_ids.is_empty() {
1136 self.get_protocol_states_paginated(
1137 request.chain,
1138 &component_ids,
1139 request.protocol_system,
1140 request.include_balances,
1141 &version,
1142 chunk_size,
1143 concurrency,
1144 )
1145 .await?
1146 .states
1147 .into_iter()
1148 .map(|state| (state.component_id.clone(), state))
1149 .collect()
1150 } else {
1151 HashMap::new()
1152 };
1153
1154 let states = request
1156 .components
1157 .values()
1158 .filter_map(|component| {
1159 if let Some(state) = protocol_states.remove(&component.id) {
1160 Some((
1161 component.id.clone(),
1162 ComponentWithState {
1163 state,
1164 component: component.clone(),
1165 component_tvl: component_tvl
1166 .get(&component.id)
1167 .cloned(),
1168 entrypoints: request
1169 .entrypoints
1170 .as_ref()
1171 .and_then(|map| map.get(&component.id))
1172 .cloned()
1173 .unwrap_or_default(),
1174 },
1175 ))
1176 } else if component_ids.contains(&component.id) {
1177 let component_id = &component.id;
1179 error!(?component_id, "Missing state for native component!");
1180 None
1181 } else {
1182 None
1183 }
1184 })
1185 .collect();
1186
1187 let vm_storage = if !request.contract_ids.is_empty() {
1188 let contract_states = self
1189 .get_contract_state_paginated(
1190 request.chain,
1191 request.contract_ids,
1192 request.protocol_system,
1193 &version,
1194 chunk_size,
1195 concurrency,
1196 )
1197 .await?
1198 .accounts
1199 .into_iter()
1200 .map(|acc| (acc.address.clone(), acc))
1201 .collect::<HashMap<_, _>>();
1202
1203 trace!(states=?&contract_states, "Retrieved ContractState");
1204
1205 let contract_address_to_components = request
1206 .components
1207 .iter()
1208 .filter_map(|(id, comp)| {
1209 if component_ids.contains(id) {
1210 Some(
1211 comp.contract_ids
1212 .iter()
1213 .map(|address| (address.clone(), comp.id.clone())),
1214 )
1215 } else {
1216 None
1217 }
1218 })
1219 .flatten()
1220 .fold(HashMap::<Bytes, Vec<String>>::new(), |mut acc, (addr, c_id)| {
1221 acc.entry(addr).or_default().push(c_id);
1222 acc
1223 });
1224
1225 request
1226 .contract_ids
1227 .iter()
1228 .filter_map(|address| {
1229 if let Some(state) = contract_states.get(address) {
1230 Some((address.clone(), state.clone()))
1231 } else if let Some(ids) = contract_address_to_components.get(address) {
1232 error!(
1234 ?address,
1235 ?ids,
1236 "Component with lacking contract storage encountered!"
1237 );
1238 None
1239 } else {
1240 None
1241 }
1242 })
1243 .collect()
1244 } else {
1245 HashMap::new()
1246 };
1247
1248 Ok(Snapshot { states, vm_storage })
1249 }
1250}
1251
1252#[cfg(test)]
1253mod tests {
1254 use std::{
1255 collections::{HashMap, HashSet},
1256 str::FromStr,
1257 };
1258
1259 use mockito::Server;
1260 use rstest::rstest;
1261 #[allow(deprecated)]
1263 use tycho_common::dto::ProtocolId;
1264 use tycho_common::dto::{AddressStorageLocation, TracingParams};
1265
1266 use super::*;
1267
1268 impl MockRPCClient {
1271 #[allow(clippy::too_many_arguments)]
1272 async fn test_get_protocol_states_paginated<T>(
1273 &self,
1274 chain: Chain,
1275 ids: &[T],
1276 protocol_system: &str,
1277 include_balances: bool,
1278 version: &VersionParam,
1279 chunk_size: usize,
1280 _concurrency: usize,
1281 ) -> Vec<ProtocolStateRequestBody>
1282 where
1283 T: AsRef<str> + Clone + Send + Sync + 'static,
1284 {
1285 ids.chunks(chunk_size)
1286 .map(|chunk| ProtocolStateRequestBody {
1287 protocol_ids: Some(
1288 chunk
1289 .iter()
1290 .map(|id| id.as_ref().to_string())
1291 .collect(),
1292 ),
1293 protocol_system: protocol_system.to_string(),
1294 chain,
1295 include_balances,
1296 version: version.clone(),
1297 pagination: PaginationParams { page: 0, page_size: chunk_size as i64 },
1298 })
1299 .collect()
1300 }
1301 }
1302
1303 #[allow(deprecated)]
1305 #[rstest]
1306 #[case::protocol_id_input(vec![
1307 ProtocolId { id: "id1".to_string(), chain: Chain::Ethereum },
1308 ProtocolId { id: "id2".to_string(), chain: Chain::Ethereum }
1309 ])]
1310 #[case::string_input(vec![
1311 "id1".to_string(),
1312 "id2".to_string()
1313 ])]
1314 #[tokio::test]
1315 async fn test_get_protocol_states_paginated_backwards_compatibility<T>(#[case] ids: Vec<T>)
1316 where
1317 T: AsRef<str> + Clone + Send + Sync + 'static,
1318 {
1319 let mock_client = MockRPCClient::new();
1320
1321 let request_bodies = mock_client
1322 .test_get_protocol_states_paginated(
1323 Chain::Ethereum,
1324 &ids,
1325 "test_system",
1326 true,
1327 &VersionParam::default(),
1328 2,
1329 2,
1330 )
1331 .await;
1332
1333 assert_eq!(request_bodies.len(), 1);
1335 assert_eq!(
1336 request_bodies[0]
1337 .protocol_ids
1338 .as_ref()
1339 .unwrap()
1340 .len(),
1341 2
1342 );
1343 }
1344
1345 #[tokio::test]
1346 async fn test_get_contract_state() {
1347 let mut server = Server::new_async().await;
1348 let server_resp = r#"
1349 {
1350 "accounts": [
1351 {
1352 "chain": "ethereum",
1353 "address": "0x0000000000000000000000000000000000000000",
1354 "title": "",
1355 "slots": {},
1356 "native_balance": "0x01f4",
1357 "token_balances": {},
1358 "code": "0x00",
1359 "code_hash": "0x5c06b7c5b3d910fd33bc2229846f9ddaf91d584d9b196e16636901ac3a77077e",
1360 "balance_modify_tx": "0x0000000000000000000000000000000000000000000000000000000000000000",
1361 "code_modify_tx": "0x0000000000000000000000000000000000000000000000000000000000000000",
1362 "creation_tx": null
1363 }
1364 ],
1365 "pagination": {
1366 "page": 0,
1367 "page_size": 20,
1368 "total": 10
1369 }
1370 }
1371 "#;
1372 serde_json::from_str::<StateRequestResponse>(server_resp).expect("deserialize");
1374
1375 let mocked_server = server
1376 .mock("POST", "/v1/contract_state")
1377 .expect(1)
1378 .with_body(server_resp)
1379 .create_async()
1380 .await;
1381
1382 let client = HttpRPCClient::new(server.url().as_str(), None).expect("create client");
1383
1384 let response = client
1385 .get_contract_state(&Default::default())
1386 .await
1387 .expect("get state");
1388 let accounts = response.accounts;
1389
1390 mocked_server.assert();
1391 assert_eq!(accounts.len(), 1);
1392 assert_eq!(accounts[0].slots, HashMap::new());
1393 assert_eq!(accounts[0].native_balance, Bytes::from(500u16.to_be_bytes()));
1394 assert_eq!(accounts[0].code, [0].to_vec());
1395 assert_eq!(
1396 accounts[0].code_hash,
1397 hex::decode("5c06b7c5b3d910fd33bc2229846f9ddaf91d584d9b196e16636901ac3a77077e")
1398 .unwrap()
1399 );
1400 }
1401
1402 #[tokio::test]
1403 async fn test_get_protocol_components() {
1404 let mut server = Server::new_async().await;
1405 let server_resp = r#"
1406 {
1407 "protocol_components": [
1408 {
1409 "id": "State1",
1410 "protocol_system": "ambient",
1411 "protocol_type_name": "Pool",
1412 "chain": "ethereum",
1413 "tokens": [
1414 "0x0000000000000000000000000000000000000000",
1415 "0x0000000000000000000000000000000000000001"
1416 ],
1417 "contract_ids": [
1418 "0x0000000000000000000000000000000000000000"
1419 ],
1420 "static_attributes": {
1421 "attribute_1": "0x00000000000003e8"
1422 },
1423 "change": "Creation",
1424 "creation_tx": "0x0000000000000000000000000000000000000000000000000000000000000000",
1425 "created_at": "2022-01-01T00:00:00"
1426 }
1427 ],
1428 "pagination": {
1429 "page": 0,
1430 "page_size": 20,
1431 "total": 10
1432 }
1433 }
1434 "#;
1435 serde_json::from_str::<ProtocolComponentRequestResponse>(server_resp).expect("deserialize");
1437
1438 let mocked_server = server
1439 .mock("POST", "/v1/protocol_components")
1440 .expect(1)
1441 .with_body(server_resp)
1442 .create_async()
1443 .await;
1444
1445 let client = HttpRPCClient::new(server.url().as_str(), None).expect("create client");
1446
1447 let response = client
1448 .get_protocol_components(&Default::default())
1449 .await
1450 .expect("get state");
1451 let components = response.protocol_components;
1452
1453 mocked_server.assert();
1454 assert_eq!(components.len(), 1);
1455 assert_eq!(components[0].id, "State1");
1456 assert_eq!(components[0].protocol_system, "ambient");
1457 assert_eq!(components[0].protocol_type_name, "Pool");
1458 assert_eq!(components[0].tokens.len(), 2);
1459 let expected_attributes =
1460 [("attribute_1".to_string(), Bytes::from(1000_u64.to_be_bytes()))]
1461 .iter()
1462 .cloned()
1463 .collect::<HashMap<String, Bytes>>();
1464 assert_eq!(components[0].static_attributes, expected_attributes);
1465 }
1466
1467 #[tokio::test]
1468 async fn test_get_protocol_states() {
1469 let mut server = Server::new_async().await;
1470 let server_resp = r#"
1471 {
1472 "states": [
1473 {
1474 "component_id": "State1",
1475 "attributes": {
1476 "attribute_1": "0x00000000000003e8"
1477 },
1478 "balances": {
1479 "0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2": "0x01f4"
1480 }
1481 }
1482 ],
1483 "pagination": {
1484 "page": 0,
1485 "page_size": 20,
1486 "total": 10
1487 }
1488 }
1489 "#;
1490 serde_json::from_str::<ProtocolStateRequestResponse>(server_resp).expect("deserialize");
1492
1493 let mocked_server = server
1494 .mock("POST", "/v1/protocol_state")
1495 .expect(1)
1496 .with_body(server_resp)
1497 .create_async()
1498 .await;
1499 let client = HttpRPCClient::new(server.url().as_str(), None).expect("create client");
1500
1501 let response = client
1502 .get_protocol_states(&Default::default())
1503 .await
1504 .expect("get state");
1505 let states = response.states;
1506
1507 mocked_server.assert();
1508 assert_eq!(states.len(), 1);
1509 assert_eq!(states[0].component_id, "State1");
1510 let expected_attributes =
1511 [("attribute_1".to_string(), Bytes::from(1000_u64.to_be_bytes()))]
1512 .iter()
1513 .cloned()
1514 .collect::<HashMap<String, Bytes>>();
1515 assert_eq!(states[0].attributes, expected_attributes);
1516 let expected_balances = [(
1517 Bytes::from_str("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2")
1518 .expect("Unsupported address format"),
1519 Bytes::from_str("0x01f4").unwrap(),
1520 )]
1521 .iter()
1522 .cloned()
1523 .collect::<HashMap<Bytes, Bytes>>();
1524 assert_eq!(states[0].balances, expected_balances);
1525 }
1526
1527 #[tokio::test]
1528 async fn test_get_tokens() {
1529 let mut server = Server::new_async().await;
1530 let server_resp = r#"
1531 {
1532 "tokens": [
1533 {
1534 "chain": "ethereum",
1535 "address": "0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2",
1536 "symbol": "WETH",
1537 "decimals": 18,
1538 "tax": 0,
1539 "gas": [
1540 29962
1541 ],
1542 "quality": 100
1543 },
1544 {
1545 "chain": "ethereum",
1546 "address": "0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48",
1547 "symbol": "USDC",
1548 "decimals": 6,
1549 "tax": 0,
1550 "gas": [
1551 40652
1552 ],
1553 "quality": 100
1554 }
1555 ],
1556 "pagination": {
1557 "page": 0,
1558 "page_size": 20,
1559 "total": 10
1560 }
1561 }
1562 "#;
1563 serde_json::from_str::<TokensRequestResponse>(server_resp).expect("deserialize");
1565
1566 let mocked_server = server
1567 .mock("POST", "/v1/tokens")
1568 .expect(1)
1569 .with_body(server_resp)
1570 .create_async()
1571 .await;
1572 let client = HttpRPCClient::new(server.url().as_str(), None).expect("create client");
1573
1574 let response = client
1575 .get_tokens(&Default::default())
1576 .await
1577 .expect("get tokens");
1578
1579 let expected = vec![
1580 ResponseToken {
1581 chain: Chain::Ethereum,
1582 address: Bytes::from_str("0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2").unwrap(),
1583 symbol: "WETH".to_string(),
1584 decimals: 18,
1585 tax: 0,
1586 gas: vec![Some(29962)],
1587 quality: 100,
1588 },
1589 ResponseToken {
1590 chain: Chain::Ethereum,
1591 address: Bytes::from_str("0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48").unwrap(),
1592 symbol: "USDC".to_string(),
1593 decimals: 6,
1594 tax: 0,
1595 gas: vec![Some(40652)],
1596 quality: 100,
1597 },
1598 ];
1599
1600 mocked_server.assert();
1601 assert_eq!(response.tokens, expected);
1602 assert_eq!(response.pagination, PaginationResponse { page: 0, page_size: 20, total: 10 });
1603 }
1604
1605 #[tokio::test]
1606 async fn test_get_protocol_systems() {
1607 let mut server = Server::new_async().await;
1608 let server_resp = r#"
1609 {
1610 "protocol_systems": [
1611 "system1",
1612 "system2"
1613 ],
1614 "pagination": {
1615 "page": 0,
1616 "page_size": 20,
1617 "total": 10
1618 }
1619 }
1620 "#;
1621 serde_json::from_str::<ProtocolSystemsRequestResponse>(server_resp).expect("deserialize");
1623
1624 let mocked_server = server
1625 .mock("POST", "/v1/protocol_systems")
1626 .expect(1)
1627 .with_body(server_resp)
1628 .create_async()
1629 .await;
1630 let client = HttpRPCClient::new(server.url().as_str(), None).expect("create client");
1631
1632 let response = client
1633 .get_protocol_systems(&Default::default())
1634 .await
1635 .expect("get protocol systems");
1636 let protocol_systems = response.protocol_systems;
1637
1638 mocked_server.assert();
1639 assert_eq!(protocol_systems, vec!["system1", "system2"]);
1640 }
1641
1642 #[tokio::test]
1643 async fn test_get_component_tvl() {
1644 let mut server = Server::new_async().await;
1645 let server_resp = r#"
1646 {
1647 "tvl": {
1648 "component1": 100.0
1649 },
1650 "pagination": {
1651 "page": 0,
1652 "page_size": 20,
1653 "total": 10
1654 }
1655 }
1656 "#;
1657 serde_json::from_str::<ComponentTvlRequestResponse>(server_resp).expect("deserialize");
1659
1660 let mocked_server = server
1661 .mock("POST", "/v1/component_tvl")
1662 .expect(1)
1663 .with_body(server_resp)
1664 .create_async()
1665 .await;
1666 let client = HttpRPCClient::new(server.url().as_str(), None).expect("create client");
1667
1668 let response = client
1669 .get_component_tvl(&Default::default())
1670 .await
1671 .expect("get protocol systems");
1672 let component_tvl = response.tvl;
1673
1674 mocked_server.assert();
1675 assert_eq!(component_tvl.get("component1"), Some(&100.0));
1676 }
1677
1678 #[tokio::test]
1679 async fn test_get_traced_entry_points() {
1680 let mut server = Server::new_async().await;
1681 let server_resp = r#"
1682 {
1683 "traced_entry_points": {
1684 "component_1": [
1685 [
1686 {
1687 "entry_point": {
1688 "external_id": "entrypoint_a",
1689 "target": "0x0000000000000000000000000000000000000001",
1690 "signature": "sig()"
1691 },
1692 "params": {
1693 "method": "rpctracer",
1694 "caller": "0x000000000000000000000000000000000000000a",
1695 "calldata": "0x000000000000000000000000000000000000000b"
1696 }
1697 },
1698 {
1699 "retriggers": [
1700 [
1701 "0x00000000000000000000000000000000000000aa",
1702 {"key": "0x0000000000000000000000000000000000000aaa", "offset": 12}
1703 ]
1704 ],
1705 "accessed_slots": {
1706 "0x0000000000000000000000000000000000aaaa": [
1707 "0x0000000000000000000000000000000000aaaa"
1708 ]
1709 }
1710 }
1711 ]
1712 ]
1713 },
1714 "pagination": {
1715 "page": 0,
1716 "page_size": 20,
1717 "total": 1
1718 }
1719 }
1720 "#;
1721 serde_json::from_str::<TracedEntryPointRequestResponse>(server_resp).expect("deserialize");
1723
1724 let mocked_server = server
1725 .mock("POST", "/v1/traced_entry_points")
1726 .expect(1)
1727 .with_body(server_resp)
1728 .create_async()
1729 .await;
1730 let client = HttpRPCClient::new(server.url().as_str(), None).expect("create client");
1731
1732 let response = client
1733 .get_traced_entry_points(&Default::default())
1734 .await
1735 .expect("get traced entry points");
1736 let entrypoints = response.traced_entry_points;
1737
1738 mocked_server.assert();
1739 assert_eq!(entrypoints.len(), 1);
1740 let comp1_entrypoints = entrypoints
1741 .get("component_1")
1742 .expect("component_1 entrypoints should exist");
1743 assert_eq!(comp1_entrypoints.len(), 1);
1744
1745 let (entrypoint, trace_result) = &comp1_entrypoints[0];
1746 assert_eq!(entrypoint.entry_point.external_id, "entrypoint_a");
1747 assert_eq!(
1748 entrypoint.entry_point.target,
1749 Bytes::from_str("0x0000000000000000000000000000000000000001").unwrap()
1750 );
1751 assert_eq!(entrypoint.entry_point.signature, "sig()");
1752 let TracingParams::RPCTracer(rpc_params) = &entrypoint.params;
1753 assert_eq!(
1754 rpc_params.caller,
1755 Some(Bytes::from("0x000000000000000000000000000000000000000a"))
1756 );
1757 assert_eq!(rpc_params.calldata, Bytes::from("0x000000000000000000000000000000000000000b"));
1758
1759 assert_eq!(
1760 trace_result.retriggers,
1761 HashSet::from([(
1762 Bytes::from("0x00000000000000000000000000000000000000aa"),
1763 AddressStorageLocation::new(
1764 Bytes::from("0x0000000000000000000000000000000000000aaa"),
1765 12
1766 )
1767 )])
1768 );
1769 assert_eq!(trace_result.accessed_slots.len(), 1);
1770 assert_eq!(
1771 trace_result.accessed_slots,
1772 HashMap::from([(
1773 Bytes::from("0x0000000000000000000000000000000000aaaa"),
1774 HashSet::from([Bytes::from("0x0000000000000000000000000000000000aaaa")])
1775 )])
1776 );
1777 }
1778
1779 #[tokio::test]
1780 async fn test_parse_retry_value_numeric() {
1781 let result = parse_retry_value("60");
1782 assert!(result.is_some());
1783
1784 let expected_time = SystemTime::now() + Duration::from_secs(60);
1785 let actual_time = result.unwrap();
1786
1787 let diff = if actual_time > expected_time {
1789 actual_time
1790 .duration_since(expected_time)
1791 .unwrap()
1792 } else {
1793 expected_time
1794 .duration_since(actual_time)
1795 .unwrap()
1796 };
1797 assert!(diff < Duration::from_secs(1), "Time difference too large: {:?}", diff);
1798 }
1799
1800 #[tokio::test]
1801 async fn test_parse_retry_value_rfc2822() {
1802 let rfc2822_date = "Sat, 01 Jan 2030 12:00:00 +0000";
1804 let result = parse_retry_value(rfc2822_date);
1805 assert!(result.is_some());
1806
1807 let parsed_time = result.unwrap();
1808 assert!(parsed_time > SystemTime::now());
1809 }
1810
1811 #[tokio::test]
1812 async fn test_parse_retry_value_invalid_formats() {
1813 assert!(parse_retry_value("invalid").is_none());
1815 assert!(parse_retry_value("").is_none());
1816 assert!(parse_retry_value("not_a_number").is_none());
1817 assert!(parse_retry_value("Mon, 32 Jan 2030 25:00:00 +0000").is_none()); }
1819
1820 #[tokio::test]
1821 async fn test_parse_retry_value_zero_seconds() {
1822 let result = parse_retry_value("0");
1823 assert!(result.is_some());
1824
1825 let expected_time = SystemTime::now();
1826 let actual_time = result.unwrap();
1827
1828 let diff = if actual_time > expected_time {
1830 actual_time
1831 .duration_since(expected_time)
1832 .unwrap()
1833 } else {
1834 expected_time
1835 .duration_since(actual_time)
1836 .unwrap()
1837 };
1838 assert!(diff < Duration::from_secs(1));
1839 }
1840
1841 #[tokio::test]
1842 async fn test_error_for_response_rate_limited() {
1843 let mut server = Server::new_async().await;
1844 let mock = server
1845 .mock("GET", "/test")
1846 .with_status(429)
1847 .with_header("Retry-After", "60")
1848 .create_async()
1849 .await;
1850
1851 let client = reqwest::Client::new();
1852 let response = client
1853 .get(format!("{}/test", server.url()))
1854 .send()
1855 .await
1856 .unwrap();
1857
1858 let http_client = HttpRPCClient::new(server.url().as_str(), None)
1859 .unwrap()
1860 .with_test_backoff_policy();
1861 let result = http_client
1862 .error_for_response(response)
1863 .await;
1864
1865 mock.assert();
1866 assert!(matches!(result, Err(RPCError::RateLimited(_))));
1867 if let Err(RPCError::RateLimited(retry_after)) = result {
1868 assert!(retry_after.is_some());
1869 }
1870 }
1871
1872 #[tokio::test]
1873 async fn test_error_for_response_rate_limited_no_header() {
1874 let mut server = Server::new_async().await;
1875 let mock = server
1876 .mock("GET", "/test")
1877 .with_status(429)
1878 .create_async()
1879 .await;
1880
1881 let client = reqwest::Client::new();
1882 let response = client
1883 .get(format!("{}/test", server.url()))
1884 .send()
1885 .await
1886 .unwrap();
1887
1888 let http_client = HttpRPCClient::new(server.url().as_str(), None)
1889 .unwrap()
1890 .with_test_backoff_policy();
1891 let result = http_client
1892 .error_for_response(response)
1893 .await;
1894
1895 mock.assert();
1896 assert!(matches!(result, Err(RPCError::RateLimited(None))));
1897 }
1898
1899 #[tokio::test]
1900 async fn test_error_for_response_server_errors() {
1901 let test_cases =
1902 vec![(502, "Bad Gateway"), (503, "Service Unavailable"), (504, "Gateway Timeout")];
1903
1904 for (status_code, expected_body) in test_cases {
1905 let mut server = Server::new_async().await;
1906 let mock = server
1907 .mock("GET", "/test")
1908 .with_status(status_code)
1909 .with_body(expected_body)
1910 .create_async()
1911 .await;
1912
1913 let client = reqwest::Client::new();
1914 let response = client
1915 .get(format!("{}/test", server.url()))
1916 .send()
1917 .await
1918 .unwrap();
1919
1920 let http_client = HttpRPCClient::new(server.url().as_str(), None)
1921 .unwrap()
1922 .with_test_backoff_policy();
1923 let result = http_client
1924 .error_for_response(response)
1925 .await;
1926
1927 mock.assert();
1928 assert!(matches!(result, Err(RPCError::ServerUnreachable(_))));
1929 if let Err(RPCError::ServerUnreachable(body)) = result {
1930 assert_eq!(body, expected_body);
1931 }
1932 }
1933 }
1934
1935 #[tokio::test]
1936 async fn test_error_for_response_success() {
1937 let mut server = Server::new_async().await;
1938 let mock = server
1939 .mock("GET", "/test")
1940 .with_status(200)
1941 .with_body("success")
1942 .create_async()
1943 .await;
1944
1945 let client = reqwest::Client::new();
1946 let response = client
1947 .get(format!("{}/test", server.url()))
1948 .send()
1949 .await
1950 .unwrap();
1951
1952 let http_client = HttpRPCClient::new(server.url().as_str(), None)
1953 .unwrap()
1954 .with_test_backoff_policy();
1955 let result = http_client
1956 .error_for_response(response)
1957 .await;
1958
1959 mock.assert();
1960 assert!(result.is_ok());
1961
1962 let response = result.unwrap();
1963 assert_eq!(response.status(), 200);
1964 }
1965
1966 #[tokio::test]
1967 async fn test_handle_error_for_backoff_server_unreachable() {
1968 let http_client = HttpRPCClient::new("http://localhost:8080", None)
1969 .unwrap()
1970 .with_test_backoff_policy();
1971 let error = RPCError::ServerUnreachable("Service down".to_string());
1972
1973 let backoff_error = http_client
1974 .handle_error_for_backoff(error)
1975 .await;
1976
1977 match backoff_error {
1978 backoff::Error::Transient { err: RPCError::ServerUnreachable(msg), retry_after } => {
1979 assert_eq!(msg, "Service down");
1980 assert_eq!(retry_after, Some(Duration::from_millis(50))); }
1982 _ => panic!("Expected transient error for ServerUnreachable"),
1983 }
1984 }
1985
1986 #[tokio::test]
1987 async fn test_handle_error_for_backoff_rate_limited_with_retry_after() {
1988 let http_client = HttpRPCClient::new("http://localhost:8080", None)
1989 .unwrap()
1990 .with_test_backoff_policy();
1991 let future_time = SystemTime::now() + Duration::from_secs(30);
1992 let error = RPCError::RateLimited(Some(future_time));
1993
1994 let backoff_error = http_client
1995 .handle_error_for_backoff(error)
1996 .await;
1997
1998 match backoff_error {
1999 backoff::Error::Transient { err: RPCError::RateLimited(retry_after), .. } => {
2000 assert_eq!(retry_after, Some(future_time));
2001 }
2002 _ => panic!("Expected transient error for RateLimited"),
2003 }
2004
2005 let stored_retry_after = http_client.retry_after.read().await;
2007 assert_eq!(*stored_retry_after, Some(future_time));
2008 }
2009
2010 #[tokio::test]
2011 async fn test_handle_error_for_backoff_rate_limited_no_retry_after() {
2012 let http_client = HttpRPCClient::new("http://localhost:8080", None)
2013 .unwrap()
2014 .with_test_backoff_policy();
2015 let error = RPCError::RateLimited(None);
2016
2017 let backoff_error = http_client
2018 .handle_error_for_backoff(error)
2019 .await;
2020
2021 match backoff_error {
2022 backoff::Error::Transient { err: RPCError::RateLimited(None), .. } => {
2023 }
2025 _ => panic!("Expected transient error for RateLimited without retry-after"),
2026 }
2027 }
2028
2029 #[tokio::test]
2030 async fn test_handle_error_for_backoff_other_errors() {
2031 let http_client = HttpRPCClient::new("http://localhost:8080", None)
2032 .unwrap()
2033 .with_test_backoff_policy();
2034 let error = RPCError::ParseResponse("Invalid JSON".to_string());
2035
2036 let backoff_error = http_client
2037 .handle_error_for_backoff(error)
2038 .await;
2039
2040 match backoff_error {
2041 backoff::Error::Permanent(RPCError::ParseResponse(msg)) => {
2042 assert_eq!(msg, "Invalid JSON");
2043 }
2044 _ => panic!("Expected permanent error for ParseResponse"),
2045 }
2046 }
2047
2048 #[tokio::test]
2049 async fn test_wait_until_retry_after_no_retry_time() {
2050 let http_client = HttpRPCClient::new("http://localhost:8080", None)
2051 .unwrap()
2052 .with_test_backoff_policy();
2053
2054 let start = std::time::Instant::now();
2055 http_client
2056 .wait_until_retry_after()
2057 .await;
2058 let elapsed = start.elapsed();
2059
2060 assert!(elapsed < Duration::from_millis(100));
2062 }
2063
2064 #[tokio::test]
2065 async fn test_wait_until_retry_after_past_time() {
2066 let http_client = HttpRPCClient::new("http://localhost:8080", None)
2067 .unwrap()
2068 .with_test_backoff_policy();
2069
2070 let past_time = SystemTime::now() - Duration::from_secs(10);
2072 *http_client.retry_after.write().await = Some(past_time);
2073
2074 let start = std::time::Instant::now();
2075 http_client
2076 .wait_until_retry_after()
2077 .await;
2078 let elapsed = start.elapsed();
2079
2080 assert!(elapsed < Duration::from_millis(100));
2082 }
2083
2084 #[tokio::test]
2085 async fn test_wait_until_retry_after_future_time() {
2086 let http_client = HttpRPCClient::new("http://localhost:8080", None)
2087 .unwrap()
2088 .with_test_backoff_policy();
2089
2090 let future_time = SystemTime::now() + Duration::from_millis(100);
2092 *http_client.retry_after.write().await = Some(future_time);
2093
2094 let start = std::time::Instant::now();
2095 http_client
2096 .wait_until_retry_after()
2097 .await;
2098 let elapsed = start.elapsed();
2099
2100 assert!(elapsed >= Duration::from_millis(80)); assert!(elapsed <= Duration::from_millis(200)); }
2104
2105 #[tokio::test]
2106 async fn test_make_post_request_success() {
2107 let mut server = Server::new_async().await;
2108 let server_resp = r#"{"success": true}"#;
2109
2110 let mock = server
2111 .mock("POST", "/test")
2112 .with_status(200)
2113 .with_body(server_resp)
2114 .create_async()
2115 .await;
2116
2117 let http_client = HttpRPCClient::new(server.url().as_str(), None)
2118 .unwrap()
2119 .with_test_backoff_policy();
2120 let request_body = serde_json::json!({"test": "data"});
2121 let uri = format!("{}/test", server.url());
2122
2123 let result = http_client
2124 .make_post_request(&request_body, &uri)
2125 .await;
2126
2127 mock.assert();
2128 assert!(result.is_ok());
2129
2130 let response = result.unwrap();
2131 assert_eq!(response.status(), 200);
2132 assert_eq!(response.text().await.unwrap(), server_resp);
2133 }
2134
2135 #[tokio::test]
2136 async fn test_make_post_request_retry_on_server_error() {
2137 let mut server = Server::new_async().await;
2138 let error_mock = server
2140 .mock("POST", "/test")
2141 .with_status(503)
2142 .with_body("Service Unavailable")
2143 .expect(1)
2144 .create_async()
2145 .await;
2146
2147 let success_mock = server
2148 .mock("POST", "/test")
2149 .with_status(200)
2150 .with_body(r#"{"success": true}"#)
2151 .expect(1)
2152 .create_async()
2153 .await;
2154
2155 let http_client = HttpRPCClient::new(server.url().as_str(), None)
2156 .unwrap()
2157 .with_test_backoff_policy();
2158 let request_body = serde_json::json!({"test": "data"});
2159 let uri = format!("{}/test", server.url());
2160
2161 let result = http_client
2162 .make_post_request(&request_body, &uri)
2163 .await;
2164
2165 error_mock.assert();
2166 success_mock.assert();
2167 assert!(result.is_ok());
2168 }
2169
2170 #[tokio::test]
2171 async fn test_make_post_request_respect_retry_after_header() {
2172 let mut server = Server::new_async().await;
2173
2174 let rate_limit_mock = server
2176 .mock("POST", "/test")
2177 .with_status(429)
2178 .with_header("Retry-After", "1") .expect(1)
2180 .create_async()
2181 .await;
2182
2183 let success_mock = server
2184 .mock("POST", "/test")
2185 .with_status(200)
2186 .with_body(r#"{"success": true}"#)
2187 .expect(1)
2188 .create_async()
2189 .await;
2190
2191 let http_client = HttpRPCClient::new(server.url().as_str(), None)
2192 .unwrap()
2193 .with_test_backoff_policy();
2194 let request_body = serde_json::json!({"test": "data"});
2195 let uri = format!("{}/test", server.url());
2196
2197 let start = std::time::Instant::now();
2198 let result = http_client
2199 .make_post_request(&request_body, &uri)
2200 .await;
2201 let elapsed = start.elapsed();
2202
2203 rate_limit_mock.assert();
2204 success_mock.assert();
2205 assert!(result.is_ok());
2206
2207 assert!(elapsed >= Duration::from_millis(900)); assert!(elapsed <= Duration::from_millis(2000)); }
2211
2212 #[tokio::test]
2213 async fn test_make_post_request_permanent_error() {
2214 let mut server = Server::new_async().await;
2215
2216 let mock = server
2217 .mock("POST", "/test")
2218 .with_status(400) .with_body("Bad Request")
2220 .expect(1)
2221 .create_async()
2222 .await;
2223
2224 let http_client = HttpRPCClient::new(server.url().as_str(), None)
2225 .unwrap()
2226 .with_test_backoff_policy();
2227 let request_body = serde_json::json!({"test": "data"});
2228 let uri = format!("{}/test", server.url());
2229
2230 let result = http_client
2231 .make_post_request(&request_body, &uri)
2232 .await;
2233
2234 mock.assert();
2235 assert!(result.is_ok()); let response = result.unwrap();
2238 assert_eq!(response.status(), 400);
2239 }
2240
2241 #[tokio::test]
2242 async fn test_concurrent_requests_with_different_retry_after() {
2243 let mut server = Server::new_async().await;
2244
2245 let rate_limit_mock_1 = server
2247 .mock("POST", "/test1")
2248 .with_status(429)
2249 .with_header("Retry-After", "1")
2250 .expect(1)
2251 .create_async()
2252 .await;
2253
2254 let rate_limit_mock_2 = server
2256 .mock("POST", "/test2")
2257 .with_status(429)
2258 .with_header("Retry-After", "2")
2259 .expect(1)
2260 .create_async()
2261 .await;
2262
2263 let success_mock_1 = server
2265 .mock("POST", "/test1")
2266 .with_status(200)
2267 .with_body(r#"{"result": "success1"}"#)
2268 .expect(1)
2269 .create_async()
2270 .await;
2271
2272 let success_mock_2 = server
2273 .mock("POST", "/test2")
2274 .with_status(200)
2275 .with_body(r#"{"result": "success2"}"#)
2276 .expect(1)
2277 .create_async()
2278 .await;
2279
2280 let http_client = HttpRPCClient::new(server.url().as_str(), None)
2281 .unwrap()
2282 .with_test_backoff_policy();
2283 let request_body = serde_json::json!({"test": "data"});
2284
2285 let uri1 = format!("{}/test1", server.url());
2286 let uri2 = format!("{}/test2", server.url());
2287
2288 let start = std::time::Instant::now();
2290 let (result1, result2) = tokio::join!(
2291 http_client.make_post_request(&request_body, &uri1),
2292 http_client.make_post_request(&request_body, &uri2)
2293 );
2294 let elapsed = start.elapsed();
2295
2296 rate_limit_mock_1.assert();
2297 rate_limit_mock_2.assert();
2298 success_mock_1.assert();
2299 success_mock_2.assert();
2300
2301 assert!(result1.is_ok());
2302 assert!(result2.is_ok());
2303
2304 assert!(elapsed >= Duration::from_millis(1800)); assert!(elapsed <= Duration::from_millis(3000)); let final_retry_after = http_client.retry_after.read().await;
2312 assert!(final_retry_after.is_some());
2313
2314 if let Some(retry_time) = *final_retry_after {
2316 let now = SystemTime::now();
2319 let diff = if retry_time > now {
2320 retry_time.duration_since(now).unwrap()
2321 } else {
2322 now.duration_since(retry_time).unwrap()
2323 };
2324
2325 assert!(diff <= Duration::from_secs(3), "Retry time difference too large: {:?}", diff);
2327 }
2328 }
2329
2330 #[tokio::test]
2331 async fn test_get_snapshots() {
2332 let mut server = Server::new_async().await;
2333
2334 let protocol_states_resp = r#"
2336 {
2337 "states": [
2338 {
2339 "component_id": "component1",
2340 "attributes": {
2341 "attribute_1": "0x00000000000003e8"
2342 },
2343 "balances": {
2344 "0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2": "0x01f4"
2345 }
2346 }
2347 ],
2348 "pagination": {
2349 "page": 0,
2350 "page_size": 100,
2351 "total": 1
2352 }
2353 }
2354 "#;
2355
2356 let contract_state_resp = r#"
2358 {
2359 "accounts": [
2360 {
2361 "chain": "ethereum",
2362 "address": "0x1111111111111111111111111111111111111111",
2363 "title": "",
2364 "slots": {},
2365 "native_balance": "0x01f4",
2366 "token_balances": {},
2367 "code": "0x00",
2368 "code_hash": "0x5c06b7c5b3d910fd33bc2229846f9ddaf91d584d9b196e16636901ac3a77077e",
2369 "balance_modify_tx": "0x0000000000000000000000000000000000000000000000000000000000000000",
2370 "code_modify_tx": "0x0000000000000000000000000000000000000000000000000000000000000000",
2371 "creation_tx": null
2372 }
2373 ],
2374 "pagination": {
2375 "page": 0,
2376 "page_size": 100,
2377 "total": 1
2378 }
2379 }
2380 "#;
2381
2382 let tvl_resp = r#"
2384 {
2385 "tvl": {
2386 "component1": 1000000.0
2387 },
2388 "pagination": {
2389 "page": 0,
2390 "page_size": 100,
2391 "total": 1
2392 }
2393 }
2394 "#;
2395
2396 let protocol_states_mock = server
2397 .mock("POST", "/v1/protocol_state")
2398 .expect(1)
2399 .with_body(protocol_states_resp)
2400 .create_async()
2401 .await;
2402
2403 let contract_state_mock = server
2404 .mock("POST", "/v1/contract_state")
2405 .expect(1)
2406 .with_body(contract_state_resp)
2407 .create_async()
2408 .await;
2409
2410 let tvl_mock = server
2411 .mock("POST", "/v1/component_tvl")
2412 .expect(1)
2413 .with_body(tvl_resp)
2414 .create_async()
2415 .await;
2416
2417 let client = HttpRPCClient::new(server.url().as_str(), None).expect("create client");
2418
2419 #[allow(deprecated)]
2420 let component = tycho_common::dto::ProtocolComponent {
2421 id: "component1".to_string(),
2422 protocol_system: "test_protocol".to_string(),
2423 protocol_type_name: "test_type".to_string(),
2424 chain: Chain::Ethereum,
2425 tokens: vec![],
2426 contract_ids: vec![
2427 Bytes::from_str("0x1111111111111111111111111111111111111111").unwrap()
2428 ],
2429 static_attributes: HashMap::new(),
2430 change: tycho_common::dto::ChangeType::Creation,
2431 creation_tx: Bytes::from_str(
2432 "0x0000000000000000000000000000000000000000000000000000000000000000",
2433 )
2434 .unwrap(),
2435 created_at: chrono::Utc::now().naive_utc(),
2436 };
2437
2438 let mut components = HashMap::new();
2439 components.insert("component1".to_string(), component);
2440
2441 let contract_ids =
2442 vec![Bytes::from_str("0x1111111111111111111111111111111111111111").unwrap()];
2443
2444 let request = SnapshotParameters::new(
2445 Chain::Ethereum,
2446 "test_protocol",
2447 &components,
2448 &contract_ids,
2449 12345,
2450 );
2451
2452 let response = client
2453 .get_snapshots(&request, 100, 4)
2454 .await
2455 .expect("get snapshots");
2456
2457 protocol_states_mock.assert();
2459 contract_state_mock.assert();
2460 tvl_mock.assert();
2461
2462 assert_eq!(response.states.len(), 1);
2464 assert!(response
2465 .states
2466 .contains_key("component1"));
2467
2468 let component_state = response
2470 .states
2471 .get("component1")
2472 .unwrap();
2473 assert_eq!(component_state.component_tvl, Some(1000000.0));
2474
2475 assert_eq!(response.vm_storage.len(), 1);
2477 let contract_addr = Bytes::from_str("0x1111111111111111111111111111111111111111").unwrap();
2478 assert!(response
2479 .vm_storage
2480 .contains_key(&contract_addr));
2481 }
2482
2483 #[tokio::test]
2484 async fn test_get_snapshots_empty_components() {
2485 let server = Server::new_async().await;
2486 let client = HttpRPCClient::new(server.url().as_str(), None).expect("create client");
2487
2488 let components = HashMap::new();
2489 let contract_ids = vec![];
2490
2491 let request = SnapshotParameters::new(
2492 Chain::Ethereum,
2493 "test_protocol",
2494 &components,
2495 &contract_ids,
2496 12345,
2497 );
2498
2499 let response = client
2500 .get_snapshots(&request, 100, 4)
2501 .await
2502 .expect("get snapshots");
2503
2504 assert!(response.states.is_empty());
2506 assert!(response.vm_storage.is_empty());
2507 }
2508
2509 #[tokio::test]
2510 async fn test_get_snapshots_without_tvl() {
2511 let mut server = Server::new_async().await;
2512
2513 let protocol_states_resp = r#"
2514 {
2515 "states": [
2516 {
2517 "component_id": "component1",
2518 "attributes": {},
2519 "balances": {}
2520 }
2521 ],
2522 "pagination": {
2523 "page": 0,
2524 "page_size": 100,
2525 "total": 1
2526 }
2527 }
2528 "#;
2529
2530 let protocol_states_mock = server
2531 .mock("POST", "/v1/protocol_state")
2532 .expect(1)
2533 .with_body(protocol_states_resp)
2534 .create_async()
2535 .await;
2536
2537 let client = HttpRPCClient::new(server.url().as_str(), None).expect("create client");
2538
2539 #[allow(deprecated)]
2541 let component = tycho_common::dto::ProtocolComponent {
2542 id: "component1".to_string(),
2543 protocol_system: "test_protocol".to_string(),
2544 protocol_type_name: "test_type".to_string(),
2545 chain: Chain::Ethereum,
2546 tokens: vec![],
2547 contract_ids: vec![],
2548 static_attributes: HashMap::new(),
2549 change: tycho_common::dto::ChangeType::Creation,
2550 creation_tx: Bytes::from_str(
2551 "0x0000000000000000000000000000000000000000000000000000000000000000",
2552 )
2553 .unwrap(),
2554 created_at: chrono::Utc::now().naive_utc(),
2555 };
2556
2557 let mut components = HashMap::new();
2558 components.insert("component1".to_string(), component);
2559 let contract_ids = vec![];
2560
2561 let request = SnapshotParameters::new(
2562 Chain::Ethereum,
2563 "test_protocol",
2564 &components,
2565 &contract_ids,
2566 12345,
2567 )
2568 .include_balances(false)
2569 .include_tvl(false);
2570
2571 let response = client
2572 .get_snapshots(&request, 100, 4)
2573 .await
2574 .expect("get snapshots");
2575
2576 protocol_states_mock.assert();
2578 assert_eq!(response.states.len(), 1);
2582 let component_state = response
2584 .states
2585 .get("component1")
2586 .unwrap();
2587 assert_eq!(component_state.component_tvl, None);
2588 }
2589}