1use std::collections::{HashMap, HashSet};
2
3use tracing::{debug, instrument, warn};
4use tycho_common::{
5 dto::{BlockChanges, Chain, ProtocolComponent, ProtocolComponentsRequestBody},
6 Bytes,
7};
8
9use crate::{rpc::RPCClient, RPCError};
10
11#[derive(Clone, Debug)]
12pub(crate) enum ComponentFilterVariant {
13 Ids(Vec<String>),
14 MinimumTVLRange((f64, f64)),
20}
21
22#[derive(Clone, Debug)]
23pub struct ComponentFilter {
24 variant: ComponentFilterVariant,
25}
26
27impl ComponentFilter {
28 #[allow(non_snake_case)] #[deprecated(since = "0.9.2", note = "Please use with_tvl_range instead")]
37 pub fn MinimumTVL(min_tvl: f64) -> ComponentFilter {
38 ComponentFilter { variant: ComponentFilterVariant::MinimumTVLRange((min_tvl, min_tvl)) }
39 }
40
41 pub fn with_tvl_range(remove_tvl_threshold: f64, add_tvl_threshold: f64) -> ComponentFilter {
56 ComponentFilter {
57 variant: ComponentFilterVariant::MinimumTVLRange((
58 remove_tvl_threshold,
59 add_tvl_threshold,
60 )),
61 }
62 }
63
64 #[allow(non_snake_case)] pub fn Ids(ids: Vec<String>) -> ComponentFilter {
73 ComponentFilter { variant: ComponentFilterVariant::Ids(ids) }
74 }
75}
76
77pub struct ComponentTracker<R: RPCClient> {
79 chain: Chain,
80 protocol_system: String,
81 filter: ComponentFilter,
82 pub components: HashMap<String, ProtocolComponent>,
85 pub contracts: HashSet<Bytes>,
88 rpc_client: R,
90}
91
92impl<R> ComponentTracker<R>
93where
94 R: RPCClient,
95{
96 pub fn new(chain: Chain, protocol_system: &str, filter: ComponentFilter, rpc: R) -> Self {
97 Self {
98 chain,
99 protocol_system: protocol_system.to_string(),
100 filter,
101 components: Default::default(),
102 contracts: Default::default(),
103 rpc_client: rpc,
104 }
105 }
106 pub async fn initialise_components(&mut self) -> Result<(), RPCError> {
108 let body = match &self.filter.variant {
109 ComponentFilterVariant::Ids(ids) => ProtocolComponentsRequestBody::id_filtered(
110 &self.protocol_system,
111 ids.clone(),
112 self.chain,
113 ),
114 ComponentFilterVariant::MinimumTVLRange((_, upper_tvl_threshold)) => {
115 ProtocolComponentsRequestBody::system_filtered(
116 &self.protocol_system,
117 Some(*upper_tvl_threshold),
118 self.chain,
119 )
120 }
121 };
122
123 self.components = self
124 .rpc_client
125 .get_protocol_components_paginated(&body, 500, 4)
126 .await?
127 .protocol_components
128 .into_iter()
129 .map(|pc| (pc.id.clone(), pc))
130 .collect::<HashMap<_, _>>();
131 self.update_contracts();
132 Ok(())
133 }
134
135 fn update_contracts(&mut self) {
136 self.contracts.extend(
137 self.components
138 .values()
139 .flat_map(|comp| comp.contract_ids.iter().cloned()),
140 );
141 }
142
143 #[instrument(skip(self, new_components))]
145 pub async fn start_tracking(&mut self, new_components: &[&String]) -> Result<(), RPCError> {
146 if new_components.is_empty() {
147 return Ok(());
148 }
149 let request = ProtocolComponentsRequestBody::id_filtered(
150 &self.protocol_system,
151 new_components
152 .iter()
153 .map(|pc_id| pc_id.to_string())
154 .collect(),
155 self.chain,
156 );
157
158 self.components.extend(
159 self.rpc_client
160 .get_protocol_components(&request)
161 .await?
162 .protocol_components
163 .into_iter()
164 .map(|pc| (pc.id.clone(), pc)),
165 );
166 self.update_contracts();
167 debug!(n_components = new_components.len(), "StartedTracking");
168 Ok(())
169 }
170
171 #[instrument(skip(self, to_remove))]
173 pub fn stop_tracking<'a, I: IntoIterator<Item = &'a String> + std::fmt::Debug>(
174 &mut self,
175 to_remove: I,
176 ) -> HashMap<String, ProtocolComponent> {
177 let mut n_components = 0;
178 let res = to_remove
179 .into_iter()
180 .filter_map(|k| {
181 let comp = self.components.remove(k);
182 if let Some(component) = &comp {
183 n_components += 1;
184 for contract in component.contract_ids.iter() {
185 self.contracts.remove(contract);
186 }
187 }
188 comp.map(|c| (k.clone(), c))
189 })
190 .collect();
191 debug!(n_components, "StoppedTracking");
192 res
193 }
194
195 pub fn get_contracts_by_component<'a, I: IntoIterator<Item = &'a String>>(
206 &self,
207 ids: I,
208 ) -> HashSet<Bytes> {
209 ids.into_iter()
210 .filter_map(|cid| {
211 if let Some(comp) = self.components.get(cid) {
212 Some(comp.contract_ids.clone())
213 } else {
214 warn!(
215 "Requested component is not tracked: {cid}. Skipping fetching contracts..."
216 );
217 None
218 }
219 })
220 .flatten()
221 .collect()
222 }
223
224 pub fn get_tracked_component_ids(&self) -> Vec<String> {
225 self.components
226 .keys()
227 .cloned()
228 .collect()
229 }
230
231 pub fn filter_updated_components(&self, deltas: &BlockChanges) -> (Vec<String>, Vec<String>) {
234 match &self.filter.variant {
235 ComponentFilterVariant::Ids(_) => (Default::default(), Default::default()),
236 ComponentFilterVariant::MinimumTVLRange((remove_tvl, add_tvl)) => deltas
237 .component_tvl
238 .iter()
239 .filter(|(_, &tvl)| tvl < *remove_tvl || tvl > *add_tvl)
240 .map(|(id, _)| id.clone())
241 .partition(|id| deltas.component_tvl[id] > *add_tvl),
242 }
243 }
244}
245
246#[cfg(test)]
247mod test {
248 use tycho_common::dto::{PaginationResponse, ProtocolComponentRequestResponse};
249
250 use super::*;
251 use crate::rpc::MockRPCClient;
252
253 fn with_mocked_rpc() -> ComponentTracker<MockRPCClient> {
254 let rpc = MockRPCClient::new();
255 ComponentTracker::new(
256 Chain::Ethereum,
257 "uniswap-v2",
258 ComponentFilter::with_tvl_range(0.0, 0.0),
259 rpc,
260 )
261 }
262
263 fn components_response() -> (Vec<Bytes>, ProtocolComponent) {
264 let contract_ids = vec![Bytes::from("0x1234"), Bytes::from("0xbabe")];
265 let component = ProtocolComponent {
266 id: "Component1".to_string(),
267 contract_ids: contract_ids.clone(),
268 ..Default::default()
269 };
270 (contract_ids, component)
271 }
272
273 #[tokio::test]
274 async fn test_initialise_components() {
275 let mut tracker = with_mocked_rpc();
276 let (contract_ids, component) = components_response();
277 let exp_component = component.clone();
278 tracker
279 .rpc_client
280 .expect_get_protocol_components_paginated()
281 .returning(move |_, _, _| {
282 Ok(ProtocolComponentRequestResponse {
283 protocol_components: vec![component.clone()],
284 pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
285 })
286 });
287
288 tracker
289 .initialise_components()
290 .await
291 .expect("Retrieving components failed");
292
293 assert_eq!(
294 tracker
295 .components
296 .get("Component1")
297 .expect("Component1 not tracked"),
298 &exp_component
299 );
300 assert_eq!(tracker.contracts, contract_ids.into_iter().collect());
301 }
302
303 #[tokio::test]
304 async fn test_start_tracking() {
305 let mut tracker = with_mocked_rpc();
306 let (contract_ids, component) = components_response();
307 let exp_contracts = contract_ids.into_iter().collect();
308 let component_id = component.id.clone();
309 let components_arg = [&component_id];
310 tracker
311 .rpc_client
312 .expect_get_protocol_components()
313 .returning(move |_| {
314 Ok(ProtocolComponentRequestResponse {
315 protocol_components: vec![component.clone()],
316 pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
317 })
318 });
319
320 tracker
321 .start_tracking(&components_arg)
322 .await
323 .expect("Tracking components failed");
324
325 assert_eq!(&tracker.contracts, &exp_contracts);
326 assert!(tracker
327 .components
328 .contains_key("Component1"));
329 }
330
331 #[test]
332 fn test_stop_tracking() {
333 let mut tracker = with_mocked_rpc();
334 let (contract_ids, component) = components_response();
335 tracker
336 .components
337 .insert("Component1".to_string(), component.clone());
338 tracker.contracts.extend(contract_ids);
339 let components_arg = ["Component1".to_string(), "Component2".to_string()];
340 let exp = [("Component1".to_string(), component)]
341 .into_iter()
342 .collect();
343
344 let res = tracker.stop_tracking(&components_arg);
345
346 assert_eq!(res, exp);
347 assert!(tracker.contracts.is_empty());
348 }
349
350 #[test]
351 fn test_get_contracts_by_component() {
352 let mut tracker = with_mocked_rpc();
353 let (exp_contracts, component) = components_response();
354 tracker
355 .components
356 .insert("Component1".to_string(), component);
357 let components_arg = ["Component1".to_string()];
358
359 let res = tracker.get_contracts_by_component(&components_arg);
360
361 assert_eq!(res, exp_contracts.into_iter().collect());
362 }
363
364 #[test]
365 fn test_get_tracked_component_ids() {
366 let mut tracker = with_mocked_rpc();
367 let (_, component) = components_response();
368 tracker
369 .components
370 .insert("Component1".to_string(), component);
371 let exp = vec!["Component1".to_string()];
372
373 let res = tracker.get_tracked_component_ids();
374
375 assert_eq!(res, exp);
376 }
377}