tycho_client/feed/
component_tracker.rs

1use std::collections::{HashMap, HashSet};
2
3use tracing::{debug, instrument, warn};
4use tycho_common::{
5    dto::{BlockChanges, Chain, DCIUpdate, ProtocolComponent, ProtocolComponentsRequestBody},
6    models::{Address, ComponentId, ProtocolSystem},
7};
8
9use crate::{rpc::RPCClient, RPCError};
10
11#[derive(Clone, Debug)]
12pub(crate) enum ComponentFilterVariant {
13    Ids(Vec<ComponentId>),
14    /// MinimumTVLRange is a tuple of (remove_tvl_threshold, add_tvl_threshold). Components that
15    /// drop below the remove threshold will be removed from tracking, components that exceed the
16    /// add threshold will be added. This helps buffer against components that fluctuate on the
17    /// tvl threshold boundary. The thresholds are denominated in native token of the chain, for
18    /// example 1 means 1 ETH on ethereum.
19    MinimumTVLRange((f64, f64)),
20}
21
22#[derive(Clone, Debug)]
23pub struct ComponentFilter {
24    variant: ComponentFilterVariant,
25}
26
27impl ComponentFilter {
28    /// Creates a `ComponentFilter` that filters components based on a minimum Total Value Locked
29    /// (TVL) threshold.
30    ///
31    /// # Arguments
32    ///
33    /// * `min_tvl` - The minimum TVL required for a component to be tracked. This is denominated in
34    ///   native token of the chain.
35    #[allow(non_snake_case)] // for backwards compatibility
36    #[deprecated(since = "0.9.2", note = "Please use with_tvl_range instead")]
37    pub fn MinimumTVL(min_tvl: f64) -> ComponentFilter {
38        ComponentFilter { variant: ComponentFilterVariant::MinimumTVLRange((min_tvl, min_tvl)) }
39    }
40
41    /// Creates a `ComponentFilter` with a specified TVL range for adding or removing components
42    /// from tracking.
43    ///
44    /// Components that drop below the `remove_tvl_threshold` will be removed from tracking,
45    /// while components that exceed the `add_tvl_threshold` will be added to tracking.
46    /// This approach helps to reduce fluctuations caused by components hovering around a single
47    /// threshold.
48    ///
49    /// # Arguments
50    ///
51    /// * `remove_tvl_threshold` - The TVL below which a component will be removed from tracking.
52    /// * `add_tvl_threshold` - The TVL above which a component will be added to tracking.
53    ///
54    /// Note: thresholds are denominated in native token of the chain.
55    pub fn with_tvl_range(remove_tvl_threshold: f64, add_tvl_threshold: f64) -> ComponentFilter {
56        ComponentFilter {
57            variant: ComponentFilterVariant::MinimumTVLRange((
58                remove_tvl_threshold,
59                add_tvl_threshold,
60            )),
61        }
62    }
63
64    /// Creates a `ComponentFilter` that **includes only** the components with the specified IDs,
65    /// effectively filtering out all other components.
66    ///
67    /// # Arguments
68    ///
69    /// * `ids` - A vector of component IDs to include in the filter. Only components with these IDs
70    ///   will be tracked.
71    #[allow(non_snake_case)] // for backwards compatibility
72    pub fn Ids(ids: Vec<ComponentId>) -> ComponentFilter {
73        ComponentFilter {
74            variant: ComponentFilterVariant::Ids(
75                ids.into_iter()
76                    .map(|id| id.to_lowercase())
77                    .collect(),
78            ),
79        }
80    }
81}
82
83/// Information about an entrypoint, including which components use it and what contracts it
84/// interacts with
85#[derive(Default)]
86struct EntrypointRelations {
87    /// Set of component ids for components that have this entrypoint
88    components: HashSet<ComponentId>,
89    /// Set of detected contracts for the entrypoint
90    contracts: HashSet<Address>,
91}
92
93/// Helper struct to determine which components and contracts are being tracked atm.
94pub struct ComponentTracker<R: RPCClient> {
95    chain: Chain,
96    protocol_system: ProtocolSystem,
97    filter: ComponentFilter,
98    // We will need to request a snapshot for components/contracts that we did not emit as
99    // snapshot for yet but are relevant now, e.g. because min tvl threshold exceeded.
100    pub components: HashMap<ComponentId, ProtocolComponent>,
101    /// Map of entrypoint id to its associated components and contracts
102    entrypoints: HashMap<String, EntrypointRelations>,
103    /// Derived from tracked components. We need this if subscribed to a vm extractor because
104    /// updates are emitted on a contract level instead of a component level.
105    pub contracts: HashSet<Address>,
106    /// Client to retrieve necessary protocol components from the rpc.
107    rpc_client: R,
108}
109
110impl<R> ComponentTracker<R>
111where
112    R: RPCClient,
113{
114    pub fn new(chain: Chain, protocol_system: &str, filter: ComponentFilter, rpc: R) -> Self {
115        Self {
116            chain,
117            protocol_system: protocol_system.to_string(),
118            filter,
119            components: Default::default(),
120            contracts: Default::default(),
121            rpc_client: rpc,
122            entrypoints: Default::default(),
123        }
124    }
125
126    /// Retrieves all components that belong to the system we are streaming that have sufficient
127    /// tvl. Also detects which contracts are relevant for simulating on those components.
128    pub async fn initialise_components(&mut self) -> Result<(), RPCError> {
129        let body = match &self.filter.variant {
130            ComponentFilterVariant::Ids(ids) => ProtocolComponentsRequestBody::id_filtered(
131                &self.protocol_system,
132                ids.clone(),
133                self.chain,
134            ),
135            ComponentFilterVariant::MinimumTVLRange((_, upper_tvl_threshold)) => {
136                ProtocolComponentsRequestBody::system_filtered(
137                    &self.protocol_system,
138                    Some(*upper_tvl_threshold),
139                    self.chain,
140                )
141            }
142        };
143        self.components = self
144            .rpc_client
145            .get_protocol_components_paginated(&body, 500, 4)
146            .await?
147            .protocol_components
148            .into_iter()
149            .map(|pc| (pc.id.clone(), pc))
150            .collect::<HashMap<_, _>>();
151
152        self.reinitialize_contracts();
153
154        Ok(())
155    }
156
157    /// Initialise the tracked contracts list from tracked components and their entrypoints
158    fn reinitialize_contracts(&mut self) {
159        // Add contracts from all tracked components
160        self.contracts = self
161            .components
162            .values()
163            .flat_map(|comp| comp.contract_ids.iter().cloned())
164            .collect();
165
166        // Add contracts from entrypoints that are linked to tracked components
167        let tracked_component_ids = self
168            .components
169            .keys()
170            .cloned()
171            .collect::<HashSet<_>>();
172        for entrypoint in self.entrypoints.values() {
173            if !entrypoint
174                .components
175                .is_disjoint(&tracked_component_ids)
176            {
177                self.contracts
178                    .extend(entrypoint.contracts.iter().cloned());
179            }
180        }
181    }
182
183    /// Update the tracked contracts list with contracts associated with the given components
184    fn update_contracts(&mut self, components: Vec<ComponentId>) {
185        // Only process components that are actually being tracked. Convert to HashSet for
186        // efficient lookup.
187        let tracked_component_ids = components
188            .into_iter()
189            .filter(|id| self.components.contains_key(id))
190            .collect::<HashSet<_>>();
191
192        // Add contracts from the components
193        for comp in &tracked_component_ids {
194            let component = self
195                .components
196                .get(comp)
197                .expect("Component should exist as it was filtered above");
198            self.contracts
199                .extend(component.contract_ids.iter().cloned());
200        }
201
202        // Identify entrypoints linked to the given components
203        for entrypoint in self.entrypoints.values() {
204            if !entrypoint
205                .components
206                .is_disjoint(&tracked_component_ids)
207            {
208                self.contracts
209                    .extend(entrypoint.contracts.iter().cloned());
210            }
211        }
212    }
213
214    /// Add new components to be tracked
215    #[instrument(skip(self, new_components))]
216    pub async fn start_tracking(
217        &mut self,
218        new_components: &[&ComponentId],
219    ) -> Result<(), RPCError> {
220        if new_components.is_empty() {
221            return Ok(());
222        }
223
224        // Fetch the components
225        let request = ProtocolComponentsRequestBody::id_filtered(
226            &self.protocol_system,
227            new_components
228                .iter()
229                .map(|&id| id.to_string())
230                .collect(),
231            self.chain,
232        );
233        let components = self
234            .rpc_client
235            .get_protocol_components(&request)
236            .await?
237            .protocol_components
238            .into_iter()
239            .map(|pc| (pc.id.clone(), pc))
240            .collect::<HashMap<_, _>>();
241
242        // Update components and contracts
243        let component_ids: Vec<_> = components.keys().cloned().collect();
244        let component_count = component_ids.len();
245        self.components.extend(components);
246        self.update_contracts(component_ids);
247
248        debug!(n_components = component_count, "StartedTracking");
249        Ok(())
250    }
251
252    /// Stop tracking components
253    #[instrument(skip(self, to_remove))]
254    pub fn stop_tracking<'a, I: IntoIterator<Item = &'a ComponentId> + std::fmt::Debug>(
255        &mut self,
256        to_remove: I,
257    ) -> HashMap<ComponentId, ProtocolComponent> {
258        let mut removed_components = HashMap::new();
259
260        for component_id in to_remove {
261            if let Some(component) = self.components.remove(component_id) {
262                removed_components.insert(component_id.clone(), component);
263            }
264        }
265
266        // Refresh the tracked contracts list. This is more reliable and efficient than directly
267        // removing contracts from the list because some contracts are shared between components.
268        self.reinitialize_contracts();
269
270        debug!(n_components = removed_components.len(), "StoppedTracking");
271        removed_components
272    }
273
274    /// Updates the tracked entrypoints and contracts based on the given DCI data.
275    pub fn process_entrypoints(&mut self, dci_update: &DCIUpdate) -> Result<(), RPCError> {
276        // Update detected contracts for entrypoints
277        for (entrypoint, traces) in &dci_update.trace_results {
278            self.entrypoints
279                .entry(entrypoint.clone())
280                .or_default()
281                .contracts
282                .extend(traces.accessed_slots.keys().cloned());
283        }
284
285        // Update linked components for entrypoints
286        for (component, entrypoints) in &dci_update.new_entrypoints {
287            for entrypoint in entrypoints {
288                let entrypoint_info = self
289                    .entrypoints
290                    .entry(entrypoint.external_id.clone())
291                    .or_default();
292                entrypoint_info
293                    .components
294                    .insert(component.clone());
295                // If the component is tracked, add the detected contracts to the tracker
296                if self.components.contains_key(component) {
297                    self.contracts.extend(
298                        entrypoint_info
299                            .contracts
300                            .iter()
301                            .cloned(),
302                    );
303                }
304            }
305        }
306
307        Ok(())
308    }
309
310    /// Get related contracts for the given component ids. Assumes that the components are already
311    /// tracked, either by calling `start_tracking` or `initialise_components`.
312    ///
313    /// # Arguments
314    ///
315    /// * `ids` - A vector of component IDs to get the contracts for.
316    ///
317    /// # Returns
318    ///
319    /// A HashSet of contract IDs. Components that are not tracked will be logged and skipped.
320    pub fn get_contracts_by_component<'a, I: IntoIterator<Item = &'a String>>(
321        &self,
322        ids: I,
323    ) -> HashSet<Address> {
324        ids.into_iter()
325            .filter_map(|cid| {
326                if let Some(comp) = self.components.get(cid) {
327                    // Collect contracts from all entrypoints linked to this component
328                    let dci_contracts: HashSet<Address> = self
329                        .entrypoints
330                        .values()
331                        .filter(|ep| ep.components.contains(cid))
332                        .flat_map(|ep| ep.contracts.iter().cloned())
333                        .collect();
334                    Some(
335                        comp.contract_ids
336                            .clone()
337                            .into_iter()
338                            .chain(dci_contracts)
339                            .collect::<HashSet<_>>(),
340                    )
341                } else {
342                    warn!(
343                        "Requested component is not tracked: {cid}. Skipping fetching contracts..."
344                    );
345                    None
346                }
347            })
348            .flatten()
349            .collect()
350    }
351
352    pub fn get_tracked_component_ids(&self) -> Vec<ComponentId> {
353        self.components
354            .keys()
355            .cloned()
356            .collect()
357    }
358
359    /// Given BlockChanges, filter out components that are no longer relevant and return the
360    /// components that need to be added or removed.
361    pub fn filter_updated_components(
362        &self,
363        deltas: &BlockChanges,
364    ) -> (Vec<ComponentId>, Vec<ComponentId>) {
365        match &self.filter.variant {
366            ComponentFilterVariant::Ids(_) => (Default::default(), Default::default()),
367            ComponentFilterVariant::MinimumTVLRange((remove_tvl, add_tvl)) => deltas
368                .component_tvl
369                .iter()
370                .filter(|(_, &tvl)| tvl < *remove_tvl || tvl > *add_tvl)
371                .map(|(id, _)| id.clone())
372                .partition(|id| deltas.component_tvl[id] > *add_tvl),
373        }
374    }
375}
376
377#[cfg(test)]
378mod test {
379    use tycho_common::{
380        dto::{PaginationResponse, ProtocolComponentRequestResponse},
381        Bytes,
382    };
383
384    use super::*;
385    use crate::rpc::MockRPCClient;
386
387    fn with_mocked_rpc() -> ComponentTracker<MockRPCClient> {
388        let rpc = MockRPCClient::new();
389        ComponentTracker::new(
390            Chain::Ethereum,
391            "uniswap-v2",
392            ComponentFilter::with_tvl_range(0.0, 0.0),
393            rpc,
394        )
395    }
396
397    fn components_response() -> (Vec<Address>, ProtocolComponent) {
398        let contract_ids = vec![Bytes::from("0x1234"), Bytes::from("0xbabe")];
399        let component = ProtocolComponent {
400            id: "Component1".to_string(),
401            contract_ids: contract_ids.clone(),
402            ..Default::default()
403        };
404        (contract_ids, component)
405    }
406
407    #[tokio::test]
408    async fn test_initialise_components() {
409        let mut tracker = with_mocked_rpc();
410        let (contract_ids, component) = components_response();
411        let exp_component = component.clone();
412        tracker
413            .rpc_client
414            .expect_get_protocol_components_paginated()
415            .returning(move |_, _, _| {
416                Ok(ProtocolComponentRequestResponse {
417                    protocol_components: vec![component.clone()],
418                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
419                })
420            });
421
422        tracker
423            .initialise_components()
424            .await
425            .expect("Retrieving components failed");
426
427        assert_eq!(
428            tracker
429                .components
430                .get("Component1")
431                .expect("Component1 not tracked"),
432            &exp_component
433        );
434        assert_eq!(tracker.contracts, contract_ids.into_iter().collect());
435    }
436
437    #[tokio::test]
438    async fn test_start_tracking() {
439        let mut tracker = with_mocked_rpc();
440        let (contract_ids, component) = components_response();
441        let exp_contracts = contract_ids.into_iter().collect();
442        let component_id = component.id.clone();
443        let components_arg = [&component_id];
444        tracker
445            .rpc_client
446            .expect_get_protocol_components()
447            .returning(move |_| {
448                Ok(ProtocolComponentRequestResponse {
449                    protocol_components: vec![component.clone()],
450                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
451                })
452            });
453
454        tracker
455            .start_tracking(&components_arg)
456            .await
457            .expect("Tracking components failed");
458
459        assert_eq!(&tracker.contracts, &exp_contracts);
460        assert!(tracker
461            .components
462            .contains_key("Component1"));
463    }
464
465    #[test]
466    fn test_stop_tracking() {
467        let mut tracker = with_mocked_rpc();
468        let (contract_ids, component) = components_response();
469        tracker
470            .components
471            .insert("Component1".to_string(), component.clone());
472        tracker.contracts.extend(contract_ids);
473        let components_arg = ["Component1".to_string(), "Component2".to_string()];
474        let exp = [("Component1".to_string(), component)]
475            .into_iter()
476            .collect();
477
478        let res = tracker.stop_tracking(&components_arg);
479
480        assert_eq!(res, exp);
481        assert!(tracker.contracts.is_empty());
482    }
483
484    #[test]
485    fn test_get_contracts_by_component() {
486        let mut tracker = with_mocked_rpc();
487        let (exp_contracts, component) = components_response();
488        tracker
489            .components
490            .insert("Component1".to_string(), component);
491        let components_arg = ["Component1".to_string()];
492
493        let res = tracker.get_contracts_by_component(&components_arg);
494
495        assert_eq!(res, exp_contracts.into_iter().collect());
496    }
497
498    #[test]
499    fn test_get_tracked_component_ids() {
500        let mut tracker = with_mocked_rpc();
501        let (_, component) = components_response();
502        tracker
503            .components
504            .insert("Component1".to_string(), component);
505        let exp = vec!["Component1".to_string()];
506
507        let res = tracker.get_tracked_component_ids();
508
509        assert_eq!(res, exp);
510    }
511}