1use std::collections::{HashMap, HashSet};
2
3use tracing::{debug, instrument, warn};
4use tycho_common::{
5 dto::{BlockChanges, Chain, DCIUpdate, ProtocolComponent, ProtocolComponentsRequestBody},
6 models::{Address, ComponentId, ProtocolSystem},
7};
8
9use crate::{rpc::RPCClient, RPCError};
10
11#[derive(Clone, Debug)]
12pub(crate) enum ComponentFilterVariant {
13 Ids(Vec<ComponentId>),
14 MinimumTVLRange((f64, f64)),
20}
21
22#[derive(Clone, Debug)]
23pub struct ComponentFilter {
24 variant: ComponentFilterVariant,
25}
26
27impl ComponentFilter {
28 #[allow(non_snake_case)] #[deprecated(since = "0.9.2", note = "Please use with_tvl_range instead")]
37 pub fn MinimumTVL(min_tvl: f64) -> ComponentFilter {
38 ComponentFilter { variant: ComponentFilterVariant::MinimumTVLRange((min_tvl, min_tvl)) }
39 }
40
41 pub fn with_tvl_range(remove_tvl_threshold: f64, add_tvl_threshold: f64) -> ComponentFilter {
56 ComponentFilter {
57 variant: ComponentFilterVariant::MinimumTVLRange((
58 remove_tvl_threshold,
59 add_tvl_threshold,
60 )),
61 }
62 }
63
64 #[allow(non_snake_case)] pub fn Ids(ids: Vec<ComponentId>) -> ComponentFilter {
73 ComponentFilter {
74 variant: ComponentFilterVariant::Ids(
75 ids.into_iter()
76 .map(|id| id.to_lowercase())
77 .collect(),
78 ),
79 }
80 }
81}
82
83#[derive(Default)]
86struct EntrypointRelations {
87 components: HashSet<ComponentId>,
89 contracts: HashSet<Address>,
91}
92
93pub struct ComponentTracker<R: RPCClient> {
95 chain: Chain,
96 protocol_system: ProtocolSystem,
97 filter: ComponentFilter,
98 pub components: HashMap<ComponentId, ProtocolComponent>,
101 entrypoints: HashMap<String, EntrypointRelations>,
103 pub contracts: HashSet<Address>,
106 rpc_client: R,
108}
109
110impl<R> ComponentTracker<R>
111where
112 R: RPCClient,
113{
114 pub fn new(chain: Chain, protocol_system: &str, filter: ComponentFilter, rpc: R) -> Self {
115 Self {
116 chain,
117 protocol_system: protocol_system.to_string(),
118 filter,
119 components: Default::default(),
120 contracts: Default::default(),
121 rpc_client: rpc,
122 entrypoints: Default::default(),
123 }
124 }
125
126 pub async fn initialise_components(&mut self) -> Result<(), RPCError> {
129 let body = match &self.filter.variant {
130 ComponentFilterVariant::Ids(ids) => ProtocolComponentsRequestBody::id_filtered(
131 &self.protocol_system,
132 ids.clone(),
133 self.chain,
134 ),
135 ComponentFilterVariant::MinimumTVLRange((_, upper_tvl_threshold)) => {
136 ProtocolComponentsRequestBody::system_filtered(
137 &self.protocol_system,
138 Some(*upper_tvl_threshold),
139 self.chain,
140 )
141 }
142 };
143 self.components = self
144 .rpc_client
145 .get_protocol_components_paginated(&body, 500, 4)
146 .await?
147 .protocol_components
148 .into_iter()
149 .map(|pc| (pc.id.clone(), pc))
150 .collect::<HashMap<_, _>>();
151
152 self.reinitialize_contracts();
153
154 Ok(())
155 }
156
157 fn reinitialize_contracts(&mut self) {
159 self.contracts = self
161 .components
162 .values()
163 .flat_map(|comp| comp.contract_ids.iter().cloned())
164 .collect();
165
166 let tracked_component_ids = self
168 .components
169 .keys()
170 .cloned()
171 .collect::<HashSet<_>>();
172 for entrypoint in self.entrypoints.values() {
173 if !entrypoint
174 .components
175 .is_disjoint(&tracked_component_ids)
176 {
177 self.contracts
178 .extend(entrypoint.contracts.iter().cloned());
179 }
180 }
181 }
182
183 fn update_contracts(&mut self, components: Vec<ComponentId>) {
185 let tracked_component_ids = components
188 .into_iter()
189 .filter(|id| self.components.contains_key(id))
190 .collect::<HashSet<_>>();
191
192 for comp in &tracked_component_ids {
194 let component = self
195 .components
196 .get(comp)
197 .expect("Component should exist as it was filtered above");
198 self.contracts
199 .extend(component.contract_ids.iter().cloned());
200 }
201
202 for entrypoint in self.entrypoints.values() {
204 if !entrypoint
205 .components
206 .is_disjoint(&tracked_component_ids)
207 {
208 self.contracts
209 .extend(entrypoint.contracts.iter().cloned());
210 }
211 }
212 }
213
214 #[instrument(skip(self, new_components))]
216 pub async fn start_tracking(
217 &mut self,
218 new_components: &[&ComponentId],
219 ) -> Result<(), RPCError> {
220 if new_components.is_empty() {
221 return Ok(());
222 }
223
224 let request = ProtocolComponentsRequestBody::id_filtered(
226 &self.protocol_system,
227 new_components
228 .iter()
229 .map(|&id| id.to_string())
230 .collect(),
231 self.chain,
232 );
233 let components = self
234 .rpc_client
235 .get_protocol_components(&request)
236 .await?
237 .protocol_components
238 .into_iter()
239 .map(|pc| (pc.id.clone(), pc))
240 .collect::<HashMap<_, _>>();
241
242 let component_ids: Vec<_> = components.keys().cloned().collect();
244 let component_count = component_ids.len();
245 self.components.extend(components);
246 self.update_contracts(component_ids);
247
248 debug!(n_components = component_count, "StartedTracking");
249 Ok(())
250 }
251
252 #[instrument(skip(self, to_remove))]
254 pub fn stop_tracking<'a, I: IntoIterator<Item = &'a ComponentId> + std::fmt::Debug>(
255 &mut self,
256 to_remove: I,
257 ) -> HashMap<ComponentId, ProtocolComponent> {
258 let mut removed_components = HashMap::new();
259
260 for component_id in to_remove {
261 if let Some(component) = self.components.remove(component_id) {
262 removed_components.insert(component_id.clone(), component);
263 }
264 }
265
266 self.reinitialize_contracts();
269
270 debug!(n_components = removed_components.len(), "StoppedTracking");
271 removed_components
272 }
273
274 pub fn process_entrypoints(&mut self, dci_update: &DCIUpdate) -> Result<(), RPCError> {
276 for (entrypoint, traces) in &dci_update.trace_results {
278 self.entrypoints
279 .entry(entrypoint.clone())
280 .or_default()
281 .contracts
282 .extend(traces.accessed_slots.keys().cloned());
283 }
284
285 for (component, entrypoints) in &dci_update.new_entrypoints {
287 for entrypoint in entrypoints {
288 let entrypoint_info = self
289 .entrypoints
290 .entry(entrypoint.external_id.clone())
291 .or_default();
292 entrypoint_info
293 .components
294 .insert(component.clone());
295 if self.components.contains_key(component) {
297 self.contracts.extend(
298 entrypoint_info
299 .contracts
300 .iter()
301 .cloned(),
302 );
303 }
304 }
305 }
306
307 Ok(())
308 }
309
310 pub fn get_contracts_by_component<'a, I: IntoIterator<Item = &'a String>>(
321 &self,
322 ids: I,
323 ) -> HashSet<Address> {
324 ids.into_iter()
325 .filter_map(|cid| {
326 if let Some(comp) = self.components.get(cid) {
327 Some(comp.contract_ids.clone())
328 } else {
329 warn!(
330 "Requested component is not tracked: {cid}. Skipping fetching contracts..."
331 );
332 None
333 }
334 })
335 .flatten()
336 .collect()
337 }
338
339 pub fn get_tracked_component_ids(&self) -> Vec<ComponentId> {
340 self.components
341 .keys()
342 .cloned()
343 .collect()
344 }
345
346 pub fn filter_updated_components(
349 &self,
350 deltas: &BlockChanges,
351 ) -> (Vec<ComponentId>, Vec<ComponentId>) {
352 match &self.filter.variant {
353 ComponentFilterVariant::Ids(_) => (Default::default(), Default::default()),
354 ComponentFilterVariant::MinimumTVLRange((remove_tvl, add_tvl)) => deltas
355 .component_tvl
356 .iter()
357 .filter(|(_, &tvl)| tvl < *remove_tvl || tvl > *add_tvl)
358 .map(|(id, _)| id.clone())
359 .partition(|id| deltas.component_tvl[id] > *add_tvl),
360 }
361 }
362}
363
364#[cfg(test)]
365mod test {
366 use tycho_common::{
367 dto::{PaginationResponse, ProtocolComponentRequestResponse},
368 Bytes,
369 };
370
371 use super::*;
372 use crate::rpc::MockRPCClient;
373
374 fn with_mocked_rpc() -> ComponentTracker<MockRPCClient> {
375 let rpc = MockRPCClient::new();
376 ComponentTracker::new(
377 Chain::Ethereum,
378 "uniswap-v2",
379 ComponentFilter::with_tvl_range(0.0, 0.0),
380 rpc,
381 )
382 }
383
384 fn components_response() -> (Vec<Address>, ProtocolComponent) {
385 let contract_ids = vec![Bytes::from("0x1234"), Bytes::from("0xbabe")];
386 let component = ProtocolComponent {
387 id: "Component1".to_string(),
388 contract_ids: contract_ids.clone(),
389 ..Default::default()
390 };
391 (contract_ids, component)
392 }
393
394 #[tokio::test]
395 async fn test_initialise_components() {
396 let mut tracker = with_mocked_rpc();
397 let (contract_ids, component) = components_response();
398 let exp_component = component.clone();
399 tracker
400 .rpc_client
401 .expect_get_protocol_components_paginated()
402 .returning(move |_, _, _| {
403 Ok(ProtocolComponentRequestResponse {
404 protocol_components: vec![component.clone()],
405 pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
406 })
407 });
408
409 tracker
410 .initialise_components()
411 .await
412 .expect("Retrieving components failed");
413
414 assert_eq!(
415 tracker
416 .components
417 .get("Component1")
418 .expect("Component1 not tracked"),
419 &exp_component
420 );
421 assert_eq!(tracker.contracts, contract_ids.into_iter().collect());
422 }
423
424 #[tokio::test]
425 async fn test_start_tracking() {
426 let mut tracker = with_mocked_rpc();
427 let (contract_ids, component) = components_response();
428 let exp_contracts = contract_ids.into_iter().collect();
429 let component_id = component.id.clone();
430 let components_arg = [&component_id];
431 tracker
432 .rpc_client
433 .expect_get_protocol_components()
434 .returning(move |_| {
435 Ok(ProtocolComponentRequestResponse {
436 protocol_components: vec![component.clone()],
437 pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
438 })
439 });
440
441 tracker
442 .start_tracking(&components_arg)
443 .await
444 .expect("Tracking components failed");
445
446 assert_eq!(&tracker.contracts, &exp_contracts);
447 assert!(tracker
448 .components
449 .contains_key("Component1"));
450 }
451
452 #[test]
453 fn test_stop_tracking() {
454 let mut tracker = with_mocked_rpc();
455 let (contract_ids, component) = components_response();
456 tracker
457 .components
458 .insert("Component1".to_string(), component.clone());
459 tracker.contracts.extend(contract_ids);
460 let components_arg = ["Component1".to_string(), "Component2".to_string()];
461 let exp = [("Component1".to_string(), component)]
462 .into_iter()
463 .collect();
464
465 let res = tracker.stop_tracking(&components_arg);
466
467 assert_eq!(res, exp);
468 assert!(tracker.contracts.is_empty());
469 }
470
471 #[test]
472 fn test_get_contracts_by_component() {
473 let mut tracker = with_mocked_rpc();
474 let (exp_contracts, component) = components_response();
475 tracker
476 .components
477 .insert("Component1".to_string(), component);
478 let components_arg = ["Component1".to_string()];
479
480 let res = tracker.get_contracts_by_component(&components_arg);
481
482 assert_eq!(res, exp_contracts.into_iter().collect());
483 }
484
485 #[test]
486 fn test_get_tracked_component_ids() {
487 let mut tracker = with_mocked_rpc();
488 let (_, component) = components_response();
489 tracker
490 .components
491 .insert("Component1".to_string(), component);
492 let exp = vec!["Component1".to_string()];
493
494 let res = tracker.get_tracked_component_ids();
495
496 assert_eq!(res, exp);
497 }
498}