Skip to main content

tycho_client/feed/
component_tracker.rs

1use std::collections::{HashMap, HashSet};
2
3use tracing::{debug, instrument, warn};
4use tycho_common::models::{
5    blockchain::{BlockAggregatedChanges, DCIUpdate},
6    protocol::ProtocolComponent,
7    Address, Chain, ComponentId, ProtocolSystem,
8};
9
10use crate::{
11    rpc::{ProtocolComponentsPaginatedParams, RPCClient, RPC_CLIENT_CONCURRENCY},
12    RPCError,
13};
14
15#[derive(Clone, Debug)]
16pub(crate) enum ComponentFilterVariant {
17    Ids(Vec<ComponentId>),
18    /// MinimumTVLRange filters components by TVL thresholds:
19    /// - `range.0` (remove threshold): components below this are removed from tracking
20    /// - `range.1` (add threshold): components above this are added to tracking
21    ///
22    /// This helps buffer against components that fluctuate on the threshold boundary.
23    /// Thresholds are denominated in native token of the chain, for example 1 means 1 ETH on
24    /// ethereum.
25    MinimumTVLRange {
26        range: (f64, f64),
27        blocklisted_ids: HashSet<ComponentId>,
28    },
29}
30
31#[derive(Clone, Debug)]
32pub struct ComponentFilter {
33    variant: ComponentFilterVariant,
34}
35
36impl ComponentFilter {
37    /// Creates a `ComponentFilter` that filters components based on a minimum Total Value Locked
38    /// (TVL) threshold.
39    ///
40    /// # Arguments
41    ///
42    /// * `min_tvl` - The minimum TVL required for a component to be tracked. This is denominated in
43    ///   native token of the chain.
44    #[allow(non_snake_case)] // for backwards compatibility
45    #[deprecated(since = "0.9.2", note = "Please use with_tvl_range instead")]
46    pub fn MinimumTVL(min_tvl: f64) -> ComponentFilter {
47        ComponentFilter {
48            variant: ComponentFilterVariant::MinimumTVLRange {
49                range: (min_tvl, min_tvl),
50                blocklisted_ids: HashSet::new(),
51            },
52        }
53    }
54
55    /// Creates a `ComponentFilter` with a specified TVL range for adding or removing components
56    /// from tracking.
57    ///
58    /// Components that drop below the `remove_tvl_threshold` will be removed from tracking,
59    /// while components that exceed the `add_tvl_threshold` will be added to tracking.
60    /// This approach helps to reduce fluctuations caused by components hovering around a single
61    /// threshold.
62    ///
63    /// # Arguments
64    ///
65    /// * `remove_tvl_threshold` - The TVL below which a component will be removed from tracking.
66    /// * `add_tvl_threshold` - The TVL above which a component will be added to tracking.
67    ///
68    /// Note: thresholds are denominated in native token of the chain.
69    pub fn with_tvl_range(remove_tvl_threshold: f64, add_tvl_threshold: f64) -> ComponentFilter {
70        ComponentFilter {
71            variant: ComponentFilterVariant::MinimumTVLRange {
72                range: (remove_tvl_threshold, add_tvl_threshold),
73                blocklisted_ids: HashSet::new(),
74            },
75        }
76    }
77
78    /// Creates a `ComponentFilter` that **includes only** the components with the specified IDs,
79    /// effectively filtering out all other components.
80    ///
81    /// # Arguments
82    ///
83    /// * `ids` - A vector of component IDs to include in the filter. Only components with these IDs
84    ///   will be tracked.
85    #[allow(non_snake_case)] // for backwards compatibility
86    pub fn Ids(ids: Vec<ComponentId>) -> ComponentFilter {
87        ComponentFilter {
88            variant: ComponentFilterVariant::Ids(
89                ids.into_iter()
90                    .map(|id| id.to_lowercase())
91                    .collect(),
92            ),
93        }
94    }
95
96    /// Blocklist specific component IDs from tracking regardless of other
97    /// filter criteria. IDs are normalized to lowercase.
98    ///
99    /// Has no effect when the filter variant is `Ids`, since the
100    /// inclusion list already defines exactly which components to track.
101    pub fn blocklist(mut self, ids: impl IntoIterator<Item = ComponentId>) -> Self {
102        match &mut self.variant {
103            ComponentFilterVariant::Ids(_) => {
104                warn!(
105                    "blocklist() has no effect on ComponentFilter::Ids; \
106                     remove the component from the ID list instead"
107                );
108            }
109            ComponentFilterVariant::MinimumTVLRange { blocklisted_ids, .. } => {
110                blocklisted_ids.extend(
111                    ids.into_iter()
112                        .map(|id| id.to_lowercase()),
113                );
114            }
115        }
116        self
117    }
118
119    /// Returns true if the given component ID is blocklisted.
120    pub fn is_blocklisted(&self, id: &str) -> bool {
121        match &self.variant {
122            ComponentFilterVariant::Ids(_) => false,
123            ComponentFilterVariant::MinimumTVLRange { blocklisted_ids, .. } => {
124                blocklisted_ids.contains(&id.to_lowercase())
125            }
126        }
127    }
128}
129
130/// Information about an entrypoint, including which components use it and what contracts it
131/// interacts with
132#[derive(Default)]
133struct EntrypointRelations {
134    /// Set of component ids for components that have this entrypoint
135    components: HashSet<ComponentId>,
136    /// Set of detected contracts for the entrypoint
137    contracts: HashSet<Address>,
138}
139
140/// Helper struct to determine which components and contracts are being tracked atm.
141pub struct ComponentTracker<R: RPCClient> {
142    chain: Chain,
143    protocol_system: ProtocolSystem,
144    filter: ComponentFilter,
145    /// We will need to request a snapshot for components/contracts that we did not emit as
146    /// snapshot for yet but are relevant now, e.g. because min tvl threshold exceeded.
147    pub components: HashMap<ComponentId, ProtocolComponent>,
148    /// Map of entrypoint id to its associated components and contracts
149    entrypoints: HashMap<String, EntrypointRelations>,
150    /// Derived from tracked components. We need this if subscribed to a vm extractor because
151    /// updates are emitted on a contract level instead of a component level.
152    pub contracts: HashSet<Address>,
153    /// Client to retrieve necessary protocol components from the rpc.
154    rpc_client: R,
155}
156
157impl<R> ComponentTracker<R>
158where
159    R: RPCClient,
160{
161    pub fn new(chain: Chain, protocol_system: &str, filter: ComponentFilter, rpc: R) -> Self {
162        Self {
163            chain,
164            protocol_system: protocol_system.to_string(),
165            filter,
166            components: Default::default(),
167            contracts: Default::default(),
168            rpc_client: rpc,
169            entrypoints: Default::default(),
170        }
171    }
172
173    /// Retrieves all components that belong to the system we are streaming that have sufficient
174    /// tvl. Also detects which contracts are relevant for simulating on those components.
175    pub async fn initialise_components(&mut self) -> Result<(), RPCError> {
176        let (component_ids, tvl_gt) = match &self.filter.variant {
177            ComponentFilterVariant::Ids(ids) => (Some(ids.clone()), None),
178            ComponentFilterVariant::MinimumTVLRange { range: (_, upper_tvl_threshold), .. } => {
179                (None, Some(*upper_tvl_threshold))
180            }
181        };
182        let mut paginated_params = ProtocolComponentsPaginatedParams::new(
183            self.chain,
184            self.protocol_system.as_str(),
185            RPC_CLIENT_CONCURRENCY,
186        );
187        if let Some(ids) = component_ids {
188            paginated_params = paginated_params.with_component_ids(ids);
189        }
190        if let Some(tvl) = tvl_gt {
191            paginated_params = paginated_params.with_tvl_gt(tvl);
192        }
193
194        self.components = self
195            .rpc_client
196            .get_protocol_components_paginated(paginated_params)
197            .await?
198            .into_iter()
199            .map(|comp| (comp.id.clone(), comp))
200            .filter(|(id, _)| !self.filter.is_blocklisted(id))
201            .collect::<HashMap<_, _>>();
202
203        self.reinitialize_contracts();
204
205        Ok(())
206    }
207
208    /// Initialise the tracked contracts list from tracked components and their entrypoints
209    fn reinitialize_contracts(&mut self) {
210        // Add contracts from all tracked components
211        self.contracts = self
212            .components
213            .values()
214            .flat_map(|comp| comp.contract_addresses.iter().cloned())
215            .collect();
216
217        // Add contracts from entrypoints that are linked to tracked components
218        let tracked_component_ids = self
219            .components
220            .keys()
221            .cloned()
222            .collect::<HashSet<_>>();
223        for entrypoint in self.entrypoints.values() {
224            if !entrypoint
225                .components
226                .is_disjoint(&tracked_component_ids)
227            {
228                self.contracts
229                    .extend(entrypoint.contracts.iter().cloned());
230            }
231        }
232    }
233
234    /// Update the tracked contracts list with contracts associated with the given components
235    pub(crate) fn update_contracts(&mut self, components: Vec<ComponentId>) {
236        // Only process components that are actually being tracked.
237        let mut tracked_component_ids = HashSet::new();
238
239        // Add contracts from the components
240        for comp in components {
241            if let Some(component) = self.components.get(&comp) {
242                self.contracts.extend(
243                    component
244                        .contract_addresses
245                        .iter()
246                        .cloned(),
247                );
248                tracked_component_ids.insert(comp);
249            }
250        }
251
252        // Identify entrypoints linked to the given components
253        for entrypoint in self.entrypoints.values() {
254            if !entrypoint
255                .components
256                .is_disjoint(&tracked_component_ids)
257            {
258                self.contracts
259                    .extend(entrypoint.contracts.iter().cloned());
260            }
261        }
262    }
263
264    /// Add new components to be tracked
265    #[instrument(skip(self, new_components))]
266    pub async fn start_tracking(
267        &mut self,
268        new_components: &[&ComponentId],
269    ) -> Result<(), RPCError> {
270        let new_components: Vec<_> = new_components
271            .iter()
272            .filter(|id| !self.filter.is_blocklisted(id))
273            .copied()
274            .collect();
275
276        if new_components.is_empty() {
277            return Ok(());
278        }
279
280        let components = self
281            .rpc_client
282            .get_protocol_components(
283                crate::rpc::ProtocolComponentsParams::new(
284                    self.chain,
285                    self.protocol_system.as_str(),
286                )
287                .with_component_ids(
288                    new_components
289                        .into_iter()
290                        .cloned()
291                        .collect(),
292                ),
293            )
294            .await?
295            .into_iter()
296            .map(|comp| (comp.id.clone(), comp))
297            .collect::<HashMap<_, _>>();
298
299        // Update components and contracts
300        let component_ids: Vec<_> = components.keys().cloned().collect();
301        let component_count = component_ids.len();
302        self.components.extend(components);
303        self.update_contracts(component_ids);
304
305        debug!(n_components = component_count, "StartedTracking");
306        Ok(())
307    }
308
309    /// Stop tracking components
310    #[instrument(skip(self, to_remove))]
311    pub fn stop_tracking<'a, I: IntoIterator<Item = &'a ComponentId> + std::fmt::Debug>(
312        &mut self,
313        to_remove: I,
314    ) -> HashMap<ComponentId, ProtocolComponent> {
315        let mut removed_components = HashMap::new();
316
317        for component_id in to_remove {
318            if let Some(component) = self.components.remove(component_id) {
319                removed_components.insert(component_id.clone(), component);
320            }
321        }
322
323        // Refresh the tracked contracts list. This is more reliable and efficient than directly
324        // removing contracts from the list because some contracts are shared between components.
325        self.reinitialize_contracts();
326
327        debug!(n_components = removed_components.len(), "StoppedTracking");
328        removed_components
329    }
330
331    /// Updates the tracked entrypoints and contracts based on the given DCI data.
332    pub fn process_entrypoints(&mut self, dci_update: &DCIUpdate) {
333        // Update detected contracts for entrypoints
334        for (entrypoint, traces) in &dci_update.trace_results {
335            self.entrypoints
336                .entry(entrypoint.clone())
337                .or_default()
338                .contracts
339                .extend(traces.accessed_slots.keys().cloned());
340        }
341
342        // Update linked components for entrypoints
343        for (component, entrypoints) in &dci_update.new_entrypoints {
344            for entrypoint in entrypoints {
345                let entrypoint_info = self
346                    .entrypoints
347                    .entry(entrypoint.external_id.clone())
348                    .or_default();
349                entrypoint_info
350                    .components
351                    .insert(component.clone());
352                // If the component is tracked, add the detected contracts to the tracker
353                if self.components.contains_key(component) {
354                    self.contracts.extend(
355                        entrypoint_info
356                            .contracts
357                            .iter()
358                            .cloned(),
359                    );
360                }
361            }
362        }
363    }
364
365    /// Get related contracts for the given component ids. Assumes that the components are already
366    /// tracked, either by calling `start_tracking` or `initialise_components`.
367    ///
368    /// # Arguments
369    ///
370    /// * `ids` - A vector of component IDs to get the contracts for.
371    ///
372    /// # Returns
373    ///
374    /// A HashSet of contract IDs. Components that are not tracked will be logged and skipped.
375    pub fn get_contracts_by_component<'a, I: IntoIterator<Item = &'a String>>(
376        &self,
377        ids: I,
378    ) -> HashSet<Address> {
379        ids.into_iter()
380            .filter_map(|cid| {
381                if let Some(comp) = self.components.get(cid) {
382                    // Collect contracts from all entrypoints linked to this component
383                    let dci_contracts: HashSet<Address> = self
384                        .entrypoints
385                        .values()
386                        .filter(|ep| ep.components.contains(cid))
387                        .flat_map(|ep| ep.contracts.iter().cloned())
388                        .collect();
389                    Some(
390                        comp.contract_addresses
391                            .clone()
392                            .into_iter()
393                            .chain(dci_contracts)
394                            .collect::<HashSet<_>>(),
395                    )
396                } else {
397                    warn!(
398                        "Requested component is not tracked: {cid}. Skipping fetching contracts..."
399                    );
400                    None
401                }
402            })
403            .flatten()
404            .collect()
405    }
406
407    pub fn get_tracked_component_ids(&self) -> Vec<ComponentId> {
408        self.components
409            .keys()
410            .cloned()
411            .collect()
412    }
413
414    /// Given BlockAggregatedChanges, filter out components that are no longer relevant and return
415    /// the components that need to be added or removed.
416    pub fn filter_updated_components(
417        &self,
418        deltas: &BlockAggregatedChanges,
419    ) -> (Vec<ComponentId>, Vec<ComponentId>) {
420        match &self.filter.variant {
421            ComponentFilterVariant::Ids(_) => (Default::default(), Default::default()),
422            ComponentFilterVariant::MinimumTVLRange { range: (remove_tvl, add_tvl), .. } => {
423                let (mut to_add, mut to_remove): (Vec<_>, Vec<_>) = deltas
424                    .component_tvl
425                    .iter()
426                    .filter(|(_, &tvl)| tvl < *remove_tvl || tvl > *add_tvl)
427                    .map(|(id, _)| id.clone())
428                    .partition(|id| deltas.component_tvl[id] > *add_tvl);
429
430                // Never add blocklisted components
431                to_add.retain(|id| !self.filter.is_blocklisted(id));
432
433                // Remove any currently tracked components that are now blocklisted
434                for id in self.components.keys() {
435                    if self.filter.is_blocklisted(id) && !to_remove.contains(id) {
436                        to_remove.push(id.clone());
437                    }
438                }
439
440                (to_add, to_remove)
441            }
442        }
443    }
444}
445
446#[cfg(test)]
447mod test {
448    use tycho_common::Bytes;
449
450    use super::*;
451    use crate::rpc::MockRPCClient;
452
453    fn with_mocked_rpc() -> ComponentTracker<MockRPCClient> {
454        let rpc = MockRPCClient::new();
455        ComponentTracker::new(
456            Chain::Ethereum,
457            "uniswap-v2",
458            ComponentFilter::with_tvl_range(0.0, 0.0),
459            rpc,
460        )
461    }
462
463    fn components_response() -> (Vec<Address>, ProtocolComponent) {
464        let contract_addresses = vec![Bytes::from("0x1234"), Bytes::from("0xbabe")];
465        let component = ProtocolComponent {
466            id: "Component1".to_string(),
467            contract_addresses: contract_addresses.clone(),
468            ..Default::default()
469        };
470        (contract_addresses, component)
471    }
472
473    #[tokio::test]
474    async fn test_initialise_components() {
475        let mut tracker = with_mocked_rpc();
476        let (contract_ids, model_component) = components_response();
477        let exp_component = model_component.clone();
478        let model_for_mock = model_component.clone();
479        tracker
480            .rpc_client
481            .expect_get_protocol_components_paginated()
482            .returning(move |_| Ok(vec![model_for_mock.clone()]));
483
484        tracker
485            .initialise_components()
486            .await
487            .expect("Retrieving components failed");
488
489        assert_eq!(
490            tracker
491                .components
492                .get("Component1")
493                .expect("Component1 not tracked"),
494            &exp_component
495        );
496        assert_eq!(tracker.contracts, contract_ids.into_iter().collect());
497    }
498
499    #[tokio::test]
500    async fn test_start_tracking() {
501        let mut tracker = with_mocked_rpc();
502        let (contract_ids, model_component) = components_response();
503        let exp_contracts = contract_ids.into_iter().collect();
504        let component_id = model_component.id.clone();
505        let components_arg = [&component_id];
506        let model_for_mock = model_component.clone();
507        tracker
508            .rpc_client
509            .expect_get_protocol_components()
510            .returning(move |_| Ok(crate::rpc::Page::new(vec![model_for_mock.clone()], 1, 0, 100)));
511
512        tracker
513            .start_tracking(&components_arg)
514            .await
515            .expect("Tracking components failed");
516
517        assert_eq!(&tracker.contracts, &exp_contracts);
518        assert!(tracker
519            .components
520            .contains_key("Component1"));
521    }
522
523    #[test]
524    fn test_stop_tracking() {
525        let mut tracker = with_mocked_rpc();
526        let (contract_ids, model_component) = components_response();
527        tracker
528            .components
529            .insert("Component1".to_string(), model_component.clone());
530        tracker.contracts.extend(contract_ids);
531        let components_arg = ["Component1".to_string(), "Component2".to_string()];
532        let exp = [("Component1".to_string(), model_component)]
533            .into_iter()
534            .collect();
535
536        let res = tracker.stop_tracking(&components_arg);
537
538        assert_eq!(res, exp);
539        assert!(tracker.contracts.is_empty());
540    }
541
542    #[test]
543    fn test_get_contracts_by_component() {
544        let mut tracker = with_mocked_rpc();
545        let (exp_contracts, model_component) = components_response();
546        tracker
547            .components
548            .insert("Component1".to_string(), model_component);
549        let components_arg = ["Component1".to_string()];
550
551        let res = tracker.get_contracts_by_component(&components_arg);
552
553        assert_eq!(res, exp_contracts.into_iter().collect());
554    }
555
556    #[test]
557    fn test_get_tracked_component_ids() {
558        let mut tracker = with_mocked_rpc();
559        let (_, model_component) = components_response();
560        tracker
561            .components
562            .insert("Component1".to_string(), model_component);
563        let exp = vec!["Component1".to_string()];
564
565        let res = tracker.get_tracked_component_ids();
566
567        assert_eq!(res, exp);
568    }
569
570    fn with_mocked_rpc_and_blocklist(blocklisted: Vec<&str>) -> ComponentTracker<MockRPCClient> {
571        let rpc = MockRPCClient::new();
572        let filter = ComponentFilter::with_tvl_range(0.0, 0.0).blocklist(
573            blocklisted
574                .into_iter()
575                .map(String::from),
576        );
577        ComponentTracker::new(Chain::Ethereum, "uniswap-v2", filter, rpc)
578    }
579
580    #[tokio::test]
581    async fn test_initialise_skips_blocklisted_components() {
582        let mut tracker = with_mocked_rpc_and_blocklist(vec!["component1"]);
583        let (_, model_component) = components_response();
584        tracker
585            .rpc_client
586            .expect_get_protocol_components_paginated()
587            .returning(move |_| Ok(vec![model_component.clone()]));
588
589        tracker
590            .initialise_components()
591            .await
592            .expect("Retrieving components failed");
593
594        assert!(tracker.components.is_empty(), "Blocklisted component should not be in tracker");
595    }
596
597    #[tokio::test]
598    async fn test_start_tracking_skips_blocklisted() {
599        let mut tracker = with_mocked_rpc_and_blocklist(vec!["component1"]);
600        let component_id = "Component1".to_string();
601        let components_arg = [&component_id];
602
603        tracker
604            .start_tracking(&components_arg)
605            .await
606            .expect("start_tracking should succeed");
607
608        assert!(tracker.components.is_empty(), "Blocklisted component should not be tracked");
609    }
610
611    #[test]
612    fn test_filter_updated_blocks_blocklisted_add() {
613        let mut tracker = with_mocked_rpc_and_blocklist(vec!["blocklisted_pool"]);
614        tracker.filter = ComponentFilter::with_tvl_range(5.0, 10.0)
615            .blocklist(vec!["blocklisted_pool".to_string()]);
616
617        let deltas = BlockAggregatedChanges {
618            component_tvl: HashMap::from([
619                ("blocklisted_pool".to_string(), 100.0),
620                ("allowed_pool".to_string(), 100.0),
621            ]),
622            ..Default::default()
623        };
624
625        let (to_add, to_remove) = tracker.filter_updated_components(&deltas);
626        assert!(
627            !to_add.contains(&"blocklisted_pool".to_string()),
628            "Blocklisted component should not be in to_add"
629        );
630        assert!(
631            to_add.contains(&"allowed_pool".to_string()),
632            "Non-blocklisted component should be in to_add"
633        );
634        assert!(to_remove.is_empty());
635    }
636}