tycho_client/
rpc.rs

1//! # Tycho RPC Client
2//!
3//! The objective of this module is to provide swift and simplified access to the Remote Procedure
4//! Call (RPC) endpoints of Tycho. These endpoints are chiefly responsible for facilitating data
5//! queries, especially querying snapshots of data.
6use std::{
7    collections::HashMap,
8    sync::Arc,
9    time::{Duration, SystemTime},
10};
11
12use async_trait::async_trait;
13use backoff::{exponential::ExponentialBackoffBuilder, ExponentialBackoff};
14use futures03::future::try_join_all;
15#[cfg(test)]
16use mockall::automock;
17use reqwest::{header, Client, ClientBuilder, Response, StatusCode, Url};
18use serde::Serialize;
19use thiserror::Error;
20use time::{format_description::well_known::Rfc2822, OffsetDateTime};
21use tokio::{
22    sync::{RwLock, Semaphore},
23    time::sleep,
24};
25use tracing::{debug, error, instrument, trace, warn};
26use tycho_common::{
27    dto::{
28        BlockParam, Chain, ComponentTvlRequestBody, ComponentTvlRequestResponse,
29        EntryPointWithTracingParams, PaginationParams, PaginationResponse, ProtocolComponent,
30        ProtocolComponentRequestResponse, ProtocolComponentsRequestBody, ProtocolStateRequestBody,
31        ProtocolStateRequestResponse, ProtocolSystemsRequestBody, ProtocolSystemsRequestResponse,
32        ResponseToken, StateRequestBody, StateRequestResponse, TokensRequestBody,
33        TokensRequestResponse, TracedEntryPointRequestBody, TracedEntryPointRequestResponse,
34        TracingResult, VersionParam,
35    },
36    models::ComponentId,
37    Bytes,
38};
39
40use crate::{
41    feed::synchronizer::{ComponentWithState, Snapshot},
42    TYCHO_SERVER_VERSION,
43};
44
45/// Request body for fetching a snapshot of protocol states and VM storage.
46///
47/// This struct helps to coordinate fetching  multiple pieces of related data
48/// (protocol states, contract storage, TVL, entry points).
49#[derive(Clone, Debug, PartialEq)]
50pub struct SnapshotParameters<'a> {
51    /// Which chain to fetch snapshots for
52    pub chain: Chain,
53    /// Protocol system name, required for correct state resolution
54    pub protocol_system: &'a str,
55    /// Components to fetch protocol states for
56    pub components: &'a HashMap<ComponentId, ProtocolComponent>,
57    /// Traced entry points data mapped by component id
58    pub entrypoints: Option<&'a HashMap<String, Vec<(EntryPointWithTracingParams, TracingResult)>>>,
59    /// Contract addresses to fetch VM storage for
60    pub contract_ids: &'a [Bytes],
61    /// Block number for versioning
62    pub block_number: u64,
63    /// Whether to include balance information
64    pub include_balances: bool,
65    /// Whether to fetch TVL data
66    pub include_tvl: bool,
67}
68
69impl<'a> SnapshotParameters<'a> {
70    pub fn new(
71        chain: Chain,
72        protocol_system: &'a str,
73        components: &'a HashMap<ComponentId, ProtocolComponent>,
74        contract_ids: &'a [Bytes],
75        block_number: u64,
76    ) -> Self {
77        Self {
78            chain,
79            protocol_system,
80            components,
81            entrypoints: None,
82            contract_ids,
83            block_number,
84            include_balances: true,
85            include_tvl: true,
86        }
87    }
88
89    /// Set whether to include balance information (default: true)
90    pub fn include_balances(mut self, include_balances: bool) -> Self {
91        self.include_balances = include_balances;
92        self
93    }
94
95    /// Set whether to fetch TVL data (default: true)
96    pub fn include_tvl(mut self, include_tvl: bool) -> Self {
97        self.include_tvl = include_tvl;
98        self
99    }
100
101    pub fn entrypoints(
102        mut self,
103        entrypoints: &'a HashMap<String, Vec<(EntryPointWithTracingParams, TracingResult)>>,
104    ) -> Self {
105        self.entrypoints = Some(entrypoints);
106        self
107    }
108}
109
110#[derive(Error, Debug)]
111pub enum RPCError {
112    /// The passed tycho url failed to parse.
113    #[error("Failed to parse URL: {0}. Error: {1}")]
114    UrlParsing(String, String),
115
116    /// The request data is not correctly formed.
117    #[error("Failed to format request: {0}")]
118    FormatRequest(String),
119
120    /// Errors forwarded from the HTTP protocol.
121    #[error("Unexpected HTTP client error: {0}")]
122    HttpClient(String, #[source] reqwest::Error),
123
124    /// The response from the server could not be parsed correctly.
125    #[error("Failed to parse response: {0}")]
126    ParseResponse(String),
127
128    /// Other fatal errors.
129    #[error("Fatal error: {0}")]
130    Fatal(String),
131
132    #[error("Rate limited until {0:?}")]
133    RateLimited(Option<SystemTime>),
134
135    #[error("Server unreachable: {0}")]
136    ServerUnreachable(String),
137}
138
139#[cfg_attr(test, automock)]
140#[async_trait]
141pub trait RPCClient: Send + Sync {
142    /// Retrieves a snapshot of contract state.
143    async fn get_contract_state(
144        &self,
145        request: &StateRequestBody,
146    ) -> Result<StateRequestResponse, RPCError>;
147
148    async fn get_contract_state_paginated(
149        &self,
150        chain: Chain,
151        ids: &[Bytes],
152        protocol_system: &str,
153        version: &VersionParam,
154        chunk_size: usize,
155        concurrency: usize,
156    ) -> Result<StateRequestResponse, RPCError> {
157        let semaphore = Arc::new(Semaphore::new(concurrency));
158
159        // Sort the ids to maximize server-side cache hits
160        let mut sorted_ids = ids.to_vec();
161        sorted_ids.sort();
162
163        let chunked_bodies = sorted_ids
164            .chunks(chunk_size)
165            .map(|chunk| StateRequestBody {
166                contract_ids: Some(chunk.to_vec()),
167                protocol_system: protocol_system.to_string(),
168                chain,
169                version: version.clone(),
170                pagination: PaginationParams { page: 0, page_size: chunk_size as i64 },
171            })
172            .collect::<Vec<_>>();
173
174        let mut tasks = Vec::new();
175        for body in chunked_bodies.iter() {
176            let sem = semaphore.clone();
177            tasks.push(async move {
178                let _permit = sem
179                    .acquire()
180                    .await
181                    .map_err(|_| RPCError::Fatal("Semaphore dropped".to_string()))?;
182                self.get_contract_state(body).await
183            });
184        }
185
186        // Execute all tasks concurrently with the defined concurrency limit.
187        let responses = try_join_all(tasks).await?;
188
189        // Aggregate the responses into a single result.
190        let accounts = responses
191            .iter()
192            .flat_map(|r| r.accounts.clone())
193            .collect();
194        let total: i64 = responses
195            .iter()
196            .map(|r| r.pagination.total)
197            .sum();
198
199        Ok(StateRequestResponse {
200            accounts,
201            pagination: PaginationResponse { page: 0, page_size: chunk_size as i64, total },
202        })
203    }
204
205    async fn get_protocol_components(
206        &self,
207        request: &ProtocolComponentsRequestBody,
208    ) -> Result<ProtocolComponentRequestResponse, RPCError>;
209
210    async fn get_protocol_components_paginated(
211        &self,
212        request: &ProtocolComponentsRequestBody,
213        chunk_size: usize,
214        concurrency: usize,
215    ) -> Result<ProtocolComponentRequestResponse, RPCError> {
216        let semaphore = Arc::new(Semaphore::new(concurrency));
217
218        // If a set of component IDs is specified, the maximum return size is already known,
219        // allowing us to pre-compute the number of requests to be made.
220        match request.component_ids {
221            Some(ref ids) => {
222                // We can divide the component_ids into chunks of size chunk_size
223                let chunked_bodies = ids
224                    .chunks(chunk_size)
225                    .enumerate()
226                    .map(|(index, _)| ProtocolComponentsRequestBody {
227                        protocol_system: request.protocol_system.clone(),
228                        component_ids: request.component_ids.clone(),
229                        tvl_gt: request.tvl_gt,
230                        chain: request.chain,
231                        pagination: PaginationParams {
232                            page: index as i64,
233                            page_size: chunk_size as i64,
234                        },
235                    })
236                    .collect::<Vec<_>>();
237
238                let mut tasks = Vec::new();
239                for body in chunked_bodies.iter() {
240                    let sem = semaphore.clone();
241                    tasks.push(async move {
242                        let _permit = sem
243                            .acquire()
244                            .await
245                            .map_err(|_| RPCError::Fatal("Semaphore dropped".to_string()))?;
246                        self.get_protocol_components(body).await
247                    });
248                }
249
250                try_join_all(tasks)
251                    .await
252                    .map(|responses| ProtocolComponentRequestResponse {
253                        protocol_components: responses
254                            .into_iter()
255                            .flat_map(|r| r.protocol_components.into_iter())
256                            .collect(),
257                        pagination: PaginationResponse {
258                            page: 0,
259                            page_size: chunk_size as i64,
260                            total: ids.len() as i64,
261                        },
262                    })
263            }
264            _ => {
265                // If no component ids are specified, we need to make requests based on the total
266                // number of results from the first response.
267
268                let initial_request = ProtocolComponentsRequestBody {
269                    protocol_system: request.protocol_system.clone(),
270                    component_ids: request.component_ids.clone(),
271                    tvl_gt: request.tvl_gt,
272                    chain: request.chain,
273                    pagination: PaginationParams { page: 0, page_size: chunk_size as i64 },
274                };
275                let first_response = self
276                    .get_protocol_components(&initial_request)
277                    .await
278                    .map_err(|err| RPCError::Fatal(err.to_string()))?;
279
280                let total_items = first_response.pagination.total;
281                let total_pages = (total_items as f64 / chunk_size as f64).ceil() as i64;
282
283                // Initialize the final response accumulator
284                let mut accumulated_response = ProtocolComponentRequestResponse {
285                    protocol_components: first_response.protocol_components,
286                    pagination: PaginationResponse {
287                        page: 0,
288                        page_size: chunk_size as i64,
289                        total: total_items,
290                    },
291                };
292
293                let mut page = 1;
294                while page < total_pages {
295                    let requests_in_this_iteration = (total_pages - page).min(concurrency as i64);
296
297                    // Create request bodies for parallel requests, respecting the concurrency limit
298                    let chunked_bodies = (0..requests_in_this_iteration)
299                        .map(|iter| ProtocolComponentsRequestBody {
300                            protocol_system: request.protocol_system.clone(),
301                            component_ids: request.component_ids.clone(),
302                            tvl_gt: request.tvl_gt,
303                            chain: request.chain,
304                            pagination: PaginationParams {
305                                page: page + iter,
306                                page_size: chunk_size as i64,
307                            },
308                        })
309                        .collect::<Vec<_>>();
310
311                    let tasks: Vec<_> = chunked_bodies
312                        .iter()
313                        .map(|body| {
314                            let sem = semaphore.clone();
315                            async move {
316                                let _permit = sem.acquire().await.map_err(|_| {
317                                    RPCError::Fatal("Semaphore dropped".to_string())
318                                })?;
319                                self.get_protocol_components(body).await
320                            }
321                        })
322                        .collect();
323
324                    let responses = try_join_all(tasks)
325                        .await
326                        .map(|responses| {
327                            let total = responses[0].pagination.total;
328                            ProtocolComponentRequestResponse {
329                                protocol_components: responses
330                                    .into_iter()
331                                    .flat_map(|r| r.protocol_components.into_iter())
332                                    .collect(),
333                                pagination: PaginationResponse {
334                                    page,
335                                    page_size: chunk_size as i64,
336                                    total,
337                                },
338                            }
339                        });
340
341                    // Update the accumulated response or set the initial response
342                    match responses {
343                        Ok(mut resp) => {
344                            accumulated_response
345                                .protocol_components
346                                .append(&mut resp.protocol_components);
347                        }
348                        Err(e) => return Err(e),
349                    }
350
351                    page += concurrency as i64;
352                }
353                Ok(accumulated_response)
354            }
355        }
356    }
357
358    async fn get_protocol_states(
359        &self,
360        request: &ProtocolStateRequestBody,
361    ) -> Result<ProtocolStateRequestResponse, RPCError>;
362
363    #[allow(clippy::too_many_arguments)]
364    async fn get_protocol_states_paginated<T>(
365        &self,
366        chain: Chain,
367        ids: &[T],
368        protocol_system: &str,
369        include_balances: bool,
370        version: &VersionParam,
371        chunk_size: usize,
372        concurrency: usize,
373    ) -> Result<ProtocolStateRequestResponse, RPCError>
374    where
375        T: AsRef<str> + Sync + 'static,
376    {
377        let semaphore = Arc::new(Semaphore::new(concurrency));
378        let chunked_bodies = ids
379            .chunks(chunk_size)
380            .map(|c| ProtocolStateRequestBody {
381                protocol_ids: Some(
382                    c.iter()
383                        .map(|id| id.as_ref().to_string())
384                        .collect(),
385                ),
386                protocol_system: protocol_system.to_string(),
387                chain,
388                include_balances,
389                version: version.clone(),
390                pagination: PaginationParams { page: 0, page_size: chunk_size as i64 },
391            })
392            .collect::<Vec<_>>();
393
394        let mut tasks = Vec::new();
395        for body in chunked_bodies.iter() {
396            let sem = semaphore.clone();
397            tasks.push(async move {
398                let _permit = sem
399                    .acquire()
400                    .await
401                    .map_err(|_| RPCError::Fatal("Semaphore dropped".to_string()))?;
402                self.get_protocol_states(body).await
403            });
404        }
405
406        try_join_all(tasks)
407            .await
408            .map(|responses| {
409                let states = responses
410                    .clone()
411                    .into_iter()
412                    .flat_map(|r| r.states)
413                    .collect();
414                let total = responses
415                    .iter()
416                    .map(|r| r.pagination.total)
417                    .sum();
418                ProtocolStateRequestResponse {
419                    states,
420                    pagination: PaginationResponse { page: 0, page_size: chunk_size as i64, total },
421                }
422            })
423    }
424
425    /// This function returns only one chunk of tokens. To get all tokens please call
426    /// get_all_tokens.
427    async fn get_tokens(
428        &self,
429        request: &TokensRequestBody,
430    ) -> Result<TokensRequestResponse, RPCError>;
431
432    async fn get_all_tokens(
433        &self,
434        chain: Chain,
435        min_quality: Option<i32>,
436        traded_n_days_ago: Option<u64>,
437        chunk_size: usize,
438    ) -> Result<Vec<ResponseToken>, RPCError> {
439        let mut request_page = 0;
440        let mut all_tokens = Vec::new();
441        loop {
442            let mut response = self
443                .get_tokens(&TokensRequestBody {
444                    token_addresses: None,
445                    min_quality,
446                    traded_n_days_ago,
447                    pagination: PaginationParams {
448                        page: request_page,
449                        page_size: chunk_size.try_into().map_err(|_| {
450                            RPCError::FormatRequest(
451                                "Failed to convert chunk_size into i64".to_string(),
452                            )
453                        })?,
454                    },
455                    chain,
456                })
457                .await?;
458
459            let num_tokens = response.tokens.len();
460            all_tokens.append(&mut response.tokens);
461            request_page += 1;
462
463            if num_tokens < chunk_size {
464                break;
465            }
466        }
467        Ok(all_tokens)
468    }
469
470    async fn get_protocol_systems(
471        &self,
472        request: &ProtocolSystemsRequestBody,
473    ) -> Result<ProtocolSystemsRequestResponse, RPCError>;
474
475    async fn get_component_tvl(
476        &self,
477        request: &ComponentTvlRequestBody,
478    ) -> Result<ComponentTvlRequestResponse, RPCError>;
479
480    async fn get_component_tvl_paginated(
481        &self,
482        request: &ComponentTvlRequestBody,
483        chunk_size: usize,
484        concurrency: usize,
485    ) -> Result<ComponentTvlRequestResponse, RPCError> {
486        let semaphore = Arc::new(Semaphore::new(concurrency));
487
488        match request.component_ids {
489            Some(ref ids) => {
490                let chunked_requests = ids
491                    .chunks(chunk_size)
492                    .enumerate()
493                    .map(|(index, _)| ComponentTvlRequestBody {
494                        chain: request.chain,
495                        protocol_system: request.protocol_system.clone(),
496                        component_ids: Some(ids.clone()),
497                        pagination: PaginationParams {
498                            page: index as i64,
499                            page_size: chunk_size as i64,
500                        },
501                    })
502                    .collect::<Vec<_>>();
503
504                let tasks: Vec<_> = chunked_requests
505                    .into_iter()
506                    .map(|req| {
507                        let sem = semaphore.clone();
508                        async move {
509                            let _permit = sem
510                                .acquire()
511                                .await
512                                .map_err(|_| RPCError::Fatal("Semaphore dropped".to_string()))?;
513                            self.get_component_tvl(&req).await
514                        }
515                    })
516                    .collect();
517
518                let responses = try_join_all(tasks).await?;
519
520                let mut merged_tvl = HashMap::new();
521                for resp in responses {
522                    for (key, value) in resp.tvl {
523                        *merged_tvl.entry(key).or_insert(0.0) = value;
524                    }
525                }
526
527                Ok(ComponentTvlRequestResponse {
528                    tvl: merged_tvl,
529                    pagination: PaginationResponse {
530                        page: 0,
531                        page_size: chunk_size as i64,
532                        total: ids.len() as i64,
533                    },
534                })
535            }
536            _ => {
537                let first_request = ComponentTvlRequestBody {
538                    chain: request.chain,
539                    protocol_system: request.protocol_system.clone(),
540                    component_ids: request.component_ids.clone(),
541                    pagination: PaginationParams { page: 0, page_size: chunk_size as i64 },
542                };
543
544                let first_response = self
545                    .get_component_tvl(&first_request)
546                    .await?;
547                let total_items = first_response.pagination.total;
548                let total_pages = (total_items as f64 / chunk_size as f64).ceil() as i64;
549
550                let mut merged_tvl = first_response.tvl;
551
552                let mut page = 1;
553                while page < total_pages {
554                    let requests_in_this_iteration = (total_pages - page).min(concurrency as i64);
555
556                    let chunked_requests: Vec<_> = (0..requests_in_this_iteration)
557                        .map(|i| ComponentTvlRequestBody {
558                            chain: request.chain,
559                            protocol_system: request.protocol_system.clone(),
560                            component_ids: request.component_ids.clone(),
561                            pagination: PaginationParams {
562                                page: page + i,
563                                page_size: chunk_size as i64,
564                            },
565                        })
566                        .collect();
567
568                    let tasks: Vec<_> = chunked_requests
569                        .into_iter()
570                        .map(|req| {
571                            let sem = semaphore.clone();
572                            async move {
573                                let _permit = sem.acquire().await.map_err(|_| {
574                                    RPCError::Fatal("Semaphore dropped".to_string())
575                                })?;
576                                self.get_component_tvl(&req).await
577                            }
578                        })
579                        .collect();
580
581                    let responses = try_join_all(tasks).await?;
582
583                    // merge hashmap
584                    for resp in responses {
585                        for (key, value) in resp.tvl {
586                            *merged_tvl.entry(key).or_insert(0.0) += value;
587                        }
588                    }
589
590                    page += concurrency as i64;
591                }
592
593                Ok(ComponentTvlRequestResponse {
594                    tvl: merged_tvl,
595                    pagination: PaginationResponse {
596                        page: 0,
597                        page_size: chunk_size as i64,
598                        total: total_items,
599                    },
600                })
601            }
602        }
603    }
604
605    async fn get_traced_entry_points(
606        &self,
607        request: &TracedEntryPointRequestBody,
608    ) -> Result<TracedEntryPointRequestResponse, RPCError>;
609
610    async fn get_traced_entry_points_paginated(
611        &self,
612        chain: Chain,
613        protocol_system: &str,
614        component_ids: &[String],
615        chunk_size: usize,
616        concurrency: usize,
617    ) -> Result<TracedEntryPointRequestResponse, RPCError> {
618        let semaphore = Arc::new(Semaphore::new(concurrency));
619        let chunked_bodies = component_ids
620            .chunks(chunk_size)
621            .map(|c| TracedEntryPointRequestBody {
622                chain,
623                protocol_system: protocol_system.to_string(),
624                component_ids: Some(c.to_vec()),
625                pagination: PaginationParams { page: 0, page_size: chunk_size as i64 },
626            })
627            .collect::<Vec<_>>();
628
629        let mut tasks = Vec::new();
630        for body in chunked_bodies.iter() {
631            let sem = semaphore.clone();
632            tasks.push(async move {
633                let _permit = sem
634                    .acquire()
635                    .await
636                    .map_err(|_| RPCError::Fatal("Semaphore dropped".to_string()))?;
637                self.get_traced_entry_points(body).await
638            });
639        }
640
641        try_join_all(tasks)
642            .await
643            .map(|responses| {
644                let traced_entry_points = responses
645                    .clone()
646                    .into_iter()
647                    .flat_map(|r| r.traced_entry_points)
648                    .collect();
649                let total = responses
650                    .iter()
651                    .map(|r| r.pagination.total)
652                    .sum();
653                TracedEntryPointRequestResponse {
654                    traced_entry_points,
655                    pagination: PaginationResponse { page: 0, page_size: chunk_size as i64, total },
656                }
657            })
658    }
659
660    async fn get_snapshots<'a>(
661        &self,
662        request: &SnapshotParameters<'a>,
663        chunk_size: usize,
664        concurrency: usize,
665    ) -> Result<Snapshot, RPCError>;
666}
667
668/// Configuration options for HttpRPCClient
669#[derive(Debug, Clone)]
670pub struct HttpRPCClientOptions {
671    /// Optional API key for authentication
672    pub auth_key: Option<String>,
673    /// Enable compression for requests (default: true)
674    /// When enabled, adds Accept-Encoding: zstd header
675    pub compression: bool,
676}
677
678impl Default for HttpRPCClientOptions {
679    fn default() -> Self {
680        Self::new()
681    }
682}
683
684impl HttpRPCClientOptions {
685    /// Create new options with default values (compression enabled)
686    pub fn new() -> Self {
687        Self { auth_key: None, compression: true }
688    }
689
690    /// Set the authentication key
691    pub fn with_auth_key(mut self, auth_key: Option<String>) -> Self {
692        self.auth_key = auth_key;
693        self
694    }
695
696    /// Set whether to enable compression (default: true)
697    pub fn with_compression(mut self, compression: bool) -> Self {
698        self.compression = compression;
699        self
700    }
701}
702
703#[derive(Debug, Clone)]
704pub struct HttpRPCClient {
705    http_client: Client,
706    url: Url,
707    retry_after: Arc<RwLock<Option<SystemTime>>>,
708    backoff_policy: ExponentialBackoff,
709    server_restart_duration: Duration,
710}
711
712impl HttpRPCClient {
713    pub fn new(base_uri: &str, options: HttpRPCClientOptions) -> Result<Self, RPCError> {
714        let uri = base_uri
715            .parse::<Url>()
716            .map_err(|e| RPCError::UrlParsing(base_uri.to_string(), e.to_string()))?;
717
718        // Add default headers
719        let mut headers = header::HeaderMap::new();
720        headers.insert(header::CONTENT_TYPE, header::HeaderValue::from_static("application/json"));
721        let user_agent = format!("tycho-client-{version}", version = env!("CARGO_PKG_VERSION"));
722        headers.insert(
723            header::USER_AGENT,
724            header::HeaderValue::from_str(&user_agent)
725                .map_err(|e| RPCError::FormatRequest(format!("Invalid user agent format: {e}")))?,
726        );
727
728        // Add Accept-Encoding header when compression is enabled
729        // Note: reqwest with zstd feature will automatically decompress responses
730        if options.compression {
731            headers.insert(header::ACCEPT_ENCODING, header::HeaderValue::from_static("zstd"));
732        }
733
734        // Add Authorization if one is given
735        if let Some(key) = options.auth_key.as_deref() {
736            let mut auth_value = header::HeaderValue::from_str(key).map_err(|e| {
737                RPCError::FormatRequest(format!("Invalid authorization key format: {e}"))
738            })?;
739            auth_value.set_sensitive(true);
740            headers.insert(header::AUTHORIZATION, auth_value);
741        }
742
743        let client = ClientBuilder::new()
744            .default_headers(headers)
745            .http2_prior_knowledge()
746            .build()
747            .map_err(|e| RPCError::HttpClient(e.to_string(), e))?;
748        Ok(Self {
749            http_client: client,
750            url: uri,
751            retry_after: Arc::new(RwLock::new(None)),
752            backoff_policy: ExponentialBackoffBuilder::new()
753                .with_initial_interval(Duration::from_millis(250))
754                // increase backoff time by 75% each failure
755                .with_multiplier(1.75)
756                // keep retrying every 30s
757                .with_max_interval(Duration::from_secs(30))
758                // if all retries take longer than 2m, give up
759                .with_max_elapsed_time(Some(Duration::from_secs(125)))
760                .build(),
761            server_restart_duration: Duration::from_secs(120),
762        })
763    }
764
765    #[cfg(test)]
766    pub fn with_test_backoff_policy(mut self) -> Self {
767        // Extremely short intervals for very fast testing
768        self.backoff_policy = ExponentialBackoffBuilder::new()
769            .with_initial_interval(Duration::from_millis(1))
770            .with_multiplier(1.1)
771            .with_max_interval(Duration::from_millis(5))
772            .with_max_elapsed_time(Some(Duration::from_millis(50)))
773            .build();
774        self.server_restart_duration = Duration::from_millis(50);
775        self
776    }
777
778    /// Converts a error response to a Result.
779    ///
780    /// Raises an error if the response status code id 429, 502, 503 or 504. In the 429
781    /// case it will try to look for a retry-after header an parse it accordingly. The
782    /// parsed value is then passed as part of the error.
783    async fn error_for_response(
784        &self,
785        response: reqwest::Response,
786    ) -> Result<reqwest::Response, RPCError> {
787        match response.status() {
788            StatusCode::TOO_MANY_REQUESTS => {
789                let retry_after_raw = response
790                    .headers()
791                    .get(reqwest::header::RETRY_AFTER)
792                    .and_then(|h| h.to_str().ok())
793                    .and_then(parse_retry_value);
794
795                Err(RPCError::RateLimited(retry_after_raw))
796            }
797            StatusCode::BAD_GATEWAY |
798            StatusCode::SERVICE_UNAVAILABLE |
799            StatusCode::GATEWAY_TIMEOUT => Err(RPCError::ServerUnreachable(
800                response
801                    .text()
802                    .await
803                    .unwrap_or_else(|_| "Server Unreachable".to_string()),
804            )),
805            _ => Ok(response),
806        }
807    }
808
809    /// Classifies errors into transient or permanent ones.
810    ///
811    /// Transient errors are retried with a potential backoff, permanent ones are not.
812    /// If the error is RateLimited, this method will set the self.retry_after value so
813    /// future requests wait until the rate limit has been reset.
814    async fn handle_error_for_backoff(&self, e: RPCError) -> backoff::Error<RPCError> {
815        match e {
816            RPCError::ServerUnreachable(_) => {
817                backoff::Error::retry_after(e, self.server_restart_duration)
818            }
819            RPCError::RateLimited(Some(until)) => {
820                let mut retry_after_guard = self.retry_after.write().await;
821                *retry_after_guard = Some(
822                    retry_after_guard
823                        .unwrap_or(until)
824                        .max(until),
825                );
826
827                if let Ok(duration) = until.duration_since(SystemTime::now()) {
828                    backoff::Error::retry_after(e, duration)
829                } else {
830                    e.into()
831                }
832            }
833            RPCError::RateLimited(None) => e.into(),
834            _ => backoff::Error::permanent(e),
835        }
836    }
837
838    /// Waits until the current rate limit time has passed.
839    ///
840    /// Only waits if there is a time and that time is in the future, else return
841    /// immediately.
842    async fn wait_until_retry_after(&self) {
843        if let Some(&until) = self.retry_after.read().await.as_ref() {
844            let now = SystemTime::now();
845            if until > now {
846                if let Ok(duration) = until.duration_since(now) {
847                    sleep(duration).await
848                }
849            }
850        }
851    }
852
853    /// Makes a post request handling transient failures.
854    ///
855    /// If a retry-after header is received it will be respected. Else the configured
856    /// backoff policy is used to deal with transient network or server errors.
857    async fn make_post_request<T: Serialize + ?Sized>(
858        &self,
859        request: &T,
860        uri: &String,
861    ) -> Result<Response, RPCError> {
862        self.wait_until_retry_after().await;
863        let response = backoff::future::retry(self.backoff_policy.clone(), || async {
864            let server_response = self
865                .http_client
866                .post(uri)
867                .json(request)
868                .send()
869                .await
870                .map_err(|e| RPCError::HttpClient(e.to_string(), e))?;
871
872            match self
873                .error_for_response(server_response)
874                .await
875            {
876                Ok(response) => Ok(response),
877                Err(e) => Err(self.handle_error_for_backoff(e).await),
878            }
879        })
880        .await?;
881        Ok(response)
882    }
883}
884
885fn parse_retry_value(val: &str) -> Option<SystemTime> {
886    if let Ok(secs) = val.parse::<u64>() {
887        return Some(SystemTime::now() + Duration::from_secs(secs));
888    }
889    if let Ok(date) = OffsetDateTime::parse(val, &Rfc2822) {
890        return Some(date.into());
891    }
892    None
893}
894
895#[async_trait]
896impl RPCClient for HttpRPCClient {
897    #[instrument(skip(self, request))]
898    async fn get_contract_state(
899        &self,
900        request: &StateRequestBody,
901    ) -> Result<StateRequestResponse, RPCError> {
902        // Check if contract ids are specified
903        if request
904            .contract_ids
905            .as_ref()
906            .is_none_or(|ids| ids.is_empty())
907        {
908            warn!("No contract ids specified in request.");
909        }
910
911        let uri = format!(
912            "{}/{}/contract_state",
913            self.url
914                .to_string()
915                .trim_end_matches('/'),
916            TYCHO_SERVER_VERSION
917        );
918        debug!(%uri, "Sending contract_state request to Tycho server");
919        trace!(?request, "Sending request to Tycho server");
920        let response = self
921            .make_post_request(request, &uri)
922            .await?;
923        trace!(?response, "Received response from Tycho server");
924
925        let body = response
926            .text()
927            .await
928            .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
929        if body.is_empty() {
930            // Pure native protocols will return empty contract states
931            return Ok(StateRequestResponse {
932                accounts: vec![],
933                pagination: PaginationResponse {
934                    page: request.pagination.page,
935                    page_size: request.pagination.page,
936                    total: 0,
937                },
938            });
939        }
940
941        let accounts = serde_json::from_str::<StateRequestResponse>(&body)
942            .map_err(|err| RPCError::ParseResponse(format!("Error: {err}, Body: {body}")))?;
943        trace!(?accounts, "Received contract_state response from Tycho server");
944
945        Ok(accounts)
946    }
947
948    async fn get_protocol_components(
949        &self,
950        request: &ProtocolComponentsRequestBody,
951    ) -> Result<ProtocolComponentRequestResponse, RPCError> {
952        let uri = format!(
953            "{}/{}/protocol_components",
954            self.url
955                .to_string()
956                .trim_end_matches('/'),
957            TYCHO_SERVER_VERSION,
958        );
959        debug!(%uri, "Sending protocol_components request to Tycho server");
960        trace!(?request, "Sending request to Tycho server");
961
962        let response = self
963            .make_post_request(request, &uri)
964            .await?;
965
966        trace!(?response, "Received response from Tycho server");
967
968        let body = response
969            .text()
970            .await
971            .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
972        let components = serde_json::from_str::<ProtocolComponentRequestResponse>(&body)
973            .map_err(|err| RPCError::ParseResponse(format!("Error: {err}, Body: {body}")))?;
974        trace!(?components, "Received protocol_components response from Tycho server");
975
976        Ok(components)
977    }
978
979    async fn get_protocol_states(
980        &self,
981        request: &ProtocolStateRequestBody,
982    ) -> Result<ProtocolStateRequestResponse, RPCError> {
983        // Check if protocol ids are specified
984        if request
985            .protocol_ids
986            .as_ref()
987            .is_none_or(|ids| ids.is_empty())
988        {
989            warn!("No protocol ids specified in request.");
990        }
991
992        let uri = format!(
993            "{}/{}/protocol_state",
994            self.url
995                .to_string()
996                .trim_end_matches('/'),
997            TYCHO_SERVER_VERSION
998        );
999        debug!(%uri, "Sending protocol_states request to Tycho server");
1000        trace!(?request, "Sending request to Tycho server");
1001
1002        let response = self
1003            .make_post_request(request, &uri)
1004            .await?;
1005        trace!(?response, "Received response from Tycho server");
1006
1007        let body = response
1008            .text()
1009            .await
1010            .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
1011
1012        if body.is_empty() {
1013            // Pure VM protocols will return empty states
1014            return Ok(ProtocolStateRequestResponse {
1015                states: vec![],
1016                pagination: PaginationResponse {
1017                    page: request.pagination.page,
1018                    page_size: request.pagination.page_size,
1019                    total: 0,
1020                },
1021            });
1022        }
1023
1024        let states = serde_json::from_str::<ProtocolStateRequestResponse>(&body)
1025            .map_err(|err| RPCError::ParseResponse(format!("Error: {err}, Body: {body}")))?;
1026        trace!(?states, "Received protocol_states response from Tycho server");
1027
1028        Ok(states)
1029    }
1030
1031    async fn get_tokens(
1032        &self,
1033        request: &TokensRequestBody,
1034    ) -> Result<TokensRequestResponse, RPCError> {
1035        let uri = format!(
1036            "{}/{}/tokens",
1037            self.url
1038                .to_string()
1039                .trim_end_matches('/'),
1040            TYCHO_SERVER_VERSION
1041        );
1042        debug!(%uri, "Sending tokens request to Tycho server");
1043
1044        let response = self
1045            .make_post_request(request, &uri)
1046            .await?;
1047
1048        let body = response
1049            .text()
1050            .await
1051            .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
1052        let tokens = serde_json::from_str::<TokensRequestResponse>(&body)
1053            .map_err(|err| RPCError::ParseResponse(format!("Error: {err}, Body: {body}")))?;
1054
1055        Ok(tokens)
1056    }
1057
1058    async fn get_protocol_systems(
1059        &self,
1060        request: &ProtocolSystemsRequestBody,
1061    ) -> Result<ProtocolSystemsRequestResponse, RPCError> {
1062        let uri = format!(
1063            "{}/{}/protocol_systems",
1064            self.url
1065                .to_string()
1066                .trim_end_matches('/'),
1067            TYCHO_SERVER_VERSION
1068        );
1069        debug!(%uri, "Sending protocol_systems request to Tycho server");
1070        trace!(?request, "Sending request to Tycho server");
1071        let response = self
1072            .make_post_request(request, &uri)
1073            .await?;
1074        trace!(?response, "Received response from Tycho server");
1075        let body = response
1076            .text()
1077            .await
1078            .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
1079        let protocol_systems = serde_json::from_str::<ProtocolSystemsRequestResponse>(&body)
1080            .map_err(|err| RPCError::ParseResponse(format!("Error: {err}, Body: {body}")))?;
1081        trace!(?protocol_systems, "Received protocol_systems response from Tycho server");
1082        Ok(protocol_systems)
1083    }
1084
1085    async fn get_component_tvl(
1086        &self,
1087        request: &ComponentTvlRequestBody,
1088    ) -> Result<ComponentTvlRequestResponse, RPCError> {
1089        let uri = format!(
1090            "{}/{}/component_tvl",
1091            self.url
1092                .to_string()
1093                .trim_end_matches('/'),
1094            TYCHO_SERVER_VERSION
1095        );
1096        debug!(%uri, "Sending get_component_tvl request to Tycho server");
1097        trace!(?request, "Sending request to Tycho server");
1098        let response = self
1099            .make_post_request(request, &uri)
1100            .await?;
1101        trace!(?response, "Received response from Tycho server");
1102        let body = response
1103            .text()
1104            .await
1105            .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
1106        let component_tvl =
1107            serde_json::from_str::<ComponentTvlRequestResponse>(&body).map_err(|err| {
1108                error!("Failed to parse component_tvl response: {:?}", &body);
1109                RPCError::ParseResponse(format!("Error: {err}, Body: {body}"))
1110            })?;
1111        trace!(?component_tvl, "Received component_tvl response from Tycho server");
1112        Ok(component_tvl)
1113    }
1114
1115    async fn get_traced_entry_points(
1116        &self,
1117        request: &TracedEntryPointRequestBody,
1118    ) -> Result<TracedEntryPointRequestResponse, RPCError> {
1119        let uri = format!(
1120            "{}/{TYCHO_SERVER_VERSION}/traced_entry_points",
1121            self.url
1122                .to_string()
1123                .trim_end_matches('/')
1124        );
1125        debug!(%uri, "Sending traced_entry_points request to Tycho server");
1126        trace!(?request, "Sending request to Tycho server");
1127
1128        let response = self
1129            .make_post_request(request, &uri)
1130            .await?;
1131
1132        trace!(?response, "Received response from Tycho server");
1133
1134        let body = response
1135            .text()
1136            .await
1137            .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
1138        let entrypoints =
1139            serde_json::from_str::<TracedEntryPointRequestResponse>(&body).map_err(|err| {
1140                error!("Failed to parse traced_entry_points response: {:?}", &body);
1141                RPCError::ParseResponse(format!("Error: {err}, Body: {body}"))
1142            })?;
1143        trace!(?entrypoints, "Received traced_entry_points response from Tycho server");
1144        Ok(entrypoints)
1145    }
1146
1147    async fn get_snapshots<'a>(
1148        &self,
1149        request: &SnapshotParameters<'a>,
1150        chunk_size: usize,
1151        concurrency: usize,
1152    ) -> Result<Snapshot, RPCError> {
1153        let component_ids: Vec<_> = request
1154            .components
1155            .keys()
1156            .cloned()
1157            .collect();
1158
1159        let version = VersionParam::new(
1160            None,
1161            Some({
1162                #[allow(deprecated)]
1163                BlockParam {
1164                    hash: None,
1165                    chain: Some(request.chain),
1166                    number: Some(request.block_number as i64),
1167                }
1168            }),
1169        );
1170
1171        let component_tvl = if request.include_tvl && !component_ids.is_empty() {
1172            let body = ComponentTvlRequestBody::id_filtered(component_ids.clone(), request.chain);
1173            self.get_component_tvl_paginated(&body, chunk_size, concurrency)
1174                .await?
1175                .tvl
1176        } else {
1177            HashMap::new()
1178        };
1179
1180        let mut protocol_states = if !component_ids.is_empty() {
1181            self.get_protocol_states_paginated(
1182                request.chain,
1183                &component_ids,
1184                request.protocol_system,
1185                request.include_balances,
1186                &version,
1187                chunk_size,
1188                concurrency,
1189            )
1190            .await?
1191            .states
1192            .into_iter()
1193            .map(|state| (state.component_id.clone(), state))
1194            .collect()
1195        } else {
1196            HashMap::new()
1197        };
1198
1199        // Convert to ComponentWithState, which includes entrypoint information.
1200        let states = request
1201            .components
1202            .values()
1203            .filter_map(|component| {
1204                if let Some(state) = protocol_states.remove(&component.id) {
1205                    Some((
1206                        component.id.clone(),
1207                        ComponentWithState {
1208                            state,
1209                            component: component.clone(),
1210                            component_tvl: component_tvl
1211                                .get(&component.id)
1212                                .cloned(),
1213                            entrypoints: request
1214                                .entrypoints
1215                                .as_ref()
1216                                .and_then(|map| map.get(&component.id))
1217                                .cloned()
1218                                .unwrap_or_default(),
1219                        },
1220                    ))
1221                } else if component_ids.contains(&component.id) {
1222                    // only emit error event if we requested this component
1223                    let component_id = &component.id;
1224                    error!(?component_id, "Missing state for native component!");
1225                    None
1226                } else {
1227                    None
1228                }
1229            })
1230            .collect();
1231
1232        let vm_storage = if !request.contract_ids.is_empty() {
1233            let contract_states = self
1234                .get_contract_state_paginated(
1235                    request.chain,
1236                    request.contract_ids,
1237                    request.protocol_system,
1238                    &version,
1239                    chunk_size,
1240                    concurrency,
1241                )
1242                .await?
1243                .accounts
1244                .into_iter()
1245                .map(|acc| (acc.address.clone(), acc))
1246                .collect::<HashMap<_, _>>();
1247
1248            trace!(states=?&contract_states, "Retrieved ContractState");
1249
1250            let contract_address_to_components = request
1251                .components
1252                .iter()
1253                .filter_map(|(id, comp)| {
1254                    if component_ids.contains(id) {
1255                        Some(
1256                            comp.contract_ids
1257                                .iter()
1258                                .map(|address| (address.clone(), comp.id.clone())),
1259                        )
1260                    } else {
1261                        None
1262                    }
1263                })
1264                .flatten()
1265                .fold(HashMap::<Bytes, Vec<String>>::new(), |mut acc, (addr, c_id)| {
1266                    acc.entry(addr).or_default().push(c_id);
1267                    acc
1268                });
1269
1270            request
1271                .contract_ids
1272                .iter()
1273                .filter_map(|address| {
1274                    if let Some(state) = contract_states.get(address) {
1275                        Some((address.clone(), state.clone()))
1276                    } else if let Some(ids) = contract_address_to_components.get(address) {
1277                        // only emit error even if we did actually request this address
1278                        error!(
1279                            ?address,
1280                            ?ids,
1281                            "Component with lacking contract storage encountered!"
1282                        );
1283                        None
1284                    } else {
1285                        None
1286                    }
1287                })
1288                .collect()
1289        } else {
1290            HashMap::new()
1291        };
1292
1293        Ok(Snapshot { states, vm_storage })
1294    }
1295}
1296
1297#[cfg(test)]
1298mod tests {
1299    use std::{
1300        collections::{HashMap, HashSet},
1301        str::FromStr,
1302    };
1303
1304    use mockito::Server;
1305    use rstest::rstest;
1306    // TODO: remove once deprecated ProtocolId struct is removed
1307    #[allow(deprecated)]
1308    use tycho_common::dto::ProtocolId;
1309    use tycho_common::dto::{AddressStorageLocation, TracingParams};
1310
1311    use super::*;
1312
1313    // Dummy implementation of `get_protocol_states_paginated` for backwards compatibility testing
1314    // purposes
1315    impl MockRPCClient {
1316        #[allow(clippy::too_many_arguments)]
1317        async fn test_get_protocol_states_paginated<T>(
1318            &self,
1319            chain: Chain,
1320            ids: &[T],
1321            protocol_system: &str,
1322            include_balances: bool,
1323            version: &VersionParam,
1324            chunk_size: usize,
1325            _concurrency: usize,
1326        ) -> Vec<ProtocolStateRequestBody>
1327        where
1328            T: AsRef<str> + Clone + Send + Sync + 'static,
1329        {
1330            ids.chunks(chunk_size)
1331                .map(|chunk| ProtocolStateRequestBody {
1332                    protocol_ids: Some(
1333                        chunk
1334                            .iter()
1335                            .map(|id| id.as_ref().to_string())
1336                            .collect(),
1337                    ),
1338                    protocol_system: protocol_system.to_string(),
1339                    chain,
1340                    include_balances,
1341                    version: version.clone(),
1342                    pagination: PaginationParams { page: 0, page_size: chunk_size as i64 },
1343                })
1344                .collect()
1345        }
1346    }
1347
1348    const GET_CONTRACT_STATE_RESP: &str = r#"
1349        {
1350            "accounts": [
1351                {
1352                    "chain": "ethereum",
1353                    "address": "0x0000000000000000000000000000000000000000",
1354                    "title": "",
1355                    "slots": {},
1356                    "native_balance": "0x01f4",
1357                    "token_balances": {},
1358                    "code": "0x00",
1359                    "code_hash": "0x5c06b7c5b3d910fd33bc2229846f9ddaf91d584d9b196e16636901ac3a77077e",
1360                    "balance_modify_tx": "0x0000000000000000000000000000000000000000000000000000000000000000",
1361                    "code_modify_tx": "0x0000000000000000000000000000000000000000000000000000000000000000",
1362                    "creation_tx": null
1363                }
1364            ],
1365            "pagination": {
1366                "page": 0,
1367                "page_size": 20,
1368                "total": 10
1369            }
1370        }
1371        "#;
1372
1373    // TODO: remove once deprecated ProtocolId struct is removed
1374    #[allow(deprecated)]
1375    #[rstest]
1376    #[case::protocol_id_input(vec![
1377        ProtocolId { id: "id1".to_string(), chain: Chain::Ethereum },
1378        ProtocolId { id: "id2".to_string(), chain: Chain::Ethereum }
1379    ])]
1380    #[case::string_input(vec![
1381        "id1".to_string(),
1382        "id2".to_string()
1383    ])]
1384    #[tokio::test]
1385    async fn test_get_protocol_states_paginated_backwards_compatibility<T>(#[case] ids: Vec<T>)
1386    where
1387        T: AsRef<str> + Clone + Send + Sync + 'static,
1388    {
1389        let mock_client = MockRPCClient::new();
1390
1391        let request_bodies = mock_client
1392            .test_get_protocol_states_paginated(
1393                Chain::Ethereum,
1394                &ids,
1395                "test_system",
1396                true,
1397                &VersionParam::default(),
1398                2,
1399                2,
1400            )
1401            .await;
1402
1403        // Verify that the request bodies have been created correctly
1404        assert_eq!(request_bodies.len(), 1);
1405        assert_eq!(
1406            request_bodies[0]
1407                .protocol_ids
1408                .as_ref()
1409                .unwrap()
1410                .len(),
1411            2
1412        );
1413    }
1414
1415    #[tokio::test]
1416    async fn test_get_contract_state() {
1417        let mut server = Server::new_async().await;
1418        let server_resp = GET_CONTRACT_STATE_RESP;
1419        // test that the response is deserialized correctly
1420        serde_json::from_str::<StateRequestResponse>(server_resp).expect("deserialize");
1421
1422        let mocked_server = server
1423            .mock("POST", "/v1/contract_state")
1424            .expect(1)
1425            .with_body(server_resp)
1426            .create_async()
1427            .await;
1428
1429        let client = HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
1430            .expect("create client");
1431
1432        let response = client
1433            .get_contract_state(&Default::default())
1434            .await
1435            .expect("get state");
1436        let accounts = response.accounts;
1437
1438        mocked_server.assert();
1439        assert_eq!(accounts.len(), 1);
1440        assert_eq!(accounts[0].slots, HashMap::new());
1441        assert_eq!(accounts[0].native_balance, Bytes::from(500u16.to_be_bytes()));
1442        assert_eq!(accounts[0].code, [0].to_vec());
1443        assert_eq!(
1444            accounts[0].code_hash,
1445            hex::decode("5c06b7c5b3d910fd33bc2229846f9ddaf91d584d9b196e16636901ac3a77077e")
1446                .unwrap()
1447        );
1448    }
1449
1450    #[tokio::test]
1451    async fn test_get_protocol_components() {
1452        let mut server = Server::new_async().await;
1453        let server_resp = r#"
1454        {
1455            "protocol_components": [
1456                {
1457                    "id": "State1",
1458                    "protocol_system": "ambient",
1459                    "protocol_type_name": "Pool",
1460                    "chain": "ethereum",
1461                    "tokens": [
1462                        "0x0000000000000000000000000000000000000000",
1463                        "0x0000000000000000000000000000000000000001"
1464                    ],
1465                    "contract_ids": [
1466                        "0x0000000000000000000000000000000000000000"
1467                    ],
1468                    "static_attributes": {
1469                        "attribute_1": "0x00000000000003e8"
1470                    },
1471                    "change": "Creation",
1472                    "creation_tx": "0x0000000000000000000000000000000000000000000000000000000000000000",
1473                    "created_at": "2022-01-01T00:00:00"
1474                }
1475            ],
1476            "pagination": {
1477                "page": 0,
1478                "page_size": 20,
1479                "total": 10
1480            }
1481        }
1482        "#;
1483        // test that the response is deserialized correctly
1484        serde_json::from_str::<ProtocolComponentRequestResponse>(server_resp).expect("deserialize");
1485
1486        let mocked_server = server
1487            .mock("POST", "/v1/protocol_components")
1488            .expect(1)
1489            .with_body(server_resp)
1490            .create_async()
1491            .await;
1492
1493        let client = HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
1494            .expect("create client");
1495
1496        let response = client
1497            .get_protocol_components(&Default::default())
1498            .await
1499            .expect("get state");
1500        let components = response.protocol_components;
1501
1502        mocked_server.assert();
1503        assert_eq!(components.len(), 1);
1504        assert_eq!(components[0].id, "State1");
1505        assert_eq!(components[0].protocol_system, "ambient");
1506        assert_eq!(components[0].protocol_type_name, "Pool");
1507        assert_eq!(components[0].tokens.len(), 2);
1508        let expected_attributes =
1509            [("attribute_1".to_string(), Bytes::from(1000_u64.to_be_bytes()))]
1510                .iter()
1511                .cloned()
1512                .collect::<HashMap<String, Bytes>>();
1513        assert_eq!(components[0].static_attributes, expected_attributes);
1514    }
1515
1516    #[tokio::test]
1517    async fn test_get_protocol_states() {
1518        let mut server = Server::new_async().await;
1519        let server_resp = r#"
1520        {
1521            "states": [
1522                {
1523                    "component_id": "State1",
1524                    "attributes": {
1525                        "attribute_1": "0x00000000000003e8"
1526                    },
1527                    "balances": {
1528                        "0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2": "0x01f4"
1529                    }
1530                }
1531            ],
1532            "pagination": {
1533                "page": 0,
1534                "page_size": 20,
1535                "total": 10
1536            }
1537        }
1538        "#;
1539        // test that the response is deserialized correctly
1540        serde_json::from_str::<ProtocolStateRequestResponse>(server_resp).expect("deserialize");
1541
1542        let mocked_server = server
1543            .mock("POST", "/v1/protocol_state")
1544            .expect(1)
1545            .with_body(server_resp)
1546            .create_async()
1547            .await;
1548        let client = HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
1549            .expect("create client");
1550
1551        let response = client
1552            .get_protocol_states(&Default::default())
1553            .await
1554            .expect("get state");
1555        let states = response.states;
1556
1557        mocked_server.assert();
1558        assert_eq!(states.len(), 1);
1559        assert_eq!(states[0].component_id, "State1");
1560        let expected_attributes =
1561            [("attribute_1".to_string(), Bytes::from(1000_u64.to_be_bytes()))]
1562                .iter()
1563                .cloned()
1564                .collect::<HashMap<String, Bytes>>();
1565        assert_eq!(states[0].attributes, expected_attributes);
1566        let expected_balances = [(
1567            Bytes::from_str("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2")
1568                .expect("Unsupported address format"),
1569            Bytes::from_str("0x01f4").unwrap(),
1570        )]
1571        .iter()
1572        .cloned()
1573        .collect::<HashMap<Bytes, Bytes>>();
1574        assert_eq!(states[0].balances, expected_balances);
1575    }
1576
1577    #[tokio::test]
1578    async fn test_get_tokens() {
1579        let mut server = Server::new_async().await;
1580        let server_resp = r#"
1581        {
1582            "tokens": [
1583              {
1584                "chain": "ethereum",
1585                "address": "0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2",
1586                "symbol": "WETH",
1587                "decimals": 18,
1588                "tax": 0,
1589                "gas": [
1590                  29962
1591                ],
1592                "quality": 100
1593              },
1594              {
1595                "chain": "ethereum",
1596                "address": "0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48",
1597                "symbol": "USDC",
1598                "decimals": 6,
1599                "tax": 0,
1600                "gas": [
1601                  40652
1602                ],
1603                "quality": 100
1604              }
1605            ],
1606            "pagination": {
1607              "page": 0,
1608              "page_size": 20,
1609              "total": 10
1610            }
1611          }
1612        "#;
1613        // test that the response is deserialized correctly
1614        serde_json::from_str::<TokensRequestResponse>(server_resp).expect("deserialize");
1615
1616        let mocked_server = server
1617            .mock("POST", "/v1/tokens")
1618            .expect(1)
1619            .with_body(server_resp)
1620            .create_async()
1621            .await;
1622        let client = HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
1623            .expect("create client");
1624
1625        let response = client
1626            .get_tokens(&Default::default())
1627            .await
1628            .expect("get tokens");
1629
1630        let expected = vec![
1631            ResponseToken {
1632                chain: Chain::Ethereum,
1633                address: Bytes::from_str("0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2").unwrap(),
1634                symbol: "WETH".to_string(),
1635                decimals: 18,
1636                tax: 0,
1637                gas: vec![Some(29962)],
1638                quality: 100,
1639            },
1640            ResponseToken {
1641                chain: Chain::Ethereum,
1642                address: Bytes::from_str("0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48").unwrap(),
1643                symbol: "USDC".to_string(),
1644                decimals: 6,
1645                tax: 0,
1646                gas: vec![Some(40652)],
1647                quality: 100,
1648            },
1649        ];
1650
1651        mocked_server.assert();
1652        assert_eq!(response.tokens, expected);
1653        assert_eq!(response.pagination, PaginationResponse { page: 0, page_size: 20, total: 10 });
1654    }
1655
1656    #[tokio::test]
1657    async fn test_get_protocol_systems() {
1658        let mut server = Server::new_async().await;
1659        let server_resp = r#"
1660        {
1661            "protocol_systems": [
1662                "system1",
1663                "system2"
1664            ],
1665            "pagination": {
1666                "page": 0,
1667                "page_size": 20,
1668                "total": 10
1669            }
1670        }
1671        "#;
1672        // test that the response is deserialized correctly
1673        serde_json::from_str::<ProtocolSystemsRequestResponse>(server_resp).expect("deserialize");
1674
1675        let mocked_server = server
1676            .mock("POST", "/v1/protocol_systems")
1677            .expect(1)
1678            .with_body(server_resp)
1679            .create_async()
1680            .await;
1681        let client = HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
1682            .expect("create client");
1683
1684        let response = client
1685            .get_protocol_systems(&Default::default())
1686            .await
1687            .expect("get protocol systems");
1688        let protocol_systems = response.protocol_systems;
1689
1690        mocked_server.assert();
1691        assert_eq!(protocol_systems, vec!["system1", "system2"]);
1692    }
1693
1694    #[tokio::test]
1695    async fn test_get_component_tvl() {
1696        let mut server = Server::new_async().await;
1697        let server_resp = r#"
1698        {
1699            "tvl": {
1700                "component1": 100.0
1701            },
1702            "pagination": {
1703                "page": 0,
1704                "page_size": 20,
1705                "total": 10
1706            }
1707        }
1708        "#;
1709        // test that the response is deserialized correctly
1710        serde_json::from_str::<ComponentTvlRequestResponse>(server_resp).expect("deserialize");
1711
1712        let mocked_server = server
1713            .mock("POST", "/v1/component_tvl")
1714            .expect(1)
1715            .with_body(server_resp)
1716            .create_async()
1717            .await;
1718        let client = HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
1719            .expect("create client");
1720
1721        let response = client
1722            .get_component_tvl(&Default::default())
1723            .await
1724            .expect("get protocol systems");
1725        let component_tvl = response.tvl;
1726
1727        mocked_server.assert();
1728        assert_eq!(component_tvl.get("component1"), Some(&100.0));
1729    }
1730
1731    #[tokio::test]
1732    async fn test_get_traced_entry_points() {
1733        let mut server = Server::new_async().await;
1734        let server_resp = r#"
1735        {
1736            "traced_entry_points": {
1737                "component_1": [
1738                    [
1739                        {
1740                            "entry_point": {
1741                                "external_id": "entrypoint_a",
1742                                "target": "0x0000000000000000000000000000000000000001",
1743                                "signature": "sig()"
1744                            },
1745                            "params": {
1746                                "method": "rpctracer",
1747                                "caller": "0x000000000000000000000000000000000000000a",
1748                                "calldata": "0x000000000000000000000000000000000000000b"
1749                            }
1750                        },
1751                        {
1752                            "retriggers": [
1753                                [
1754                                    "0x00000000000000000000000000000000000000aa",
1755                                    {"key": "0x0000000000000000000000000000000000000aaa", "offset": 12}
1756                                ]
1757                            ],
1758                            "accessed_slots": {
1759                                "0x0000000000000000000000000000000000aaaa": [
1760                                    "0x0000000000000000000000000000000000aaaa"
1761                                ]
1762                            }
1763                        }
1764                    ]
1765                ]
1766            },
1767            "pagination": {
1768                "page": 0,
1769                "page_size": 20,
1770                "total": 1
1771            }
1772        }
1773        "#;
1774        // test that the response is deserialized correctly
1775        serde_json::from_str::<TracedEntryPointRequestResponse>(server_resp).expect("deserialize");
1776
1777        let mocked_server = server
1778            .mock("POST", "/v1/traced_entry_points")
1779            .expect(1)
1780            .with_body(server_resp)
1781            .create_async()
1782            .await;
1783        let client = HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
1784            .expect("create client");
1785
1786        let response = client
1787            .get_traced_entry_points(&Default::default())
1788            .await
1789            .expect("get traced entry points");
1790        let entrypoints = response.traced_entry_points;
1791
1792        mocked_server.assert();
1793        assert_eq!(entrypoints.len(), 1);
1794        let comp1_entrypoints = entrypoints
1795            .get("component_1")
1796            .expect("component_1 entrypoints should exist");
1797        assert_eq!(comp1_entrypoints.len(), 1);
1798
1799        let (entrypoint, trace_result) = &comp1_entrypoints[0];
1800        assert_eq!(entrypoint.entry_point.external_id, "entrypoint_a");
1801        assert_eq!(
1802            entrypoint.entry_point.target,
1803            Bytes::from_str("0x0000000000000000000000000000000000000001").unwrap()
1804        );
1805        assert_eq!(entrypoint.entry_point.signature, "sig()");
1806        let TracingParams::RPCTracer(rpc_params) = &entrypoint.params;
1807        assert_eq!(
1808            rpc_params.caller,
1809            Some(Bytes::from("0x000000000000000000000000000000000000000a"))
1810        );
1811        assert_eq!(rpc_params.calldata, Bytes::from("0x000000000000000000000000000000000000000b"));
1812
1813        assert_eq!(
1814            trace_result.retriggers,
1815            HashSet::from([(
1816                Bytes::from("0x00000000000000000000000000000000000000aa"),
1817                AddressStorageLocation::new(
1818                    Bytes::from("0x0000000000000000000000000000000000000aaa"),
1819                    12
1820                )
1821            )])
1822        );
1823        assert_eq!(trace_result.accessed_slots.len(), 1);
1824        assert_eq!(
1825            trace_result.accessed_slots,
1826            HashMap::from([(
1827                Bytes::from("0x0000000000000000000000000000000000aaaa"),
1828                HashSet::from([Bytes::from("0x0000000000000000000000000000000000aaaa")])
1829            )])
1830        );
1831    }
1832
1833    #[tokio::test]
1834    async fn test_parse_retry_value_numeric() {
1835        let result = parse_retry_value("60");
1836        assert!(result.is_some());
1837
1838        let expected_time = SystemTime::now() + Duration::from_secs(60);
1839        let actual_time = result.unwrap();
1840
1841        // Allow for small timing differences during test execution
1842        let diff = if actual_time > expected_time {
1843            actual_time
1844                .duration_since(expected_time)
1845                .unwrap()
1846        } else {
1847            expected_time
1848                .duration_since(actual_time)
1849                .unwrap()
1850        };
1851        assert!(diff < Duration::from_secs(1), "Time difference too large: {:?}", diff);
1852    }
1853
1854    #[tokio::test]
1855    async fn test_parse_retry_value_rfc2822() {
1856        // Use a fixed future date in RFC2822 format
1857        let rfc2822_date = "Sat, 01 Jan 2030 12:00:00 +0000";
1858        let result = parse_retry_value(rfc2822_date);
1859        assert!(result.is_some());
1860
1861        let parsed_time = result.unwrap();
1862        assert!(parsed_time > SystemTime::now());
1863    }
1864
1865    #[tokio::test]
1866    async fn test_parse_retry_value_invalid_formats() {
1867        // Test various invalid formats
1868        assert!(parse_retry_value("invalid").is_none());
1869        assert!(parse_retry_value("").is_none());
1870        assert!(parse_retry_value("not_a_number").is_none());
1871        assert!(parse_retry_value("Mon, 32 Jan 2030 25:00:00 +0000").is_none()); // Invalid date
1872    }
1873
1874    #[tokio::test]
1875    async fn test_parse_retry_value_zero_seconds() {
1876        let result = parse_retry_value("0");
1877        assert!(result.is_some());
1878
1879        let expected_time = SystemTime::now();
1880        let actual_time = result.unwrap();
1881
1882        // Should be very close to current time
1883        let diff = if actual_time > expected_time {
1884            actual_time
1885                .duration_since(expected_time)
1886                .unwrap()
1887        } else {
1888            expected_time
1889                .duration_since(actual_time)
1890                .unwrap()
1891        };
1892        assert!(diff < Duration::from_secs(1));
1893    }
1894
1895    #[tokio::test]
1896    async fn test_error_for_response_rate_limited() {
1897        let mut server = Server::new_async().await;
1898        let mock = server
1899            .mock("GET", "/test")
1900            .with_status(429)
1901            .with_header("Retry-After", "60")
1902            .create_async()
1903            .await;
1904
1905        let client = reqwest::Client::new();
1906        let response = client
1907            .get(format!("{}/test", server.url()))
1908            .send()
1909            .await
1910            .unwrap();
1911
1912        let http_client =
1913            HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
1914                .unwrap()
1915                .with_test_backoff_policy();
1916        let result = http_client
1917            .error_for_response(response)
1918            .await;
1919
1920        mock.assert();
1921        assert!(matches!(result, Err(RPCError::RateLimited(_))));
1922        if let Err(RPCError::RateLimited(retry_after)) = result {
1923            assert!(retry_after.is_some());
1924        }
1925    }
1926
1927    #[tokio::test]
1928    async fn test_error_for_response_rate_limited_no_header() {
1929        let mut server = Server::new_async().await;
1930        let mock = server
1931            .mock("GET", "/test")
1932            .with_status(429)
1933            .create_async()
1934            .await;
1935
1936        let client = reqwest::Client::new();
1937        let response = client
1938            .get(format!("{}/test", server.url()))
1939            .send()
1940            .await
1941            .unwrap();
1942
1943        let http_client =
1944            HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
1945                .unwrap()
1946                .with_test_backoff_policy();
1947        let result = http_client
1948            .error_for_response(response)
1949            .await;
1950
1951        mock.assert();
1952        assert!(matches!(result, Err(RPCError::RateLimited(None))));
1953    }
1954
1955    #[tokio::test]
1956    async fn test_error_for_response_server_errors() {
1957        let test_cases =
1958            vec![(502, "Bad Gateway"), (503, "Service Unavailable"), (504, "Gateway Timeout")];
1959
1960        for (status_code, expected_body) in test_cases {
1961            let mut server = Server::new_async().await;
1962            let mock = server
1963                .mock("GET", "/test")
1964                .with_status(status_code)
1965                .with_body(expected_body)
1966                .create_async()
1967                .await;
1968
1969            let client = reqwest::Client::new();
1970            let response = client
1971                .get(format!("{}/test", server.url()))
1972                .send()
1973                .await
1974                .unwrap();
1975
1976            let http_client =
1977                HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
1978                    .unwrap()
1979                    .with_test_backoff_policy();
1980            let result = http_client
1981                .error_for_response(response)
1982                .await;
1983
1984            mock.assert();
1985            assert!(matches!(result, Err(RPCError::ServerUnreachable(_))));
1986            if let Err(RPCError::ServerUnreachable(body)) = result {
1987                assert_eq!(body, expected_body);
1988            }
1989        }
1990    }
1991
1992    #[tokio::test]
1993    async fn test_error_for_response_success() {
1994        let mut server = Server::new_async().await;
1995        let mock = server
1996            .mock("GET", "/test")
1997            .with_status(200)
1998            .with_body("success")
1999            .create_async()
2000            .await;
2001
2002        let client = reqwest::Client::new();
2003        let response = client
2004            .get(format!("{}/test", server.url()))
2005            .send()
2006            .await
2007            .unwrap();
2008
2009        let http_client =
2010            HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
2011                .unwrap()
2012                .with_test_backoff_policy();
2013        let result = http_client
2014            .error_for_response(response)
2015            .await;
2016
2017        mock.assert();
2018        assert!(result.is_ok());
2019
2020        let response = result.unwrap();
2021        assert_eq!(response.status(), 200);
2022    }
2023
2024    #[tokio::test]
2025    async fn test_handle_error_for_backoff_server_unreachable() {
2026        let http_client =
2027            HttpRPCClient::new("http://localhost:8080", HttpRPCClientOptions::default())
2028                .unwrap()
2029                .with_test_backoff_policy();
2030        let error = RPCError::ServerUnreachable("Service down".to_string());
2031
2032        let backoff_error = http_client
2033            .handle_error_for_backoff(error)
2034            .await;
2035
2036        match backoff_error {
2037            backoff::Error::Transient { err: RPCError::ServerUnreachable(msg), retry_after } => {
2038                assert_eq!(msg, "Service down");
2039                assert_eq!(retry_after, Some(Duration::from_millis(50))); // Fast test duration
2040            }
2041            _ => panic!("Expected transient error for ServerUnreachable"),
2042        }
2043    }
2044
2045    #[tokio::test]
2046    async fn test_handle_error_for_backoff_rate_limited_with_retry_after() {
2047        let http_client =
2048            HttpRPCClient::new("http://localhost:8080", HttpRPCClientOptions::default())
2049                .unwrap()
2050                .with_test_backoff_policy();
2051        let future_time = SystemTime::now() + Duration::from_secs(30);
2052        let error = RPCError::RateLimited(Some(future_time));
2053
2054        let backoff_error = http_client
2055            .handle_error_for_backoff(error)
2056            .await;
2057
2058        match backoff_error {
2059            backoff::Error::Transient { err: RPCError::RateLimited(retry_after), .. } => {
2060                assert_eq!(retry_after, Some(future_time));
2061            }
2062            _ => panic!("Expected transient error for RateLimited"),
2063        }
2064
2065        // Verify that retry_after was stored in the client state
2066        let stored_retry_after = http_client.retry_after.read().await;
2067        assert_eq!(*stored_retry_after, Some(future_time));
2068    }
2069
2070    #[tokio::test]
2071    async fn test_handle_error_for_backoff_rate_limited_no_retry_after() {
2072        let http_client =
2073            HttpRPCClient::new("http://localhost:8080", HttpRPCClientOptions::default())
2074                .unwrap()
2075                .with_test_backoff_policy();
2076        let error = RPCError::RateLimited(None);
2077
2078        let backoff_error = http_client
2079            .handle_error_for_backoff(error)
2080            .await;
2081
2082        match backoff_error {
2083            backoff::Error::Transient { err: RPCError::RateLimited(None), .. } => {
2084                // This is expected - no retry-after still allows retries with default policy
2085            }
2086            _ => panic!("Expected transient error for RateLimited without retry-after"),
2087        }
2088    }
2089
2090    #[tokio::test]
2091    async fn test_handle_error_for_backoff_other_errors() {
2092        let http_client =
2093            HttpRPCClient::new("http://localhost:8080", HttpRPCClientOptions::default())
2094                .unwrap()
2095                .with_test_backoff_policy();
2096        let error = RPCError::ParseResponse("Invalid JSON".to_string());
2097
2098        let backoff_error = http_client
2099            .handle_error_for_backoff(error)
2100            .await;
2101
2102        match backoff_error {
2103            backoff::Error::Permanent(RPCError::ParseResponse(msg)) => {
2104                assert_eq!(msg, "Invalid JSON");
2105            }
2106            _ => panic!("Expected permanent error for ParseResponse"),
2107        }
2108    }
2109
2110    #[tokio::test]
2111    async fn test_wait_until_retry_after_no_retry_time() {
2112        let http_client =
2113            HttpRPCClient::new("http://localhost:8080", HttpRPCClientOptions::default())
2114                .unwrap()
2115                .with_test_backoff_policy();
2116
2117        let start = std::time::Instant::now();
2118        http_client
2119            .wait_until_retry_after()
2120            .await;
2121        let elapsed = start.elapsed();
2122
2123        // Should return immediately if no retry time is set
2124        assert!(elapsed < Duration::from_millis(100));
2125    }
2126
2127    #[tokio::test]
2128    async fn test_wait_until_retry_after_past_time() {
2129        let http_client =
2130            HttpRPCClient::new("http://localhost:8080", HttpRPCClientOptions::default())
2131                .unwrap()
2132                .with_test_backoff_policy();
2133
2134        // Set a retry time in the past
2135        let past_time = SystemTime::now() - Duration::from_secs(10);
2136        *http_client.retry_after.write().await = Some(past_time);
2137
2138        let start = std::time::Instant::now();
2139        http_client
2140            .wait_until_retry_after()
2141            .await;
2142        let elapsed = start.elapsed();
2143
2144        // Should return immediately if retry time is in the past
2145        assert!(elapsed < Duration::from_millis(100));
2146    }
2147
2148    #[tokio::test]
2149    async fn test_wait_until_retry_after_future_time() {
2150        let http_client =
2151            HttpRPCClient::new("http://localhost:8080", HttpRPCClientOptions::default())
2152                .unwrap()
2153                .with_test_backoff_policy();
2154
2155        // Set a retry time 100ms in the future
2156        let future_time = SystemTime::now() + Duration::from_millis(100);
2157        *http_client.retry_after.write().await = Some(future_time);
2158
2159        let start = std::time::Instant::now();
2160        http_client
2161            .wait_until_retry_after()
2162            .await;
2163        let elapsed = start.elapsed();
2164
2165        // Should wait approximately the specified duration
2166        assert!(elapsed >= Duration::from_millis(80)); // Allow some tolerance
2167        assert!(elapsed <= Duration::from_millis(200)); // Upper bound for test stability
2168    }
2169
2170    #[tokio::test]
2171    async fn test_make_post_request_success() {
2172        let mut server = Server::new_async().await;
2173        let server_resp = r#"{"success": true}"#;
2174
2175        let mock = server
2176            .mock("POST", "/test")
2177            .with_status(200)
2178            .with_body(server_resp)
2179            .create_async()
2180            .await;
2181
2182        let http_client =
2183            HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
2184                .unwrap()
2185                .with_test_backoff_policy();
2186        let request_body = serde_json::json!({"test": "data"});
2187        let uri = format!("{}/test", server.url());
2188
2189        let result = http_client
2190            .make_post_request(&request_body, &uri)
2191            .await;
2192
2193        mock.assert();
2194        assert!(result.is_ok());
2195
2196        let response = result.unwrap();
2197        assert_eq!(response.status(), 200);
2198        assert_eq!(response.text().await.unwrap(), server_resp);
2199    }
2200
2201    #[tokio::test]
2202    async fn test_make_post_request_retry_on_server_error() {
2203        let mut server = Server::new_async().await;
2204        // First request fails with 503, second succeeds
2205        let error_mock = server
2206            .mock("POST", "/test")
2207            .with_status(503)
2208            .with_body("Service Unavailable")
2209            .expect(1)
2210            .create_async()
2211            .await;
2212
2213        let success_mock = server
2214            .mock("POST", "/test")
2215            .with_status(200)
2216            .with_body(r#"{"success": true}"#)
2217            .expect(1)
2218            .create_async()
2219            .await;
2220
2221        let http_client =
2222            HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
2223                .unwrap()
2224                .with_test_backoff_policy();
2225        let request_body = serde_json::json!({"test": "data"});
2226        let uri = format!("{}/test", server.url());
2227
2228        let result = http_client
2229            .make_post_request(&request_body, &uri)
2230            .await;
2231
2232        error_mock.assert();
2233        success_mock.assert();
2234        assert!(result.is_ok());
2235    }
2236
2237    #[tokio::test]
2238    async fn test_make_post_request_respect_retry_after_header() {
2239        let mut server = Server::new_async().await;
2240
2241        // First request returns 429 with retry-after, second succeeds
2242        let rate_limit_mock = server
2243            .mock("POST", "/test")
2244            .with_status(429)
2245            .with_header("Retry-After", "1") // 1 second
2246            .expect(1)
2247            .create_async()
2248            .await;
2249
2250        let success_mock = server
2251            .mock("POST", "/test")
2252            .with_status(200)
2253            .with_body(r#"{"success": true}"#)
2254            .expect(1)
2255            .create_async()
2256            .await;
2257
2258        let http_client =
2259            HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
2260                .unwrap()
2261                .with_test_backoff_policy();
2262        let request_body = serde_json::json!({"test": "data"});
2263        let uri = format!("{}/test", server.url());
2264
2265        let start = std::time::Instant::now();
2266        let result = http_client
2267            .make_post_request(&request_body, &uri)
2268            .await;
2269        let elapsed = start.elapsed();
2270
2271        rate_limit_mock.assert();
2272        success_mock.assert();
2273        assert!(result.is_ok());
2274
2275        // Should have waited at least 1 second due to retry-after header
2276        assert!(elapsed >= Duration::from_millis(900)); // Allow some tolerance
2277        assert!(elapsed <= Duration::from_millis(2000)); // Upper bound for test stability
2278    }
2279
2280    #[tokio::test]
2281    async fn test_make_post_request_permanent_error() {
2282        let mut server = Server::new_async().await;
2283
2284        let mock = server
2285            .mock("POST", "/test")
2286            .with_status(400) // Bad Request - should not be retried
2287            .with_body("Bad Request")
2288            .expect(1)
2289            .create_async()
2290            .await;
2291
2292        let http_client =
2293            HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
2294                .unwrap()
2295                .with_test_backoff_policy();
2296        let request_body = serde_json::json!({"test": "data"});
2297        let uri = format!("{}/test", server.url());
2298
2299        let result = http_client
2300            .make_post_request(&request_body, &uri)
2301            .await;
2302
2303        mock.assert();
2304        assert!(result.is_ok()); // 400 doesn't trigger retry logic, just returns the response
2305
2306        let response = result.unwrap();
2307        assert_eq!(response.status(), 400);
2308    }
2309
2310    #[tokio::test]
2311    async fn test_concurrent_requests_with_different_retry_after() {
2312        let mut server = Server::new_async().await;
2313
2314        // First request gets rate limited with 1 second retry-after
2315        let rate_limit_mock_1 = server
2316            .mock("POST", "/test1")
2317            .with_status(429)
2318            .with_header("Retry-After", "1")
2319            .expect(1)
2320            .create_async()
2321            .await;
2322
2323        // Second request gets rate limited with 2 second retry-after
2324        let rate_limit_mock_2 = server
2325            .mock("POST", "/test2")
2326            .with_status(429)
2327            .with_header("Retry-After", "2")
2328            .expect(1)
2329            .create_async()
2330            .await;
2331
2332        // Success mocks for retries
2333        let success_mock_1 = server
2334            .mock("POST", "/test1")
2335            .with_status(200)
2336            .with_body(r#"{"result": "success1"}"#)
2337            .expect(1)
2338            .create_async()
2339            .await;
2340
2341        let success_mock_2 = server
2342            .mock("POST", "/test2")
2343            .with_status(200)
2344            .with_body(r#"{"result": "success2"}"#)
2345            .expect(1)
2346            .create_async()
2347            .await;
2348
2349        let http_client =
2350            HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
2351                .unwrap()
2352                .with_test_backoff_policy();
2353        let request_body = serde_json::json!({"test": "data"});
2354
2355        let uri1 = format!("{}/test1", server.url());
2356        let uri2 = format!("{}/test2", server.url());
2357
2358        // Start both requests concurrently
2359        let start = std::time::Instant::now();
2360        let (result1, result2) = tokio::join!(
2361            http_client.make_post_request(&request_body, &uri1),
2362            http_client.make_post_request(&request_body, &uri2)
2363        );
2364        let elapsed = start.elapsed();
2365
2366        rate_limit_mock_1.assert();
2367        rate_limit_mock_2.assert();
2368        success_mock_1.assert();
2369        success_mock_2.assert();
2370
2371        assert!(result1.is_ok());
2372        assert!(result2.is_ok());
2373
2374        // Both requests should succeed, but the second should take longer due to the 2s retry-after
2375        // The total time should be at least 2 seconds since the shared retry_after state
2376        // gets updated by both requests
2377        assert!(elapsed >= Duration::from_millis(1800)); // Allow some tolerance
2378        assert!(elapsed <= Duration::from_millis(3000)); // Upper bound
2379
2380        // Check the final retry_after state - should be the latest (higher) value
2381        let final_retry_after = http_client.retry_after.read().await;
2382        assert!(final_retry_after.is_some());
2383
2384        // The retry_after should be set to the latest (higher) value from the two requests
2385        if let Some(retry_time) = *final_retry_after {
2386            // The retry_after time might be in the past now since we waited,
2387            // but it should be reasonable (not too far in past/future)
2388            let now = SystemTime::now();
2389            let diff = if retry_time > now {
2390                retry_time.duration_since(now).unwrap()
2391            } else {
2392                now.duration_since(retry_time).unwrap()
2393            };
2394
2395            // Should be within a reasonable range (the 2s retry-after plus some buffer)
2396            assert!(diff <= Duration::from_secs(3), "Retry time difference too large: {:?}", diff);
2397        }
2398    }
2399
2400    #[tokio::test]
2401    async fn test_get_snapshots() {
2402        let mut server = Server::new_async().await;
2403
2404        // Mock protocol states response
2405        let protocol_states_resp = r#"
2406        {
2407            "states": [
2408                {
2409                    "component_id": "component1",
2410                    "attributes": {
2411                        "attribute_1": "0x00000000000003e8"
2412                    },
2413                    "balances": {
2414                        "0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2": "0x01f4"
2415                    }
2416                }
2417            ],
2418            "pagination": {
2419                "page": 0,
2420                "page_size": 100,
2421                "total": 1
2422            }
2423        }
2424        "#;
2425
2426        // Mock contract state response
2427        let contract_state_resp = r#"
2428        {
2429            "accounts": [
2430                {
2431                    "chain": "ethereum",
2432                    "address": "0x1111111111111111111111111111111111111111",
2433                    "title": "",
2434                    "slots": {},
2435                    "native_balance": "0x01f4",
2436                    "token_balances": {},
2437                    "code": "0x00",
2438                    "code_hash": "0x5c06b7c5b3d910fd33bc2229846f9ddaf91d584d9b196e16636901ac3a77077e",
2439                    "balance_modify_tx": "0x0000000000000000000000000000000000000000000000000000000000000000",
2440                    "code_modify_tx": "0x0000000000000000000000000000000000000000000000000000000000000000",
2441                    "creation_tx": null
2442                }
2443            ],
2444            "pagination": {
2445                "page": 0,
2446                "page_size": 100,
2447                "total": 1
2448            }
2449        }
2450        "#;
2451
2452        // Mock component TVL response
2453        let tvl_resp = r#"
2454        {
2455            "tvl": {
2456                "component1": 1000000.0
2457            },
2458            "pagination": {
2459                "page": 0,
2460                "page_size": 100,
2461                "total": 1
2462            }
2463        }
2464        "#;
2465
2466        let protocol_states_mock = server
2467            .mock("POST", "/v1/protocol_state")
2468            .expect(1)
2469            .with_body(protocol_states_resp)
2470            .create_async()
2471            .await;
2472
2473        let contract_state_mock = server
2474            .mock("POST", "/v1/contract_state")
2475            .expect(1)
2476            .with_body(contract_state_resp)
2477            .create_async()
2478            .await;
2479
2480        let tvl_mock = server
2481            .mock("POST", "/v1/component_tvl")
2482            .expect(1)
2483            .with_body(tvl_resp)
2484            .create_async()
2485            .await;
2486
2487        let client = HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
2488            .expect("create client");
2489
2490        #[allow(deprecated)]
2491        let component = tycho_common::dto::ProtocolComponent {
2492            id: "component1".to_string(),
2493            protocol_system: "test_protocol".to_string(),
2494            protocol_type_name: "test_type".to_string(),
2495            chain: Chain::Ethereum,
2496            tokens: vec![],
2497            contract_ids: vec![
2498                Bytes::from_str("0x1111111111111111111111111111111111111111").unwrap()
2499            ],
2500            static_attributes: HashMap::new(),
2501            change: tycho_common::dto::ChangeType::Creation,
2502            creation_tx: Bytes::from_str(
2503                "0x0000000000000000000000000000000000000000000000000000000000000000",
2504            )
2505            .unwrap(),
2506            created_at: chrono::Utc::now().naive_utc(),
2507        };
2508
2509        let mut components = HashMap::new();
2510        components.insert("component1".to_string(), component);
2511
2512        let contract_ids =
2513            vec![Bytes::from_str("0x1111111111111111111111111111111111111111").unwrap()];
2514
2515        let request = SnapshotParameters::new(
2516            Chain::Ethereum,
2517            "test_protocol",
2518            &components,
2519            &contract_ids,
2520            12345,
2521        );
2522
2523        let response = client
2524            .get_snapshots(&request, 100, 4)
2525            .await
2526            .expect("get snapshots");
2527
2528        // Verify all mocks were called
2529        protocol_states_mock.assert();
2530        contract_state_mock.assert();
2531        tvl_mock.assert();
2532
2533        // Assert states
2534        assert_eq!(response.states.len(), 1);
2535        assert!(response
2536            .states
2537            .contains_key("component1"));
2538
2539        // Check that the state has the expected TVL
2540        let component_state = response
2541            .states
2542            .get("component1")
2543            .unwrap();
2544        assert_eq!(component_state.component_tvl, Some(1000000.0));
2545
2546        // Assert VM storage
2547        assert_eq!(response.vm_storage.len(), 1);
2548        let contract_addr = Bytes::from_str("0x1111111111111111111111111111111111111111").unwrap();
2549        assert!(response
2550            .vm_storage
2551            .contains_key(&contract_addr));
2552    }
2553
2554    #[tokio::test]
2555    async fn test_get_snapshots_empty_components() {
2556        let server = Server::new_async().await;
2557        let client = HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
2558            .expect("create client");
2559
2560        let components = HashMap::new();
2561        let contract_ids = vec![];
2562
2563        let request = SnapshotParameters::new(
2564            Chain::Ethereum,
2565            "test_protocol",
2566            &components,
2567            &contract_ids,
2568            12345,
2569        );
2570
2571        let response = client
2572            .get_snapshots(&request, 100, 4)
2573            .await
2574            .expect("get snapshots");
2575
2576        // Should return empty response without making any requests
2577        assert!(response.states.is_empty());
2578        assert!(response.vm_storage.is_empty());
2579    }
2580
2581    #[tokio::test]
2582    async fn test_get_snapshots_without_tvl() {
2583        let mut server = Server::new_async().await;
2584
2585        let protocol_states_resp = r#"
2586        {
2587            "states": [
2588                {
2589                    "component_id": "component1",
2590                    "attributes": {},
2591                    "balances": {}
2592                }
2593            ],
2594            "pagination": {
2595                "page": 0,
2596                "page_size": 100,
2597                "total": 1
2598            }
2599        }
2600        "#;
2601
2602        let protocol_states_mock = server
2603            .mock("POST", "/v1/protocol_state")
2604            .expect(1)
2605            .with_body(protocol_states_resp)
2606            .create_async()
2607            .await;
2608
2609        let client = HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
2610            .expect("create client");
2611
2612        // Create test component
2613        #[allow(deprecated)]
2614        let component = tycho_common::dto::ProtocolComponent {
2615            id: "component1".to_string(),
2616            protocol_system: "test_protocol".to_string(),
2617            protocol_type_name: "test_type".to_string(),
2618            chain: Chain::Ethereum,
2619            tokens: vec![],
2620            contract_ids: vec![],
2621            static_attributes: HashMap::new(),
2622            change: tycho_common::dto::ChangeType::Creation,
2623            creation_tx: Bytes::from_str(
2624                "0x0000000000000000000000000000000000000000000000000000000000000000",
2625            )
2626            .unwrap(),
2627            created_at: chrono::Utc::now().naive_utc(),
2628        };
2629
2630        let mut components = HashMap::new();
2631        components.insert("component1".to_string(), component);
2632        let contract_ids = vec![];
2633
2634        let request = SnapshotParameters::new(
2635            Chain::Ethereum,
2636            "test_protocol",
2637            &components,
2638            &contract_ids,
2639            12345,
2640        )
2641        .include_balances(false)
2642        .include_tvl(false);
2643
2644        let response = client
2645            .get_snapshots(&request, 100, 4)
2646            .await
2647            .expect("get snapshots");
2648
2649        // Verify only necessary mocks were called
2650        protocol_states_mock.assert();
2651        // No contract_state_mock.assert() since contract_ids is empty
2652        // No tvl_mock.assert() since include_tvl is false
2653
2654        assert_eq!(response.states.len(), 1);
2655        // Check that TVL is None since we didn't request it
2656        let component_state = response
2657            .states
2658            .get("component1")
2659            .unwrap();
2660        assert_eq!(component_state.component_tvl, None);
2661    }
2662
2663    #[tokio::test]
2664    async fn test_compression_enabled() {
2665        let mut server = Server::new_async().await;
2666        let server_resp = GET_CONTRACT_STATE_RESP;
2667
2668        // Compress the response using zstd
2669        let compressed_body =
2670            zstd::encode_all(server_resp.as_bytes(), 0).expect("compression failed");
2671
2672        let mocked_server = server
2673            .mock("POST", "/v1/contract_state")
2674            .expect(1)
2675            .with_header("Content-Encoding", "zstd")
2676            .with_body(compressed_body)
2677            .create_async()
2678            .await;
2679
2680        // Create client with compression enabled
2681        let client = HttpRPCClient::new(
2682            server.url().as_str(),
2683            HttpRPCClientOptions::new().with_compression(true),
2684        )
2685        .expect("create client");
2686
2687        let response = client
2688            .get_contract_state(&Default::default())
2689            .await
2690            .expect("get state");
2691        let accounts = response.accounts;
2692
2693        mocked_server.assert();
2694        assert_eq!(accounts.len(), 1);
2695        assert_eq!(accounts[0].native_balance, Bytes::from(500u16.to_be_bytes()));
2696    }
2697
2698    #[tokio::test]
2699    async fn test_compression_disabled() {
2700        let mut server = Server::new_async().await;
2701        let server_resp = GET_CONTRACT_STATE_RESP;
2702
2703        // Server sends plain text response
2704        let mocked_server = server
2705            .mock("POST", "/v1/contract_state")
2706            .expect(1)
2707            .with_body(server_resp)
2708            .create_async()
2709            .await;
2710
2711        // Create client with compression disabled
2712        let client = HttpRPCClient::new(
2713            server.url().as_str(),
2714            HttpRPCClientOptions::new().with_compression(false),
2715        )
2716        .expect("create client");
2717
2718        let response = client
2719            .get_contract_state(&Default::default())
2720            .await
2721            .expect("get state");
2722        let accounts = response.accounts;
2723
2724        mocked_server.assert();
2725        assert_eq!(accounts.len(), 1);
2726        assert_eq!(accounts[0].native_balance, Bytes::from(500u16.to_be_bytes()));
2727    }
2728}