Skip to main content

tycho_client/feed/
component_tracker.rs

1use std::collections::{HashMap, HashSet};
2
3use tracing::{debug, instrument, warn};
4use tycho_common::{
5    dto::{BlockChanges, Chain, DCIUpdate, ProtocolComponent, ProtocolComponentsRequestBody},
6    models::{Address, ComponentId, ProtocolSystem},
7};
8
9use crate::{
10    rpc::{RPCClient, RPC_CLIENT_CONCURRENCY},
11    RPCError,
12};
13
14#[derive(Clone, Debug)]
15pub(crate) enum ComponentFilterVariant {
16    Ids(Vec<ComponentId>),
17    /// MinimumTVLRange filters components by TVL thresholds:
18    /// - `range.0` (remove threshold): components below this are removed from tracking
19    /// - `range.1` (add threshold): components above this are added to tracking
20    ///
21    /// This helps buffer against components that fluctuate on the threshold boundary.
22    /// Thresholds are denominated in native token of the chain, for example 1 means 1 ETH on
23    /// ethereum.
24    MinimumTVLRange {
25        range: (f64, f64),
26        blocklisted_ids: HashSet<ComponentId>,
27    },
28}
29
30#[derive(Clone, Debug)]
31pub struct ComponentFilter {
32    variant: ComponentFilterVariant,
33}
34
35impl ComponentFilter {
36    /// Creates a `ComponentFilter` that filters components based on a minimum Total Value Locked
37    /// (TVL) threshold.
38    ///
39    /// # Arguments
40    ///
41    /// * `min_tvl` - The minimum TVL required for a component to be tracked. This is denominated in
42    ///   native token of the chain.
43    #[allow(non_snake_case)] // for backwards compatibility
44    #[deprecated(since = "0.9.2", note = "Please use with_tvl_range instead")]
45    pub fn MinimumTVL(min_tvl: f64) -> ComponentFilter {
46        ComponentFilter {
47            variant: ComponentFilterVariant::MinimumTVLRange {
48                range: (min_tvl, min_tvl),
49                blocklisted_ids: HashSet::new(),
50            },
51        }
52    }
53
54    /// Creates a `ComponentFilter` with a specified TVL range for adding or removing components
55    /// from tracking.
56    ///
57    /// Components that drop below the `remove_tvl_threshold` will be removed from tracking,
58    /// while components that exceed the `add_tvl_threshold` will be added to tracking.
59    /// This approach helps to reduce fluctuations caused by components hovering around a single
60    /// threshold.
61    ///
62    /// # Arguments
63    ///
64    /// * `remove_tvl_threshold` - The TVL below which a component will be removed from tracking.
65    /// * `add_tvl_threshold` - The TVL above which a component will be added to tracking.
66    ///
67    /// Note: thresholds are denominated in native token of the chain.
68    pub fn with_tvl_range(remove_tvl_threshold: f64, add_tvl_threshold: f64) -> ComponentFilter {
69        ComponentFilter {
70            variant: ComponentFilterVariant::MinimumTVLRange {
71                range: (remove_tvl_threshold, add_tvl_threshold),
72                blocklisted_ids: HashSet::new(),
73            },
74        }
75    }
76
77    /// Creates a `ComponentFilter` that **includes only** the components with the specified IDs,
78    /// effectively filtering out all other components.
79    ///
80    /// # Arguments
81    ///
82    /// * `ids` - A vector of component IDs to include in the filter. Only components with these IDs
83    ///   will be tracked.
84    #[allow(non_snake_case)] // for backwards compatibility
85    pub fn Ids(ids: Vec<ComponentId>) -> ComponentFilter {
86        ComponentFilter {
87            variant: ComponentFilterVariant::Ids(
88                ids.into_iter()
89                    .map(|id| id.to_lowercase())
90                    .collect(),
91            ),
92        }
93    }
94
95    /// Blocklist specific component IDs from tracking regardless of other
96    /// filter criteria. IDs are normalized to lowercase.
97    ///
98    /// Has no effect when the filter variant is `Ids`, since the
99    /// inclusion list already defines exactly which components to track.
100    pub fn blocklist(mut self, ids: impl IntoIterator<Item = ComponentId>) -> Self {
101        match &mut self.variant {
102            ComponentFilterVariant::Ids(_) => {
103                warn!(
104                    "blocklist() has no effect on ComponentFilter::Ids; \
105                     remove the component from the ID list instead"
106                );
107            }
108            ComponentFilterVariant::MinimumTVLRange { blocklisted_ids, .. } => {
109                blocklisted_ids.extend(
110                    ids.into_iter()
111                        .map(|id| id.to_lowercase()),
112                );
113            }
114        }
115        self
116    }
117
118    /// Returns true if the given component ID is blocklisted.
119    pub fn is_blocklisted(&self, id: &str) -> bool {
120        match &self.variant {
121            ComponentFilterVariant::Ids(_) => false,
122            ComponentFilterVariant::MinimumTVLRange { blocklisted_ids, .. } => {
123                blocklisted_ids.contains(&id.to_lowercase())
124            }
125        }
126    }
127}
128
129/// Information about an entrypoint, including which components use it and what contracts it
130/// interacts with
131#[derive(Default)]
132struct EntrypointRelations {
133    /// Set of component ids for components that have this entrypoint
134    components: HashSet<ComponentId>,
135    /// Set of detected contracts for the entrypoint
136    contracts: HashSet<Address>,
137}
138
139/// Helper struct to determine which components and contracts are being tracked atm.
140pub struct ComponentTracker<R: RPCClient> {
141    chain: Chain,
142    protocol_system: ProtocolSystem,
143    filter: ComponentFilter,
144    /// We will need to request a snapshot for components/contracts that we did not emit as
145    /// snapshot for yet but are relevant now, e.g. because min tvl threshold exceeded.
146    pub components: HashMap<ComponentId, ProtocolComponent>,
147    /// Map of entrypoint id to its associated components and contracts
148    entrypoints: HashMap<String, EntrypointRelations>,
149    /// Derived from tracked components. We need this if subscribed to a vm extractor because
150    /// updates are emitted on a contract level instead of a component level.
151    pub contracts: HashSet<Address>,
152    /// Client to retrieve necessary protocol components from the rpc.
153    rpc_client: R,
154}
155
156impl<R> ComponentTracker<R>
157where
158    R: RPCClient,
159{
160    pub fn new(chain: Chain, protocol_system: &str, filter: ComponentFilter, rpc: R) -> Self {
161        Self {
162            chain,
163            protocol_system: protocol_system.to_string(),
164            filter,
165            components: Default::default(),
166            contracts: Default::default(),
167            rpc_client: rpc,
168            entrypoints: Default::default(),
169        }
170    }
171
172    /// Retrieves all components that belong to the system we are streaming that have sufficient
173    /// tvl. Also detects which contracts are relevant for simulating on those components.
174    pub async fn initialise_components(&mut self) -> Result<(), RPCError> {
175        let body = match &self.filter.variant {
176            ComponentFilterVariant::Ids(ids) => ProtocolComponentsRequestBody::id_filtered(
177                &self.protocol_system,
178                ids.clone(),
179                self.chain,
180            ),
181            ComponentFilterVariant::MinimumTVLRange { range: (_, upper_tvl_threshold), .. } => {
182                ProtocolComponentsRequestBody::system_filtered(
183                    &self.protocol_system,
184                    Some(*upper_tvl_threshold),
185                    self.chain,
186                )
187            }
188        };
189        self.components = self
190            .rpc_client
191            .get_protocol_components_paginated(&body, None, RPC_CLIENT_CONCURRENCY)
192            .await?
193            .protocol_components
194            .into_iter()
195            .map(|pc| (pc.id.clone(), pc))
196            .filter(|(id, _)| !self.filter.is_blocklisted(id))
197            .collect::<HashMap<_, _>>();
198
199        self.reinitialize_contracts();
200
201        Ok(())
202    }
203
204    /// Initialise the tracked contracts list from tracked components and their entrypoints
205    fn reinitialize_contracts(&mut self) {
206        // Add contracts from all tracked components
207        self.contracts = self
208            .components
209            .values()
210            .flat_map(|comp| comp.contract_ids.iter().cloned())
211            .collect();
212
213        // Add contracts from entrypoints that are linked to tracked components
214        let tracked_component_ids = self
215            .components
216            .keys()
217            .cloned()
218            .collect::<HashSet<_>>();
219        for entrypoint in self.entrypoints.values() {
220            if !entrypoint
221                .components
222                .is_disjoint(&tracked_component_ids)
223            {
224                self.contracts
225                    .extend(entrypoint.contracts.iter().cloned());
226            }
227        }
228    }
229
230    /// Update the tracked contracts list with contracts associated with the given components
231    fn update_contracts(&mut self, components: Vec<ComponentId>) {
232        // Only process components that are actually being tracked.
233        let mut tracked_component_ids = HashSet::new();
234
235        // Add contracts from the components
236        for comp in components {
237            if let Some(component) = self.components.get(&comp) {
238                self.contracts
239                    .extend(component.contract_ids.iter().cloned());
240                tracked_component_ids.insert(comp);
241            }
242        }
243
244        // Identify entrypoints linked to the given components
245        for entrypoint in self.entrypoints.values() {
246            if !entrypoint
247                .components
248                .is_disjoint(&tracked_component_ids)
249            {
250                self.contracts
251                    .extend(entrypoint.contracts.iter().cloned());
252            }
253        }
254    }
255
256    /// Add new components to be tracked
257    #[instrument(skip(self, new_components))]
258    pub async fn start_tracking(
259        &mut self,
260        new_components: &[&ComponentId],
261    ) -> Result<(), RPCError> {
262        let new_components: Vec<_> = new_components
263            .iter()
264            .filter(|id| !self.filter.is_blocklisted(id))
265            .copied()
266            .collect();
267
268        if new_components.is_empty() {
269            return Ok(());
270        }
271
272        // Fetch the components
273        let request = ProtocolComponentsRequestBody::id_filtered(
274            &self.protocol_system,
275            new_components
276                .iter()
277                .map(|&id| id.to_string())
278                .collect(),
279            self.chain,
280        );
281        let components = self
282            .rpc_client
283            .get_protocol_components(&request)
284            .await?
285            .protocol_components
286            .into_iter()
287            .map(|pc| (pc.id.clone(), pc))
288            .collect::<HashMap<_, _>>();
289
290        // Update components and contracts
291        let component_ids: Vec<_> = components.keys().cloned().collect();
292        let component_count = component_ids.len();
293        self.components.extend(components);
294        self.update_contracts(component_ids);
295
296        debug!(n_components = component_count, "StartedTracking");
297        Ok(())
298    }
299
300    /// Stop tracking components
301    #[instrument(skip(self, to_remove))]
302    pub fn stop_tracking<'a, I: IntoIterator<Item = &'a ComponentId> + std::fmt::Debug>(
303        &mut self,
304        to_remove: I,
305    ) -> HashMap<ComponentId, ProtocolComponent> {
306        let mut removed_components = HashMap::new();
307
308        for component_id in to_remove {
309            if let Some(component) = self.components.remove(component_id) {
310                removed_components.insert(component_id.clone(), component);
311            }
312        }
313
314        // Refresh the tracked contracts list. This is more reliable and efficient than directly
315        // removing contracts from the list because some contracts are shared between components.
316        self.reinitialize_contracts();
317
318        debug!(n_components = removed_components.len(), "StoppedTracking");
319        removed_components
320    }
321
322    /// Updates the tracked entrypoints and contracts based on the given DCI data.
323    pub fn process_entrypoints(&mut self, dci_update: &DCIUpdate) {
324        // Update detected contracts for entrypoints
325        for (entrypoint, traces) in &dci_update.trace_results {
326            self.entrypoints
327                .entry(entrypoint.clone())
328                .or_default()
329                .contracts
330                .extend(traces.accessed_slots.keys().cloned());
331        }
332
333        // Update linked components for entrypoints
334        for (component, entrypoints) in &dci_update.new_entrypoints {
335            for entrypoint in entrypoints {
336                let entrypoint_info = self
337                    .entrypoints
338                    .entry(entrypoint.external_id.clone())
339                    .or_default();
340                entrypoint_info
341                    .components
342                    .insert(component.clone());
343                // If the component is tracked, add the detected contracts to the tracker
344                if self.components.contains_key(component) {
345                    self.contracts.extend(
346                        entrypoint_info
347                            .contracts
348                            .iter()
349                            .cloned(),
350                    );
351                }
352            }
353        }
354    }
355
356    /// Get related contracts for the given component ids. Assumes that the components are already
357    /// tracked, either by calling `start_tracking` or `initialise_components`.
358    ///
359    /// # Arguments
360    ///
361    /// * `ids` - A vector of component IDs to get the contracts for.
362    ///
363    /// # Returns
364    ///
365    /// A HashSet of contract IDs. Components that are not tracked will be logged and skipped.
366    pub fn get_contracts_by_component<'a, I: IntoIterator<Item = &'a String>>(
367        &self,
368        ids: I,
369    ) -> HashSet<Address> {
370        ids.into_iter()
371            .filter_map(|cid| {
372                if let Some(comp) = self.components.get(cid) {
373                    // Collect contracts from all entrypoints linked to this component
374                    let dci_contracts: HashSet<Address> = self
375                        .entrypoints
376                        .values()
377                        .filter(|ep| ep.components.contains(cid))
378                        .flat_map(|ep| ep.contracts.iter().cloned())
379                        .collect();
380                    Some(
381                        comp.contract_ids
382                            .clone()
383                            .into_iter()
384                            .chain(dci_contracts)
385                            .collect::<HashSet<_>>(),
386                    )
387                } else {
388                    warn!(
389                        "Requested component is not tracked: {cid}. Skipping fetching contracts..."
390                    );
391                    None
392                }
393            })
394            .flatten()
395            .collect()
396    }
397
398    pub fn get_tracked_component_ids(&self) -> Vec<ComponentId> {
399        self.components
400            .keys()
401            .cloned()
402            .collect()
403    }
404
405    /// Given BlockChanges, filter out components that are no longer relevant and return the
406    /// components that need to be added or removed.
407    pub fn filter_updated_components(
408        &self,
409        deltas: &BlockChanges,
410    ) -> (Vec<ComponentId>, Vec<ComponentId>) {
411        match &self.filter.variant {
412            ComponentFilterVariant::Ids(_) => (Default::default(), Default::default()),
413            ComponentFilterVariant::MinimumTVLRange { range: (remove_tvl, add_tvl), .. } => {
414                let (mut to_add, mut to_remove): (Vec<_>, Vec<_>) = deltas
415                    .component_tvl
416                    .iter()
417                    .filter(|(_, &tvl)| tvl < *remove_tvl || tvl > *add_tvl)
418                    .map(|(id, _)| id.clone())
419                    .partition(|id| deltas.component_tvl[id] > *add_tvl);
420
421                // Never add blocklisted components
422                to_add.retain(|id| !self.filter.is_blocklisted(id));
423
424                // Remove any currently tracked components that are now blocklisted
425                for id in self.components.keys() {
426                    if self.filter.is_blocklisted(id) && !to_remove.contains(id) {
427                        to_remove.push(id.clone());
428                    }
429                }
430
431                (to_add, to_remove)
432            }
433        }
434    }
435}
436
437#[cfg(test)]
438mod test {
439    use tycho_common::{
440        dto::{PaginationResponse, ProtocolComponentRequestResponse},
441        Bytes,
442    };
443
444    use super::*;
445    use crate::rpc::MockRPCClient;
446
447    fn with_mocked_rpc() -> ComponentTracker<MockRPCClient> {
448        let rpc = MockRPCClient::new();
449        ComponentTracker::new(
450            Chain::Ethereum,
451            "uniswap-v2",
452            ComponentFilter::with_tvl_range(0.0, 0.0),
453            rpc,
454        )
455    }
456
457    fn components_response() -> (Vec<Address>, ProtocolComponent) {
458        let contract_ids = vec![Bytes::from("0x1234"), Bytes::from("0xbabe")];
459        let component = ProtocolComponent {
460            id: "Component1".to_string(),
461            contract_ids: contract_ids.clone(),
462            ..Default::default()
463        };
464        (contract_ids, component)
465    }
466
467    #[tokio::test]
468    async fn test_initialise_components() {
469        let mut tracker = with_mocked_rpc();
470        let (contract_ids, component) = components_response();
471        let exp_component = component.clone();
472        tracker
473            .rpc_client
474            .expect_get_protocol_components_paginated()
475            .returning(move |_, _, _| {
476                Ok(ProtocolComponentRequestResponse {
477                    protocol_components: vec![component.clone()],
478                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
479                })
480            });
481
482        tracker
483            .initialise_components()
484            .await
485            .expect("Retrieving components failed");
486
487        assert_eq!(
488            tracker
489                .components
490                .get("Component1")
491                .expect("Component1 not tracked"),
492            &exp_component
493        );
494        assert_eq!(tracker.contracts, contract_ids.into_iter().collect());
495    }
496
497    #[tokio::test]
498    async fn test_start_tracking() {
499        let mut tracker = with_mocked_rpc();
500        let (contract_ids, component) = components_response();
501        let exp_contracts = contract_ids.into_iter().collect();
502        let component_id = component.id.clone();
503        let components_arg = [&component_id];
504        tracker
505            .rpc_client
506            .expect_get_protocol_components()
507            .returning(move |_| {
508                Ok(ProtocolComponentRequestResponse {
509                    protocol_components: vec![component.clone()],
510                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
511                })
512            });
513
514        tracker
515            .start_tracking(&components_arg)
516            .await
517            .expect("Tracking components failed");
518
519        assert_eq!(&tracker.contracts, &exp_contracts);
520        assert!(tracker
521            .components
522            .contains_key("Component1"));
523    }
524
525    #[test]
526    fn test_stop_tracking() {
527        let mut tracker = with_mocked_rpc();
528        let (contract_ids, component) = components_response();
529        tracker
530            .components
531            .insert("Component1".to_string(), component.clone());
532        tracker.contracts.extend(contract_ids);
533        let components_arg = ["Component1".to_string(), "Component2".to_string()];
534        let exp = [("Component1".to_string(), component)]
535            .into_iter()
536            .collect();
537
538        let res = tracker.stop_tracking(&components_arg);
539
540        assert_eq!(res, exp);
541        assert!(tracker.contracts.is_empty());
542    }
543
544    #[test]
545    fn test_get_contracts_by_component() {
546        let mut tracker = with_mocked_rpc();
547        let (exp_contracts, component) = components_response();
548        tracker
549            .components
550            .insert("Component1".to_string(), component);
551        let components_arg = ["Component1".to_string()];
552
553        let res = tracker.get_contracts_by_component(&components_arg);
554
555        assert_eq!(res, exp_contracts.into_iter().collect());
556    }
557
558    #[test]
559    fn test_get_tracked_component_ids() {
560        let mut tracker = with_mocked_rpc();
561        let (_, component) = components_response();
562        tracker
563            .components
564            .insert("Component1".to_string(), component);
565        let exp = vec!["Component1".to_string()];
566
567        let res = tracker.get_tracked_component_ids();
568
569        assert_eq!(res, exp);
570    }
571
572    fn with_mocked_rpc_and_blocklist(blocklisted: Vec<&str>) -> ComponentTracker<MockRPCClient> {
573        let rpc = MockRPCClient::new();
574        let filter = ComponentFilter::with_tvl_range(0.0, 0.0).blocklist(
575            blocklisted
576                .into_iter()
577                .map(String::from),
578        );
579        ComponentTracker::new(Chain::Ethereum, "uniswap-v2", filter, rpc)
580    }
581
582    #[tokio::test]
583    async fn test_initialise_skips_blocklisted_components() {
584        let mut tracker = with_mocked_rpc_and_blocklist(vec!["component1"]);
585        let (_, component) = components_response();
586        tracker
587            .rpc_client
588            .expect_get_protocol_components_paginated()
589            .returning(move |_, _, _| {
590                Ok(ProtocolComponentRequestResponse {
591                    protocol_components: vec![component.clone()],
592                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
593                })
594            });
595
596        tracker
597            .initialise_components()
598            .await
599            .expect("Retrieving components failed");
600
601        assert!(tracker.components.is_empty(), "Blocklisted component should not be in tracker");
602    }
603
604    #[tokio::test]
605    async fn test_start_tracking_skips_blocklisted() {
606        let mut tracker = with_mocked_rpc_and_blocklist(vec!["component1"]);
607        let component_id = "Component1".to_string();
608        let components_arg = [&component_id];
609
610        tracker
611            .start_tracking(&components_arg)
612            .await
613            .expect("start_tracking should succeed");
614
615        assert!(tracker.components.is_empty(), "Blocklisted component should not be tracked");
616    }
617
618    #[test]
619    fn test_filter_updated_blocks_blocklisted_add() {
620        let mut tracker = with_mocked_rpc_and_blocklist(vec!["blocklisted_pool"]);
621        tracker.filter = ComponentFilter::with_tvl_range(5.0, 10.0)
622            .blocklist(vec!["blocklisted_pool".to_string()]);
623
624        let deltas = BlockChanges {
625            component_tvl: HashMap::from([
626                ("blocklisted_pool".to_string(), 100.0),
627                ("allowed_pool".to_string(), 100.0),
628            ]),
629            ..Default::default()
630        };
631
632        let (to_add, to_remove) = tracker.filter_updated_components(&deltas);
633        assert!(
634            !to_add.contains(&"blocklisted_pool".to_string()),
635            "Blocklisted component should not be in to_add"
636        );
637        assert!(
638            to_add.contains(&"allowed_pool".to_string()),
639            "Non-blocklisted component should be in to_add"
640        );
641        assert!(to_remove.is_empty());
642    }
643}