tycho_client/feed/
component_tracker.rs

1use std::collections::{HashMap, HashSet};
2
3use tracing::{debug, instrument, warn};
4use tycho_common::{
5    dto::{BlockChanges, Chain, ProtocolComponent, ProtocolComponentsRequestBody},
6    Bytes,
7};
8
9use crate::{rpc::RPCClient, RPCError};
10
11#[derive(Clone, Debug)]
12pub(crate) enum ComponentFilterVariant {
13    Ids(Vec<String>),
14    /// MinimumTVLRange is a tuple of (remove_tvl_threshold, add_tvl_threshold). Components that
15    /// drop below the remove threshold will be removed from tracking, components that exceed the
16    /// add threshold will be added. This helps buffer against components that fluctuate on the
17    /// tvl threshold boundary. The thresholds are denominated in native token of the chain, for
18    /// example 1 means 1 ETH on ethereum.
19    MinimumTVLRange((f64, f64)),
20}
21
22#[derive(Clone, Debug)]
23pub struct ComponentFilter {
24    variant: ComponentFilterVariant,
25}
26
27impl ComponentFilter {
28    /// Creates a `ComponentFilter` that filters components based on a minimum Total Value Locked
29    /// (TVL) threshold.
30    ///
31    /// # Arguments
32    ///
33    /// * `min_tvl` - The minimum TVL required for a component to be tracked. This is denominated in
34    ///   native token of the chain.
35    #[allow(non_snake_case)] // for backwards compatibility
36    #[deprecated(since = "0.9.2", note = "Please use with_tvl_range instead")]
37    pub fn MinimumTVL(min_tvl: f64) -> ComponentFilter {
38        ComponentFilter { variant: ComponentFilterVariant::MinimumTVLRange((min_tvl, min_tvl)) }
39    }
40
41    /// Creates a `ComponentFilter` with a specified TVL range for adding or removing components
42    /// from tracking.
43    ///
44    /// Components that drop below the `remove_tvl_threshold` will be removed from tracking,
45    /// while components that exceed the `add_tvl_threshold` will be added to tracking.
46    /// This approach helps to reduce fluctuations caused by components hovering around a single
47    /// threshold.
48    ///
49    /// # Arguments
50    ///
51    /// * `remove_tvl_threshold` - The TVL below which a component will be removed from tracking.
52    /// * `add_tvl_threshold` - The TVL above which a component will be added to tracking.
53    ///
54    /// Note: thresholds are denominated in native token of the chain.
55    pub fn with_tvl_range(remove_tvl_threshold: f64, add_tvl_threshold: f64) -> ComponentFilter {
56        ComponentFilter {
57            variant: ComponentFilterVariant::MinimumTVLRange((
58                remove_tvl_threshold,
59                add_tvl_threshold,
60            )),
61        }
62    }
63
64    /// Creates a `ComponentFilter` that **includes only** the components with the specified IDs,
65    /// effectively filtering out all other components.
66    ///
67    /// # Arguments
68    ///
69    /// * `ids` - A vector of component IDs to include in the filter. Only components with these IDs
70    ///   will be tracked.
71    #[allow(non_snake_case)] // for backwards compatibility
72    pub fn Ids(ids: Vec<String>) -> ComponentFilter {
73        ComponentFilter { variant: ComponentFilterVariant::Ids(ids) }
74    }
75}
76
77/// Helper struct to store which components are being tracked atm.
78pub struct ComponentTracker<R: RPCClient> {
79    chain: Chain,
80    protocol_system: String,
81    filter: ComponentFilter,
82    // We will need to request a snapshot for components/Contracts that we did not emit as
83    // snapshot for yet but are relevant now, e.g. because min tvl threshold exceeded.
84    pub components: HashMap<String, ProtocolComponent>,
85    /// derived from tracked components, we need this if subscribed to a vm extractor cause updates
86    /// are emitted on a contract level instead of on a component level.
87    pub contracts: HashSet<Bytes>,
88    /// Client to retrieve necessary protocol components from the rpc.
89    rpc_client: R,
90}
91
92impl<R> ComponentTracker<R>
93where
94    R: RPCClient,
95{
96    pub fn new(chain: Chain, protocol_system: &str, filter: ComponentFilter, rpc: R) -> Self {
97        Self {
98            chain,
99            protocol_system: protocol_system.to_string(),
100            filter,
101            components: Default::default(),
102            contracts: Default::default(),
103            rpc_client: rpc,
104        }
105    }
106    /// Retrieve all components that belong to the system we are extracing and have sufficient tvl.
107    pub async fn initialise_components(&mut self) -> Result<(), RPCError> {
108        let body = match &self.filter.variant {
109            ComponentFilterVariant::Ids(ids) => ProtocolComponentsRequestBody::id_filtered(
110                &self.protocol_system,
111                ids.clone(),
112                self.chain,
113            ),
114            ComponentFilterVariant::MinimumTVLRange((_, upper_tvl_threshold)) => {
115                ProtocolComponentsRequestBody::system_filtered(
116                    &self.protocol_system,
117                    Some(*upper_tvl_threshold),
118                    self.chain,
119                )
120            }
121        };
122
123        self.components = self
124            .rpc_client
125            .get_protocol_components_paginated(&body, 500, 4)
126            .await?
127            .protocol_components
128            .into_iter()
129            .map(|pc| (pc.id.clone(), pc))
130            .collect::<HashMap<_, _>>();
131        self.update_contracts();
132        Ok(())
133    }
134
135    fn update_contracts(&mut self) {
136        self.contracts.extend(
137            self.components
138                .values()
139                .flat_map(|comp| comp.contract_ids.iter().cloned()),
140        );
141    }
142
143    /// Add a new component to be tracked
144    #[instrument(skip(self, new_components))]
145    pub async fn start_tracking(&mut self, new_components: &[&String]) -> Result<(), RPCError> {
146        if new_components.is_empty() {
147            return Ok(());
148        }
149        let request = ProtocolComponentsRequestBody::id_filtered(
150            &self.protocol_system,
151            new_components
152                .iter()
153                .map(|pc_id| pc_id.to_string())
154                .collect(),
155            self.chain,
156        );
157
158        self.components.extend(
159            self.rpc_client
160                .get_protocol_components(&request)
161                .await?
162                .protocol_components
163                .into_iter()
164                .map(|pc| (pc.id.clone(), pc)),
165        );
166        self.update_contracts();
167        debug!(n_components = new_components.len(), "StartedTracking");
168        Ok(())
169    }
170
171    /// Stop tracking components
172    #[instrument(skip(self, to_remove))]
173    pub fn stop_tracking<'a, I: IntoIterator<Item = &'a String> + std::fmt::Debug>(
174        &mut self,
175        to_remove: I,
176    ) -> HashMap<String, ProtocolComponent> {
177        let mut n_components = 0;
178        let res = to_remove
179            .into_iter()
180            .filter_map(|k| {
181                let comp = self.components.remove(k);
182                if let Some(component) = &comp {
183                    n_components += 1;
184                    for contract in component.contract_ids.iter() {
185                        self.contracts.remove(contract);
186                    }
187                }
188                comp.map(|c| (k.clone(), c))
189            })
190            .collect();
191        debug!(n_components, "StoppedTracking");
192        res
193    }
194
195    pub fn get_contracts_by_component<'a, I: IntoIterator<Item = &'a String>>(
196        &self,
197        ids: I,
198    ) -> HashSet<Bytes> {
199        ids.into_iter()
200            .flat_map(|cid| {
201                let comp = self
202                    .components
203                    .get(cid)
204                    .unwrap_or_else(|| panic!("requested component that is not present: {cid}"));
205                comp.contract_ids.iter().cloned()
206            })
207            .collect()
208    }
209
210    pub fn get_tracked_component_ids(&self) -> Vec<String> {
211        self.components
212            .keys()
213            .cloned()
214            .collect()
215    }
216
217    /// Given BlockChanges, filter out components that are no longer relevant and return the
218    /// components that need to be added or removed.
219    pub fn filter_updated_components(&self, deltas: &BlockChanges) -> (Vec<String>, Vec<String>) {
220        match &self.filter.variant {
221            ComponentFilterVariant::Ids(_) => (Default::default(), Default::default()),
222            ComponentFilterVariant::MinimumTVLRange((remove_tvl, add_tvl)) => deltas
223                .component_tvl
224                .iter()
225                .filter(|(_, &tvl)| tvl < *remove_tvl || tvl > *add_tvl)
226                .map(|(id, _)| id.clone())
227                .partition(|id| deltas.component_tvl[id] > *add_tvl),
228        }
229    }
230}
231
232#[cfg(test)]
233mod test {
234    use tycho_common::dto::{PaginationResponse, ProtocolComponentRequestResponse};
235
236    use super::*;
237    use crate::rpc::MockRPCClient;
238
239    fn with_mocked_rpc() -> ComponentTracker<MockRPCClient> {
240        let rpc = MockRPCClient::new();
241        ComponentTracker::new(
242            Chain::Ethereum,
243            "uniswap-v2",
244            ComponentFilter::with_tvl_range(0.0, 0.0),
245            rpc,
246        )
247    }
248
249    fn components_response() -> (Vec<Bytes>, ProtocolComponent) {
250        let contract_ids = vec![Bytes::from("0x1234"), Bytes::from("0xbabe")];
251        let component = ProtocolComponent {
252            id: "Component1".to_string(),
253            contract_ids: contract_ids.clone(),
254            ..Default::default()
255        };
256        (contract_ids, component)
257    }
258
259    #[tokio::test]
260    async fn test_initialise_components() {
261        let mut tracker = with_mocked_rpc();
262        let (contract_ids, component) = components_response();
263        let exp_component = component.clone();
264        tracker
265            .rpc_client
266            .expect_get_protocol_components_paginated()
267            .returning(move |_, _, _| {
268                Ok(ProtocolComponentRequestResponse {
269                    protocol_components: vec![component.clone()],
270                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
271                })
272            });
273
274        tracker
275            .initialise_components()
276            .await
277            .expect("Retrieving components failed");
278
279        assert_eq!(
280            tracker
281                .components
282                .get("Component1")
283                .expect("Component1 not tracked"),
284            &exp_component
285        );
286        assert_eq!(tracker.contracts, contract_ids.into_iter().collect());
287    }
288
289    #[tokio::test]
290    async fn test_start_tracking() {
291        let mut tracker = with_mocked_rpc();
292        let (contract_ids, component) = components_response();
293        let exp_contracts = contract_ids.into_iter().collect();
294        let component_id = component.id.clone();
295        let components_arg = [&component_id];
296        tracker
297            .rpc_client
298            .expect_get_protocol_components()
299            .returning(move |_| {
300                Ok(ProtocolComponentRequestResponse {
301                    protocol_components: vec![component.clone()],
302                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
303                })
304            });
305
306        tracker
307            .start_tracking(&components_arg)
308            .await
309            .expect("Tracking components failed");
310
311        assert_eq!(&tracker.contracts, &exp_contracts);
312        assert!(tracker
313            .components
314            .contains_key("Component1"));
315    }
316
317    #[test]
318    fn test_stop_tracking() {
319        let mut tracker = with_mocked_rpc();
320        let (contract_ids, component) = components_response();
321        tracker
322            .components
323            .insert("Component1".to_string(), component.clone());
324        tracker.contracts.extend(contract_ids);
325        let components_arg = ["Component1".to_string(), "Component2".to_string()];
326        let exp = [("Component1".to_string(), component)]
327            .into_iter()
328            .collect();
329
330        let res = tracker.stop_tracking(&components_arg);
331
332        assert_eq!(res, exp);
333        assert!(tracker.contracts.is_empty());
334    }
335
336    #[test]
337    fn test_get_contracts_by_component() {
338        let mut tracker = with_mocked_rpc();
339        let (exp_contracts, component) = components_response();
340        tracker
341            .components
342            .insert("Component1".to_string(), component);
343        let components_arg = ["Component1".to_string()];
344
345        let res = tracker.get_contracts_by_component(&components_arg);
346
347        assert_eq!(res, exp_contracts.into_iter().collect());
348    }
349
350    #[test]
351    fn test_get_tracked_component_ids() {
352        let mut tracker = with_mocked_rpc();
353        let (_, component) = components_response();
354        tracker
355            .components
356            .insert("Component1".to_string(), component);
357        let exp = vec!["Component1".to_string()];
358
359        let res = tracker.get_tracked_component_ids();
360
361        assert_eq!(res, exp);
362    }
363}