tycho_client/feed/
component_tracker.rs

1use std::collections::{HashMap, HashSet};
2
3use tracing::{debug, instrument, warn};
4use tycho_common::{
5    dto::{BlockChanges, Chain, ProtocolComponent, ProtocolComponentsRequestBody},
6    Bytes,
7};
8
9use crate::{rpc::RPCClient, RPCError};
10
11#[derive(Clone, Debug)]
12pub(crate) enum ComponentFilterVariant {
13    Ids(Vec<String>),
14    /// MinimumTVLRange is a tuple of (remove_tvl_threshold, add_tvl_threshold). Components that
15    /// drop below the remove threshold will be removed from tracking, components that exceed the
16    /// add threshold will be added. This helps buffer against components that fluctuate on the
17    /// tvl threshold boundary. The thresholds are denominated in native token of the chain, for
18    /// example 1 means 1 ETH on ethereum.
19    MinimumTVLRange((f64, f64)),
20}
21
22#[derive(Clone, Debug)]
23pub struct ComponentFilter {
24    variant: ComponentFilterVariant,
25}
26
27impl ComponentFilter {
28    /// Creates a `ComponentFilter` that filters components based on a minimum Total Value Locked
29    /// (TVL) threshold.
30    ///
31    /// # Arguments
32    ///
33    /// * `min_tvl` - The minimum TVL required for a component to be tracked. This is denominated in
34    ///   native token of the chain.
35    #[allow(non_snake_case)] // for backwards compatibility
36    #[deprecated(since = "0.9.2", note = "Please use with_tvl_range instead")]
37    pub fn MinimumTVL(min_tvl: f64) -> ComponentFilter {
38        ComponentFilter { variant: ComponentFilterVariant::MinimumTVLRange((min_tvl, min_tvl)) }
39    }
40
41    /// Creates a `ComponentFilter` with a specified TVL range for adding or removing components
42    /// from tracking.
43    ///
44    /// Components that drop below the `remove_tvl_threshold` will be removed from tracking,
45    /// while components that exceed the `add_tvl_threshold` will be added to tracking.
46    /// This approach helps to reduce fluctuations caused by components hovering around a single
47    /// threshold.
48    ///
49    /// # Arguments
50    ///
51    /// * `remove_tvl_threshold` - The TVL below which a component will be removed from tracking.
52    /// * `add_tvl_threshold` - The TVL above which a component will be added to tracking.
53    ///
54    /// Note: thresholds are denominated in native token of the chain.
55    pub fn with_tvl_range(remove_tvl_threshold: f64, add_tvl_threshold: f64) -> ComponentFilter {
56        ComponentFilter {
57            variant: ComponentFilterVariant::MinimumTVLRange((
58                remove_tvl_threshold,
59                add_tvl_threshold,
60            )),
61        }
62    }
63
64    /// Creates a `ComponentFilter` that **includes only** the components with the specified IDs,
65    /// effectively filtering out all other components.
66    ///
67    /// # Arguments
68    ///
69    /// * `ids` - A vector of component IDs to include in the filter. Only components with these IDs
70    ///   will be tracked.
71    #[allow(non_snake_case)] // for backwards compatibility
72    pub fn Ids(ids: Vec<String>) -> ComponentFilter {
73        ComponentFilter { variant: ComponentFilterVariant::Ids(ids) }
74    }
75}
76
77/// Helper struct to store which components are being tracked atm.
78pub struct ComponentTracker<R: RPCClient> {
79    chain: Chain,
80    protocol_system: String,
81    filter: ComponentFilter,
82    // We will need to request a snapshot for components/Contracts that we did not emit as
83    // snapshot for yet but are relevant now, e.g. because min tvl threshold exceeded.
84    pub components: HashMap<String, ProtocolComponent>,
85    /// derived from tracked components, we need this if subscribed to a vm extractor cause updates
86    /// are emitted on a contract level instead of on a component level.
87    pub contracts: HashSet<Bytes>,
88    /// Client to retrieve necessary protocol components from the rpc.
89    rpc_client: R,
90}
91
92impl<R> ComponentTracker<R>
93where
94    R: RPCClient,
95{
96    pub fn new(chain: Chain, protocol_system: &str, filter: ComponentFilter, rpc: R) -> Self {
97        Self {
98            chain,
99            protocol_system: protocol_system.to_string(),
100            filter,
101            components: Default::default(),
102            contracts: Default::default(),
103            rpc_client: rpc,
104        }
105    }
106    /// Retrieve all components that belong to the system we are extracing and have sufficient tvl.
107    pub async fn initialise_components(&mut self) -> Result<(), RPCError> {
108        let body = match &self.filter.variant {
109            ComponentFilterVariant::Ids(ids) => ProtocolComponentsRequestBody::id_filtered(
110                &self.protocol_system,
111                ids.clone(),
112                self.chain,
113            ),
114            ComponentFilterVariant::MinimumTVLRange((_, upper_tvl_threshold)) => {
115                ProtocolComponentsRequestBody::system_filtered(
116                    &self.protocol_system,
117                    Some(*upper_tvl_threshold),
118                    self.chain,
119                )
120            }
121        };
122
123        self.components = self
124            .rpc_client
125            .get_protocol_components_paginated(&body, 500, 4)
126            .await?
127            .protocol_components
128            .into_iter()
129            .map(|pc| (pc.id.clone(), pc))
130            .collect::<HashMap<_, _>>();
131        self.update_contracts();
132        Ok(())
133    }
134
135    fn update_contracts(&mut self) {
136        self.contracts.extend(
137            self.components
138                .values()
139                .flat_map(|comp| comp.contract_ids.iter().cloned()),
140        );
141    }
142
143    /// Add a new component to be tracked
144    #[instrument(skip(self, new_components))]
145    pub async fn start_tracking(&mut self, new_components: &[&String]) -> Result<(), RPCError> {
146        if new_components.is_empty() {
147            return Ok(());
148        }
149        let request = ProtocolComponentsRequestBody::id_filtered(
150            &self.protocol_system,
151            new_components
152                .iter()
153                .map(|pc_id| pc_id.to_string())
154                .collect(),
155            self.chain,
156        );
157
158        self.components.extend(
159            self.rpc_client
160                .get_protocol_components(&request)
161                .await?
162                .protocol_components
163                .into_iter()
164                .map(|pc| (pc.id.clone(), pc)),
165        );
166        self.update_contracts();
167        debug!(n_components = new_components.len(), "StartedTracking");
168        Ok(())
169    }
170
171    /// Stop tracking components
172    #[instrument(skip(self, to_remove))]
173    pub fn stop_tracking<'a, I: IntoIterator<Item = &'a String> + std::fmt::Debug>(
174        &mut self,
175        to_remove: I,
176    ) -> HashMap<String, ProtocolComponent> {
177        let mut n_components = 0;
178        let res = to_remove
179            .into_iter()
180            .filter_map(|k| {
181                let comp = self.components.remove(k);
182                if let Some(component) = &comp {
183                    n_components += 1;
184                    for contract in component.contract_ids.iter() {
185                        self.contracts.remove(contract);
186                    }
187                }
188                comp.map(|c| (k.clone(), c))
189            })
190            .collect();
191        debug!(n_components, "StoppedTracking");
192        res
193    }
194
195    /// Get related contracts for the given component ids. Assumes that the components are already
196    /// tracked, either by calling `start_tracking` or `initialise_components`.
197    ///
198    /// # Arguments
199    ///
200    /// * `ids` - A vector of component IDs to get the contracts for.
201    ///
202    /// # Returns
203    ///
204    /// A HashSet of contract IDs. Components that are not tracked will be logged and skipped.
205    pub fn get_contracts_by_component<'a, I: IntoIterator<Item = &'a String>>(
206        &self,
207        ids: I,
208    ) -> HashSet<Bytes> {
209        ids.into_iter()
210            .filter_map(|cid| {
211                if let Some(comp) = self.components.get(cid) {
212                    Some(comp.contract_ids.clone())
213                } else {
214                    warn!(
215                        "Requested component is not tracked: {cid}. Skipping fetching contracts..."
216                    );
217                    None
218                }
219            })
220            .flatten()
221            .collect()
222    }
223
224    pub fn get_tracked_component_ids(&self) -> Vec<String> {
225        self.components
226            .keys()
227            .cloned()
228            .collect()
229    }
230
231    /// Given BlockChanges, filter out components that are no longer relevant and return the
232    /// components that need to be added or removed.
233    pub fn filter_updated_components(&self, deltas: &BlockChanges) -> (Vec<String>, Vec<String>) {
234        match &self.filter.variant {
235            ComponentFilterVariant::Ids(_) => (Default::default(), Default::default()),
236            ComponentFilterVariant::MinimumTVLRange((remove_tvl, add_tvl)) => deltas
237                .component_tvl
238                .iter()
239                .filter(|(_, &tvl)| tvl < *remove_tvl || tvl > *add_tvl)
240                .map(|(id, _)| id.clone())
241                .partition(|id| deltas.component_tvl[id] > *add_tvl),
242        }
243    }
244}
245
246#[cfg(test)]
247mod test {
248    use tycho_common::dto::{PaginationResponse, ProtocolComponentRequestResponse};
249
250    use super::*;
251    use crate::rpc::MockRPCClient;
252
253    fn with_mocked_rpc() -> ComponentTracker<MockRPCClient> {
254        let rpc = MockRPCClient::new();
255        ComponentTracker::new(
256            Chain::Ethereum,
257            "uniswap-v2",
258            ComponentFilter::with_tvl_range(0.0, 0.0),
259            rpc,
260        )
261    }
262
263    fn components_response() -> (Vec<Bytes>, ProtocolComponent) {
264        let contract_ids = vec![Bytes::from("0x1234"), Bytes::from("0xbabe")];
265        let component = ProtocolComponent {
266            id: "Component1".to_string(),
267            contract_ids: contract_ids.clone(),
268            ..Default::default()
269        };
270        (contract_ids, component)
271    }
272
273    #[tokio::test]
274    async fn test_initialise_components() {
275        let mut tracker = with_mocked_rpc();
276        let (contract_ids, component) = components_response();
277        let exp_component = component.clone();
278        tracker
279            .rpc_client
280            .expect_get_protocol_components_paginated()
281            .returning(move |_, _, _| {
282                Ok(ProtocolComponentRequestResponse {
283                    protocol_components: vec![component.clone()],
284                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
285                })
286            });
287
288        tracker
289            .initialise_components()
290            .await
291            .expect("Retrieving components failed");
292
293        assert_eq!(
294            tracker
295                .components
296                .get("Component1")
297                .expect("Component1 not tracked"),
298            &exp_component
299        );
300        assert_eq!(tracker.contracts, contract_ids.into_iter().collect());
301    }
302
303    #[tokio::test]
304    async fn test_start_tracking() {
305        let mut tracker = with_mocked_rpc();
306        let (contract_ids, component) = components_response();
307        let exp_contracts = contract_ids.into_iter().collect();
308        let component_id = component.id.clone();
309        let components_arg = [&component_id];
310        tracker
311            .rpc_client
312            .expect_get_protocol_components()
313            .returning(move |_| {
314                Ok(ProtocolComponentRequestResponse {
315                    protocol_components: vec![component.clone()],
316                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
317                })
318            });
319
320        tracker
321            .start_tracking(&components_arg)
322            .await
323            .expect("Tracking components failed");
324
325        assert_eq!(&tracker.contracts, &exp_contracts);
326        assert!(tracker
327            .components
328            .contains_key("Component1"));
329    }
330
331    #[test]
332    fn test_stop_tracking() {
333        let mut tracker = with_mocked_rpc();
334        let (contract_ids, component) = components_response();
335        tracker
336            .components
337            .insert("Component1".to_string(), component.clone());
338        tracker.contracts.extend(contract_ids);
339        let components_arg = ["Component1".to_string(), "Component2".to_string()];
340        let exp = [("Component1".to_string(), component)]
341            .into_iter()
342            .collect();
343
344        let res = tracker.stop_tracking(&components_arg);
345
346        assert_eq!(res, exp);
347        assert!(tracker.contracts.is_empty());
348    }
349
350    #[test]
351    fn test_get_contracts_by_component() {
352        let mut tracker = with_mocked_rpc();
353        let (exp_contracts, component) = components_response();
354        tracker
355            .components
356            .insert("Component1".to_string(), component);
357        let components_arg = ["Component1".to_string()];
358
359        let res = tracker.get_contracts_by_component(&components_arg);
360
361        assert_eq!(res, exp_contracts.into_iter().collect());
362    }
363
364    #[test]
365    fn test_get_tracked_component_ids() {
366        let mut tracker = with_mocked_rpc();
367        let (_, component) = components_response();
368        tracker
369            .components
370            .insert("Component1".to_string(), component);
371        let exp = vec!["Component1".to_string()];
372
373        let res = tracker.get_tracked_component_ids();
374
375        assert_eq!(res, exp);
376    }
377}