tycho_client/feed/
component_tracker.rs

1use std::collections::{HashMap, HashSet};
2
3use tracing::{debug, instrument, warn};
4use tycho_common::{
5    dto::{BlockChanges, Chain, DCIUpdate, ProtocolComponent, ProtocolComponentsRequestBody},
6    models::{Address, ComponentId, ProtocolSystem},
7};
8
9use crate::{
10    rpc::{RPCClient, RPC_CLIENT_CONCURRENCY},
11    RPCError,
12};
13
14#[derive(Clone, Debug)]
15pub(crate) enum ComponentFilterVariant {
16    Ids(Vec<ComponentId>),
17    /// MinimumTVLRange is a tuple of (remove_tvl_threshold, add_tvl_threshold). Components that
18    /// drop below the remove threshold will be removed from tracking, components that exceed the
19    /// add threshold will be added. This helps buffer against components that fluctuate on the
20    /// tvl threshold boundary. The thresholds are denominated in native token of the chain, for
21    /// example 1 means 1 ETH on ethereum.
22    MinimumTVLRange((f64, f64)),
23}
24
25#[derive(Clone, Debug)]
26pub struct ComponentFilter {
27    variant: ComponentFilterVariant,
28}
29
30impl ComponentFilter {
31    /// Creates a `ComponentFilter` that filters components based on a minimum Total Value Locked
32    /// (TVL) threshold.
33    ///
34    /// # Arguments
35    ///
36    /// * `min_tvl` - The minimum TVL required for a component to be tracked. This is denominated in
37    ///   native token of the chain.
38    #[allow(non_snake_case)] // for backwards compatibility
39    #[deprecated(since = "0.9.2", note = "Please use with_tvl_range instead")]
40    pub fn MinimumTVL(min_tvl: f64) -> ComponentFilter {
41        ComponentFilter { variant: ComponentFilterVariant::MinimumTVLRange((min_tvl, min_tvl)) }
42    }
43
44    /// Creates a `ComponentFilter` with a specified TVL range for adding or removing components
45    /// from tracking.
46    ///
47    /// Components that drop below the `remove_tvl_threshold` will be removed from tracking,
48    /// while components that exceed the `add_tvl_threshold` will be added to tracking.
49    /// This approach helps to reduce fluctuations caused by components hovering around a single
50    /// threshold.
51    ///
52    /// # Arguments
53    ///
54    /// * `remove_tvl_threshold` - The TVL below which a component will be removed from tracking.
55    /// * `add_tvl_threshold` - The TVL above which a component will be added to tracking.
56    ///
57    /// Note: thresholds are denominated in native token of the chain.
58    pub fn with_tvl_range(remove_tvl_threshold: f64, add_tvl_threshold: f64) -> ComponentFilter {
59        ComponentFilter {
60            variant: ComponentFilterVariant::MinimumTVLRange((
61                remove_tvl_threshold,
62                add_tvl_threshold,
63            )),
64        }
65    }
66
67    /// Creates a `ComponentFilter` that **includes only** the components with the specified IDs,
68    /// effectively filtering out all other components.
69    ///
70    /// # Arguments
71    ///
72    /// * `ids` - A vector of component IDs to include in the filter. Only components with these IDs
73    ///   will be tracked.
74    #[allow(non_snake_case)] // for backwards compatibility
75    pub fn Ids(ids: Vec<ComponentId>) -> ComponentFilter {
76        ComponentFilter {
77            variant: ComponentFilterVariant::Ids(
78                ids.into_iter()
79                    .map(|id| id.to_lowercase())
80                    .collect(),
81            ),
82        }
83    }
84}
85
86/// Information about an entrypoint, including which components use it and what contracts it
87/// interacts with
88#[derive(Default)]
89struct EntrypointRelations {
90    /// Set of component ids for components that have this entrypoint
91    components: HashSet<ComponentId>,
92    /// Set of detected contracts for the entrypoint
93    contracts: HashSet<Address>,
94}
95
96/// Helper struct to determine which components and contracts are being tracked atm.
97pub struct ComponentTracker<R: RPCClient> {
98    chain: Chain,
99    protocol_system: ProtocolSystem,
100    filter: ComponentFilter,
101    /// We will need to request a snapshot for components/contracts that we did not emit as
102    /// snapshot for yet but are relevant now, e.g. because min tvl threshold exceeded.
103    pub components: HashMap<ComponentId, ProtocolComponent>,
104    /// Map of entrypoint id to its associated components and contracts
105    entrypoints: HashMap<String, EntrypointRelations>,
106    /// Derived from tracked components. We need this if subscribed to a vm extractor because
107    /// updates are emitted on a contract level instead of a component level.
108    pub contracts: HashSet<Address>,
109    /// Client to retrieve necessary protocol components from the rpc.
110    rpc_client: R,
111}
112
113impl<R> ComponentTracker<R>
114where
115    R: RPCClient,
116{
117    pub fn new(chain: Chain, protocol_system: &str, filter: ComponentFilter, rpc: R) -> Self {
118        Self {
119            chain,
120            protocol_system: protocol_system.to_string(),
121            filter,
122            components: Default::default(),
123            contracts: Default::default(),
124            rpc_client: rpc,
125            entrypoints: Default::default(),
126        }
127    }
128
129    /// Retrieves all components that belong to the system we are streaming that have sufficient
130    /// tvl. Also detects which contracts are relevant for simulating on those components.
131    pub async fn initialise_components(&mut self) -> Result<(), RPCError> {
132        let body = match &self.filter.variant {
133            ComponentFilterVariant::Ids(ids) => ProtocolComponentsRequestBody::id_filtered(
134                &self.protocol_system,
135                ids.clone(),
136                self.chain,
137            ),
138            ComponentFilterVariant::MinimumTVLRange((_, upper_tvl_threshold)) => {
139                ProtocolComponentsRequestBody::system_filtered(
140                    &self.protocol_system,
141                    Some(*upper_tvl_threshold),
142                    self.chain,
143                )
144            }
145        };
146        self.components = self
147            .rpc_client
148            .get_protocol_components_paginated(&body, None, RPC_CLIENT_CONCURRENCY)
149            .await?
150            .protocol_components
151            .into_iter()
152            .map(|pc| (pc.id.clone(), pc))
153            .collect::<HashMap<_, _>>();
154
155        self.reinitialize_contracts();
156
157        Ok(())
158    }
159
160    /// Initialise the tracked contracts list from tracked components and their entrypoints
161    fn reinitialize_contracts(&mut self) {
162        // Add contracts from all tracked components
163        self.contracts = self
164            .components
165            .values()
166            .flat_map(|comp| comp.contract_ids.iter().cloned())
167            .collect();
168
169        // Add contracts from entrypoints that are linked to tracked components
170        let tracked_component_ids = self
171            .components
172            .keys()
173            .cloned()
174            .collect::<HashSet<_>>();
175        for entrypoint in self.entrypoints.values() {
176            if !entrypoint
177                .components
178                .is_disjoint(&tracked_component_ids)
179            {
180                self.contracts
181                    .extend(entrypoint.contracts.iter().cloned());
182            }
183        }
184    }
185
186    /// Update the tracked contracts list with contracts associated with the given components
187    fn update_contracts(&mut self, components: Vec<ComponentId>) {
188        // Only process components that are actually being tracked.
189        let mut tracked_component_ids = HashSet::new();
190
191        // Add contracts from the components
192        for comp in components {
193            if let Some(component) = self.components.get(&comp) {
194                self.contracts
195                    .extend(component.contract_ids.iter().cloned());
196                tracked_component_ids.insert(comp);
197            }
198        }
199
200        // Identify entrypoints linked to the given components
201        for entrypoint in self.entrypoints.values() {
202            if !entrypoint
203                .components
204                .is_disjoint(&tracked_component_ids)
205            {
206                self.contracts
207                    .extend(entrypoint.contracts.iter().cloned());
208            }
209        }
210    }
211
212    /// Add new components to be tracked
213    #[instrument(skip(self, new_components))]
214    pub async fn start_tracking(
215        &mut self,
216        new_components: &[&ComponentId],
217    ) -> Result<(), RPCError> {
218        if new_components.is_empty() {
219            return Ok(());
220        }
221
222        // Fetch the components
223        let request = ProtocolComponentsRequestBody::id_filtered(
224            &self.protocol_system,
225            new_components
226                .iter()
227                .map(|&id| id.to_string())
228                .collect(),
229            self.chain,
230        );
231        let components = self
232            .rpc_client
233            .get_protocol_components(&request)
234            .await?
235            .protocol_components
236            .into_iter()
237            .map(|pc| (pc.id.clone(), pc))
238            .collect::<HashMap<_, _>>();
239
240        // Update components and contracts
241        let component_ids: Vec<_> = components.keys().cloned().collect();
242        let component_count = component_ids.len();
243        self.components.extend(components);
244        self.update_contracts(component_ids);
245
246        debug!(n_components = component_count, "StartedTracking");
247        Ok(())
248    }
249
250    /// Stop tracking components
251    #[instrument(skip(self, to_remove))]
252    pub fn stop_tracking<'a, I: IntoIterator<Item = &'a ComponentId> + std::fmt::Debug>(
253        &mut self,
254        to_remove: I,
255    ) -> HashMap<ComponentId, ProtocolComponent> {
256        let mut removed_components = HashMap::new();
257
258        for component_id in to_remove {
259            if let Some(component) = self.components.remove(component_id) {
260                removed_components.insert(component_id.clone(), component);
261            }
262        }
263
264        // Refresh the tracked contracts list. This is more reliable and efficient than directly
265        // removing contracts from the list because some contracts are shared between components.
266        self.reinitialize_contracts();
267
268        debug!(n_components = removed_components.len(), "StoppedTracking");
269        removed_components
270    }
271
272    /// Updates the tracked entrypoints and contracts based on the given DCI data.
273    pub fn process_entrypoints(&mut self, dci_update: &DCIUpdate) {
274        // Update detected contracts for entrypoints
275        for (entrypoint, traces) in &dci_update.trace_results {
276            self.entrypoints
277                .entry(entrypoint.clone())
278                .or_default()
279                .contracts
280                .extend(traces.accessed_slots.keys().cloned());
281        }
282
283        // Update linked components for entrypoints
284        for (component, entrypoints) in &dci_update.new_entrypoints {
285            for entrypoint in entrypoints {
286                let entrypoint_info = self
287                    .entrypoints
288                    .entry(entrypoint.external_id.clone())
289                    .or_default();
290                entrypoint_info
291                    .components
292                    .insert(component.clone());
293                // If the component is tracked, add the detected contracts to the tracker
294                if self.components.contains_key(component) {
295                    self.contracts.extend(
296                        entrypoint_info
297                            .contracts
298                            .iter()
299                            .cloned(),
300                    );
301                }
302            }
303        }
304    }
305
306    /// Get related contracts for the given component ids. Assumes that the components are already
307    /// tracked, either by calling `start_tracking` or `initialise_components`.
308    ///
309    /// # Arguments
310    ///
311    /// * `ids` - A vector of component IDs to get the contracts for.
312    ///
313    /// # Returns
314    ///
315    /// A HashSet of contract IDs. Components that are not tracked will be logged and skipped.
316    pub fn get_contracts_by_component<'a, I: IntoIterator<Item = &'a String>>(
317        &self,
318        ids: I,
319    ) -> HashSet<Address> {
320        ids.into_iter()
321            .filter_map(|cid| {
322                if let Some(comp) = self.components.get(cid) {
323                    // Collect contracts from all entrypoints linked to this component
324                    let dci_contracts: HashSet<Address> = self
325                        .entrypoints
326                        .values()
327                        .filter(|ep| ep.components.contains(cid))
328                        .flat_map(|ep| ep.contracts.iter().cloned())
329                        .collect();
330                    Some(
331                        comp.contract_ids
332                            .clone()
333                            .into_iter()
334                            .chain(dci_contracts)
335                            .collect::<HashSet<_>>(),
336                    )
337                } else {
338                    warn!(
339                        "Requested component is not tracked: {cid}. Skipping fetching contracts..."
340                    );
341                    None
342                }
343            })
344            .flatten()
345            .collect()
346    }
347
348    pub fn get_tracked_component_ids(&self) -> Vec<ComponentId> {
349        self.components
350            .keys()
351            .cloned()
352            .collect()
353    }
354
355    /// Given BlockChanges, filter out components that are no longer relevant and return the
356    /// components that need to be added or removed.
357    pub fn filter_updated_components(
358        &self,
359        deltas: &BlockChanges,
360    ) -> (Vec<ComponentId>, Vec<ComponentId>) {
361        match &self.filter.variant {
362            ComponentFilterVariant::Ids(_) => (Default::default(), Default::default()),
363            ComponentFilterVariant::MinimumTVLRange((remove_tvl, add_tvl)) => deltas
364                .component_tvl
365                .iter()
366                .filter(|(_, &tvl)| tvl < *remove_tvl || tvl > *add_tvl)
367                .map(|(id, _)| id.clone())
368                .partition(|id| deltas.component_tvl[id] > *add_tvl),
369        }
370    }
371}
372
373#[cfg(test)]
374mod test {
375    use tycho_common::{
376        dto::{PaginationResponse, ProtocolComponentRequestResponse},
377        Bytes,
378    };
379
380    use super::*;
381    use crate::rpc::MockRPCClient;
382
383    fn with_mocked_rpc() -> ComponentTracker<MockRPCClient> {
384        let rpc = MockRPCClient::new();
385        ComponentTracker::new(
386            Chain::Ethereum,
387            "uniswap-v2",
388            ComponentFilter::with_tvl_range(0.0, 0.0),
389            rpc,
390        )
391    }
392
393    fn components_response() -> (Vec<Address>, ProtocolComponent) {
394        let contract_ids = vec![Bytes::from("0x1234"), Bytes::from("0xbabe")];
395        let component = ProtocolComponent {
396            id: "Component1".to_string(),
397            contract_ids: contract_ids.clone(),
398            ..Default::default()
399        };
400        (contract_ids, component)
401    }
402
403    #[tokio::test]
404    async fn test_initialise_components() {
405        let mut tracker = with_mocked_rpc();
406        let (contract_ids, component) = components_response();
407        let exp_component = component.clone();
408        tracker
409            .rpc_client
410            .expect_get_protocol_components_paginated()
411            .returning(move |_, _, _| {
412                Ok(ProtocolComponentRequestResponse {
413                    protocol_components: vec![component.clone()],
414                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
415                })
416            });
417
418        tracker
419            .initialise_components()
420            .await
421            .expect("Retrieving components failed");
422
423        assert_eq!(
424            tracker
425                .components
426                .get("Component1")
427                .expect("Component1 not tracked"),
428            &exp_component
429        );
430        assert_eq!(tracker.contracts, contract_ids.into_iter().collect());
431    }
432
433    #[tokio::test]
434    async fn test_start_tracking() {
435        let mut tracker = with_mocked_rpc();
436        let (contract_ids, component) = components_response();
437        let exp_contracts = contract_ids.into_iter().collect();
438        let component_id = component.id.clone();
439        let components_arg = [&component_id];
440        tracker
441            .rpc_client
442            .expect_get_protocol_components()
443            .returning(move |_| {
444                Ok(ProtocolComponentRequestResponse {
445                    protocol_components: vec![component.clone()],
446                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
447                })
448            });
449
450        tracker
451            .start_tracking(&components_arg)
452            .await
453            .expect("Tracking components failed");
454
455        assert_eq!(&tracker.contracts, &exp_contracts);
456        assert!(tracker
457            .components
458            .contains_key("Component1"));
459    }
460
461    #[test]
462    fn test_stop_tracking() {
463        let mut tracker = with_mocked_rpc();
464        let (contract_ids, component) = components_response();
465        tracker
466            .components
467            .insert("Component1".to_string(), component.clone());
468        tracker.contracts.extend(contract_ids);
469        let components_arg = ["Component1".to_string(), "Component2".to_string()];
470        let exp = [("Component1".to_string(), component)]
471            .into_iter()
472            .collect();
473
474        let res = tracker.stop_tracking(&components_arg);
475
476        assert_eq!(res, exp);
477        assert!(tracker.contracts.is_empty());
478    }
479
480    #[test]
481    fn test_get_contracts_by_component() {
482        let mut tracker = with_mocked_rpc();
483        let (exp_contracts, component) = components_response();
484        tracker
485            .components
486            .insert("Component1".to_string(), component);
487        let components_arg = ["Component1".to_string()];
488
489        let res = tracker.get_contracts_by_component(&components_arg);
490
491        assert_eq!(res, exp_contracts.into_iter().collect());
492    }
493
494    #[test]
495    fn test_get_tracked_component_ids() {
496        let mut tracker = with_mocked_rpc();
497        let (_, component) = components_response();
498        tracker
499            .components
500            .insert("Component1".to_string(), component);
501        let exp = vec!["Component1".to_string()];
502
503        let res = tracker.get_tracked_component_ids();
504
505        assert_eq!(res, exp);
506    }
507}