tycho_client/feed/
component_tracker.rs

1use std::collections::{HashMap, HashSet};
2
3use tracing::{debug, instrument, warn};
4use tycho_common::{
5    dto::{BlockChanges, Chain, DCIUpdate, ProtocolComponent, ProtocolComponentsRequestBody},
6    models::{Address, ComponentId, ProtocolSystem},
7};
8
9use crate::{rpc::RPCClient, RPCError};
10
11#[derive(Clone, Debug)]
12pub(crate) enum ComponentFilterVariant {
13    Ids(Vec<ComponentId>),
14    /// MinimumTVLRange is a tuple of (remove_tvl_threshold, add_tvl_threshold). Components that
15    /// drop below the remove threshold will be removed from tracking, components that exceed the
16    /// add threshold will be added. This helps buffer against components that fluctuate on the
17    /// tvl threshold boundary. The thresholds are denominated in native token of the chain, for
18    /// example 1 means 1 ETH on ethereum.
19    MinimumTVLRange((f64, f64)),
20}
21
22#[derive(Clone, Debug)]
23pub struct ComponentFilter {
24    variant: ComponentFilterVariant,
25}
26
27impl ComponentFilter {
28    /// Creates a `ComponentFilter` that filters components based on a minimum Total Value Locked
29    /// (TVL) threshold.
30    ///
31    /// # Arguments
32    ///
33    /// * `min_tvl` - The minimum TVL required for a component to be tracked. This is denominated in
34    ///   native token of the chain.
35    #[allow(non_snake_case)] // for backwards compatibility
36    #[deprecated(since = "0.9.2", note = "Please use with_tvl_range instead")]
37    pub fn MinimumTVL(min_tvl: f64) -> ComponentFilter {
38        ComponentFilter { variant: ComponentFilterVariant::MinimumTVLRange((min_tvl, min_tvl)) }
39    }
40
41    /// Creates a `ComponentFilter` with a specified TVL range for adding or removing components
42    /// from tracking.
43    ///
44    /// Components that drop below the `remove_tvl_threshold` will be removed from tracking,
45    /// while components that exceed the `add_tvl_threshold` will be added to tracking.
46    /// This approach helps to reduce fluctuations caused by components hovering around a single
47    /// threshold.
48    ///
49    /// # Arguments
50    ///
51    /// * `remove_tvl_threshold` - The TVL below which a component will be removed from tracking.
52    /// * `add_tvl_threshold` - The TVL above which a component will be added to tracking.
53    ///
54    /// Note: thresholds are denominated in native token of the chain.
55    pub fn with_tvl_range(remove_tvl_threshold: f64, add_tvl_threshold: f64) -> ComponentFilter {
56        ComponentFilter {
57            variant: ComponentFilterVariant::MinimumTVLRange((
58                remove_tvl_threshold,
59                add_tvl_threshold,
60            )),
61        }
62    }
63
64    /// Creates a `ComponentFilter` that **includes only** the components with the specified IDs,
65    /// effectively filtering out all other components.
66    ///
67    /// # Arguments
68    ///
69    /// * `ids` - A vector of component IDs to include in the filter. Only components with these IDs
70    ///   will be tracked.
71    #[allow(non_snake_case)] // for backwards compatibility
72    pub fn Ids(ids: Vec<ComponentId>) -> ComponentFilter {
73        ComponentFilter {
74            variant: ComponentFilterVariant::Ids(
75                ids.into_iter()
76                    .map(|id| id.to_lowercase())
77                    .collect(),
78            ),
79        }
80    }
81}
82
83/// Information about an entrypoint, including which components use it and what contracts it
84/// interacts with
85#[derive(Default)]
86pub struct EntrypointRelations {
87    /// Set of component ids for components that have this entrypoint
88    components: HashSet<ComponentId>,
89    /// Set of detected contracts for the entrypoint
90    contracts: HashSet<Address>,
91}
92
93/// Helper struct to determine which components and contracts are being tracked atm.
94pub struct ComponentTracker<R: RPCClient> {
95    chain: Chain,
96    protocol_system: ProtocolSystem,
97    filter: ComponentFilter,
98    /// We will need to request a snapshot for components/contracts that we did not emit as
99    /// snapshot for yet but are relevant now, e.g. because min tvl threshold exceeded.
100    pub components: HashMap<ComponentId, ProtocolComponent>,
101    /// Map of entrypoint id to its associated components and contracts
102    entrypoints: HashMap<String, EntrypointRelations>,
103    /// Derived from tracked components. We need this if subscribed to a vm extractor because
104    /// updates are emitted on a contract level instead of a component level.
105    pub contracts: HashSet<Address>,
106    /// Client to retrieve necessary protocol components from the rpc.
107    rpc_client: R,
108}
109
110impl<R> ComponentTracker<R>
111where
112    R: RPCClient,
113{
114    pub fn new(chain: Chain, protocol_system: &str, filter: ComponentFilter, rpc: R) -> Self {
115        Self {
116            chain,
117            protocol_system: protocol_system.to_string(),
118            filter,
119            components: Default::default(),
120            contracts: Default::default(),
121            rpc_client: rpc,
122            entrypoints: Default::default(),
123        }
124    }
125
126    /// Retrieves all components that belong to the system we are streaming that have sufficient
127    /// tvl. Also detects which contracts are relevant for simulating on those components.
128    pub async fn initialise_components(&mut self) -> Result<(), RPCError> {
129        let body = match &self.filter.variant {
130            ComponentFilterVariant::Ids(ids) => ProtocolComponentsRequestBody::id_filtered(
131                &self.protocol_system,
132                ids.clone(),
133                self.chain,
134            ),
135            ComponentFilterVariant::MinimumTVLRange((_, upper_tvl_threshold)) => {
136                ProtocolComponentsRequestBody::system_filtered(
137                    &self.protocol_system,
138                    Some(*upper_tvl_threshold),
139                    self.chain,
140                )
141            }
142        };
143        self.components = self
144            .rpc_client
145            .get_protocol_components_paginated(&body, 500, 4)
146            .await?
147            .protocol_components
148            .into_iter()
149            .map(|pc| (pc.id.clone(), pc))
150            .collect::<HashMap<_, _>>();
151
152        self.reinitialize_contracts();
153
154        Ok(())
155    }
156
157    /// Initialise the tracked contracts list from tracked components and their entrypoints
158    fn reinitialize_contracts(&mut self) {
159        // Add contracts from all tracked components
160        self.contracts = self
161            .components
162            .values()
163            .flat_map(|comp| comp.contract_ids.iter().cloned())
164            .collect();
165
166        // Add contracts from entrypoints that are linked to tracked components
167        let tracked_component_ids = self
168            .components
169            .keys()
170            .cloned()
171            .collect::<HashSet<_>>();
172        for entrypoint in self.entrypoints.values() {
173            if !entrypoint
174                .components
175                .is_disjoint(&tracked_component_ids)
176            {
177                self.contracts
178                    .extend(entrypoint.contracts.iter().cloned());
179            }
180        }
181    }
182
183    /// Update the tracked contracts list with contracts associated with the given components
184    fn update_contracts(&mut self, components: Vec<ComponentId>) {
185        // Only process components that are actually being tracked.
186        let mut tracked_component_ids = HashSet::new();
187
188        // Add contracts from the components
189        for comp in components {
190            if let Some(component) = self.components.get(&comp) {
191                self.contracts
192                    .extend(component.contract_ids.iter().cloned());
193                tracked_component_ids.insert(comp);
194            }
195        }
196
197        // Identify entrypoints linked to the given components
198        for entrypoint in self.entrypoints.values() {
199            if !entrypoint
200                .components
201                .is_disjoint(&tracked_component_ids)
202            {
203                self.contracts
204                    .extend(entrypoint.contracts.iter().cloned());
205            }
206        }
207    }
208
209    /// Add new components to be tracked
210    #[instrument(skip(self, new_components))]
211    pub async fn start_tracking(
212        &mut self,
213        new_components: &[&ComponentId],
214    ) -> Result<(), RPCError> {
215        if new_components.is_empty() {
216            return Ok(());
217        }
218
219        // Fetch the components
220        let request = ProtocolComponentsRequestBody::id_filtered(
221            &self.protocol_system,
222            new_components
223                .iter()
224                .map(|&id| id.to_string())
225                .collect(),
226            self.chain,
227        );
228        let components = self
229            .rpc_client
230            .get_protocol_components(&request)
231            .await?
232            .protocol_components
233            .into_iter()
234            .map(|pc| (pc.id.clone(), pc))
235            .collect::<HashMap<_, _>>();
236
237        // Update components and contracts
238        let component_ids: Vec<_> = components.keys().cloned().collect();
239        let component_count = component_ids.len();
240        self.components.extend(components);
241        self.update_contracts(component_ids);
242
243        debug!(n_components = component_count, "StartedTracking");
244        Ok(())
245    }
246
247    /// Stop tracking components
248    #[instrument(skip(self, to_remove))]
249    pub fn stop_tracking<'a, I: IntoIterator<Item = &'a ComponentId> + std::fmt::Debug>(
250        &mut self,
251        to_remove: I,
252    ) -> HashMap<ComponentId, ProtocolComponent> {
253        let mut removed_components = HashMap::new();
254
255        for component_id in to_remove {
256            if let Some(component) = self.components.remove(component_id) {
257                removed_components.insert(component_id.clone(), component);
258            }
259        }
260
261        // Refresh the tracked contracts list. This is more reliable and efficient than directly
262        // removing contracts from the list because some contracts are shared between components.
263        self.reinitialize_contracts();
264
265        debug!(n_components = removed_components.len(), "StoppedTracking");
266        removed_components
267    }
268
269    /// Updates the tracked entrypoints and contracts based on the given DCI data.
270    pub fn process_entrypoints(&mut self, dci_update: &DCIUpdate) {
271        // Update detected contracts for entrypoints
272        for (entrypoint, traces) in &dci_update.trace_results {
273            self.entrypoints
274                .entry(entrypoint.clone())
275                .or_default()
276                .contracts
277                .extend(traces.accessed_slots.keys().cloned());
278        }
279
280        // Update linked components for entrypoints
281        for (component, entrypoints) in &dci_update.new_entrypoints {
282            for entrypoint in entrypoints {
283                let entrypoint_info = self
284                    .entrypoints
285                    .entry(entrypoint.external_id.clone())
286                    .or_default();
287                entrypoint_info
288                    .components
289                    .insert(component.clone());
290                // If the component is tracked, add the detected contracts to the tracker
291                if self.components.contains_key(component) {
292                    self.contracts.extend(
293                        entrypoint_info
294                            .contracts
295                            .iter()
296                            .cloned(),
297                    );
298                }
299            }
300        }
301    }
302
303    /// Get related contracts for the given component ids. Assumes that the components are already
304    /// tracked, either by calling `start_tracking` or `initialise_components`.
305    ///
306    /// # Arguments
307    ///
308    /// * `ids` - A vector of component IDs to get the contracts for.
309    ///
310    /// # Returns
311    ///
312    /// A HashSet of contract IDs. Components that are not tracked will be logged and skipped.
313    pub fn get_contracts_by_component<'a, I: IntoIterator<Item = &'a String>>(
314        &self,
315        ids: I,
316    ) -> HashSet<Address> {
317        ids.into_iter()
318            .filter_map(|cid| {
319                if let Some(comp) = self.components.get(cid) {
320                    // Collect contracts from all entrypoints linked to this component
321                    let dci_contracts: HashSet<Address> = self
322                        .entrypoints
323                        .values()
324                        .filter(|ep| ep.components.contains(cid))
325                        .flat_map(|ep| ep.contracts.iter().cloned())
326                        .collect();
327                    Some(
328                        comp.contract_ids
329                            .clone()
330                            .into_iter()
331                            .chain(dci_contracts)
332                            .collect::<HashSet<_>>(),
333                    )
334                } else {
335                    warn!(
336                        "Requested component is not tracked: {cid}. Skipping fetching contracts..."
337                    );
338                    None
339                }
340            })
341            .flatten()
342            .collect()
343    }
344
345    pub fn get_tracked_component_ids(&self) -> Vec<ComponentId> {
346        self.components
347            .keys()
348            .cloned()
349            .collect()
350    }
351
352    /// Given BlockChanges, filter out components that are no longer relevant and return the
353    /// components that need to be added or removed.
354    pub fn filter_updated_components(
355        &self,
356        deltas: &BlockChanges,
357    ) -> (Vec<ComponentId>, Vec<ComponentId>) {
358        match &self.filter.variant {
359            ComponentFilterVariant::Ids(_) => (Default::default(), Default::default()),
360            ComponentFilterVariant::MinimumTVLRange((remove_tvl, add_tvl)) => deltas
361                .component_tvl
362                .iter()
363                .filter(|(_, &tvl)| tvl < *remove_tvl || tvl > *add_tvl)
364                .map(|(id, _)| id.clone())
365                .partition(|id| deltas.component_tvl[id] > *add_tvl),
366        }
367    }
368}
369
370#[cfg(test)]
371mod test {
372    use tycho_common::{
373        dto::{PaginationResponse, ProtocolComponentRequestResponse},
374        Bytes,
375    };
376
377    use super::*;
378    use crate::rpc::MockRPCClient;
379
380    fn with_mocked_rpc() -> ComponentTracker<MockRPCClient> {
381        let rpc = MockRPCClient::new();
382        ComponentTracker::new(
383            Chain::Ethereum,
384            "uniswap-v2",
385            ComponentFilter::with_tvl_range(0.0, 0.0),
386            rpc,
387        )
388    }
389
390    fn components_response() -> (Vec<Address>, ProtocolComponent) {
391        let contract_ids = vec![Bytes::from("0x1234"), Bytes::from("0xbabe")];
392        let component = ProtocolComponent {
393            id: "Component1".to_string(),
394            contract_ids: contract_ids.clone(),
395            ..Default::default()
396        };
397        (contract_ids, component)
398    }
399
400    #[tokio::test]
401    async fn test_initialise_components() {
402        let mut tracker = with_mocked_rpc();
403        let (contract_ids, component) = components_response();
404        let exp_component = component.clone();
405        tracker
406            .rpc_client
407            .expect_get_protocol_components_paginated()
408            .returning(move |_, _, _| {
409                Ok(ProtocolComponentRequestResponse {
410                    protocol_components: vec![component.clone()],
411                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
412                })
413            });
414
415        tracker
416            .initialise_components()
417            .await
418            .expect("Retrieving components failed");
419
420        assert_eq!(
421            tracker
422                .components
423                .get("Component1")
424                .expect("Component1 not tracked"),
425            &exp_component
426        );
427        assert_eq!(tracker.contracts, contract_ids.into_iter().collect());
428    }
429
430    #[tokio::test]
431    async fn test_start_tracking() {
432        let mut tracker = with_mocked_rpc();
433        let (contract_ids, component) = components_response();
434        let exp_contracts = contract_ids.into_iter().collect();
435        let component_id = component.id.clone();
436        let components_arg = [&component_id];
437        tracker
438            .rpc_client
439            .expect_get_protocol_components()
440            .returning(move |_| {
441                Ok(ProtocolComponentRequestResponse {
442                    protocol_components: vec![component.clone()],
443                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
444                })
445            });
446
447        tracker
448            .start_tracking(&components_arg)
449            .await
450            .expect("Tracking components failed");
451
452        assert_eq!(&tracker.contracts, &exp_contracts);
453        assert!(tracker
454            .components
455            .contains_key("Component1"));
456    }
457
458    #[test]
459    fn test_stop_tracking() {
460        let mut tracker = with_mocked_rpc();
461        let (contract_ids, component) = components_response();
462        tracker
463            .components
464            .insert("Component1".to_string(), component.clone());
465        tracker.contracts.extend(contract_ids);
466        let components_arg = ["Component1".to_string(), "Component2".to_string()];
467        let exp = [("Component1".to_string(), component)]
468            .into_iter()
469            .collect();
470
471        let res = tracker.stop_tracking(&components_arg);
472
473        assert_eq!(res, exp);
474        assert!(tracker.contracts.is_empty());
475    }
476
477    #[test]
478    fn test_get_contracts_by_component() {
479        let mut tracker = with_mocked_rpc();
480        let (exp_contracts, component) = components_response();
481        tracker
482            .components
483            .insert("Component1".to_string(), component);
484        let components_arg = ["Component1".to_string()];
485
486        let res = tracker.get_contracts_by_component(&components_arg);
487
488        assert_eq!(res, exp_contracts.into_iter().collect());
489    }
490
491    #[test]
492    fn test_get_tracked_component_ids() {
493        let mut tracker = with_mocked_rpc();
494        let (_, component) = components_response();
495        tracker
496            .components
497            .insert("Component1".to_string(), component);
498        let exp = vec!["Component1".to_string()];
499
500        let res = tracker.get_tracked_component_ids();
501
502        assert_eq!(res, exp);
503    }
504}