tycho_client/
rpc.rs

1//! # Tycho RPC Client
2//!
3//! The objective of this module is to provide swift and simplified access to the Remote Procedure
4//! Call (RPC) endpoints of Tycho. These endpoints are chiefly responsible for facilitating data
5//! queries, especially querying snapshots of data.
6use std::{collections::HashMap, sync::Arc};
7
8use async_trait::async_trait;
9use futures03::future::try_join_all;
10#[cfg(test)]
11use mockall::automock;
12use reqwest::{header, Client, ClientBuilder, Url};
13use thiserror::Error;
14use tokio::sync::Semaphore;
15use tracing::{debug, error, instrument, trace, warn};
16use tycho_common::{
17    dto::{
18        Chain, ComponentTvlRequestBody, ComponentTvlRequestResponse, PaginationParams,
19        PaginationResponse, ProtocolComponentRequestResponse, ProtocolComponentsRequestBody,
20        ProtocolStateRequestBody, ProtocolStateRequestResponse, ProtocolSystemsRequestBody,
21        ProtocolSystemsRequestResponse, ResponseToken, StateRequestBody, StateRequestResponse,
22        TokensRequestBody, TokensRequestResponse, TracedEntryPointRequestBody,
23        TracedEntryPointRequestResponse, VersionParam,
24    },
25    Bytes,
26};
27
28use crate::TYCHO_SERVER_VERSION;
29
30#[derive(Error, Debug)]
31pub enum RPCError {
32    /// The passed tycho url failed to parse.
33    #[error("Failed to parse URL: {0}. Error: {1}")]
34    UrlParsing(String, String),
35
36    /// The request data is not correctly formed.
37    #[error("Failed to format request: {0}")]
38    FormatRequest(String),
39
40    /// Errors forwarded from the HTTP protocol.
41    #[error("Unexpected HTTP client error: {0}")]
42    HttpClient(String),
43
44    /// The response from the server could not be parsed correctly.
45    #[error("Failed to parse response: {0}")]
46    ParseResponse(String),
47
48    /// Other fatal errors.
49    #[error("Fatal error: {0}")]
50    Fatal(String),
51}
52
53#[cfg_attr(test, automock)]
54#[async_trait]
55pub trait RPCClient: Send + Sync {
56    /// Retrieves a snapshot of contract state.
57    async fn get_contract_state(
58        &self,
59        request: &StateRequestBody,
60    ) -> Result<StateRequestResponse, RPCError>;
61
62    async fn get_contract_state_paginated(
63        &self,
64        chain: Chain,
65        ids: &[Bytes],
66        protocol_system: &str,
67        version: &VersionParam,
68        chunk_size: usize,
69        concurrency: usize,
70    ) -> Result<StateRequestResponse, RPCError> {
71        let semaphore = Arc::new(Semaphore::new(concurrency));
72
73        // Sort the ids to maximize server-side cache hits
74        let mut sorted_ids = ids.to_vec();
75        sorted_ids.sort();
76
77        let chunked_bodies = sorted_ids
78            .chunks(chunk_size)
79            .map(|chunk| StateRequestBody {
80                contract_ids: Some(chunk.to_vec()),
81                protocol_system: protocol_system.to_string(),
82                chain,
83                version: version.clone(),
84                pagination: PaginationParams { page: 0, page_size: chunk_size as i64 },
85            })
86            .collect::<Vec<_>>();
87
88        let mut tasks = Vec::new();
89        for body in chunked_bodies.iter() {
90            let sem = semaphore.clone();
91            tasks.push(async move {
92                let _permit = sem
93                    .acquire()
94                    .await
95                    .map_err(|_| RPCError::Fatal("Semaphore dropped".to_string()))?;
96                self.get_contract_state(body).await
97            });
98        }
99
100        // Execute all tasks concurrently with the defined concurrency limit.
101        let responses = try_join_all(tasks).await?;
102
103        // Aggregate the responses into a single result.
104        let accounts = responses
105            .iter()
106            .flat_map(|r| r.accounts.clone())
107            .collect();
108        let total: i64 = responses
109            .iter()
110            .map(|r| r.pagination.total)
111            .sum();
112
113        Ok(StateRequestResponse {
114            accounts,
115            pagination: PaginationResponse { page: 0, page_size: chunk_size as i64, total },
116        })
117    }
118
119    async fn get_protocol_components(
120        &self,
121        request: &ProtocolComponentsRequestBody,
122    ) -> Result<ProtocolComponentRequestResponse, RPCError>;
123
124    async fn get_protocol_components_paginated(
125        &self,
126        request: &ProtocolComponentsRequestBody,
127        chunk_size: usize,
128        concurrency: usize,
129    ) -> Result<ProtocolComponentRequestResponse, RPCError> {
130        let semaphore = Arc::new(Semaphore::new(concurrency));
131
132        // If a set of component IDs is specified, the maximum return size is already known,
133        // allowing us to pre-compute the number of requests to be made.
134        match request.component_ids {
135            Some(ref ids) => {
136                // We can divide the component_ids into chunks of size chunk_size
137                let chunked_bodies = ids
138                    .chunks(chunk_size)
139                    .enumerate()
140                    .map(|(index, _)| ProtocolComponentsRequestBody {
141                        protocol_system: request.protocol_system.clone(),
142                        component_ids: request.component_ids.clone(),
143                        tvl_gt: request.tvl_gt,
144                        chain: request.chain,
145                        pagination: PaginationParams {
146                            page: index as i64,
147                            page_size: chunk_size as i64,
148                        },
149                    })
150                    .collect::<Vec<_>>();
151
152                let mut tasks = Vec::new();
153                for body in chunked_bodies.iter() {
154                    let sem = semaphore.clone();
155                    tasks.push(async move {
156                        let _permit = sem
157                            .acquire()
158                            .await
159                            .map_err(|_| RPCError::Fatal("Semaphore dropped".to_string()))?;
160                        self.get_protocol_components(body).await
161                    });
162                }
163
164                try_join_all(tasks)
165                    .await
166                    .map(|responses| ProtocolComponentRequestResponse {
167                        protocol_components: responses
168                            .into_iter()
169                            .flat_map(|r| r.protocol_components.into_iter())
170                            .collect(),
171                        pagination: PaginationResponse {
172                            page: 0,
173                            page_size: chunk_size as i64,
174                            total: ids.len() as i64,
175                        },
176                    })
177            }
178            _ => {
179                // If no component ids are specified, we need to make requests based on the total
180                // number of results from the first response.
181
182                let initial_request = ProtocolComponentsRequestBody {
183                    protocol_system: request.protocol_system.clone(),
184                    component_ids: request.component_ids.clone(),
185                    tvl_gt: request.tvl_gt,
186                    chain: request.chain,
187                    pagination: PaginationParams { page: 0, page_size: chunk_size as i64 },
188                };
189                let first_response = self
190                    .get_protocol_components(&initial_request)
191                    .await
192                    .map_err(|err| RPCError::Fatal(err.to_string()))?;
193
194                let total_items = first_response.pagination.total;
195                let total_pages = (total_items as f64 / chunk_size as f64).ceil() as i64;
196
197                // Initialize the final response accumulator
198                let mut accumulated_response = ProtocolComponentRequestResponse {
199                    protocol_components: first_response.protocol_components,
200                    pagination: PaginationResponse {
201                        page: 0,
202                        page_size: chunk_size as i64,
203                        total: total_items,
204                    },
205                };
206
207                let mut page = 1;
208                while page < total_pages {
209                    let requests_in_this_iteration = (total_pages - page).min(concurrency as i64);
210
211                    // Create request bodies for parallel requests, respecting the concurrency limit
212                    let chunked_bodies = (0..requests_in_this_iteration)
213                        .map(|iter| ProtocolComponentsRequestBody {
214                            protocol_system: request.protocol_system.clone(),
215                            component_ids: request.component_ids.clone(),
216                            tvl_gt: request.tvl_gt,
217                            chain: request.chain,
218                            pagination: PaginationParams {
219                                page: page + iter,
220                                page_size: chunk_size as i64,
221                            },
222                        })
223                        .collect::<Vec<_>>();
224
225                    let tasks: Vec<_> = chunked_bodies
226                        .iter()
227                        .map(|body| {
228                            let sem = semaphore.clone();
229                            async move {
230                                let _permit = sem.acquire().await.map_err(|_| {
231                                    RPCError::Fatal("Semaphore dropped".to_string())
232                                })?;
233                                self.get_protocol_components(body).await
234                            }
235                        })
236                        .collect();
237
238                    let responses = try_join_all(tasks)
239                        .await
240                        .map(|responses| {
241                            let total = responses[0].pagination.total;
242                            ProtocolComponentRequestResponse {
243                                protocol_components: responses
244                                    .into_iter()
245                                    .flat_map(|r| r.protocol_components.into_iter())
246                                    .collect(),
247                                pagination: PaginationResponse {
248                                    page,
249                                    page_size: chunk_size as i64,
250                                    total,
251                                },
252                            }
253                        });
254
255                    // Update the accumulated response or set the initial response
256                    match responses {
257                        Ok(mut resp) => {
258                            accumulated_response
259                                .protocol_components
260                                .append(&mut resp.protocol_components);
261                        }
262                        Err(e) => return Err(e),
263                    }
264
265                    page += concurrency as i64;
266                }
267                Ok(accumulated_response)
268            }
269        }
270    }
271
272    async fn get_protocol_states(
273        &self,
274        request: &ProtocolStateRequestBody,
275    ) -> Result<ProtocolStateRequestResponse, RPCError>;
276
277    #[allow(clippy::too_many_arguments)]
278    async fn get_protocol_states_paginated<T>(
279        &self,
280        chain: Chain,
281        ids: &[T],
282        protocol_system: &str,
283        include_balances: bool,
284        version: &VersionParam,
285        chunk_size: usize,
286        concurrency: usize,
287    ) -> Result<ProtocolStateRequestResponse, RPCError>
288    where
289        T: AsRef<str> + Sync + 'static,
290    {
291        let semaphore = Arc::new(Semaphore::new(concurrency));
292        let chunked_bodies = ids
293            .chunks(chunk_size)
294            .map(|c| ProtocolStateRequestBody {
295                protocol_ids: Some(
296                    c.iter()
297                        .map(|id| id.as_ref().to_string())
298                        .collect(),
299                ),
300                protocol_system: protocol_system.to_string(),
301                chain,
302                include_balances,
303                version: version.clone(),
304                pagination: PaginationParams { page: 0, page_size: chunk_size as i64 },
305            })
306            .collect::<Vec<_>>();
307
308        let mut tasks = Vec::new();
309        for body in chunked_bodies.iter() {
310            let sem = semaphore.clone();
311            tasks.push(async move {
312                let _permit = sem
313                    .acquire()
314                    .await
315                    .map_err(|_| RPCError::Fatal("Semaphore dropped".to_string()))?;
316                self.get_protocol_states(body).await
317            });
318        }
319
320        try_join_all(tasks)
321            .await
322            .map(|responses| {
323                let states = responses
324                    .clone()
325                    .into_iter()
326                    .flat_map(|r| r.states)
327                    .collect();
328                let total = responses
329                    .iter()
330                    .map(|r| r.pagination.total)
331                    .sum();
332                ProtocolStateRequestResponse {
333                    states,
334                    pagination: PaginationResponse { page: 0, page_size: chunk_size as i64, total },
335                }
336            })
337    }
338
339    /// This function returns only one chunk of tokens. To get all tokens please call
340    /// get_all_tokens.
341    async fn get_tokens(
342        &self,
343        request: &TokensRequestBody,
344    ) -> Result<TokensRequestResponse, RPCError>;
345
346    async fn get_all_tokens(
347        &self,
348        chain: Chain,
349        min_quality: Option<i32>,
350        traded_n_days_ago: Option<u64>,
351        chunk_size: usize,
352    ) -> Result<Vec<ResponseToken>, RPCError> {
353        let mut request_page = 0;
354        let mut all_tokens = Vec::new();
355        loop {
356            let mut response = self
357                .get_tokens(&TokensRequestBody {
358                    token_addresses: None,
359                    min_quality,
360                    traded_n_days_ago,
361                    pagination: PaginationParams {
362                        page: request_page,
363                        page_size: chunk_size.try_into().map_err(|_| {
364                            RPCError::FormatRequest(
365                                "Failed to convert chunk_size into i64".to_string(),
366                            )
367                        })?,
368                    },
369                    chain,
370                })
371                .await?;
372
373            let num_tokens = response.tokens.len();
374            all_tokens.append(&mut response.tokens);
375            request_page += 1;
376
377            if num_tokens < chunk_size {
378                break;
379            }
380        }
381        Ok(all_tokens)
382    }
383
384    async fn get_protocol_systems(
385        &self,
386        request: &ProtocolSystemsRequestBody,
387    ) -> Result<ProtocolSystemsRequestResponse, RPCError>;
388
389    async fn get_component_tvl(
390        &self,
391        request: &ComponentTvlRequestBody,
392    ) -> Result<ComponentTvlRequestResponse, RPCError>;
393
394    async fn get_component_tvl_paginated(
395        &self,
396        request: &ComponentTvlRequestBody,
397        chunk_size: usize,
398        concurrency: usize,
399    ) -> Result<ComponentTvlRequestResponse, RPCError> {
400        let semaphore = Arc::new(Semaphore::new(concurrency));
401
402        match request.component_ids {
403            Some(ref ids) => {
404                let chunked_requests = ids
405                    .chunks(chunk_size)
406                    .enumerate()
407                    .map(|(index, _)| ComponentTvlRequestBody {
408                        chain: request.chain,
409                        protocol_system: request.protocol_system.clone(),
410                        component_ids: Some(ids.clone()),
411                        pagination: PaginationParams {
412                            page: index as i64,
413                            page_size: chunk_size as i64,
414                        },
415                    })
416                    .collect::<Vec<_>>();
417
418                let tasks: Vec<_> = chunked_requests
419                    .into_iter()
420                    .map(|req| {
421                        let sem = semaphore.clone();
422                        async move {
423                            let _permit = sem
424                                .acquire()
425                                .await
426                                .map_err(|_| RPCError::Fatal("Semaphore dropped".to_string()))?;
427                            self.get_component_tvl(&req).await
428                        }
429                    })
430                    .collect();
431
432                let responses = try_join_all(tasks).await?;
433
434                let mut merged_tvl = HashMap::new();
435                for resp in responses {
436                    for (key, value) in resp.tvl {
437                        *merged_tvl.entry(key).or_insert(0.0) = value;
438                    }
439                }
440
441                Ok(ComponentTvlRequestResponse {
442                    tvl: merged_tvl,
443                    pagination: PaginationResponse {
444                        page: 0,
445                        page_size: chunk_size as i64,
446                        total: ids.len() as i64,
447                    },
448                })
449            }
450            _ => {
451                let first_request = ComponentTvlRequestBody {
452                    chain: request.chain,
453                    protocol_system: request.protocol_system.clone(),
454                    component_ids: request.component_ids.clone(),
455                    pagination: PaginationParams { page: 0, page_size: chunk_size as i64 },
456                };
457
458                let first_response = self
459                    .get_component_tvl(&first_request)
460                    .await?;
461                let total_items = first_response.pagination.total;
462                let total_pages = (total_items as f64 / chunk_size as f64).ceil() as i64;
463
464                let mut merged_tvl = first_response.tvl;
465
466                let mut page = 1;
467                while page < total_pages {
468                    let requests_in_this_iteration = (total_pages - page).min(concurrency as i64);
469
470                    let chunked_requests: Vec<_> = (0..requests_in_this_iteration)
471                        .map(|i| ComponentTvlRequestBody {
472                            chain: request.chain,
473                            protocol_system: request.protocol_system.clone(),
474                            component_ids: request.component_ids.clone(),
475                            pagination: PaginationParams {
476                                page: page + i,
477                                page_size: chunk_size as i64,
478                            },
479                        })
480                        .collect();
481
482                    let tasks: Vec<_> = chunked_requests
483                        .into_iter()
484                        .map(|req| {
485                            let sem = semaphore.clone();
486                            async move {
487                                let _permit = sem.acquire().await.map_err(|_| {
488                                    RPCError::Fatal("Semaphore dropped".to_string())
489                                })?;
490                                self.get_component_tvl(&req).await
491                            }
492                        })
493                        .collect();
494
495                    let responses = try_join_all(tasks).await?;
496
497                    // merge hashmap
498                    for resp in responses {
499                        for (key, value) in resp.tvl {
500                            *merged_tvl.entry(key).or_insert(0.0) += value;
501                        }
502                    }
503
504                    page += concurrency as i64;
505                }
506
507                Ok(ComponentTvlRequestResponse {
508                    tvl: merged_tvl,
509                    pagination: PaginationResponse {
510                        page: 0,
511                        page_size: chunk_size as i64,
512                        total: total_items,
513                    },
514                })
515            }
516        }
517    }
518
519    async fn get_traced_entry_points(
520        &self,
521        request: &TracedEntryPointRequestBody,
522    ) -> Result<TracedEntryPointRequestResponse, RPCError>;
523
524    async fn get_traced_entry_points_paginated(
525        &self,
526        chain: Chain,
527        protocol_system: &str,
528        component_ids: &[String],
529        chunk_size: usize,
530        concurrency: usize,
531    ) -> Result<TracedEntryPointRequestResponse, RPCError> {
532        let semaphore = Arc::new(Semaphore::new(concurrency));
533        let chunked_bodies = component_ids
534            .chunks(chunk_size)
535            .map(|c| TracedEntryPointRequestBody {
536                chain,
537                protocol_system: protocol_system.to_string(),
538                component_ids: Some(c.to_vec()),
539                pagination: PaginationParams { page: 0, page_size: chunk_size as i64 },
540            })
541            .collect::<Vec<_>>();
542
543        let mut tasks = Vec::new();
544        for body in chunked_bodies.iter() {
545            let sem = semaphore.clone();
546            tasks.push(async move {
547                let _permit = sem
548                    .acquire()
549                    .await
550                    .map_err(|_| RPCError::Fatal("Semaphore dropped".to_string()))?;
551                self.get_traced_entry_points(body).await
552            });
553        }
554
555        try_join_all(tasks)
556            .await
557            .map(|responses| {
558                let traced_entry_points = responses
559                    .clone()
560                    .into_iter()
561                    .flat_map(|r| r.traced_entry_points)
562                    .collect();
563                let total = responses
564                    .iter()
565                    .map(|r| r.pagination.total)
566                    .sum();
567                TracedEntryPointRequestResponse {
568                    traced_entry_points,
569                    pagination: PaginationResponse { page: 0, page_size: chunk_size as i64, total },
570                }
571            })
572    }
573}
574
575#[derive(Debug, Clone)]
576pub struct HttpRPCClient {
577    http_client: Client,
578    url: Url,
579}
580
581impl HttpRPCClient {
582    pub fn new(base_uri: &str, auth_key: Option<&str>) -> Result<Self, RPCError> {
583        let uri = base_uri
584            .parse::<Url>()
585            .map_err(|e| RPCError::UrlParsing(base_uri.to_string(), e.to_string()))?;
586
587        // Add default headers
588        let mut headers = header::HeaderMap::new();
589        headers.insert(header::CONTENT_TYPE, header::HeaderValue::from_static("application/json"));
590        let user_agent = format!("tycho-client-{version}", version = env!("CARGO_PKG_VERSION"));
591        headers.insert(
592            header::USER_AGENT,
593            header::HeaderValue::from_str(&user_agent)
594                .map_err(|e| RPCError::FormatRequest(format!("Invalid user agent format: {e}")))?,
595        );
596
597        // Add Authorization if one is given
598        if let Some(key) = auth_key {
599            let mut auth_value = header::HeaderValue::from_str(key).map_err(|e| {
600                RPCError::FormatRequest(format!("Invalid authorization key format: {e}"))
601            })?;
602            auth_value.set_sensitive(true);
603            headers.insert(header::AUTHORIZATION, auth_value);
604        }
605
606        let client = ClientBuilder::new()
607            .default_headers(headers)
608            .http2_prior_knowledge()
609            .build()
610            .map_err(|e| RPCError::HttpClient(e.to_string()))?;
611        Ok(Self { http_client: client, url: uri })
612    }
613}
614
615#[async_trait]
616impl RPCClient for HttpRPCClient {
617    #[instrument(skip(self, request))]
618    async fn get_contract_state(
619        &self,
620        request: &StateRequestBody,
621    ) -> Result<StateRequestResponse, RPCError> {
622        // Check if contract ids are specified
623        if request
624            .contract_ids
625            .as_ref()
626            .is_none_or(|ids| ids.is_empty())
627        {
628            warn!("No contract ids specified in request.");
629        }
630
631        let uri = format!(
632            "{}/{}/contract_state",
633            self.url
634                .to_string()
635                .trim_end_matches('/'),
636            TYCHO_SERVER_VERSION
637        );
638        debug!(%uri, "Sending contract_state request to Tycho server");
639        trace!(?request, "Sending request to Tycho server");
640
641        let response = self
642            .http_client
643            .post(&uri)
644            .json(request)
645            .send()
646            .await
647            .map_err(|e| RPCError::HttpClient(e.to_string()))?;
648        trace!(?response, "Received response from Tycho server");
649
650        let body = response
651            .text()
652            .await
653            .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
654        if body.is_empty() {
655            // Pure native protocols will return empty contract states
656            return Ok(StateRequestResponse {
657                accounts: vec![],
658                pagination: PaginationResponse {
659                    page: request.pagination.page,
660                    page_size: request.pagination.page,
661                    total: 0,
662                },
663            });
664        }
665
666        let accounts = serde_json::from_str::<StateRequestResponse>(&body)
667            .map_err(|err| RPCError::ParseResponse(format!("Error: {err}, Body: {body}")))?;
668        trace!(?accounts, "Received contract_state response from Tycho server");
669
670        Ok(accounts)
671    }
672
673    async fn get_protocol_components(
674        &self,
675        request: &ProtocolComponentsRequestBody,
676    ) -> Result<ProtocolComponentRequestResponse, RPCError> {
677        let uri = format!(
678            "{}/{}/protocol_components",
679            self.url
680                .to_string()
681                .trim_end_matches('/'),
682            TYCHO_SERVER_VERSION,
683        );
684        debug!(%uri, "Sending protocol_components request to Tycho server");
685        trace!(?request, "Sending request to Tycho server");
686
687        let response = self
688            .http_client
689            .post(uri)
690            .header(header::CONTENT_TYPE, "application/json")
691            .json(request)
692            .send()
693            .await
694            .map_err(|e| RPCError::HttpClient(e.to_string()))?;
695
696        trace!(?response, "Received response from Tycho server");
697
698        let body = response
699            .text()
700            .await
701            .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
702        let components = serde_json::from_str::<ProtocolComponentRequestResponse>(&body)
703            .map_err(|err| RPCError::ParseResponse(format!("Error: {err}, Body: {body}")))?;
704        trace!(?components, "Received protocol_components response from Tycho server");
705
706        Ok(components)
707    }
708
709    async fn get_protocol_states(
710        &self,
711        request: &ProtocolStateRequestBody,
712    ) -> Result<ProtocolStateRequestResponse, RPCError> {
713        // Check if protocol ids are specified
714        if request
715            .protocol_ids
716            .as_ref()
717            .is_none_or(|ids| ids.is_empty())
718        {
719            warn!("No protocol ids specified in request.");
720        }
721
722        let uri = format!(
723            "{}/{}/protocol_state",
724            self.url
725                .to_string()
726                .trim_end_matches('/'),
727            TYCHO_SERVER_VERSION
728        );
729        debug!(%uri, "Sending protocol_states request to Tycho server");
730        trace!(?request, "Sending request to Tycho server");
731
732        let response = self
733            .http_client
734            .post(&uri)
735            .json(request)
736            .send()
737            .await
738            .map_err(|e| RPCError::HttpClient(e.to_string()))?;
739        trace!(?response, "Received response from Tycho server");
740
741        let body = response
742            .text()
743            .await
744            .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
745
746        if body.is_empty() {
747            // Pure VM protocols will return empty states
748            return Ok(ProtocolStateRequestResponse {
749                states: vec![],
750                pagination: PaginationResponse {
751                    page: request.pagination.page,
752                    page_size: request.pagination.page_size,
753                    total: 0,
754                },
755            });
756        }
757
758        let states = serde_json::from_str::<ProtocolStateRequestResponse>(&body)
759            .map_err(|err| RPCError::ParseResponse(format!("Error: {err}, Body: {body}")))?;
760        trace!(?states, "Received protocol_states response from Tycho server");
761
762        Ok(states)
763    }
764
765    async fn get_tokens(
766        &self,
767        request: &TokensRequestBody,
768    ) -> Result<TokensRequestResponse, RPCError> {
769        let uri = format!(
770            "{}/{}/tokens",
771            self.url
772                .to_string()
773                .trim_end_matches('/'),
774            TYCHO_SERVER_VERSION
775        );
776        debug!(%uri, "Sending tokens request to Tycho server");
777
778        let response = self
779            .http_client
780            .post(&uri)
781            .json(request)
782            .send()
783            .await
784            .map_err(|e| RPCError::HttpClient(e.to_string()))?;
785
786        let body = response
787            .text()
788            .await
789            .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
790        let tokens = serde_json::from_str::<TokensRequestResponse>(&body)
791            .map_err(|err| RPCError::ParseResponse(format!("Error: {err}, Body: {body}")))?;
792
793        Ok(tokens)
794    }
795
796    async fn get_protocol_systems(
797        &self,
798        request: &ProtocolSystemsRequestBody,
799    ) -> Result<ProtocolSystemsRequestResponse, RPCError> {
800        let uri = format!(
801            "{}/{}/protocol_systems",
802            self.url
803                .to_string()
804                .trim_end_matches('/'),
805            TYCHO_SERVER_VERSION
806        );
807        debug!(%uri, "Sending protocol_systems request to Tycho server");
808        trace!(?request, "Sending request to Tycho server");
809        let response = self
810            .http_client
811            .post(&uri)
812            .json(request)
813            .send()
814            .await
815            .map_err(|e| RPCError::HttpClient(e.to_string()))?;
816        trace!(?response, "Received response from Tycho server");
817        let body = response
818            .text()
819            .await
820            .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
821        let protocol_systems = serde_json::from_str::<ProtocolSystemsRequestResponse>(&body)
822            .map_err(|err| RPCError::ParseResponse(format!("Error: {err}, Body: {body}")))?;
823        trace!(?protocol_systems, "Received protocol_systems response from Tycho server");
824        Ok(protocol_systems)
825    }
826
827    async fn get_component_tvl(
828        &self,
829        request: &ComponentTvlRequestBody,
830    ) -> Result<ComponentTvlRequestResponse, RPCError> {
831        let uri = format!(
832            "{}/{}/component_tvl",
833            self.url
834                .to_string()
835                .trim_end_matches('/'),
836            TYCHO_SERVER_VERSION
837        );
838        debug!(%uri, "Sending get_component_tvl request to Tycho server");
839        trace!(?request, "Sending request to Tycho server");
840        let response = self
841            .http_client
842            .post(&uri)
843            .json(request)
844            .send()
845            .await
846            .map_err(|e| RPCError::HttpClient(e.to_string()))?;
847        trace!(?response, "Received response from Tycho server");
848        let body = response
849            .text()
850            .await
851            .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
852        let component_tvl =
853            serde_json::from_str::<ComponentTvlRequestResponse>(&body).map_err(|err| {
854                error!("Failed to parse component_tvl response: {:?}", &body);
855                RPCError::ParseResponse(format!("Error: {err}, Body: {body}"))
856            })?;
857        trace!(?component_tvl, "Received component_tvl response from Tycho server");
858        Ok(component_tvl)
859    }
860
861    async fn get_traced_entry_points(
862        &self,
863        request: &TracedEntryPointRequestBody,
864    ) -> Result<TracedEntryPointRequestResponse, RPCError> {
865        let uri = format!(
866            "{}/{TYCHO_SERVER_VERSION}/traced_entry_points",
867            self.url
868                .to_string()
869                .trim_end_matches('/')
870        );
871        debug!(%uri, "Sending traced_entry_points request to Tycho server");
872        trace!(?request, "Sending request to Tycho server");
873
874        let response = self
875            .http_client
876            .post(&uri)
877            .json(request)
878            .send()
879            .await
880            .map_err(|e| RPCError::HttpClient(e.to_string()))?;
881        trace!(?response, "Received response from Tycho server");
882
883        let body = response
884            .text()
885            .await
886            .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
887        let entrypoints =
888            serde_json::from_str::<TracedEntryPointRequestResponse>(&body).map_err(|err| {
889                error!("Failed to parse traced_entry_points response: {:?}", &body);
890                RPCError::ParseResponse(format!("Error: {err}, Body: {body}"))
891            })?;
892        trace!(?entrypoints, "Received traced_entry_points response from Tycho server");
893        Ok(entrypoints)
894    }
895}
896
897#[cfg(test)]
898mod tests {
899    use std::{
900        collections::{HashMap, HashSet},
901        str::FromStr,
902    };
903
904    use mockito::Server;
905    use rstest::rstest;
906    // TODO: remove once deprecated ProtocolId struct is removed
907    #[allow(deprecated)]
908    use tycho_common::dto::ProtocolId;
909    use tycho_common::dto::TracingParams;
910
911    use super::*;
912
913    // Dummy implementation of `get_protocol_states_paginated` for backwards compatibility testing
914    // purposes
915    impl MockRPCClient {
916        #[allow(clippy::too_many_arguments)]
917        async fn test_get_protocol_states_paginated<T>(
918            &self,
919            chain: Chain,
920            ids: &[T],
921            protocol_system: &str,
922            include_balances: bool,
923            version: &VersionParam,
924            chunk_size: usize,
925            _concurrency: usize,
926        ) -> Vec<ProtocolStateRequestBody>
927        where
928            T: AsRef<str> + Clone + Send + Sync + 'static,
929        {
930            ids.chunks(chunk_size)
931                .map(|chunk| ProtocolStateRequestBody {
932                    protocol_ids: Some(
933                        chunk
934                            .iter()
935                            .map(|id| id.as_ref().to_string())
936                            .collect(),
937                    ),
938                    protocol_system: protocol_system.to_string(),
939                    chain,
940                    include_balances,
941                    version: version.clone(),
942                    pagination: PaginationParams { page: 0, page_size: chunk_size as i64 },
943                })
944                .collect()
945        }
946    }
947
948    // TODO: remove once deprecated ProtocolId struct is removed
949    #[allow(deprecated)]
950    #[rstest]
951    #[case::protocol_id_input(vec![
952        ProtocolId { id: "id1".to_string(), chain: Chain::Ethereum },
953        ProtocolId { id: "id2".to_string(), chain: Chain::Ethereum }
954    ])]
955    #[case::string_input(vec![
956        "id1".to_string(),
957        "id2".to_string()
958    ])]
959    #[tokio::test]
960    async fn test_get_protocol_states_paginated_backwards_compatibility<T>(#[case] ids: Vec<T>)
961    where
962        T: AsRef<str> + Clone + Send + Sync + 'static,
963    {
964        let mock_client = MockRPCClient::new();
965
966        let request_bodies = mock_client
967            .test_get_protocol_states_paginated(
968                Chain::Ethereum,
969                &ids,
970                "test_system",
971                true,
972                &VersionParam::default(),
973                2,
974                2,
975            )
976            .await;
977
978        // Verify that the request bodies have been created correctly
979        assert_eq!(request_bodies.len(), 1);
980        assert_eq!(
981            request_bodies[0]
982                .protocol_ids
983                .as_ref()
984                .unwrap()
985                .len(),
986            2
987        );
988    }
989
990    #[tokio::test]
991    async fn test_get_contract_state() {
992        let mut server = Server::new_async().await;
993        let server_resp = r#"
994        {
995            "accounts": [
996                {
997                    "chain": "ethereum",
998                    "address": "0x0000000000000000000000000000000000000000",
999                    "title": "",
1000                    "slots": {},
1001                    "native_balance": "0x01f4",
1002                    "token_balances": {},
1003                    "code": "0x00",
1004                    "code_hash": "0x5c06b7c5b3d910fd33bc2229846f9ddaf91d584d9b196e16636901ac3a77077e",
1005                    "balance_modify_tx": "0x0000000000000000000000000000000000000000000000000000000000000000",
1006                    "code_modify_tx": "0x0000000000000000000000000000000000000000000000000000000000000000",
1007                    "creation_tx": null
1008                }
1009            ],
1010            "pagination": {
1011                "page": 0,
1012                "page_size": 20,
1013                "total": 10
1014            }
1015        }
1016        "#;
1017        // test that the response is deserialized correctly
1018        serde_json::from_str::<StateRequestResponse>(server_resp).expect("deserialize");
1019
1020        let mocked_server = server
1021            .mock("POST", "/v1/contract_state")
1022            .expect(1)
1023            .with_body(server_resp)
1024            .create_async()
1025            .await;
1026
1027        let client = HttpRPCClient::new(server.url().as_str(), None).expect("create client");
1028
1029        let response = client
1030            .get_contract_state(&Default::default())
1031            .await
1032            .expect("get state");
1033        let accounts = response.accounts;
1034
1035        mocked_server.assert();
1036        assert_eq!(accounts.len(), 1);
1037        assert_eq!(accounts[0].slots, HashMap::new());
1038        assert_eq!(accounts[0].native_balance, Bytes::from(500u16.to_be_bytes()));
1039        assert_eq!(accounts[0].code, [0].to_vec());
1040        assert_eq!(
1041            accounts[0].code_hash,
1042            hex::decode("5c06b7c5b3d910fd33bc2229846f9ddaf91d584d9b196e16636901ac3a77077e")
1043                .unwrap()
1044        );
1045    }
1046
1047    #[tokio::test]
1048    async fn test_get_protocol_components() {
1049        let mut server = Server::new_async().await;
1050        let server_resp = r#"
1051        {
1052            "protocol_components": [
1053                {
1054                    "id": "State1",
1055                    "protocol_system": "ambient",
1056                    "protocol_type_name": "Pool",
1057                    "chain": "ethereum",
1058                    "tokens": [
1059                        "0x0000000000000000000000000000000000000000",
1060                        "0x0000000000000000000000000000000000000001"
1061                    ],
1062                    "contract_ids": [
1063                        "0x0000000000000000000000000000000000000000"
1064                    ],
1065                    "static_attributes": {
1066                        "attribute_1": "0x00000000000003e8"
1067                    },
1068                    "change": "Creation",
1069                    "creation_tx": "0x0000000000000000000000000000000000000000000000000000000000000000",
1070                    "created_at": "2022-01-01T00:00:00"
1071                }
1072            ],
1073            "pagination": {
1074                "page": 0,
1075                "page_size": 20,
1076                "total": 10
1077            }
1078        }
1079        "#;
1080        // test that the response is deserialized correctly
1081        serde_json::from_str::<ProtocolComponentRequestResponse>(server_resp).expect("deserialize");
1082
1083        let mocked_server = server
1084            .mock("POST", "/v1/protocol_components")
1085            .expect(1)
1086            .with_body(server_resp)
1087            .create_async()
1088            .await;
1089
1090        let client = HttpRPCClient::new(server.url().as_str(), None).expect("create client");
1091
1092        let response = client
1093            .get_protocol_components(&Default::default())
1094            .await
1095            .expect("get state");
1096        let components = response.protocol_components;
1097
1098        mocked_server.assert();
1099        assert_eq!(components.len(), 1);
1100        assert_eq!(components[0].id, "State1");
1101        assert_eq!(components[0].protocol_system, "ambient");
1102        assert_eq!(components[0].protocol_type_name, "Pool");
1103        assert_eq!(components[0].tokens.len(), 2);
1104        let expected_attributes =
1105            [("attribute_1".to_string(), Bytes::from(1000_u64.to_be_bytes()))]
1106                .iter()
1107                .cloned()
1108                .collect::<HashMap<String, Bytes>>();
1109        assert_eq!(components[0].static_attributes, expected_attributes);
1110    }
1111
1112    #[tokio::test]
1113    async fn test_get_protocol_states() {
1114        let mut server = Server::new_async().await;
1115        let server_resp = r#"
1116        {
1117            "states": [
1118                {
1119                    "component_id": "State1",
1120                    "attributes": {
1121                        "attribute_1": "0x00000000000003e8"
1122                    },
1123                    "balances": {
1124                        "0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2": "0x01f4"
1125                    }
1126                }
1127            ],
1128            "pagination": {
1129                "page": 0,
1130                "page_size": 20,
1131                "total": 10
1132            }
1133        }
1134        "#;
1135        // test that the response is deserialized correctly
1136        serde_json::from_str::<ProtocolStateRequestResponse>(server_resp).expect("deserialize");
1137
1138        let mocked_server = server
1139            .mock("POST", "/v1/protocol_state")
1140            .expect(1)
1141            .with_body(server_resp)
1142            .create_async()
1143            .await;
1144        let client = HttpRPCClient::new(server.url().as_str(), None).expect("create client");
1145
1146        let response = client
1147            .get_protocol_states(&Default::default())
1148            .await
1149            .expect("get state");
1150        let states = response.states;
1151
1152        mocked_server.assert();
1153        assert_eq!(states.len(), 1);
1154        assert_eq!(states[0].component_id, "State1");
1155        let expected_attributes =
1156            [("attribute_1".to_string(), Bytes::from(1000_u64.to_be_bytes()))]
1157                .iter()
1158                .cloned()
1159                .collect::<HashMap<String, Bytes>>();
1160        assert_eq!(states[0].attributes, expected_attributes);
1161        let expected_balances = [(
1162            Bytes::from_str("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2")
1163                .expect("Unsupported address format"),
1164            Bytes::from_str("0x01f4").unwrap(),
1165        )]
1166        .iter()
1167        .cloned()
1168        .collect::<HashMap<Bytes, Bytes>>();
1169        assert_eq!(states[0].balances, expected_balances);
1170    }
1171
1172    #[tokio::test]
1173    async fn test_get_tokens() {
1174        let mut server = Server::new_async().await;
1175        let server_resp = r#"
1176        {
1177            "tokens": [
1178              {
1179                "chain": "ethereum",
1180                "address": "0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2",
1181                "symbol": "WETH",
1182                "decimals": 18,
1183                "tax": 0,
1184                "gas": [
1185                  29962
1186                ],
1187                "quality": 100
1188              },
1189              {
1190                "chain": "ethereum",
1191                "address": "0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48",
1192                "symbol": "USDC",
1193                "decimals": 6,
1194                "tax": 0,
1195                "gas": [
1196                  40652
1197                ],
1198                "quality": 100
1199              }
1200            ],
1201            "pagination": {
1202              "page": 0,
1203              "page_size": 20,
1204              "total": 10
1205            }
1206          }
1207        "#;
1208        // test that the response is deserialized correctly
1209        serde_json::from_str::<TokensRequestResponse>(server_resp).expect("deserialize");
1210
1211        let mocked_server = server
1212            .mock("POST", "/v1/tokens")
1213            .expect(1)
1214            .with_body(server_resp)
1215            .create_async()
1216            .await;
1217        let client = HttpRPCClient::new(server.url().as_str(), None).expect("create client");
1218
1219        let response = client
1220            .get_tokens(&Default::default())
1221            .await
1222            .expect("get tokens");
1223
1224        let expected = vec![
1225            ResponseToken {
1226                chain: Chain::Ethereum,
1227                address: Bytes::from_str("0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2").unwrap(),
1228                symbol: "WETH".to_string(),
1229                decimals: 18,
1230                tax: 0,
1231                gas: vec![Some(29962)],
1232                quality: 100,
1233            },
1234            ResponseToken {
1235                chain: Chain::Ethereum,
1236                address: Bytes::from_str("0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48").unwrap(),
1237                symbol: "USDC".to_string(),
1238                decimals: 6,
1239                tax: 0,
1240                gas: vec![Some(40652)],
1241                quality: 100,
1242            },
1243        ];
1244
1245        mocked_server.assert();
1246        assert_eq!(response.tokens, expected);
1247        assert_eq!(response.pagination, PaginationResponse { page: 0, page_size: 20, total: 10 });
1248    }
1249
1250    #[tokio::test]
1251    async fn test_get_protocol_systems() {
1252        let mut server = Server::new_async().await;
1253        let server_resp = r#"
1254        {
1255            "protocol_systems": [
1256                "system1",
1257                "system2"
1258            ],
1259            "pagination": {
1260                "page": 0,
1261                "page_size": 20,
1262                "total": 10
1263            }
1264        }
1265        "#;
1266        // test that the response is deserialized correctly
1267        serde_json::from_str::<ProtocolSystemsRequestResponse>(server_resp).expect("deserialize");
1268
1269        let mocked_server = server
1270            .mock("POST", "/v1/protocol_systems")
1271            .expect(1)
1272            .with_body(server_resp)
1273            .create_async()
1274            .await;
1275        let client = HttpRPCClient::new(server.url().as_str(), None).expect("create client");
1276
1277        let response = client
1278            .get_protocol_systems(&Default::default())
1279            .await
1280            .expect("get protocol systems");
1281        let protocol_systems = response.protocol_systems;
1282
1283        mocked_server.assert();
1284        assert_eq!(protocol_systems, vec!["system1", "system2"]);
1285    }
1286
1287    #[tokio::test]
1288    async fn test_get_component_tvl() {
1289        let mut server = Server::new_async().await;
1290        let server_resp = r#"
1291        {
1292            "tvl": {
1293                "component1": 100.0
1294            },
1295            "pagination": {
1296                "page": 0,
1297                "page_size": 20,
1298                "total": 10
1299            }
1300        }
1301        "#;
1302        // test that the response is deserialized correctly
1303        serde_json::from_str::<ComponentTvlRequestResponse>(server_resp).expect("deserialize");
1304
1305        let mocked_server = server
1306            .mock("POST", "/v1/component_tvl")
1307            .expect(1)
1308            .with_body(server_resp)
1309            .create_async()
1310            .await;
1311        let client = HttpRPCClient::new(server.url().as_str(), None).expect("create client");
1312
1313        let response = client
1314            .get_component_tvl(&Default::default())
1315            .await
1316            .expect("get protocol systems");
1317        let component_tvl = response.tvl;
1318
1319        mocked_server.assert();
1320        assert_eq!(component_tvl.get("component1"), Some(&100.0));
1321    }
1322
1323    #[tokio::test]
1324    async fn test_get_traced_entry_points() {
1325        let mut server = Server::new_async().await;
1326        let server_resp = r#"
1327        {
1328            "traced_entry_points": {
1329                "component_1": [
1330                    [
1331                        {
1332                            "entry_point": {
1333                                "external_id": "entrypoint_a",
1334                                "target": "0x0000000000000000000000000000000000000001",
1335                                "signature": "sig()"
1336                            },
1337                            "params": {
1338                                "method": "rpctracer",
1339                                "caller": "0x000000000000000000000000000000000000000a",
1340                                "calldata": "0x000000000000000000000000000000000000000b"
1341                            }
1342                        },
1343                        {
1344                            "retriggers": [
1345                                [
1346                                    "0x00000000000000000000000000000000000000aa",
1347                                    "0x0000000000000000000000000000000000000aaa"
1348                                ]
1349                            ],
1350                            "accessed_slots": {
1351                                "0x0000000000000000000000000000000000aaaa": [
1352                                    "0x0000000000000000000000000000000000aaaa"
1353                                ]
1354                            }
1355                        }
1356                    ]
1357                ]
1358            },
1359            "pagination": {
1360                "page": 0,
1361                "page_size": 20,
1362                "total": 1
1363            }
1364        }
1365        "#;
1366        // test that the response is deserialized correctly
1367        serde_json::from_str::<TracedEntryPointRequestResponse>(server_resp).expect("deserialize");
1368
1369        let mocked_server = server
1370            .mock("POST", "/v1/traced_entry_points")
1371            .expect(1)
1372            .with_body(server_resp)
1373            .create_async()
1374            .await;
1375        let client = HttpRPCClient::new(server.url().as_str(), None).expect("create client");
1376
1377        let response = client
1378            .get_traced_entry_points(&Default::default())
1379            .await
1380            .expect("get traced entry points");
1381        let entrypoints = response.traced_entry_points;
1382
1383        mocked_server.assert();
1384        assert_eq!(entrypoints.len(), 1);
1385        let comp1_entrypoints = entrypoints
1386            .get("component_1")
1387            .expect("component_1 entrypoints should exist");
1388        assert_eq!(comp1_entrypoints.len(), 1);
1389
1390        let (entrypoint, trace_result) = &comp1_entrypoints[0];
1391        assert_eq!(entrypoint.entry_point.external_id, "entrypoint_a");
1392        assert_eq!(
1393            entrypoint.entry_point.target,
1394            Bytes::from_str("0x0000000000000000000000000000000000000001").unwrap()
1395        );
1396        assert_eq!(entrypoint.entry_point.signature, "sig()");
1397        let TracingParams::RPCTracer(rpc_params) = &entrypoint.params;
1398        assert_eq!(
1399            rpc_params.caller,
1400            Some(Bytes::from("0x000000000000000000000000000000000000000a"))
1401        );
1402        assert_eq!(rpc_params.calldata, Bytes::from("0x000000000000000000000000000000000000000b"));
1403
1404        assert_eq!(
1405            trace_result.retriggers,
1406            HashSet::from([(
1407                Bytes::from("0x00000000000000000000000000000000000000aa"),
1408                Bytes::from("0x0000000000000000000000000000000000000aaa")
1409            )])
1410        );
1411        assert_eq!(trace_result.accessed_slots.len(), 1);
1412        assert_eq!(
1413            trace_result.accessed_slots,
1414            HashMap::from([(
1415                Bytes::from("0x0000000000000000000000000000000000aaaa"),
1416                HashSet::from([Bytes::from("0x0000000000000000000000000000000000aaaa")])
1417            )])
1418        );
1419    }
1420}