tycho_client/
rpc.rs

1//! # Tycho RPC Client
2//!
3//! The objective of this module is to provide swift and simplified access to the Remote Procedure
4//! Call (RPC) endpoints of Tycho. These endpoints are chiefly responsible for facilitating data
5//! queries, especially querying snapshots of data.
6use std::{
7    collections::HashMap,
8    sync::Arc,
9    time::{Duration, SystemTime},
10};
11
12use async_trait::async_trait;
13use backoff::{exponential::ExponentialBackoffBuilder, ExponentialBackoff};
14use futures03::future::try_join_all;
15#[cfg(test)]
16use mockall::automock;
17use reqwest::{header, Client, ClientBuilder, Response, StatusCode, Url};
18use serde::Serialize;
19use thiserror::Error;
20use time::{format_description::well_known::Rfc2822, OffsetDateTime};
21use tokio::{
22    sync::{RwLock, Semaphore},
23    time::sleep,
24};
25use tracing::{debug, error, instrument, trace, warn};
26use tycho_common::{
27    dto::{
28        BlockParam, Chain, ComponentTvlRequestBody, ComponentTvlRequestResponse,
29        EntryPointWithTracingParams, PaginationParams, PaginationResponse, ProtocolComponent,
30        ProtocolComponentRequestResponse, ProtocolComponentsRequestBody, ProtocolStateRequestBody,
31        ProtocolStateRequestResponse, ProtocolSystemsRequestBody, ProtocolSystemsRequestResponse,
32        ResponseToken, StateRequestBody, StateRequestResponse, TokensRequestBody,
33        TokensRequestResponse, TracedEntryPointRequestBody, TracedEntryPointRequestResponse,
34        TracingResult, VersionParam,
35    },
36    models::ComponentId,
37    Bytes,
38};
39
40use crate::{
41    feed::synchronizer::{ComponentWithState, Snapshot},
42    TYCHO_SERVER_VERSION,
43};
44
45/// Request body for fetching a snapshot of protocol states and VM storage.
46///
47/// This struct helps to coordinate fetching  multiple pieces of related data
48/// (protocol states, contract storage, TVL, entry points).
49#[derive(Clone, Debug, PartialEq)]
50pub struct SnapshotParameters<'a> {
51    /// Which chain to fetch snapshots for
52    pub chain: Chain,
53    /// Protocol system name, required for correct state resolution
54    pub protocol_system: &'a str,
55    /// Components to fetch protocol states for
56    pub components: &'a HashMap<ComponentId, ProtocolComponent>,
57    /// Traced entry points data mapped by component id
58    pub entrypoints: Option<&'a HashMap<String, Vec<(EntryPointWithTracingParams, TracingResult)>>>,
59    /// Contract addresses to fetch VM storage for
60    pub contract_ids: &'a [Bytes],
61    /// Block number for versioning
62    pub block_number: u64,
63    /// Whether to include balance information
64    pub include_balances: bool,
65    /// Whether to fetch TVL data
66    pub include_tvl: bool,
67}
68
69impl<'a> SnapshotParameters<'a> {
70    pub fn new(
71        chain: Chain,
72        protocol_system: &'a str,
73        components: &'a HashMap<ComponentId, ProtocolComponent>,
74        contract_ids: &'a [Bytes],
75        block_number: u64,
76    ) -> Self {
77        Self {
78            chain,
79            protocol_system,
80            components,
81            entrypoints: None,
82            contract_ids,
83            block_number,
84            include_balances: true,
85            include_tvl: true,
86        }
87    }
88
89    /// Set whether to include balance information (default: true)
90    pub fn include_balances(mut self, include_balances: bool) -> Self {
91        self.include_balances = include_balances;
92        self
93    }
94
95    /// Set whether to fetch TVL data (default: true)
96    pub fn include_tvl(mut self, include_tvl: bool) -> Self {
97        self.include_tvl = include_tvl;
98        self
99    }
100
101    pub fn entrypoints(
102        mut self,
103        entrypoints: &'a HashMap<String, Vec<(EntryPointWithTracingParams, TracingResult)>>,
104    ) -> Self {
105        self.entrypoints = Some(entrypoints);
106        self
107    }
108}
109
110#[derive(Error, Debug)]
111pub enum RPCError {
112    /// The passed tycho url failed to parse.
113    #[error("Failed to parse URL: {0}. Error: {1}")]
114    UrlParsing(String, String),
115
116    /// The request data is not correctly formed.
117    #[error("Failed to format request: {0}")]
118    FormatRequest(String),
119
120    /// Errors forwarded from the HTTP protocol.
121    #[error("Unexpected HTTP client error: {0}")]
122    HttpClient(String, #[source] reqwest::Error),
123
124    /// The response from the server could not be parsed correctly.
125    #[error("Failed to parse response: {0}")]
126    ParseResponse(String),
127
128    /// Other fatal errors.
129    #[error("Fatal error: {0}")]
130    Fatal(String),
131
132    #[error("Rate limited until {0:?}")]
133    RateLimited(Option<SystemTime>),
134
135    #[error("Server unreachable: {0}")]
136    ServerUnreachable(String),
137}
138
139#[cfg_attr(test, automock)]
140#[async_trait]
141pub trait RPCClient: Send + Sync {
142    /// Retrieves a snapshot of contract state.
143    async fn get_contract_state(
144        &self,
145        request: &StateRequestBody,
146    ) -> Result<StateRequestResponse, RPCError>;
147
148    async fn get_contract_state_paginated(
149        &self,
150        chain: Chain,
151        ids: &[Bytes],
152        protocol_system: &str,
153        version: &VersionParam,
154        chunk_size: usize,
155        concurrency: usize,
156    ) -> Result<StateRequestResponse, RPCError> {
157        let semaphore = Arc::new(Semaphore::new(concurrency));
158
159        // Sort the ids to maximize server-side cache hits
160        let mut sorted_ids = ids.to_vec();
161        sorted_ids.sort();
162
163        let chunked_bodies = sorted_ids
164            .chunks(chunk_size)
165            .map(|chunk| StateRequestBody {
166                contract_ids: Some(chunk.to_vec()),
167                protocol_system: protocol_system.to_string(),
168                chain,
169                version: version.clone(),
170                pagination: PaginationParams { page: 0, page_size: chunk_size as i64 },
171            })
172            .collect::<Vec<_>>();
173
174        let mut tasks = Vec::new();
175        for body in chunked_bodies.iter() {
176            let sem = semaphore.clone();
177            tasks.push(async move {
178                let _permit = sem
179                    .acquire()
180                    .await
181                    .map_err(|_| RPCError::Fatal("Semaphore dropped".to_string()))?;
182                self.get_contract_state(body).await
183            });
184        }
185
186        // Execute all tasks concurrently with the defined concurrency limit.
187        let responses = try_join_all(tasks).await?;
188
189        // Aggregate the responses into a single result.
190        let accounts = responses
191            .iter()
192            .flat_map(|r| r.accounts.clone())
193            .collect();
194        let total: i64 = responses
195            .iter()
196            .map(|r| r.pagination.total)
197            .sum();
198
199        Ok(StateRequestResponse {
200            accounts,
201            pagination: PaginationResponse { page: 0, page_size: chunk_size as i64, total },
202        })
203    }
204
205    async fn get_protocol_components(
206        &self,
207        request: &ProtocolComponentsRequestBody,
208    ) -> Result<ProtocolComponentRequestResponse, RPCError>;
209
210    async fn get_protocol_components_paginated(
211        &self,
212        request: &ProtocolComponentsRequestBody,
213        chunk_size: usize,
214        concurrency: usize,
215    ) -> Result<ProtocolComponentRequestResponse, RPCError> {
216        let semaphore = Arc::new(Semaphore::new(concurrency));
217
218        // If a set of component IDs is specified, the maximum return size is already known,
219        // allowing us to pre-compute the number of requests to be made.
220        match request.component_ids {
221            Some(ref ids) => {
222                // We can divide the component_ids into chunks of size chunk_size
223                let chunked_bodies = ids
224                    .chunks(chunk_size)
225                    .enumerate()
226                    .map(|(index, _)| ProtocolComponentsRequestBody {
227                        protocol_system: request.protocol_system.clone(),
228                        component_ids: request.component_ids.clone(),
229                        tvl_gt: request.tvl_gt,
230                        chain: request.chain,
231                        pagination: PaginationParams {
232                            page: index as i64,
233                            page_size: chunk_size as i64,
234                        },
235                    })
236                    .collect::<Vec<_>>();
237
238                let mut tasks = Vec::new();
239                for body in chunked_bodies.iter() {
240                    let sem = semaphore.clone();
241                    tasks.push(async move {
242                        let _permit = sem
243                            .acquire()
244                            .await
245                            .map_err(|_| RPCError::Fatal("Semaphore dropped".to_string()))?;
246                        self.get_protocol_components(body).await
247                    });
248                }
249
250                try_join_all(tasks)
251                    .await
252                    .map(|responses| ProtocolComponentRequestResponse {
253                        protocol_components: responses
254                            .into_iter()
255                            .flat_map(|r| r.protocol_components.into_iter())
256                            .collect(),
257                        pagination: PaginationResponse {
258                            page: 0,
259                            page_size: chunk_size as i64,
260                            total: ids.len() as i64,
261                        },
262                    })
263            }
264            _ => {
265                // If no component ids are specified, we need to make requests based on the total
266                // number of results from the first response.
267
268                let initial_request = ProtocolComponentsRequestBody {
269                    protocol_system: request.protocol_system.clone(),
270                    component_ids: request.component_ids.clone(),
271                    tvl_gt: request.tvl_gt,
272                    chain: request.chain,
273                    pagination: PaginationParams { page: 0, page_size: chunk_size as i64 },
274                };
275                let first_response = self
276                    .get_protocol_components(&initial_request)
277                    .await
278                    .map_err(|err| RPCError::Fatal(err.to_string()))?;
279
280                let total_items = first_response.pagination.total;
281                let total_pages = (total_items as f64 / chunk_size as f64).ceil() as i64;
282
283                // Initialize the final response accumulator
284                let mut accumulated_response = ProtocolComponentRequestResponse {
285                    protocol_components: first_response.protocol_components,
286                    pagination: PaginationResponse {
287                        page: 0,
288                        page_size: chunk_size as i64,
289                        total: total_items,
290                    },
291                };
292
293                let mut page = 1;
294                while page < total_pages {
295                    let requests_in_this_iteration = (total_pages - page).min(concurrency as i64);
296
297                    // Create request bodies for parallel requests, respecting the concurrency limit
298                    let chunked_bodies = (0..requests_in_this_iteration)
299                        .map(|iter| ProtocolComponentsRequestBody {
300                            protocol_system: request.protocol_system.clone(),
301                            component_ids: request.component_ids.clone(),
302                            tvl_gt: request.tvl_gt,
303                            chain: request.chain,
304                            pagination: PaginationParams {
305                                page: page + iter,
306                                page_size: chunk_size as i64,
307                            },
308                        })
309                        .collect::<Vec<_>>();
310
311                    let tasks: Vec<_> = chunked_bodies
312                        .iter()
313                        .map(|body| {
314                            let sem = semaphore.clone();
315                            async move {
316                                let _permit = sem.acquire().await.map_err(|_| {
317                                    RPCError::Fatal("Semaphore dropped".to_string())
318                                })?;
319                                self.get_protocol_components(body).await
320                            }
321                        })
322                        .collect();
323
324                    let responses = try_join_all(tasks)
325                        .await
326                        .map(|responses| {
327                            let total = responses[0].pagination.total;
328                            ProtocolComponentRequestResponse {
329                                protocol_components: responses
330                                    .into_iter()
331                                    .flat_map(|r| r.protocol_components.into_iter())
332                                    .collect(),
333                                pagination: PaginationResponse {
334                                    page,
335                                    page_size: chunk_size as i64,
336                                    total,
337                                },
338                            }
339                        });
340
341                    // Update the accumulated response or set the initial response
342                    match responses {
343                        Ok(mut resp) => {
344                            accumulated_response
345                                .protocol_components
346                                .append(&mut resp.protocol_components);
347                        }
348                        Err(e) => return Err(e),
349                    }
350
351                    page += concurrency as i64;
352                }
353                Ok(accumulated_response)
354            }
355        }
356    }
357
358    async fn get_protocol_states(
359        &self,
360        request: &ProtocolStateRequestBody,
361    ) -> Result<ProtocolStateRequestResponse, RPCError>;
362
363    #[allow(clippy::too_many_arguments)]
364    async fn get_protocol_states_paginated<T>(
365        &self,
366        chain: Chain,
367        ids: &[T],
368        protocol_system: &str,
369        include_balances: bool,
370        version: &VersionParam,
371        chunk_size: usize,
372        concurrency: usize,
373    ) -> Result<ProtocolStateRequestResponse, RPCError>
374    where
375        T: AsRef<str> + Sync + 'static,
376    {
377        let semaphore = Arc::new(Semaphore::new(concurrency));
378        let chunked_bodies = ids
379            .chunks(chunk_size)
380            .map(|c| ProtocolStateRequestBody {
381                protocol_ids: Some(
382                    c.iter()
383                        .map(|id| id.as_ref().to_string())
384                        .collect(),
385                ),
386                protocol_system: protocol_system.to_string(),
387                chain,
388                include_balances,
389                version: version.clone(),
390                pagination: PaginationParams { page: 0, page_size: chunk_size as i64 },
391            })
392            .collect::<Vec<_>>();
393
394        let mut tasks = Vec::new();
395        for body in chunked_bodies.iter() {
396            let sem = semaphore.clone();
397            tasks.push(async move {
398                let _permit = sem
399                    .acquire()
400                    .await
401                    .map_err(|_| RPCError::Fatal("Semaphore dropped".to_string()))?;
402                self.get_protocol_states(body).await
403            });
404        }
405
406        try_join_all(tasks)
407            .await
408            .map(|responses| {
409                let states = responses
410                    .clone()
411                    .into_iter()
412                    .flat_map(|r| r.states)
413                    .collect();
414                let total = responses
415                    .iter()
416                    .map(|r| r.pagination.total)
417                    .sum();
418                ProtocolStateRequestResponse {
419                    states,
420                    pagination: PaginationResponse { page: 0, page_size: chunk_size as i64, total },
421                }
422            })
423    }
424
425    /// This function returns only one chunk of tokens. To get all tokens please call
426    /// get_all_tokens.
427    async fn get_tokens(
428        &self,
429        request: &TokensRequestBody,
430    ) -> Result<TokensRequestResponse, RPCError>;
431
432    async fn get_all_tokens(
433        &self,
434        chain: Chain,
435        min_quality: Option<i32>,
436        traded_n_days_ago: Option<u64>,
437        chunk_size: usize,
438    ) -> Result<Vec<ResponseToken>, RPCError> {
439        let mut request_page = 0;
440        let mut all_tokens = Vec::new();
441        loop {
442            let mut response = self
443                .get_tokens(&TokensRequestBody {
444                    token_addresses: None,
445                    min_quality,
446                    traded_n_days_ago,
447                    pagination: PaginationParams {
448                        page: request_page,
449                        page_size: chunk_size.try_into().map_err(|_| {
450                            RPCError::FormatRequest(
451                                "Failed to convert chunk_size into i64".to_string(),
452                            )
453                        })?,
454                    },
455                    chain,
456                })
457                .await?;
458
459            let num_tokens = response.tokens.len();
460            all_tokens.append(&mut response.tokens);
461            request_page += 1;
462
463            if num_tokens < chunk_size {
464                break;
465            }
466        }
467        Ok(all_tokens)
468    }
469
470    async fn get_protocol_systems(
471        &self,
472        request: &ProtocolSystemsRequestBody,
473    ) -> Result<ProtocolSystemsRequestResponse, RPCError>;
474
475    async fn get_component_tvl(
476        &self,
477        request: &ComponentTvlRequestBody,
478    ) -> Result<ComponentTvlRequestResponse, RPCError>;
479
480    async fn get_component_tvl_paginated(
481        &self,
482        request: &ComponentTvlRequestBody,
483        chunk_size: usize,
484        concurrency: usize,
485    ) -> Result<ComponentTvlRequestResponse, RPCError> {
486        let semaphore = Arc::new(Semaphore::new(concurrency));
487
488        match request.component_ids {
489            Some(ref ids) => {
490                let chunked_requests = ids
491                    .chunks(chunk_size)
492                    .enumerate()
493                    .map(|(index, _)| ComponentTvlRequestBody {
494                        chain: request.chain,
495                        protocol_system: request.protocol_system.clone(),
496                        component_ids: Some(ids.clone()),
497                        pagination: PaginationParams {
498                            page: index as i64,
499                            page_size: chunk_size as i64,
500                        },
501                    })
502                    .collect::<Vec<_>>();
503
504                let tasks: Vec<_> = chunked_requests
505                    .into_iter()
506                    .map(|req| {
507                        let sem = semaphore.clone();
508                        async move {
509                            let _permit = sem
510                                .acquire()
511                                .await
512                                .map_err(|_| RPCError::Fatal("Semaphore dropped".to_string()))?;
513                            self.get_component_tvl(&req).await
514                        }
515                    })
516                    .collect();
517
518                let responses = try_join_all(tasks).await?;
519
520                let mut merged_tvl = HashMap::new();
521                for resp in responses {
522                    for (key, value) in resp.tvl {
523                        *merged_tvl.entry(key).or_insert(0.0) = value;
524                    }
525                }
526
527                Ok(ComponentTvlRequestResponse {
528                    tvl: merged_tvl,
529                    pagination: PaginationResponse {
530                        page: 0,
531                        page_size: chunk_size as i64,
532                        total: ids.len() as i64,
533                    },
534                })
535            }
536            _ => {
537                let first_request = ComponentTvlRequestBody {
538                    chain: request.chain,
539                    protocol_system: request.protocol_system.clone(),
540                    component_ids: request.component_ids.clone(),
541                    pagination: PaginationParams { page: 0, page_size: chunk_size as i64 },
542                };
543
544                let first_response = self
545                    .get_component_tvl(&first_request)
546                    .await?;
547                let total_items = first_response.pagination.total;
548                let total_pages = (total_items as f64 / chunk_size as f64).ceil() as i64;
549
550                let mut merged_tvl = first_response.tvl;
551
552                let mut page = 1;
553                while page < total_pages {
554                    let requests_in_this_iteration = (total_pages - page).min(concurrency as i64);
555
556                    let chunked_requests: Vec<_> = (0..requests_in_this_iteration)
557                        .map(|i| ComponentTvlRequestBody {
558                            chain: request.chain,
559                            protocol_system: request.protocol_system.clone(),
560                            component_ids: request.component_ids.clone(),
561                            pagination: PaginationParams {
562                                page: page + i,
563                                page_size: chunk_size as i64,
564                            },
565                        })
566                        .collect();
567
568                    let tasks: Vec<_> = chunked_requests
569                        .into_iter()
570                        .map(|req| {
571                            let sem = semaphore.clone();
572                            async move {
573                                let _permit = sem.acquire().await.map_err(|_| {
574                                    RPCError::Fatal("Semaphore dropped".to_string())
575                                })?;
576                                self.get_component_tvl(&req).await
577                            }
578                        })
579                        .collect();
580
581                    let responses = try_join_all(tasks).await?;
582
583                    // merge hashmap
584                    for resp in responses {
585                        for (key, value) in resp.tvl {
586                            *merged_tvl.entry(key).or_insert(0.0) += value;
587                        }
588                    }
589
590                    page += concurrency as i64;
591                }
592
593                Ok(ComponentTvlRequestResponse {
594                    tvl: merged_tvl,
595                    pagination: PaginationResponse {
596                        page: 0,
597                        page_size: chunk_size as i64,
598                        total: total_items,
599                    },
600                })
601            }
602        }
603    }
604
605    async fn get_traced_entry_points(
606        &self,
607        request: &TracedEntryPointRequestBody,
608    ) -> Result<TracedEntryPointRequestResponse, RPCError>;
609
610    async fn get_traced_entry_points_paginated(
611        &self,
612        chain: Chain,
613        protocol_system: &str,
614        component_ids: &[String],
615        chunk_size: usize,
616        concurrency: usize,
617    ) -> Result<TracedEntryPointRequestResponse, RPCError> {
618        let semaphore = Arc::new(Semaphore::new(concurrency));
619        let chunked_bodies = component_ids
620            .chunks(chunk_size)
621            .map(|c| TracedEntryPointRequestBody {
622                chain,
623                protocol_system: protocol_system.to_string(),
624                component_ids: Some(c.to_vec()),
625                pagination: PaginationParams { page: 0, page_size: chunk_size as i64 },
626            })
627            .collect::<Vec<_>>();
628
629        let mut tasks = Vec::new();
630        for body in chunked_bodies.iter() {
631            let sem = semaphore.clone();
632            tasks.push(async move {
633                let _permit = sem
634                    .acquire()
635                    .await
636                    .map_err(|_| RPCError::Fatal("Semaphore dropped".to_string()))?;
637                self.get_traced_entry_points(body).await
638            });
639        }
640
641        try_join_all(tasks)
642            .await
643            .map(|responses| {
644                let traced_entry_points = responses
645                    .clone()
646                    .into_iter()
647                    .flat_map(|r| r.traced_entry_points)
648                    .collect();
649                let total = responses
650                    .iter()
651                    .map(|r| r.pagination.total)
652                    .sum();
653                TracedEntryPointRequestResponse {
654                    traced_entry_points,
655                    pagination: PaginationResponse { page: 0, page_size: chunk_size as i64, total },
656                }
657            })
658    }
659
660    async fn get_snapshots<'a>(
661        &self,
662        request: &SnapshotParameters<'a>,
663        chunk_size: usize,
664        concurrency: usize,
665    ) -> Result<Snapshot, RPCError>;
666}
667
668#[derive(Debug, Clone)]
669pub struct HttpRPCClient {
670    http_client: Client,
671    url: Url,
672    retry_after: Arc<RwLock<Option<SystemTime>>>,
673    backoff_policy: ExponentialBackoff,
674    server_restart_duration: Duration,
675}
676
677impl HttpRPCClient {
678    pub fn new(base_uri: &str, auth_key: Option<&str>) -> Result<Self, RPCError> {
679        let uri = base_uri
680            .parse::<Url>()
681            .map_err(|e| RPCError::UrlParsing(base_uri.to_string(), e.to_string()))?;
682
683        // Add default headers
684        let mut headers = header::HeaderMap::new();
685        headers.insert(header::CONTENT_TYPE, header::HeaderValue::from_static("application/json"));
686        let user_agent = format!("tycho-client-{version}", version = env!("CARGO_PKG_VERSION"));
687        headers.insert(
688            header::USER_AGENT,
689            header::HeaderValue::from_str(&user_agent)
690                .map_err(|e| RPCError::FormatRequest(format!("Invalid user agent format: {e}")))?,
691        );
692
693        // Add Authorization if one is given
694        if let Some(key) = auth_key {
695            let mut auth_value = header::HeaderValue::from_str(key).map_err(|e| {
696                RPCError::FormatRequest(format!("Invalid authorization key format: {e}"))
697            })?;
698            auth_value.set_sensitive(true);
699            headers.insert(header::AUTHORIZATION, auth_value);
700        }
701
702        let client = ClientBuilder::new()
703            .default_headers(headers)
704            .http2_prior_knowledge()
705            .build()
706            .map_err(|e| RPCError::HttpClient(e.to_string(), e))?;
707        Ok(Self {
708            http_client: client,
709            url: uri,
710            retry_after: Arc::new(RwLock::new(None)),
711            backoff_policy: ExponentialBackoffBuilder::new()
712                .with_initial_interval(Duration::from_millis(250))
713                // increase backoff time by 75% each failure
714                .with_multiplier(1.75)
715                // keep retrying every 30s
716                .with_max_interval(Duration::from_secs(30))
717                // if all retries take longer than 2m, give up
718                .with_max_elapsed_time(Some(Duration::from_secs(125)))
719                .build(),
720            server_restart_duration: Duration::from_secs(120),
721        })
722    }
723
724    #[cfg(test)]
725    pub fn with_test_backoff_policy(mut self) -> Self {
726        // Extremely short intervals for very fast testing
727        self.backoff_policy = ExponentialBackoffBuilder::new()
728            .with_initial_interval(Duration::from_millis(1))
729            .with_multiplier(1.1)
730            .with_max_interval(Duration::from_millis(5))
731            .with_max_elapsed_time(Some(Duration::from_millis(50)))
732            .build();
733        self.server_restart_duration = Duration::from_millis(50);
734        self
735    }
736
737    /// Converts a error response to a Result.
738    ///
739    /// Raises an error if the response status code id 429, 502, 503 or 504. In the 429
740    /// case it will try to look for a retry-after header an parse it accordingly. The
741    /// parsed value is then passed as part of the error.
742    async fn error_for_response(
743        &self,
744        response: reqwest::Response,
745    ) -> Result<reqwest::Response, RPCError> {
746        match response.status() {
747            StatusCode::TOO_MANY_REQUESTS => {
748                let retry_after_raw = response
749                    .headers()
750                    .get(reqwest::header::RETRY_AFTER)
751                    .and_then(|h| h.to_str().ok())
752                    .and_then(parse_retry_value);
753
754                Err(RPCError::RateLimited(retry_after_raw))
755            }
756            StatusCode::BAD_GATEWAY |
757            StatusCode::SERVICE_UNAVAILABLE |
758            StatusCode::GATEWAY_TIMEOUT => Err(RPCError::ServerUnreachable(
759                response
760                    .text()
761                    .await
762                    .unwrap_or_else(|_| "Server Unreachable".to_string()),
763            )),
764            _ => Ok(response),
765        }
766    }
767
768    /// Classifies errors into transient or permanent ones.
769    ///
770    /// Transient errors are retried with a potential backoff, permanent ones are not.
771    /// If the error is RateLimited, this method will set the self.retry_after value so
772    /// future requests wait until the rate limit has been reset.
773    async fn handle_error_for_backoff(&self, e: RPCError) -> backoff::Error<RPCError> {
774        match e {
775            RPCError::ServerUnreachable(_) => {
776                backoff::Error::retry_after(e, self.server_restart_duration)
777            }
778            RPCError::RateLimited(Some(until)) => {
779                let mut retry_after_guard = self.retry_after.write().await;
780                *retry_after_guard = Some(
781                    retry_after_guard
782                        .unwrap_or(until)
783                        .max(until),
784                );
785
786                if let Ok(duration) = until.duration_since(SystemTime::now()) {
787                    backoff::Error::retry_after(e, duration)
788                } else {
789                    e.into()
790                }
791            }
792            RPCError::RateLimited(None) => e.into(),
793            _ => backoff::Error::permanent(e),
794        }
795    }
796
797    /// Waits until the current rate limit time has passed.
798    ///
799    /// Only waits if there is a time and that time is in the future, else return
800    /// immediately.
801    async fn wait_until_retry_after(&self) {
802        if let Some(&until) = self.retry_after.read().await.as_ref() {
803            let now = SystemTime::now();
804            if until > now {
805                if let Ok(duration) = until.duration_since(now) {
806                    sleep(duration).await
807                }
808            }
809        }
810    }
811
812    /// Makes a post request handling transient failures.
813    ///
814    /// If a retry-after header is received it will be respected. Else the configured
815    /// backoff policy is used to deal with transient network or server errors.
816    async fn make_post_request<T: Serialize + ?Sized>(
817        &self,
818        request: &T,
819        uri: &String,
820    ) -> Result<Response, RPCError> {
821        self.wait_until_retry_after().await;
822        let response = backoff::future::retry(self.backoff_policy.clone(), || async {
823            let server_response = self
824                .http_client
825                .post(uri)
826                .json(request)
827                .send()
828                .await
829                .map_err(|e| RPCError::HttpClient(e.to_string(), e))?;
830
831            match self
832                .error_for_response(server_response)
833                .await
834            {
835                Ok(response) => Ok(response),
836                Err(e) => Err(self.handle_error_for_backoff(e).await),
837            }
838        })
839        .await?;
840        Ok(response)
841    }
842}
843
844fn parse_retry_value(val: &str) -> Option<SystemTime> {
845    if let Ok(secs) = val.parse::<u64>() {
846        return Some(SystemTime::now() + Duration::from_secs(secs));
847    }
848    if let Ok(date) = OffsetDateTime::parse(val, &Rfc2822) {
849        return Some(date.into());
850    }
851    None
852}
853
854#[async_trait]
855impl RPCClient for HttpRPCClient {
856    #[instrument(skip(self, request))]
857    async fn get_contract_state(
858        &self,
859        request: &StateRequestBody,
860    ) -> Result<StateRequestResponse, RPCError> {
861        // Check if contract ids are specified
862        if request
863            .contract_ids
864            .as_ref()
865            .is_none_or(|ids| ids.is_empty())
866        {
867            warn!("No contract ids specified in request.");
868        }
869
870        let uri = format!(
871            "{}/{}/contract_state",
872            self.url
873                .to_string()
874                .trim_end_matches('/'),
875            TYCHO_SERVER_VERSION
876        );
877        debug!(%uri, "Sending contract_state request to Tycho server");
878        trace!(?request, "Sending request to Tycho server");
879        let response = self
880            .make_post_request(request, &uri)
881            .await?;
882        trace!(?response, "Received response from Tycho server");
883
884        let body = response
885            .text()
886            .await
887            .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
888        if body.is_empty() {
889            // Pure native protocols will return empty contract states
890            return Ok(StateRequestResponse {
891                accounts: vec![],
892                pagination: PaginationResponse {
893                    page: request.pagination.page,
894                    page_size: request.pagination.page,
895                    total: 0,
896                },
897            });
898        }
899
900        let accounts = serde_json::from_str::<StateRequestResponse>(&body)
901            .map_err(|err| RPCError::ParseResponse(format!("Error: {err}, Body: {body}")))?;
902        trace!(?accounts, "Received contract_state response from Tycho server");
903
904        Ok(accounts)
905    }
906
907    async fn get_protocol_components(
908        &self,
909        request: &ProtocolComponentsRequestBody,
910    ) -> Result<ProtocolComponentRequestResponse, RPCError> {
911        let uri = format!(
912            "{}/{}/protocol_components",
913            self.url
914                .to_string()
915                .trim_end_matches('/'),
916            TYCHO_SERVER_VERSION,
917        );
918        debug!(%uri, "Sending protocol_components request to Tycho server");
919        trace!(?request, "Sending request to Tycho server");
920
921        let response = self
922            .make_post_request(request, &uri)
923            .await?;
924
925        trace!(?response, "Received response from Tycho server");
926
927        let body = response
928            .text()
929            .await
930            .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
931        let components = serde_json::from_str::<ProtocolComponentRequestResponse>(&body)
932            .map_err(|err| RPCError::ParseResponse(format!("Error: {err}, Body: {body}")))?;
933        trace!(?components, "Received protocol_components response from Tycho server");
934
935        Ok(components)
936    }
937
938    async fn get_protocol_states(
939        &self,
940        request: &ProtocolStateRequestBody,
941    ) -> Result<ProtocolStateRequestResponse, RPCError> {
942        // Check if protocol ids are specified
943        if request
944            .protocol_ids
945            .as_ref()
946            .is_none_or(|ids| ids.is_empty())
947        {
948            warn!("No protocol ids specified in request.");
949        }
950
951        let uri = format!(
952            "{}/{}/protocol_state",
953            self.url
954                .to_string()
955                .trim_end_matches('/'),
956            TYCHO_SERVER_VERSION
957        );
958        debug!(%uri, "Sending protocol_states request to Tycho server");
959        trace!(?request, "Sending request to Tycho server");
960
961        let response = self
962            .make_post_request(request, &uri)
963            .await?;
964        trace!(?response, "Received response from Tycho server");
965
966        let body = response
967            .text()
968            .await
969            .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
970
971        if body.is_empty() {
972            // Pure VM protocols will return empty states
973            return Ok(ProtocolStateRequestResponse {
974                states: vec![],
975                pagination: PaginationResponse {
976                    page: request.pagination.page,
977                    page_size: request.pagination.page_size,
978                    total: 0,
979                },
980            });
981        }
982
983        let states = serde_json::from_str::<ProtocolStateRequestResponse>(&body)
984            .map_err(|err| RPCError::ParseResponse(format!("Error: {err}, Body: {body}")))?;
985        trace!(?states, "Received protocol_states response from Tycho server");
986
987        Ok(states)
988    }
989
990    async fn get_tokens(
991        &self,
992        request: &TokensRequestBody,
993    ) -> Result<TokensRequestResponse, RPCError> {
994        let uri = format!(
995            "{}/{}/tokens",
996            self.url
997                .to_string()
998                .trim_end_matches('/'),
999            TYCHO_SERVER_VERSION
1000        );
1001        debug!(%uri, "Sending tokens request to Tycho server");
1002
1003        let response = self
1004            .make_post_request(request, &uri)
1005            .await?;
1006
1007        let body = response
1008            .text()
1009            .await
1010            .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
1011        let tokens = serde_json::from_str::<TokensRequestResponse>(&body)
1012            .map_err(|err| RPCError::ParseResponse(format!("Error: {err}, Body: {body}")))?;
1013
1014        Ok(tokens)
1015    }
1016
1017    async fn get_protocol_systems(
1018        &self,
1019        request: &ProtocolSystemsRequestBody,
1020    ) -> Result<ProtocolSystemsRequestResponse, RPCError> {
1021        let uri = format!(
1022            "{}/{}/protocol_systems",
1023            self.url
1024                .to_string()
1025                .trim_end_matches('/'),
1026            TYCHO_SERVER_VERSION
1027        );
1028        debug!(%uri, "Sending protocol_systems request to Tycho server");
1029        trace!(?request, "Sending request to Tycho server");
1030        let response = self
1031            .make_post_request(request, &uri)
1032            .await?;
1033        trace!(?response, "Received response from Tycho server");
1034        let body = response
1035            .text()
1036            .await
1037            .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
1038        let protocol_systems = serde_json::from_str::<ProtocolSystemsRequestResponse>(&body)
1039            .map_err(|err| RPCError::ParseResponse(format!("Error: {err}, Body: {body}")))?;
1040        trace!(?protocol_systems, "Received protocol_systems response from Tycho server");
1041        Ok(protocol_systems)
1042    }
1043
1044    async fn get_component_tvl(
1045        &self,
1046        request: &ComponentTvlRequestBody,
1047    ) -> Result<ComponentTvlRequestResponse, RPCError> {
1048        let uri = format!(
1049            "{}/{}/component_tvl",
1050            self.url
1051                .to_string()
1052                .trim_end_matches('/'),
1053            TYCHO_SERVER_VERSION
1054        );
1055        debug!(%uri, "Sending get_component_tvl request to Tycho server");
1056        trace!(?request, "Sending request to Tycho server");
1057        let response = self
1058            .make_post_request(request, &uri)
1059            .await?;
1060        trace!(?response, "Received response from Tycho server");
1061        let body = response
1062            .text()
1063            .await
1064            .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
1065        let component_tvl =
1066            serde_json::from_str::<ComponentTvlRequestResponse>(&body).map_err(|err| {
1067                error!("Failed to parse component_tvl response: {:?}", &body);
1068                RPCError::ParseResponse(format!("Error: {err}, Body: {body}"))
1069            })?;
1070        trace!(?component_tvl, "Received component_tvl response from Tycho server");
1071        Ok(component_tvl)
1072    }
1073
1074    async fn get_traced_entry_points(
1075        &self,
1076        request: &TracedEntryPointRequestBody,
1077    ) -> Result<TracedEntryPointRequestResponse, RPCError> {
1078        let uri = format!(
1079            "{}/{TYCHO_SERVER_VERSION}/traced_entry_points",
1080            self.url
1081                .to_string()
1082                .trim_end_matches('/')
1083        );
1084        debug!(%uri, "Sending traced_entry_points request to Tycho server");
1085        trace!(?request, "Sending request to Tycho server");
1086
1087        let response = self
1088            .make_post_request(request, &uri)
1089            .await?;
1090
1091        trace!(?response, "Received response from Tycho server");
1092
1093        let body = response
1094            .text()
1095            .await
1096            .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
1097        let entrypoints =
1098            serde_json::from_str::<TracedEntryPointRequestResponse>(&body).map_err(|err| {
1099                error!("Failed to parse traced_entry_points response: {:?}", &body);
1100                RPCError::ParseResponse(format!("Error: {err}, Body: {body}"))
1101            })?;
1102        trace!(?entrypoints, "Received traced_entry_points response from Tycho server");
1103        Ok(entrypoints)
1104    }
1105
1106    async fn get_snapshots<'a>(
1107        &self,
1108        request: &SnapshotParameters<'a>,
1109        chunk_size: usize,
1110        concurrency: usize,
1111    ) -> Result<Snapshot, RPCError> {
1112        let component_ids: Vec<_> = request
1113            .components
1114            .keys()
1115            .cloned()
1116            .collect();
1117
1118        let version = VersionParam::new(
1119            None,
1120            Some({
1121                #[allow(deprecated)]
1122                BlockParam { hash: None, chain: None, number: Some(request.block_number as i64) }
1123            }),
1124        );
1125
1126        let component_tvl = if request.include_tvl && !component_ids.is_empty() {
1127            let body = ComponentTvlRequestBody::id_filtered(component_ids.clone(), request.chain);
1128            self.get_component_tvl_paginated(&body, chunk_size, concurrency)
1129                .await?
1130                .tvl
1131        } else {
1132            HashMap::new()
1133        };
1134
1135        let mut protocol_states = if !component_ids.is_empty() {
1136            self.get_protocol_states_paginated(
1137                request.chain,
1138                &component_ids,
1139                request.protocol_system,
1140                request.include_balances,
1141                &version,
1142                chunk_size,
1143                concurrency,
1144            )
1145            .await?
1146            .states
1147            .into_iter()
1148            .map(|state| (state.component_id.clone(), state))
1149            .collect()
1150        } else {
1151            HashMap::new()
1152        };
1153
1154        // Convert to ComponentWithState, which includes entrypoint information.
1155        let states = request
1156            .components
1157            .values()
1158            .filter_map(|component| {
1159                if let Some(state) = protocol_states.remove(&component.id) {
1160                    Some((
1161                        component.id.clone(),
1162                        ComponentWithState {
1163                            state,
1164                            component: component.clone(),
1165                            component_tvl: component_tvl
1166                                .get(&component.id)
1167                                .cloned(),
1168                            entrypoints: request
1169                                .entrypoints
1170                                .as_ref()
1171                                .and_then(|map| map.get(&component.id))
1172                                .cloned()
1173                                .unwrap_or_default(),
1174                        },
1175                    ))
1176                } else if component_ids.contains(&component.id) {
1177                    // only emit error event if we requested this component
1178                    let component_id = &component.id;
1179                    error!(?component_id, "Missing state for native component!");
1180                    None
1181                } else {
1182                    None
1183                }
1184            })
1185            .collect();
1186
1187        let vm_storage = if !request.contract_ids.is_empty() {
1188            let contract_states = self
1189                .get_contract_state_paginated(
1190                    request.chain,
1191                    request.contract_ids,
1192                    request.protocol_system,
1193                    &version,
1194                    chunk_size,
1195                    concurrency,
1196                )
1197                .await?
1198                .accounts
1199                .into_iter()
1200                .map(|acc| (acc.address.clone(), acc))
1201                .collect::<HashMap<_, _>>();
1202
1203            trace!(states=?&contract_states, "Retrieved ContractState");
1204
1205            let contract_address_to_components = request
1206                .components
1207                .iter()
1208                .filter_map(|(id, comp)| {
1209                    if component_ids.contains(id) {
1210                        Some(
1211                            comp.contract_ids
1212                                .iter()
1213                                .map(|address| (address.clone(), comp.id.clone())),
1214                        )
1215                    } else {
1216                        None
1217                    }
1218                })
1219                .flatten()
1220                .fold(HashMap::<Bytes, Vec<String>>::new(), |mut acc, (addr, c_id)| {
1221                    acc.entry(addr).or_default().push(c_id);
1222                    acc
1223                });
1224
1225            request
1226                .contract_ids
1227                .iter()
1228                .filter_map(|address| {
1229                    if let Some(state) = contract_states.get(address) {
1230                        Some((address.clone(), state.clone()))
1231                    } else if let Some(ids) = contract_address_to_components.get(address) {
1232                        // only emit error even if we did actually request this address
1233                        error!(
1234                            ?address,
1235                            ?ids,
1236                            "Component with lacking contract storage encountered!"
1237                        );
1238                        None
1239                    } else {
1240                        None
1241                    }
1242                })
1243                .collect()
1244        } else {
1245            HashMap::new()
1246        };
1247
1248        Ok(Snapshot { states, vm_storage })
1249    }
1250}
1251
1252#[cfg(test)]
1253mod tests {
1254    use std::{
1255        collections::{HashMap, HashSet},
1256        str::FromStr,
1257    };
1258
1259    use mockito::Server;
1260    use rstest::rstest;
1261    // TODO: remove once deprecated ProtocolId struct is removed
1262    #[allow(deprecated)]
1263    use tycho_common::dto::ProtocolId;
1264    use tycho_common::dto::{AddressStorageLocation, TracingParams};
1265
1266    use super::*;
1267
1268    // Dummy implementation of `get_protocol_states_paginated` for backwards compatibility testing
1269    // purposes
1270    impl MockRPCClient {
1271        #[allow(clippy::too_many_arguments)]
1272        async fn test_get_protocol_states_paginated<T>(
1273            &self,
1274            chain: Chain,
1275            ids: &[T],
1276            protocol_system: &str,
1277            include_balances: bool,
1278            version: &VersionParam,
1279            chunk_size: usize,
1280            _concurrency: usize,
1281        ) -> Vec<ProtocolStateRequestBody>
1282        where
1283            T: AsRef<str> + Clone + Send + Sync + 'static,
1284        {
1285            ids.chunks(chunk_size)
1286                .map(|chunk| ProtocolStateRequestBody {
1287                    protocol_ids: Some(
1288                        chunk
1289                            .iter()
1290                            .map(|id| id.as_ref().to_string())
1291                            .collect(),
1292                    ),
1293                    protocol_system: protocol_system.to_string(),
1294                    chain,
1295                    include_balances,
1296                    version: version.clone(),
1297                    pagination: PaginationParams { page: 0, page_size: chunk_size as i64 },
1298                })
1299                .collect()
1300        }
1301    }
1302
1303    // TODO: remove once deprecated ProtocolId struct is removed
1304    #[allow(deprecated)]
1305    #[rstest]
1306    #[case::protocol_id_input(vec![
1307        ProtocolId { id: "id1".to_string(), chain: Chain::Ethereum },
1308        ProtocolId { id: "id2".to_string(), chain: Chain::Ethereum }
1309    ])]
1310    #[case::string_input(vec![
1311        "id1".to_string(),
1312        "id2".to_string()
1313    ])]
1314    #[tokio::test]
1315    async fn test_get_protocol_states_paginated_backwards_compatibility<T>(#[case] ids: Vec<T>)
1316    where
1317        T: AsRef<str> + Clone + Send + Sync + 'static,
1318    {
1319        let mock_client = MockRPCClient::new();
1320
1321        let request_bodies = mock_client
1322            .test_get_protocol_states_paginated(
1323                Chain::Ethereum,
1324                &ids,
1325                "test_system",
1326                true,
1327                &VersionParam::default(),
1328                2,
1329                2,
1330            )
1331            .await;
1332
1333        // Verify that the request bodies have been created correctly
1334        assert_eq!(request_bodies.len(), 1);
1335        assert_eq!(
1336            request_bodies[0]
1337                .protocol_ids
1338                .as_ref()
1339                .unwrap()
1340                .len(),
1341            2
1342        );
1343    }
1344
1345    #[tokio::test]
1346    async fn test_get_contract_state() {
1347        let mut server = Server::new_async().await;
1348        let server_resp = r#"
1349        {
1350            "accounts": [
1351                {
1352                    "chain": "ethereum",
1353                    "address": "0x0000000000000000000000000000000000000000",
1354                    "title": "",
1355                    "slots": {},
1356                    "native_balance": "0x01f4",
1357                    "token_balances": {},
1358                    "code": "0x00",
1359                    "code_hash": "0x5c06b7c5b3d910fd33bc2229846f9ddaf91d584d9b196e16636901ac3a77077e",
1360                    "balance_modify_tx": "0x0000000000000000000000000000000000000000000000000000000000000000",
1361                    "code_modify_tx": "0x0000000000000000000000000000000000000000000000000000000000000000",
1362                    "creation_tx": null
1363                }
1364            ],
1365            "pagination": {
1366                "page": 0,
1367                "page_size": 20,
1368                "total": 10
1369            }
1370        }
1371        "#;
1372        // test that the response is deserialized correctly
1373        serde_json::from_str::<StateRequestResponse>(server_resp).expect("deserialize");
1374
1375        let mocked_server = server
1376            .mock("POST", "/v1/contract_state")
1377            .expect(1)
1378            .with_body(server_resp)
1379            .create_async()
1380            .await;
1381
1382        let client = HttpRPCClient::new(server.url().as_str(), None).expect("create client");
1383
1384        let response = client
1385            .get_contract_state(&Default::default())
1386            .await
1387            .expect("get state");
1388        let accounts = response.accounts;
1389
1390        mocked_server.assert();
1391        assert_eq!(accounts.len(), 1);
1392        assert_eq!(accounts[0].slots, HashMap::new());
1393        assert_eq!(accounts[0].native_balance, Bytes::from(500u16.to_be_bytes()));
1394        assert_eq!(accounts[0].code, [0].to_vec());
1395        assert_eq!(
1396            accounts[0].code_hash,
1397            hex::decode("5c06b7c5b3d910fd33bc2229846f9ddaf91d584d9b196e16636901ac3a77077e")
1398                .unwrap()
1399        );
1400    }
1401
1402    #[tokio::test]
1403    async fn test_get_protocol_components() {
1404        let mut server = Server::new_async().await;
1405        let server_resp = r#"
1406        {
1407            "protocol_components": [
1408                {
1409                    "id": "State1",
1410                    "protocol_system": "ambient",
1411                    "protocol_type_name": "Pool",
1412                    "chain": "ethereum",
1413                    "tokens": [
1414                        "0x0000000000000000000000000000000000000000",
1415                        "0x0000000000000000000000000000000000000001"
1416                    ],
1417                    "contract_ids": [
1418                        "0x0000000000000000000000000000000000000000"
1419                    ],
1420                    "static_attributes": {
1421                        "attribute_1": "0x00000000000003e8"
1422                    },
1423                    "change": "Creation",
1424                    "creation_tx": "0x0000000000000000000000000000000000000000000000000000000000000000",
1425                    "created_at": "2022-01-01T00:00:00"
1426                }
1427            ],
1428            "pagination": {
1429                "page": 0,
1430                "page_size": 20,
1431                "total": 10
1432            }
1433        }
1434        "#;
1435        // test that the response is deserialized correctly
1436        serde_json::from_str::<ProtocolComponentRequestResponse>(server_resp).expect("deserialize");
1437
1438        let mocked_server = server
1439            .mock("POST", "/v1/protocol_components")
1440            .expect(1)
1441            .with_body(server_resp)
1442            .create_async()
1443            .await;
1444
1445        let client = HttpRPCClient::new(server.url().as_str(), None).expect("create client");
1446
1447        let response = client
1448            .get_protocol_components(&Default::default())
1449            .await
1450            .expect("get state");
1451        let components = response.protocol_components;
1452
1453        mocked_server.assert();
1454        assert_eq!(components.len(), 1);
1455        assert_eq!(components[0].id, "State1");
1456        assert_eq!(components[0].protocol_system, "ambient");
1457        assert_eq!(components[0].protocol_type_name, "Pool");
1458        assert_eq!(components[0].tokens.len(), 2);
1459        let expected_attributes =
1460            [("attribute_1".to_string(), Bytes::from(1000_u64.to_be_bytes()))]
1461                .iter()
1462                .cloned()
1463                .collect::<HashMap<String, Bytes>>();
1464        assert_eq!(components[0].static_attributes, expected_attributes);
1465    }
1466
1467    #[tokio::test]
1468    async fn test_get_protocol_states() {
1469        let mut server = Server::new_async().await;
1470        let server_resp = r#"
1471        {
1472            "states": [
1473                {
1474                    "component_id": "State1",
1475                    "attributes": {
1476                        "attribute_1": "0x00000000000003e8"
1477                    },
1478                    "balances": {
1479                        "0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2": "0x01f4"
1480                    }
1481                }
1482            ],
1483            "pagination": {
1484                "page": 0,
1485                "page_size": 20,
1486                "total": 10
1487            }
1488        }
1489        "#;
1490        // test that the response is deserialized correctly
1491        serde_json::from_str::<ProtocolStateRequestResponse>(server_resp).expect("deserialize");
1492
1493        let mocked_server = server
1494            .mock("POST", "/v1/protocol_state")
1495            .expect(1)
1496            .with_body(server_resp)
1497            .create_async()
1498            .await;
1499        let client = HttpRPCClient::new(server.url().as_str(), None).expect("create client");
1500
1501        let response = client
1502            .get_protocol_states(&Default::default())
1503            .await
1504            .expect("get state");
1505        let states = response.states;
1506
1507        mocked_server.assert();
1508        assert_eq!(states.len(), 1);
1509        assert_eq!(states[0].component_id, "State1");
1510        let expected_attributes =
1511            [("attribute_1".to_string(), Bytes::from(1000_u64.to_be_bytes()))]
1512                .iter()
1513                .cloned()
1514                .collect::<HashMap<String, Bytes>>();
1515        assert_eq!(states[0].attributes, expected_attributes);
1516        let expected_balances = [(
1517            Bytes::from_str("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2")
1518                .expect("Unsupported address format"),
1519            Bytes::from_str("0x01f4").unwrap(),
1520        )]
1521        .iter()
1522        .cloned()
1523        .collect::<HashMap<Bytes, Bytes>>();
1524        assert_eq!(states[0].balances, expected_balances);
1525    }
1526
1527    #[tokio::test]
1528    async fn test_get_tokens() {
1529        let mut server = Server::new_async().await;
1530        let server_resp = r#"
1531        {
1532            "tokens": [
1533              {
1534                "chain": "ethereum",
1535                "address": "0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2",
1536                "symbol": "WETH",
1537                "decimals": 18,
1538                "tax": 0,
1539                "gas": [
1540                  29962
1541                ],
1542                "quality": 100
1543              },
1544              {
1545                "chain": "ethereum",
1546                "address": "0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48",
1547                "symbol": "USDC",
1548                "decimals": 6,
1549                "tax": 0,
1550                "gas": [
1551                  40652
1552                ],
1553                "quality": 100
1554              }
1555            ],
1556            "pagination": {
1557              "page": 0,
1558              "page_size": 20,
1559              "total": 10
1560            }
1561          }
1562        "#;
1563        // test that the response is deserialized correctly
1564        serde_json::from_str::<TokensRequestResponse>(server_resp).expect("deserialize");
1565
1566        let mocked_server = server
1567            .mock("POST", "/v1/tokens")
1568            .expect(1)
1569            .with_body(server_resp)
1570            .create_async()
1571            .await;
1572        let client = HttpRPCClient::new(server.url().as_str(), None).expect("create client");
1573
1574        let response = client
1575            .get_tokens(&Default::default())
1576            .await
1577            .expect("get tokens");
1578
1579        let expected = vec![
1580            ResponseToken {
1581                chain: Chain::Ethereum,
1582                address: Bytes::from_str("0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2").unwrap(),
1583                symbol: "WETH".to_string(),
1584                decimals: 18,
1585                tax: 0,
1586                gas: vec![Some(29962)],
1587                quality: 100,
1588            },
1589            ResponseToken {
1590                chain: Chain::Ethereum,
1591                address: Bytes::from_str("0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48").unwrap(),
1592                symbol: "USDC".to_string(),
1593                decimals: 6,
1594                tax: 0,
1595                gas: vec![Some(40652)],
1596                quality: 100,
1597            },
1598        ];
1599
1600        mocked_server.assert();
1601        assert_eq!(response.tokens, expected);
1602        assert_eq!(response.pagination, PaginationResponse { page: 0, page_size: 20, total: 10 });
1603    }
1604
1605    #[tokio::test]
1606    async fn test_get_protocol_systems() {
1607        let mut server = Server::new_async().await;
1608        let server_resp = r#"
1609        {
1610            "protocol_systems": [
1611                "system1",
1612                "system2"
1613            ],
1614            "pagination": {
1615                "page": 0,
1616                "page_size": 20,
1617                "total": 10
1618            }
1619        }
1620        "#;
1621        // test that the response is deserialized correctly
1622        serde_json::from_str::<ProtocolSystemsRequestResponse>(server_resp).expect("deserialize");
1623
1624        let mocked_server = server
1625            .mock("POST", "/v1/protocol_systems")
1626            .expect(1)
1627            .with_body(server_resp)
1628            .create_async()
1629            .await;
1630        let client = HttpRPCClient::new(server.url().as_str(), None).expect("create client");
1631
1632        let response = client
1633            .get_protocol_systems(&Default::default())
1634            .await
1635            .expect("get protocol systems");
1636        let protocol_systems = response.protocol_systems;
1637
1638        mocked_server.assert();
1639        assert_eq!(protocol_systems, vec!["system1", "system2"]);
1640    }
1641
1642    #[tokio::test]
1643    async fn test_get_component_tvl() {
1644        let mut server = Server::new_async().await;
1645        let server_resp = r#"
1646        {
1647            "tvl": {
1648                "component1": 100.0
1649            },
1650            "pagination": {
1651                "page": 0,
1652                "page_size": 20,
1653                "total": 10
1654            }
1655        }
1656        "#;
1657        // test that the response is deserialized correctly
1658        serde_json::from_str::<ComponentTvlRequestResponse>(server_resp).expect("deserialize");
1659
1660        let mocked_server = server
1661            .mock("POST", "/v1/component_tvl")
1662            .expect(1)
1663            .with_body(server_resp)
1664            .create_async()
1665            .await;
1666        let client = HttpRPCClient::new(server.url().as_str(), None).expect("create client");
1667
1668        let response = client
1669            .get_component_tvl(&Default::default())
1670            .await
1671            .expect("get protocol systems");
1672        let component_tvl = response.tvl;
1673
1674        mocked_server.assert();
1675        assert_eq!(component_tvl.get("component1"), Some(&100.0));
1676    }
1677
1678    #[tokio::test]
1679    async fn test_get_traced_entry_points() {
1680        let mut server = Server::new_async().await;
1681        let server_resp = r#"
1682        {
1683            "traced_entry_points": {
1684                "component_1": [
1685                    [
1686                        {
1687                            "entry_point": {
1688                                "external_id": "entrypoint_a",
1689                                "target": "0x0000000000000000000000000000000000000001",
1690                                "signature": "sig()"
1691                            },
1692                            "params": {
1693                                "method": "rpctracer",
1694                                "caller": "0x000000000000000000000000000000000000000a",
1695                                "calldata": "0x000000000000000000000000000000000000000b"
1696                            }
1697                        },
1698                        {
1699                            "retriggers": [
1700                                [
1701                                    "0x00000000000000000000000000000000000000aa",
1702                                    {"key": "0x0000000000000000000000000000000000000aaa", "offset": 12}
1703                                ]
1704                            ],
1705                            "accessed_slots": {
1706                                "0x0000000000000000000000000000000000aaaa": [
1707                                    "0x0000000000000000000000000000000000aaaa"
1708                                ]
1709                            }
1710                        }
1711                    ]
1712                ]
1713            },
1714            "pagination": {
1715                "page": 0,
1716                "page_size": 20,
1717                "total": 1
1718            }
1719        }
1720        "#;
1721        // test that the response is deserialized correctly
1722        serde_json::from_str::<TracedEntryPointRequestResponse>(server_resp).expect("deserialize");
1723
1724        let mocked_server = server
1725            .mock("POST", "/v1/traced_entry_points")
1726            .expect(1)
1727            .with_body(server_resp)
1728            .create_async()
1729            .await;
1730        let client = HttpRPCClient::new(server.url().as_str(), None).expect("create client");
1731
1732        let response = client
1733            .get_traced_entry_points(&Default::default())
1734            .await
1735            .expect("get traced entry points");
1736        let entrypoints = response.traced_entry_points;
1737
1738        mocked_server.assert();
1739        assert_eq!(entrypoints.len(), 1);
1740        let comp1_entrypoints = entrypoints
1741            .get("component_1")
1742            .expect("component_1 entrypoints should exist");
1743        assert_eq!(comp1_entrypoints.len(), 1);
1744
1745        let (entrypoint, trace_result) = &comp1_entrypoints[0];
1746        assert_eq!(entrypoint.entry_point.external_id, "entrypoint_a");
1747        assert_eq!(
1748            entrypoint.entry_point.target,
1749            Bytes::from_str("0x0000000000000000000000000000000000000001").unwrap()
1750        );
1751        assert_eq!(entrypoint.entry_point.signature, "sig()");
1752        let TracingParams::RPCTracer(rpc_params) = &entrypoint.params;
1753        assert_eq!(
1754            rpc_params.caller,
1755            Some(Bytes::from("0x000000000000000000000000000000000000000a"))
1756        );
1757        assert_eq!(rpc_params.calldata, Bytes::from("0x000000000000000000000000000000000000000b"));
1758
1759        assert_eq!(
1760            trace_result.retriggers,
1761            HashSet::from([(
1762                Bytes::from("0x00000000000000000000000000000000000000aa"),
1763                AddressStorageLocation::new(
1764                    Bytes::from("0x0000000000000000000000000000000000000aaa"),
1765                    12
1766                )
1767            )])
1768        );
1769        assert_eq!(trace_result.accessed_slots.len(), 1);
1770        assert_eq!(
1771            trace_result.accessed_slots,
1772            HashMap::from([(
1773                Bytes::from("0x0000000000000000000000000000000000aaaa"),
1774                HashSet::from([Bytes::from("0x0000000000000000000000000000000000aaaa")])
1775            )])
1776        );
1777    }
1778
1779    #[tokio::test]
1780    async fn test_parse_retry_value_numeric() {
1781        let result = parse_retry_value("60");
1782        assert!(result.is_some());
1783
1784        let expected_time = SystemTime::now() + Duration::from_secs(60);
1785        let actual_time = result.unwrap();
1786
1787        // Allow for small timing differences during test execution
1788        let diff = if actual_time > expected_time {
1789            actual_time
1790                .duration_since(expected_time)
1791                .unwrap()
1792        } else {
1793            expected_time
1794                .duration_since(actual_time)
1795                .unwrap()
1796        };
1797        assert!(diff < Duration::from_secs(1), "Time difference too large: {:?}", diff);
1798    }
1799
1800    #[tokio::test]
1801    async fn test_parse_retry_value_rfc2822() {
1802        // Use a fixed future date in RFC2822 format
1803        let rfc2822_date = "Sat, 01 Jan 2030 12:00:00 +0000";
1804        let result = parse_retry_value(rfc2822_date);
1805        assert!(result.is_some());
1806
1807        let parsed_time = result.unwrap();
1808        assert!(parsed_time > SystemTime::now());
1809    }
1810
1811    #[tokio::test]
1812    async fn test_parse_retry_value_invalid_formats() {
1813        // Test various invalid formats
1814        assert!(parse_retry_value("invalid").is_none());
1815        assert!(parse_retry_value("").is_none());
1816        assert!(parse_retry_value("not_a_number").is_none());
1817        assert!(parse_retry_value("Mon, 32 Jan 2030 25:00:00 +0000").is_none()); // Invalid date
1818    }
1819
1820    #[tokio::test]
1821    async fn test_parse_retry_value_zero_seconds() {
1822        let result = parse_retry_value("0");
1823        assert!(result.is_some());
1824
1825        let expected_time = SystemTime::now();
1826        let actual_time = result.unwrap();
1827
1828        // Should be very close to current time
1829        let diff = if actual_time > expected_time {
1830            actual_time
1831                .duration_since(expected_time)
1832                .unwrap()
1833        } else {
1834            expected_time
1835                .duration_since(actual_time)
1836                .unwrap()
1837        };
1838        assert!(diff < Duration::from_secs(1));
1839    }
1840
1841    #[tokio::test]
1842    async fn test_error_for_response_rate_limited() {
1843        let mut server = Server::new_async().await;
1844        let mock = server
1845            .mock("GET", "/test")
1846            .with_status(429)
1847            .with_header("Retry-After", "60")
1848            .create_async()
1849            .await;
1850
1851        let client = reqwest::Client::new();
1852        let response = client
1853            .get(format!("{}/test", server.url()))
1854            .send()
1855            .await
1856            .unwrap();
1857
1858        let http_client = HttpRPCClient::new(server.url().as_str(), None)
1859            .unwrap()
1860            .with_test_backoff_policy();
1861        let result = http_client
1862            .error_for_response(response)
1863            .await;
1864
1865        mock.assert();
1866        assert!(matches!(result, Err(RPCError::RateLimited(_))));
1867        if let Err(RPCError::RateLimited(retry_after)) = result {
1868            assert!(retry_after.is_some());
1869        }
1870    }
1871
1872    #[tokio::test]
1873    async fn test_error_for_response_rate_limited_no_header() {
1874        let mut server = Server::new_async().await;
1875        let mock = server
1876            .mock("GET", "/test")
1877            .with_status(429)
1878            .create_async()
1879            .await;
1880
1881        let client = reqwest::Client::new();
1882        let response = client
1883            .get(format!("{}/test", server.url()))
1884            .send()
1885            .await
1886            .unwrap();
1887
1888        let http_client = HttpRPCClient::new(server.url().as_str(), None)
1889            .unwrap()
1890            .with_test_backoff_policy();
1891        let result = http_client
1892            .error_for_response(response)
1893            .await;
1894
1895        mock.assert();
1896        assert!(matches!(result, Err(RPCError::RateLimited(None))));
1897    }
1898
1899    #[tokio::test]
1900    async fn test_error_for_response_server_errors() {
1901        let test_cases =
1902            vec![(502, "Bad Gateway"), (503, "Service Unavailable"), (504, "Gateway Timeout")];
1903
1904        for (status_code, expected_body) in test_cases {
1905            let mut server = Server::new_async().await;
1906            let mock = server
1907                .mock("GET", "/test")
1908                .with_status(status_code)
1909                .with_body(expected_body)
1910                .create_async()
1911                .await;
1912
1913            let client = reqwest::Client::new();
1914            let response = client
1915                .get(format!("{}/test", server.url()))
1916                .send()
1917                .await
1918                .unwrap();
1919
1920            let http_client = HttpRPCClient::new(server.url().as_str(), None)
1921                .unwrap()
1922                .with_test_backoff_policy();
1923            let result = http_client
1924                .error_for_response(response)
1925                .await;
1926
1927            mock.assert();
1928            assert!(matches!(result, Err(RPCError::ServerUnreachable(_))));
1929            if let Err(RPCError::ServerUnreachable(body)) = result {
1930                assert_eq!(body, expected_body);
1931            }
1932        }
1933    }
1934
1935    #[tokio::test]
1936    async fn test_error_for_response_success() {
1937        let mut server = Server::new_async().await;
1938        let mock = server
1939            .mock("GET", "/test")
1940            .with_status(200)
1941            .with_body("success")
1942            .create_async()
1943            .await;
1944
1945        let client = reqwest::Client::new();
1946        let response = client
1947            .get(format!("{}/test", server.url()))
1948            .send()
1949            .await
1950            .unwrap();
1951
1952        let http_client = HttpRPCClient::new(server.url().as_str(), None)
1953            .unwrap()
1954            .with_test_backoff_policy();
1955        let result = http_client
1956            .error_for_response(response)
1957            .await;
1958
1959        mock.assert();
1960        assert!(result.is_ok());
1961
1962        let response = result.unwrap();
1963        assert_eq!(response.status(), 200);
1964    }
1965
1966    #[tokio::test]
1967    async fn test_handle_error_for_backoff_server_unreachable() {
1968        let http_client = HttpRPCClient::new("http://localhost:8080", None)
1969            .unwrap()
1970            .with_test_backoff_policy();
1971        let error = RPCError::ServerUnreachable("Service down".to_string());
1972
1973        let backoff_error = http_client
1974            .handle_error_for_backoff(error)
1975            .await;
1976
1977        match backoff_error {
1978            backoff::Error::Transient { err: RPCError::ServerUnreachable(msg), retry_after } => {
1979                assert_eq!(msg, "Service down");
1980                assert_eq!(retry_after, Some(Duration::from_millis(50))); // Fast test duration
1981            }
1982            _ => panic!("Expected transient error for ServerUnreachable"),
1983        }
1984    }
1985
1986    #[tokio::test]
1987    async fn test_handle_error_for_backoff_rate_limited_with_retry_after() {
1988        let http_client = HttpRPCClient::new("http://localhost:8080", None)
1989            .unwrap()
1990            .with_test_backoff_policy();
1991        let future_time = SystemTime::now() + Duration::from_secs(30);
1992        let error = RPCError::RateLimited(Some(future_time));
1993
1994        let backoff_error = http_client
1995            .handle_error_for_backoff(error)
1996            .await;
1997
1998        match backoff_error {
1999            backoff::Error::Transient { err: RPCError::RateLimited(retry_after), .. } => {
2000                assert_eq!(retry_after, Some(future_time));
2001            }
2002            _ => panic!("Expected transient error for RateLimited"),
2003        }
2004
2005        // Verify that retry_after was stored in the client state
2006        let stored_retry_after = http_client.retry_after.read().await;
2007        assert_eq!(*stored_retry_after, Some(future_time));
2008    }
2009
2010    #[tokio::test]
2011    async fn test_handle_error_for_backoff_rate_limited_no_retry_after() {
2012        let http_client = HttpRPCClient::new("http://localhost:8080", None)
2013            .unwrap()
2014            .with_test_backoff_policy();
2015        let error = RPCError::RateLimited(None);
2016
2017        let backoff_error = http_client
2018            .handle_error_for_backoff(error)
2019            .await;
2020
2021        match backoff_error {
2022            backoff::Error::Transient { err: RPCError::RateLimited(None), .. } => {
2023                // This is expected - no retry-after still allows retries with default policy
2024            }
2025            _ => panic!("Expected transient error for RateLimited without retry-after"),
2026        }
2027    }
2028
2029    #[tokio::test]
2030    async fn test_handle_error_for_backoff_other_errors() {
2031        let http_client = HttpRPCClient::new("http://localhost:8080", None)
2032            .unwrap()
2033            .with_test_backoff_policy();
2034        let error = RPCError::ParseResponse("Invalid JSON".to_string());
2035
2036        let backoff_error = http_client
2037            .handle_error_for_backoff(error)
2038            .await;
2039
2040        match backoff_error {
2041            backoff::Error::Permanent(RPCError::ParseResponse(msg)) => {
2042                assert_eq!(msg, "Invalid JSON");
2043            }
2044            _ => panic!("Expected permanent error for ParseResponse"),
2045        }
2046    }
2047
2048    #[tokio::test]
2049    async fn test_wait_until_retry_after_no_retry_time() {
2050        let http_client = HttpRPCClient::new("http://localhost:8080", None)
2051            .unwrap()
2052            .with_test_backoff_policy();
2053
2054        let start = std::time::Instant::now();
2055        http_client
2056            .wait_until_retry_after()
2057            .await;
2058        let elapsed = start.elapsed();
2059
2060        // Should return immediately if no retry time is set
2061        assert!(elapsed < Duration::from_millis(100));
2062    }
2063
2064    #[tokio::test]
2065    async fn test_wait_until_retry_after_past_time() {
2066        let http_client = HttpRPCClient::new("http://localhost:8080", None)
2067            .unwrap()
2068            .with_test_backoff_policy();
2069
2070        // Set a retry time in the past
2071        let past_time = SystemTime::now() - Duration::from_secs(10);
2072        *http_client.retry_after.write().await = Some(past_time);
2073
2074        let start = std::time::Instant::now();
2075        http_client
2076            .wait_until_retry_after()
2077            .await;
2078        let elapsed = start.elapsed();
2079
2080        // Should return immediately if retry time is in the past
2081        assert!(elapsed < Duration::from_millis(100));
2082    }
2083
2084    #[tokio::test]
2085    async fn test_wait_until_retry_after_future_time() {
2086        let http_client = HttpRPCClient::new("http://localhost:8080", None)
2087            .unwrap()
2088            .with_test_backoff_policy();
2089
2090        // Set a retry time 100ms in the future
2091        let future_time = SystemTime::now() + Duration::from_millis(100);
2092        *http_client.retry_after.write().await = Some(future_time);
2093
2094        let start = std::time::Instant::now();
2095        http_client
2096            .wait_until_retry_after()
2097            .await;
2098        let elapsed = start.elapsed();
2099
2100        // Should wait approximately the specified duration
2101        assert!(elapsed >= Duration::from_millis(80)); // Allow some tolerance
2102        assert!(elapsed <= Duration::from_millis(200)); // Upper bound for test stability
2103    }
2104
2105    #[tokio::test]
2106    async fn test_make_post_request_success() {
2107        let mut server = Server::new_async().await;
2108        let server_resp = r#"{"success": true}"#;
2109
2110        let mock = server
2111            .mock("POST", "/test")
2112            .with_status(200)
2113            .with_body(server_resp)
2114            .create_async()
2115            .await;
2116
2117        let http_client = HttpRPCClient::new(server.url().as_str(), None)
2118            .unwrap()
2119            .with_test_backoff_policy();
2120        let request_body = serde_json::json!({"test": "data"});
2121        let uri = format!("{}/test", server.url());
2122
2123        let result = http_client
2124            .make_post_request(&request_body, &uri)
2125            .await;
2126
2127        mock.assert();
2128        assert!(result.is_ok());
2129
2130        let response = result.unwrap();
2131        assert_eq!(response.status(), 200);
2132        assert_eq!(response.text().await.unwrap(), server_resp);
2133    }
2134
2135    #[tokio::test]
2136    async fn test_make_post_request_retry_on_server_error() {
2137        let mut server = Server::new_async().await;
2138        // First request fails with 503, second succeeds
2139        let error_mock = server
2140            .mock("POST", "/test")
2141            .with_status(503)
2142            .with_body("Service Unavailable")
2143            .expect(1)
2144            .create_async()
2145            .await;
2146
2147        let success_mock = server
2148            .mock("POST", "/test")
2149            .with_status(200)
2150            .with_body(r#"{"success": true}"#)
2151            .expect(1)
2152            .create_async()
2153            .await;
2154
2155        let http_client = HttpRPCClient::new(server.url().as_str(), None)
2156            .unwrap()
2157            .with_test_backoff_policy();
2158        let request_body = serde_json::json!({"test": "data"});
2159        let uri = format!("{}/test", server.url());
2160
2161        let result = http_client
2162            .make_post_request(&request_body, &uri)
2163            .await;
2164
2165        error_mock.assert();
2166        success_mock.assert();
2167        assert!(result.is_ok());
2168    }
2169
2170    #[tokio::test]
2171    async fn test_make_post_request_respect_retry_after_header() {
2172        let mut server = Server::new_async().await;
2173
2174        // First request returns 429 with retry-after, second succeeds
2175        let rate_limit_mock = server
2176            .mock("POST", "/test")
2177            .with_status(429)
2178            .with_header("Retry-After", "1") // 1 second
2179            .expect(1)
2180            .create_async()
2181            .await;
2182
2183        let success_mock = server
2184            .mock("POST", "/test")
2185            .with_status(200)
2186            .with_body(r#"{"success": true}"#)
2187            .expect(1)
2188            .create_async()
2189            .await;
2190
2191        let http_client = HttpRPCClient::new(server.url().as_str(), None)
2192            .unwrap()
2193            .with_test_backoff_policy();
2194        let request_body = serde_json::json!({"test": "data"});
2195        let uri = format!("{}/test", server.url());
2196
2197        let start = std::time::Instant::now();
2198        let result = http_client
2199            .make_post_request(&request_body, &uri)
2200            .await;
2201        let elapsed = start.elapsed();
2202
2203        rate_limit_mock.assert();
2204        success_mock.assert();
2205        assert!(result.is_ok());
2206
2207        // Should have waited at least 1 second due to retry-after header
2208        assert!(elapsed >= Duration::from_millis(900)); // Allow some tolerance
2209        assert!(elapsed <= Duration::from_millis(2000)); // Upper bound for test stability
2210    }
2211
2212    #[tokio::test]
2213    async fn test_make_post_request_permanent_error() {
2214        let mut server = Server::new_async().await;
2215
2216        let mock = server
2217            .mock("POST", "/test")
2218            .with_status(400) // Bad Request - should not be retried
2219            .with_body("Bad Request")
2220            .expect(1)
2221            .create_async()
2222            .await;
2223
2224        let http_client = HttpRPCClient::new(server.url().as_str(), None)
2225            .unwrap()
2226            .with_test_backoff_policy();
2227        let request_body = serde_json::json!({"test": "data"});
2228        let uri = format!("{}/test", server.url());
2229
2230        let result = http_client
2231            .make_post_request(&request_body, &uri)
2232            .await;
2233
2234        mock.assert();
2235        assert!(result.is_ok()); // 400 doesn't trigger retry logic, just returns the response
2236
2237        let response = result.unwrap();
2238        assert_eq!(response.status(), 400);
2239    }
2240
2241    #[tokio::test]
2242    async fn test_concurrent_requests_with_different_retry_after() {
2243        let mut server = Server::new_async().await;
2244
2245        // First request gets rate limited with 1 second retry-after
2246        let rate_limit_mock_1 = server
2247            .mock("POST", "/test1")
2248            .with_status(429)
2249            .with_header("Retry-After", "1")
2250            .expect(1)
2251            .create_async()
2252            .await;
2253
2254        // Second request gets rate limited with 2 second retry-after
2255        let rate_limit_mock_2 = server
2256            .mock("POST", "/test2")
2257            .with_status(429)
2258            .with_header("Retry-After", "2")
2259            .expect(1)
2260            .create_async()
2261            .await;
2262
2263        // Success mocks for retries
2264        let success_mock_1 = server
2265            .mock("POST", "/test1")
2266            .with_status(200)
2267            .with_body(r#"{"result": "success1"}"#)
2268            .expect(1)
2269            .create_async()
2270            .await;
2271
2272        let success_mock_2 = server
2273            .mock("POST", "/test2")
2274            .with_status(200)
2275            .with_body(r#"{"result": "success2"}"#)
2276            .expect(1)
2277            .create_async()
2278            .await;
2279
2280        let http_client = HttpRPCClient::new(server.url().as_str(), None)
2281            .unwrap()
2282            .with_test_backoff_policy();
2283        let request_body = serde_json::json!({"test": "data"});
2284
2285        let uri1 = format!("{}/test1", server.url());
2286        let uri2 = format!("{}/test2", server.url());
2287
2288        // Start both requests concurrently
2289        let start = std::time::Instant::now();
2290        let (result1, result2) = tokio::join!(
2291            http_client.make_post_request(&request_body, &uri1),
2292            http_client.make_post_request(&request_body, &uri2)
2293        );
2294        let elapsed = start.elapsed();
2295
2296        rate_limit_mock_1.assert();
2297        rate_limit_mock_2.assert();
2298        success_mock_1.assert();
2299        success_mock_2.assert();
2300
2301        assert!(result1.is_ok());
2302        assert!(result2.is_ok());
2303
2304        // Both requests should succeed, but the second should take longer due to the 2s retry-after
2305        // The total time should be at least 2 seconds since the shared retry_after state
2306        // gets updated by both requests
2307        assert!(elapsed >= Duration::from_millis(1800)); // Allow some tolerance
2308        assert!(elapsed <= Duration::from_millis(3000)); // Upper bound
2309
2310        // Check the final retry_after state - should be the latest (higher) value
2311        let final_retry_after = http_client.retry_after.read().await;
2312        assert!(final_retry_after.is_some());
2313
2314        // The retry_after should be set to the latest (higher) value from the two requests
2315        if let Some(retry_time) = *final_retry_after {
2316            // The retry_after time might be in the past now since we waited,
2317            // but it should be reasonable (not too far in past/future)
2318            let now = SystemTime::now();
2319            let diff = if retry_time > now {
2320                retry_time.duration_since(now).unwrap()
2321            } else {
2322                now.duration_since(retry_time).unwrap()
2323            };
2324
2325            // Should be within a reasonable range (the 2s retry-after plus some buffer)
2326            assert!(diff <= Duration::from_secs(3), "Retry time difference too large: {:?}", diff);
2327        }
2328    }
2329
2330    #[tokio::test]
2331    async fn test_get_snapshots() {
2332        let mut server = Server::new_async().await;
2333
2334        // Mock protocol states response
2335        let protocol_states_resp = r#"
2336        {
2337            "states": [
2338                {
2339                    "component_id": "component1",
2340                    "attributes": {
2341                        "attribute_1": "0x00000000000003e8"
2342                    },
2343                    "balances": {
2344                        "0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2": "0x01f4"
2345                    }
2346                }
2347            ],
2348            "pagination": {
2349                "page": 0,
2350                "page_size": 100,
2351                "total": 1
2352            }
2353        }
2354        "#;
2355
2356        // Mock contract state response
2357        let contract_state_resp = r#"
2358        {
2359            "accounts": [
2360                {
2361                    "chain": "ethereum",
2362                    "address": "0x1111111111111111111111111111111111111111",
2363                    "title": "",
2364                    "slots": {},
2365                    "native_balance": "0x01f4",
2366                    "token_balances": {},
2367                    "code": "0x00",
2368                    "code_hash": "0x5c06b7c5b3d910fd33bc2229846f9ddaf91d584d9b196e16636901ac3a77077e",
2369                    "balance_modify_tx": "0x0000000000000000000000000000000000000000000000000000000000000000",
2370                    "code_modify_tx": "0x0000000000000000000000000000000000000000000000000000000000000000",
2371                    "creation_tx": null
2372                }
2373            ],
2374            "pagination": {
2375                "page": 0,
2376                "page_size": 100,
2377                "total": 1
2378            }
2379        }
2380        "#;
2381
2382        // Mock component TVL response
2383        let tvl_resp = r#"
2384        {
2385            "tvl": {
2386                "component1": 1000000.0
2387            },
2388            "pagination": {
2389                "page": 0,
2390                "page_size": 100,
2391                "total": 1
2392            }
2393        }
2394        "#;
2395
2396        let protocol_states_mock = server
2397            .mock("POST", "/v1/protocol_state")
2398            .expect(1)
2399            .with_body(protocol_states_resp)
2400            .create_async()
2401            .await;
2402
2403        let contract_state_mock = server
2404            .mock("POST", "/v1/contract_state")
2405            .expect(1)
2406            .with_body(contract_state_resp)
2407            .create_async()
2408            .await;
2409
2410        let tvl_mock = server
2411            .mock("POST", "/v1/component_tvl")
2412            .expect(1)
2413            .with_body(tvl_resp)
2414            .create_async()
2415            .await;
2416
2417        let client = HttpRPCClient::new(server.url().as_str(), None).expect("create client");
2418
2419        #[allow(deprecated)]
2420        let component = tycho_common::dto::ProtocolComponent {
2421            id: "component1".to_string(),
2422            protocol_system: "test_protocol".to_string(),
2423            protocol_type_name: "test_type".to_string(),
2424            chain: Chain::Ethereum,
2425            tokens: vec![],
2426            contract_ids: vec![
2427                Bytes::from_str("0x1111111111111111111111111111111111111111").unwrap()
2428            ],
2429            static_attributes: HashMap::new(),
2430            change: tycho_common::dto::ChangeType::Creation,
2431            creation_tx: Bytes::from_str(
2432                "0x0000000000000000000000000000000000000000000000000000000000000000",
2433            )
2434            .unwrap(),
2435            created_at: chrono::Utc::now().naive_utc(),
2436        };
2437
2438        let mut components = HashMap::new();
2439        components.insert("component1".to_string(), component);
2440
2441        let contract_ids =
2442            vec![Bytes::from_str("0x1111111111111111111111111111111111111111").unwrap()];
2443
2444        let request = SnapshotParameters::new(
2445            Chain::Ethereum,
2446            "test_protocol",
2447            &components,
2448            &contract_ids,
2449            12345,
2450        );
2451
2452        let response = client
2453            .get_snapshots(&request, 100, 4)
2454            .await
2455            .expect("get snapshots");
2456
2457        // Verify all mocks were called
2458        protocol_states_mock.assert();
2459        contract_state_mock.assert();
2460        tvl_mock.assert();
2461
2462        // Assert states
2463        assert_eq!(response.states.len(), 1);
2464        assert!(response
2465            .states
2466            .contains_key("component1"));
2467
2468        // Check that the state has the expected TVL
2469        let component_state = response
2470            .states
2471            .get("component1")
2472            .unwrap();
2473        assert_eq!(component_state.component_tvl, Some(1000000.0));
2474
2475        // Assert VM storage
2476        assert_eq!(response.vm_storage.len(), 1);
2477        let contract_addr = Bytes::from_str("0x1111111111111111111111111111111111111111").unwrap();
2478        assert!(response
2479            .vm_storage
2480            .contains_key(&contract_addr));
2481    }
2482
2483    #[tokio::test]
2484    async fn test_get_snapshots_empty_components() {
2485        let server = Server::new_async().await;
2486        let client = HttpRPCClient::new(server.url().as_str(), None).expect("create client");
2487
2488        let components = HashMap::new();
2489        let contract_ids = vec![];
2490
2491        let request = SnapshotParameters::new(
2492            Chain::Ethereum,
2493            "test_protocol",
2494            &components,
2495            &contract_ids,
2496            12345,
2497        );
2498
2499        let response = client
2500            .get_snapshots(&request, 100, 4)
2501            .await
2502            .expect("get snapshots");
2503
2504        // Should return empty response without making any requests
2505        assert!(response.states.is_empty());
2506        assert!(response.vm_storage.is_empty());
2507    }
2508
2509    #[tokio::test]
2510    async fn test_get_snapshots_without_tvl() {
2511        let mut server = Server::new_async().await;
2512
2513        let protocol_states_resp = r#"
2514        {
2515            "states": [
2516                {
2517                    "component_id": "component1",
2518                    "attributes": {},
2519                    "balances": {}
2520                }
2521            ],
2522            "pagination": {
2523                "page": 0,
2524                "page_size": 100,
2525                "total": 1
2526            }
2527        }
2528        "#;
2529
2530        let protocol_states_mock = server
2531            .mock("POST", "/v1/protocol_state")
2532            .expect(1)
2533            .with_body(protocol_states_resp)
2534            .create_async()
2535            .await;
2536
2537        let client = HttpRPCClient::new(server.url().as_str(), None).expect("create client");
2538
2539        // Create test component
2540        #[allow(deprecated)]
2541        let component = tycho_common::dto::ProtocolComponent {
2542            id: "component1".to_string(),
2543            protocol_system: "test_protocol".to_string(),
2544            protocol_type_name: "test_type".to_string(),
2545            chain: Chain::Ethereum,
2546            tokens: vec![],
2547            contract_ids: vec![],
2548            static_attributes: HashMap::new(),
2549            change: tycho_common::dto::ChangeType::Creation,
2550            creation_tx: Bytes::from_str(
2551                "0x0000000000000000000000000000000000000000000000000000000000000000",
2552            )
2553            .unwrap(),
2554            created_at: chrono::Utc::now().naive_utc(),
2555        };
2556
2557        let mut components = HashMap::new();
2558        components.insert("component1".to_string(), component);
2559        let contract_ids = vec![];
2560
2561        let request = SnapshotParameters::new(
2562            Chain::Ethereum,
2563            "test_protocol",
2564            &components,
2565            &contract_ids,
2566            12345,
2567        )
2568        .include_balances(false)
2569        .include_tvl(false);
2570
2571        let response = client
2572            .get_snapshots(&request, 100, 4)
2573            .await
2574            .expect("get snapshots");
2575
2576        // Verify only necessary mocks were called
2577        protocol_states_mock.assert();
2578        // No contract_state_mock.assert() since contract_ids is empty
2579        // No tvl_mock.assert() since include_tvl is false
2580
2581        assert_eq!(response.states.len(), 1);
2582        // Check that TVL is None since we didn't request it
2583        let component_state = response
2584            .states
2585            .get("component1")
2586            .unwrap();
2587        assert_eq!(component_state.component_tvl, None);
2588    }
2589}