1use std::collections::{HashMap, HashSet};
2
3use tracing::{debug, instrument, warn};
4use tycho_common::{
5 dto::{BlockChanges, Chain, DCIUpdate, ProtocolComponent, ProtocolComponentsRequestBody},
6 models::{Address, ComponentId, ProtocolSystem},
7};
8
9use crate::{
10 rpc::{RPCClient, RPC_CLIENT_CONCURRENCY},
11 RPCError,
12};
13
14#[derive(Clone, Debug)]
15pub(crate) enum ComponentFilterVariant {
16 Ids(Vec<ComponentId>),
17 MinimumTVLRange {
25 range: (f64, f64),
26 blocklisted_ids: HashSet<ComponentId>,
27 },
28}
29
30#[derive(Clone, Debug)]
31pub struct ComponentFilter {
32 variant: ComponentFilterVariant,
33}
34
35impl ComponentFilter {
36 #[allow(non_snake_case)] #[deprecated(since = "0.9.2", note = "Please use with_tvl_range instead")]
45 pub fn MinimumTVL(min_tvl: f64) -> ComponentFilter {
46 ComponentFilter {
47 variant: ComponentFilterVariant::MinimumTVLRange {
48 range: (min_tvl, min_tvl),
49 blocklisted_ids: HashSet::new(),
50 },
51 }
52 }
53
54 pub fn with_tvl_range(remove_tvl_threshold: f64, add_tvl_threshold: f64) -> ComponentFilter {
69 ComponentFilter {
70 variant: ComponentFilterVariant::MinimumTVLRange {
71 range: (remove_tvl_threshold, add_tvl_threshold),
72 blocklisted_ids: HashSet::new(),
73 },
74 }
75 }
76
77 #[allow(non_snake_case)] pub fn Ids(ids: Vec<ComponentId>) -> ComponentFilter {
86 ComponentFilter {
87 variant: ComponentFilterVariant::Ids(
88 ids.into_iter()
89 .map(|id| id.to_lowercase())
90 .collect(),
91 ),
92 }
93 }
94
95 pub fn blocklist(mut self, ids: impl IntoIterator<Item = ComponentId>) -> Self {
101 match &mut self.variant {
102 ComponentFilterVariant::Ids(_) => {
103 warn!(
104 "blocklist() has no effect on ComponentFilter::Ids; \
105 remove the component from the ID list instead"
106 );
107 }
108 ComponentFilterVariant::MinimumTVLRange { blocklisted_ids, .. } => {
109 blocklisted_ids.extend(
110 ids.into_iter()
111 .map(|id| id.to_lowercase()),
112 );
113 }
114 }
115 self
116 }
117
118 pub fn is_blocklisted(&self, id: &str) -> bool {
120 match &self.variant {
121 ComponentFilterVariant::Ids(_) => false,
122 ComponentFilterVariant::MinimumTVLRange { blocklisted_ids, .. } => {
123 blocklisted_ids.contains(&id.to_lowercase())
124 }
125 }
126 }
127}
128
129#[derive(Default)]
132struct EntrypointRelations {
133 components: HashSet<ComponentId>,
135 contracts: HashSet<Address>,
137}
138
139pub struct ComponentTracker<R: RPCClient> {
141 chain: Chain,
142 protocol_system: ProtocolSystem,
143 filter: ComponentFilter,
144 pub components: HashMap<ComponentId, ProtocolComponent>,
147 entrypoints: HashMap<String, EntrypointRelations>,
149 pub contracts: HashSet<Address>,
152 rpc_client: R,
154}
155
156impl<R> ComponentTracker<R>
157where
158 R: RPCClient,
159{
160 pub fn new(chain: Chain, protocol_system: &str, filter: ComponentFilter, rpc: R) -> Self {
161 Self {
162 chain,
163 protocol_system: protocol_system.to_string(),
164 filter,
165 components: Default::default(),
166 contracts: Default::default(),
167 rpc_client: rpc,
168 entrypoints: Default::default(),
169 }
170 }
171
172 pub async fn initialise_components(&mut self) -> Result<(), RPCError> {
175 let body = match &self.filter.variant {
176 ComponentFilterVariant::Ids(ids) => ProtocolComponentsRequestBody::id_filtered(
177 &self.protocol_system,
178 ids.clone(),
179 self.chain,
180 ),
181 ComponentFilterVariant::MinimumTVLRange { range: (_, upper_tvl_threshold), .. } => {
182 ProtocolComponentsRequestBody::system_filtered(
183 &self.protocol_system,
184 Some(*upper_tvl_threshold),
185 self.chain,
186 )
187 }
188 };
189 self.components = self
190 .rpc_client
191 .get_protocol_components_paginated(&body, None, RPC_CLIENT_CONCURRENCY)
192 .await?
193 .protocol_components
194 .into_iter()
195 .map(|pc| (pc.id.clone(), pc))
196 .filter(|(id, _)| !self.filter.is_blocklisted(id))
197 .collect::<HashMap<_, _>>();
198
199 self.reinitialize_contracts();
200
201 Ok(())
202 }
203
204 fn reinitialize_contracts(&mut self) {
206 self.contracts = self
208 .components
209 .values()
210 .flat_map(|comp| comp.contract_ids.iter().cloned())
211 .collect();
212
213 let tracked_component_ids = self
215 .components
216 .keys()
217 .cloned()
218 .collect::<HashSet<_>>();
219 for entrypoint in self.entrypoints.values() {
220 if !entrypoint
221 .components
222 .is_disjoint(&tracked_component_ids)
223 {
224 self.contracts
225 .extend(entrypoint.contracts.iter().cloned());
226 }
227 }
228 }
229
230 fn update_contracts(&mut self, components: Vec<ComponentId>) {
232 let mut tracked_component_ids = HashSet::new();
234
235 for comp in components {
237 if let Some(component) = self.components.get(&comp) {
238 self.contracts
239 .extend(component.contract_ids.iter().cloned());
240 tracked_component_ids.insert(comp);
241 }
242 }
243
244 for entrypoint in self.entrypoints.values() {
246 if !entrypoint
247 .components
248 .is_disjoint(&tracked_component_ids)
249 {
250 self.contracts
251 .extend(entrypoint.contracts.iter().cloned());
252 }
253 }
254 }
255
256 #[instrument(skip(self, new_components))]
258 pub async fn start_tracking(
259 &mut self,
260 new_components: &[&ComponentId],
261 ) -> Result<(), RPCError> {
262 let new_components: Vec<_> = new_components
263 .iter()
264 .filter(|id| !self.filter.is_blocklisted(id))
265 .copied()
266 .collect();
267
268 if new_components.is_empty() {
269 return Ok(());
270 }
271
272 let request = ProtocolComponentsRequestBody::id_filtered(
274 &self.protocol_system,
275 new_components
276 .iter()
277 .map(|&id| id.to_string())
278 .collect(),
279 self.chain,
280 );
281 let components = self
282 .rpc_client
283 .get_protocol_components(&request)
284 .await?
285 .protocol_components
286 .into_iter()
287 .map(|pc| (pc.id.clone(), pc))
288 .collect::<HashMap<_, _>>();
289
290 let component_ids: Vec<_> = components.keys().cloned().collect();
292 let component_count = component_ids.len();
293 self.components.extend(components);
294 self.update_contracts(component_ids);
295
296 debug!(n_components = component_count, "StartedTracking");
297 Ok(())
298 }
299
300 #[instrument(skip(self, to_remove))]
302 pub fn stop_tracking<'a, I: IntoIterator<Item = &'a ComponentId> + std::fmt::Debug>(
303 &mut self,
304 to_remove: I,
305 ) -> HashMap<ComponentId, ProtocolComponent> {
306 let mut removed_components = HashMap::new();
307
308 for component_id in to_remove {
309 if let Some(component) = self.components.remove(component_id) {
310 removed_components.insert(component_id.clone(), component);
311 }
312 }
313
314 self.reinitialize_contracts();
317
318 debug!(n_components = removed_components.len(), "StoppedTracking");
319 removed_components
320 }
321
322 pub fn process_entrypoints(&mut self, dci_update: &DCIUpdate) {
324 for (entrypoint, traces) in &dci_update.trace_results {
326 self.entrypoints
327 .entry(entrypoint.clone())
328 .or_default()
329 .contracts
330 .extend(traces.accessed_slots.keys().cloned());
331 }
332
333 for (component, entrypoints) in &dci_update.new_entrypoints {
335 for entrypoint in entrypoints {
336 let entrypoint_info = self
337 .entrypoints
338 .entry(entrypoint.external_id.clone())
339 .or_default();
340 entrypoint_info
341 .components
342 .insert(component.clone());
343 if self.components.contains_key(component) {
345 self.contracts.extend(
346 entrypoint_info
347 .contracts
348 .iter()
349 .cloned(),
350 );
351 }
352 }
353 }
354 }
355
356 pub fn get_contracts_by_component<'a, I: IntoIterator<Item = &'a String>>(
367 &self,
368 ids: I,
369 ) -> HashSet<Address> {
370 ids.into_iter()
371 .filter_map(|cid| {
372 if let Some(comp) = self.components.get(cid) {
373 let dci_contracts: HashSet<Address> = self
375 .entrypoints
376 .values()
377 .filter(|ep| ep.components.contains(cid))
378 .flat_map(|ep| ep.contracts.iter().cloned())
379 .collect();
380 Some(
381 comp.contract_ids
382 .clone()
383 .into_iter()
384 .chain(dci_contracts)
385 .collect::<HashSet<_>>(),
386 )
387 } else {
388 warn!(
389 "Requested component is not tracked: {cid}. Skipping fetching contracts..."
390 );
391 None
392 }
393 })
394 .flatten()
395 .collect()
396 }
397
398 pub fn get_tracked_component_ids(&self) -> Vec<ComponentId> {
399 self.components
400 .keys()
401 .cloned()
402 .collect()
403 }
404
405 pub fn filter_updated_components(
408 &self,
409 deltas: &BlockChanges,
410 ) -> (Vec<ComponentId>, Vec<ComponentId>) {
411 match &self.filter.variant {
412 ComponentFilterVariant::Ids(_) => (Default::default(), Default::default()),
413 ComponentFilterVariant::MinimumTVLRange { range: (remove_tvl, add_tvl), .. } => {
414 let (mut to_add, mut to_remove): (Vec<_>, Vec<_>) = deltas
415 .component_tvl
416 .iter()
417 .filter(|(_, &tvl)| tvl < *remove_tvl || tvl > *add_tvl)
418 .map(|(id, _)| id.clone())
419 .partition(|id| deltas.component_tvl[id] > *add_tvl);
420
421 to_add.retain(|id| !self.filter.is_blocklisted(id));
423
424 for id in self.components.keys() {
426 if self.filter.is_blocklisted(id) && !to_remove.contains(id) {
427 to_remove.push(id.clone());
428 }
429 }
430
431 (to_add, to_remove)
432 }
433 }
434 }
435}
436
437#[cfg(test)]
438mod test {
439 use tycho_common::{
440 dto::{PaginationResponse, ProtocolComponentRequestResponse},
441 Bytes,
442 };
443
444 use super::*;
445 use crate::rpc::MockRPCClient;
446
447 fn with_mocked_rpc() -> ComponentTracker<MockRPCClient> {
448 let rpc = MockRPCClient::new();
449 ComponentTracker::new(
450 Chain::Ethereum,
451 "uniswap-v2",
452 ComponentFilter::with_tvl_range(0.0, 0.0),
453 rpc,
454 )
455 }
456
457 fn components_response() -> (Vec<Address>, ProtocolComponent) {
458 let contract_ids = vec![Bytes::from("0x1234"), Bytes::from("0xbabe")];
459 let component = ProtocolComponent {
460 id: "Component1".to_string(),
461 contract_ids: contract_ids.clone(),
462 ..Default::default()
463 };
464 (contract_ids, component)
465 }
466
467 #[tokio::test]
468 async fn test_initialise_components() {
469 let mut tracker = with_mocked_rpc();
470 let (contract_ids, component) = components_response();
471 let exp_component = component.clone();
472 tracker
473 .rpc_client
474 .expect_get_protocol_components_paginated()
475 .returning(move |_, _, _| {
476 Ok(ProtocolComponentRequestResponse {
477 protocol_components: vec![component.clone()],
478 pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
479 })
480 });
481
482 tracker
483 .initialise_components()
484 .await
485 .expect("Retrieving components failed");
486
487 assert_eq!(
488 tracker
489 .components
490 .get("Component1")
491 .expect("Component1 not tracked"),
492 &exp_component
493 );
494 assert_eq!(tracker.contracts, contract_ids.into_iter().collect());
495 }
496
497 #[tokio::test]
498 async fn test_start_tracking() {
499 let mut tracker = with_mocked_rpc();
500 let (contract_ids, component) = components_response();
501 let exp_contracts = contract_ids.into_iter().collect();
502 let component_id = component.id.clone();
503 let components_arg = [&component_id];
504 tracker
505 .rpc_client
506 .expect_get_protocol_components()
507 .returning(move |_| {
508 Ok(ProtocolComponentRequestResponse {
509 protocol_components: vec![component.clone()],
510 pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
511 })
512 });
513
514 tracker
515 .start_tracking(&components_arg)
516 .await
517 .expect("Tracking components failed");
518
519 assert_eq!(&tracker.contracts, &exp_contracts);
520 assert!(tracker
521 .components
522 .contains_key("Component1"));
523 }
524
525 #[test]
526 fn test_stop_tracking() {
527 let mut tracker = with_mocked_rpc();
528 let (contract_ids, component) = components_response();
529 tracker
530 .components
531 .insert("Component1".to_string(), component.clone());
532 tracker.contracts.extend(contract_ids);
533 let components_arg = ["Component1".to_string(), "Component2".to_string()];
534 let exp = [("Component1".to_string(), component)]
535 .into_iter()
536 .collect();
537
538 let res = tracker.stop_tracking(&components_arg);
539
540 assert_eq!(res, exp);
541 assert!(tracker.contracts.is_empty());
542 }
543
544 #[test]
545 fn test_get_contracts_by_component() {
546 let mut tracker = with_mocked_rpc();
547 let (exp_contracts, component) = components_response();
548 tracker
549 .components
550 .insert("Component1".to_string(), component);
551 let components_arg = ["Component1".to_string()];
552
553 let res = tracker.get_contracts_by_component(&components_arg);
554
555 assert_eq!(res, exp_contracts.into_iter().collect());
556 }
557
558 #[test]
559 fn test_get_tracked_component_ids() {
560 let mut tracker = with_mocked_rpc();
561 let (_, component) = components_response();
562 tracker
563 .components
564 .insert("Component1".to_string(), component);
565 let exp = vec!["Component1".to_string()];
566
567 let res = tracker.get_tracked_component_ids();
568
569 assert_eq!(res, exp);
570 }
571
572 fn with_mocked_rpc_and_blocklist(blocklisted: Vec<&str>) -> ComponentTracker<MockRPCClient> {
573 let rpc = MockRPCClient::new();
574 let filter = ComponentFilter::with_tvl_range(0.0, 0.0).blocklist(
575 blocklisted
576 .into_iter()
577 .map(String::from),
578 );
579 ComponentTracker::new(Chain::Ethereum, "uniswap-v2", filter, rpc)
580 }
581
582 #[tokio::test]
583 async fn test_initialise_skips_blocklisted_components() {
584 let mut tracker = with_mocked_rpc_and_blocklist(vec!["component1"]);
585 let (_, component) = components_response();
586 tracker
587 .rpc_client
588 .expect_get_protocol_components_paginated()
589 .returning(move |_, _, _| {
590 Ok(ProtocolComponentRequestResponse {
591 protocol_components: vec![component.clone()],
592 pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
593 })
594 });
595
596 tracker
597 .initialise_components()
598 .await
599 .expect("Retrieving components failed");
600
601 assert!(tracker.components.is_empty(), "Blocklisted component should not be in tracker");
602 }
603
604 #[tokio::test]
605 async fn test_start_tracking_skips_blocklisted() {
606 let mut tracker = with_mocked_rpc_and_blocklist(vec!["component1"]);
607 let component_id = "Component1".to_string();
608 let components_arg = [&component_id];
609
610 tracker
611 .start_tracking(&components_arg)
612 .await
613 .expect("start_tracking should succeed");
614
615 assert!(tracker.components.is_empty(), "Blocklisted component should not be tracked");
616 }
617
618 #[test]
619 fn test_filter_updated_blocks_blocklisted_add() {
620 let mut tracker = with_mocked_rpc_and_blocklist(vec!["blocklisted_pool"]);
621 tracker.filter = ComponentFilter::with_tvl_range(5.0, 10.0)
622 .blocklist(vec!["blocklisted_pool".to_string()]);
623
624 let deltas = BlockChanges {
625 component_tvl: HashMap::from([
626 ("blocklisted_pool".to_string(), 100.0),
627 ("allowed_pool".to_string(), 100.0),
628 ]),
629 ..Default::default()
630 };
631
632 let (to_add, to_remove) = tracker.filter_updated_components(&deltas);
633 assert!(
634 !to_add.contains(&"blocklisted_pool".to_string()),
635 "Blocklisted component should not be in to_add"
636 );
637 assert!(
638 to_add.contains(&"allowed_pool".to_string()),
639 "Non-blocklisted component should be in to_add"
640 );
641 assert!(to_remove.is_empty());
642 }
643}