1use std::collections::{HashMap, HashSet};
2
3use tracing::{debug, instrument, warn};
4use tycho_common::{
5 dto::{BlockChanges, Chain, DCIUpdate, ProtocolComponent, ProtocolComponentsRequestBody},
6 models::{Address, ComponentId, ProtocolSystem},
7};
8
9use crate::{
10 rpc::{RPCClient, RPC_CLIENT_CONCURRENCY},
11 RPCError,
12};
13
14#[derive(Clone, Debug)]
15pub(crate) enum ComponentFilterVariant {
16 Ids(Vec<ComponentId>),
17 MinimumTVLRange((f64, f64)),
23}
24
25#[derive(Clone, Debug)]
26pub struct ComponentFilter {
27 variant: ComponentFilterVariant,
28}
29
30impl ComponentFilter {
31 #[allow(non_snake_case)] #[deprecated(since = "0.9.2", note = "Please use with_tvl_range instead")]
40 pub fn MinimumTVL(min_tvl: f64) -> ComponentFilter {
41 ComponentFilter { variant: ComponentFilterVariant::MinimumTVLRange((min_tvl, min_tvl)) }
42 }
43
44 pub fn with_tvl_range(remove_tvl_threshold: f64, add_tvl_threshold: f64) -> ComponentFilter {
59 ComponentFilter {
60 variant: ComponentFilterVariant::MinimumTVLRange((
61 remove_tvl_threshold,
62 add_tvl_threshold,
63 )),
64 }
65 }
66
67 #[allow(non_snake_case)] pub fn Ids(ids: Vec<ComponentId>) -> ComponentFilter {
76 ComponentFilter {
77 variant: ComponentFilterVariant::Ids(
78 ids.into_iter()
79 .map(|id| id.to_lowercase())
80 .collect(),
81 ),
82 }
83 }
84}
85
86#[derive(Default)]
89struct EntrypointRelations {
90 components: HashSet<ComponentId>,
92 contracts: HashSet<Address>,
94}
95
96pub struct ComponentTracker<R: RPCClient> {
98 chain: Chain,
99 protocol_system: ProtocolSystem,
100 filter: ComponentFilter,
101 pub components: HashMap<ComponentId, ProtocolComponent>,
104 entrypoints: HashMap<String, EntrypointRelations>,
106 pub contracts: HashSet<Address>,
109 rpc_client: R,
111}
112
113impl<R> ComponentTracker<R>
114where
115 R: RPCClient,
116{
117 pub fn new(chain: Chain, protocol_system: &str, filter: ComponentFilter, rpc: R) -> Self {
118 Self {
119 chain,
120 protocol_system: protocol_system.to_string(),
121 filter,
122 components: Default::default(),
123 contracts: Default::default(),
124 rpc_client: rpc,
125 entrypoints: Default::default(),
126 }
127 }
128
129 pub async fn initialise_components(&mut self) -> Result<(), RPCError> {
132 let body = match &self.filter.variant {
133 ComponentFilterVariant::Ids(ids) => ProtocolComponentsRequestBody::id_filtered(
134 &self.protocol_system,
135 ids.clone(),
136 self.chain,
137 ),
138 ComponentFilterVariant::MinimumTVLRange((_, upper_tvl_threshold)) => {
139 ProtocolComponentsRequestBody::system_filtered(
140 &self.protocol_system,
141 Some(*upper_tvl_threshold),
142 self.chain,
143 )
144 }
145 };
146 self.components = self
147 .rpc_client
148 .get_protocol_components_paginated(&body, None, RPC_CLIENT_CONCURRENCY)
149 .await?
150 .protocol_components
151 .into_iter()
152 .map(|pc| (pc.id.clone(), pc))
153 .collect::<HashMap<_, _>>();
154
155 self.reinitialize_contracts();
156
157 Ok(())
158 }
159
160 fn reinitialize_contracts(&mut self) {
162 self.contracts = self
164 .components
165 .values()
166 .flat_map(|comp| comp.contract_ids.iter().cloned())
167 .collect();
168
169 let tracked_component_ids = self
171 .components
172 .keys()
173 .cloned()
174 .collect::<HashSet<_>>();
175 for entrypoint in self.entrypoints.values() {
176 if !entrypoint
177 .components
178 .is_disjoint(&tracked_component_ids)
179 {
180 self.contracts
181 .extend(entrypoint.contracts.iter().cloned());
182 }
183 }
184 }
185
186 fn update_contracts(&mut self, components: Vec<ComponentId>) {
188 let mut tracked_component_ids = HashSet::new();
190
191 for comp in components {
193 if let Some(component) = self.components.get(&comp) {
194 self.contracts
195 .extend(component.contract_ids.iter().cloned());
196 tracked_component_ids.insert(comp);
197 }
198 }
199
200 for entrypoint in self.entrypoints.values() {
202 if !entrypoint
203 .components
204 .is_disjoint(&tracked_component_ids)
205 {
206 self.contracts
207 .extend(entrypoint.contracts.iter().cloned());
208 }
209 }
210 }
211
212 #[instrument(skip(self, new_components))]
214 pub async fn start_tracking(
215 &mut self,
216 new_components: &[&ComponentId],
217 ) -> Result<(), RPCError> {
218 if new_components.is_empty() {
219 return Ok(());
220 }
221
222 let request = ProtocolComponentsRequestBody::id_filtered(
224 &self.protocol_system,
225 new_components
226 .iter()
227 .map(|&id| id.to_string())
228 .collect(),
229 self.chain,
230 );
231 let components = self
232 .rpc_client
233 .get_protocol_components(&request)
234 .await?
235 .protocol_components
236 .into_iter()
237 .map(|pc| (pc.id.clone(), pc))
238 .collect::<HashMap<_, _>>();
239
240 let component_ids: Vec<_> = components.keys().cloned().collect();
242 let component_count = component_ids.len();
243 self.components.extend(components);
244 self.update_contracts(component_ids);
245
246 debug!(n_components = component_count, "StartedTracking");
247 Ok(())
248 }
249
250 #[instrument(skip(self, to_remove))]
252 pub fn stop_tracking<'a, I: IntoIterator<Item = &'a ComponentId> + std::fmt::Debug>(
253 &mut self,
254 to_remove: I,
255 ) -> HashMap<ComponentId, ProtocolComponent> {
256 let mut removed_components = HashMap::new();
257
258 for component_id in to_remove {
259 if let Some(component) = self.components.remove(component_id) {
260 removed_components.insert(component_id.clone(), component);
261 }
262 }
263
264 self.reinitialize_contracts();
267
268 debug!(n_components = removed_components.len(), "StoppedTracking");
269 removed_components
270 }
271
272 pub fn process_entrypoints(&mut self, dci_update: &DCIUpdate) {
274 for (entrypoint, traces) in &dci_update.trace_results {
276 self.entrypoints
277 .entry(entrypoint.clone())
278 .or_default()
279 .contracts
280 .extend(traces.accessed_slots.keys().cloned());
281 }
282
283 for (component, entrypoints) in &dci_update.new_entrypoints {
285 for entrypoint in entrypoints {
286 let entrypoint_info = self
287 .entrypoints
288 .entry(entrypoint.external_id.clone())
289 .or_default();
290 entrypoint_info
291 .components
292 .insert(component.clone());
293 if self.components.contains_key(component) {
295 self.contracts.extend(
296 entrypoint_info
297 .contracts
298 .iter()
299 .cloned(),
300 );
301 }
302 }
303 }
304 }
305
306 pub fn get_contracts_by_component<'a, I: IntoIterator<Item = &'a String>>(
317 &self,
318 ids: I,
319 ) -> HashSet<Address> {
320 ids.into_iter()
321 .filter_map(|cid| {
322 if let Some(comp) = self.components.get(cid) {
323 let dci_contracts: HashSet<Address> = self
325 .entrypoints
326 .values()
327 .filter(|ep| ep.components.contains(cid))
328 .flat_map(|ep| ep.contracts.iter().cloned())
329 .collect();
330 Some(
331 comp.contract_ids
332 .clone()
333 .into_iter()
334 .chain(dci_contracts)
335 .collect::<HashSet<_>>(),
336 )
337 } else {
338 warn!(
339 "Requested component is not tracked: {cid}. Skipping fetching contracts..."
340 );
341 None
342 }
343 })
344 .flatten()
345 .collect()
346 }
347
348 pub fn get_tracked_component_ids(&self) -> Vec<ComponentId> {
349 self.components
350 .keys()
351 .cloned()
352 .collect()
353 }
354
355 pub fn filter_updated_components(
358 &self,
359 deltas: &BlockChanges,
360 ) -> (Vec<ComponentId>, Vec<ComponentId>) {
361 match &self.filter.variant {
362 ComponentFilterVariant::Ids(_) => (Default::default(), Default::default()),
363 ComponentFilterVariant::MinimumTVLRange((remove_tvl, add_tvl)) => deltas
364 .component_tvl
365 .iter()
366 .filter(|(_, &tvl)| tvl < *remove_tvl || tvl > *add_tvl)
367 .map(|(id, _)| id.clone())
368 .partition(|id| deltas.component_tvl[id] > *add_tvl),
369 }
370 }
371}
372
373#[cfg(test)]
374mod test {
375 use tycho_common::{
376 dto::{PaginationResponse, ProtocolComponentRequestResponse},
377 Bytes,
378 };
379
380 use super::*;
381 use crate::rpc::MockRPCClient;
382
383 fn with_mocked_rpc() -> ComponentTracker<MockRPCClient> {
384 let rpc = MockRPCClient::new();
385 ComponentTracker::new(
386 Chain::Ethereum,
387 "uniswap-v2",
388 ComponentFilter::with_tvl_range(0.0, 0.0),
389 rpc,
390 )
391 }
392
393 fn components_response() -> (Vec<Address>, ProtocolComponent) {
394 let contract_ids = vec![Bytes::from("0x1234"), Bytes::from("0xbabe")];
395 let component = ProtocolComponent {
396 id: "Component1".to_string(),
397 contract_ids: contract_ids.clone(),
398 ..Default::default()
399 };
400 (contract_ids, component)
401 }
402
403 #[tokio::test]
404 async fn test_initialise_components() {
405 let mut tracker = with_mocked_rpc();
406 let (contract_ids, component) = components_response();
407 let exp_component = component.clone();
408 tracker
409 .rpc_client
410 .expect_get_protocol_components_paginated()
411 .returning(move |_, _, _| {
412 Ok(ProtocolComponentRequestResponse {
413 protocol_components: vec![component.clone()],
414 pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
415 })
416 });
417
418 tracker
419 .initialise_components()
420 .await
421 .expect("Retrieving components failed");
422
423 assert_eq!(
424 tracker
425 .components
426 .get("Component1")
427 .expect("Component1 not tracked"),
428 &exp_component
429 );
430 assert_eq!(tracker.contracts, contract_ids.into_iter().collect());
431 }
432
433 #[tokio::test]
434 async fn test_start_tracking() {
435 let mut tracker = with_mocked_rpc();
436 let (contract_ids, component) = components_response();
437 let exp_contracts = contract_ids.into_iter().collect();
438 let component_id = component.id.clone();
439 let components_arg = [&component_id];
440 tracker
441 .rpc_client
442 .expect_get_protocol_components()
443 .returning(move |_| {
444 Ok(ProtocolComponentRequestResponse {
445 protocol_components: vec![component.clone()],
446 pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
447 })
448 });
449
450 tracker
451 .start_tracking(&components_arg)
452 .await
453 .expect("Tracking components failed");
454
455 assert_eq!(&tracker.contracts, &exp_contracts);
456 assert!(tracker
457 .components
458 .contains_key("Component1"));
459 }
460
461 #[test]
462 fn test_stop_tracking() {
463 let mut tracker = with_mocked_rpc();
464 let (contract_ids, component) = components_response();
465 tracker
466 .components
467 .insert("Component1".to_string(), component.clone());
468 tracker.contracts.extend(contract_ids);
469 let components_arg = ["Component1".to_string(), "Component2".to_string()];
470 let exp = [("Component1".to_string(), component)]
471 .into_iter()
472 .collect();
473
474 let res = tracker.stop_tracking(&components_arg);
475
476 assert_eq!(res, exp);
477 assert!(tracker.contracts.is_empty());
478 }
479
480 #[test]
481 fn test_get_contracts_by_component() {
482 let mut tracker = with_mocked_rpc();
483 let (exp_contracts, component) = components_response();
484 tracker
485 .components
486 .insert("Component1".to_string(), component);
487 let components_arg = ["Component1".to_string()];
488
489 let res = tracker.get_contracts_by_component(&components_arg);
490
491 assert_eq!(res, exp_contracts.into_iter().collect());
492 }
493
494 #[test]
495 fn test_get_tracked_component_ids() {
496 let mut tracker = with_mocked_rpc();
497 let (_, component) = components_response();
498 tracker
499 .components
500 .insert("Component1".to_string(), component);
501 let exp = vec!["Component1".to_string()];
502
503 let res = tracker.get_tracked_component_ids();
504
505 assert_eq!(res, exp);
506 }
507}