1use std::collections::{HashMap, HashSet};
2
3use tracing::{debug, instrument, warn};
4use tycho_common::{
5 dto::{BlockChanges, Chain, DCIUpdate, ProtocolComponent, ProtocolComponentsRequestBody},
6 models::{Address, ComponentId, ProtocolSystem},
7};
8
9use crate::{rpc::RPCClient, RPCError};
10
11#[derive(Clone, Debug)]
12pub(crate) enum ComponentFilterVariant {
13 Ids(Vec<ComponentId>),
14 MinimumTVLRange((f64, f64)),
20}
21
22#[derive(Clone, Debug)]
23pub struct ComponentFilter {
24 variant: ComponentFilterVariant,
25}
26
27impl ComponentFilter {
28 #[allow(non_snake_case)] #[deprecated(since = "0.9.2", note = "Please use with_tvl_range instead")]
37 pub fn MinimumTVL(min_tvl: f64) -> ComponentFilter {
38 ComponentFilter { variant: ComponentFilterVariant::MinimumTVLRange((min_tvl, min_tvl)) }
39 }
40
41 pub fn with_tvl_range(remove_tvl_threshold: f64, add_tvl_threshold: f64) -> ComponentFilter {
56 ComponentFilter {
57 variant: ComponentFilterVariant::MinimumTVLRange((
58 remove_tvl_threshold,
59 add_tvl_threshold,
60 )),
61 }
62 }
63
64 #[allow(non_snake_case)] pub fn Ids(ids: Vec<ComponentId>) -> ComponentFilter {
73 ComponentFilter {
74 variant: ComponentFilterVariant::Ids(
75 ids.into_iter()
76 .map(|id| id.to_lowercase())
77 .collect(),
78 ),
79 }
80 }
81}
82
83#[derive(Default)]
86struct EntrypointRelations {
87 components: HashSet<ComponentId>,
89 contracts: HashSet<Address>,
91}
92
93pub struct ComponentTracker<R: RPCClient> {
95 chain: Chain,
96 protocol_system: ProtocolSystem,
97 filter: ComponentFilter,
98 pub components: HashMap<ComponentId, ProtocolComponent>,
101 entrypoints: HashMap<String, EntrypointRelations>,
103 pub contracts: HashSet<Address>,
106 rpc_client: R,
108}
109
110impl<R> ComponentTracker<R>
111where
112 R: RPCClient,
113{
114 pub fn new(chain: Chain, protocol_system: &str, filter: ComponentFilter, rpc: R) -> Self {
115 Self {
116 chain,
117 protocol_system: protocol_system.to_string(),
118 filter,
119 components: Default::default(),
120 contracts: Default::default(),
121 rpc_client: rpc,
122 entrypoints: Default::default(),
123 }
124 }
125
126 pub async fn initialise_components(&mut self) -> Result<(), RPCError> {
129 let body = match &self.filter.variant {
130 ComponentFilterVariant::Ids(ids) => ProtocolComponentsRequestBody::id_filtered(
131 &self.protocol_system,
132 ids.clone(),
133 self.chain,
134 ),
135 ComponentFilterVariant::MinimumTVLRange((_, upper_tvl_threshold)) => {
136 ProtocolComponentsRequestBody::system_filtered(
137 &self.protocol_system,
138 Some(*upper_tvl_threshold),
139 self.chain,
140 )
141 }
142 };
143 self.components = self
144 .rpc_client
145 .get_protocol_components_paginated(&body, 500, 4)
146 .await?
147 .protocol_components
148 .into_iter()
149 .map(|pc| (pc.id.clone(), pc))
150 .collect::<HashMap<_, _>>();
151
152 self.reinitialize_contracts();
153
154 Ok(())
155 }
156
157 fn reinitialize_contracts(&mut self) {
159 self.contracts = self
161 .components
162 .values()
163 .flat_map(|comp| comp.contract_ids.iter().cloned())
164 .collect();
165
166 let tracked_component_ids = self
168 .components
169 .keys()
170 .cloned()
171 .collect::<HashSet<_>>();
172 for entrypoint in self.entrypoints.values() {
173 if !entrypoint
174 .components
175 .is_disjoint(&tracked_component_ids)
176 {
177 self.contracts
178 .extend(entrypoint.contracts.iter().cloned());
179 }
180 }
181 }
182
183 fn update_contracts(&mut self, components: Vec<ComponentId>) {
185 let tracked_component_ids = components
188 .into_iter()
189 .filter(|id| self.components.contains_key(id))
190 .collect::<HashSet<_>>();
191
192 for comp in &tracked_component_ids {
194 let component = self
195 .components
196 .get(comp)
197 .expect("Component should exist as it was filtered above");
198 self.contracts
199 .extend(component.contract_ids.iter().cloned());
200 }
201
202 for entrypoint in self.entrypoints.values() {
204 if !entrypoint
205 .components
206 .is_disjoint(&tracked_component_ids)
207 {
208 self.contracts
209 .extend(entrypoint.contracts.iter().cloned());
210 }
211 }
212 }
213
214 #[instrument(skip(self, new_components))]
216 pub async fn start_tracking(
217 &mut self,
218 new_components: &[&ComponentId],
219 ) -> Result<(), RPCError> {
220 if new_components.is_empty() {
221 return Ok(());
222 }
223
224 let request = ProtocolComponentsRequestBody::id_filtered(
226 &self.protocol_system,
227 new_components
228 .iter()
229 .map(|&id| id.to_string())
230 .collect(),
231 self.chain,
232 );
233 let components = self
234 .rpc_client
235 .get_protocol_components(&request)
236 .await?
237 .protocol_components
238 .into_iter()
239 .map(|pc| (pc.id.clone(), pc))
240 .collect::<HashMap<_, _>>();
241
242 let component_ids: Vec<_> = components.keys().cloned().collect();
244 let component_count = component_ids.len();
245 self.components.extend(components);
246 self.update_contracts(component_ids);
247
248 debug!(n_components = component_count, "StartedTracking");
249 Ok(())
250 }
251
252 #[instrument(skip(self, to_remove))]
254 pub fn stop_tracking<'a, I: IntoIterator<Item = &'a ComponentId> + std::fmt::Debug>(
255 &mut self,
256 to_remove: I,
257 ) -> HashMap<ComponentId, ProtocolComponent> {
258 let mut removed_components = HashMap::new();
259
260 for component_id in to_remove {
261 if let Some(component) = self.components.remove(component_id) {
262 removed_components.insert(component_id.clone(), component);
263 }
264 }
265
266 self.reinitialize_contracts();
269
270 debug!(n_components = removed_components.len(), "StoppedTracking");
271 removed_components
272 }
273
274 pub fn process_entrypoints(&mut self, dci_update: &DCIUpdate) -> Result<(), RPCError> {
276 for (entrypoint, traces) in &dci_update.trace_results {
278 self.entrypoints
279 .entry(entrypoint.clone())
280 .or_default()
281 .contracts
282 .extend(traces.accessed_slots.keys().cloned());
283 }
284
285 for (component, entrypoints) in &dci_update.new_entrypoints {
287 for entrypoint in entrypoints {
288 let entrypoint_info = self
289 .entrypoints
290 .entry(entrypoint.external_id.clone())
291 .or_default();
292 entrypoint_info
293 .components
294 .insert(component.clone());
295 if self.components.contains_key(component) {
297 self.contracts.extend(
298 entrypoint_info
299 .contracts
300 .iter()
301 .cloned(),
302 );
303 }
304 }
305 }
306
307 Ok(())
308 }
309
310 pub fn get_contracts_by_component<'a, I: IntoIterator<Item = &'a String>>(
321 &self,
322 ids: I,
323 ) -> HashSet<Address> {
324 ids.into_iter()
325 .filter_map(|cid| {
326 if let Some(comp) = self.components.get(cid) {
327 let dci_contracts: HashSet<Address> = self
329 .entrypoints
330 .values()
331 .filter(|ep| ep.components.contains(cid))
332 .flat_map(|ep| ep.contracts.iter().cloned())
333 .collect();
334 Some(
335 comp.contract_ids
336 .clone()
337 .into_iter()
338 .chain(dci_contracts)
339 .collect::<HashSet<_>>(),
340 )
341 } else {
342 warn!(
343 "Requested component is not tracked: {cid}. Skipping fetching contracts..."
344 );
345 None
346 }
347 })
348 .flatten()
349 .collect()
350 }
351
352 pub fn get_tracked_component_ids(&self) -> Vec<ComponentId> {
353 self.components
354 .keys()
355 .cloned()
356 .collect()
357 }
358
359 pub fn filter_updated_components(
362 &self,
363 deltas: &BlockChanges,
364 ) -> (Vec<ComponentId>, Vec<ComponentId>) {
365 match &self.filter.variant {
366 ComponentFilterVariant::Ids(_) => (Default::default(), Default::default()),
367 ComponentFilterVariant::MinimumTVLRange((remove_tvl, add_tvl)) => deltas
368 .component_tvl
369 .iter()
370 .filter(|(_, &tvl)| tvl < *remove_tvl || tvl > *add_tvl)
371 .map(|(id, _)| id.clone())
372 .partition(|id| deltas.component_tvl[id] > *add_tvl),
373 }
374 }
375}
376
377#[cfg(test)]
378mod test {
379 use tycho_common::{
380 dto::{PaginationResponse, ProtocolComponentRequestResponse},
381 Bytes,
382 };
383
384 use super::*;
385 use crate::rpc::MockRPCClient;
386
387 fn with_mocked_rpc() -> ComponentTracker<MockRPCClient> {
388 let rpc = MockRPCClient::new();
389 ComponentTracker::new(
390 Chain::Ethereum,
391 "uniswap-v2",
392 ComponentFilter::with_tvl_range(0.0, 0.0),
393 rpc,
394 )
395 }
396
397 fn components_response() -> (Vec<Address>, ProtocolComponent) {
398 let contract_ids = vec![Bytes::from("0x1234"), Bytes::from("0xbabe")];
399 let component = ProtocolComponent {
400 id: "Component1".to_string(),
401 contract_ids: contract_ids.clone(),
402 ..Default::default()
403 };
404 (contract_ids, component)
405 }
406
407 #[tokio::test]
408 async fn test_initialise_components() {
409 let mut tracker = with_mocked_rpc();
410 let (contract_ids, component) = components_response();
411 let exp_component = component.clone();
412 tracker
413 .rpc_client
414 .expect_get_protocol_components_paginated()
415 .returning(move |_, _, _| {
416 Ok(ProtocolComponentRequestResponse {
417 protocol_components: vec![component.clone()],
418 pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
419 })
420 });
421
422 tracker
423 .initialise_components()
424 .await
425 .expect("Retrieving components failed");
426
427 assert_eq!(
428 tracker
429 .components
430 .get("Component1")
431 .expect("Component1 not tracked"),
432 &exp_component
433 );
434 assert_eq!(tracker.contracts, contract_ids.into_iter().collect());
435 }
436
437 #[tokio::test]
438 async fn test_start_tracking() {
439 let mut tracker = with_mocked_rpc();
440 let (contract_ids, component) = components_response();
441 let exp_contracts = contract_ids.into_iter().collect();
442 let component_id = component.id.clone();
443 let components_arg = [&component_id];
444 tracker
445 .rpc_client
446 .expect_get_protocol_components()
447 .returning(move |_| {
448 Ok(ProtocolComponentRequestResponse {
449 protocol_components: vec![component.clone()],
450 pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
451 })
452 });
453
454 tracker
455 .start_tracking(&components_arg)
456 .await
457 .expect("Tracking components failed");
458
459 assert_eq!(&tracker.contracts, &exp_contracts);
460 assert!(tracker
461 .components
462 .contains_key("Component1"));
463 }
464
465 #[test]
466 fn test_stop_tracking() {
467 let mut tracker = with_mocked_rpc();
468 let (contract_ids, component) = components_response();
469 tracker
470 .components
471 .insert("Component1".to_string(), component.clone());
472 tracker.contracts.extend(contract_ids);
473 let components_arg = ["Component1".to_string(), "Component2".to_string()];
474 let exp = [("Component1".to_string(), component)]
475 .into_iter()
476 .collect();
477
478 let res = tracker.stop_tracking(&components_arg);
479
480 assert_eq!(res, exp);
481 assert!(tracker.contracts.is_empty());
482 }
483
484 #[test]
485 fn test_get_contracts_by_component() {
486 let mut tracker = with_mocked_rpc();
487 let (exp_contracts, component) = components_response();
488 tracker
489 .components
490 .insert("Component1".to_string(), component);
491 let components_arg = ["Component1".to_string()];
492
493 let res = tracker.get_contracts_by_component(&components_arg);
494
495 assert_eq!(res, exp_contracts.into_iter().collect());
496 }
497
498 #[test]
499 fn test_get_tracked_component_ids() {
500 let mut tracker = with_mocked_rpc();
501 let (_, component) = components_response();
502 tracker
503 .components
504 .insert("Component1".to_string(), component);
505 let exp = vec!["Component1".to_string()];
506
507 let res = tracker.get_tracked_component_ids();
508
509 assert_eq!(res, exp);
510 }
511}