1use std::{
7 collections::HashMap,
8 sync::Arc,
9 time::{Duration, SystemTime},
10};
11
12use async_trait::async_trait;
13use backoff::{exponential::ExponentialBackoffBuilder, ExponentialBackoff};
14use futures03::future::try_join_all;
15#[cfg(test)]
16use mockall::automock;
17use reqwest::{header, Client, ClientBuilder, Response, StatusCode, Url};
18use serde::Serialize;
19use thiserror::Error;
20use time::{format_description::well_known::Rfc2822, OffsetDateTime};
21use tokio::{
22 sync::{RwLock, Semaphore},
23 time::sleep,
24};
25use tracing::{debug, error, instrument, trace, warn};
26use tycho_common::{
27 dto::{
28 Chain, ComponentTvlRequestBody, ComponentTvlRequestResponse, PaginationParams,
29 PaginationResponse, ProtocolComponentRequestResponse, ProtocolComponentsRequestBody,
30 ProtocolStateRequestBody, ProtocolStateRequestResponse, ProtocolSystemsRequestBody,
31 ProtocolSystemsRequestResponse, ResponseToken, StateRequestBody, StateRequestResponse,
32 TokensRequestBody, TokensRequestResponse, TracedEntryPointRequestBody,
33 TracedEntryPointRequestResponse, VersionParam,
34 },
35 Bytes,
36};
37
38use crate::TYCHO_SERVER_VERSION;
39
40#[derive(Error, Debug)]
41pub enum RPCError {
42 #[error("Failed to parse URL: {0}. Error: {1}")]
44 UrlParsing(String, String),
45
46 #[error("Failed to format request: {0}")]
48 FormatRequest(String),
49
50 #[error("Unexpected HTTP client error: {0}")]
52 HttpClient(String, #[source] reqwest::Error),
53
54 #[error("Failed to parse response: {0}")]
56 ParseResponse(String),
57
58 #[error("Fatal error: {0}")]
60 Fatal(String),
61
62 #[error("Rate limited until {0:?}")]
63 RateLimited(Option<SystemTime>),
64
65 #[error("Server unreachable: {0}")]
66 ServerUnreachable(String),
67}
68
69#[cfg_attr(test, automock)]
70#[async_trait]
71pub trait RPCClient: Send + Sync {
72 async fn get_contract_state(
74 &self,
75 request: &StateRequestBody,
76 ) -> Result<StateRequestResponse, RPCError>;
77
78 async fn get_contract_state_paginated(
79 &self,
80 chain: Chain,
81 ids: &[Bytes],
82 protocol_system: &str,
83 version: &VersionParam,
84 chunk_size: usize,
85 concurrency: usize,
86 ) -> Result<StateRequestResponse, RPCError> {
87 let semaphore = Arc::new(Semaphore::new(concurrency));
88
89 let mut sorted_ids = ids.to_vec();
91 sorted_ids.sort();
92
93 let chunked_bodies = sorted_ids
94 .chunks(chunk_size)
95 .map(|chunk| StateRequestBody {
96 contract_ids: Some(chunk.to_vec()),
97 protocol_system: protocol_system.to_string(),
98 chain,
99 version: version.clone(),
100 pagination: PaginationParams { page: 0, page_size: chunk_size as i64 },
101 })
102 .collect::<Vec<_>>();
103
104 let mut tasks = Vec::new();
105 for body in chunked_bodies.iter() {
106 let sem = semaphore.clone();
107 tasks.push(async move {
108 let _permit = sem
109 .acquire()
110 .await
111 .map_err(|_| RPCError::Fatal("Semaphore dropped".to_string()))?;
112 self.get_contract_state(body).await
113 });
114 }
115
116 let responses = try_join_all(tasks).await?;
118
119 let accounts = responses
121 .iter()
122 .flat_map(|r| r.accounts.clone())
123 .collect();
124 let total: i64 = responses
125 .iter()
126 .map(|r| r.pagination.total)
127 .sum();
128
129 Ok(StateRequestResponse {
130 accounts,
131 pagination: PaginationResponse { page: 0, page_size: chunk_size as i64, total },
132 })
133 }
134
135 async fn get_protocol_components(
136 &self,
137 request: &ProtocolComponentsRequestBody,
138 ) -> Result<ProtocolComponentRequestResponse, RPCError>;
139
140 async fn get_protocol_components_paginated(
141 &self,
142 request: &ProtocolComponentsRequestBody,
143 chunk_size: usize,
144 concurrency: usize,
145 ) -> Result<ProtocolComponentRequestResponse, RPCError> {
146 let semaphore = Arc::new(Semaphore::new(concurrency));
147
148 match request.component_ids {
151 Some(ref ids) => {
152 let chunked_bodies = ids
154 .chunks(chunk_size)
155 .enumerate()
156 .map(|(index, _)| ProtocolComponentsRequestBody {
157 protocol_system: request.protocol_system.clone(),
158 component_ids: request.component_ids.clone(),
159 tvl_gt: request.tvl_gt,
160 chain: request.chain,
161 pagination: PaginationParams {
162 page: index as i64,
163 page_size: chunk_size as i64,
164 },
165 })
166 .collect::<Vec<_>>();
167
168 let mut tasks = Vec::new();
169 for body in chunked_bodies.iter() {
170 let sem = semaphore.clone();
171 tasks.push(async move {
172 let _permit = sem
173 .acquire()
174 .await
175 .map_err(|_| RPCError::Fatal("Semaphore dropped".to_string()))?;
176 self.get_protocol_components(body).await
177 });
178 }
179
180 try_join_all(tasks)
181 .await
182 .map(|responses| ProtocolComponentRequestResponse {
183 protocol_components: responses
184 .into_iter()
185 .flat_map(|r| r.protocol_components.into_iter())
186 .collect(),
187 pagination: PaginationResponse {
188 page: 0,
189 page_size: chunk_size as i64,
190 total: ids.len() as i64,
191 },
192 })
193 }
194 _ => {
195 let initial_request = ProtocolComponentsRequestBody {
199 protocol_system: request.protocol_system.clone(),
200 component_ids: request.component_ids.clone(),
201 tvl_gt: request.tvl_gt,
202 chain: request.chain,
203 pagination: PaginationParams { page: 0, page_size: chunk_size as i64 },
204 };
205 let first_response = self
206 .get_protocol_components(&initial_request)
207 .await
208 .map_err(|err| RPCError::Fatal(err.to_string()))?;
209
210 let total_items = first_response.pagination.total;
211 let total_pages = (total_items as f64 / chunk_size as f64).ceil() as i64;
212
213 let mut accumulated_response = ProtocolComponentRequestResponse {
215 protocol_components: first_response.protocol_components,
216 pagination: PaginationResponse {
217 page: 0,
218 page_size: chunk_size as i64,
219 total: total_items,
220 },
221 };
222
223 let mut page = 1;
224 while page < total_pages {
225 let requests_in_this_iteration = (total_pages - page).min(concurrency as i64);
226
227 let chunked_bodies = (0..requests_in_this_iteration)
229 .map(|iter| ProtocolComponentsRequestBody {
230 protocol_system: request.protocol_system.clone(),
231 component_ids: request.component_ids.clone(),
232 tvl_gt: request.tvl_gt,
233 chain: request.chain,
234 pagination: PaginationParams {
235 page: page + iter,
236 page_size: chunk_size as i64,
237 },
238 })
239 .collect::<Vec<_>>();
240
241 let tasks: Vec<_> = chunked_bodies
242 .iter()
243 .map(|body| {
244 let sem = semaphore.clone();
245 async move {
246 let _permit = sem.acquire().await.map_err(|_| {
247 RPCError::Fatal("Semaphore dropped".to_string())
248 })?;
249 self.get_protocol_components(body).await
250 }
251 })
252 .collect();
253
254 let responses = try_join_all(tasks)
255 .await
256 .map(|responses| {
257 let total = responses[0].pagination.total;
258 ProtocolComponentRequestResponse {
259 protocol_components: responses
260 .into_iter()
261 .flat_map(|r| r.protocol_components.into_iter())
262 .collect(),
263 pagination: PaginationResponse {
264 page,
265 page_size: chunk_size as i64,
266 total,
267 },
268 }
269 });
270
271 match responses {
273 Ok(mut resp) => {
274 accumulated_response
275 .protocol_components
276 .append(&mut resp.protocol_components);
277 }
278 Err(e) => return Err(e),
279 }
280
281 page += concurrency as i64;
282 }
283 Ok(accumulated_response)
284 }
285 }
286 }
287
288 async fn get_protocol_states(
289 &self,
290 request: &ProtocolStateRequestBody,
291 ) -> Result<ProtocolStateRequestResponse, RPCError>;
292
293 #[allow(clippy::too_many_arguments)]
294 async fn get_protocol_states_paginated<T>(
295 &self,
296 chain: Chain,
297 ids: &[T],
298 protocol_system: &str,
299 include_balances: bool,
300 version: &VersionParam,
301 chunk_size: usize,
302 concurrency: usize,
303 ) -> Result<ProtocolStateRequestResponse, RPCError>
304 where
305 T: AsRef<str> + Sync + 'static,
306 {
307 let semaphore = Arc::new(Semaphore::new(concurrency));
308 let chunked_bodies = ids
309 .chunks(chunk_size)
310 .map(|c| ProtocolStateRequestBody {
311 protocol_ids: Some(
312 c.iter()
313 .map(|id| id.as_ref().to_string())
314 .collect(),
315 ),
316 protocol_system: protocol_system.to_string(),
317 chain,
318 include_balances,
319 version: version.clone(),
320 pagination: PaginationParams { page: 0, page_size: chunk_size as i64 },
321 })
322 .collect::<Vec<_>>();
323
324 let mut tasks = Vec::new();
325 for body in chunked_bodies.iter() {
326 let sem = semaphore.clone();
327 tasks.push(async move {
328 let _permit = sem
329 .acquire()
330 .await
331 .map_err(|_| RPCError::Fatal("Semaphore dropped".to_string()))?;
332 self.get_protocol_states(body).await
333 });
334 }
335
336 try_join_all(tasks)
337 .await
338 .map(|responses| {
339 let states = responses
340 .clone()
341 .into_iter()
342 .flat_map(|r| r.states)
343 .collect();
344 let total = responses
345 .iter()
346 .map(|r| r.pagination.total)
347 .sum();
348 ProtocolStateRequestResponse {
349 states,
350 pagination: PaginationResponse { page: 0, page_size: chunk_size as i64, total },
351 }
352 })
353 }
354
355 async fn get_tokens(
358 &self,
359 request: &TokensRequestBody,
360 ) -> Result<TokensRequestResponse, RPCError>;
361
362 async fn get_all_tokens(
363 &self,
364 chain: Chain,
365 min_quality: Option<i32>,
366 traded_n_days_ago: Option<u64>,
367 chunk_size: usize,
368 ) -> Result<Vec<ResponseToken>, RPCError> {
369 let mut request_page = 0;
370 let mut all_tokens = Vec::new();
371 loop {
372 let mut response = self
373 .get_tokens(&TokensRequestBody {
374 token_addresses: None,
375 min_quality,
376 traded_n_days_ago,
377 pagination: PaginationParams {
378 page: request_page,
379 page_size: chunk_size.try_into().map_err(|_| {
380 RPCError::FormatRequest(
381 "Failed to convert chunk_size into i64".to_string(),
382 )
383 })?,
384 },
385 chain,
386 })
387 .await?;
388
389 let num_tokens = response.tokens.len();
390 all_tokens.append(&mut response.tokens);
391 request_page += 1;
392
393 if num_tokens < chunk_size {
394 break;
395 }
396 }
397 Ok(all_tokens)
398 }
399
400 async fn get_protocol_systems(
401 &self,
402 request: &ProtocolSystemsRequestBody,
403 ) -> Result<ProtocolSystemsRequestResponse, RPCError>;
404
405 async fn get_component_tvl(
406 &self,
407 request: &ComponentTvlRequestBody,
408 ) -> Result<ComponentTvlRequestResponse, RPCError>;
409
410 async fn get_component_tvl_paginated(
411 &self,
412 request: &ComponentTvlRequestBody,
413 chunk_size: usize,
414 concurrency: usize,
415 ) -> Result<ComponentTvlRequestResponse, RPCError> {
416 let semaphore = Arc::new(Semaphore::new(concurrency));
417
418 match request.component_ids {
419 Some(ref ids) => {
420 let chunked_requests = ids
421 .chunks(chunk_size)
422 .enumerate()
423 .map(|(index, _)| ComponentTvlRequestBody {
424 chain: request.chain,
425 protocol_system: request.protocol_system.clone(),
426 component_ids: Some(ids.clone()),
427 pagination: PaginationParams {
428 page: index as i64,
429 page_size: chunk_size as i64,
430 },
431 })
432 .collect::<Vec<_>>();
433
434 let tasks: Vec<_> = chunked_requests
435 .into_iter()
436 .map(|req| {
437 let sem = semaphore.clone();
438 async move {
439 let _permit = sem
440 .acquire()
441 .await
442 .map_err(|_| RPCError::Fatal("Semaphore dropped".to_string()))?;
443 self.get_component_tvl(&req).await
444 }
445 })
446 .collect();
447
448 let responses = try_join_all(tasks).await?;
449
450 let mut merged_tvl = HashMap::new();
451 for resp in responses {
452 for (key, value) in resp.tvl {
453 *merged_tvl.entry(key).or_insert(0.0) = value;
454 }
455 }
456
457 Ok(ComponentTvlRequestResponse {
458 tvl: merged_tvl,
459 pagination: PaginationResponse {
460 page: 0,
461 page_size: chunk_size as i64,
462 total: ids.len() as i64,
463 },
464 })
465 }
466 _ => {
467 let first_request = ComponentTvlRequestBody {
468 chain: request.chain,
469 protocol_system: request.protocol_system.clone(),
470 component_ids: request.component_ids.clone(),
471 pagination: PaginationParams { page: 0, page_size: chunk_size as i64 },
472 };
473
474 let first_response = self
475 .get_component_tvl(&first_request)
476 .await?;
477 let total_items = first_response.pagination.total;
478 let total_pages = (total_items as f64 / chunk_size as f64).ceil() as i64;
479
480 let mut merged_tvl = first_response.tvl;
481
482 let mut page = 1;
483 while page < total_pages {
484 let requests_in_this_iteration = (total_pages - page).min(concurrency as i64);
485
486 let chunked_requests: Vec<_> = (0..requests_in_this_iteration)
487 .map(|i| ComponentTvlRequestBody {
488 chain: request.chain,
489 protocol_system: request.protocol_system.clone(),
490 component_ids: request.component_ids.clone(),
491 pagination: PaginationParams {
492 page: page + i,
493 page_size: chunk_size as i64,
494 },
495 })
496 .collect();
497
498 let tasks: Vec<_> = chunked_requests
499 .into_iter()
500 .map(|req| {
501 let sem = semaphore.clone();
502 async move {
503 let _permit = sem.acquire().await.map_err(|_| {
504 RPCError::Fatal("Semaphore dropped".to_string())
505 })?;
506 self.get_component_tvl(&req).await
507 }
508 })
509 .collect();
510
511 let responses = try_join_all(tasks).await?;
512
513 for resp in responses {
515 for (key, value) in resp.tvl {
516 *merged_tvl.entry(key).or_insert(0.0) += value;
517 }
518 }
519
520 page += concurrency as i64;
521 }
522
523 Ok(ComponentTvlRequestResponse {
524 tvl: merged_tvl,
525 pagination: PaginationResponse {
526 page: 0,
527 page_size: chunk_size as i64,
528 total: total_items,
529 },
530 })
531 }
532 }
533 }
534
535 async fn get_traced_entry_points(
536 &self,
537 request: &TracedEntryPointRequestBody,
538 ) -> Result<TracedEntryPointRequestResponse, RPCError>;
539
540 async fn get_traced_entry_points_paginated(
541 &self,
542 chain: Chain,
543 protocol_system: &str,
544 component_ids: &[String],
545 chunk_size: usize,
546 concurrency: usize,
547 ) -> Result<TracedEntryPointRequestResponse, RPCError> {
548 let semaphore = Arc::new(Semaphore::new(concurrency));
549 let chunked_bodies = component_ids
550 .chunks(chunk_size)
551 .map(|c| TracedEntryPointRequestBody {
552 chain,
553 protocol_system: protocol_system.to_string(),
554 component_ids: Some(c.to_vec()),
555 pagination: PaginationParams { page: 0, page_size: chunk_size as i64 },
556 })
557 .collect::<Vec<_>>();
558
559 let mut tasks = Vec::new();
560 for body in chunked_bodies.iter() {
561 let sem = semaphore.clone();
562 tasks.push(async move {
563 let _permit = sem
564 .acquire()
565 .await
566 .map_err(|_| RPCError::Fatal("Semaphore dropped".to_string()))?;
567 self.get_traced_entry_points(body).await
568 });
569 }
570
571 try_join_all(tasks)
572 .await
573 .map(|responses| {
574 let traced_entry_points = responses
575 .clone()
576 .into_iter()
577 .flat_map(|r| r.traced_entry_points)
578 .collect();
579 let total = responses
580 .iter()
581 .map(|r| r.pagination.total)
582 .sum();
583 TracedEntryPointRequestResponse {
584 traced_entry_points,
585 pagination: PaginationResponse { page: 0, page_size: chunk_size as i64, total },
586 }
587 })
588 }
589}
590
591#[derive(Debug, Clone)]
592pub struct HttpRPCClient {
593 http_client: Client,
594 url: Url,
595 retry_after: Arc<RwLock<Option<SystemTime>>>,
596 backoff_policy: ExponentialBackoff,
597 server_restart_duration: Duration,
598}
599
600impl HttpRPCClient {
601 pub fn new(base_uri: &str, auth_key: Option<&str>) -> Result<Self, RPCError> {
602 let uri = base_uri
603 .parse::<Url>()
604 .map_err(|e| RPCError::UrlParsing(base_uri.to_string(), e.to_string()))?;
605
606 let mut headers = header::HeaderMap::new();
608 headers.insert(header::CONTENT_TYPE, header::HeaderValue::from_static("application/json"));
609 let user_agent = format!("tycho-client-{version}", version = env!("CARGO_PKG_VERSION"));
610 headers.insert(
611 header::USER_AGENT,
612 header::HeaderValue::from_str(&user_agent)
613 .map_err(|e| RPCError::FormatRequest(format!("Invalid user agent format: {e}")))?,
614 );
615
616 if let Some(key) = auth_key {
618 let mut auth_value = header::HeaderValue::from_str(key).map_err(|e| {
619 RPCError::FormatRequest(format!("Invalid authorization key format: {e}"))
620 })?;
621 auth_value.set_sensitive(true);
622 headers.insert(header::AUTHORIZATION, auth_value);
623 }
624
625 let client = ClientBuilder::new()
626 .default_headers(headers)
627 .http2_prior_knowledge()
628 .build()
629 .map_err(|e| RPCError::HttpClient(e.to_string(), e))?;
630 Ok(Self {
631 http_client: client,
632 url: uri,
633 retry_after: Arc::new(RwLock::new(None)),
634 backoff_policy: ExponentialBackoffBuilder::new()
635 .with_initial_interval(Duration::from_millis(250))
636 .with_multiplier(1.75)
638 .with_max_interval(Duration::from_secs(30))
640 .with_max_elapsed_time(Some(Duration::from_secs(125)))
642 .build(),
643 server_restart_duration: Duration::from_secs(120),
644 })
645 }
646
647 #[cfg(test)]
648 pub fn with_test_backoff_policy(mut self) -> Self {
649 self.backoff_policy = ExponentialBackoffBuilder::new()
651 .with_initial_interval(Duration::from_millis(1))
652 .with_multiplier(1.1)
653 .with_max_interval(Duration::from_millis(5))
654 .with_max_elapsed_time(Some(Duration::from_millis(50)))
655 .build();
656 self.server_restart_duration = Duration::from_millis(50);
657 self
658 }
659
660 async fn error_for_response(
666 &self,
667 response: reqwest::Response,
668 ) -> Result<reqwest::Response, RPCError> {
669 match response.status() {
670 StatusCode::TOO_MANY_REQUESTS => {
671 let retry_after_raw = response
672 .headers()
673 .get(reqwest::header::RETRY_AFTER)
674 .and_then(|h| h.to_str().ok())
675 .and_then(parse_retry_value);
676
677 Err(RPCError::RateLimited(retry_after_raw))
678 }
679 StatusCode::BAD_GATEWAY |
680 StatusCode::SERVICE_UNAVAILABLE |
681 StatusCode::GATEWAY_TIMEOUT => Err(RPCError::ServerUnreachable(
682 response
683 .text()
684 .await
685 .unwrap_or_else(|_| "Server Unreachable".to_string()),
686 )),
687 _ => Ok(response),
688 }
689 }
690
691 async fn handle_error_for_backoff(&self, e: RPCError) -> backoff::Error<RPCError> {
697 match e {
698 RPCError::ServerUnreachable(_) => {
699 backoff::Error::retry_after(e, self.server_restart_duration)
700 }
701 RPCError::RateLimited(Some(until)) => {
702 let mut retry_after_guard = self.retry_after.write().await;
703 *retry_after_guard = Some(
704 retry_after_guard
705 .unwrap_or(until)
706 .max(until),
707 );
708
709 if let Ok(duration) = until.duration_since(SystemTime::now()) {
710 backoff::Error::retry_after(e, duration)
711 } else {
712 e.into()
713 }
714 }
715 RPCError::RateLimited(None) => e.into(),
716 _ => backoff::Error::permanent(e),
717 }
718 }
719
720 async fn wait_until_retry_after(&self) {
725 if let Some(&until) = self.retry_after.read().await.as_ref() {
726 let now = SystemTime::now();
727 if until > now {
728 if let Ok(duration) = until.duration_since(now) {
729 sleep(duration).await
730 }
731 }
732 }
733 }
734
735 async fn make_post_request<T: Serialize + ?Sized>(
740 &self,
741 request: &T,
742 uri: &String,
743 ) -> Result<Response, RPCError> {
744 self.wait_until_retry_after().await;
745 let response = backoff::future::retry(self.backoff_policy.clone(), || async {
746 let server_response = self
747 .http_client
748 .post(uri)
749 .json(request)
750 .send()
751 .await
752 .map_err(|e| RPCError::HttpClient(e.to_string(), e))?;
753
754 match self
755 .error_for_response(server_response)
756 .await
757 {
758 Ok(response) => Ok(response),
759 Err(e) => Err(self.handle_error_for_backoff(e).await),
760 }
761 })
762 .await?;
763 Ok(response)
764 }
765}
766
767fn parse_retry_value(val: &str) -> Option<SystemTime> {
768 if let Ok(secs) = val.parse::<u64>() {
769 return Some(SystemTime::now() + Duration::from_secs(secs));
770 }
771 if let Ok(date) = OffsetDateTime::parse(val, &Rfc2822) {
772 return Some(date.into());
773 }
774 None
775}
776
777#[async_trait]
778impl RPCClient for HttpRPCClient {
779 #[instrument(skip(self, request))]
780 async fn get_contract_state(
781 &self,
782 request: &StateRequestBody,
783 ) -> Result<StateRequestResponse, RPCError> {
784 if request
786 .contract_ids
787 .as_ref()
788 .is_none_or(|ids| ids.is_empty())
789 {
790 warn!("No contract ids specified in request.");
791 }
792
793 let uri = format!(
794 "{}/{}/contract_state",
795 self.url
796 .to_string()
797 .trim_end_matches('/'),
798 TYCHO_SERVER_VERSION
799 );
800 debug!(%uri, "Sending contract_state request to Tycho server");
801 trace!(?request, "Sending request to Tycho server");
802 let response = self
803 .make_post_request(request, &uri)
804 .await?;
805 trace!(?response, "Received response from Tycho server");
806
807 let body = response
808 .text()
809 .await
810 .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
811 if body.is_empty() {
812 return Ok(StateRequestResponse {
814 accounts: vec![],
815 pagination: PaginationResponse {
816 page: request.pagination.page,
817 page_size: request.pagination.page,
818 total: 0,
819 },
820 });
821 }
822
823 let accounts = serde_json::from_str::<StateRequestResponse>(&body)
824 .map_err(|err| RPCError::ParseResponse(format!("Error: {err}, Body: {body}")))?;
825 trace!(?accounts, "Received contract_state response from Tycho server");
826
827 Ok(accounts)
828 }
829
830 async fn get_protocol_components(
831 &self,
832 request: &ProtocolComponentsRequestBody,
833 ) -> Result<ProtocolComponentRequestResponse, RPCError> {
834 let uri = format!(
835 "{}/{}/protocol_components",
836 self.url
837 .to_string()
838 .trim_end_matches('/'),
839 TYCHO_SERVER_VERSION,
840 );
841 debug!(%uri, "Sending protocol_components request to Tycho server");
842 trace!(?request, "Sending request to Tycho server");
843
844 let response = self
845 .make_post_request(request, &uri)
846 .await?;
847
848 trace!(?response, "Received response from Tycho server");
849
850 let body = response
851 .text()
852 .await
853 .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
854 let components = serde_json::from_str::<ProtocolComponentRequestResponse>(&body)
855 .map_err(|err| RPCError::ParseResponse(format!("Error: {err}, Body: {body}")))?;
856 trace!(?components, "Received protocol_components response from Tycho server");
857
858 Ok(components)
859 }
860
861 async fn get_protocol_states(
862 &self,
863 request: &ProtocolStateRequestBody,
864 ) -> Result<ProtocolStateRequestResponse, RPCError> {
865 if request
867 .protocol_ids
868 .as_ref()
869 .is_none_or(|ids| ids.is_empty())
870 {
871 warn!("No protocol ids specified in request.");
872 }
873
874 let uri = format!(
875 "{}/{}/protocol_state",
876 self.url
877 .to_string()
878 .trim_end_matches('/'),
879 TYCHO_SERVER_VERSION
880 );
881 debug!(%uri, "Sending protocol_states request to Tycho server");
882 trace!(?request, "Sending request to Tycho server");
883
884 let response = self
885 .make_post_request(request, &uri)
886 .await?;
887 trace!(?response, "Received response from Tycho server");
888
889 let body = response
890 .text()
891 .await
892 .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
893
894 if body.is_empty() {
895 return Ok(ProtocolStateRequestResponse {
897 states: vec![],
898 pagination: PaginationResponse {
899 page: request.pagination.page,
900 page_size: request.pagination.page_size,
901 total: 0,
902 },
903 });
904 }
905
906 let states = serde_json::from_str::<ProtocolStateRequestResponse>(&body)
907 .map_err(|err| RPCError::ParseResponse(format!("Error: {err}, Body: {body}")))?;
908 trace!(?states, "Received protocol_states response from Tycho server");
909
910 Ok(states)
911 }
912
913 async fn get_tokens(
914 &self,
915 request: &TokensRequestBody,
916 ) -> Result<TokensRequestResponse, RPCError> {
917 let uri = format!(
918 "{}/{}/tokens",
919 self.url
920 .to_string()
921 .trim_end_matches('/'),
922 TYCHO_SERVER_VERSION
923 );
924 debug!(%uri, "Sending tokens request to Tycho server");
925
926 let response = self
927 .make_post_request(request, &uri)
928 .await?;
929
930 let body = response
931 .text()
932 .await
933 .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
934 let tokens = serde_json::from_str::<TokensRequestResponse>(&body)
935 .map_err(|err| RPCError::ParseResponse(format!("Error: {err}, Body: {body}")))?;
936
937 Ok(tokens)
938 }
939
940 async fn get_protocol_systems(
941 &self,
942 request: &ProtocolSystemsRequestBody,
943 ) -> Result<ProtocolSystemsRequestResponse, RPCError> {
944 let uri = format!(
945 "{}/{}/protocol_systems",
946 self.url
947 .to_string()
948 .trim_end_matches('/'),
949 TYCHO_SERVER_VERSION
950 );
951 debug!(%uri, "Sending protocol_systems request to Tycho server");
952 trace!(?request, "Sending request to Tycho server");
953 let response = self
954 .make_post_request(request, &uri)
955 .await?;
956 trace!(?response, "Received response from Tycho server");
957 let body = response
958 .text()
959 .await
960 .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
961 let protocol_systems = serde_json::from_str::<ProtocolSystemsRequestResponse>(&body)
962 .map_err(|err| RPCError::ParseResponse(format!("Error: {err}, Body: {body}")))?;
963 trace!(?protocol_systems, "Received protocol_systems response from Tycho server");
964 Ok(protocol_systems)
965 }
966
967 async fn get_component_tvl(
968 &self,
969 request: &ComponentTvlRequestBody,
970 ) -> Result<ComponentTvlRequestResponse, RPCError> {
971 let uri = format!(
972 "{}/{}/component_tvl",
973 self.url
974 .to_string()
975 .trim_end_matches('/'),
976 TYCHO_SERVER_VERSION
977 );
978 debug!(%uri, "Sending get_component_tvl request to Tycho server");
979 trace!(?request, "Sending request to Tycho server");
980 let response = self
981 .make_post_request(request, &uri)
982 .await?;
983 trace!(?response, "Received response from Tycho server");
984 let body = response
985 .text()
986 .await
987 .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
988 let component_tvl =
989 serde_json::from_str::<ComponentTvlRequestResponse>(&body).map_err(|err| {
990 error!("Failed to parse component_tvl response: {:?}", &body);
991 RPCError::ParseResponse(format!("Error: {err}, Body: {body}"))
992 })?;
993 trace!(?component_tvl, "Received component_tvl response from Tycho server");
994 Ok(component_tvl)
995 }
996
997 async fn get_traced_entry_points(
998 &self,
999 request: &TracedEntryPointRequestBody,
1000 ) -> Result<TracedEntryPointRequestResponse, RPCError> {
1001 let uri = format!(
1002 "{}/{TYCHO_SERVER_VERSION}/traced_entry_points",
1003 self.url
1004 .to_string()
1005 .trim_end_matches('/')
1006 );
1007 debug!(%uri, "Sending traced_entry_points request to Tycho server");
1008 trace!(?request, "Sending request to Tycho server");
1009
1010 let response = self
1011 .make_post_request(request, &uri)
1012 .await?;
1013
1014 trace!(?response, "Received response from Tycho server");
1015
1016 let body = response
1017 .text()
1018 .await
1019 .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
1020 let entrypoints =
1021 serde_json::from_str::<TracedEntryPointRequestResponse>(&body).map_err(|err| {
1022 error!("Failed to parse traced_entry_points response: {:?}", &body);
1023 RPCError::ParseResponse(format!("Error: {err}, Body: {body}"))
1024 })?;
1025 trace!(?entrypoints, "Received traced_entry_points response from Tycho server");
1026 Ok(entrypoints)
1027 }
1028}
1029
1030#[cfg(test)]
1031mod tests {
1032 use std::{
1033 collections::{HashMap, HashSet},
1034 str::FromStr,
1035 };
1036
1037 use mockito::Server;
1038 use rstest::rstest;
1039 #[allow(deprecated)]
1041 use tycho_common::dto::ProtocolId;
1042 use tycho_common::dto::{AddressStorageLocation, TracingParams};
1043
1044 use super::*;
1045
1046 impl MockRPCClient {
1049 #[allow(clippy::too_many_arguments)]
1050 async fn test_get_protocol_states_paginated<T>(
1051 &self,
1052 chain: Chain,
1053 ids: &[T],
1054 protocol_system: &str,
1055 include_balances: bool,
1056 version: &VersionParam,
1057 chunk_size: usize,
1058 _concurrency: usize,
1059 ) -> Vec<ProtocolStateRequestBody>
1060 where
1061 T: AsRef<str> + Clone + Send + Sync + 'static,
1062 {
1063 ids.chunks(chunk_size)
1064 .map(|chunk| ProtocolStateRequestBody {
1065 protocol_ids: Some(
1066 chunk
1067 .iter()
1068 .map(|id| id.as_ref().to_string())
1069 .collect(),
1070 ),
1071 protocol_system: protocol_system.to_string(),
1072 chain,
1073 include_balances,
1074 version: version.clone(),
1075 pagination: PaginationParams { page: 0, page_size: chunk_size as i64 },
1076 })
1077 .collect()
1078 }
1079 }
1080
1081 #[allow(deprecated)]
1083 #[rstest]
1084 #[case::protocol_id_input(vec![
1085 ProtocolId { id: "id1".to_string(), chain: Chain::Ethereum },
1086 ProtocolId { id: "id2".to_string(), chain: Chain::Ethereum }
1087 ])]
1088 #[case::string_input(vec![
1089 "id1".to_string(),
1090 "id2".to_string()
1091 ])]
1092 #[tokio::test]
1093 async fn test_get_protocol_states_paginated_backwards_compatibility<T>(#[case] ids: Vec<T>)
1094 where
1095 T: AsRef<str> + Clone + Send + Sync + 'static,
1096 {
1097 let mock_client = MockRPCClient::new();
1098
1099 let request_bodies = mock_client
1100 .test_get_protocol_states_paginated(
1101 Chain::Ethereum,
1102 &ids,
1103 "test_system",
1104 true,
1105 &VersionParam::default(),
1106 2,
1107 2,
1108 )
1109 .await;
1110
1111 assert_eq!(request_bodies.len(), 1);
1113 assert_eq!(
1114 request_bodies[0]
1115 .protocol_ids
1116 .as_ref()
1117 .unwrap()
1118 .len(),
1119 2
1120 );
1121 }
1122
1123 #[tokio::test]
1124 async fn test_get_contract_state() {
1125 let mut server = Server::new_async().await;
1126 let server_resp = r#"
1127 {
1128 "accounts": [
1129 {
1130 "chain": "ethereum",
1131 "address": "0x0000000000000000000000000000000000000000",
1132 "title": "",
1133 "slots": {},
1134 "native_balance": "0x01f4",
1135 "token_balances": {},
1136 "code": "0x00",
1137 "code_hash": "0x5c06b7c5b3d910fd33bc2229846f9ddaf91d584d9b196e16636901ac3a77077e",
1138 "balance_modify_tx": "0x0000000000000000000000000000000000000000000000000000000000000000",
1139 "code_modify_tx": "0x0000000000000000000000000000000000000000000000000000000000000000",
1140 "creation_tx": null
1141 }
1142 ],
1143 "pagination": {
1144 "page": 0,
1145 "page_size": 20,
1146 "total": 10
1147 }
1148 }
1149 "#;
1150 serde_json::from_str::<StateRequestResponse>(server_resp).expect("deserialize");
1152
1153 let mocked_server = server
1154 .mock("POST", "/v1/contract_state")
1155 .expect(1)
1156 .with_body(server_resp)
1157 .create_async()
1158 .await;
1159
1160 let client = HttpRPCClient::new(server.url().as_str(), None).expect("create client");
1161
1162 let response = client
1163 .get_contract_state(&Default::default())
1164 .await
1165 .expect("get state");
1166 let accounts = response.accounts;
1167
1168 mocked_server.assert();
1169 assert_eq!(accounts.len(), 1);
1170 assert_eq!(accounts[0].slots, HashMap::new());
1171 assert_eq!(accounts[0].native_balance, Bytes::from(500u16.to_be_bytes()));
1172 assert_eq!(accounts[0].code, [0].to_vec());
1173 assert_eq!(
1174 accounts[0].code_hash,
1175 hex::decode("5c06b7c5b3d910fd33bc2229846f9ddaf91d584d9b196e16636901ac3a77077e")
1176 .unwrap()
1177 );
1178 }
1179
1180 #[tokio::test]
1181 async fn test_get_protocol_components() {
1182 let mut server = Server::new_async().await;
1183 let server_resp = r#"
1184 {
1185 "protocol_components": [
1186 {
1187 "id": "State1",
1188 "protocol_system": "ambient",
1189 "protocol_type_name": "Pool",
1190 "chain": "ethereum",
1191 "tokens": [
1192 "0x0000000000000000000000000000000000000000",
1193 "0x0000000000000000000000000000000000000001"
1194 ],
1195 "contract_ids": [
1196 "0x0000000000000000000000000000000000000000"
1197 ],
1198 "static_attributes": {
1199 "attribute_1": "0x00000000000003e8"
1200 },
1201 "change": "Creation",
1202 "creation_tx": "0x0000000000000000000000000000000000000000000000000000000000000000",
1203 "created_at": "2022-01-01T00:00:00"
1204 }
1205 ],
1206 "pagination": {
1207 "page": 0,
1208 "page_size": 20,
1209 "total": 10
1210 }
1211 }
1212 "#;
1213 serde_json::from_str::<ProtocolComponentRequestResponse>(server_resp).expect("deserialize");
1215
1216 let mocked_server = server
1217 .mock("POST", "/v1/protocol_components")
1218 .expect(1)
1219 .with_body(server_resp)
1220 .create_async()
1221 .await;
1222
1223 let client = HttpRPCClient::new(server.url().as_str(), None).expect("create client");
1224
1225 let response = client
1226 .get_protocol_components(&Default::default())
1227 .await
1228 .expect("get state");
1229 let components = response.protocol_components;
1230
1231 mocked_server.assert();
1232 assert_eq!(components.len(), 1);
1233 assert_eq!(components[0].id, "State1");
1234 assert_eq!(components[0].protocol_system, "ambient");
1235 assert_eq!(components[0].protocol_type_name, "Pool");
1236 assert_eq!(components[0].tokens.len(), 2);
1237 let expected_attributes =
1238 [("attribute_1".to_string(), Bytes::from(1000_u64.to_be_bytes()))]
1239 .iter()
1240 .cloned()
1241 .collect::<HashMap<String, Bytes>>();
1242 assert_eq!(components[0].static_attributes, expected_attributes);
1243 }
1244
1245 #[tokio::test]
1246 async fn test_get_protocol_states() {
1247 let mut server = Server::new_async().await;
1248 let server_resp = r#"
1249 {
1250 "states": [
1251 {
1252 "component_id": "State1",
1253 "attributes": {
1254 "attribute_1": "0x00000000000003e8"
1255 },
1256 "balances": {
1257 "0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2": "0x01f4"
1258 }
1259 }
1260 ],
1261 "pagination": {
1262 "page": 0,
1263 "page_size": 20,
1264 "total": 10
1265 }
1266 }
1267 "#;
1268 serde_json::from_str::<ProtocolStateRequestResponse>(server_resp).expect("deserialize");
1270
1271 let mocked_server = server
1272 .mock("POST", "/v1/protocol_state")
1273 .expect(1)
1274 .with_body(server_resp)
1275 .create_async()
1276 .await;
1277 let client = HttpRPCClient::new(server.url().as_str(), None).expect("create client");
1278
1279 let response = client
1280 .get_protocol_states(&Default::default())
1281 .await
1282 .expect("get state");
1283 let states = response.states;
1284
1285 mocked_server.assert();
1286 assert_eq!(states.len(), 1);
1287 assert_eq!(states[0].component_id, "State1");
1288 let expected_attributes =
1289 [("attribute_1".to_string(), Bytes::from(1000_u64.to_be_bytes()))]
1290 .iter()
1291 .cloned()
1292 .collect::<HashMap<String, Bytes>>();
1293 assert_eq!(states[0].attributes, expected_attributes);
1294 let expected_balances = [(
1295 Bytes::from_str("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2")
1296 .expect("Unsupported address format"),
1297 Bytes::from_str("0x01f4").unwrap(),
1298 )]
1299 .iter()
1300 .cloned()
1301 .collect::<HashMap<Bytes, Bytes>>();
1302 assert_eq!(states[0].balances, expected_balances);
1303 }
1304
1305 #[tokio::test]
1306 async fn test_get_tokens() {
1307 let mut server = Server::new_async().await;
1308 let server_resp = r#"
1309 {
1310 "tokens": [
1311 {
1312 "chain": "ethereum",
1313 "address": "0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2",
1314 "symbol": "WETH",
1315 "decimals": 18,
1316 "tax": 0,
1317 "gas": [
1318 29962
1319 ],
1320 "quality": 100
1321 },
1322 {
1323 "chain": "ethereum",
1324 "address": "0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48",
1325 "symbol": "USDC",
1326 "decimals": 6,
1327 "tax": 0,
1328 "gas": [
1329 40652
1330 ],
1331 "quality": 100
1332 }
1333 ],
1334 "pagination": {
1335 "page": 0,
1336 "page_size": 20,
1337 "total": 10
1338 }
1339 }
1340 "#;
1341 serde_json::from_str::<TokensRequestResponse>(server_resp).expect("deserialize");
1343
1344 let mocked_server = server
1345 .mock("POST", "/v1/tokens")
1346 .expect(1)
1347 .with_body(server_resp)
1348 .create_async()
1349 .await;
1350 let client = HttpRPCClient::new(server.url().as_str(), None).expect("create client");
1351
1352 let response = client
1353 .get_tokens(&Default::default())
1354 .await
1355 .expect("get tokens");
1356
1357 let expected = vec![
1358 ResponseToken {
1359 chain: Chain::Ethereum,
1360 address: Bytes::from_str("0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2").unwrap(),
1361 symbol: "WETH".to_string(),
1362 decimals: 18,
1363 tax: 0,
1364 gas: vec![Some(29962)],
1365 quality: 100,
1366 },
1367 ResponseToken {
1368 chain: Chain::Ethereum,
1369 address: Bytes::from_str("0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48").unwrap(),
1370 symbol: "USDC".to_string(),
1371 decimals: 6,
1372 tax: 0,
1373 gas: vec![Some(40652)],
1374 quality: 100,
1375 },
1376 ];
1377
1378 mocked_server.assert();
1379 assert_eq!(response.tokens, expected);
1380 assert_eq!(response.pagination, PaginationResponse { page: 0, page_size: 20, total: 10 });
1381 }
1382
1383 #[tokio::test]
1384 async fn test_get_protocol_systems() {
1385 let mut server = Server::new_async().await;
1386 let server_resp = r#"
1387 {
1388 "protocol_systems": [
1389 "system1",
1390 "system2"
1391 ],
1392 "pagination": {
1393 "page": 0,
1394 "page_size": 20,
1395 "total": 10
1396 }
1397 }
1398 "#;
1399 serde_json::from_str::<ProtocolSystemsRequestResponse>(server_resp).expect("deserialize");
1401
1402 let mocked_server = server
1403 .mock("POST", "/v1/protocol_systems")
1404 .expect(1)
1405 .with_body(server_resp)
1406 .create_async()
1407 .await;
1408 let client = HttpRPCClient::new(server.url().as_str(), None).expect("create client");
1409
1410 let response = client
1411 .get_protocol_systems(&Default::default())
1412 .await
1413 .expect("get protocol systems");
1414 let protocol_systems = response.protocol_systems;
1415
1416 mocked_server.assert();
1417 assert_eq!(protocol_systems, vec!["system1", "system2"]);
1418 }
1419
1420 #[tokio::test]
1421 async fn test_get_component_tvl() {
1422 let mut server = Server::new_async().await;
1423 let server_resp = r#"
1424 {
1425 "tvl": {
1426 "component1": 100.0
1427 },
1428 "pagination": {
1429 "page": 0,
1430 "page_size": 20,
1431 "total": 10
1432 }
1433 }
1434 "#;
1435 serde_json::from_str::<ComponentTvlRequestResponse>(server_resp).expect("deserialize");
1437
1438 let mocked_server = server
1439 .mock("POST", "/v1/component_tvl")
1440 .expect(1)
1441 .with_body(server_resp)
1442 .create_async()
1443 .await;
1444 let client = HttpRPCClient::new(server.url().as_str(), None).expect("create client");
1445
1446 let response = client
1447 .get_component_tvl(&Default::default())
1448 .await
1449 .expect("get protocol systems");
1450 let component_tvl = response.tvl;
1451
1452 mocked_server.assert();
1453 assert_eq!(component_tvl.get("component1"), Some(&100.0));
1454 }
1455
1456 #[tokio::test]
1457 async fn test_get_traced_entry_points() {
1458 let mut server = Server::new_async().await;
1459 let server_resp = r#"
1460 {
1461 "traced_entry_points": {
1462 "component_1": [
1463 [
1464 {
1465 "entry_point": {
1466 "external_id": "entrypoint_a",
1467 "target": "0x0000000000000000000000000000000000000001",
1468 "signature": "sig()"
1469 },
1470 "params": {
1471 "method": "rpctracer",
1472 "caller": "0x000000000000000000000000000000000000000a",
1473 "calldata": "0x000000000000000000000000000000000000000b"
1474 }
1475 },
1476 {
1477 "retriggers": [
1478 [
1479 "0x00000000000000000000000000000000000000aa",
1480 {"key": "0x0000000000000000000000000000000000000aaa", "offset": 12}
1481 ]
1482 ],
1483 "accessed_slots": {
1484 "0x0000000000000000000000000000000000aaaa": [
1485 "0x0000000000000000000000000000000000aaaa"
1486 ]
1487 }
1488 }
1489 ]
1490 ]
1491 },
1492 "pagination": {
1493 "page": 0,
1494 "page_size": 20,
1495 "total": 1
1496 }
1497 }
1498 "#;
1499 serde_json::from_str::<TracedEntryPointRequestResponse>(server_resp).expect("deserialize");
1501
1502 let mocked_server = server
1503 .mock("POST", "/v1/traced_entry_points")
1504 .expect(1)
1505 .with_body(server_resp)
1506 .create_async()
1507 .await;
1508 let client = HttpRPCClient::new(server.url().as_str(), None).expect("create client");
1509
1510 let response = client
1511 .get_traced_entry_points(&Default::default())
1512 .await
1513 .expect("get traced entry points");
1514 let entrypoints = response.traced_entry_points;
1515
1516 mocked_server.assert();
1517 assert_eq!(entrypoints.len(), 1);
1518 let comp1_entrypoints = entrypoints
1519 .get("component_1")
1520 .expect("component_1 entrypoints should exist");
1521 assert_eq!(comp1_entrypoints.len(), 1);
1522
1523 let (entrypoint, trace_result) = &comp1_entrypoints[0];
1524 assert_eq!(entrypoint.entry_point.external_id, "entrypoint_a");
1525 assert_eq!(
1526 entrypoint.entry_point.target,
1527 Bytes::from_str("0x0000000000000000000000000000000000000001").unwrap()
1528 );
1529 assert_eq!(entrypoint.entry_point.signature, "sig()");
1530 let TracingParams::RPCTracer(rpc_params) = &entrypoint.params;
1531 assert_eq!(
1532 rpc_params.caller,
1533 Some(Bytes::from("0x000000000000000000000000000000000000000a"))
1534 );
1535 assert_eq!(rpc_params.calldata, Bytes::from("0x000000000000000000000000000000000000000b"));
1536
1537 assert_eq!(
1538 trace_result.retriggers,
1539 HashSet::from([(
1540 Bytes::from("0x00000000000000000000000000000000000000aa"),
1541 AddressStorageLocation::new(
1542 Bytes::from("0x0000000000000000000000000000000000000aaa"),
1543 12
1544 )
1545 )])
1546 );
1547 assert_eq!(trace_result.accessed_slots.len(), 1);
1548 assert_eq!(
1549 trace_result.accessed_slots,
1550 HashMap::from([(
1551 Bytes::from("0x0000000000000000000000000000000000aaaa"),
1552 HashSet::from([Bytes::from("0x0000000000000000000000000000000000aaaa")])
1553 )])
1554 );
1555 }
1556
1557 #[tokio::test]
1558 async fn test_parse_retry_value_numeric() {
1559 let result = parse_retry_value("60");
1560 assert!(result.is_some());
1561
1562 let expected_time = SystemTime::now() + Duration::from_secs(60);
1563 let actual_time = result.unwrap();
1564
1565 let diff = if actual_time > expected_time {
1567 actual_time
1568 .duration_since(expected_time)
1569 .unwrap()
1570 } else {
1571 expected_time
1572 .duration_since(actual_time)
1573 .unwrap()
1574 };
1575 assert!(diff < Duration::from_secs(1), "Time difference too large: {:?}", diff);
1576 }
1577
1578 #[tokio::test]
1579 async fn test_parse_retry_value_rfc2822() {
1580 let rfc2822_date = "Sat, 01 Jan 2030 12:00:00 +0000";
1582 let result = parse_retry_value(rfc2822_date);
1583 assert!(result.is_some());
1584
1585 let parsed_time = result.unwrap();
1586 assert!(parsed_time > SystemTime::now());
1587 }
1588
1589 #[tokio::test]
1590 async fn test_parse_retry_value_invalid_formats() {
1591 assert!(parse_retry_value("invalid").is_none());
1593 assert!(parse_retry_value("").is_none());
1594 assert!(parse_retry_value("not_a_number").is_none());
1595 assert!(parse_retry_value("Mon, 32 Jan 2030 25:00:00 +0000").is_none()); }
1597
1598 #[tokio::test]
1599 async fn test_parse_retry_value_zero_seconds() {
1600 let result = parse_retry_value("0");
1601 assert!(result.is_some());
1602
1603 let expected_time = SystemTime::now();
1604 let actual_time = result.unwrap();
1605
1606 let diff = if actual_time > expected_time {
1608 actual_time
1609 .duration_since(expected_time)
1610 .unwrap()
1611 } else {
1612 expected_time
1613 .duration_since(actual_time)
1614 .unwrap()
1615 };
1616 assert!(diff < Duration::from_secs(1));
1617 }
1618
1619 #[tokio::test]
1620 async fn test_error_for_response_rate_limited() {
1621 let mut server = Server::new_async().await;
1622 let mock = server
1623 .mock("GET", "/test")
1624 .with_status(429)
1625 .with_header("Retry-After", "60")
1626 .create_async()
1627 .await;
1628
1629 let client = reqwest::Client::new();
1630 let response = client
1631 .get(format!("{}/test", server.url()))
1632 .send()
1633 .await
1634 .unwrap();
1635
1636 let http_client = HttpRPCClient::new(server.url().as_str(), None)
1637 .unwrap()
1638 .with_test_backoff_policy();
1639 let result = http_client
1640 .error_for_response(response)
1641 .await;
1642
1643 mock.assert();
1644 assert!(matches!(result, Err(RPCError::RateLimited(_))));
1645 if let Err(RPCError::RateLimited(retry_after)) = result {
1646 assert!(retry_after.is_some());
1647 }
1648 }
1649
1650 #[tokio::test]
1651 async fn test_error_for_response_rate_limited_no_header() {
1652 let mut server = Server::new_async().await;
1653 let mock = server
1654 .mock("GET", "/test")
1655 .with_status(429)
1656 .create_async()
1657 .await;
1658
1659 let client = reqwest::Client::new();
1660 let response = client
1661 .get(format!("{}/test", server.url()))
1662 .send()
1663 .await
1664 .unwrap();
1665
1666 let http_client = HttpRPCClient::new(server.url().as_str(), None)
1667 .unwrap()
1668 .with_test_backoff_policy();
1669 let result = http_client
1670 .error_for_response(response)
1671 .await;
1672
1673 mock.assert();
1674 assert!(matches!(result, Err(RPCError::RateLimited(None))));
1675 }
1676
1677 #[tokio::test]
1678 async fn test_error_for_response_server_errors() {
1679 let test_cases =
1680 vec![(502, "Bad Gateway"), (503, "Service Unavailable"), (504, "Gateway Timeout")];
1681
1682 for (status_code, expected_body) in test_cases {
1683 let mut server = Server::new_async().await;
1684 let mock = server
1685 .mock("GET", "/test")
1686 .with_status(status_code)
1687 .with_body(expected_body)
1688 .create_async()
1689 .await;
1690
1691 let client = reqwest::Client::new();
1692 let response = client
1693 .get(format!("{}/test", server.url()))
1694 .send()
1695 .await
1696 .unwrap();
1697
1698 let http_client = HttpRPCClient::new(server.url().as_str(), None)
1699 .unwrap()
1700 .with_test_backoff_policy();
1701 let result = http_client
1702 .error_for_response(response)
1703 .await;
1704
1705 mock.assert();
1706 assert!(matches!(result, Err(RPCError::ServerUnreachable(_))));
1707 if let Err(RPCError::ServerUnreachable(body)) = result {
1708 assert_eq!(body, expected_body);
1709 }
1710 }
1711 }
1712
1713 #[tokio::test]
1714 async fn test_error_for_response_success() {
1715 let mut server = Server::new_async().await;
1716 let mock = server
1717 .mock("GET", "/test")
1718 .with_status(200)
1719 .with_body("success")
1720 .create_async()
1721 .await;
1722
1723 let client = reqwest::Client::new();
1724 let response = client
1725 .get(format!("{}/test", server.url()))
1726 .send()
1727 .await
1728 .unwrap();
1729
1730 let http_client = HttpRPCClient::new(server.url().as_str(), None)
1731 .unwrap()
1732 .with_test_backoff_policy();
1733 let result = http_client
1734 .error_for_response(response)
1735 .await;
1736
1737 mock.assert();
1738 assert!(result.is_ok());
1739
1740 let response = result.unwrap();
1741 assert_eq!(response.status(), 200);
1742 }
1743
1744 #[tokio::test]
1745 async fn test_handle_error_for_backoff_server_unreachable() {
1746 let http_client = HttpRPCClient::new("http://localhost:8080", None)
1747 .unwrap()
1748 .with_test_backoff_policy();
1749 let error = RPCError::ServerUnreachable("Service down".to_string());
1750
1751 let backoff_error = http_client
1752 .handle_error_for_backoff(error)
1753 .await;
1754
1755 match backoff_error {
1756 backoff::Error::Transient { err: RPCError::ServerUnreachable(msg), retry_after } => {
1757 assert_eq!(msg, "Service down");
1758 assert_eq!(retry_after, Some(Duration::from_millis(50))); }
1760 _ => panic!("Expected transient error for ServerUnreachable"),
1761 }
1762 }
1763
1764 #[tokio::test]
1765 async fn test_handle_error_for_backoff_rate_limited_with_retry_after() {
1766 let http_client = HttpRPCClient::new("http://localhost:8080", None)
1767 .unwrap()
1768 .with_test_backoff_policy();
1769 let future_time = SystemTime::now() + Duration::from_secs(30);
1770 let error = RPCError::RateLimited(Some(future_time));
1771
1772 let backoff_error = http_client
1773 .handle_error_for_backoff(error)
1774 .await;
1775
1776 match backoff_error {
1777 backoff::Error::Transient { err: RPCError::RateLimited(retry_after), .. } => {
1778 assert_eq!(retry_after, Some(future_time));
1779 }
1780 _ => panic!("Expected transient error for RateLimited"),
1781 }
1782
1783 let stored_retry_after = http_client.retry_after.read().await;
1785 assert_eq!(*stored_retry_after, Some(future_time));
1786 }
1787
1788 #[tokio::test]
1789 async fn test_handle_error_for_backoff_rate_limited_no_retry_after() {
1790 let http_client = HttpRPCClient::new("http://localhost:8080", None)
1791 .unwrap()
1792 .with_test_backoff_policy();
1793 let error = RPCError::RateLimited(None);
1794
1795 let backoff_error = http_client
1796 .handle_error_for_backoff(error)
1797 .await;
1798
1799 match backoff_error {
1800 backoff::Error::Transient { err: RPCError::RateLimited(None), .. } => {
1801 }
1803 _ => panic!("Expected transient error for RateLimited without retry-after"),
1804 }
1805 }
1806
1807 #[tokio::test]
1808 async fn test_handle_error_for_backoff_other_errors() {
1809 let http_client = HttpRPCClient::new("http://localhost:8080", None)
1810 .unwrap()
1811 .with_test_backoff_policy();
1812 let error = RPCError::ParseResponse("Invalid JSON".to_string());
1813
1814 let backoff_error = http_client
1815 .handle_error_for_backoff(error)
1816 .await;
1817
1818 match backoff_error {
1819 backoff::Error::Permanent(RPCError::ParseResponse(msg)) => {
1820 assert_eq!(msg, "Invalid JSON");
1821 }
1822 _ => panic!("Expected permanent error for ParseResponse"),
1823 }
1824 }
1825
1826 #[tokio::test]
1827 async fn test_wait_until_retry_after_no_retry_time() {
1828 let http_client = HttpRPCClient::new("http://localhost:8080", None)
1829 .unwrap()
1830 .with_test_backoff_policy();
1831
1832 let start = std::time::Instant::now();
1833 http_client
1834 .wait_until_retry_after()
1835 .await;
1836 let elapsed = start.elapsed();
1837
1838 assert!(elapsed < Duration::from_millis(100));
1840 }
1841
1842 #[tokio::test]
1843 async fn test_wait_until_retry_after_past_time() {
1844 let http_client = HttpRPCClient::new("http://localhost:8080", None)
1845 .unwrap()
1846 .with_test_backoff_policy();
1847
1848 let past_time = SystemTime::now() - Duration::from_secs(10);
1850 *http_client.retry_after.write().await = Some(past_time);
1851
1852 let start = std::time::Instant::now();
1853 http_client
1854 .wait_until_retry_after()
1855 .await;
1856 let elapsed = start.elapsed();
1857
1858 assert!(elapsed < Duration::from_millis(100));
1860 }
1861
1862 #[tokio::test]
1863 async fn test_wait_until_retry_after_future_time() {
1864 let http_client = HttpRPCClient::new("http://localhost:8080", None)
1865 .unwrap()
1866 .with_test_backoff_policy();
1867
1868 let future_time = SystemTime::now() + Duration::from_millis(100);
1870 *http_client.retry_after.write().await = Some(future_time);
1871
1872 let start = std::time::Instant::now();
1873 http_client
1874 .wait_until_retry_after()
1875 .await;
1876 let elapsed = start.elapsed();
1877
1878 assert!(elapsed >= Duration::from_millis(80)); assert!(elapsed <= Duration::from_millis(200)); }
1882
1883 #[tokio::test]
1884 async fn test_make_post_request_success() {
1885 let mut server = Server::new_async().await;
1886 let server_resp = r#"{"success": true}"#;
1887
1888 let mock = server
1889 .mock("POST", "/test")
1890 .with_status(200)
1891 .with_body(server_resp)
1892 .create_async()
1893 .await;
1894
1895 let http_client = HttpRPCClient::new(server.url().as_str(), None)
1896 .unwrap()
1897 .with_test_backoff_policy();
1898 let request_body = serde_json::json!({"test": "data"});
1899 let uri = format!("{}/test", server.url());
1900
1901 let result = http_client
1902 .make_post_request(&request_body, &uri)
1903 .await;
1904
1905 mock.assert();
1906 assert!(result.is_ok());
1907
1908 let response = result.unwrap();
1909 assert_eq!(response.status(), 200);
1910 assert_eq!(response.text().await.unwrap(), server_resp);
1911 }
1912
1913 #[tokio::test]
1914 async fn test_make_post_request_retry_on_server_error() {
1915 let mut server = Server::new_async().await;
1916 let error_mock = server
1918 .mock("POST", "/test")
1919 .with_status(503)
1920 .with_body("Service Unavailable")
1921 .expect(1)
1922 .create_async()
1923 .await;
1924
1925 let success_mock = server
1926 .mock("POST", "/test")
1927 .with_status(200)
1928 .with_body(r#"{"success": true}"#)
1929 .expect(1)
1930 .create_async()
1931 .await;
1932
1933 let http_client = HttpRPCClient::new(server.url().as_str(), None)
1934 .unwrap()
1935 .with_test_backoff_policy();
1936 let request_body = serde_json::json!({"test": "data"});
1937 let uri = format!("{}/test", server.url());
1938
1939 let result = http_client
1940 .make_post_request(&request_body, &uri)
1941 .await;
1942
1943 error_mock.assert();
1944 success_mock.assert();
1945 assert!(result.is_ok());
1946 }
1947
1948 #[tokio::test]
1949 async fn test_make_post_request_respect_retry_after_header() {
1950 let mut server = Server::new_async().await;
1951
1952 let rate_limit_mock = server
1954 .mock("POST", "/test")
1955 .with_status(429)
1956 .with_header("Retry-After", "1") .expect(1)
1958 .create_async()
1959 .await;
1960
1961 let success_mock = server
1962 .mock("POST", "/test")
1963 .with_status(200)
1964 .with_body(r#"{"success": true}"#)
1965 .expect(1)
1966 .create_async()
1967 .await;
1968
1969 let http_client = HttpRPCClient::new(server.url().as_str(), None)
1970 .unwrap()
1971 .with_test_backoff_policy();
1972 let request_body = serde_json::json!({"test": "data"});
1973 let uri = format!("{}/test", server.url());
1974
1975 let start = std::time::Instant::now();
1976 let result = http_client
1977 .make_post_request(&request_body, &uri)
1978 .await;
1979 let elapsed = start.elapsed();
1980
1981 rate_limit_mock.assert();
1982 success_mock.assert();
1983 assert!(result.is_ok());
1984
1985 assert!(elapsed >= Duration::from_millis(900)); assert!(elapsed <= Duration::from_millis(2000)); }
1989
1990 #[tokio::test]
1991 async fn test_make_post_request_permanent_error() {
1992 let mut server = Server::new_async().await;
1993
1994 let mock = server
1995 .mock("POST", "/test")
1996 .with_status(400) .with_body("Bad Request")
1998 .expect(1)
1999 .create_async()
2000 .await;
2001
2002 let http_client = HttpRPCClient::new(server.url().as_str(), None)
2003 .unwrap()
2004 .with_test_backoff_policy();
2005 let request_body = serde_json::json!({"test": "data"});
2006 let uri = format!("{}/test", server.url());
2007
2008 let result = http_client
2009 .make_post_request(&request_body, &uri)
2010 .await;
2011
2012 mock.assert();
2013 assert!(result.is_ok()); let response = result.unwrap();
2016 assert_eq!(response.status(), 400);
2017 }
2018
2019 #[tokio::test]
2020 async fn test_concurrent_requests_with_different_retry_after() {
2021 let mut server = Server::new_async().await;
2022
2023 let rate_limit_mock_1 = server
2025 .mock("POST", "/test1")
2026 .with_status(429)
2027 .with_header("Retry-After", "1")
2028 .expect(1)
2029 .create_async()
2030 .await;
2031
2032 let rate_limit_mock_2 = server
2034 .mock("POST", "/test2")
2035 .with_status(429)
2036 .with_header("Retry-After", "2")
2037 .expect(1)
2038 .create_async()
2039 .await;
2040
2041 let success_mock_1 = server
2043 .mock("POST", "/test1")
2044 .with_status(200)
2045 .with_body(r#"{"result": "success1"}"#)
2046 .expect(1)
2047 .create_async()
2048 .await;
2049
2050 let success_mock_2 = server
2051 .mock("POST", "/test2")
2052 .with_status(200)
2053 .with_body(r#"{"result": "success2"}"#)
2054 .expect(1)
2055 .create_async()
2056 .await;
2057
2058 let http_client = HttpRPCClient::new(server.url().as_str(), None)
2059 .unwrap()
2060 .with_test_backoff_policy();
2061 let request_body = serde_json::json!({"test": "data"});
2062
2063 let uri1 = format!("{}/test1", server.url());
2064 let uri2 = format!("{}/test2", server.url());
2065
2066 let start = std::time::Instant::now();
2068 let (result1, result2) = tokio::join!(
2069 http_client.make_post_request(&request_body, &uri1),
2070 http_client.make_post_request(&request_body, &uri2)
2071 );
2072 let elapsed = start.elapsed();
2073
2074 rate_limit_mock_1.assert();
2075 rate_limit_mock_2.assert();
2076 success_mock_1.assert();
2077 success_mock_2.assert();
2078
2079 assert!(result1.is_ok());
2080 assert!(result2.is_ok());
2081
2082 assert!(elapsed >= Duration::from_millis(1800)); assert!(elapsed <= Duration::from_millis(3000)); let final_retry_after = http_client.retry_after.read().await;
2090 assert!(final_retry_after.is_some());
2091
2092 if let Some(retry_time) = *final_retry_after {
2094 let now = SystemTime::now();
2097 let diff = if retry_time > now {
2098 retry_time.duration_since(now).unwrap()
2099 } else {
2100 now.duration_since(retry_time).unwrap()
2101 };
2102
2103 assert!(diff <= Duration::from_secs(3), "Retry time difference too large: {:?}", diff);
2105 }
2106 }
2107}