1use std::collections::{HashMap, HashSet};
2
3use tracing::{debug, instrument, warn};
4use tycho_common::models::{
5 blockchain::{BlockAggregatedChanges, DCIUpdate},
6 protocol::ProtocolComponent,
7 Address, Chain, ComponentId, ProtocolSystem,
8};
9
10use crate::{
11 rpc::{ProtocolComponentsPaginatedParams, RPCClient, RPC_CLIENT_CONCURRENCY},
12 RPCError,
13};
14
15#[derive(Clone, Debug)]
16pub(crate) enum ComponentFilterVariant {
17 Ids(Vec<ComponentId>),
18 MinimumTVLRange {
26 range: (f64, f64),
27 blocklisted_ids: HashSet<ComponentId>,
28 },
29}
30
31#[derive(Clone, Debug)]
32pub struct ComponentFilter {
33 variant: ComponentFilterVariant,
34}
35
36impl ComponentFilter {
37 #[allow(non_snake_case)] #[deprecated(since = "0.9.2", note = "Please use with_tvl_range instead")]
46 pub fn MinimumTVL(min_tvl: f64) -> ComponentFilter {
47 ComponentFilter {
48 variant: ComponentFilterVariant::MinimumTVLRange {
49 range: (min_tvl, min_tvl),
50 blocklisted_ids: HashSet::new(),
51 },
52 }
53 }
54
55 pub fn with_tvl_range(remove_tvl_threshold: f64, add_tvl_threshold: f64) -> ComponentFilter {
70 ComponentFilter {
71 variant: ComponentFilterVariant::MinimumTVLRange {
72 range: (remove_tvl_threshold, add_tvl_threshold),
73 blocklisted_ids: HashSet::new(),
74 },
75 }
76 }
77
78 #[allow(non_snake_case)] pub fn Ids(ids: Vec<ComponentId>) -> ComponentFilter {
87 ComponentFilter {
88 variant: ComponentFilterVariant::Ids(
89 ids.into_iter()
90 .map(|id| id.to_lowercase())
91 .collect(),
92 ),
93 }
94 }
95
96 pub fn blocklist(mut self, ids: impl IntoIterator<Item = ComponentId>) -> Self {
102 match &mut self.variant {
103 ComponentFilterVariant::Ids(_) => {
104 warn!(
105 "blocklist() has no effect on ComponentFilter::Ids; \
106 remove the component from the ID list instead"
107 );
108 }
109 ComponentFilterVariant::MinimumTVLRange { blocklisted_ids, .. } => {
110 blocklisted_ids.extend(
111 ids.into_iter()
112 .map(|id| id.to_lowercase()),
113 );
114 }
115 }
116 self
117 }
118
119 pub fn is_blocklisted(&self, id: &str) -> bool {
121 match &self.variant {
122 ComponentFilterVariant::Ids(_) => false,
123 ComponentFilterVariant::MinimumTVLRange { blocklisted_ids, .. } => {
124 blocklisted_ids.contains(&id.to_lowercase())
125 }
126 }
127 }
128}
129
130#[derive(Default)]
133struct EntrypointRelations {
134 components: HashSet<ComponentId>,
136 contracts: HashSet<Address>,
138}
139
140pub struct ComponentTracker<R: RPCClient> {
142 chain: Chain,
143 protocol_system: ProtocolSystem,
144 filter: ComponentFilter,
145 pub components: HashMap<ComponentId, ProtocolComponent>,
148 entrypoints: HashMap<String, EntrypointRelations>,
150 pub contracts: HashSet<Address>,
153 rpc_client: R,
155}
156
157impl<R> ComponentTracker<R>
158where
159 R: RPCClient,
160{
161 pub fn new(chain: Chain, protocol_system: &str, filter: ComponentFilter, rpc: R) -> Self {
162 Self {
163 chain,
164 protocol_system: protocol_system.to_string(),
165 filter,
166 components: Default::default(),
167 contracts: Default::default(),
168 rpc_client: rpc,
169 entrypoints: Default::default(),
170 }
171 }
172
173 pub async fn initialise_components(&mut self) -> Result<(), RPCError> {
176 let (component_ids, tvl_gt) = match &self.filter.variant {
177 ComponentFilterVariant::Ids(ids) => (Some(ids.clone()), None),
178 ComponentFilterVariant::MinimumTVLRange { range: (_, upper_tvl_threshold), .. } => {
179 (None, Some(*upper_tvl_threshold))
180 }
181 };
182 let mut paginated_params = ProtocolComponentsPaginatedParams::new(
183 self.chain,
184 self.protocol_system.as_str(),
185 RPC_CLIENT_CONCURRENCY,
186 );
187 if let Some(ids) = component_ids {
188 paginated_params = paginated_params.with_component_ids(ids);
189 }
190 if let Some(tvl) = tvl_gt {
191 paginated_params = paginated_params.with_tvl_gt(tvl);
192 }
193
194 self.components = self
195 .rpc_client
196 .get_protocol_components_paginated(paginated_params)
197 .await?
198 .into_iter()
199 .map(|comp| (comp.id.clone(), comp))
200 .filter(|(id, _)| !self.filter.is_blocklisted(id))
201 .collect::<HashMap<_, _>>();
202
203 self.reinitialize_contracts();
204
205 Ok(())
206 }
207
208 fn reinitialize_contracts(&mut self) {
210 self.contracts = self
212 .components
213 .values()
214 .flat_map(|comp| comp.contract_addresses.iter().cloned())
215 .collect();
216
217 let tracked_component_ids = self
219 .components
220 .keys()
221 .cloned()
222 .collect::<HashSet<_>>();
223 for entrypoint in self.entrypoints.values() {
224 if !entrypoint
225 .components
226 .is_disjoint(&tracked_component_ids)
227 {
228 self.contracts
229 .extend(entrypoint.contracts.iter().cloned());
230 }
231 }
232 }
233
234 pub(crate) fn update_contracts(&mut self, components: Vec<ComponentId>) {
236 let mut tracked_component_ids = HashSet::new();
238
239 for comp in components {
241 if let Some(component) = self.components.get(&comp) {
242 self.contracts.extend(
243 component
244 .contract_addresses
245 .iter()
246 .cloned(),
247 );
248 tracked_component_ids.insert(comp);
249 }
250 }
251
252 for entrypoint in self.entrypoints.values() {
254 if !entrypoint
255 .components
256 .is_disjoint(&tracked_component_ids)
257 {
258 self.contracts
259 .extend(entrypoint.contracts.iter().cloned());
260 }
261 }
262 }
263
264 #[instrument(skip(self, new_components))]
266 pub async fn start_tracking(
267 &mut self,
268 new_components: &[&ComponentId],
269 ) -> Result<(), RPCError> {
270 let new_components: Vec<_> = new_components
271 .iter()
272 .filter(|id| !self.filter.is_blocklisted(id))
273 .copied()
274 .collect();
275
276 if new_components.is_empty() {
277 return Ok(());
278 }
279
280 let components = self
281 .rpc_client
282 .get_protocol_components(
283 crate::rpc::ProtocolComponentsParams::new(
284 self.chain,
285 self.protocol_system.as_str(),
286 )
287 .with_component_ids(
288 new_components
289 .into_iter()
290 .cloned()
291 .collect(),
292 ),
293 )
294 .await?
295 .into_iter()
296 .map(|comp| (comp.id.clone(), comp))
297 .collect::<HashMap<_, _>>();
298
299 let component_ids: Vec<_> = components.keys().cloned().collect();
301 let component_count = component_ids.len();
302 self.components.extend(components);
303 self.update_contracts(component_ids);
304
305 debug!(n_components = component_count, "StartedTracking");
306 Ok(())
307 }
308
309 #[instrument(skip(self, to_remove))]
311 pub fn stop_tracking<'a, I: IntoIterator<Item = &'a ComponentId> + std::fmt::Debug>(
312 &mut self,
313 to_remove: I,
314 ) -> HashMap<ComponentId, ProtocolComponent> {
315 let mut removed_components = HashMap::new();
316
317 for component_id in to_remove {
318 if let Some(component) = self.components.remove(component_id) {
319 removed_components.insert(component_id.clone(), component);
320 }
321 }
322
323 self.reinitialize_contracts();
326
327 debug!(n_components = removed_components.len(), "StoppedTracking");
328 removed_components
329 }
330
331 pub fn process_entrypoints(&mut self, dci_update: &DCIUpdate) {
333 for (entrypoint, traces) in &dci_update.trace_results {
335 self.entrypoints
336 .entry(entrypoint.clone())
337 .or_default()
338 .contracts
339 .extend(traces.accessed_slots.keys().cloned());
340 }
341
342 for (component, entrypoints) in &dci_update.new_entrypoints {
344 for entrypoint in entrypoints {
345 let entrypoint_info = self
346 .entrypoints
347 .entry(entrypoint.external_id.clone())
348 .or_default();
349 entrypoint_info
350 .components
351 .insert(component.clone());
352 if self.components.contains_key(component) {
354 self.contracts.extend(
355 entrypoint_info
356 .contracts
357 .iter()
358 .cloned(),
359 );
360 }
361 }
362 }
363 }
364
365 pub fn get_contracts_by_component<'a, I: IntoIterator<Item = &'a String>>(
376 &self,
377 ids: I,
378 ) -> HashSet<Address> {
379 ids.into_iter()
380 .filter_map(|cid| {
381 if let Some(comp) = self.components.get(cid) {
382 let dci_contracts: HashSet<Address> = self
384 .entrypoints
385 .values()
386 .filter(|ep| ep.components.contains(cid))
387 .flat_map(|ep| ep.contracts.iter().cloned())
388 .collect();
389 Some(
390 comp.contract_addresses
391 .clone()
392 .into_iter()
393 .chain(dci_contracts)
394 .collect::<HashSet<_>>(),
395 )
396 } else {
397 warn!(
398 "Requested component is not tracked: {cid}. Skipping fetching contracts..."
399 );
400 None
401 }
402 })
403 .flatten()
404 .collect()
405 }
406
407 pub fn get_tracked_component_ids(&self) -> Vec<ComponentId> {
408 self.components
409 .keys()
410 .cloned()
411 .collect()
412 }
413
414 pub fn filter_updated_components(
417 &self,
418 deltas: &BlockAggregatedChanges,
419 ) -> (Vec<ComponentId>, Vec<ComponentId>) {
420 match &self.filter.variant {
421 ComponentFilterVariant::Ids(_) => (Default::default(), Default::default()),
422 ComponentFilterVariant::MinimumTVLRange { range: (remove_tvl, add_tvl), .. } => {
423 let (mut to_add, mut to_remove): (Vec<_>, Vec<_>) = deltas
424 .component_tvl
425 .iter()
426 .filter(|(_, &tvl)| tvl < *remove_tvl || tvl > *add_tvl)
427 .map(|(id, _)| id.clone())
428 .partition(|id| deltas.component_tvl[id] > *add_tvl);
429
430 to_add.retain(|id| !self.filter.is_blocklisted(id));
432
433 for id in self.components.keys() {
435 if self.filter.is_blocklisted(id) && !to_remove.contains(id) {
436 to_remove.push(id.clone());
437 }
438 }
439
440 (to_add, to_remove)
441 }
442 }
443 }
444}
445
446#[cfg(test)]
447mod test {
448 use tycho_common::Bytes;
449
450 use super::*;
451 use crate::rpc::MockRPCClient;
452
453 fn with_mocked_rpc() -> ComponentTracker<MockRPCClient> {
454 let rpc = MockRPCClient::new();
455 ComponentTracker::new(
456 Chain::Ethereum,
457 "uniswap-v2",
458 ComponentFilter::with_tvl_range(0.0, 0.0),
459 rpc,
460 )
461 }
462
463 fn components_response() -> (Vec<Address>, ProtocolComponent) {
464 let contract_addresses = vec![Bytes::from("0x1234"), Bytes::from("0xbabe")];
465 let component = ProtocolComponent {
466 id: "Component1".to_string(),
467 contract_addresses: contract_addresses.clone(),
468 ..Default::default()
469 };
470 (contract_addresses, component)
471 }
472
473 #[tokio::test]
474 async fn test_initialise_components() {
475 let mut tracker = with_mocked_rpc();
476 let (contract_ids, model_component) = components_response();
477 let exp_component = model_component.clone();
478 let model_for_mock = model_component.clone();
479 tracker
480 .rpc_client
481 .expect_get_protocol_components_paginated()
482 .returning(move |_| Ok(vec![model_for_mock.clone()]));
483
484 tracker
485 .initialise_components()
486 .await
487 .expect("Retrieving components failed");
488
489 assert_eq!(
490 tracker
491 .components
492 .get("Component1")
493 .expect("Component1 not tracked"),
494 &exp_component
495 );
496 assert_eq!(tracker.contracts, contract_ids.into_iter().collect());
497 }
498
499 #[tokio::test]
500 async fn test_start_tracking() {
501 let mut tracker = with_mocked_rpc();
502 let (contract_ids, model_component) = components_response();
503 let exp_contracts = contract_ids.into_iter().collect();
504 let component_id = model_component.id.clone();
505 let components_arg = [&component_id];
506 let model_for_mock = model_component.clone();
507 tracker
508 .rpc_client
509 .expect_get_protocol_components()
510 .returning(move |_| Ok(crate::rpc::Page::new(vec![model_for_mock.clone()], 1, 0, 100)));
511
512 tracker
513 .start_tracking(&components_arg)
514 .await
515 .expect("Tracking components failed");
516
517 assert_eq!(&tracker.contracts, &exp_contracts);
518 assert!(tracker
519 .components
520 .contains_key("Component1"));
521 }
522
523 #[test]
524 fn test_stop_tracking() {
525 let mut tracker = with_mocked_rpc();
526 let (contract_ids, model_component) = components_response();
527 tracker
528 .components
529 .insert("Component1".to_string(), model_component.clone());
530 tracker.contracts.extend(contract_ids);
531 let components_arg = ["Component1".to_string(), "Component2".to_string()];
532 let exp = [("Component1".to_string(), model_component)]
533 .into_iter()
534 .collect();
535
536 let res = tracker.stop_tracking(&components_arg);
537
538 assert_eq!(res, exp);
539 assert!(tracker.contracts.is_empty());
540 }
541
542 #[test]
543 fn test_get_contracts_by_component() {
544 let mut tracker = with_mocked_rpc();
545 let (exp_contracts, model_component) = components_response();
546 tracker
547 .components
548 .insert("Component1".to_string(), model_component);
549 let components_arg = ["Component1".to_string()];
550
551 let res = tracker.get_contracts_by_component(&components_arg);
552
553 assert_eq!(res, exp_contracts.into_iter().collect());
554 }
555
556 #[test]
557 fn test_get_tracked_component_ids() {
558 let mut tracker = with_mocked_rpc();
559 let (_, model_component) = components_response();
560 tracker
561 .components
562 .insert("Component1".to_string(), model_component);
563 let exp = vec!["Component1".to_string()];
564
565 let res = tracker.get_tracked_component_ids();
566
567 assert_eq!(res, exp);
568 }
569
570 fn with_mocked_rpc_and_blocklist(blocklisted: Vec<&str>) -> ComponentTracker<MockRPCClient> {
571 let rpc = MockRPCClient::new();
572 let filter = ComponentFilter::with_tvl_range(0.0, 0.0).blocklist(
573 blocklisted
574 .into_iter()
575 .map(String::from),
576 );
577 ComponentTracker::new(Chain::Ethereum, "uniswap-v2", filter, rpc)
578 }
579
580 #[tokio::test]
581 async fn test_initialise_skips_blocklisted_components() {
582 let mut tracker = with_mocked_rpc_and_blocklist(vec!["component1"]);
583 let (_, model_component) = components_response();
584 tracker
585 .rpc_client
586 .expect_get_protocol_components_paginated()
587 .returning(move |_| Ok(vec![model_component.clone()]));
588
589 tracker
590 .initialise_components()
591 .await
592 .expect("Retrieving components failed");
593
594 assert!(tracker.components.is_empty(), "Blocklisted component should not be in tracker");
595 }
596
597 #[tokio::test]
598 async fn test_start_tracking_skips_blocklisted() {
599 let mut tracker = with_mocked_rpc_and_blocklist(vec!["component1"]);
600 let component_id = "Component1".to_string();
601 let components_arg = [&component_id];
602
603 tracker
604 .start_tracking(&components_arg)
605 .await
606 .expect("start_tracking should succeed");
607
608 assert!(tracker.components.is_empty(), "Blocklisted component should not be tracked");
609 }
610
611 #[test]
612 fn test_filter_updated_blocks_blocklisted_add() {
613 let mut tracker = with_mocked_rpc_and_blocklist(vec!["blocklisted_pool"]);
614 tracker.filter = ComponentFilter::with_tvl_range(5.0, 10.0)
615 .blocklist(vec!["blocklisted_pool".to_string()]);
616
617 let deltas = BlockAggregatedChanges {
618 component_tvl: HashMap::from([
619 ("blocklisted_pool".to_string(), 100.0),
620 ("allowed_pool".to_string(), 100.0),
621 ]),
622 ..Default::default()
623 };
624
625 let (to_add, to_remove) = tracker.filter_updated_components(&deltas);
626 assert!(
627 !to_add.contains(&"blocklisted_pool".to_string()),
628 "Blocklisted component should not be in to_add"
629 );
630 assert!(
631 to_add.contains(&"allowed_pool".to_string()),
632 "Non-blocklisted component should be in to_add"
633 );
634 assert!(to_remove.is_empty());
635 }
636}