tycho_client/
rpc.rs

1//! # Tycho RPC Client
2//!
3//! The objective of this module is to provide swift and simplified access to the Remote Procedure
4//! Call (RPC) endpoints of Tycho. These endpoints are chiefly responsible for facilitating data
5//! queries, especially querying snapshots of data.
6use std::{
7    collections::HashMap,
8    sync::Arc,
9    time::{Duration, SystemTime},
10};
11
12use async_trait::async_trait;
13use backoff::{exponential::ExponentialBackoffBuilder, ExponentialBackoff};
14use futures03::future::try_join_all;
15#[cfg(test)]
16use mockall::automock;
17use reqwest::{header, Client, ClientBuilder, Response, StatusCode, Url};
18use serde::Serialize;
19use thiserror::Error;
20use time::{format_description::well_known::Rfc2822, OffsetDateTime};
21use tokio::{
22    sync::{RwLock, Semaphore},
23    time::sleep,
24};
25use tracing::{debug, error, instrument, trace, warn};
26use tycho_common::{
27    dto::{
28        BlockParam, Chain, ComponentTvlRequestBody, ComponentTvlRequestResponse,
29        EntryPointWithTracingParams, PaginationLimits, PaginationParams, PaginationResponse,
30        ProtocolComponent, ProtocolComponentRequestResponse, ProtocolComponentsRequestBody,
31        ProtocolStateRequestBody, ProtocolStateRequestResponse, ProtocolSystemsRequestBody,
32        ProtocolSystemsRequestResponse, ResponseToken, StateRequestBody, StateRequestResponse,
33        TokensRequestBody, TokensRequestResponse, TracedEntryPointRequestBody,
34        TracedEntryPointRequestResponse, TracingResult, VersionParam,
35    },
36    models::ComponentId,
37    Bytes,
38};
39
40use crate::{
41    feed::synchronizer::{ComponentWithState, Snapshot},
42    TYCHO_SERVER_VERSION,
43};
44
45/// Suggested concurrency level for RPC clients.
46pub const RPC_CLIENT_CONCURRENCY: usize = 4;
47
48/// Request body for fetching a snapshot of protocol states and VM storage.
49///
50/// This struct helps to coordinate fetching  multiple pieces of related data
51/// (protocol states, contract storage, TVL, entry points).
52#[derive(Clone, Debug, PartialEq)]
53pub struct SnapshotParameters<'a> {
54    /// Which chain to fetch snapshots for
55    pub chain: Chain,
56    /// Protocol system name, required for correct state resolution
57    pub protocol_system: &'a str,
58    /// Components to fetch protocol states for
59    pub components: &'a HashMap<ComponentId, ProtocolComponent>,
60    /// Traced entry points data mapped by component id
61    pub entrypoints: Option<&'a HashMap<String, Vec<(EntryPointWithTracingParams, TracingResult)>>>,
62    /// Contract addresses to fetch VM storage for
63    pub contract_ids: &'a [Bytes],
64    /// Block number for versioning
65    pub block_number: u64,
66    /// Whether to include balance information
67    pub include_balances: bool,
68    /// Whether to fetch TVL data
69    pub include_tvl: bool,
70}
71
72impl<'a> SnapshotParameters<'a> {
73    pub fn new(
74        chain: Chain,
75        protocol_system: &'a str,
76        components: &'a HashMap<ComponentId, ProtocolComponent>,
77        contract_ids: &'a [Bytes],
78        block_number: u64,
79    ) -> Self {
80        Self {
81            chain,
82            protocol_system,
83            components,
84            entrypoints: None,
85            contract_ids,
86            block_number,
87            include_balances: true,
88            include_tvl: true,
89        }
90    }
91
92    /// Set whether to include balance information (default: true)
93    pub fn include_balances(mut self, include_balances: bool) -> Self {
94        self.include_balances = include_balances;
95        self
96    }
97
98    /// Set whether to fetch TVL data (default: true)
99    pub fn include_tvl(mut self, include_tvl: bool) -> Self {
100        self.include_tvl = include_tvl;
101        self
102    }
103
104    pub fn entrypoints(
105        mut self,
106        entrypoints: &'a HashMap<String, Vec<(EntryPointWithTracingParams, TracingResult)>>,
107    ) -> Self {
108        self.entrypoints = Some(entrypoints);
109        self
110    }
111}
112
113#[derive(Error, Debug)]
114pub enum RPCError {
115    /// The passed tycho url failed to parse.
116    #[error("Failed to parse URL: {0}. Error: {1}")]
117    UrlParsing(String, String),
118
119    /// The request data is not correctly formed.
120    #[error("Failed to format request: {0}")]
121    FormatRequest(String),
122
123    /// Errors forwarded from the HTTP protocol.
124    #[error("Unexpected HTTP client error: {0}")]
125    HttpClient(String, #[source] reqwest::Error),
126
127    /// The response from the server could not be parsed correctly.
128    #[error("Failed to parse response: {0}")]
129    ParseResponse(String),
130
131    /// Other fatal errors.
132    #[error("Fatal error: {0}")]
133    Fatal(String),
134
135    #[error("Rate limited until {0:?}")]
136    RateLimited(Option<SystemTime>),
137
138    #[error("Server unreachable: {0}")]
139    ServerUnreachable(String),
140}
141
142#[cfg_attr(test, automock)]
143#[async_trait]
144pub trait RPCClient: Send + Sync {
145    /// Returns whether compression is enabled for requests.
146    fn compression(&self) -> bool;
147
148    /// Retrieves a snapshot of contract state.
149    async fn get_contract_state(
150        &self,
151        request: &StateRequestBody,
152    ) -> Result<StateRequestResponse, RPCError>;
153
154    /// Retrieves a snapshot of contract state for a set of contract IDs.
155    /// If the `chunk_size` is `None`, it defaults to the maximum page size
156    async fn get_contract_state_paginated(
157        &self,
158        chain: Chain,
159        ids: &[Bytes],
160        protocol_system: &str,
161        version: &VersionParam,
162        chunk_size: Option<usize>,
163        concurrency: usize,
164    ) -> Result<StateRequestResponse, RPCError> {
165        let semaphore = Arc::new(Semaphore::new(concurrency));
166
167        // Sort the ids to maximize server-side cache hits
168        let mut sorted_ids = ids.to_vec();
169        sorted_ids.sort();
170
171        let chunk_size = chunk_size
172            .unwrap_or(StateRequestBody::effective_max_page_size(self.compression()) as usize);
173
174        let chunked_bodies = sorted_ids
175            .chunks(chunk_size)
176            .map(|chunk| StateRequestBody {
177                contract_ids: Some(chunk.to_vec()),
178                protocol_system: protocol_system.to_string(),
179                chain,
180                version: version.clone(),
181                pagination: PaginationParams { page: 0, page_size: chunk_size as i64 },
182            })
183            .collect::<Vec<_>>();
184
185        let mut tasks = Vec::new();
186        for body in chunked_bodies.iter() {
187            let sem = semaphore.clone();
188            tasks.push(async move {
189                let _permit = sem
190                    .acquire()
191                    .await
192                    .map_err(|_| RPCError::Fatal("Semaphore dropped".to_string()))?;
193                self.get_contract_state(body).await
194            });
195        }
196
197        // Execute all tasks concurrently with the defined concurrency limit.
198        let responses = try_join_all(tasks).await?;
199
200        // Aggregate the responses into a single result.
201        let accounts = responses
202            .iter()
203            .flat_map(|r| r.accounts.clone())
204            .collect();
205        let total: i64 = responses
206            .iter()
207            .map(|r| r.pagination.total)
208            .sum();
209
210        Ok(StateRequestResponse {
211            accounts,
212            pagination: PaginationResponse { page: 0, page_size: chunk_size as i64, total },
213        })
214    }
215
216    async fn get_protocol_components(
217        &self,
218        request: &ProtocolComponentsRequestBody,
219    ) -> Result<ProtocolComponentRequestResponse, RPCError>;
220
221    /// Retrieves protocol components.
222    /// If the `chunk_size` is `None`, it defaults to the maximum page size.
223    async fn get_protocol_components_paginated(
224        &self,
225        request: &ProtocolComponentsRequestBody,
226        chunk_size: Option<usize>,
227        concurrency: usize,
228    ) -> Result<ProtocolComponentRequestResponse, RPCError> {
229        let semaphore = Arc::new(Semaphore::new(concurrency));
230
231        let chunk_size = chunk_size.unwrap_or(
232            ProtocolComponentsRequestBody::effective_max_page_size(self.compression()) as usize,
233        );
234
235        // If a set of component IDs is specified, the maximum return size is already known,
236        // allowing us to pre-compute the number of requests to be made.
237        match request.component_ids {
238            Some(ref ids) => {
239                // We can divide the component_ids into chunks of size chunk_size
240                let chunked_bodies = ids
241                    .chunks(chunk_size)
242                    .enumerate()
243                    .map(|(index, _)| ProtocolComponentsRequestBody {
244                        protocol_system: request.protocol_system.clone(),
245                        component_ids: request.component_ids.clone(),
246                        tvl_gt: request.tvl_gt,
247                        chain: request.chain,
248                        pagination: PaginationParams {
249                            page: index as i64,
250                            page_size: chunk_size as i64,
251                        },
252                    })
253                    .collect::<Vec<_>>();
254
255                let mut tasks = Vec::new();
256                for body in chunked_bodies.iter() {
257                    let sem = semaphore.clone();
258                    tasks.push(async move {
259                        let _permit = sem
260                            .acquire()
261                            .await
262                            .map_err(|_| RPCError::Fatal("Semaphore dropped".to_string()))?;
263                        self.get_protocol_components(body).await
264                    });
265                }
266
267                try_join_all(tasks)
268                    .await
269                    .map(|responses| ProtocolComponentRequestResponse {
270                        protocol_components: responses
271                            .into_iter()
272                            .flat_map(|r| r.protocol_components.into_iter())
273                            .collect(),
274                        pagination: PaginationResponse {
275                            page: 0,
276                            page_size: chunk_size as i64,
277                            total: ids.len() as i64,
278                        },
279                    })
280            }
281            _ => {
282                // If no component ids are specified, we need to make requests based on the total
283                // number of results from the first response.
284
285                let initial_request = ProtocolComponentsRequestBody {
286                    protocol_system: request.protocol_system.clone(),
287                    component_ids: request.component_ids.clone(),
288                    tvl_gt: request.tvl_gt,
289                    chain: request.chain,
290                    pagination: PaginationParams { page: 0, page_size: chunk_size as i64 },
291                };
292                let first_response = self
293                    .get_protocol_components(&initial_request)
294                    .await
295                    .map_err(|err| RPCError::Fatal(err.to_string()))?;
296
297                let total_items = first_response.pagination.total;
298                let total_pages = (total_items as f64 / chunk_size as f64).ceil() as i64;
299
300                // Initialize the final response accumulator
301                let mut accumulated_response = ProtocolComponentRequestResponse {
302                    protocol_components: first_response.protocol_components,
303                    pagination: PaginationResponse {
304                        page: 0,
305                        page_size: chunk_size as i64,
306                        total: total_items,
307                    },
308                };
309
310                let mut page = 1;
311                while page < total_pages {
312                    let requests_in_this_iteration = (total_pages - page).min(concurrency as i64);
313
314                    // Create request bodies for parallel requests, respecting the concurrency limit
315                    let chunked_bodies = (0..requests_in_this_iteration)
316                        .map(|iter| ProtocolComponentsRequestBody {
317                            protocol_system: request.protocol_system.clone(),
318                            component_ids: request.component_ids.clone(),
319                            tvl_gt: request.tvl_gt,
320                            chain: request.chain,
321                            pagination: PaginationParams {
322                                page: page + iter,
323                                page_size: chunk_size as i64,
324                            },
325                        })
326                        .collect::<Vec<_>>();
327
328                    let tasks: Vec<_> = chunked_bodies
329                        .iter()
330                        .map(|body| {
331                            let sem = semaphore.clone();
332                            async move {
333                                let _permit = sem.acquire().await.map_err(|_| {
334                                    RPCError::Fatal("Semaphore dropped".to_string())
335                                })?;
336                                self.get_protocol_components(body).await
337                            }
338                        })
339                        .collect();
340
341                    let responses = try_join_all(tasks)
342                        .await
343                        .map(|responses| {
344                            let total = responses[0].pagination.total;
345                            ProtocolComponentRequestResponse {
346                                protocol_components: responses
347                                    .into_iter()
348                                    .flat_map(|r| r.protocol_components.into_iter())
349                                    .collect(),
350                                pagination: PaginationResponse {
351                                    page,
352                                    page_size: chunk_size as i64,
353                                    total,
354                                },
355                            }
356                        });
357
358                    // Update the accumulated response or set the initial response
359                    match responses {
360                        Ok(mut resp) => {
361                            accumulated_response
362                                .protocol_components
363                                .append(&mut resp.protocol_components);
364                        }
365                        Err(e) => return Err(e),
366                    }
367
368                    page += concurrency as i64;
369                }
370                Ok(accumulated_response)
371            }
372        }
373    }
374
375    async fn get_protocol_states(
376        &self,
377        request: &ProtocolStateRequestBody,
378    ) -> Result<ProtocolStateRequestResponse, RPCError>;
379
380    /// Retrieves protocol states for a set of protocol IDs.
381    /// If the `chunk_size` is `None`, it defaults to the maximum page size.
382    #[allow(clippy::too_many_arguments)]
383    async fn get_protocol_states_paginated<T>(
384        &self,
385        chain: Chain,
386        ids: &[T],
387        protocol_system: &str,
388        include_balances: bool,
389        version: &VersionParam,
390        chunk_size: Option<usize>,
391        concurrency: usize,
392    ) -> Result<ProtocolStateRequestResponse, RPCError>
393    where
394        T: AsRef<str> + Sync + 'static,
395    {
396        let semaphore = Arc::new(Semaphore::new(concurrency));
397
398        let chunk_size = chunk_size.unwrap_or(ProtocolStateRequestBody::effective_max_page_size(
399            self.compression(),
400        ) as usize);
401
402        let chunked_bodies = ids
403            .chunks(chunk_size)
404            .map(|c| ProtocolStateRequestBody {
405                protocol_ids: Some(
406                    c.iter()
407                        .map(|id| id.as_ref().to_string())
408                        .collect(),
409                ),
410                protocol_system: protocol_system.to_string(),
411                chain,
412                include_balances,
413                version: version.clone(),
414                pagination: PaginationParams { page: 0, page_size: chunk_size as i64 },
415            })
416            .collect::<Vec<_>>();
417
418        let mut tasks = Vec::new();
419        for body in chunked_bodies.iter() {
420            let sem = semaphore.clone();
421            tasks.push(async move {
422                let _permit = sem
423                    .acquire()
424                    .await
425                    .map_err(|_| RPCError::Fatal("Semaphore dropped".to_string()))?;
426                self.get_protocol_states(body).await
427            });
428        }
429
430        try_join_all(tasks)
431            .await
432            .map(|responses| {
433                let states = responses
434                    .clone()
435                    .into_iter()
436                    .flat_map(|r| r.states)
437                    .collect();
438                let total = responses
439                    .iter()
440                    .map(|r| r.pagination.total)
441                    .sum();
442                ProtocolStateRequestResponse {
443                    states,
444                    pagination: PaginationResponse { page: 0, page_size: chunk_size as i64, total },
445                }
446            })
447    }
448
449    /// This function returns only one chunk of tokens. To get all tokens please call
450    /// get_all_tokens.
451    async fn get_tokens(
452        &self,
453        request: &TokensRequestBody,
454    ) -> Result<TokensRequestResponse, RPCError>;
455
456    /// Retrieves all tokens matching the given criteria.
457    /// If the `chunk_size` is `None`, it defaults to the maximum page size.
458    async fn get_all_tokens(
459        &self,
460        chain: Chain,
461        min_quality: Option<i32>,
462        traded_n_days_ago: Option<u64>,
463        chunk_size: Option<usize>,
464        concurrency: usize,
465    ) -> Result<Vec<ResponseToken>, RPCError> {
466        let chunk_size = chunk_size
467            .unwrap_or(TokensRequestBody::effective_max_page_size(self.compression()) as usize);
468
469        let semaphore = Arc::new(Semaphore::new(concurrency));
470
471        // Make initial request to get total count
472        let page_size = chunk_size.try_into().map_err(|_| {
473            RPCError::FormatRequest("Failed to convert chunk_size into i64".to_string())
474        })?;
475
476        let initial_request = TokensRequestBody {
477            token_addresses: None,
478            min_quality,
479            traded_n_days_ago,
480            pagination: PaginationParams { page: 0, page_size },
481            chain,
482        };
483
484        let first_response = self
485            .get_tokens(&initial_request)
486            .await?;
487        let total_items = first_response.pagination.total;
488        let total_pages = (total_items as f64 / chunk_size as f64).ceil() as i64;
489
490        let mut all_tokens = first_response.tokens;
491
492        // If only one page, return immediately
493        if total_pages <= 1 {
494            return Ok(all_tokens);
495        }
496
497        // Create a task for each remaining page
498        let tasks: Vec<_> = (1..total_pages)
499            .map(|page| {
500                let sem = semaphore.clone();
501                let request = TokensRequestBody {
502                    token_addresses: None,
503                    min_quality,
504                    traded_n_days_ago,
505                    pagination: PaginationParams { page, page_size },
506                    chain,
507                };
508
509                async move {
510                    // Semaphore controls how many requests are actually in-flight
511                    let _permit = sem
512                        .acquire()
513                        .await
514                        .map_err(|_| RPCError::Fatal("Semaphore dropped".to_string()))?;
515                    self.get_tokens(&request).await
516                }
517            })
518            .collect();
519
520        // Join all tasks - semaphore ensures only 'concurrency' execute at once
521        let responses = try_join_all(tasks).await?;
522
523        // Aggregate all tokens from all pages
524        for mut response in responses {
525            all_tokens.append(&mut response.tokens);
526        }
527
528        Ok(all_tokens)
529    }
530
531    async fn get_protocol_systems(
532        &self,
533        request: &ProtocolSystemsRequestBody,
534    ) -> Result<ProtocolSystemsRequestResponse, RPCError>;
535
536    async fn get_component_tvl(
537        &self,
538        request: &ComponentTvlRequestBody,
539    ) -> Result<ComponentTvlRequestResponse, RPCError>;
540
541    /// Retrieves component TVL for a set of component IDs.
542    /// If the `chunk_size` is `None`, it defaults to the maximum page size.
543    async fn get_component_tvl_paginated(
544        &self,
545        request: &ComponentTvlRequestBody,
546        chunk_size: Option<usize>,
547        concurrency: usize,
548    ) -> Result<ComponentTvlRequestResponse, RPCError> {
549        let semaphore = Arc::new(Semaphore::new(concurrency));
550
551        let chunk_size = chunk_size.unwrap_or(ComponentTvlRequestBody::effective_max_page_size(
552            self.compression(),
553        ) as usize);
554
555        match request.component_ids {
556            Some(ref ids) => {
557                let chunked_requests = ids
558                    .chunks(chunk_size)
559                    .enumerate()
560                    .map(|(index, _)| ComponentTvlRequestBody {
561                        chain: request.chain,
562                        protocol_system: request.protocol_system.clone(),
563                        component_ids: Some(ids.clone()),
564                        pagination: PaginationParams {
565                            page: index as i64,
566                            page_size: chunk_size as i64,
567                        },
568                    })
569                    .collect::<Vec<_>>();
570
571                let tasks: Vec<_> = chunked_requests
572                    .into_iter()
573                    .map(|req| {
574                        let sem = semaphore.clone();
575                        async move {
576                            let _permit = sem
577                                .acquire()
578                                .await
579                                .map_err(|_| RPCError::Fatal("Semaphore dropped".to_string()))?;
580                            self.get_component_tvl(&req).await
581                        }
582                    })
583                    .collect();
584
585                let responses = try_join_all(tasks).await?;
586
587                let mut merged_tvl = HashMap::new();
588                for resp in responses {
589                    for (key, value) in resp.tvl {
590                        *merged_tvl.entry(key).or_insert(0.0) = value;
591                    }
592                }
593
594                Ok(ComponentTvlRequestResponse {
595                    tvl: merged_tvl,
596                    pagination: PaginationResponse {
597                        page: 0,
598                        page_size: chunk_size as i64,
599                        total: ids.len() as i64,
600                    },
601                })
602            }
603            _ => {
604                let first_request = ComponentTvlRequestBody {
605                    chain: request.chain,
606                    protocol_system: request.protocol_system.clone(),
607                    component_ids: request.component_ids.clone(),
608                    pagination: PaginationParams { page: 0, page_size: chunk_size as i64 },
609                };
610
611                let first_response = self
612                    .get_component_tvl(&first_request)
613                    .await?;
614                let total_items = first_response.pagination.total;
615                let total_pages = (total_items as f64 / chunk_size as f64).ceil() as i64;
616
617                let mut merged_tvl = first_response.tvl;
618
619                let mut page = 1;
620                while page < total_pages {
621                    let requests_in_this_iteration = (total_pages - page).min(concurrency as i64);
622
623                    let chunked_requests: Vec<_> = (0..requests_in_this_iteration)
624                        .map(|i| ComponentTvlRequestBody {
625                            chain: request.chain,
626                            protocol_system: request.protocol_system.clone(),
627                            component_ids: request.component_ids.clone(),
628                            pagination: PaginationParams {
629                                page: page + i,
630                                page_size: chunk_size as i64,
631                            },
632                        })
633                        .collect();
634
635                    let tasks: Vec<_> = chunked_requests
636                        .into_iter()
637                        .map(|req| {
638                            let sem = semaphore.clone();
639                            async move {
640                                let _permit = sem.acquire().await.map_err(|_| {
641                                    RPCError::Fatal("Semaphore dropped".to_string())
642                                })?;
643                                self.get_component_tvl(&req).await
644                            }
645                        })
646                        .collect();
647
648                    let responses = try_join_all(tasks).await?;
649
650                    // merge hashmap
651                    for resp in responses {
652                        for (key, value) in resp.tvl {
653                            *merged_tvl.entry(key).or_insert(0.0) += value;
654                        }
655                    }
656
657                    page += concurrency as i64;
658                }
659
660                Ok(ComponentTvlRequestResponse {
661                    tvl: merged_tvl,
662                    pagination: PaginationResponse {
663                        page: 0,
664                        page_size: chunk_size as i64,
665                        total: total_items,
666                    },
667                })
668            }
669        }
670    }
671
672    async fn get_traced_entry_points(
673        &self,
674        request: &TracedEntryPointRequestBody,
675    ) -> Result<TracedEntryPointRequestResponse, RPCError>;
676
677    /// Retrieves traced entry points for a set of component IDs.
678    /// If the `chunk_size` is `None`, it defaults to the maximum page size.
679    async fn get_traced_entry_points_paginated(
680        &self,
681        chain: Chain,
682        protocol_system: &str,
683        component_ids: &[String],
684        chunk_size: Option<usize>,
685        concurrency: usize,
686    ) -> Result<TracedEntryPointRequestResponse, RPCError> {
687        let semaphore = Arc::new(Semaphore::new(concurrency));
688
689        let chunk_size = chunk_size.unwrap_or(
690            TracedEntryPointRequestBody::effective_max_page_size(self.compression()) as usize,
691        );
692
693        let chunked_bodies = component_ids
694            .chunks(chunk_size)
695            .map(|c| TracedEntryPointRequestBody {
696                chain,
697                protocol_system: protocol_system.to_string(),
698                component_ids: Some(c.to_vec()),
699                pagination: PaginationParams { page: 0, page_size: chunk_size as i64 },
700            })
701            .collect::<Vec<_>>();
702
703        let mut tasks = Vec::new();
704        for body in chunked_bodies.iter() {
705            let sem = semaphore.clone();
706            tasks.push(async move {
707                let _permit = sem
708                    .acquire()
709                    .await
710                    .map_err(|_| RPCError::Fatal("Semaphore dropped".to_string()))?;
711                self.get_traced_entry_points(body).await
712            });
713        }
714
715        try_join_all(tasks)
716            .await
717            .map(|responses| {
718                let traced_entry_points = responses
719                    .clone()
720                    .into_iter()
721                    .flat_map(|r| r.traced_entry_points)
722                    .collect();
723                let total = responses
724                    .iter()
725                    .map(|r| r.pagination.total)
726                    .sum();
727                TracedEntryPointRequestResponse {
728                    traced_entry_points,
729                    pagination: PaginationResponse { page: 0, page_size: chunk_size as i64, total },
730                }
731            })
732    }
733
734    async fn get_snapshots<'a>(
735        &self,
736        request: &SnapshotParameters<'a>,
737        chunk_size: Option<usize>,
738        concurrency: usize,
739    ) -> Result<Snapshot, RPCError>;
740}
741
742/// Configuration options for HttpRPCClient
743#[derive(Debug, Clone)]
744pub struct HttpRPCClientOptions {
745    /// Optional API key for authentication
746    pub auth_key: Option<String>,
747    /// Enable compression for requests (default: true)
748    /// When enabled, adds Accept-Encoding: zstd header
749    pub compression: bool,
750}
751
752impl Default for HttpRPCClientOptions {
753    fn default() -> Self {
754        Self::new()
755    }
756}
757
758impl HttpRPCClientOptions {
759    /// Create new options with default values (compression enabled)
760    pub fn new() -> Self {
761        Self { auth_key: None, compression: true }
762    }
763
764    /// Set the authentication key
765    pub fn with_auth_key(mut self, auth_key: Option<String>) -> Self {
766        self.auth_key = auth_key;
767        self
768    }
769
770    /// Set whether to enable compression (default: true)
771    pub fn with_compression(mut self, compression: bool) -> Self {
772        self.compression = compression;
773        self
774    }
775}
776
777#[derive(Debug, Clone)]
778pub struct HttpRPCClient {
779    http_client: Client,
780    url: Url,
781    retry_after: Arc<RwLock<Option<SystemTime>>>,
782    backoff_policy: ExponentialBackoff,
783    server_restart_duration: Duration,
784    compression: bool,
785}
786
787impl HttpRPCClient {
788    pub fn new(base_uri: &str, options: HttpRPCClientOptions) -> Result<Self, RPCError> {
789        let uri = base_uri
790            .parse::<Url>()
791            .map_err(|e| RPCError::UrlParsing(base_uri.to_string(), e.to_string()))?;
792
793        // Add default headers
794        let mut headers = header::HeaderMap::new();
795        headers.insert(header::CONTENT_TYPE, header::HeaderValue::from_static("application/json"));
796        let user_agent = format!("tycho-client-{version}", version = env!("CARGO_PKG_VERSION"));
797        headers.insert(
798            header::USER_AGENT,
799            header::HeaderValue::from_str(&user_agent)
800                .map_err(|e| RPCError::FormatRequest(format!("Invalid user agent format: {e}")))?,
801        );
802
803        // Add Authorization if one is given
804        if let Some(key) = options.auth_key.as_deref() {
805            let mut auth_value = header::HeaderValue::from_str(key).map_err(|e| {
806                RPCError::FormatRequest(format!("Invalid authorization key format: {e}"))
807            })?;
808            auth_value.set_sensitive(true);
809            headers.insert(header::AUTHORIZATION, auth_value);
810        }
811
812        let mut client_builder = ClientBuilder::new()
813            .default_headers(headers)
814            .http2_prior_knowledge();
815
816        // When compression is disabled, turn off all automatic compression
817        if !options.compression {
818            client_builder = client_builder.no_zstd();
819        }
820
821        let client = client_builder
822            .build()
823            .map_err(|e| RPCError::HttpClient(e.to_string(), e))?;
824
825        Ok(Self {
826            http_client: client,
827            url: uri,
828            retry_after: Arc::new(RwLock::new(None)),
829            backoff_policy: ExponentialBackoffBuilder::new()
830                .with_initial_interval(Duration::from_millis(250))
831                // increase backoff time by 75% each failure
832                .with_multiplier(1.75)
833                // keep retrying every 30s
834                .with_max_interval(Duration::from_secs(30))
835                // if all retries take longer than 2m, give up
836                .with_max_elapsed_time(Some(Duration::from_secs(125)))
837                .build(),
838            server_restart_duration: Duration::from_secs(120),
839            compression: options.compression,
840        })
841    }
842
843    #[cfg(test)]
844    pub fn with_test_backoff_policy(mut self) -> Self {
845        // Extremely short intervals for very fast testing
846        self.backoff_policy = ExponentialBackoffBuilder::new()
847            .with_initial_interval(Duration::from_millis(1))
848            .with_multiplier(1.1)
849            .with_max_interval(Duration::from_millis(5))
850            .with_max_elapsed_time(Some(Duration::from_millis(50)))
851            .build();
852        self.server_restart_duration = Duration::from_millis(50);
853        self
854    }
855
856    /// Converts a error response to a Result.
857    ///
858    /// Raises an error if the response status code id 429, 502, 503 or 504. In the 429
859    /// case it will try to look for a retry-after header an parse it accordingly. The
860    /// parsed value is then passed as part of the error.
861    async fn error_for_response(
862        &self,
863        response: reqwest::Response,
864    ) -> Result<reqwest::Response, RPCError> {
865        match response.status() {
866            StatusCode::TOO_MANY_REQUESTS => {
867                let retry_after_raw = response
868                    .headers()
869                    .get(reqwest::header::RETRY_AFTER)
870                    .and_then(|h| h.to_str().ok())
871                    .and_then(parse_retry_value);
872
873                Err(RPCError::RateLimited(retry_after_raw))
874            }
875            StatusCode::BAD_GATEWAY |
876            StatusCode::SERVICE_UNAVAILABLE |
877            StatusCode::GATEWAY_TIMEOUT => Err(RPCError::ServerUnreachable(
878                response
879                    .text()
880                    .await
881                    .unwrap_or_else(|_| "Server Unreachable".to_string()),
882            )),
883            _ => Ok(response),
884        }
885    }
886
887    /// Classifies errors into transient or permanent ones.
888    ///
889    /// Transient errors are retried with a potential backoff, permanent ones are not.
890    /// If the error is RateLimited, this method will set the self.retry_after value so
891    /// future requests wait until the rate limit has been reset.
892    async fn handle_error_for_backoff(&self, e: RPCError) -> backoff::Error<RPCError> {
893        match e {
894            RPCError::ServerUnreachable(_) => {
895                backoff::Error::retry_after(e, self.server_restart_duration)
896            }
897            RPCError::RateLimited(Some(until)) => {
898                let mut retry_after_guard = self.retry_after.write().await;
899                *retry_after_guard = Some(
900                    retry_after_guard
901                        .unwrap_or(until)
902                        .max(until),
903                );
904
905                if let Ok(duration) = until.duration_since(SystemTime::now()) {
906                    backoff::Error::retry_after(e, duration)
907                } else {
908                    e.into()
909                }
910            }
911            RPCError::RateLimited(None) => e.into(),
912            _ => backoff::Error::permanent(e),
913        }
914    }
915
916    /// Waits until the current rate limit time has passed.
917    ///
918    /// Only waits if there is a time and that time is in the future, else return
919    /// immediately.
920    async fn wait_until_retry_after(&self) {
921        if let Some(&until) = self.retry_after.read().await.as_ref() {
922            let now = SystemTime::now();
923            if until > now {
924                if let Ok(duration) = until.duration_since(now) {
925                    sleep(duration).await
926                }
927            }
928        }
929    }
930
931    /// Makes a post request handling transient failures.
932    ///
933    /// If a retry-after header is received it will be respected. Else the configured
934    /// backoff policy is used to deal with transient network or server errors.
935    async fn make_post_request<T: Serialize + ?Sized>(
936        &self,
937        request: &T,
938        uri: &String,
939    ) -> Result<Response, RPCError> {
940        self.wait_until_retry_after().await;
941        let response = backoff::future::retry(self.backoff_policy.clone(), || async {
942            let server_response = self
943                .http_client
944                .post(uri)
945                .json(request)
946                .send()
947                .await
948                .map_err(|e| RPCError::HttpClient(e.to_string(), e))?;
949
950            match self
951                .error_for_response(server_response)
952                .await
953            {
954                Ok(response) => Ok(response),
955                Err(e) => Err(self.handle_error_for_backoff(e).await),
956            }
957        })
958        .await?;
959        Ok(response)
960    }
961}
962
963fn parse_retry_value(val: &str) -> Option<SystemTime> {
964    if let Ok(secs) = val.parse::<u64>() {
965        return Some(SystemTime::now() + Duration::from_secs(secs));
966    }
967    if let Ok(date) = OffsetDateTime::parse(val, &Rfc2822) {
968        return Some(date.into());
969    }
970    None
971}
972
973#[async_trait]
974impl RPCClient for HttpRPCClient {
975    fn compression(&self) -> bool {
976        self.compression
977    }
978
979    #[instrument(skip(self, request))]
980    async fn get_contract_state(
981        &self,
982        request: &StateRequestBody,
983    ) -> Result<StateRequestResponse, RPCError> {
984        // Check if contract ids are specified
985        if request
986            .contract_ids
987            .as_ref()
988            .is_none_or(|ids| ids.is_empty())
989        {
990            warn!("No contract ids specified in request.");
991        }
992
993        let uri = format!(
994            "{}/{}/contract_state",
995            self.url
996                .to_string()
997                .trim_end_matches('/'),
998            TYCHO_SERVER_VERSION
999        );
1000        debug!(%uri, "Sending contract_state request to Tycho server");
1001        trace!(?request, "Sending request to Tycho server");
1002        let response = self
1003            .make_post_request(request, &uri)
1004            .await?;
1005        trace!(?response, "Received response from Tycho server");
1006
1007        let body = response
1008            .text()
1009            .await
1010            .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
1011        if body.is_empty() {
1012            // Pure native protocols will return empty contract states
1013            return Ok(StateRequestResponse {
1014                accounts: vec![],
1015                pagination: PaginationResponse {
1016                    page: request.pagination.page,
1017                    page_size: request.pagination.page,
1018                    total: 0,
1019                },
1020            });
1021        }
1022
1023        let accounts = serde_json::from_str::<StateRequestResponse>(&body)
1024            .map_err(|err| RPCError::ParseResponse(format!("Error: {err}, Body: {body}")))?;
1025        trace!(?accounts, "Received contract_state response from Tycho server");
1026
1027        Ok(accounts)
1028    }
1029
1030    async fn get_protocol_components(
1031        &self,
1032        request: &ProtocolComponentsRequestBody,
1033    ) -> Result<ProtocolComponentRequestResponse, RPCError> {
1034        let uri = format!(
1035            "{}/{}/protocol_components",
1036            self.url
1037                .to_string()
1038                .trim_end_matches('/'),
1039            TYCHO_SERVER_VERSION,
1040        );
1041        debug!(%uri, "Sending protocol_components request to Tycho server");
1042        trace!(?request, "Sending request to Tycho server");
1043
1044        let response = self
1045            .make_post_request(request, &uri)
1046            .await?;
1047
1048        trace!(?response, "Received response from Tycho server");
1049
1050        let body = response
1051            .text()
1052            .await
1053            .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
1054        let components = serde_json::from_str::<ProtocolComponentRequestResponse>(&body)
1055            .map_err(|err| RPCError::ParseResponse(format!("Error: {err}, Body: {body}")))?;
1056        trace!(?components, "Received protocol_components response from Tycho server");
1057
1058        Ok(components)
1059    }
1060
1061    async fn get_protocol_states(
1062        &self,
1063        request: &ProtocolStateRequestBody,
1064    ) -> Result<ProtocolStateRequestResponse, RPCError> {
1065        // Check if protocol ids are specified
1066        if request
1067            .protocol_ids
1068            .as_ref()
1069            .is_none_or(|ids| ids.is_empty())
1070        {
1071            warn!("No protocol ids specified in request.");
1072        }
1073
1074        let uri = format!(
1075            "{}/{}/protocol_state",
1076            self.url
1077                .to_string()
1078                .trim_end_matches('/'),
1079            TYCHO_SERVER_VERSION
1080        );
1081        debug!(%uri, "Sending protocol_states request to Tycho server");
1082        trace!(?request, "Sending request to Tycho server");
1083
1084        let response = self
1085            .make_post_request(request, &uri)
1086            .await?;
1087        trace!(?response, "Received response from Tycho server");
1088
1089        let body = response
1090            .text()
1091            .await
1092            .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
1093
1094        if body.is_empty() {
1095            // Pure VM protocols will return empty states
1096            return Ok(ProtocolStateRequestResponse {
1097                states: vec![],
1098                pagination: PaginationResponse {
1099                    page: request.pagination.page,
1100                    page_size: request.pagination.page_size,
1101                    total: 0,
1102                },
1103            });
1104        }
1105
1106        let states = serde_json::from_str::<ProtocolStateRequestResponse>(&body)
1107            .map_err(|err| RPCError::ParseResponse(format!("Error: {err}, Body: {body}")))?;
1108        trace!(?states, "Received protocol_states response from Tycho server");
1109
1110        Ok(states)
1111    }
1112
1113    async fn get_tokens(
1114        &self,
1115        request: &TokensRequestBody,
1116    ) -> Result<TokensRequestResponse, RPCError> {
1117        let uri = format!(
1118            "{}/{}/tokens",
1119            self.url
1120                .to_string()
1121                .trim_end_matches('/'),
1122            TYCHO_SERVER_VERSION
1123        );
1124        debug!(%uri, "Sending tokens request to Tycho server");
1125
1126        let response = self
1127            .make_post_request(request, &uri)
1128            .await?;
1129
1130        let body = response
1131            .text()
1132            .await
1133            .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
1134        let tokens = serde_json::from_str::<TokensRequestResponse>(&body)
1135            .map_err(|err| RPCError::ParseResponse(format!("Error: {err}, Body: {body}")))?;
1136
1137        Ok(tokens)
1138    }
1139
1140    async fn get_protocol_systems(
1141        &self,
1142        request: &ProtocolSystemsRequestBody,
1143    ) -> Result<ProtocolSystemsRequestResponse, RPCError> {
1144        let uri = format!(
1145            "{}/{}/protocol_systems",
1146            self.url
1147                .to_string()
1148                .trim_end_matches('/'),
1149            TYCHO_SERVER_VERSION
1150        );
1151        debug!(%uri, "Sending protocol_systems request to Tycho server");
1152        trace!(?request, "Sending request to Tycho server");
1153        let response = self
1154            .make_post_request(request, &uri)
1155            .await?;
1156        trace!(?response, "Received response from Tycho server");
1157        let body = response
1158            .text()
1159            .await
1160            .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
1161        let protocol_systems = serde_json::from_str::<ProtocolSystemsRequestResponse>(&body)
1162            .map_err(|err| RPCError::ParseResponse(format!("Error: {err}, Body: {body}")))?;
1163        trace!(?protocol_systems, "Received protocol_systems response from Tycho server");
1164        Ok(protocol_systems)
1165    }
1166
1167    async fn get_component_tvl(
1168        &self,
1169        request: &ComponentTvlRequestBody,
1170    ) -> Result<ComponentTvlRequestResponse, RPCError> {
1171        let uri = format!(
1172            "{}/{}/component_tvl",
1173            self.url
1174                .to_string()
1175                .trim_end_matches('/'),
1176            TYCHO_SERVER_VERSION
1177        );
1178        debug!(%uri, "Sending get_component_tvl request to Tycho server");
1179        trace!(?request, "Sending request to Tycho server");
1180        let response = self
1181            .make_post_request(request, &uri)
1182            .await?;
1183        trace!(?response, "Received response from Tycho server");
1184        let body = response
1185            .text()
1186            .await
1187            .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
1188        let component_tvl =
1189            serde_json::from_str::<ComponentTvlRequestResponse>(&body).map_err(|err| {
1190                error!("Failed to parse component_tvl response: {:?}", &body);
1191                RPCError::ParseResponse(format!("Error: {err}, Body: {body}"))
1192            })?;
1193        trace!(?component_tvl, "Received component_tvl response from Tycho server");
1194        Ok(component_tvl)
1195    }
1196
1197    async fn get_traced_entry_points(
1198        &self,
1199        request: &TracedEntryPointRequestBody,
1200    ) -> Result<TracedEntryPointRequestResponse, RPCError> {
1201        let uri = format!(
1202            "{}/{TYCHO_SERVER_VERSION}/traced_entry_points",
1203            self.url
1204                .to_string()
1205                .trim_end_matches('/')
1206        );
1207        debug!(%uri, "Sending traced_entry_points request to Tycho server");
1208        trace!(?request, "Sending request to Tycho server");
1209
1210        let response = self
1211            .make_post_request(request, &uri)
1212            .await?;
1213
1214        trace!(?response, "Received response from Tycho server");
1215
1216        let body = response
1217            .text()
1218            .await
1219            .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
1220        let entrypoints =
1221            serde_json::from_str::<TracedEntryPointRequestResponse>(&body).map_err(|err| {
1222                error!("Failed to parse traced_entry_points response: {:?}", &body);
1223                RPCError::ParseResponse(format!("Error: {err}, Body: {body}"))
1224            })?;
1225        trace!(?entrypoints, "Received traced_entry_points response from Tycho server");
1226        Ok(entrypoints)
1227    }
1228
1229    async fn get_snapshots<'a>(
1230        &self,
1231        request: &SnapshotParameters<'a>,
1232        chunk_size: Option<usize>,
1233        concurrency: usize,
1234    ) -> Result<Snapshot, RPCError> {
1235        let component_ids: Vec<_> = request
1236            .components
1237            .keys()
1238            .cloned()
1239            .collect();
1240
1241        let version = VersionParam::new(
1242            None,
1243            Some({
1244                #[allow(deprecated)]
1245                BlockParam {
1246                    hash: None,
1247                    chain: Some(request.chain),
1248                    number: Some(request.block_number as i64),
1249                }
1250            }),
1251        );
1252
1253        let component_tvl = if request.include_tvl && !component_ids.is_empty() {
1254            let body = ComponentTvlRequestBody::id_filtered(component_ids.clone(), request.chain);
1255            self.get_component_tvl_paginated(&body, chunk_size, concurrency)
1256                .await?
1257                .tvl
1258        } else {
1259            HashMap::new()
1260        };
1261
1262        let mut protocol_states = if !component_ids.is_empty() {
1263            self.get_protocol_states_paginated(
1264                request.chain,
1265                &component_ids,
1266                request.protocol_system,
1267                request.include_balances,
1268                &version,
1269                chunk_size,
1270                concurrency,
1271            )
1272            .await?
1273            .states
1274            .into_iter()
1275            .map(|state| (state.component_id.clone(), state))
1276            .collect()
1277        } else {
1278            HashMap::new()
1279        };
1280
1281        // Convert to ComponentWithState, which includes entrypoint information.
1282        let states = request
1283            .components
1284            .values()
1285            .filter_map(|component| {
1286                if let Some(state) = protocol_states.remove(&component.id) {
1287                    Some((
1288                        component.id.clone(),
1289                        ComponentWithState {
1290                            state,
1291                            component: component.clone(),
1292                            component_tvl: component_tvl
1293                                .get(&component.id)
1294                                .cloned(),
1295                            entrypoints: request
1296                                .entrypoints
1297                                .as_ref()
1298                                .and_then(|map| map.get(&component.id))
1299                                .cloned()
1300                                .unwrap_or_default(),
1301                        },
1302                    ))
1303                } else if component_ids.contains(&component.id) {
1304                    // only emit error event if we requested this component
1305                    let component_id = &component.id;
1306                    error!(?component_id, "Missing state for native component!");
1307                    None
1308                } else {
1309                    None
1310                }
1311            })
1312            .collect();
1313
1314        let vm_storage = if !request.contract_ids.is_empty() {
1315            let contract_states = self
1316                .get_contract_state_paginated(
1317                    request.chain,
1318                    request.contract_ids,
1319                    request.protocol_system,
1320                    &version,
1321                    chunk_size,
1322                    concurrency,
1323                )
1324                .await?
1325                .accounts
1326                .into_iter()
1327                .map(|acc| (acc.address.clone(), acc))
1328                .collect::<HashMap<_, _>>();
1329
1330            trace!(states=?&contract_states, "Retrieved ContractState");
1331
1332            let contract_address_to_components = request
1333                .components
1334                .iter()
1335                .filter_map(|(id, comp)| {
1336                    if component_ids.contains(id) {
1337                        Some(
1338                            comp.contract_ids
1339                                .iter()
1340                                .map(|address| (address.clone(), comp.id.clone())),
1341                        )
1342                    } else {
1343                        None
1344                    }
1345                })
1346                .flatten()
1347                .fold(HashMap::<Bytes, Vec<String>>::new(), |mut acc, (addr, c_id)| {
1348                    acc.entry(addr).or_default().push(c_id);
1349                    acc
1350                });
1351
1352            request
1353                .contract_ids
1354                .iter()
1355                .filter_map(|address| {
1356                    if let Some(state) = contract_states.get(address) {
1357                        Some((address.clone(), state.clone()))
1358                    } else if let Some(ids) = contract_address_to_components.get(address) {
1359                        // only emit error even if we did actually request this address
1360                        error!(
1361                            ?address,
1362                            ?ids,
1363                            "Component with lacking contract storage encountered!"
1364                        );
1365                        None
1366                    } else {
1367                        None
1368                    }
1369                })
1370                .collect()
1371        } else {
1372            HashMap::new()
1373        };
1374
1375        Ok(Snapshot { states, vm_storage })
1376    }
1377}
1378
1379#[cfg(test)]
1380mod tests {
1381    use std::{
1382        collections::{HashMap, HashSet},
1383        str::FromStr,
1384    };
1385
1386    use mockito::Server;
1387    use rstest::rstest;
1388    // TODO: remove once deprecated ProtocolId struct is removed
1389    #[allow(deprecated)]
1390    use tycho_common::dto::ProtocolId;
1391    use tycho_common::dto::{AddressStorageLocation, TracingParams};
1392
1393    use super::*;
1394
1395    // Dummy implementation of `get_protocol_states_paginated` for backwards compatibility testing
1396    // purposes
1397    impl MockRPCClient {
1398        #[allow(clippy::too_many_arguments)]
1399        async fn test_get_protocol_states_paginated<T>(
1400            &self,
1401            chain: Chain,
1402            ids: &[T],
1403            protocol_system: &str,
1404            include_balances: bool,
1405            version: &VersionParam,
1406            chunk_size: usize,
1407            _concurrency: usize,
1408        ) -> Vec<ProtocolStateRequestBody>
1409        where
1410            T: AsRef<str> + Clone + Send + Sync + 'static,
1411        {
1412            ids.chunks(chunk_size)
1413                .map(|chunk| ProtocolStateRequestBody {
1414                    protocol_ids: Some(
1415                        chunk
1416                            .iter()
1417                            .map(|id| id.as_ref().to_string())
1418                            .collect(),
1419                    ),
1420                    protocol_system: protocol_system.to_string(),
1421                    chain,
1422                    include_balances,
1423                    version: version.clone(),
1424                    pagination: PaginationParams { page: 0, page_size: chunk_size as i64 },
1425                })
1426                .collect()
1427        }
1428    }
1429
1430    const GET_CONTRACT_STATE_RESP: &str = r#"
1431        {
1432            "accounts": [
1433                {
1434                    "chain": "ethereum",
1435                    "address": "0x0000000000000000000000000000000000000000",
1436                    "title": "",
1437                    "slots": {},
1438                    "native_balance": "0x01f4",
1439                    "token_balances": {},
1440                    "code": "0x00",
1441                    "code_hash": "0x5c06b7c5b3d910fd33bc2229846f9ddaf91d584d9b196e16636901ac3a77077e",
1442                    "balance_modify_tx": "0x0000000000000000000000000000000000000000000000000000000000000000",
1443                    "code_modify_tx": "0x0000000000000000000000000000000000000000000000000000000000000000",
1444                    "creation_tx": null
1445                }
1446            ],
1447            "pagination": {
1448                "page": 0,
1449                "page_size": 20,
1450                "total": 10
1451            }
1452        }
1453        "#;
1454
1455    // TODO: remove once deprecated ProtocolId struct is removed
1456    #[allow(deprecated)]
1457    #[rstest]
1458    #[case::protocol_id_input(vec![
1459        ProtocolId { id: "id1".to_string(), chain: Chain::Ethereum },
1460        ProtocolId { id: "id2".to_string(), chain: Chain::Ethereum }
1461    ])]
1462    #[case::string_input(vec![
1463        "id1".to_string(),
1464        "id2".to_string()
1465    ])]
1466    #[tokio::test]
1467    async fn test_get_protocol_states_paginated_backwards_compatibility<T>(#[case] ids: Vec<T>)
1468    where
1469        T: AsRef<str> + Clone + Send + Sync + 'static,
1470    {
1471        let mock_client = MockRPCClient::new();
1472
1473        let request_bodies = mock_client
1474            .test_get_protocol_states_paginated(
1475                Chain::Ethereum,
1476                &ids,
1477                "test_system",
1478                true,
1479                &VersionParam::default(),
1480                2,
1481                2,
1482            )
1483            .await;
1484
1485        // Verify that the request bodies have been created correctly
1486        assert_eq!(request_bodies.len(), 1);
1487        assert_eq!(
1488            request_bodies[0]
1489                .protocol_ids
1490                .as_ref()
1491                .unwrap()
1492                .len(),
1493            2
1494        );
1495    }
1496
1497    #[tokio::test]
1498    async fn test_get_contract_state() {
1499        let mut server = Server::new_async().await;
1500        let server_resp = GET_CONTRACT_STATE_RESP;
1501        // test that the response is deserialized correctly
1502        serde_json::from_str::<StateRequestResponse>(server_resp).expect("deserialize");
1503
1504        let mocked_server = server
1505            .mock("POST", "/v1/contract_state")
1506            .expect(1)
1507            .with_body(server_resp)
1508            .create_async()
1509            .await;
1510
1511        let client = HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
1512            .expect("create client");
1513
1514        let response = client
1515            .get_contract_state(&Default::default())
1516            .await
1517            .expect("get state");
1518        let accounts = response.accounts;
1519
1520        mocked_server.assert();
1521        assert_eq!(accounts.len(), 1);
1522        assert_eq!(accounts[0].slots, HashMap::new());
1523        assert_eq!(accounts[0].native_balance, Bytes::from(500u16.to_be_bytes()));
1524        assert_eq!(accounts[0].code, [0].to_vec());
1525        assert_eq!(
1526            accounts[0].code_hash,
1527            hex::decode("5c06b7c5b3d910fd33bc2229846f9ddaf91d584d9b196e16636901ac3a77077e")
1528                .unwrap()
1529        );
1530    }
1531
1532    #[tokio::test]
1533    async fn test_get_protocol_components() {
1534        let mut server = Server::new_async().await;
1535        let server_resp = r#"
1536        {
1537            "protocol_components": [
1538                {
1539                    "id": "State1",
1540                    "protocol_system": "ambient",
1541                    "protocol_type_name": "Pool",
1542                    "chain": "ethereum",
1543                    "tokens": [
1544                        "0x0000000000000000000000000000000000000000",
1545                        "0x0000000000000000000000000000000000000001"
1546                    ],
1547                    "contract_ids": [
1548                        "0x0000000000000000000000000000000000000000"
1549                    ],
1550                    "static_attributes": {
1551                        "attribute_1": "0x00000000000003e8"
1552                    },
1553                    "change": "Creation",
1554                    "creation_tx": "0x0000000000000000000000000000000000000000000000000000000000000000",
1555                    "created_at": "2022-01-01T00:00:00"
1556                }
1557            ],
1558            "pagination": {
1559                "page": 0,
1560                "page_size": 20,
1561                "total": 10
1562            }
1563        }
1564        "#;
1565        // test that the response is deserialized correctly
1566        serde_json::from_str::<ProtocolComponentRequestResponse>(server_resp).expect("deserialize");
1567
1568        let mocked_server = server
1569            .mock("POST", "/v1/protocol_components")
1570            .expect(1)
1571            .with_body(server_resp)
1572            .create_async()
1573            .await;
1574
1575        let client = HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
1576            .expect("create client");
1577
1578        let response = client
1579            .get_protocol_components(&Default::default())
1580            .await
1581            .expect("get state");
1582        let components = response.protocol_components;
1583
1584        mocked_server.assert();
1585        assert_eq!(components.len(), 1);
1586        assert_eq!(components[0].id, "State1");
1587        assert_eq!(components[0].protocol_system, "ambient");
1588        assert_eq!(components[0].protocol_type_name, "Pool");
1589        assert_eq!(components[0].tokens.len(), 2);
1590        let expected_attributes =
1591            [("attribute_1".to_string(), Bytes::from(1000_u64.to_be_bytes()))]
1592                .iter()
1593                .cloned()
1594                .collect::<HashMap<String, Bytes>>();
1595        assert_eq!(components[0].static_attributes, expected_attributes);
1596    }
1597
1598    #[tokio::test]
1599    async fn test_get_protocol_states() {
1600        let mut server = Server::new_async().await;
1601        let server_resp = r#"
1602        {
1603            "states": [
1604                {
1605                    "component_id": "State1",
1606                    "attributes": {
1607                        "attribute_1": "0x00000000000003e8"
1608                    },
1609                    "balances": {
1610                        "0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2": "0x01f4"
1611                    }
1612                }
1613            ],
1614            "pagination": {
1615                "page": 0,
1616                "page_size": 20,
1617                "total": 10
1618            }
1619        }
1620        "#;
1621        // test that the response is deserialized correctly
1622        serde_json::from_str::<ProtocolStateRequestResponse>(server_resp).expect("deserialize");
1623
1624        let mocked_server = server
1625            .mock("POST", "/v1/protocol_state")
1626            .expect(1)
1627            .with_body(server_resp)
1628            .create_async()
1629            .await;
1630        let client = HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
1631            .expect("create client");
1632
1633        let response = client
1634            .get_protocol_states(&Default::default())
1635            .await
1636            .expect("get state");
1637        let states = response.states;
1638
1639        mocked_server.assert();
1640        assert_eq!(states.len(), 1);
1641        assert_eq!(states[0].component_id, "State1");
1642        let expected_attributes =
1643            [("attribute_1".to_string(), Bytes::from(1000_u64.to_be_bytes()))]
1644                .iter()
1645                .cloned()
1646                .collect::<HashMap<String, Bytes>>();
1647        assert_eq!(states[0].attributes, expected_attributes);
1648        let expected_balances = [(
1649            Bytes::from_str("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2")
1650                .expect("Unsupported address format"),
1651            Bytes::from_str("0x01f4").unwrap(),
1652        )]
1653        .iter()
1654        .cloned()
1655        .collect::<HashMap<Bytes, Bytes>>();
1656        assert_eq!(states[0].balances, expected_balances);
1657    }
1658
1659    #[tokio::test]
1660    async fn test_get_tokens() {
1661        let mut server = Server::new_async().await;
1662        let server_resp = r#"
1663        {
1664            "tokens": [
1665              {
1666                "chain": "ethereum",
1667                "address": "0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2",
1668                "symbol": "WETH",
1669                "decimals": 18,
1670                "tax": 0,
1671                "gas": [
1672                  29962
1673                ],
1674                "quality": 100
1675              },
1676              {
1677                "chain": "ethereum",
1678                "address": "0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48",
1679                "symbol": "USDC",
1680                "decimals": 6,
1681                "tax": 0,
1682                "gas": [
1683                  40652
1684                ],
1685                "quality": 100
1686              }
1687            ],
1688            "pagination": {
1689              "page": 0,
1690              "page_size": 20,
1691              "total": 10
1692            }
1693          }
1694        "#;
1695        // test that the response is deserialized correctly
1696        serde_json::from_str::<TokensRequestResponse>(server_resp).expect("deserialize");
1697
1698        let mocked_server = server
1699            .mock("POST", "/v1/tokens")
1700            .expect(1)
1701            .with_body(server_resp)
1702            .create_async()
1703            .await;
1704        let client = HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
1705            .expect("create client");
1706
1707        let response = client
1708            .get_tokens(&Default::default())
1709            .await
1710            .expect("get tokens");
1711
1712        let expected = vec![
1713            ResponseToken {
1714                chain: Chain::Ethereum,
1715                address: Bytes::from_str("0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2").unwrap(),
1716                symbol: "WETH".to_string(),
1717                decimals: 18,
1718                tax: 0,
1719                gas: vec![Some(29962)],
1720                quality: 100,
1721            },
1722            ResponseToken {
1723                chain: Chain::Ethereum,
1724                address: Bytes::from_str("0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48").unwrap(),
1725                symbol: "USDC".to_string(),
1726                decimals: 6,
1727                tax: 0,
1728                gas: vec![Some(40652)],
1729                quality: 100,
1730            },
1731        ];
1732
1733        mocked_server.assert();
1734        assert_eq!(response.tokens, expected);
1735        assert_eq!(response.pagination, PaginationResponse { page: 0, page_size: 20, total: 10 });
1736    }
1737
1738    #[tokio::test]
1739    async fn test_get_protocol_systems() {
1740        let mut server = Server::new_async().await;
1741        let server_resp = r#"
1742        {
1743            "protocol_systems": [
1744                "system1",
1745                "system2"
1746            ],
1747            "pagination": {
1748                "page": 0,
1749                "page_size": 20,
1750                "total": 10
1751            }
1752        }
1753        "#;
1754        // test that the response is deserialized correctly
1755        serde_json::from_str::<ProtocolSystemsRequestResponse>(server_resp).expect("deserialize");
1756
1757        let mocked_server = server
1758            .mock("POST", "/v1/protocol_systems")
1759            .expect(1)
1760            .with_body(server_resp)
1761            .create_async()
1762            .await;
1763        let client = HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
1764            .expect("create client");
1765
1766        let response = client
1767            .get_protocol_systems(&Default::default())
1768            .await
1769            .expect("get protocol systems");
1770        let protocol_systems = response.protocol_systems;
1771
1772        mocked_server.assert();
1773        assert_eq!(protocol_systems, vec!["system1", "system2"]);
1774    }
1775
1776    #[tokio::test]
1777    async fn test_get_component_tvl() {
1778        let mut server = Server::new_async().await;
1779        let server_resp = r#"
1780        {
1781            "tvl": {
1782                "component1": 100.0
1783            },
1784            "pagination": {
1785                "page": 0,
1786                "page_size": 20,
1787                "total": 10
1788            }
1789        }
1790        "#;
1791        // test that the response is deserialized correctly
1792        serde_json::from_str::<ComponentTvlRequestResponse>(server_resp).expect("deserialize");
1793
1794        let mocked_server = server
1795            .mock("POST", "/v1/component_tvl")
1796            .expect(1)
1797            .with_body(server_resp)
1798            .create_async()
1799            .await;
1800        let client = HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
1801            .expect("create client");
1802
1803        let response = client
1804            .get_component_tvl(&Default::default())
1805            .await
1806            .expect("get protocol systems");
1807        let component_tvl = response.tvl;
1808
1809        mocked_server.assert();
1810        assert_eq!(component_tvl.get("component1"), Some(&100.0));
1811    }
1812
1813    #[tokio::test]
1814    async fn test_get_traced_entry_points() {
1815        let mut server = Server::new_async().await;
1816        let server_resp = r#"
1817        {
1818            "traced_entry_points": {
1819                "component_1": [
1820                    [
1821                        {
1822                            "entry_point": {
1823                                "external_id": "entrypoint_a",
1824                                "target": "0x0000000000000000000000000000000000000001",
1825                                "signature": "sig()"
1826                            },
1827                            "params": {
1828                                "method": "rpctracer",
1829                                "caller": "0x000000000000000000000000000000000000000a",
1830                                "calldata": "0x000000000000000000000000000000000000000b"
1831                            }
1832                        },
1833                        {
1834                            "retriggers": [
1835                                [
1836                                    "0x00000000000000000000000000000000000000aa",
1837                                    {"key": "0x0000000000000000000000000000000000000aaa", "offset": 12}
1838                                ]
1839                            ],
1840                            "accessed_slots": {
1841                                "0x0000000000000000000000000000000000aaaa": [
1842                                    "0x0000000000000000000000000000000000aaaa"
1843                                ]
1844                            }
1845                        }
1846                    ]
1847                ]
1848            },
1849            "pagination": {
1850                "page": 0,
1851                "page_size": 20,
1852                "total": 1
1853            }
1854        }
1855        "#;
1856        // test that the response is deserialized correctly
1857        serde_json::from_str::<TracedEntryPointRequestResponse>(server_resp).expect("deserialize");
1858
1859        let mocked_server = server
1860            .mock("POST", "/v1/traced_entry_points")
1861            .expect(1)
1862            .with_body(server_resp)
1863            .create_async()
1864            .await;
1865        let client = HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
1866            .expect("create client");
1867
1868        let response = client
1869            .get_traced_entry_points(&Default::default())
1870            .await
1871            .expect("get traced entry points");
1872        let entrypoints = response.traced_entry_points;
1873
1874        mocked_server.assert();
1875        assert_eq!(entrypoints.len(), 1);
1876        let comp1_entrypoints = entrypoints
1877            .get("component_1")
1878            .expect("component_1 entrypoints should exist");
1879        assert_eq!(comp1_entrypoints.len(), 1);
1880
1881        let (entrypoint, trace_result) = &comp1_entrypoints[0];
1882        assert_eq!(entrypoint.entry_point.external_id, "entrypoint_a");
1883        assert_eq!(
1884            entrypoint.entry_point.target,
1885            Bytes::from_str("0x0000000000000000000000000000000000000001").unwrap()
1886        );
1887        assert_eq!(entrypoint.entry_point.signature, "sig()");
1888        let TracingParams::RPCTracer(rpc_params) = &entrypoint.params;
1889        assert_eq!(
1890            rpc_params.caller,
1891            Some(Bytes::from("0x000000000000000000000000000000000000000a"))
1892        );
1893        assert_eq!(rpc_params.calldata, Bytes::from("0x000000000000000000000000000000000000000b"));
1894
1895        assert_eq!(
1896            trace_result.retriggers,
1897            HashSet::from([(
1898                Bytes::from("0x00000000000000000000000000000000000000aa"),
1899                AddressStorageLocation::new(
1900                    Bytes::from("0x0000000000000000000000000000000000000aaa"),
1901                    12
1902                )
1903            )])
1904        );
1905        assert_eq!(trace_result.accessed_slots.len(), 1);
1906        assert_eq!(
1907            trace_result.accessed_slots,
1908            HashMap::from([(
1909                Bytes::from("0x0000000000000000000000000000000000aaaa"),
1910                HashSet::from([Bytes::from("0x0000000000000000000000000000000000aaaa")])
1911            )])
1912        );
1913    }
1914
1915    #[tokio::test]
1916    async fn test_parse_retry_value_numeric() {
1917        let result = parse_retry_value("60");
1918        assert!(result.is_some());
1919
1920        let expected_time = SystemTime::now() + Duration::from_secs(60);
1921        let actual_time = result.unwrap();
1922
1923        // Allow for small timing differences during test execution
1924        let diff = if actual_time > expected_time {
1925            actual_time
1926                .duration_since(expected_time)
1927                .unwrap()
1928        } else {
1929            expected_time
1930                .duration_since(actual_time)
1931                .unwrap()
1932        };
1933        assert!(diff < Duration::from_secs(1), "Time difference too large: {:?}", diff);
1934    }
1935
1936    #[tokio::test]
1937    async fn test_parse_retry_value_rfc2822() {
1938        // Use a fixed future date in RFC2822 format
1939        let rfc2822_date = "Sat, 01 Jan 2030 12:00:00 +0000";
1940        let result = parse_retry_value(rfc2822_date);
1941        assert!(result.is_some());
1942
1943        let parsed_time = result.unwrap();
1944        assert!(parsed_time > SystemTime::now());
1945    }
1946
1947    #[tokio::test]
1948    async fn test_parse_retry_value_invalid_formats() {
1949        // Test various invalid formats
1950        assert!(parse_retry_value("invalid").is_none());
1951        assert!(parse_retry_value("").is_none());
1952        assert!(parse_retry_value("not_a_number").is_none());
1953        assert!(parse_retry_value("Mon, 32 Jan 2030 25:00:00 +0000").is_none()); // Invalid date
1954    }
1955
1956    #[tokio::test]
1957    async fn test_parse_retry_value_zero_seconds() {
1958        let result = parse_retry_value("0");
1959        assert!(result.is_some());
1960
1961        let expected_time = SystemTime::now();
1962        let actual_time = result.unwrap();
1963
1964        // Should be very close to current time
1965        let diff = if actual_time > expected_time {
1966            actual_time
1967                .duration_since(expected_time)
1968                .unwrap()
1969        } else {
1970            expected_time
1971                .duration_since(actual_time)
1972                .unwrap()
1973        };
1974        assert!(diff < Duration::from_secs(1));
1975    }
1976
1977    #[tokio::test]
1978    async fn test_error_for_response_rate_limited() {
1979        let mut server = Server::new_async().await;
1980        let mock = server
1981            .mock("GET", "/test")
1982            .with_status(429)
1983            .with_header("Retry-After", "60")
1984            .create_async()
1985            .await;
1986
1987        let client = reqwest::Client::new();
1988        let response = client
1989            .get(format!("{}/test", server.url()))
1990            .send()
1991            .await
1992            .unwrap();
1993
1994        let http_client =
1995            HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
1996                .unwrap()
1997                .with_test_backoff_policy();
1998        let result = http_client
1999            .error_for_response(response)
2000            .await;
2001
2002        mock.assert();
2003        assert!(matches!(result, Err(RPCError::RateLimited(_))));
2004        if let Err(RPCError::RateLimited(retry_after)) = result {
2005            assert!(retry_after.is_some());
2006        }
2007    }
2008
2009    #[tokio::test]
2010    async fn test_error_for_response_rate_limited_no_header() {
2011        let mut server = Server::new_async().await;
2012        let mock = server
2013            .mock("GET", "/test")
2014            .with_status(429)
2015            .create_async()
2016            .await;
2017
2018        let client = reqwest::Client::new();
2019        let response = client
2020            .get(format!("{}/test", server.url()))
2021            .send()
2022            .await
2023            .unwrap();
2024
2025        let http_client =
2026            HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
2027                .unwrap()
2028                .with_test_backoff_policy();
2029        let result = http_client
2030            .error_for_response(response)
2031            .await;
2032
2033        mock.assert();
2034        assert!(matches!(result, Err(RPCError::RateLimited(None))));
2035    }
2036
2037    #[tokio::test]
2038    async fn test_error_for_response_server_errors() {
2039        let test_cases =
2040            vec![(502, "Bad Gateway"), (503, "Service Unavailable"), (504, "Gateway Timeout")];
2041
2042        for (status_code, expected_body) in test_cases {
2043            let mut server = Server::new_async().await;
2044            let mock = server
2045                .mock("GET", "/test")
2046                .with_status(status_code)
2047                .with_body(expected_body)
2048                .create_async()
2049                .await;
2050
2051            let client = reqwest::Client::new();
2052            let response = client
2053                .get(format!("{}/test", server.url()))
2054                .send()
2055                .await
2056                .unwrap();
2057
2058            let http_client =
2059                HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
2060                    .unwrap()
2061                    .with_test_backoff_policy();
2062            let result = http_client
2063                .error_for_response(response)
2064                .await;
2065
2066            mock.assert();
2067            assert!(matches!(result, Err(RPCError::ServerUnreachable(_))));
2068            if let Err(RPCError::ServerUnreachable(body)) = result {
2069                assert_eq!(body, expected_body);
2070            }
2071        }
2072    }
2073
2074    #[tokio::test]
2075    async fn test_error_for_response_success() {
2076        let mut server = Server::new_async().await;
2077        let mock = server
2078            .mock("GET", "/test")
2079            .with_status(200)
2080            .with_body("success")
2081            .create_async()
2082            .await;
2083
2084        let client = reqwest::Client::new();
2085        let response = client
2086            .get(format!("{}/test", server.url()))
2087            .send()
2088            .await
2089            .unwrap();
2090
2091        let http_client =
2092            HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
2093                .unwrap()
2094                .with_test_backoff_policy();
2095        let result = http_client
2096            .error_for_response(response)
2097            .await;
2098
2099        mock.assert();
2100        assert!(result.is_ok());
2101
2102        let response = result.unwrap();
2103        assert_eq!(response.status(), 200);
2104    }
2105
2106    #[tokio::test]
2107    async fn test_handle_error_for_backoff_server_unreachable() {
2108        let http_client =
2109            HttpRPCClient::new("http://localhost:8080", HttpRPCClientOptions::default())
2110                .unwrap()
2111                .with_test_backoff_policy();
2112        let error = RPCError::ServerUnreachable("Service down".to_string());
2113
2114        let backoff_error = http_client
2115            .handle_error_for_backoff(error)
2116            .await;
2117
2118        match backoff_error {
2119            backoff::Error::Transient { err: RPCError::ServerUnreachable(msg), retry_after } => {
2120                assert_eq!(msg, "Service down");
2121                assert_eq!(retry_after, Some(Duration::from_millis(50))); // Fast test duration
2122            }
2123            _ => panic!("Expected transient error for ServerUnreachable"),
2124        }
2125    }
2126
2127    #[tokio::test]
2128    async fn test_handle_error_for_backoff_rate_limited_with_retry_after() {
2129        let http_client =
2130            HttpRPCClient::new("http://localhost:8080", HttpRPCClientOptions::default())
2131                .unwrap()
2132                .with_test_backoff_policy();
2133        let future_time = SystemTime::now() + Duration::from_secs(30);
2134        let error = RPCError::RateLimited(Some(future_time));
2135
2136        let backoff_error = http_client
2137            .handle_error_for_backoff(error)
2138            .await;
2139
2140        match backoff_error {
2141            backoff::Error::Transient { err: RPCError::RateLimited(retry_after), .. } => {
2142                assert_eq!(retry_after, Some(future_time));
2143            }
2144            _ => panic!("Expected transient error for RateLimited"),
2145        }
2146
2147        // Verify that retry_after was stored in the client state
2148        let stored_retry_after = http_client.retry_after.read().await;
2149        assert_eq!(*stored_retry_after, Some(future_time));
2150    }
2151
2152    #[tokio::test]
2153    async fn test_handle_error_for_backoff_rate_limited_no_retry_after() {
2154        let http_client =
2155            HttpRPCClient::new("http://localhost:8080", HttpRPCClientOptions::default())
2156                .unwrap()
2157                .with_test_backoff_policy();
2158        let error = RPCError::RateLimited(None);
2159
2160        let backoff_error = http_client
2161            .handle_error_for_backoff(error)
2162            .await;
2163
2164        match backoff_error {
2165            backoff::Error::Transient { err: RPCError::RateLimited(None), .. } => {
2166                // This is expected - no retry-after still allows retries with default policy
2167            }
2168            _ => panic!("Expected transient error for RateLimited without retry-after"),
2169        }
2170    }
2171
2172    #[tokio::test]
2173    async fn test_handle_error_for_backoff_other_errors() {
2174        let http_client =
2175            HttpRPCClient::new("http://localhost:8080", HttpRPCClientOptions::default())
2176                .unwrap()
2177                .with_test_backoff_policy();
2178        let error = RPCError::ParseResponse("Invalid JSON".to_string());
2179
2180        let backoff_error = http_client
2181            .handle_error_for_backoff(error)
2182            .await;
2183
2184        match backoff_error {
2185            backoff::Error::Permanent(RPCError::ParseResponse(msg)) => {
2186                assert_eq!(msg, "Invalid JSON");
2187            }
2188            _ => panic!("Expected permanent error for ParseResponse"),
2189        }
2190    }
2191
2192    #[tokio::test]
2193    async fn test_wait_until_retry_after_no_retry_time() {
2194        let http_client =
2195            HttpRPCClient::new("http://localhost:8080", HttpRPCClientOptions::default())
2196                .unwrap()
2197                .with_test_backoff_policy();
2198
2199        let start = std::time::Instant::now();
2200        http_client
2201            .wait_until_retry_after()
2202            .await;
2203        let elapsed = start.elapsed();
2204
2205        // Should return immediately if no retry time is set
2206        assert!(elapsed < Duration::from_millis(100));
2207    }
2208
2209    #[tokio::test]
2210    async fn test_wait_until_retry_after_past_time() {
2211        let http_client =
2212            HttpRPCClient::new("http://localhost:8080", HttpRPCClientOptions::default())
2213                .unwrap()
2214                .with_test_backoff_policy();
2215
2216        // Set a retry time in the past
2217        let past_time = SystemTime::now() - Duration::from_secs(10);
2218        *http_client.retry_after.write().await = Some(past_time);
2219
2220        let start = std::time::Instant::now();
2221        http_client
2222            .wait_until_retry_after()
2223            .await;
2224        let elapsed = start.elapsed();
2225
2226        // Should return immediately if retry time is in the past
2227        assert!(elapsed < Duration::from_millis(100));
2228    }
2229
2230    #[tokio::test]
2231    async fn test_wait_until_retry_after_future_time() {
2232        let http_client =
2233            HttpRPCClient::new("http://localhost:8080", HttpRPCClientOptions::default())
2234                .unwrap()
2235                .with_test_backoff_policy();
2236
2237        // Set a retry time 100ms in the future
2238        let future_time = SystemTime::now() + Duration::from_millis(100);
2239        *http_client.retry_after.write().await = Some(future_time);
2240
2241        let start = std::time::Instant::now();
2242        http_client
2243            .wait_until_retry_after()
2244            .await;
2245        let elapsed = start.elapsed();
2246
2247        // Should wait approximately the specified duration
2248        assert!(elapsed >= Duration::from_millis(80)); // Allow some tolerance
2249        assert!(elapsed <= Duration::from_millis(200)); // Upper bound for test stability
2250    }
2251
2252    #[tokio::test]
2253    async fn test_make_post_request_success() {
2254        let mut server = Server::new_async().await;
2255        let server_resp = r#"{"success": true}"#;
2256
2257        let mock = server
2258            .mock("POST", "/test")
2259            .with_status(200)
2260            .with_body(server_resp)
2261            .create_async()
2262            .await;
2263
2264        let http_client =
2265            HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
2266                .unwrap()
2267                .with_test_backoff_policy();
2268        let request_body = serde_json::json!({"test": "data"});
2269        let uri = format!("{}/test", server.url());
2270
2271        let result = http_client
2272            .make_post_request(&request_body, &uri)
2273            .await;
2274
2275        mock.assert();
2276        assert!(result.is_ok());
2277
2278        let response = result.unwrap();
2279        assert_eq!(response.status(), 200);
2280        assert_eq!(response.text().await.unwrap(), server_resp);
2281    }
2282
2283    #[tokio::test]
2284    async fn test_make_post_request_retry_on_server_error() {
2285        let mut server = Server::new_async().await;
2286        // First request fails with 503, second succeeds
2287        let error_mock = server
2288            .mock("POST", "/test")
2289            .with_status(503)
2290            .with_body("Service Unavailable")
2291            .expect(1)
2292            .create_async()
2293            .await;
2294
2295        let success_mock = server
2296            .mock("POST", "/test")
2297            .with_status(200)
2298            .with_body(r#"{"success": true}"#)
2299            .expect(1)
2300            .create_async()
2301            .await;
2302
2303        let http_client =
2304            HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
2305                .unwrap()
2306                .with_test_backoff_policy();
2307        let request_body = serde_json::json!({"test": "data"});
2308        let uri = format!("{}/test", server.url());
2309
2310        let result = http_client
2311            .make_post_request(&request_body, &uri)
2312            .await;
2313
2314        error_mock.assert();
2315        success_mock.assert();
2316        assert!(result.is_ok());
2317    }
2318
2319    #[tokio::test]
2320    async fn test_make_post_request_respect_retry_after_header() {
2321        let mut server = Server::new_async().await;
2322
2323        // First request returns 429 with retry-after, second succeeds
2324        let rate_limit_mock = server
2325            .mock("POST", "/test")
2326            .with_status(429)
2327            .with_header("Retry-After", "1") // 1 second
2328            .expect(1)
2329            .create_async()
2330            .await;
2331
2332        let success_mock = server
2333            .mock("POST", "/test")
2334            .with_status(200)
2335            .with_body(r#"{"success": true}"#)
2336            .expect(1)
2337            .create_async()
2338            .await;
2339
2340        let http_client =
2341            HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
2342                .unwrap()
2343                .with_test_backoff_policy();
2344        let request_body = serde_json::json!({"test": "data"});
2345        let uri = format!("{}/test", server.url());
2346
2347        let start = std::time::Instant::now();
2348        let result = http_client
2349            .make_post_request(&request_body, &uri)
2350            .await;
2351        let elapsed = start.elapsed();
2352
2353        rate_limit_mock.assert();
2354        success_mock.assert();
2355        assert!(result.is_ok());
2356
2357        // Should have waited at least 1 second due to retry-after header
2358        assert!(elapsed >= Duration::from_millis(900)); // Allow some tolerance
2359        assert!(elapsed <= Duration::from_millis(2000)); // Upper bound for test stability
2360    }
2361
2362    #[tokio::test]
2363    async fn test_make_post_request_permanent_error() {
2364        let mut server = Server::new_async().await;
2365
2366        let mock = server
2367            .mock("POST", "/test")
2368            .with_status(400) // Bad Request - should not be retried
2369            .with_body("Bad Request")
2370            .expect(1)
2371            .create_async()
2372            .await;
2373
2374        let http_client =
2375            HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
2376                .unwrap()
2377                .with_test_backoff_policy();
2378        let request_body = serde_json::json!({"test": "data"});
2379        let uri = format!("{}/test", server.url());
2380
2381        let result = http_client
2382            .make_post_request(&request_body, &uri)
2383            .await;
2384
2385        mock.assert();
2386        assert!(result.is_ok()); // 400 doesn't trigger retry logic, just returns the response
2387
2388        let response = result.unwrap();
2389        assert_eq!(response.status(), 400);
2390    }
2391
2392    #[tokio::test]
2393    async fn test_concurrent_requests_with_different_retry_after() {
2394        let mut server = Server::new_async().await;
2395
2396        // First request gets rate limited with 1 second retry-after
2397        let rate_limit_mock_1 = server
2398            .mock("POST", "/test1")
2399            .with_status(429)
2400            .with_header("Retry-After", "1")
2401            .expect(1)
2402            .create_async()
2403            .await;
2404
2405        // Second request gets rate limited with 2 second retry-after
2406        let rate_limit_mock_2 = server
2407            .mock("POST", "/test2")
2408            .with_status(429)
2409            .with_header("Retry-After", "2")
2410            .expect(1)
2411            .create_async()
2412            .await;
2413
2414        // Success mocks for retries
2415        let success_mock_1 = server
2416            .mock("POST", "/test1")
2417            .with_status(200)
2418            .with_body(r#"{"result": "success1"}"#)
2419            .expect(1)
2420            .create_async()
2421            .await;
2422
2423        let success_mock_2 = server
2424            .mock("POST", "/test2")
2425            .with_status(200)
2426            .with_body(r#"{"result": "success2"}"#)
2427            .expect(1)
2428            .create_async()
2429            .await;
2430
2431        let http_client =
2432            HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
2433                .unwrap()
2434                .with_test_backoff_policy();
2435        let request_body = serde_json::json!({"test": "data"});
2436
2437        let uri1 = format!("{}/test1", server.url());
2438        let uri2 = format!("{}/test2", server.url());
2439
2440        // Start both requests concurrently
2441        let start = std::time::Instant::now();
2442        let (result1, result2) = tokio::join!(
2443            http_client.make_post_request(&request_body, &uri1),
2444            http_client.make_post_request(&request_body, &uri2)
2445        );
2446        let elapsed = start.elapsed();
2447
2448        rate_limit_mock_1.assert();
2449        rate_limit_mock_2.assert();
2450        success_mock_1.assert();
2451        success_mock_2.assert();
2452
2453        assert!(result1.is_ok());
2454        assert!(result2.is_ok());
2455
2456        // Both requests should succeed, but the second should take longer due to the 2s retry-after
2457        // The total time should be at least 2 seconds since the shared retry_after state
2458        // gets updated by both requests
2459        assert!(elapsed >= Duration::from_millis(1800)); // Allow some tolerance
2460        assert!(elapsed <= Duration::from_millis(3000)); // Upper bound
2461
2462        // Check the final retry_after state - should be the latest (higher) value
2463        let final_retry_after = http_client.retry_after.read().await;
2464        assert!(final_retry_after.is_some());
2465
2466        // The retry_after should be set to the latest (higher) value from the two requests
2467        if let Some(retry_time) = *final_retry_after {
2468            // The retry_after time might be in the past now since we waited,
2469            // but it should be reasonable (not too far in past/future)
2470            let now = SystemTime::now();
2471            let diff = if retry_time > now {
2472                retry_time.duration_since(now).unwrap()
2473            } else {
2474                now.duration_since(retry_time).unwrap()
2475            };
2476
2477            // Should be within a reasonable range (the 2s retry-after plus some buffer)
2478            assert!(diff <= Duration::from_secs(3), "Retry time difference too large: {:?}", diff);
2479        }
2480    }
2481
2482    #[tokio::test]
2483    async fn test_get_snapshots() {
2484        let mut server = Server::new_async().await;
2485
2486        // Mock protocol states response
2487        let protocol_states_resp = r#"
2488        {
2489            "states": [
2490                {
2491                    "component_id": "component1",
2492                    "attributes": {
2493                        "attribute_1": "0x00000000000003e8"
2494                    },
2495                    "balances": {
2496                        "0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2": "0x01f4"
2497                    }
2498                }
2499            ],
2500            "pagination": {
2501                "page": 0,
2502                "page_size": 100,
2503                "total": 1
2504            }
2505        }
2506        "#;
2507
2508        // Mock contract state response
2509        let contract_state_resp = r#"
2510        {
2511            "accounts": [
2512                {
2513                    "chain": "ethereum",
2514                    "address": "0x1111111111111111111111111111111111111111",
2515                    "title": "",
2516                    "slots": {},
2517                    "native_balance": "0x01f4",
2518                    "token_balances": {},
2519                    "code": "0x00",
2520                    "code_hash": "0x5c06b7c5b3d910fd33bc2229846f9ddaf91d584d9b196e16636901ac3a77077e",
2521                    "balance_modify_tx": "0x0000000000000000000000000000000000000000000000000000000000000000",
2522                    "code_modify_tx": "0x0000000000000000000000000000000000000000000000000000000000000000",
2523                    "creation_tx": null
2524                }
2525            ],
2526            "pagination": {
2527                "page": 0,
2528                "page_size": 100,
2529                "total": 1
2530            }
2531        }
2532        "#;
2533
2534        // Mock component TVL response
2535        let tvl_resp = r#"
2536        {
2537            "tvl": {
2538                "component1": 1000000.0
2539            },
2540            "pagination": {
2541                "page": 0,
2542                "page_size": 100,
2543                "total": 1
2544            }
2545        }
2546        "#;
2547
2548        let protocol_states_mock = server
2549            .mock("POST", "/v1/protocol_state")
2550            .expect(1)
2551            .with_body(protocol_states_resp)
2552            .create_async()
2553            .await;
2554
2555        let contract_state_mock = server
2556            .mock("POST", "/v1/contract_state")
2557            .expect(1)
2558            .with_body(contract_state_resp)
2559            .create_async()
2560            .await;
2561
2562        let tvl_mock = server
2563            .mock("POST", "/v1/component_tvl")
2564            .expect(1)
2565            .with_body(tvl_resp)
2566            .create_async()
2567            .await;
2568
2569        let client = HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
2570            .expect("create client");
2571
2572        #[allow(deprecated)]
2573        let component = tycho_common::dto::ProtocolComponent {
2574            id: "component1".to_string(),
2575            protocol_system: "test_protocol".to_string(),
2576            protocol_type_name: "test_type".to_string(),
2577            chain: Chain::Ethereum,
2578            tokens: vec![],
2579            contract_ids: vec![
2580                Bytes::from_str("0x1111111111111111111111111111111111111111").unwrap()
2581            ],
2582            static_attributes: HashMap::new(),
2583            change: tycho_common::dto::ChangeType::Creation,
2584            creation_tx: Bytes::from_str(
2585                "0x0000000000000000000000000000000000000000000000000000000000000000",
2586            )
2587            .unwrap(),
2588            created_at: chrono::Utc::now().naive_utc(),
2589        };
2590
2591        let mut components = HashMap::new();
2592        components.insert("component1".to_string(), component);
2593
2594        let contract_ids =
2595            vec![Bytes::from_str("0x1111111111111111111111111111111111111111").unwrap()];
2596
2597        let request = SnapshotParameters::new(
2598            Chain::Ethereum,
2599            "test_protocol",
2600            &components,
2601            &contract_ids,
2602            12345,
2603        );
2604
2605        let response = client
2606            .get_snapshots(&request, None, RPC_CLIENT_CONCURRENCY)
2607            .await
2608            .expect("get snapshots");
2609
2610        // Verify all mocks were called
2611        protocol_states_mock.assert();
2612        contract_state_mock.assert();
2613        tvl_mock.assert();
2614
2615        // Assert states
2616        assert_eq!(response.states.len(), 1);
2617        assert!(response
2618            .states
2619            .contains_key("component1"));
2620
2621        // Check that the state has the expected TVL
2622        let component_state = response
2623            .states
2624            .get("component1")
2625            .unwrap();
2626        assert_eq!(component_state.component_tvl, Some(1000000.0));
2627
2628        // Assert VM storage
2629        assert_eq!(response.vm_storage.len(), 1);
2630        let contract_addr = Bytes::from_str("0x1111111111111111111111111111111111111111").unwrap();
2631        assert!(response
2632            .vm_storage
2633            .contains_key(&contract_addr));
2634    }
2635
2636    #[tokio::test]
2637    async fn test_get_snapshots_empty_components() {
2638        let server = Server::new_async().await;
2639        let client = HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
2640            .expect("create client");
2641
2642        let components = HashMap::new();
2643        let contract_ids = vec![];
2644
2645        let request = SnapshotParameters::new(
2646            Chain::Ethereum,
2647            "test_protocol",
2648            &components,
2649            &contract_ids,
2650            12345,
2651        );
2652
2653        let response = client
2654            .get_snapshots(&request, None, RPC_CLIENT_CONCURRENCY)
2655            .await
2656            .expect("get snapshots");
2657
2658        // Should return empty response without making any requests
2659        assert!(response.states.is_empty());
2660        assert!(response.vm_storage.is_empty());
2661    }
2662
2663    #[tokio::test]
2664    async fn test_get_snapshots_without_tvl() {
2665        let mut server = Server::new_async().await;
2666
2667        let protocol_states_resp = r#"
2668        {
2669            "states": [
2670                {
2671                    "component_id": "component1",
2672                    "attributes": {},
2673                    "balances": {}
2674                }
2675            ],
2676            "pagination": {
2677                "page": 0,
2678                "page_size": 100,
2679                "total": 1
2680            }
2681        }
2682        "#;
2683
2684        let protocol_states_mock = server
2685            .mock("POST", "/v1/protocol_state")
2686            .expect(1)
2687            .with_body(protocol_states_resp)
2688            .create_async()
2689            .await;
2690
2691        let client = HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
2692            .expect("create client");
2693
2694        // Create test component
2695        #[allow(deprecated)]
2696        let component = tycho_common::dto::ProtocolComponent {
2697            id: "component1".to_string(),
2698            protocol_system: "test_protocol".to_string(),
2699            protocol_type_name: "test_type".to_string(),
2700            chain: Chain::Ethereum,
2701            tokens: vec![],
2702            contract_ids: vec![],
2703            static_attributes: HashMap::new(),
2704            change: tycho_common::dto::ChangeType::Creation,
2705            creation_tx: Bytes::from_str(
2706                "0x0000000000000000000000000000000000000000000000000000000000000000",
2707            )
2708            .unwrap(),
2709            created_at: chrono::Utc::now().naive_utc(),
2710        };
2711
2712        let mut components = HashMap::new();
2713        components.insert("component1".to_string(), component);
2714        let contract_ids = vec![];
2715
2716        let request = SnapshotParameters::new(
2717            Chain::Ethereum,
2718            "test_protocol",
2719            &components,
2720            &contract_ids,
2721            12345,
2722        )
2723        .include_balances(false)
2724        .include_tvl(false);
2725
2726        let response = client
2727            .get_snapshots(&request, None, RPC_CLIENT_CONCURRENCY)
2728            .await
2729            .expect("get snapshots");
2730
2731        // Verify only necessary mocks were called
2732        protocol_states_mock.assert();
2733        // No contract_state_mock.assert() since contract_ids is empty
2734        // No tvl_mock.assert() since include_tvl is false
2735
2736        assert_eq!(response.states.len(), 1);
2737        // Check that TVL is None since we didn't request it
2738        let component_state = response
2739            .states
2740            .get("component1")
2741            .unwrap();
2742        assert_eq!(component_state.component_tvl, None);
2743    }
2744
2745    #[tokio::test]
2746    async fn test_compression_enabled() {
2747        let mut server = Server::new_async().await;
2748        let server_resp = GET_CONTRACT_STATE_RESP;
2749
2750        // Compress the response using zstd
2751        let compressed_body =
2752            zstd::encode_all(server_resp.as_bytes(), 0).expect("compression failed");
2753
2754        let mocked_server = server
2755            .mock("POST", "/v1/contract_state")
2756            .expect(1)
2757            .with_header("Content-Encoding", "zstd")
2758            .with_body(compressed_body)
2759            .create_async()
2760            .await;
2761
2762        // Create client with compression enabled
2763        let client = HttpRPCClient::new(
2764            server.url().as_str(),
2765            HttpRPCClientOptions::new().with_compression(true),
2766        )
2767        .expect("create client");
2768
2769        let response = client
2770            .get_contract_state(&Default::default())
2771            .await
2772            .expect("get state");
2773        let accounts = response.accounts;
2774
2775        mocked_server.assert();
2776        assert_eq!(accounts.len(), 1);
2777        assert_eq!(accounts[0].native_balance, Bytes::from(500u16.to_be_bytes()));
2778    }
2779
2780    #[tokio::test]
2781    async fn test_compression_disabled() {
2782        let mut server = Server::new_async().await;
2783        let server_resp = GET_CONTRACT_STATE_RESP;
2784
2785        // Verify client does NOT send Accept-Encoding: zstd when compression is disabled
2786        // Instead, server should receive request without compression headers
2787        let mocked_server = server
2788            .mock("POST", "/v1/contract_state")
2789            .expect(1)
2790            .match_header("Accept-Encoding", mockito::Matcher::Missing)
2791            .with_status(200)
2792            .with_body(server_resp)
2793            .create_async()
2794            .await;
2795
2796        // Create client with compression disabled
2797        let client = HttpRPCClient::new(
2798            server.url().as_str(),
2799            HttpRPCClientOptions::new().with_compression(false),
2800        )
2801        .expect("create client");
2802
2803        let response = client
2804            .get_contract_state(&Default::default())
2805            .await
2806            .expect("get state");
2807        let accounts = response.accounts;
2808
2809        // Verify the mock was called (client sent request without Accept-Encoding header)
2810        mocked_server.assert();
2811        assert_eq!(accounts.len(), 1);
2812        assert_eq!(accounts[0].native_balance, Bytes::from(500u16.to_be_bytes()));
2813    }
2814
2815    #[rstest]
2816    #[case::single_page(2, 1000)]
2817    #[case::multiple_pages_within_concurrency(10, 2)]
2818    #[case::exceeds_concurrency_limit(60, 2)]
2819    #[tokio::test]
2820    async fn test_get_all_tokens_pagination_and_concurrency(
2821        #[case] total_tokens: usize,
2822        #[case] page_size: usize,
2823    ) {
2824        use std::sync::atomic::{AtomicUsize, Ordering};
2825
2826        let allowed_concurrency = 10;
2827
2828        let concurrent_requests = Arc::new(AtomicUsize::new(0));
2829        let max_concurrent = Arc::new(AtomicUsize::new(0));
2830
2831        let mut server = Server::new_async().await;
2832
2833        let total_pages = (total_tokens as f64 / page_size as f64).ceil() as i64;
2834
2835        // Mock all required pages
2836        for page in 0..total_pages {
2837            let concurrent = concurrent_requests.clone();
2838            let max_conc = max_concurrent.clone();
2839
2840            let tokens_in_page = {
2841                let start_idx = (page as usize) * page_size;
2842                let end_idx = ((page as usize + 1) * page_size).min(total_tokens);
2843                (start_idx..end_idx)
2844                    .map(|i| {
2845                        format!(
2846                            r#"{{
2847                            "chain": "ethereum",
2848                            "address": "0x{i:040x}",
2849                            "symbol": "TOKEN_{i}",
2850                            "decimals": 18,
2851                            "tax": 0,
2852                            "gas": [30000],
2853                            "quality": 100
2854                        }}"#
2855                        )
2856                    })
2857                    .collect::<Vec<_>>()
2858            };
2859
2860            let tokens_json = tokens_in_page.join(",");
2861            let response = format!(
2862                r#"{{
2863                    "tokens": [{tokens_json}],
2864                    "pagination": {{
2865                        "page": {page},
2866                        "page_size": {page_size},
2867                        "total": {total_tokens}
2868                    }}
2869                }}"#,
2870            );
2871
2872            server
2873                .mock("POST", "/v1/tokens")
2874                .expect(1)
2875                .with_chunked_body(move |w| {
2876                    // Track concurrent requests
2877                    let current = concurrent.fetch_add(1, Ordering::SeqCst);
2878                    max_conc.fetch_max(current + 1, Ordering::SeqCst);
2879
2880                    // Simulate some work to increase likelihood of concurrent requests
2881                    std::thread::sleep(Duration::from_millis(10));
2882
2883                    concurrent.fetch_sub(1, Ordering::SeqCst);
2884
2885                    w.write_all(response.as_bytes())
2886                })
2887                .create_async()
2888                .await;
2889        }
2890
2891        let client = HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
2892            .expect("create client");
2893
2894        let tokens = client
2895            .get_all_tokens(Chain::Ethereum, None, None, Some(page_size), allowed_concurrency)
2896            .await
2897            .expect("get all tokens");
2898
2899        // Verify concurrency was respected
2900        let max = max_concurrent.load(Ordering::SeqCst);
2901        let expected_max_concurrency = (total_pages as usize)
2902            .saturating_sub(1)
2903            .min(allowed_concurrency);
2904        assert!(
2905            max <= allowed_concurrency,
2906            "Expected max concurrent requests <= {allowed_concurrency}, got {max}"
2907        );
2908
2909        // For cases with multiple pages, verify we actually used concurrency
2910        if total_pages > 1 && expected_max_concurrency > 1 {
2911            assert!(
2912                max > 0,
2913                "Expected some concurrent requests for multi-page response, got {max}"
2914            );
2915        }
2916
2917        // Verify we got all expected tokens
2918        assert_eq!(
2919            tokens.len(),
2920            total_tokens,
2921            "Expected {total_tokens} tokens, got {}",
2922            tokens.len()
2923        );
2924
2925        // Verify tokens are in the expected order
2926        for (i, token) in tokens.iter().enumerate() {
2927            assert_eq!(token.symbol, format!("TOKEN_{i}"), "Token at index {i} has wrong symbol");
2928        }
2929    }
2930}