tycho_client/
rpc.rs

1//! # Tycho RPC Client
2//!
3//! The objective of this module is to provide swift and simplified access to the Remote Procedure
4//! Call (RPC) endpoints of Tycho. These endpoints are chiefly responsible for facilitating data
5//! queries, especially querying snapshots of data.
6use std::{collections::HashMap, sync::Arc};
7
8use async_trait::async_trait;
9use futures03::future::try_join_all;
10#[cfg(test)]
11use mockall::automock;
12use reqwest::{header, Client, ClientBuilder, Url};
13use thiserror::Error;
14use tokio::sync::Semaphore;
15use tracing::{debug, error, instrument, trace, warn};
16use tycho_common::{
17    dto::{
18        Chain, ComponentTvlRequestBody, ComponentTvlRequestResponse, PaginationParams,
19        PaginationResponse, ProtocolComponentRequestResponse, ProtocolComponentsRequestBody,
20        ProtocolStateRequestBody, ProtocolStateRequestResponse, ProtocolSystemsRequestBody,
21        ProtocolSystemsRequestResponse, ResponseToken, StateRequestBody, StateRequestResponse,
22        TokensRequestBody, TokensRequestResponse, TracedEntryPointRequestBody,
23        TracedEntryPointRequestResponse, VersionParam,
24    },
25    Bytes,
26};
27
28use crate::TYCHO_SERVER_VERSION;
29
30#[derive(Error, Debug)]
31pub enum RPCError {
32    /// The passed tycho url failed to parse.
33    #[error("Failed to parse URL: {0}. Error: {1}")]
34    UrlParsing(String, String),
35
36    /// The request data is not correctly formed.
37    #[error("Failed to format request: {0}")]
38    FormatRequest(String),
39
40    /// Errors forwarded from the HTTP protocol.
41    #[error("Unexpected HTTP client error: {0}")]
42    HttpClient(String),
43
44    /// The response from the server could not be parsed correctly.
45    #[error("Failed to parse response: {0}")]
46    ParseResponse(String),
47
48    /// Other fatal errors.
49    #[error("Fatal error: {0}")]
50    Fatal(String),
51}
52
53#[cfg_attr(test, automock)]
54#[async_trait]
55pub trait RPCClient: Send + Sync {
56    /// Retrieves a snapshot of contract state.
57    async fn get_contract_state(
58        &self,
59        request: &StateRequestBody,
60    ) -> Result<StateRequestResponse, RPCError>;
61
62    async fn get_contract_state_paginated(
63        &self,
64        chain: Chain,
65        ids: &[Bytes],
66        protocol_system: &str,
67        version: &VersionParam,
68        chunk_size: usize,
69        concurrency: usize,
70    ) -> Result<StateRequestResponse, RPCError> {
71        let semaphore = Arc::new(Semaphore::new(concurrency));
72
73        let chunked_bodies = ids
74            .chunks(chunk_size)
75            .map(|chunk| StateRequestBody {
76                contract_ids: Some(chunk.to_vec()),
77                protocol_system: protocol_system.to_string(),
78                chain,
79                version: version.clone(),
80                pagination: PaginationParams { page: 0, page_size: chunk_size as i64 },
81            })
82            .collect::<Vec<_>>();
83
84        let mut tasks = Vec::new();
85        for body in chunked_bodies.iter() {
86            let sem = semaphore.clone();
87            tasks.push(async move {
88                let _permit = sem
89                    .acquire()
90                    .await
91                    .map_err(|_| RPCError::Fatal("Semaphore dropped".to_string()))?;
92                self.get_contract_state(body).await
93            });
94        }
95
96        // Execute all tasks concurrently with the defined concurrency limit.
97        let responses = try_join_all(tasks).await?;
98
99        // Aggregate the responses into a single result.
100        let accounts = responses
101            .iter()
102            .flat_map(|r| r.accounts.clone())
103            .collect();
104        let total: i64 = responses
105            .iter()
106            .map(|r| r.pagination.total)
107            .sum();
108
109        Ok(StateRequestResponse {
110            accounts,
111            pagination: PaginationResponse { page: 0, page_size: chunk_size as i64, total },
112        })
113    }
114
115    async fn get_protocol_components(
116        &self,
117        request: &ProtocolComponentsRequestBody,
118    ) -> Result<ProtocolComponentRequestResponse, RPCError>;
119
120    async fn get_protocol_components_paginated(
121        &self,
122        request: &ProtocolComponentsRequestBody,
123        chunk_size: usize,
124        concurrency: usize,
125    ) -> Result<ProtocolComponentRequestResponse, RPCError> {
126        let semaphore = Arc::new(Semaphore::new(concurrency));
127
128        // If a set of component IDs is specified, the maximum return size is already known,
129        // allowing us to pre-compute the number of requests to be made.
130        match request.component_ids {
131            Some(ref ids) => {
132                // We can divide the component_ids into chunks of size chunk_size
133                let chunked_bodies = ids
134                    .chunks(chunk_size)
135                    .enumerate()
136                    .map(|(index, _)| ProtocolComponentsRequestBody {
137                        protocol_system: request.protocol_system.clone(),
138                        component_ids: request.component_ids.clone(),
139                        tvl_gt: request.tvl_gt,
140                        chain: request.chain,
141                        pagination: PaginationParams {
142                            page: index as i64,
143                            page_size: chunk_size as i64,
144                        },
145                    })
146                    .collect::<Vec<_>>();
147
148                let mut tasks = Vec::new();
149                for body in chunked_bodies.iter() {
150                    let sem = semaphore.clone();
151                    tasks.push(async move {
152                        let _permit = sem
153                            .acquire()
154                            .await
155                            .map_err(|_| RPCError::Fatal("Semaphore dropped".to_string()))?;
156                        self.get_protocol_components(body).await
157                    });
158                }
159
160                try_join_all(tasks)
161                    .await
162                    .map(|responses| ProtocolComponentRequestResponse {
163                        protocol_components: responses
164                            .into_iter()
165                            .flat_map(|r| r.protocol_components.into_iter())
166                            .collect(),
167                        pagination: PaginationResponse {
168                            page: 0,
169                            page_size: chunk_size as i64,
170                            total: ids.len() as i64,
171                        },
172                    })
173            }
174            _ => {
175                // If no component ids are specified, we need to make requests based on the total
176                // number of results from the first response.
177
178                let initial_request = ProtocolComponentsRequestBody {
179                    protocol_system: request.protocol_system.clone(),
180                    component_ids: request.component_ids.clone(),
181                    tvl_gt: request.tvl_gt,
182                    chain: request.chain,
183                    pagination: PaginationParams { page: 0, page_size: chunk_size as i64 },
184                };
185                let first_response = self
186                    .get_protocol_components(&initial_request)
187                    .await
188                    .map_err(|err| RPCError::Fatal(err.to_string()))?;
189
190                let total_items = first_response.pagination.total;
191                let total_pages = (total_items as f64 / chunk_size as f64).ceil() as i64;
192
193                // Initialize the final response accumulator
194                let mut accumulated_response = ProtocolComponentRequestResponse {
195                    protocol_components: first_response.protocol_components,
196                    pagination: PaginationResponse {
197                        page: 0,
198                        page_size: chunk_size as i64,
199                        total: total_items,
200                    },
201                };
202
203                let mut page = 1;
204                while page < total_pages {
205                    let requests_in_this_iteration = (total_pages - page).min(concurrency as i64);
206
207                    // Create request bodies for parallel requests, respecting the concurrency limit
208                    let chunked_bodies = (0..requests_in_this_iteration)
209                        .map(|iter| ProtocolComponentsRequestBody {
210                            protocol_system: request.protocol_system.clone(),
211                            component_ids: request.component_ids.clone(),
212                            tvl_gt: request.tvl_gt,
213                            chain: request.chain,
214                            pagination: PaginationParams {
215                                page: page + iter,
216                                page_size: chunk_size as i64,
217                            },
218                        })
219                        .collect::<Vec<_>>();
220
221                    let tasks: Vec<_> = chunked_bodies
222                        .iter()
223                        .map(|body| {
224                            let sem = semaphore.clone();
225                            async move {
226                                let _permit = sem.acquire().await.map_err(|_| {
227                                    RPCError::Fatal("Semaphore dropped".to_string())
228                                })?;
229                                self.get_protocol_components(body).await
230                            }
231                        })
232                        .collect();
233
234                    let responses = try_join_all(tasks)
235                        .await
236                        .map(|responses| {
237                            let total = responses[0].pagination.total;
238                            ProtocolComponentRequestResponse {
239                                protocol_components: responses
240                                    .into_iter()
241                                    .flat_map(|r| r.protocol_components.into_iter())
242                                    .collect(),
243                                pagination: PaginationResponse {
244                                    page,
245                                    page_size: chunk_size as i64,
246                                    total,
247                                },
248                            }
249                        });
250
251                    // Update the accumulated response or set the initial response
252                    match responses {
253                        Ok(mut resp) => {
254                            accumulated_response
255                                .protocol_components
256                                .append(&mut resp.protocol_components);
257                        }
258                        Err(e) => return Err(e),
259                    }
260
261                    page += concurrency as i64;
262                }
263                Ok(accumulated_response)
264            }
265        }
266    }
267
268    async fn get_protocol_states(
269        &self,
270        request: &ProtocolStateRequestBody,
271    ) -> Result<ProtocolStateRequestResponse, RPCError>;
272
273    #[allow(clippy::too_many_arguments)]
274    async fn get_protocol_states_paginated<T>(
275        &self,
276        chain: Chain,
277        ids: &[T],
278        protocol_system: &str,
279        include_balances: bool,
280        version: &VersionParam,
281        chunk_size: usize,
282        concurrency: usize,
283    ) -> Result<ProtocolStateRequestResponse, RPCError>
284    where
285        T: AsRef<str> + Sync + 'static,
286    {
287        let semaphore = Arc::new(Semaphore::new(concurrency));
288        let chunked_bodies = ids
289            .chunks(chunk_size)
290            .map(|c| ProtocolStateRequestBody {
291                protocol_ids: Some(
292                    c.iter()
293                        .map(|id| id.as_ref().to_string())
294                        .collect(),
295                ),
296                protocol_system: protocol_system.to_string(),
297                chain,
298                include_balances,
299                version: version.clone(),
300                pagination: PaginationParams { page: 0, page_size: chunk_size as i64 },
301            })
302            .collect::<Vec<_>>();
303
304        let mut tasks = Vec::new();
305        for body in chunked_bodies.iter() {
306            let sem = semaphore.clone();
307            tasks.push(async move {
308                let _permit = sem
309                    .acquire()
310                    .await
311                    .map_err(|_| RPCError::Fatal("Semaphore dropped".to_string()))?;
312                self.get_protocol_states(body).await
313            });
314        }
315
316        try_join_all(tasks)
317            .await
318            .map(|responses| {
319                let states = responses
320                    .clone()
321                    .into_iter()
322                    .flat_map(|r| r.states)
323                    .collect();
324                let total = responses
325                    .iter()
326                    .map(|r| r.pagination.total)
327                    .sum();
328                ProtocolStateRequestResponse {
329                    states,
330                    pagination: PaginationResponse { page: 0, page_size: chunk_size as i64, total },
331                }
332            })
333    }
334
335    /// This function returns only one chunk of tokens. To get all tokens please call
336    /// get_all_tokens.
337    async fn get_tokens(
338        &self,
339        request: &TokensRequestBody,
340    ) -> Result<TokensRequestResponse, RPCError>;
341
342    async fn get_all_tokens(
343        &self,
344        chain: Chain,
345        min_quality: Option<i32>,
346        traded_n_days_ago: Option<u64>,
347        chunk_size: usize,
348    ) -> Result<Vec<ResponseToken>, RPCError> {
349        let mut request_page = 0;
350        let mut all_tokens = Vec::new();
351        loop {
352            let mut response = self
353                .get_tokens(&TokensRequestBody {
354                    token_addresses: None,
355                    min_quality,
356                    traded_n_days_ago,
357                    pagination: PaginationParams {
358                        page: request_page,
359                        page_size: chunk_size.try_into().map_err(|_| {
360                            RPCError::FormatRequest(
361                                "Failed to convert chunk_size into i64".to_string(),
362                            )
363                        })?,
364                    },
365                    chain,
366                })
367                .await?;
368
369            let num_tokens = response.tokens.len();
370            all_tokens.append(&mut response.tokens);
371            request_page += 1;
372
373            if num_tokens < chunk_size {
374                break;
375            }
376        }
377        Ok(all_tokens)
378    }
379
380    async fn get_protocol_systems(
381        &self,
382        request: &ProtocolSystemsRequestBody,
383    ) -> Result<ProtocolSystemsRequestResponse, RPCError>;
384
385    async fn get_component_tvl(
386        &self,
387        request: &ComponentTvlRequestBody,
388    ) -> Result<ComponentTvlRequestResponse, RPCError>;
389
390    async fn get_component_tvl_paginated(
391        &self,
392        request: &ComponentTvlRequestBody,
393        chunk_size: usize,
394        concurrency: usize,
395    ) -> Result<ComponentTvlRequestResponse, RPCError> {
396        let semaphore = Arc::new(Semaphore::new(concurrency));
397
398        match request.component_ids {
399            Some(ref ids) => {
400                let chunked_requests = ids
401                    .chunks(chunk_size)
402                    .enumerate()
403                    .map(|(index, _)| ComponentTvlRequestBody {
404                        chain: request.chain,
405                        protocol_system: request.protocol_system.clone(),
406                        component_ids: Some(ids.clone()),
407                        pagination: PaginationParams {
408                            page: index as i64,
409                            page_size: chunk_size as i64,
410                        },
411                    })
412                    .collect::<Vec<_>>();
413
414                let tasks: Vec<_> = chunked_requests
415                    .into_iter()
416                    .map(|req| {
417                        let sem = semaphore.clone();
418                        async move {
419                            let _permit = sem
420                                .acquire()
421                                .await
422                                .map_err(|_| RPCError::Fatal("Semaphore dropped".to_string()))?;
423                            self.get_component_tvl(&req).await
424                        }
425                    })
426                    .collect();
427
428                let responses = try_join_all(tasks).await?;
429
430                let mut merged_tvl = HashMap::new();
431                for resp in responses {
432                    for (key, value) in resp.tvl {
433                        *merged_tvl.entry(key).or_insert(0.0) = value;
434                    }
435                }
436
437                Ok(ComponentTvlRequestResponse {
438                    tvl: merged_tvl,
439                    pagination: PaginationResponse {
440                        page: 0,
441                        page_size: chunk_size as i64,
442                        total: ids.len() as i64,
443                    },
444                })
445            }
446            _ => {
447                let first_request = ComponentTvlRequestBody {
448                    chain: request.chain,
449                    protocol_system: request.protocol_system.clone(),
450                    component_ids: request.component_ids.clone(),
451                    pagination: PaginationParams { page: 0, page_size: chunk_size as i64 },
452                };
453
454                let first_response = self
455                    .get_component_tvl(&first_request)
456                    .await?;
457                let total_items = first_response.pagination.total;
458                let total_pages = (total_items as f64 / chunk_size as f64).ceil() as i64;
459
460                let mut merged_tvl = first_response.tvl;
461
462                let mut page = 1;
463                while page < total_pages {
464                    let requests_in_this_iteration = (total_pages - page).min(concurrency as i64);
465
466                    let chunked_requests: Vec<_> = (0..requests_in_this_iteration)
467                        .map(|i| ComponentTvlRequestBody {
468                            chain: request.chain,
469                            protocol_system: request.protocol_system.clone(),
470                            component_ids: request.component_ids.clone(),
471                            pagination: PaginationParams {
472                                page: page + i,
473                                page_size: chunk_size as i64,
474                            },
475                        })
476                        .collect();
477
478                    let tasks: Vec<_> = chunked_requests
479                        .into_iter()
480                        .map(|req| {
481                            let sem = semaphore.clone();
482                            async move {
483                                let _permit = sem.acquire().await.map_err(|_| {
484                                    RPCError::Fatal("Semaphore dropped".to_string())
485                                })?;
486                                self.get_component_tvl(&req).await
487                            }
488                        })
489                        .collect();
490
491                    let responses = try_join_all(tasks).await?;
492
493                    // merge hashmap
494                    for resp in responses {
495                        for (key, value) in resp.tvl {
496                            *merged_tvl.entry(key).or_insert(0.0) += value;
497                        }
498                    }
499
500                    page += concurrency as i64;
501                }
502
503                Ok(ComponentTvlRequestResponse {
504                    tvl: merged_tvl,
505                    pagination: PaginationResponse {
506                        page: 0,
507                        page_size: chunk_size as i64,
508                        total: total_items,
509                    },
510                })
511            }
512        }
513    }
514
515    async fn get_traced_entry_points(
516        &self,
517        request: &TracedEntryPointRequestBody,
518    ) -> Result<TracedEntryPointRequestResponse, RPCError>;
519
520    async fn get_traced_entry_points_paginated(
521        &self,
522        chain: Chain,
523        protocol_system: &str,
524        component_ids: &[String],
525        chunk_size: usize,
526        concurrency: usize,
527    ) -> Result<TracedEntryPointRequestResponse, RPCError> {
528        let semaphore = Arc::new(Semaphore::new(concurrency));
529        let chunked_bodies = component_ids
530            .chunks(chunk_size)
531            .map(|c| TracedEntryPointRequestBody {
532                chain,
533                protocol_system: protocol_system.to_string(),
534                component_ids: Some(c.to_vec()),
535                pagination: PaginationParams { page: 0, page_size: chunk_size as i64 },
536            })
537            .collect::<Vec<_>>();
538
539        let mut tasks = Vec::new();
540        for body in chunked_bodies.iter() {
541            let sem = semaphore.clone();
542            tasks.push(async move {
543                let _permit = sem
544                    .acquire()
545                    .await
546                    .map_err(|_| RPCError::Fatal("Semaphore dropped".to_string()))?;
547                self.get_traced_entry_points(body).await
548            });
549        }
550
551        try_join_all(tasks)
552            .await
553            .map(|responses| {
554                let traced_entry_points = responses
555                    .clone()
556                    .into_iter()
557                    .flat_map(|r| r.traced_entry_points)
558                    .collect();
559                let total = responses
560                    .iter()
561                    .map(|r| r.pagination.total)
562                    .sum();
563                TracedEntryPointRequestResponse {
564                    traced_entry_points,
565                    pagination: PaginationResponse { page: 0, page_size: chunk_size as i64, total },
566                }
567            })
568    }
569}
570
571#[derive(Debug, Clone)]
572pub struct HttpRPCClient {
573    http_client: Client,
574    url: Url,
575}
576
577impl HttpRPCClient {
578    pub fn new(base_uri: &str, auth_key: Option<&str>) -> Result<Self, RPCError> {
579        let uri = base_uri
580            .parse::<Url>()
581            .map_err(|e| RPCError::UrlParsing(base_uri.to_string(), e.to_string()))?;
582
583        // Add default headers
584        let mut headers = header::HeaderMap::new();
585        headers.insert(header::CONTENT_TYPE, header::HeaderValue::from_static("application/json"));
586        let user_agent = format!("tycho-client-{version}", version = env!("CARGO_PKG_VERSION"));
587        headers.insert(
588            header::USER_AGENT,
589            header::HeaderValue::from_str(&user_agent)
590                .map_err(|e| RPCError::FormatRequest(format!("Invalid user agent format: {e}")))?,
591        );
592
593        // Add Authorization if one is given
594        if let Some(key) = auth_key {
595            let mut auth_value = header::HeaderValue::from_str(key).map_err(|e| {
596                RPCError::FormatRequest(format!("Invalid authorization key format: {e}"))
597            })?;
598            auth_value.set_sensitive(true);
599            headers.insert(header::AUTHORIZATION, auth_value);
600        }
601
602        let client = ClientBuilder::new()
603            .default_headers(headers)
604            .http2_prior_knowledge()
605            .build()
606            .map_err(|e| RPCError::HttpClient(e.to_string()))?;
607        Ok(Self { http_client: client, url: uri })
608    }
609}
610
611#[async_trait]
612impl RPCClient for HttpRPCClient {
613    #[instrument(skip(self, request))]
614    async fn get_contract_state(
615        &self,
616        request: &StateRequestBody,
617    ) -> Result<StateRequestResponse, RPCError> {
618        // Check if contract ids are specified
619        if request
620            .contract_ids
621            .as_ref()
622            .is_none_or(|ids| ids.is_empty())
623        {
624            warn!("No contract ids specified in request.");
625        }
626
627        let uri = format!(
628            "{}/{}/contract_state",
629            self.url
630                .to_string()
631                .trim_end_matches('/'),
632            TYCHO_SERVER_VERSION
633        );
634        debug!(%uri, "Sending contract_state request to Tycho server");
635        trace!(?request, "Sending request to Tycho server");
636
637        let response = self
638            .http_client
639            .post(&uri)
640            .json(request)
641            .send()
642            .await
643            .map_err(|e| RPCError::HttpClient(e.to_string()))?;
644        trace!(?response, "Received response from Tycho server");
645
646        let body = response
647            .text()
648            .await
649            .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
650        if body.is_empty() {
651            // Pure native protocols will return empty contract states
652            return Ok(StateRequestResponse {
653                accounts: vec![],
654                pagination: PaginationResponse {
655                    page: request.pagination.page,
656                    page_size: request.pagination.page,
657                    total: 0,
658                },
659            });
660        }
661
662        let accounts = serde_json::from_str::<StateRequestResponse>(&body)
663            .map_err(|err| RPCError::ParseResponse(format!("Error: {err}, Body: {body}")))?;
664        trace!(?accounts, "Received contract_state response from Tycho server");
665
666        Ok(accounts)
667    }
668
669    async fn get_protocol_components(
670        &self,
671        request: &ProtocolComponentsRequestBody,
672    ) -> Result<ProtocolComponentRequestResponse, RPCError> {
673        let uri = format!(
674            "{}/{}/protocol_components",
675            self.url
676                .to_string()
677                .trim_end_matches('/'),
678            TYCHO_SERVER_VERSION,
679        );
680        debug!(%uri, "Sending protocol_components request to Tycho server");
681        trace!(?request, "Sending request to Tycho server");
682
683        let response = self
684            .http_client
685            .post(uri)
686            .header(header::CONTENT_TYPE, "application/json")
687            .json(request)
688            .send()
689            .await
690            .map_err(|e| RPCError::HttpClient(e.to_string()))?;
691
692        trace!(?response, "Received response from Tycho server");
693
694        let body = response
695            .text()
696            .await
697            .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
698        let components = serde_json::from_str::<ProtocolComponentRequestResponse>(&body)
699            .map_err(|err| RPCError::ParseResponse(format!("Error: {err}, Body: {body}")))?;
700        trace!(?components, "Received protocol_components response from Tycho server");
701
702        Ok(components)
703    }
704
705    async fn get_protocol_states(
706        &self,
707        request: &ProtocolStateRequestBody,
708    ) -> Result<ProtocolStateRequestResponse, RPCError> {
709        // Check if protocol ids are specified
710        if request
711            .protocol_ids
712            .as_ref()
713            .is_none_or(|ids| ids.is_empty())
714        {
715            warn!("No protocol ids specified in request.");
716        }
717
718        let uri = format!(
719            "{}/{}/protocol_state",
720            self.url
721                .to_string()
722                .trim_end_matches('/'),
723            TYCHO_SERVER_VERSION
724        );
725        debug!(%uri, "Sending protocol_states request to Tycho server");
726        trace!(?request, "Sending request to Tycho server");
727
728        let response = self
729            .http_client
730            .post(&uri)
731            .json(request)
732            .send()
733            .await
734            .map_err(|e| RPCError::HttpClient(e.to_string()))?;
735        trace!(?response, "Received response from Tycho server");
736
737        let body = response
738            .text()
739            .await
740            .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
741
742        if body.is_empty() {
743            // Pure VM protocols will return empty states
744            return Ok(ProtocolStateRequestResponse {
745                states: vec![],
746                pagination: PaginationResponse {
747                    page: request.pagination.page,
748                    page_size: request.pagination.page_size,
749                    total: 0,
750                },
751            });
752        }
753
754        let states = serde_json::from_str::<ProtocolStateRequestResponse>(&body)
755            .map_err(|err| RPCError::ParseResponse(format!("Error: {err}, Body: {body}")))?;
756        trace!(?states, "Received protocol_states response from Tycho server");
757
758        Ok(states)
759    }
760
761    async fn get_tokens(
762        &self,
763        request: &TokensRequestBody,
764    ) -> Result<TokensRequestResponse, RPCError> {
765        let uri = format!(
766            "{}/{}/tokens",
767            self.url
768                .to_string()
769                .trim_end_matches('/'),
770            TYCHO_SERVER_VERSION
771        );
772        debug!(%uri, "Sending tokens request to Tycho server");
773
774        let response = self
775            .http_client
776            .post(&uri)
777            .json(request)
778            .send()
779            .await
780            .map_err(|e| RPCError::HttpClient(e.to_string()))?;
781
782        let body = response
783            .text()
784            .await
785            .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
786        let tokens = serde_json::from_str::<TokensRequestResponse>(&body)
787            .map_err(|err| RPCError::ParseResponse(format!("Error: {err}, Body: {body}")))?;
788
789        Ok(tokens)
790    }
791
792    async fn get_protocol_systems(
793        &self,
794        request: &ProtocolSystemsRequestBody,
795    ) -> Result<ProtocolSystemsRequestResponse, RPCError> {
796        let uri = format!(
797            "{}/{}/protocol_systems",
798            self.url
799                .to_string()
800                .trim_end_matches('/'),
801            TYCHO_SERVER_VERSION
802        );
803        debug!(%uri, "Sending protocol_systems request to Tycho server");
804        trace!(?request, "Sending request to Tycho server");
805        let response = self
806            .http_client
807            .post(&uri)
808            .json(request)
809            .send()
810            .await
811            .map_err(|e| RPCError::HttpClient(e.to_string()))?;
812        trace!(?response, "Received response from Tycho server");
813        let body = response
814            .text()
815            .await
816            .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
817        let protocol_systems = serde_json::from_str::<ProtocolSystemsRequestResponse>(&body)
818            .map_err(|err| RPCError::ParseResponse(format!("Error: {err}, Body: {body}")))?;
819        trace!(?protocol_systems, "Received protocol_systems response from Tycho server");
820        Ok(protocol_systems)
821    }
822
823    async fn get_component_tvl(
824        &self,
825        request: &ComponentTvlRequestBody,
826    ) -> Result<ComponentTvlRequestResponse, RPCError> {
827        let uri = format!(
828            "{}/{}/component_tvl",
829            self.url
830                .to_string()
831                .trim_end_matches('/'),
832            TYCHO_SERVER_VERSION
833        );
834        debug!(%uri, "Sending get_component_tvl request to Tycho server");
835        trace!(?request, "Sending request to Tycho server");
836        let response = self
837            .http_client
838            .post(&uri)
839            .json(request)
840            .send()
841            .await
842            .map_err(|e| RPCError::HttpClient(e.to_string()))?;
843        trace!(?response, "Received response from Tycho server");
844        let body = response
845            .text()
846            .await
847            .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
848        let component_tvl =
849            serde_json::from_str::<ComponentTvlRequestResponse>(&body).map_err(|err| {
850                error!("Failed to parse component_tvl response: {:?}", &body);
851                RPCError::ParseResponse(format!("Error: {err}, Body: {body}"))
852            })?;
853        trace!(?component_tvl, "Received component_tvl response from Tycho server");
854        Ok(component_tvl)
855    }
856
857    async fn get_traced_entry_points(
858        &self,
859        request: &TracedEntryPointRequestBody,
860    ) -> Result<TracedEntryPointRequestResponse, RPCError> {
861        let uri = format!(
862            "{}/{TYCHO_SERVER_VERSION}/traced_entry_points",
863            self.url
864                .to_string()
865                .trim_end_matches('/')
866        );
867        debug!(%uri, "Sending traced_entry_points request to Tycho server");
868        trace!(?request, "Sending request to Tycho server");
869
870        let response = self
871            .http_client
872            .post(&uri)
873            .json(request)
874            .send()
875            .await
876            .map_err(|e| RPCError::HttpClient(e.to_string()))?;
877        trace!(?response, "Received response from Tycho server");
878
879        let body = response
880            .text()
881            .await
882            .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
883        let entrypoints =
884            serde_json::from_str::<TracedEntryPointRequestResponse>(&body).map_err(|err| {
885                error!("Failed to parse traced_entry_points response: {:?}", &body);
886                RPCError::ParseResponse(format!("Error: {err}, Body: {body}"))
887            })?;
888        trace!(?entrypoints, "Received traced_entry_points response from Tycho server");
889        Ok(entrypoints)
890    }
891}
892
893#[cfg(test)]
894mod tests {
895    use std::{
896        collections::{HashMap, HashSet},
897        str::FromStr,
898    };
899
900    use mockito::Server;
901    use rstest::rstest;
902    // TODO: remove once deprecated ProtocolId struct is removed
903    #[allow(deprecated)]
904    use tycho_common::dto::ProtocolId;
905    use tycho_common::dto::TracingParams;
906
907    use super::*;
908
909    // Dummy implementation of `get_protocol_states_paginated` for backwards compatibility testing
910    // purposes
911    impl MockRPCClient {
912        #[allow(clippy::too_many_arguments)]
913        async fn test_get_protocol_states_paginated<T>(
914            &self,
915            chain: Chain,
916            ids: &[T],
917            protocol_system: &str,
918            include_balances: bool,
919            version: &VersionParam,
920            chunk_size: usize,
921            _concurrency: usize,
922        ) -> Vec<ProtocolStateRequestBody>
923        where
924            T: AsRef<str> + Clone + Send + Sync + 'static,
925        {
926            ids.chunks(chunk_size)
927                .map(|chunk| ProtocolStateRequestBody {
928                    protocol_ids: Some(
929                        chunk
930                            .iter()
931                            .map(|id| id.as_ref().to_string())
932                            .collect(),
933                    ),
934                    protocol_system: protocol_system.to_string(),
935                    chain,
936                    include_balances,
937                    version: version.clone(),
938                    pagination: PaginationParams { page: 0, page_size: chunk_size as i64 },
939                })
940                .collect()
941        }
942    }
943
944    // TODO: remove once deprecated ProtocolId struct is removed
945    #[allow(deprecated)]
946    #[rstest]
947    #[case::protocol_id_input(vec![
948        ProtocolId { id: "id1".to_string(), chain: Chain::Ethereum },
949        ProtocolId { id: "id2".to_string(), chain: Chain::Ethereum }
950    ])]
951    #[case::string_input(vec![
952        "id1".to_string(),
953        "id2".to_string()
954    ])]
955    #[tokio::test]
956    async fn test_get_protocol_states_paginated_backwards_compatibility<T>(#[case] ids: Vec<T>)
957    where
958        T: AsRef<str> + Clone + Send + Sync + 'static,
959    {
960        let mock_client = MockRPCClient::new();
961
962        let request_bodies = mock_client
963            .test_get_protocol_states_paginated(
964                Chain::Ethereum,
965                &ids,
966                "test_system",
967                true,
968                &VersionParam::default(),
969                2,
970                2,
971            )
972            .await;
973
974        // Verify that the request bodies have been created correctly
975        assert_eq!(request_bodies.len(), 1);
976        assert_eq!(
977            request_bodies[0]
978                .protocol_ids
979                .as_ref()
980                .unwrap()
981                .len(),
982            2
983        );
984    }
985
986    #[tokio::test]
987    async fn test_get_contract_state() {
988        let mut server = Server::new_async().await;
989        let server_resp = r#"
990        {
991            "accounts": [
992                {
993                    "chain": "ethereum",
994                    "address": "0x0000000000000000000000000000000000000000",
995                    "title": "",
996                    "slots": {},
997                    "native_balance": "0x01f4",
998                    "token_balances": {},
999                    "code": "0x00",
1000                    "code_hash": "0x5c06b7c5b3d910fd33bc2229846f9ddaf91d584d9b196e16636901ac3a77077e",
1001                    "balance_modify_tx": "0x0000000000000000000000000000000000000000000000000000000000000000",
1002                    "code_modify_tx": "0x0000000000000000000000000000000000000000000000000000000000000000",
1003                    "creation_tx": null
1004                }
1005            ],
1006            "pagination": {
1007                "page": 0,
1008                "page_size": 20,
1009                "total": 10
1010            }
1011        }
1012        "#;
1013        // test that the response is deserialized correctly
1014        serde_json::from_str::<StateRequestResponse>(server_resp).expect("deserialize");
1015
1016        let mocked_server = server
1017            .mock("POST", "/v1/contract_state")
1018            .expect(1)
1019            .with_body(server_resp)
1020            .create_async()
1021            .await;
1022
1023        let client = HttpRPCClient::new(server.url().as_str(), None).expect("create client");
1024
1025        let response = client
1026            .get_contract_state(&Default::default())
1027            .await
1028            .expect("get state");
1029        let accounts = response.accounts;
1030
1031        mocked_server.assert();
1032        assert_eq!(accounts.len(), 1);
1033        assert_eq!(accounts[0].slots, HashMap::new());
1034        assert_eq!(accounts[0].native_balance, Bytes::from(500u16.to_be_bytes()));
1035        assert_eq!(accounts[0].code, [0].to_vec());
1036        assert_eq!(
1037            accounts[0].code_hash,
1038            hex::decode("5c06b7c5b3d910fd33bc2229846f9ddaf91d584d9b196e16636901ac3a77077e")
1039                .unwrap()
1040        );
1041    }
1042
1043    #[tokio::test]
1044    async fn test_get_protocol_components() {
1045        let mut server = Server::new_async().await;
1046        let server_resp = r#"
1047        {
1048            "protocol_components": [
1049                {
1050                    "id": "State1",
1051                    "protocol_system": "ambient",
1052                    "protocol_type_name": "Pool",
1053                    "chain": "ethereum",
1054                    "tokens": [
1055                        "0x0000000000000000000000000000000000000000",
1056                        "0x0000000000000000000000000000000000000001"
1057                    ],
1058                    "contract_ids": [
1059                        "0x0000000000000000000000000000000000000000"
1060                    ],
1061                    "static_attributes": {
1062                        "attribute_1": "0x00000000000003e8"
1063                    },
1064                    "change": "Creation",
1065                    "creation_tx": "0x0000000000000000000000000000000000000000000000000000000000000000",
1066                    "created_at": "2022-01-01T00:00:00"
1067                }
1068            ],
1069            "pagination": {
1070                "page": 0,
1071                "page_size": 20,
1072                "total": 10
1073            }
1074        }
1075        "#;
1076        // test that the response is deserialized correctly
1077        serde_json::from_str::<ProtocolComponentRequestResponse>(server_resp).expect("deserialize");
1078
1079        let mocked_server = server
1080            .mock("POST", "/v1/protocol_components")
1081            .expect(1)
1082            .with_body(server_resp)
1083            .create_async()
1084            .await;
1085
1086        let client = HttpRPCClient::new(server.url().as_str(), None).expect("create client");
1087
1088        let response = client
1089            .get_protocol_components(&Default::default())
1090            .await
1091            .expect("get state");
1092        let components = response.protocol_components;
1093
1094        mocked_server.assert();
1095        assert_eq!(components.len(), 1);
1096        assert_eq!(components[0].id, "State1");
1097        assert_eq!(components[0].protocol_system, "ambient");
1098        assert_eq!(components[0].protocol_type_name, "Pool");
1099        assert_eq!(components[0].tokens.len(), 2);
1100        let expected_attributes =
1101            [("attribute_1".to_string(), Bytes::from(1000_u64.to_be_bytes()))]
1102                .iter()
1103                .cloned()
1104                .collect::<HashMap<String, Bytes>>();
1105        assert_eq!(components[0].static_attributes, expected_attributes);
1106    }
1107
1108    #[tokio::test]
1109    async fn test_get_protocol_states() {
1110        let mut server = Server::new_async().await;
1111        let server_resp = r#"
1112        {
1113            "states": [
1114                {
1115                    "component_id": "State1",
1116                    "attributes": {
1117                        "attribute_1": "0x00000000000003e8"
1118                    },
1119                    "balances": {
1120                        "0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2": "0x01f4"
1121                    }
1122                }
1123            ],
1124            "pagination": {
1125                "page": 0,
1126                "page_size": 20,
1127                "total": 10
1128            }
1129        }
1130        "#;
1131        // test that the response is deserialized correctly
1132        serde_json::from_str::<ProtocolStateRequestResponse>(server_resp).expect("deserialize");
1133
1134        let mocked_server = server
1135            .mock("POST", "/v1/protocol_state")
1136            .expect(1)
1137            .with_body(server_resp)
1138            .create_async()
1139            .await;
1140        let client = HttpRPCClient::new(server.url().as_str(), None).expect("create client");
1141
1142        let response = client
1143            .get_protocol_states(&Default::default())
1144            .await
1145            .expect("get state");
1146        let states = response.states;
1147
1148        mocked_server.assert();
1149        assert_eq!(states.len(), 1);
1150        assert_eq!(states[0].component_id, "State1");
1151        let expected_attributes =
1152            [("attribute_1".to_string(), Bytes::from(1000_u64.to_be_bytes()))]
1153                .iter()
1154                .cloned()
1155                .collect::<HashMap<String, Bytes>>();
1156        assert_eq!(states[0].attributes, expected_attributes);
1157        let expected_balances = [(
1158            Bytes::from_str("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2")
1159                .expect("Unsupported address format"),
1160            Bytes::from_str("0x01f4").unwrap(),
1161        )]
1162        .iter()
1163        .cloned()
1164        .collect::<HashMap<Bytes, Bytes>>();
1165        assert_eq!(states[0].balances, expected_balances);
1166    }
1167
1168    #[tokio::test]
1169    async fn test_get_tokens() {
1170        let mut server = Server::new_async().await;
1171        let server_resp = r#"
1172        {
1173            "tokens": [
1174              {
1175                "chain": "ethereum",
1176                "address": "0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2",
1177                "symbol": "WETH",
1178                "decimals": 18,
1179                "tax": 0,
1180                "gas": [
1181                  29962
1182                ],
1183                "quality": 100
1184              },
1185              {
1186                "chain": "ethereum",
1187                "address": "0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48",
1188                "symbol": "USDC",
1189                "decimals": 6,
1190                "tax": 0,
1191                "gas": [
1192                  40652
1193                ],
1194                "quality": 100
1195              }
1196            ],
1197            "pagination": {
1198              "page": 0,
1199              "page_size": 20,
1200              "total": 10
1201            }
1202          }
1203        "#;
1204        // test that the response is deserialized correctly
1205        serde_json::from_str::<TokensRequestResponse>(server_resp).expect("deserialize");
1206
1207        let mocked_server = server
1208            .mock("POST", "/v1/tokens")
1209            .expect(1)
1210            .with_body(server_resp)
1211            .create_async()
1212            .await;
1213        let client = HttpRPCClient::new(server.url().as_str(), None).expect("create client");
1214
1215        let response = client
1216            .get_tokens(&Default::default())
1217            .await
1218            .expect("get tokens");
1219
1220        let expected = vec![
1221            ResponseToken {
1222                chain: Chain::Ethereum,
1223                address: Bytes::from_str("0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2").unwrap(),
1224                symbol: "WETH".to_string(),
1225                decimals: 18,
1226                tax: 0,
1227                gas: vec![Some(29962)],
1228                quality: 100,
1229            },
1230            ResponseToken {
1231                chain: Chain::Ethereum,
1232                address: Bytes::from_str("0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48").unwrap(),
1233                symbol: "USDC".to_string(),
1234                decimals: 6,
1235                tax: 0,
1236                gas: vec![Some(40652)],
1237                quality: 100,
1238            },
1239        ];
1240
1241        mocked_server.assert();
1242        assert_eq!(response.tokens, expected);
1243        assert_eq!(response.pagination, PaginationResponse { page: 0, page_size: 20, total: 10 });
1244    }
1245
1246    #[tokio::test]
1247    async fn test_get_protocol_systems() {
1248        let mut server = Server::new_async().await;
1249        let server_resp = r#"
1250        {
1251            "protocol_systems": [
1252                "system1",
1253                "system2"
1254            ],
1255            "pagination": {
1256                "page": 0,
1257                "page_size": 20,
1258                "total": 10
1259            }
1260        }
1261        "#;
1262        // test that the response is deserialized correctly
1263        serde_json::from_str::<ProtocolSystemsRequestResponse>(server_resp).expect("deserialize");
1264
1265        let mocked_server = server
1266            .mock("POST", "/v1/protocol_systems")
1267            .expect(1)
1268            .with_body(server_resp)
1269            .create_async()
1270            .await;
1271        let client = HttpRPCClient::new(server.url().as_str(), None).expect("create client");
1272
1273        let response = client
1274            .get_protocol_systems(&Default::default())
1275            .await
1276            .expect("get protocol systems");
1277        let protocol_systems = response.protocol_systems;
1278
1279        mocked_server.assert();
1280        assert_eq!(protocol_systems, vec!["system1", "system2"]);
1281    }
1282
1283    #[tokio::test]
1284    async fn test_get_component_tvl() {
1285        let mut server = Server::new_async().await;
1286        let server_resp = r#"
1287        {
1288            "tvl": {
1289                "component1": 100.0
1290            },
1291            "pagination": {
1292                "page": 0,
1293                "page_size": 20,
1294                "total": 10
1295            }
1296        }
1297        "#;
1298        // test that the response is deserialized correctly
1299        serde_json::from_str::<ComponentTvlRequestResponse>(server_resp).expect("deserialize");
1300
1301        let mocked_server = server
1302            .mock("POST", "/v1/component_tvl")
1303            .expect(1)
1304            .with_body(server_resp)
1305            .create_async()
1306            .await;
1307        let client = HttpRPCClient::new(server.url().as_str(), None).expect("create client");
1308
1309        let response = client
1310            .get_component_tvl(&Default::default())
1311            .await
1312            .expect("get protocol systems");
1313        let component_tvl = response.tvl;
1314
1315        mocked_server.assert();
1316        assert_eq!(component_tvl.get("component1"), Some(&100.0));
1317    }
1318
1319    #[tokio::test]
1320    async fn test_get_traced_entry_points() {
1321        let mut server = Server::new_async().await;
1322        let server_resp = r#"
1323        {
1324            "traced_entry_points": {
1325                "component_1": [
1326                    [
1327                        {
1328                            "entry_point": {
1329                                "external_id": "entrypoint_a",
1330                                "target": "0x0000000000000000000000000000000000000001",
1331                                "signature": "sig()"
1332                            },
1333                            "params": {
1334                                "method": "rpctracer",
1335                                "caller": "0x000000000000000000000000000000000000000a",
1336                                "calldata": "0x000000000000000000000000000000000000000b"
1337                            }
1338                        },
1339                        {
1340                            "retriggers": [
1341                                [
1342                                    "0x00000000000000000000000000000000000000aa",
1343                                    "0x0000000000000000000000000000000000000aaa"
1344                                ]
1345                            ],
1346                            "accessed_slots": {
1347                                "0x0000000000000000000000000000000000aaaa": [
1348                                    "0x0000000000000000000000000000000000aaaa"
1349                                ]
1350                            }
1351                        }
1352                    ]
1353                ]
1354            },
1355            "pagination": {
1356                "page": 0,
1357                "page_size": 20,
1358                "total": 1
1359            }
1360        }
1361        "#;
1362        // test that the response is deserialized correctly
1363        serde_json::from_str::<TracedEntryPointRequestResponse>(server_resp).expect("deserialize");
1364
1365        let mocked_server = server
1366            .mock("POST", "/v1/traced_entry_points")
1367            .expect(1)
1368            .with_body(server_resp)
1369            .create_async()
1370            .await;
1371        let client = HttpRPCClient::new(server.url().as_str(), None).expect("create client");
1372
1373        let response = client
1374            .get_traced_entry_points(&Default::default())
1375            .await
1376            .expect("get traced entry points");
1377        let entrypoints = response.traced_entry_points;
1378
1379        mocked_server.assert();
1380        assert_eq!(entrypoints.len(), 1);
1381        let comp1_entrypoints = entrypoints
1382            .get("component_1")
1383            .expect("component_1 entrypoints should exist");
1384        assert_eq!(comp1_entrypoints.len(), 1);
1385
1386        let (entrypoint, trace_result) = &comp1_entrypoints[0];
1387        assert_eq!(entrypoint.entry_point.external_id, "entrypoint_a");
1388        assert_eq!(
1389            entrypoint.entry_point.target,
1390            Bytes::from_str("0x0000000000000000000000000000000000000001").unwrap()
1391        );
1392        assert_eq!(entrypoint.entry_point.signature, "sig()");
1393        let TracingParams::RPCTracer(rpc_params) = &entrypoint.params;
1394        assert_eq!(
1395            rpc_params.caller,
1396            Some(Bytes::from("0x000000000000000000000000000000000000000a"))
1397        );
1398        assert_eq!(rpc_params.calldata, Bytes::from("0x000000000000000000000000000000000000000b"));
1399
1400        assert_eq!(
1401            trace_result.retriggers,
1402            HashSet::from([(
1403                Bytes::from("0x00000000000000000000000000000000000000aa"),
1404                Bytes::from("0x0000000000000000000000000000000000000aaa")
1405            )])
1406        );
1407        assert_eq!(trace_result.accessed_slots.len(), 1);
1408        assert_eq!(
1409            trace_result.accessed_slots,
1410            HashMap::from([(
1411                Bytes::from("0x0000000000000000000000000000000000aaaa"),
1412                HashSet::from([Bytes::from("0x0000000000000000000000000000000000aaaa")])
1413            )])
1414        );
1415    }
1416}